tustvold commented on code in PR #4716:
URL: https://github.com/apache/arrow-rs/pull/4716#discussion_r1300068671
##########
arrow-ord/src/cmp.rs:
##########
@@ -141,51 +174,114 @@ fn compare_op(
let l_len = l.len();
let r_len = r.len();
- let l_nulls = l.logical_nulls();
- let r_nulls = r.logical_nulls();
- let (len, nulls) = match (l_s, r_s) {
- (true, true) | (false, false) => {
- if l_len != r_len {
- return Err(ArrowError::InvalidArgumentError(format!(
- "Cannot compare arrays of different lengths, got {l_len}
vs {r_len}"
- )));
- }
- (l_len, NullBuffer::union(l_nulls.as_ref(), r_nulls.as_ref()))
- }
- (true, false) => match l_nulls.map(|x| x.null_count() !=
0).unwrap_or_default() {
- true => (r_len, Some(NullBuffer::new_null(r_len))),
- false => (r_len, r_nulls), // Left is scalar and not null
- },
- (false, true) => match r_nulls.map(|x| x.null_count() !=
0).unwrap_or_default() {
- true => (l_len, Some(NullBuffer::new_null(l_len))),
- false => (l_len, l_nulls), // Right is scalar and not null
- },
+ if l_len != r_len && !l_s && !r_s {
+ return Err(ArrowError::InvalidArgumentError(format!(
+ "Cannot compare arrays of different lengths, got {l_len} vs
{r_len}"
+ )));
+ }
+
+ let len = match l_s {
+ true => r_len,
+ false => l_len,
};
+ let l_nulls = l.logical_nulls();
+ let r_nulls = r.logical_nulls();
+
let l_v = l.as_any_dictionary_opt();
let l = l_v.map(|x| x.values().as_ref()).unwrap_or(l);
+ let l_t = l.data_type();
let r_v = r.as_any_dictionary_opt();
let r = r_v.map(|x| x.values().as_ref()).unwrap_or(r);
+ let r_t = r.data_type();
+
+ if l_t != r_t || l_t.is_nested() {
+ return Err(ArrowError::InvalidArgumentError(format!(
+ "Invalid comparison operation: {l_t} {op} {r_t}"
+ )));
+ }
+
+ // Defer computation as may not be necessary
+ let values = || -> BooleanBuffer {
+ let d = downcast_primitive_array! {
+ (l, r) => apply(op, l.values().as_ref(), l_s, l_v,
r.values().as_ref(), r_s, r_v),
+ (Boolean, Boolean) => apply(op, l.as_boolean(), l_s, l_v,
r.as_boolean(), r_s, r_v),
+ (Utf8, Utf8) => apply(op, l.as_string::<i32>(), l_s, l_v,
r.as_string::<i32>(), r_s, r_v),
+ (LargeUtf8, LargeUtf8) => apply(op, l.as_string::<i64>(), l_s,
l_v, r.as_string::<i64>(), r_s, r_v),
+ (Binary, Binary) => apply(op, l.as_binary::<i32>(), l_s, l_v,
r.as_binary::<i32>(), r_s, r_v),
+ (LargeBinary, LargeBinary) => apply(op, l.as_binary::<i64>(), l_s,
l_v, r.as_binary::<i64>(), r_s, r_v),
+ (FixedSizeBinary(_), FixedSizeBinary(_)) => apply(op,
l.as_fixed_size_binary(), l_s, l_v, r.as_fixed_size_binary(), r_s, r_v),
+ (Null, Null) => None,
+ _ => unreachable!(),
+ };
+ d.unwrap_or_else(|| BooleanBuffer::new_unset(len))
+ };
- let values = downcast_primitive_array! {
- (l, r) => apply(op, l.values().as_ref(), l_s, l_v,
r.values().as_ref(), r_s, r_v),
- (Boolean, Boolean) => apply(op, l.as_boolean(), l_s, l_v,
r.as_boolean(), r_s, r_v),
- (Utf8, Utf8) => apply(op, l.as_string::<i32>(), l_s, l_v,
r.as_string::<i32>(), r_s, r_v),
- (LargeUtf8, LargeUtf8) => apply(op, l.as_string::<i64>(), l_s, l_v,
r.as_string::<i64>(), r_s, r_v),
- (Binary, Binary) => apply(op, l.as_binary::<i32>(), l_s, l_v,
r.as_binary::<i32>(), r_s, r_v),
- (LargeBinary, LargeBinary) => apply(op, l.as_binary::<i64>(), l_s,
l_v, r.as_binary::<i64>(), r_s, r_v),
- (FixedSizeBinary(_), FixedSizeBinary(_)) => apply(op,
l.as_fixed_size_binary(), l_s, l_v, r.as_fixed_size_binary(), r_s, r_v),
- (l_t, r_t) => return
Err(ArrowError::InvalidArgumentError(format!("Invalid comparison operation:
{l_t} {op} {r_t}"))),
- }.unwrap_or_else(|| {
- let count = nulls.as_ref().map(|x| x.null_count()).unwrap_or_default();
- assert_eq!(count, len); // Sanity check
- BooleanBuffer::new_unset(len)
- });
-
- assert_eq!(values.len(), len); // Sanity check
- Ok(BooleanArray::new(values, nulls))
+ let l_nulls = l_nulls.filter(|n| n.null_count() > 0);
+ let r_nulls = r_nulls.filter(|n| n.null_count() > 0);
+ Ok(match (l_nulls, l_s, r_nulls, r_s) {
+ (Some(l), true, Some(r), true) | (Some(l), false, Some(r), false) => {
+ // Either both sides are scalar or neither side is scalar
+ match op {
+ Op::Distinct => {
+ let values = values();
+ let l = l.inner().bit_chunks().iter_padded();
+ let r = r.inner().bit_chunks().iter_padded();
+ let ne = values.bit_chunks().iter_padded();
+
+ let c = |((l, r), n)| ((l ^ r) | (l & r & n));
+ let buffer = l.zip(r).zip(ne).map(c).collect();
+ BooleanBuffer::new(buffer, 0, len).into()
+ }
+ Op::NotDistinct => {
+ let values = values();
+ let l = l.inner().bit_chunks().iter_padded();
+ let r = r.inner().bit_chunks().iter_padded();
+ let e = values.bit_chunks().iter_padded();
+
+ let c = |((l, r), e)| u64::not(l | r) | (l & r & e);
+ let buffer = l.zip(r).zip(e).map(c).collect();
+ BooleanBuffer::new(buffer, 0, len).into()
+ }
+ _ => BooleanArray::new(values(), NullBuffer::union(Some(&l),
Some(&r))),
+ }
+ }
+ (Some(_), true, Some(a), false) | (Some(a), false, Some(_), true) => {
+ // Scalar is null, other side is non-scalar and nullable
+ match op {
+ Op::Distinct => a.into_inner().into(),
+ Op::NotDistinct => a.into_inner().not().into(),
+ _ => BooleanArray::new_null(len),
+ }
+ }
+ (Some(nulls), is_scalar, None, _) | (None, _, Some(nulls), is_scalar)
=> {
+ // Only one side is nullable
+ match is_scalar {
+ true => match op {
+ // Scalar is null, other side is not nullable
+ Op::Distinct => BooleanBuffer::new_set(len).into(),
+ Op::NotDistinct => BooleanBuffer::new_unset(len).into(),
+ _ => BooleanArray::new_null(len),
+ },
+ false => match op {
+ Op::Distinct => {
+ let values = values();
+ let l = nulls.inner().bit_chunks().iter_padded();
+ let ne = values.bit_chunks().iter_padded();
+ let c = |(l, n)| u64::not(l) | n;
+ let buffer = l.zip(ne).map(c).collect();
Review Comment:
BooleanBuffer does implement https://doc.rust-lang.org/std/ops/trait.Not.html
However, as you surmised, computing the mask as is done here is marginally
better for performance.
Deferring the buffer allocated by values would result in non-trivial
additional codegen, and would likely confuse LLVMs already temperamental
vectorisation,
Reusing the buffer is potentially worth exploring. I suspect that in most
cases the performance gain would be marginal at best, but I haven't profiled
this
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]