alamb commented on code in PR #4701:
URL: https://github.com/apache/arrow-rs/pull/4701#discussion_r1295681482
##########
arrow-ord/Cargo.toml:
##########
@@ -44,10 +44,3 @@ half = { version = "2.1", default-features = false, features
= ["num-traits"] }
[dev-dependencies]
rand = { version = "0.8", default-features = false, features = ["std",
"std_rng"] }
-
-[package.metadata.docs.rs]
-features = ["dyn_cmp_dict"]
Review Comment:
Is this an API change? (if someone was using arrow-ord directly)
I see that `dyn_cmp_dict` is still used in arrow / arrow-string
##########
arrow-flight/src/sql/metadata/db_schemas.rs:
##########
@@ -129,7 +129,8 @@ impl GetDbSchemasBuilder {
}
if let Some(catalog_filter_name) = catalog_filter {
- filters.push(eq_utf8_scalar(&catalog_name, &catalog_filter_name)?);
+ let scalar = StringArray::from_iter_values([catalog_filter_name]);
Review Comment:
As I user I find this construction somewhat awkward to create a simple scalar
What would you think about creating convenience functions that would let
this code be like:
```rust
let scalar = Scalar::new_utf8(catalog_fitler_name);
```
(doesn't have to be part of this PR, I can add it as a follow on ticket/PR)
##########
arrow-ord/src/comparison.rs:
##########
@@ -821,1325 +750,211 @@ where
/// Note that totalOrder treats positive and negative zeros are different. If
it is necessary
/// to treat them as equal, please normalize zeros before calling this kernel.
/// Please refer to `f32::total_cmp` and `f64::total_cmp`.
+#[deprecated(note = "Use arrow_ord::cmp::neq")]
pub fn neq_dyn_scalar<T>(left: &dyn Array, right: T) -> Result<BooleanArray,
ArrowError>
where
T: num::ToPrimitive + std::fmt::Debug,
{
- match left.data_type() {
- DataType::Dictionary(key_type, _value_type) => {
- dyn_compare_scalar!(left, right, key_type, neq_dyn_scalar)
- }
- _ => dyn_compare_scalar!(left, right, neq_scalar),
- }
+ let right = make_primitive_scalar(left.data_type(), right)?;
Review Comment:
The fact that you also felt it was useful to make a function that makes a
`Scalar` from a rust primitive value I think also increases the justification
for formalizing the creation in functions such as `Scalar::new_primtiive(right)`
##########
arrow-ord/src/comparison.rs:
##########
@@ -5231,21 +3871,12 @@ mod tests {
.into_iter()
.map(Some)
.collect();
- #[cfg(feature = "simd")]
- let expected = BooleanArray::from(
- vec![Some(false), Some(false), Some(false), Some(false),
Some(false)],
- );
- #[cfg(not(feature = "simd"))]
+
let expected = BooleanArray::from(
vec![Some(true), Some(false), Some(false), Some(false),
Some(false)],
);
assert_eq!(eq_dyn_scalar(&array, f32::NAN).unwrap(), expected);
- #[cfg(feature = "simd")]
- let expected = BooleanArray::from(
Review Comment:
that is quite strange to get a different (presumably wrong) answer when
using the `simd` feature flat 🤔
##########
arrow-ord/src/cmp.rs:
##########
@@ -0,0 +1,527 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Comparison kernels for `Array`s.
+//!
+//! These kernels can leverage SIMD if available on your system. Currently no
runtime
+//! detection is provided, you should enable the specific SIMD intrinsics using
+//! `RUSTFLAGS="-C target-feature=+avx2"` for example. See the documentation
+//! [here](https://doc.rust-lang.org/stable/core/arch/) for more information.
+//!
+
+use arrow_array::cast::AsArray;
+use arrow_array::types::{ArrowDictionaryKeyType, ByteArrayType};
+use arrow_array::{
+ downcast_dictionary_array, downcast_primitive_array, Array, ArrayRef,
+ ArrowNativeTypeOp, BooleanArray, Datum, DictionaryArray,
FixedSizeBinaryArray,
+ GenericByteArray,
+};
+use arrow_buffer::bit_util::ceil;
+use arrow_buffer::{ArrowNativeType, BooleanBuffer, MutableBuffer, NullBuffer};
+use arrow_schema::ArrowError;
+use arrow_select::take::take;
+
+#[derive(Debug, Copy, Clone)]
+enum Op {
+ Equal,
+ NotEqual,
+ Less,
+ LessEqual,
+ Greater,
+ GreaterEqual,
+}
+
+impl std::fmt::Display for Op {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ match self {
+ Op::Equal => write!(f, "=="),
+ Op::NotEqual => write!(f, "!="),
+ Op::Less => write!(f, "<"),
+ Op::LessEqual => write!(f, "<="),
+ Op::Greater => write!(f, ">"),
+ Op::GreaterEqual => write!(f, ">="),
+ }
+ }
+}
+
+/// Perform `left == right` operation on two [`Datum`]
+///
+/// For floating values like f32 and f64, this comparison produces an ordering
in accordance to
+/// the totalOrder predicate as defined in the IEEE 754 (2008 revision)
floating point standard.
+/// Note that totalOrder treats positive and negative zeros as different. If
it is necessary
+/// to treat them as equal, please normalize zeros before calling this kernel.
+///
+/// Please refer to [`f32::total_cmp`] and [`f64::total_cmp`]
+pub fn eq(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray,
ArrowError> {
+ compare_op(Op::Equal, lhs, rhs)
+}
+
+/// Perform `left != right` operation on two [`Datum`]
+///
+/// For floating values like f32 and f64, this comparison produces an ordering
in accordance to
+/// the totalOrder predicate as defined in the IEEE 754 (2008 revision)
floating point standard.
+/// Note that totalOrder treats positive and negative zeros as different. If
it is necessary
+/// to treat them as equal, please normalize zeros before calling this kernel.
+///
+/// Please refer to [`f32::total_cmp`] and [`f64::total_cmp`]
+pub fn neq(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray,
ArrowError> {
+ compare_op(Op::NotEqual, lhs, rhs)
+}
+
+/// Perform `left < right` operation on two [`Datum`]
+///
+/// For floating values like f32 and f64, this comparison produces an ordering
in accordance to
+/// the totalOrder predicate as defined in the IEEE 754 (2008 revision)
floating point standard.
+/// Note that totalOrder treats positive and negative zeros as different. If
it is necessary
+/// to treat them as equal, please normalize zeros before calling this kernel.
+///
+/// Please refer to [`f32::total_cmp`] and [`f64::total_cmp`]
+pub fn lt(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray,
ArrowError> {
+ compare_op(Op::Less, lhs, rhs)
+}
+
+/// Perform `left <= right` operation on two [`Datum`]
+///
+/// For floating values like f32 and f64, this comparison produces an ordering
in accordance to
+/// the totalOrder predicate as defined in the IEEE 754 (2008 revision)
floating point standard.
+/// Note that totalOrder treats positive and negative zeros as different. If
it is necessary
+/// to treat them as equal, please normalize zeros before calling this kernel.
+///
+/// Please refer to [`f32::total_cmp`] and [`f64::total_cmp`]
+pub fn lt_eq(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray,
ArrowError> {
+ compare_op(Op::LessEqual, lhs, rhs)
+}
+
+/// Perform `left > right` operation on two [`Datum`]
+///
+/// For floating values like f32 and f64, this comparison produces an ordering
in accordance to
+/// the totalOrder predicate as defined in the IEEE 754 (2008 revision)
floating point standard.
+/// Note that totalOrder treats positive and negative zeros as different. If
it is necessary
+/// to treat them as equal, please normalize zeros before calling this kernel.
+///
+/// Please refer to [`f32::total_cmp`] and [`f64::total_cmp`]
+pub fn gt(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray,
ArrowError> {
+ compare_op(Op::Greater, lhs, rhs)
+}
+
+/// Perform `left >= right` operation on two [`Datum`]
+///
+/// For floating values like f32 and f64, this comparison produces an ordering
in accordance to
+/// the totalOrder predicate as defined in the IEEE 754 (2008 revision)
floating point standard.
+/// Note that totalOrder treats positive and negative zeros as different. If
it is necessary
+/// to treat them as equal, please normalize zeros before calling this kernel.
+///
+/// Please refer to [`f32::total_cmp`] and [`f64::total_cmp`]
+pub fn gt_eq(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray,
ArrowError> {
+ compare_op(Op::GreaterEqual, lhs, rhs)
+}
+
+/// Perform `op` on the provided `Datum`
+fn compare_op(
+ op: Op,
+ lhs: &dyn Datum,
+ rhs: &dyn Datum,
+) -> Result<BooleanArray, ArrowError> {
+ use arrow_schema::DataType::*;
+ let (l, l_s) = lhs.get();
+ let (r, r_s) = rhs.get();
Review Comment:
While I applaud the brevity of single character abbreviations, I would
personally find this code easier to read if it spelled `scalar` out like:
```suggestion
let (l, l_scalar) = lhs.get();
let (r, r_scalar) = rhs.get();
```
##########
arrow-ord/src/cmp.rs:
##########
@@ -0,0 +1,527 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Comparison kernels for `Array`s.
+//!
+//! These kernels can leverage SIMD if available on your system. Currently no
runtime
+//! detection is provided, you should enable the specific SIMD intrinsics using
+//! `RUSTFLAGS="-C target-feature=+avx2"` for example. See the documentation
+//! [here](https://doc.rust-lang.org/stable/core/arch/) for more information.
+//!
+
+use arrow_array::cast::AsArray;
+use arrow_array::types::{ArrowDictionaryKeyType, ByteArrayType};
+use arrow_array::{
+ downcast_dictionary_array, downcast_primitive_array, Array, ArrayRef,
+ ArrowNativeTypeOp, BooleanArray, Datum, DictionaryArray,
FixedSizeBinaryArray,
+ GenericByteArray,
+};
+use arrow_buffer::bit_util::ceil;
+use arrow_buffer::{ArrowNativeType, BooleanBuffer, MutableBuffer, NullBuffer};
+use arrow_schema::ArrowError;
+use arrow_select::take::take;
+
+#[derive(Debug, Copy, Clone)]
+enum Op {
+ Equal,
+ NotEqual,
+ Less,
+ LessEqual,
+ Greater,
+ GreaterEqual,
+}
+
+impl std::fmt::Display for Op {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ match self {
+ Op::Equal => write!(f, "=="),
+ Op::NotEqual => write!(f, "!="),
+ Op::Less => write!(f, "<"),
+ Op::LessEqual => write!(f, "<="),
+ Op::Greater => write!(f, ">"),
+ Op::GreaterEqual => write!(f, ">="),
+ }
+ }
+}
+
+/// Perform `left == right` operation on two [`Datum`]
+///
+/// For floating values like f32 and f64, this comparison produces an ordering
in accordance to
+/// the totalOrder predicate as defined in the IEEE 754 (2008 revision)
floating point standard.
+/// Note that totalOrder treats positive and negative zeros as different. If
it is necessary
+/// to treat them as equal, please normalize zeros before calling this kernel.
+///
+/// Please refer to [`f32::total_cmp`] and [`f64::total_cmp`]
+pub fn eq(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray,
ArrowError> {
+ compare_op(Op::Equal, lhs, rhs)
+}
+
+/// Perform `left != right` operation on two [`Datum`]
+///
+/// For floating values like f32 and f64, this comparison produces an ordering
in accordance to
+/// the totalOrder predicate as defined in the IEEE 754 (2008 revision)
floating point standard.
+/// Note that totalOrder treats positive and negative zeros as different. If
it is necessary
+/// to treat them as equal, please normalize zeros before calling this kernel.
+///
+/// Please refer to [`f32::total_cmp`] and [`f64::total_cmp`]
+pub fn neq(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray,
ArrowError> {
+ compare_op(Op::NotEqual, lhs, rhs)
+}
+
+/// Perform `left < right` operation on two [`Datum`]
+///
+/// For floating values like f32 and f64, this comparison produces an ordering
in accordance to
+/// the totalOrder predicate as defined in the IEEE 754 (2008 revision)
floating point standard.
+/// Note that totalOrder treats positive and negative zeros as different. If
it is necessary
+/// to treat them as equal, please normalize zeros before calling this kernel.
+///
+/// Please refer to [`f32::total_cmp`] and [`f64::total_cmp`]
+pub fn lt(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray,
ArrowError> {
+ compare_op(Op::Less, lhs, rhs)
+}
+
+/// Perform `left <= right` operation on two [`Datum`]
+///
+/// For floating values like f32 and f64, this comparison produces an ordering
in accordance to
+/// the totalOrder predicate as defined in the IEEE 754 (2008 revision)
floating point standard.
+/// Note that totalOrder treats positive and negative zeros as different. If
it is necessary
+/// to treat them as equal, please normalize zeros before calling this kernel.
+///
+/// Please refer to [`f32::total_cmp`] and [`f64::total_cmp`]
+pub fn lt_eq(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray,
ArrowError> {
+ compare_op(Op::LessEqual, lhs, rhs)
+}
+
+/// Perform `left > right` operation on two [`Datum`]
+///
+/// For floating values like f32 and f64, this comparison produces an ordering
in accordance to
+/// the totalOrder predicate as defined in the IEEE 754 (2008 revision)
floating point standard.
+/// Note that totalOrder treats positive and negative zeros as different. If
it is necessary
+/// to treat them as equal, please normalize zeros before calling this kernel.
+///
+/// Please refer to [`f32::total_cmp`] and [`f64::total_cmp`]
+pub fn gt(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray,
ArrowError> {
+ compare_op(Op::Greater, lhs, rhs)
+}
+
+/// Perform `left >= right` operation on two [`Datum`]
+///
+/// For floating values like f32 and f64, this comparison produces an ordering
in accordance to
+/// the totalOrder predicate as defined in the IEEE 754 (2008 revision)
floating point standard.
+/// Note that totalOrder treats positive and negative zeros as different. If
it is necessary
+/// to treat them as equal, please normalize zeros before calling this kernel.
+///
+/// Please refer to [`f32::total_cmp`] and [`f64::total_cmp`]
+pub fn gt_eq(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray,
ArrowError> {
+ compare_op(Op::GreaterEqual, lhs, rhs)
+}
+
+/// Perform `op` on the provided `Datum`
+fn compare_op(
+ op: Op,
+ lhs: &dyn Datum,
+ rhs: &dyn Datum,
+) -> Result<BooleanArray, ArrowError> {
+ use arrow_schema::DataType::*;
+ let (l, l_s) = lhs.get();
+ let (r, r_s) = rhs.get();
+
+ let l_len = l.len();
+ let r_len = r.len();
+ let l_nulls = l.logical_nulls();
+ let r_nulls = r.logical_nulls();
+
+ let (len, nulls) = match (l_s, r_s) {
+ (true, true) | (false, false) => {
+ if l_len != r_len {
+ return Err(ArrowError::InvalidArgumentError(format!(
+ "Cannot compare arrays of different lengths, got {l_len}
vs {r_len}"
+ )));
+ }
+ (l_len, NullBuffer::union(l_nulls.as_ref(), r_nulls.as_ref()))
+ }
+ (true, false) => match l_nulls.map(|x| x.null_count() !=
0).unwrap_or_default() {
+ true => (r_len, Some(NullBuffer::new_null(r_len))),
+ false => (r_len, r_nulls), // Left is scalar and not null
+ },
+ (false, true) => match r_nulls.map(|x| x.null_count() !=
0).unwrap_or_default() {
+ true => (l_len, Some(NullBuffer::new_null(l_len))),
+ false => (l_len, l_nulls), // Right is scalar and not null
+ },
+ };
+
+ let l_v = as_dictionary(l);
+ let l = l_v.map(|x| x.values().as_ref()).unwrap_or(l);
+
+ let r_v = as_dictionary(r);
+ let r = r_v.map(|x| x.values().as_ref()).unwrap_or(r);
+
+ let values = downcast_primitive_array! {
+ (l, r) => apply(op, l.values().as_ref(), l_s, l_v,
r.values().as_ref(), r_s, r_v),
+ (Boolean, Boolean) => apply(op, l.as_boolean(), l_s, l_v,
r.as_boolean(), r_s, r_v),
+ (Utf8, Utf8) => apply(op, l.as_string::<i32>(), l_s, l_v,
r.as_string::<i32>(), r_s, r_v),
+ (LargeUtf8, LargeUtf8) => apply(op, l.as_string::<i64>(), l_s, l_v,
r.as_string::<i64>(), r_s, r_v),
+ (Binary, Binary) => apply(op, l.as_binary::<i32>(), l_s, l_v,
r.as_binary::<i32>(), r_s, r_v),
+ (LargeBinary, LargeBinary) => apply(op, l.as_binary::<i64>(), l_s,
l_v, r.as_binary::<i64>(), r_s, r_v),
+ (FixedSizeBinary(_), FixedSizeBinary(_)) => apply(op,
l.as_fixed_size_binary(), l_s, l_v, r.as_fixed_size_binary(), r_s, r_v),
+ (l_t, r_t) => return
Err(ArrowError::InvalidArgumentError(format!("Invalid comparison operation:
{l_t} {op} {r_t}"))),
+ }.unwrap_or_else(|| {
+ let count = nulls.as_ref().map(|x| x.null_count()).unwrap_or_default();
+ assert_eq!(count, len); // Sanity check
+ BooleanBuffer::new_unset(len)
+ });
+
+ assert_eq!(values.len(), len); // Sanity check
+ Ok(BooleanArray::new(values, nulls))
+}
+
+fn as_dictionary(a: &dyn Array) -> Option<&dyn Dictionary> {
Review Comment:
I was confused at first with this naming as it shadows
https://docs.rs/arrow/latest/arrow/array/trait.AsArray.html#method.as_dictionary
Maybe we could call it `as_dyn_dictionary` or something 🤔
##########
arrow-ord/src/cmp.rs:
##########
@@ -174,21 +179,52 @@ fn compare_op(
(LargeBinary, LargeBinary) => apply(op, l.as_binary::<i64>(), l_s,
l_v, r.as_binary::<i64>(), r_s, r_v),
(FixedSizeBinary(_), FixedSizeBinary(_)) => apply(op,
l.as_fixed_size_binary(), l_s, l_v, r.as_fixed_size_binary(), r_s, r_v),
(l_t, r_t) => return
Err(ArrowError::InvalidArgumentError(format!("Invalid comparison operation:
{l_t} {op} {r_t}"))),
- };
+ }.unwrap_or_else(|| {
+ let count = nulls.as_ref().map(|x| x.null_count()).unwrap_or_default();
+ assert_eq!(count, len); // Sanity check
+ BooleanBuffer::new_unset(len)
+ });
assert_eq!(values.len(), len); // Sanity check
Ok(BooleanArray::new(values, nulls))
}
-fn values(a: &dyn Array) -> (Option<Vec<usize>>, &dyn Array) {
+fn as_dictionary(a: &dyn Array) -> Option<&dyn Dictionary> {
downcast_dictionary_array! {
- a => {
- let v = a.values().as_ref();
- let v_len = v.len();
- let keys = a.keys().values().iter().map(|x|
x.as_usize().min(v_len)).collect();
- (Some(keys), v)
- }
- _ => (None, a)
+ a => Some(a),
+ _ => None
+ }
+}
+
+trait Dictionary: Array {
Review Comment:
I agree -- methods like "apply unary function to values, returning a
Dictionary of the same type" would be very helpful. Given the right
documentation I think this would be very helpful and less confusing
One thought I had is if there is a more general formulation that might be
helpful (for example, for DictionaryArrays REE Arrays, and maybe
StringViewArrays), although maybe they just could have their own functions
##########
arrow-ord/src/cmp.rs:
##########
@@ -0,0 +1,527 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Comparison kernels for `Array`s.
+//!
+//! These kernels can leverage SIMD if available on your system. Currently no
runtime
+//! detection is provided, you should enable the specific SIMD intrinsics using
+//! `RUSTFLAGS="-C target-feature=+avx2"` for example. See the documentation
+//! [here](https://doc.rust-lang.org/stable/core/arch/) for more information.
+//!
+
+use arrow_array::cast::AsArray;
+use arrow_array::types::{ArrowDictionaryKeyType, ByteArrayType};
+use arrow_array::{
+ downcast_dictionary_array, downcast_primitive_array, Array, ArrayRef,
+ ArrowNativeTypeOp, BooleanArray, Datum, DictionaryArray,
FixedSizeBinaryArray,
+ GenericByteArray,
+};
+use arrow_buffer::bit_util::ceil;
+use arrow_buffer::{ArrowNativeType, BooleanBuffer, MutableBuffer, NullBuffer};
+use arrow_schema::ArrowError;
+use arrow_select::take::take;
+
+#[derive(Debug, Copy, Clone)]
+enum Op {
+ Equal,
+ NotEqual,
+ Less,
+ LessEqual,
+ Greater,
+ GreaterEqual,
+}
+
+impl std::fmt::Display for Op {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ match self {
+ Op::Equal => write!(f, "=="),
+ Op::NotEqual => write!(f, "!="),
+ Op::Less => write!(f, "<"),
+ Op::LessEqual => write!(f, "<="),
+ Op::Greater => write!(f, ">"),
+ Op::GreaterEqual => write!(f, ">="),
+ }
+ }
+}
+
+/// Perform `left == right` operation on two [`Datum`]
+///
+/// For floating values like f32 and f64, this comparison produces an ordering
in accordance to
+/// the totalOrder predicate as defined in the IEEE 754 (2008 revision)
floating point standard.
+/// Note that totalOrder treats positive and negative zeros as different. If
it is necessary
+/// to treat them as equal, please normalize zeros before calling this kernel.
+///
+/// Please refer to [`f32::total_cmp`] and [`f64::total_cmp`]
+pub fn eq(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray,
ArrowError> {
+ compare_op(Op::Equal, lhs, rhs)
+}
+
+/// Perform `left != right` operation on two [`Datum`]
+///
+/// For floating values like f32 and f64, this comparison produces an ordering
in accordance to
+/// the totalOrder predicate as defined in the IEEE 754 (2008 revision)
floating point standard.
+/// Note that totalOrder treats positive and negative zeros as different. If
it is necessary
+/// to treat them as equal, please normalize zeros before calling this kernel.
+///
+/// Please refer to [`f32::total_cmp`] and [`f64::total_cmp`]
+pub fn neq(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray,
ArrowError> {
+ compare_op(Op::NotEqual, lhs, rhs)
+}
+
+/// Perform `left < right` operation on two [`Datum`]
+///
+/// For floating values like f32 and f64, this comparison produces an ordering
in accordance to
+/// the totalOrder predicate as defined in the IEEE 754 (2008 revision)
floating point standard.
+/// Note that totalOrder treats positive and negative zeros as different. If
it is necessary
+/// to treat them as equal, please normalize zeros before calling this kernel.
+///
+/// Please refer to [`f32::total_cmp`] and [`f64::total_cmp`]
+pub fn lt(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray,
ArrowError> {
+ compare_op(Op::Less, lhs, rhs)
+}
+
+/// Perform `left <= right` operation on two [`Datum`]
+///
+/// For floating values like f32 and f64, this comparison produces an ordering
in accordance to
+/// the totalOrder predicate as defined in the IEEE 754 (2008 revision)
floating point standard.
+/// Note that totalOrder treats positive and negative zeros as different. If
it is necessary
+/// to treat them as equal, please normalize zeros before calling this kernel.
+///
+/// Please refer to [`f32::total_cmp`] and [`f64::total_cmp`]
+pub fn lt_eq(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray,
ArrowError> {
+ compare_op(Op::LessEqual, lhs, rhs)
+}
+
+/// Perform `left > right` operation on two [`Datum`]
+///
+/// For floating values like f32 and f64, this comparison produces an ordering
in accordance to
+/// the totalOrder predicate as defined in the IEEE 754 (2008 revision)
floating point standard.
+/// Note that totalOrder treats positive and negative zeros as different. If
it is necessary
+/// to treat them as equal, please normalize zeros before calling this kernel.
+///
+/// Please refer to [`f32::total_cmp`] and [`f64::total_cmp`]
+pub fn gt(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray,
ArrowError> {
+ compare_op(Op::Greater, lhs, rhs)
+}
+
+/// Perform `left >= right` operation on two [`Datum`]
+///
+/// For floating values like f32 and f64, this comparison produces an ordering
in accordance to
+/// the totalOrder predicate as defined in the IEEE 754 (2008 revision)
floating point standard.
+/// Note that totalOrder treats positive and negative zeros as different. If
it is necessary
+/// to treat them as equal, please normalize zeros before calling this kernel.
+///
+/// Please refer to [`f32::total_cmp`] and [`f64::total_cmp`]
+pub fn gt_eq(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray,
ArrowError> {
+ compare_op(Op::GreaterEqual, lhs, rhs)
+}
+
+/// Perform `op` on the provided `Datum`
+fn compare_op(
+ op: Op,
+ lhs: &dyn Datum,
+ rhs: &dyn Datum,
+) -> Result<BooleanArray, ArrowError> {
+ use arrow_schema::DataType::*;
+ let (l, l_s) = lhs.get();
+ let (r, r_s) = rhs.get();
+
+ let l_len = l.len();
+ let r_len = r.len();
+ let l_nulls = l.logical_nulls();
+ let r_nulls = r.logical_nulls();
+
+ let (len, nulls) = match (l_s, r_s) {
+ (true, true) | (false, false) => {
+ if l_len != r_len {
+ return Err(ArrowError::InvalidArgumentError(format!(
+ "Cannot compare arrays of different lengths, got {l_len}
vs {r_len}"
+ )));
+ }
+ (l_len, NullBuffer::union(l_nulls.as_ref(), r_nulls.as_ref()))
+ }
+ (true, false) => match l_nulls.map(|x| x.null_count() !=
0).unwrap_or_default() {
+ true => (r_len, Some(NullBuffer::new_null(r_len))),
+ false => (r_len, r_nulls), // Left is scalar and not null
+ },
+ (false, true) => match r_nulls.map(|x| x.null_count() !=
0).unwrap_or_default() {
+ true => (l_len, Some(NullBuffer::new_null(l_len))),
+ false => (l_len, l_nulls), // Right is scalar and not null
+ },
+ };
+
+ let l_v = as_dictionary(l);
+ let l = l_v.map(|x| x.values().as_ref()).unwrap_or(l);
+
+ let r_v = as_dictionary(r);
+ let r = r_v.map(|x| x.values().as_ref()).unwrap_or(r);
+
+ let values = downcast_primitive_array! {
+ (l, r) => apply(op, l.values().as_ref(), l_s, l_v,
r.values().as_ref(), r_s, r_v),
+ (Boolean, Boolean) => apply(op, l.as_boolean(), l_s, l_v,
r.as_boolean(), r_s, r_v),
+ (Utf8, Utf8) => apply(op, l.as_string::<i32>(), l_s, l_v,
r.as_string::<i32>(), r_s, r_v),
+ (LargeUtf8, LargeUtf8) => apply(op, l.as_string::<i64>(), l_s, l_v,
r.as_string::<i64>(), r_s, r_v),
+ (Binary, Binary) => apply(op, l.as_binary::<i32>(), l_s, l_v,
r.as_binary::<i32>(), r_s, r_v),
+ (LargeBinary, LargeBinary) => apply(op, l.as_binary::<i64>(), l_s,
l_v, r.as_binary::<i64>(), r_s, r_v),
+ (FixedSizeBinary(_), FixedSizeBinary(_)) => apply(op,
l.as_fixed_size_binary(), l_s, l_v, r.as_fixed_size_binary(), r_s, r_v),
+ (l_t, r_t) => return
Err(ArrowError::InvalidArgumentError(format!("Invalid comparison operation:
{l_t} {op} {r_t}"))),
+ }.unwrap_or_else(|| {
+ let count = nulls.as_ref().map(|x| x.null_count()).unwrap_or_default();
+ assert_eq!(count, len); // Sanity check
+ BooleanBuffer::new_unset(len)
+ });
+
+ assert_eq!(values.len(), len); // Sanity check
+ Ok(BooleanArray::new(values, nulls))
+}
+
+fn as_dictionary(a: &dyn Array) -> Option<&dyn Dictionary> {
+ downcast_dictionary_array! {
+ a => Some(a),
+ _ => None
+ }
+}
+
+trait Dictionary: Array {
+ /// Returns the keys of this dictionary, clamped to be in the range
`0..values.len()`
+ ///
+ /// # Panic
+ ///
+ /// Panics if `values.len() == 0`
+ fn normalized_keys(&self) -> Vec<usize>;
+
+ /// Returns the values of this dictionary
+ fn values(&self) -> &ArrayRef;
+
+ /// Applies the `keys` of this dictionary to the provided array
+ fn take(&self, array: &dyn Array) -> Result<ArrayRef, ArrowError>;
+}
+
+impl<K: ArrowDictionaryKeyType> Dictionary for DictionaryArray<K> {
+ fn normalized_keys(&self) -> Vec<usize> {
+ let v_len = self.values().len();
+ assert_ne!(v_len, 0);
+ let iter = self.keys().values().iter();
+ iter.map(|x| x.as_usize().min(v_len)).collect()
+ }
+
+ fn values(&self) -> &ArrayRef {
+ self.values()
+ }
+
+ fn take(&self, array: &dyn Array) -> Result<ArrayRef, ArrowError> {
+ take(array, self.keys(), None)
+ }
+}
+
+/// Perform a potentially vectored `op` on the provided `ArrayOrd`
+fn apply<T: ArrayOrd>(
+ op: Op,
+ l: T,
+ l_s: bool,
+ l_v: Option<&dyn Dictionary>,
+ r: T,
+ r_s: bool,
+ r_v: Option<&dyn Dictionary>,
+) -> Option<BooleanBuffer> {
+ if l.len() == 0 || r.len() == 0 {
+ return None; // Handle empty dictionaries
+ }
+
+ if !l_s && !r_s && (l_v.is_some() || r_v.is_some()) {
+ // Not scalar and at least one side has a dictionary, need to perform
vectored comparison
+ let l_v = l_v
+ .map(|x| x.normalized_keys())
+ .unwrap_or_else(|| (0..l.len()).collect());
+
+ let r_v = r_v
+ .map(|x| x.normalized_keys())
+ .unwrap_or_else(|| (0..r.len()).collect());
+
+ assert_eq!(l_v.len(), r_v.len()); // Sanity check
+
+ Some(match op {
+ Op::Equal => apply_op_vectored(l, &l_v, r, &r_v, false, T::is_eq),
+ Op::NotEqual => apply_op_vectored(l, &l_v, r, &r_v, true,
T::is_eq),
+ Op::Less => apply_op_vectored(l, &l_v, r, &r_v, false, T::is_lt),
+ Op::LessEqual => apply_op_vectored(r, &r_v, l, &l_v, true,
T::is_lt),
+ Op::Greater => apply_op_vectored(r, &r_v, l, &l_v, false,
T::is_lt),
+ Op::GreaterEqual => apply_op_vectored(l, &l_v, r, &r_v, true,
T::is_lt),
+ })
+ } else {
+ let l_s = l_s.then(|| l_v.map(|x|
x.normalized_keys()[0]).unwrap_or_default());
+ let r_s = r_s.then(|| r_v.map(|x|
x.normalized_keys()[0]).unwrap_or_default());
+
+ let buffer = match op {
+ Op::Equal => apply_op(l, l_s, r, r_s, false, T::is_eq),
+ Op::NotEqual => apply_op(l, l_s, r, r_s, true, T::is_eq),
+ Op::Less => apply_op(l, l_s, r, r_s, false, T::is_lt),
+ Op::LessEqual => apply_op(r, r_s, l, l_s, true, T::is_lt),
+ Op::Greater => apply_op(r, r_s, l, l_s, false, T::is_lt),
+ Op::GreaterEqual => apply_op(l, l_s, r, r_s, true, T::is_lt),
+ };
+
+ // If a side had a dictionary, and was not scalar, we need to
materialize this
+ Some(match (l_v, r_v) {
+ (Some(l_v), _) if l_s.is_none() => take_bits(l_v, buffer),
+ (_, Some(r_v)) if r_s.is_none() => take_bits(r_v, buffer),
+ _ => buffer,
+ })
+ }
+}
+
+/// Perform a take operation on `buffer` with the given dictionary
+fn take_bits(v: &dyn Dictionary, buffer: BooleanBuffer) -> BooleanBuffer {
+ let array = v.take(&BooleanArray::new(buffer, None)).unwrap();
+ array.as_boolean().values().clone()
+}
+
+/// Invokes `f` with values `0..len` collecting the boolean results into a new
`BooleanBuffer`
+///
+/// This is similar to [`MutableBuffer::collect_bool`] but with
+/// the option to efficiently negate the result
+fn collect_bool(len: usize, neg: bool, f: impl Fn(usize) -> bool) ->
BooleanBuffer {
+ let mut buffer = MutableBuffer::new(ceil(len, 64) * 8);
+
+ let chunks = len / 64;
+ let remainder = len % 64;
+ for chunk in 0..chunks {
+ let mut packed = 0;
+ for bit_idx in 0..64 {
+ let i = bit_idx + chunk * 64;
+ packed |= (f(i) as u64) << bit_idx;
+ }
+ if neg {
+ packed = !packed
+ }
+
+ // SAFETY: Already allocated sufficient capacity
+ unsafe { buffer.push_unchecked(packed) }
+ }
+
+ if remainder != 0 {
+ let mut packed = 0;
+ for bit_idx in 0..remainder {
+ let i = bit_idx + chunks * 64;
+ packed |= (f(i) as u64) << bit_idx;
+ }
+ if neg {
+ packed = !packed
+ }
+
+ // SAFETY: Already allocated sufficient capacity
+ unsafe { buffer.push_unchecked(packed) }
+ }
+ BooleanBuffer::new(buffer.into(), 0, len)
+}
+
+/// Applies `op` to possibly scalar `ArrayOrd`
+///
+/// If l is scalar `l_s` will be `Some(idx)` where `idx` is the index of the
scalar value in `l`
+/// If r is scalar `r_s` will be `Some(idx)` where `idx` is the index of the
scalar value in `r`
+fn apply_op<T: ArrayOrd>(
+ l: T,
+ l_s: Option<usize>,
+ r: T,
+ r_s: Option<usize>,
+ neg: bool,
+ op: impl Fn(T::Item, T::Item) -> bool,
+) -> BooleanBuffer {
+ match (l_s, r_s) {
+ (None, None) => {
+ assert_eq!(l.len(), r.len());
+ collect_bool(l.len(), neg, |idx| unsafe {
Review Comment:
is it worth adding a safety note here related to the fact that the index is
coming from a valid array?
##########
arrow-ord/src/cmp.rs:
##########
@@ -0,0 +1,527 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Comparison kernels for `Array`s.
+//!
+//! These kernels can leverage SIMD if available on your system. Currently no
runtime
+//! detection is provided, you should enable the specific SIMD intrinsics using
+//! `RUSTFLAGS="-C target-feature=+avx2"` for example. See the documentation
+//! [here](https://doc.rust-lang.org/stable/core/arch/) for more information.
+//!
+
+use arrow_array::cast::AsArray;
+use arrow_array::types::{ArrowDictionaryKeyType, ByteArrayType};
+use arrow_array::{
+ downcast_dictionary_array, downcast_primitive_array, Array, ArrayRef,
+ ArrowNativeTypeOp, BooleanArray, Datum, DictionaryArray,
FixedSizeBinaryArray,
+ GenericByteArray,
+};
+use arrow_buffer::bit_util::ceil;
+use arrow_buffer::{ArrowNativeType, BooleanBuffer, MutableBuffer, NullBuffer};
+use arrow_schema::ArrowError;
+use arrow_select::take::take;
+
+#[derive(Debug, Copy, Clone)]
+enum Op {
+ Equal,
+ NotEqual,
+ Less,
+ LessEqual,
+ Greater,
+ GreaterEqual,
+}
+
+impl std::fmt::Display for Op {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ match self {
+ Op::Equal => write!(f, "=="),
+ Op::NotEqual => write!(f, "!="),
+ Op::Less => write!(f, "<"),
+ Op::LessEqual => write!(f, "<="),
+ Op::Greater => write!(f, ">"),
+ Op::GreaterEqual => write!(f, ">="),
+ }
+ }
+}
+
+/// Perform `left == right` operation on two [`Datum`]
+///
+/// For floating values like f32 and f64, this comparison produces an ordering
in accordance to
+/// the totalOrder predicate as defined in the IEEE 754 (2008 revision)
floating point standard.
+/// Note that totalOrder treats positive and negative zeros as different. If
it is necessary
+/// to treat them as equal, please normalize zeros before calling this kernel.
+///
+/// Please refer to [`f32::total_cmp`] and [`f64::total_cmp`]
+pub fn eq(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray,
ArrowError> {
+ compare_op(Op::Equal, lhs, rhs)
+}
+
+/// Perform `left != right` operation on two [`Datum`]
+///
+/// For floating values like f32 and f64, this comparison produces an ordering
in accordance to
+/// the totalOrder predicate as defined in the IEEE 754 (2008 revision)
floating point standard.
+/// Note that totalOrder treats positive and negative zeros as different. If
it is necessary
+/// to treat them as equal, please normalize zeros before calling this kernel.
+///
+/// Please refer to [`f32::total_cmp`] and [`f64::total_cmp`]
+pub fn neq(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray,
ArrowError> {
+ compare_op(Op::NotEqual, lhs, rhs)
+}
+
+/// Perform `left < right` operation on two [`Datum`]
+///
+/// For floating values like f32 and f64, this comparison produces an ordering
in accordance to
+/// the totalOrder predicate as defined in the IEEE 754 (2008 revision)
floating point standard.
+/// Note that totalOrder treats positive and negative zeros as different. If
it is necessary
+/// to treat them as equal, please normalize zeros before calling this kernel.
+///
+/// Please refer to [`f32::total_cmp`] and [`f64::total_cmp`]
+pub fn lt(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray,
ArrowError> {
+ compare_op(Op::Less, lhs, rhs)
+}
+
+/// Perform `left <= right` operation on two [`Datum`]
+///
+/// For floating values like f32 and f64, this comparison produces an ordering
in accordance to
+/// the totalOrder predicate as defined in the IEEE 754 (2008 revision)
floating point standard.
+/// Note that totalOrder treats positive and negative zeros as different. If
it is necessary
+/// to treat them as equal, please normalize zeros before calling this kernel.
+///
+/// Please refer to [`f32::total_cmp`] and [`f64::total_cmp`]
+pub fn lt_eq(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray,
ArrowError> {
+ compare_op(Op::LessEqual, lhs, rhs)
+}
+
+/// Perform `left > right` operation on two [`Datum`]
+///
+/// For floating values like f32 and f64, this comparison produces an ordering
in accordance to
+/// the totalOrder predicate as defined in the IEEE 754 (2008 revision)
floating point standard.
+/// Note that totalOrder treats positive and negative zeros as different. If
it is necessary
+/// to treat them as equal, please normalize zeros before calling this kernel.
+///
+/// Please refer to [`f32::total_cmp`] and [`f64::total_cmp`]
+pub fn gt(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray,
ArrowError> {
+ compare_op(Op::Greater, lhs, rhs)
+}
+
+/// Perform `left >= right` operation on two [`Datum`]
+///
+/// For floating values like f32 and f64, this comparison produces an ordering
in accordance to
+/// the totalOrder predicate as defined in the IEEE 754 (2008 revision)
floating point standard.
+/// Note that totalOrder treats positive and negative zeros as different. If
it is necessary
+/// to treat them as equal, please normalize zeros before calling this kernel.
+///
+/// Please refer to [`f32::total_cmp`] and [`f64::total_cmp`]
+pub fn gt_eq(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray,
ArrowError> {
+ compare_op(Op::GreaterEqual, lhs, rhs)
+}
+
+/// Perform `op` on the provided `Datum`
+fn compare_op(
+ op: Op,
+ lhs: &dyn Datum,
+ rhs: &dyn Datum,
+) -> Result<BooleanArray, ArrowError> {
+ use arrow_schema::DataType::*;
+ let (l, l_s) = lhs.get();
+ let (r, r_s) = rhs.get();
+
+ let l_len = l.len();
+ let r_len = r.len();
+ let l_nulls = l.logical_nulls();
+ let r_nulls = r.logical_nulls();
+
+ let (len, nulls) = match (l_s, r_s) {
+ (true, true) | (false, false) => {
+ if l_len != r_len {
+ return Err(ArrowError::InvalidArgumentError(format!(
+ "Cannot compare arrays of different lengths, got {l_len}
vs {r_len}"
+ )));
+ }
+ (l_len, NullBuffer::union(l_nulls.as_ref(), r_nulls.as_ref()))
+ }
+ (true, false) => match l_nulls.map(|x| x.null_count() !=
0).unwrap_or_default() {
+ true => (r_len, Some(NullBuffer::new_null(r_len))),
+ false => (r_len, r_nulls), // Left is scalar and not null
+ },
+ (false, true) => match r_nulls.map(|x| x.null_count() !=
0).unwrap_or_default() {
+ true => (l_len, Some(NullBuffer::new_null(l_len))),
+ false => (l_len, l_nulls), // Right is scalar and not null
+ },
+ };
+
+ let l_v = as_dictionary(l);
+ let l = l_v.map(|x| x.values().as_ref()).unwrap_or(l);
+
+ let r_v = as_dictionary(r);
+ let r = r_v.map(|x| x.values().as_ref()).unwrap_or(r);
+
+ let values = downcast_primitive_array! {
+ (l, r) => apply(op, l.values().as_ref(), l_s, l_v,
r.values().as_ref(), r_s, r_v),
+ (Boolean, Boolean) => apply(op, l.as_boolean(), l_s, l_v,
r.as_boolean(), r_s, r_v),
+ (Utf8, Utf8) => apply(op, l.as_string::<i32>(), l_s, l_v,
r.as_string::<i32>(), r_s, r_v),
+ (LargeUtf8, LargeUtf8) => apply(op, l.as_string::<i64>(), l_s, l_v,
r.as_string::<i64>(), r_s, r_v),
+ (Binary, Binary) => apply(op, l.as_binary::<i32>(), l_s, l_v,
r.as_binary::<i32>(), r_s, r_v),
+ (LargeBinary, LargeBinary) => apply(op, l.as_binary::<i64>(), l_s,
l_v, r.as_binary::<i64>(), r_s, r_v),
+ (FixedSizeBinary(_), FixedSizeBinary(_)) => apply(op,
l.as_fixed_size_binary(), l_s, l_v, r.as_fixed_size_binary(), r_s, r_v),
+ (l_t, r_t) => return
Err(ArrowError::InvalidArgumentError(format!("Invalid comparison operation:
{l_t} {op} {r_t}"))),
+ }.unwrap_or_else(|| {
+ let count = nulls.as_ref().map(|x| x.null_count()).unwrap_or_default();
+ assert_eq!(count, len); // Sanity check
+ BooleanBuffer::new_unset(len)
+ });
+
+ assert_eq!(values.len(), len); // Sanity check
+ Ok(BooleanArray::new(values, nulls))
+}
+
+fn as_dictionary(a: &dyn Array) -> Option<&dyn Dictionary> {
+ downcast_dictionary_array! {
+ a => Some(a),
+ _ => None
+ }
+}
+
+trait Dictionary: Array {
+ /// Returns the keys of this dictionary, clamped to be in the range
`0..values.len()`
+ ///
+ /// # Panic
+ ///
+ /// Panics if `values.len() == 0`
+ fn normalized_keys(&self) -> Vec<usize>;
+
+ /// Returns the values of this dictionary
+ fn values(&self) -> &ArrayRef;
+
+ /// Applies the `keys` of this dictionary to the provided array
+ fn take(&self, array: &dyn Array) -> Result<ArrayRef, ArrowError>;
+}
+
+impl<K: ArrowDictionaryKeyType> Dictionary for DictionaryArray<K> {
+ fn normalized_keys(&self) -> Vec<usize> {
+ let v_len = self.values().len();
+ assert_ne!(v_len, 0);
+ let iter = self.keys().values().iter();
+ iter.map(|x| x.as_usize().min(v_len)).collect()
+ }
+
+ fn values(&self) -> &ArrayRef {
+ self.values()
+ }
+
+ fn take(&self, array: &dyn Array) -> Result<ArrayRef, ArrowError> {
+ take(array, self.keys(), None)
+ }
+}
+
+/// Perform a potentially vectored `op` on the provided `ArrayOrd`
+fn apply<T: ArrayOrd>(
+ op: Op,
+ l: T,
+ l_s: bool,
+ l_v: Option<&dyn Dictionary>,
+ r: T,
+ r_s: bool,
+ r_v: Option<&dyn Dictionary>,
+) -> Option<BooleanBuffer> {
+ if l.len() == 0 || r.len() == 0 {
+ return None; // Handle empty dictionaries
+ }
+
+ if !l_s && !r_s && (l_v.is_some() || r_v.is_some()) {
+ // Not scalar and at least one side has a dictionary, need to perform
vectored comparison
+ let l_v = l_v
+ .map(|x| x.normalized_keys())
+ .unwrap_or_else(|| (0..l.len()).collect());
+
+ let r_v = r_v
+ .map(|x| x.normalized_keys())
+ .unwrap_or_else(|| (0..r.len()).collect());
+
+ assert_eq!(l_v.len(), r_v.len()); // Sanity check
+
+ Some(match op {
+ Op::Equal => apply_op_vectored(l, &l_v, r, &r_v, false, T::is_eq),
+ Op::NotEqual => apply_op_vectored(l, &l_v, r, &r_v, true,
T::is_eq),
+ Op::Less => apply_op_vectored(l, &l_v, r, &r_v, false, T::is_lt),
+ Op::LessEqual => apply_op_vectored(r, &r_v, l, &l_v, true,
T::is_lt),
+ Op::Greater => apply_op_vectored(r, &r_v, l, &l_v, false,
T::is_lt),
+ Op::GreaterEqual => apply_op_vectored(l, &l_v, r, &r_v, true,
T::is_lt),
+ })
+ } else {
+ let l_s = l_s.then(|| l_v.map(|x|
x.normalized_keys()[0]).unwrap_or_default());
+ let r_s = r_s.then(|| r_v.map(|x|
x.normalized_keys()[0]).unwrap_or_default());
+
+ let buffer = match op {
+ Op::Equal => apply_op(l, l_s, r, r_s, false, T::is_eq),
+ Op::NotEqual => apply_op(l, l_s, r, r_s, true, T::is_eq),
+ Op::Less => apply_op(l, l_s, r, r_s, false, T::is_lt),
+ Op::LessEqual => apply_op(r, r_s, l, l_s, true, T::is_lt),
+ Op::Greater => apply_op(r, r_s, l, l_s, false, T::is_lt),
+ Op::GreaterEqual => apply_op(l, l_s, r, r_s, true, T::is_lt),
+ };
+
+ // If a side had a dictionary, and was not scalar, we need to
materialize this
+ Some(match (l_v, r_v) {
+ (Some(l_v), _) if l_s.is_none() => take_bits(l_v, buffer),
+ (_, Some(r_v)) if r_s.is_none() => take_bits(r_v, buffer),
+ _ => buffer,
+ })
+ }
+}
+
+/// Perform a take operation on `buffer` with the given dictionary
+fn take_bits(v: &dyn Dictionary, buffer: BooleanBuffer) -> BooleanBuffer {
+ let array = v.take(&BooleanArray::new(buffer, None)).unwrap();
+ array.as_boolean().values().clone()
+}
+
+/// Invokes `f` with values `0..len` collecting the boolean results into a new
`BooleanBuffer`
+///
+/// This is similar to [`MutableBuffer::collect_bool`] but with
+/// the option to efficiently negate the result
+fn collect_bool(len: usize, neg: bool, f: impl Fn(usize) -> bool) ->
BooleanBuffer {
+ let mut buffer = MutableBuffer::new(ceil(len, 64) * 8);
+
+ let chunks = len / 64;
+ let remainder = len % 64;
+ for chunk in 0..chunks {
+ let mut packed = 0;
+ for bit_idx in 0..64 {
+ let i = bit_idx + chunk * 64;
+ packed |= (f(i) as u64) << bit_idx;
+ }
+ if neg {
+ packed = !packed
+ }
+
+ // SAFETY: Already allocated sufficient capacity
+ unsafe { buffer.push_unchecked(packed) }
+ }
+
+ if remainder != 0 {
+ let mut packed = 0;
+ for bit_idx in 0..remainder {
+ let i = bit_idx + chunks * 64;
+ packed |= (f(i) as u64) << bit_idx;
+ }
+ if neg {
+ packed = !packed
+ }
+
+ // SAFETY: Already allocated sufficient capacity
+ unsafe { buffer.push_unchecked(packed) }
+ }
+ BooleanBuffer::new(buffer.into(), 0, len)
+}
+
+/// Applies `op` to possibly scalar `ArrayOrd`
+///
+/// If l is scalar `l_s` will be `Some(idx)` where `idx` is the index of the
scalar value in `l`
+/// If r is scalar `r_s` will be `Some(idx)` where `idx` is the index of the
scalar value in `r`
+fn apply_op<T: ArrayOrd>(
+ l: T,
+ l_s: Option<usize>,
+ r: T,
+ r_s: Option<usize>,
+ neg: bool,
+ op: impl Fn(T::Item, T::Item) -> bool,
+) -> BooleanBuffer {
+ match (l_s, r_s) {
+ (None, None) => {
+ assert_eq!(l.len(), r.len());
+ collect_bool(l.len(), neg, |idx| unsafe {
+ op(l.value_unchecked(idx), r.value_unchecked(idx))
+ })
+ }
+ (Some(l_s), Some(r_s)) => {
+ let a = l.value(l_s);
+ let b = r.value(r_s);
+ std::iter::once(op(a, b)).collect()
+ }
+ (Some(l_s), None) => {
+ let v = l.value(l_s);
+ collect_bool(r.len(), neg, |idx| op(v, unsafe {
r.value_unchecked(idx) }))
+ }
+ (None, Some(r_s)) => {
+ let v = r.value(r_s);
+ collect_bool(l.len(), neg, |idx| op(unsafe {
l.value_unchecked(idx) }, v))
+ }
+ }
+}
+
+/// Applies `op` to possibly scalar `ArrayOrd` with the given indices
+fn apply_op_vectored<T: ArrayOrd>(
+ l: T,
+ l_v: &[usize],
+ r: T,
+ r_v: &[usize],
+ neg: bool,
+ op: impl Fn(T::Item, T::Item) -> bool,
+) -> BooleanBuffer {
+ assert_eq!(l_v.len(), r_v.len());
+ collect_bool(l_v.len(), neg, |idx| unsafe {
Review Comment:
same comment here about documenting the safety of using `idx`
##########
arrow-ord/src/lib.rs:
##########
@@ -43,6 +43,8 @@
//! ```
//!
+pub mod cmp;
+#[doc(hidden)]
Review Comment:
What is the reason for `hidden` -- is it to discourage new uses of these
functions? I also note you added "deprecated" notes to the functions in there,
which should help people migrate
##########
arrow-ord/src/cmp.rs:
##########
@@ -0,0 +1,527 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Comparison kernels for `Array`s.
+//!
+//! These kernels can leverage SIMD if available on your system. Currently no
runtime
+//! detection is provided, you should enable the specific SIMD intrinsics using
+//! `RUSTFLAGS="-C target-feature=+avx2"` for example. See the documentation
+//! [here](https://doc.rust-lang.org/stable/core/arch/) for more information.
+//!
+
+use arrow_array::cast::AsArray;
+use arrow_array::types::{ArrowDictionaryKeyType, ByteArrayType};
+use arrow_array::{
+ downcast_dictionary_array, downcast_primitive_array, Array, ArrayRef,
+ ArrowNativeTypeOp, BooleanArray, Datum, DictionaryArray,
FixedSizeBinaryArray,
+ GenericByteArray,
+};
+use arrow_buffer::bit_util::ceil;
+use arrow_buffer::{ArrowNativeType, BooleanBuffer, MutableBuffer, NullBuffer};
+use arrow_schema::ArrowError;
+use arrow_select::take::take;
+
+#[derive(Debug, Copy, Clone)]
+enum Op {
+ Equal,
+ NotEqual,
+ Less,
+ LessEqual,
+ Greater,
+ GreaterEqual,
+}
+
+impl std::fmt::Display for Op {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ match self {
+ Op::Equal => write!(f, "=="),
+ Op::NotEqual => write!(f, "!="),
+ Op::Less => write!(f, "<"),
+ Op::LessEqual => write!(f, "<="),
+ Op::Greater => write!(f, ">"),
+ Op::GreaterEqual => write!(f, ">="),
+ }
+ }
+}
+
+/// Perform `left == right` operation on two [`Datum`]
+///
+/// For floating values like f32 and f64, this comparison produces an ordering
in accordance to
Review Comment:
What does "comparison produces an ordering" mean? The output of this kernel
is a BooleanArray
Maybe it should say something like " For floating values like f32 and f64,
this comparison follows the definition of equality from the totalOrder
predicate as defined in the IEEE 754 (2008 revision) floating point standard."
🤔
The same comment applies to `neq` as well
##########
arrow-ord/src/comparison.rs:
##########
@@ -5231,21 +3871,12 @@ mod tests {
.into_iter()
.map(Some)
.collect();
- #[cfg(feature = "simd")]
- let expected = BooleanArray::from(
- vec![Some(false), Some(false), Some(false), Some(false),
Some(false)],
- );
- #[cfg(not(feature = "simd"))]
+
let expected = BooleanArray::from(
vec![Some(true), Some(false), Some(false), Some(false),
Some(false)],
);
assert_eq!(eq_dyn_scalar(&array, f32::NAN).unwrap(), expected);
- #[cfg(feature = "simd")]
- let expected = BooleanArray::from(
Review Comment:
that is quite strange to get a different (presumably wrong) answer when
using the `simd` feature flat 🤔
##########
arrow/benches/equal.rs:
##########
@@ -49,11 +44,6 @@ fn add_benchmark(c: &mut Criterion) {
let arr_a = create_string_array::<i32>(512, 0.0);
c.bench_function("equal_string_512", |b| b.iter(|| bench_equal(&arr_a)));
- let arr_a = create_string_array::<i32>(512, 0.0);
Review Comment:
why did you remove this benchmark?
##########
arrow-ord/src/cmp.rs:
##########
@@ -0,0 +1,527 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Comparison kernels for `Array`s.
+//!
+//! These kernels can leverage SIMD if available on your system. Currently no
runtime
+//! detection is provided, you should enable the specific SIMD intrinsics using
+//! `RUSTFLAGS="-C target-feature=+avx2"` for example. See the documentation
+//! [here](https://doc.rust-lang.org/stable/core/arch/) for more information.
+//!
+
+use arrow_array::cast::AsArray;
+use arrow_array::types::{ArrowDictionaryKeyType, ByteArrayType};
+use arrow_array::{
+ downcast_dictionary_array, downcast_primitive_array, Array, ArrayRef,
+ ArrowNativeTypeOp, BooleanArray, Datum, DictionaryArray,
FixedSizeBinaryArray,
+ GenericByteArray,
+};
+use arrow_buffer::bit_util::ceil;
+use arrow_buffer::{ArrowNativeType, BooleanBuffer, MutableBuffer, NullBuffer};
+use arrow_schema::ArrowError;
+use arrow_select::take::take;
+
+#[derive(Debug, Copy, Clone)]
+enum Op {
+ Equal,
+ NotEqual,
+ Less,
+ LessEqual,
+ Greater,
+ GreaterEqual,
+}
+
+impl std::fmt::Display for Op {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ match self {
+ Op::Equal => write!(f, "=="),
+ Op::NotEqual => write!(f, "!="),
+ Op::Less => write!(f, "<"),
+ Op::LessEqual => write!(f, "<="),
+ Op::Greater => write!(f, ">"),
+ Op::GreaterEqual => write!(f, ">="),
+ }
+ }
+}
+
+/// Perform `left == right` operation on two [`Datum`]
+///
+/// For floating values like f32 and f64, this comparison produces an ordering
in accordance to
+/// the totalOrder predicate as defined in the IEEE 754 (2008 revision)
floating point standard.
+/// Note that totalOrder treats positive and negative zeros as different. If
it is necessary
+/// to treat them as equal, please normalize zeros before calling this kernel.
+///
+/// Please refer to [`f32::total_cmp`] and [`f64::total_cmp`]
+pub fn eq(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray,
ArrowError> {
+ compare_op(Op::Equal, lhs, rhs)
+}
+
+/// Perform `left != right` operation on two [`Datum`]
+///
+/// For floating values like f32 and f64, this comparison produces an ordering
in accordance to
+/// the totalOrder predicate as defined in the IEEE 754 (2008 revision)
floating point standard.
+/// Note that totalOrder treats positive and negative zeros as different. If
it is necessary
+/// to treat them as equal, please normalize zeros before calling this kernel.
+///
+/// Please refer to [`f32::total_cmp`] and [`f64::total_cmp`]
+pub fn neq(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray,
ArrowError> {
+ compare_op(Op::NotEqual, lhs, rhs)
+}
+
+/// Perform `left < right` operation on two [`Datum`]
+///
+/// For floating values like f32 and f64, this comparison produces an ordering
in accordance to
+/// the totalOrder predicate as defined in the IEEE 754 (2008 revision)
floating point standard.
+/// Note that totalOrder treats positive and negative zeros as different. If
it is necessary
+/// to treat them as equal, please normalize zeros before calling this kernel.
+///
+/// Please refer to [`f32::total_cmp`] and [`f64::total_cmp`]
+pub fn lt(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray,
ArrowError> {
+ compare_op(Op::Less, lhs, rhs)
+}
+
+/// Perform `left <= right` operation on two [`Datum`]
+///
+/// For floating values like f32 and f64, this comparison produces an ordering
in accordance to
+/// the totalOrder predicate as defined in the IEEE 754 (2008 revision)
floating point standard.
+/// Note that totalOrder treats positive and negative zeros as different. If
it is necessary
+/// to treat them as equal, please normalize zeros before calling this kernel.
+///
+/// Please refer to [`f32::total_cmp`] and [`f64::total_cmp`]
+pub fn lt_eq(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray,
ArrowError> {
+ compare_op(Op::LessEqual, lhs, rhs)
+}
+
+/// Perform `left > right` operation on two [`Datum`]
+///
+/// For floating values like f32 and f64, this comparison produces an ordering
in accordance to
+/// the totalOrder predicate as defined in the IEEE 754 (2008 revision)
floating point standard.
+/// Note that totalOrder treats positive and negative zeros as different. If
it is necessary
+/// to treat them as equal, please normalize zeros before calling this kernel.
+///
+/// Please refer to [`f32::total_cmp`] and [`f64::total_cmp`]
+pub fn gt(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray,
ArrowError> {
+ compare_op(Op::Greater, lhs, rhs)
+}
+
+/// Perform `left >= right` operation on two [`Datum`]
+///
+/// For floating values like f32 and f64, this comparison produces an ordering
in accordance to
+/// the totalOrder predicate as defined in the IEEE 754 (2008 revision)
floating point standard.
+/// Note that totalOrder treats positive and negative zeros as different. If
it is necessary
+/// to treat them as equal, please normalize zeros before calling this kernel.
+///
+/// Please refer to [`f32::total_cmp`] and [`f64::total_cmp`]
+pub fn gt_eq(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray,
ArrowError> {
+ compare_op(Op::GreaterEqual, lhs, rhs)
+}
+
+/// Perform `op` on the provided `Datum`
+fn compare_op(
+ op: Op,
+ lhs: &dyn Datum,
+ rhs: &dyn Datum,
+) -> Result<BooleanArray, ArrowError> {
+ use arrow_schema::DataType::*;
+ let (l, l_s) = lhs.get();
+ let (r, r_s) = rhs.get();
+
+ let l_len = l.len();
+ let r_len = r.len();
+ let l_nulls = l.logical_nulls();
+ let r_nulls = r.logical_nulls();
+
+ let (len, nulls) = match (l_s, r_s) {
+ (true, true) | (false, false) => {
+ if l_len != r_len {
+ return Err(ArrowError::InvalidArgumentError(format!(
+ "Cannot compare arrays of different lengths, got {l_len}
vs {r_len}"
+ )));
+ }
+ (l_len, NullBuffer::union(l_nulls.as_ref(), r_nulls.as_ref()))
+ }
+ (true, false) => match l_nulls.map(|x| x.null_count() !=
0).unwrap_or_default() {
+ true => (r_len, Some(NullBuffer::new_null(r_len))),
+ false => (r_len, r_nulls), // Left is scalar and not null
+ },
+ (false, true) => match r_nulls.map(|x| x.null_count() !=
0).unwrap_or_default() {
+ true => (l_len, Some(NullBuffer::new_null(l_len))),
+ false => (l_len, l_nulls), // Right is scalar and not null
+ },
+ };
+
+ let l_v = as_dictionary(l);
+ let l = l_v.map(|x| x.values().as_ref()).unwrap_or(l);
+
+ let r_v = as_dictionary(r);
+ let r = r_v.map(|x| x.values().as_ref()).unwrap_or(r);
+
+ let values = downcast_primitive_array! {
+ (l, r) => apply(op, l.values().as_ref(), l_s, l_v,
r.values().as_ref(), r_s, r_v),
Review Comment:
Just confirming that this handles `Date`, `Time`, `Timestamp`, etc types,
right?
##########
arrow-flight/src/sql/metadata/sql_info.rs:
##########
@@ -425,13 +424,16 @@ impl SqlInfoData {
&self,
info: impl IntoIterator<Item = u32>,
) -> Result<RecordBatch> {
- let arr: UInt32Array = downcast_array(self.batch.column(0).as_ref());
+ let arr = self.batch.column(0);
let type_filter = info
.into_iter()
- .map(|tt| eq_scalar(&arr, tt))
+ .map(|tt| {
+ let s = UInt32Array::from(vec![tt]);
+ eq(arr, &Scalar::new(&s))
Review Comment:
Similarly, I think this would read better like
```
let s = Scalar::new_uint32(tt);
eq(arr, &Scalar::new(&s))
```
##########
arrow-ord/src/cmp.rs:
##########
@@ -0,0 +1,527 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Comparison kernels for `Array`s.
+//!
+//! These kernels can leverage SIMD if available on your system. Currently no
runtime
+//! detection is provided, you should enable the specific SIMD intrinsics using
+//! `RUSTFLAGS="-C target-feature=+avx2"` for example. See the documentation
+//! [here](https://doc.rust-lang.org/stable/core/arch/) for more information.
+//!
+
+use arrow_array::cast::AsArray;
+use arrow_array::types::{ArrowDictionaryKeyType, ByteArrayType};
+use arrow_array::{
+ downcast_dictionary_array, downcast_primitive_array, Array, ArrayRef,
+ ArrowNativeTypeOp, BooleanArray, Datum, DictionaryArray,
FixedSizeBinaryArray,
+ GenericByteArray,
+};
+use arrow_buffer::bit_util::ceil;
+use arrow_buffer::{ArrowNativeType, BooleanBuffer, MutableBuffer, NullBuffer};
+use arrow_schema::ArrowError;
+use arrow_select::take::take;
+
+#[derive(Debug, Copy, Clone)]
+enum Op {
+ Equal,
+ NotEqual,
+ Less,
+ LessEqual,
+ Greater,
+ GreaterEqual,
+}
+
+impl std::fmt::Display for Op {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ match self {
+ Op::Equal => write!(f, "=="),
+ Op::NotEqual => write!(f, "!="),
+ Op::Less => write!(f, "<"),
+ Op::LessEqual => write!(f, "<="),
+ Op::Greater => write!(f, ">"),
+ Op::GreaterEqual => write!(f, ">="),
+ }
+ }
+}
+
+/// Perform `left == right` operation on two [`Datum`]
+///
+/// For floating values like f32 and f64, this comparison produces an ordering
in accordance to
+/// the totalOrder predicate as defined in the IEEE 754 (2008 revision)
floating point standard.
+/// Note that totalOrder treats positive and negative zeros as different. If
it is necessary
+/// to treat them as equal, please normalize zeros before calling this kernel.
+///
+/// Please refer to [`f32::total_cmp`] and [`f64::total_cmp`]
+pub fn eq(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray,
ArrowError> {
+ compare_op(Op::Equal, lhs, rhs)
+}
+
+/// Perform `left != right` operation on two [`Datum`]
+///
+/// For floating values like f32 and f64, this comparison produces an ordering
in accordance to
+/// the totalOrder predicate as defined in the IEEE 754 (2008 revision)
floating point standard.
+/// Note that totalOrder treats positive and negative zeros as different. If
it is necessary
+/// to treat them as equal, please normalize zeros before calling this kernel.
+///
+/// Please refer to [`f32::total_cmp`] and [`f64::total_cmp`]
+pub fn neq(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray,
ArrowError> {
+ compare_op(Op::NotEqual, lhs, rhs)
+}
+
+/// Perform `left < right` operation on two [`Datum`]
+///
+/// For floating values like f32 and f64, this comparison produces an ordering
in accordance to
+/// the totalOrder predicate as defined in the IEEE 754 (2008 revision)
floating point standard.
+/// Note that totalOrder treats positive and negative zeros as different. If
it is necessary
+/// to treat them as equal, please normalize zeros before calling this kernel.
+///
+/// Please refer to [`f32::total_cmp`] and [`f64::total_cmp`]
+pub fn lt(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray,
ArrowError> {
+ compare_op(Op::Less, lhs, rhs)
+}
+
+/// Perform `left <= right` operation on two [`Datum`]
+///
+/// For floating values like f32 and f64, this comparison produces an ordering
in accordance to
+/// the totalOrder predicate as defined in the IEEE 754 (2008 revision)
floating point standard.
+/// Note that totalOrder treats positive and negative zeros as different. If
it is necessary
+/// to treat them as equal, please normalize zeros before calling this kernel.
+///
+/// Please refer to [`f32::total_cmp`] and [`f64::total_cmp`]
+pub fn lt_eq(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray,
ArrowError> {
+ compare_op(Op::LessEqual, lhs, rhs)
+}
+
+/// Perform `left > right` operation on two [`Datum`]
+///
+/// For floating values like f32 and f64, this comparison produces an ordering
in accordance to
+/// the totalOrder predicate as defined in the IEEE 754 (2008 revision)
floating point standard.
+/// Note that totalOrder treats positive and negative zeros as different. If
it is necessary
+/// to treat them as equal, please normalize zeros before calling this kernel.
+///
+/// Please refer to [`f32::total_cmp`] and [`f64::total_cmp`]
+pub fn gt(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray,
ArrowError> {
+ compare_op(Op::Greater, lhs, rhs)
+}
+
+/// Perform `left >= right` operation on two [`Datum`]
+///
+/// For floating values like f32 and f64, this comparison produces an ordering
in accordance to
+/// the totalOrder predicate as defined in the IEEE 754 (2008 revision)
floating point standard.
+/// Note that totalOrder treats positive and negative zeros as different. If
it is necessary
+/// to treat them as equal, please normalize zeros before calling this kernel.
+///
+/// Please refer to [`f32::total_cmp`] and [`f64::total_cmp`]
+pub fn gt_eq(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray,
ArrowError> {
+ compare_op(Op::GreaterEqual, lhs, rhs)
+}
+
+/// Perform `op` on the provided `Datum`
+fn compare_op(
+ op: Op,
+ lhs: &dyn Datum,
+ rhs: &dyn Datum,
+) -> Result<BooleanArray, ArrowError> {
+ use arrow_schema::DataType::*;
+ let (l, l_s) = lhs.get();
+ let (r, r_s) = rhs.get();
+
+ let l_len = l.len();
+ let r_len = r.len();
+ let l_nulls = l.logical_nulls();
+ let r_nulls = r.logical_nulls();
+
+ let (len, nulls) = match (l_s, r_s) {
+ (true, true) | (false, false) => {
+ if l_len != r_len {
+ return Err(ArrowError::InvalidArgumentError(format!(
+ "Cannot compare arrays of different lengths, got {l_len}
vs {r_len}"
+ )));
+ }
+ (l_len, NullBuffer::union(l_nulls.as_ref(), r_nulls.as_ref()))
+ }
+ (true, false) => match l_nulls.map(|x| x.null_count() !=
0).unwrap_or_default() {
+ true => (r_len, Some(NullBuffer::new_null(r_len))),
+ false => (r_len, r_nulls), // Left is scalar and not null
+ },
+ (false, true) => match r_nulls.map(|x| x.null_count() !=
0).unwrap_or_default() {
+ true => (l_len, Some(NullBuffer::new_null(l_len))),
+ false => (l_len, l_nulls), // Right is scalar and not null
+ },
+ };
+
+ let l_v = as_dictionary(l);
+ let l = l_v.map(|x| x.values().as_ref()).unwrap_or(l);
+
+ let r_v = as_dictionary(r);
+ let r = r_v.map(|x| x.values().as_ref()).unwrap_or(r);
+
+ let values = downcast_primitive_array! {
+ (l, r) => apply(op, l.values().as_ref(), l_s, l_v,
r.values().as_ref(), r_s, r_v),
+ (Boolean, Boolean) => apply(op, l.as_boolean(), l_s, l_v,
r.as_boolean(), r_s, r_v),
+ (Utf8, Utf8) => apply(op, l.as_string::<i32>(), l_s, l_v,
r.as_string::<i32>(), r_s, r_v),
+ (LargeUtf8, LargeUtf8) => apply(op, l.as_string::<i64>(), l_s, l_v,
r.as_string::<i64>(), r_s, r_v),
+ (Binary, Binary) => apply(op, l.as_binary::<i32>(), l_s, l_v,
r.as_binary::<i32>(), r_s, r_v),
+ (LargeBinary, LargeBinary) => apply(op, l.as_binary::<i64>(), l_s,
l_v, r.as_binary::<i64>(), r_s, r_v),
+ (FixedSizeBinary(_), FixedSizeBinary(_)) => apply(op,
l.as_fixed_size_binary(), l_s, l_v, r.as_fixed_size_binary(), r_s, r_v),
+ (l_t, r_t) => return
Err(ArrowError::InvalidArgumentError(format!("Invalid comparison operation:
{l_t} {op} {r_t}"))),
Review Comment:
I think "NotImplemented" might be a more precise message
```suggestion
(l_t, r_t) => return
Err(ArrowError::NotImplemented(format!("Unsupported comparison operation: {l_t}
{op} {r_t}"))),
```
##########
arrow-ord/src/cmp.rs:
##########
@@ -0,0 +1,527 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Comparison kernels for `Array`s.
+//!
+//! These kernels can leverage SIMD if available on your system. Currently no
runtime
+//! detection is provided, you should enable the specific SIMD intrinsics using
+//! `RUSTFLAGS="-C target-feature=+avx2"` for example. See the documentation
+//! [here](https://doc.rust-lang.org/stable/core/arch/) for more information.
+//!
+
+use arrow_array::cast::AsArray;
+use arrow_array::types::{ArrowDictionaryKeyType, ByteArrayType};
+use arrow_array::{
+ downcast_dictionary_array, downcast_primitive_array, Array, ArrayRef,
+ ArrowNativeTypeOp, BooleanArray, Datum, DictionaryArray,
FixedSizeBinaryArray,
+ GenericByteArray,
+};
+use arrow_buffer::bit_util::ceil;
+use arrow_buffer::{ArrowNativeType, BooleanBuffer, MutableBuffer, NullBuffer};
+use arrow_schema::ArrowError;
+use arrow_select::take::take;
+
+#[derive(Debug, Copy, Clone)]
+enum Op {
+ Equal,
+ NotEqual,
+ Less,
+ LessEqual,
+ Greater,
+ GreaterEqual,
+}
+
+impl std::fmt::Display for Op {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ match self {
+ Op::Equal => write!(f, "=="),
+ Op::NotEqual => write!(f, "!="),
+ Op::Less => write!(f, "<"),
+ Op::LessEqual => write!(f, "<="),
+ Op::Greater => write!(f, ">"),
+ Op::GreaterEqual => write!(f, ">="),
+ }
+ }
+}
+
+/// Perform `left == right` operation on two [`Datum`]
+///
+/// For floating values like f32 and f64, this comparison produces an ordering
in accordance to
+/// the totalOrder predicate as defined in the IEEE 754 (2008 revision)
floating point standard.
+/// Note that totalOrder treats positive and negative zeros as different. If
it is necessary
+/// to treat them as equal, please normalize zeros before calling this kernel.
+///
+/// Please refer to [`f32::total_cmp`] and [`f64::total_cmp`]
+pub fn eq(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray,
ArrowError> {
+ compare_op(Op::Equal, lhs, rhs)
+}
+
+/// Perform `left != right` operation on two [`Datum`]
+///
+/// For floating values like f32 and f64, this comparison produces an ordering
in accordance to
+/// the totalOrder predicate as defined in the IEEE 754 (2008 revision)
floating point standard.
+/// Note that totalOrder treats positive and negative zeros as different. If
it is necessary
+/// to treat them as equal, please normalize zeros before calling this kernel.
+///
+/// Please refer to [`f32::total_cmp`] and [`f64::total_cmp`]
+pub fn neq(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray,
ArrowError> {
+ compare_op(Op::NotEqual, lhs, rhs)
+}
+
+/// Perform `left < right` operation on two [`Datum`]
+///
+/// For floating values like f32 and f64, this comparison produces an ordering
in accordance to
+/// the totalOrder predicate as defined in the IEEE 754 (2008 revision)
floating point standard.
+/// Note that totalOrder treats positive and negative zeros as different. If
it is necessary
+/// to treat them as equal, please normalize zeros before calling this kernel.
+///
+/// Please refer to [`f32::total_cmp`] and [`f64::total_cmp`]
+pub fn lt(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray,
ArrowError> {
+ compare_op(Op::Less, lhs, rhs)
+}
+
+/// Perform `left <= right` operation on two [`Datum`]
+///
+/// For floating values like f32 and f64, this comparison produces an ordering
in accordance to
+/// the totalOrder predicate as defined in the IEEE 754 (2008 revision)
floating point standard.
+/// Note that totalOrder treats positive and negative zeros as different. If
it is necessary
+/// to treat them as equal, please normalize zeros before calling this kernel.
+///
+/// Please refer to [`f32::total_cmp`] and [`f64::total_cmp`]
+pub fn lt_eq(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray,
ArrowError> {
+ compare_op(Op::LessEqual, lhs, rhs)
+}
+
+/// Perform `left > right` operation on two [`Datum`]
+///
+/// For floating values like f32 and f64, this comparison produces an ordering
in accordance to
+/// the totalOrder predicate as defined in the IEEE 754 (2008 revision)
floating point standard.
+/// Note that totalOrder treats positive and negative zeros as different. If
it is necessary
+/// to treat them as equal, please normalize zeros before calling this kernel.
+///
+/// Please refer to [`f32::total_cmp`] and [`f64::total_cmp`]
+pub fn gt(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray,
ArrowError> {
+ compare_op(Op::Greater, lhs, rhs)
+}
+
+/// Perform `left >= right` operation on two [`Datum`]
+///
+/// For floating values like f32 and f64, this comparison produces an ordering
in accordance to
+/// the totalOrder predicate as defined in the IEEE 754 (2008 revision)
floating point standard.
+/// Note that totalOrder treats positive and negative zeros as different. If
it is necessary
+/// to treat them as equal, please normalize zeros before calling this kernel.
+///
+/// Please refer to [`f32::total_cmp`] and [`f64::total_cmp`]
+pub fn gt_eq(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray,
ArrowError> {
+ compare_op(Op::GreaterEqual, lhs, rhs)
+}
+
+/// Perform `op` on the provided `Datum`
+fn compare_op(
+ op: Op,
+ lhs: &dyn Datum,
+ rhs: &dyn Datum,
+) -> Result<BooleanArray, ArrowError> {
+ use arrow_schema::DataType::*;
+ let (l, l_s) = lhs.get();
+ let (r, r_s) = rhs.get();
+
+ let l_len = l.len();
+ let r_len = r.len();
+ let l_nulls = l.logical_nulls();
+ let r_nulls = r.logical_nulls();
+
+ let (len, nulls) = match (l_s, r_s) {
+ (true, true) | (false, false) => {
+ if l_len != r_len {
+ return Err(ArrowError::InvalidArgumentError(format!(
+ "Cannot compare arrays of different lengths, got {l_len}
vs {r_len}"
+ )));
+ }
+ (l_len, NullBuffer::union(l_nulls.as_ref(), r_nulls.as_ref()))
+ }
+ (true, false) => match l_nulls.map(|x| x.null_count() !=
0).unwrap_or_default() {
+ true => (r_len, Some(NullBuffer::new_null(r_len))),
Review Comment:
At this point if the scalar was null, and the entire output will be null,
there is no need to do any other computation -- maybe this function could
simply return early here (and reuse the input array's buffer rather than making
a new one)?
The same comment applies when right is a null scalar
##########
arrow-ord/Cargo.toml:
##########
@@ -44,10 +44,3 @@ half = { version = "2.1", default-features = false, features
= ["num-traits"] }
[dev-dependencies]
rand = { version = "0.8", default-features = false, features = ["std",
"std_rng"] }
-
-[package.metadata.docs.rs]
-features = ["dyn_cmp_dict"]
-
-[features]
-dyn_cmp_dict = []
-simd = ["arrow-array/simd"]
Review Comment:
For other kernels as I recall we found properly written rust code would be
vectorized by llvm more effectively than our hand rolled simd kernels. Do you
think that is still the case?
cc @jhorstmann
##########
arrow-ord/src/cmp.rs:
##########
@@ -0,0 +1,527 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Comparison kernels for `Array`s.
+//!
+//! These kernels can leverage SIMD if available on your system. Currently no
runtime
+//! detection is provided, you should enable the specific SIMD intrinsics using
+//! `RUSTFLAGS="-C target-feature=+avx2"` for example. See the documentation
+//! [here](https://doc.rust-lang.org/stable/core/arch/) for more information.
+//!
+
+use arrow_array::cast::AsArray;
+use arrow_array::types::{ArrowDictionaryKeyType, ByteArrayType};
+use arrow_array::{
+ downcast_dictionary_array, downcast_primitive_array, Array, ArrayRef,
+ ArrowNativeTypeOp, BooleanArray, Datum, DictionaryArray,
FixedSizeBinaryArray,
+ GenericByteArray,
+};
+use arrow_buffer::bit_util::ceil;
+use arrow_buffer::{ArrowNativeType, BooleanBuffer, MutableBuffer, NullBuffer};
+use arrow_schema::ArrowError;
+use arrow_select::take::take;
+
+#[derive(Debug, Copy, Clone)]
+enum Op {
+ Equal,
+ NotEqual,
+ Less,
+ LessEqual,
+ Greater,
+ GreaterEqual,
+}
+
+impl std::fmt::Display for Op {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ match self {
+ Op::Equal => write!(f, "=="),
+ Op::NotEqual => write!(f, "!="),
+ Op::Less => write!(f, "<"),
+ Op::LessEqual => write!(f, "<="),
+ Op::Greater => write!(f, ">"),
+ Op::GreaterEqual => write!(f, ">="),
+ }
+ }
+}
+
+/// Perform `left == right` operation on two [`Datum`]
+///
+/// For floating values like f32 and f64, this comparison produces an ordering
in accordance to
+/// the totalOrder predicate as defined in the IEEE 754 (2008 revision)
floating point standard.
+/// Note that totalOrder treats positive and negative zeros as different. If
it is necessary
+/// to treat them as equal, please normalize zeros before calling this kernel.
+///
+/// Please refer to [`f32::total_cmp`] and [`f64::total_cmp`]
+pub fn eq(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray,
ArrowError> {
+ compare_op(Op::Equal, lhs, rhs)
+}
+
+/// Perform `left != right` operation on two [`Datum`]
+///
+/// For floating values like f32 and f64, this comparison produces an ordering
in accordance to
+/// the totalOrder predicate as defined in the IEEE 754 (2008 revision)
floating point standard.
+/// Note that totalOrder treats positive and negative zeros as different. If
it is necessary
+/// to treat them as equal, please normalize zeros before calling this kernel.
+///
+/// Please refer to [`f32::total_cmp`] and [`f64::total_cmp`]
+pub fn neq(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray,
ArrowError> {
+ compare_op(Op::NotEqual, lhs, rhs)
+}
+
+/// Perform `left < right` operation on two [`Datum`]
+///
+/// For floating values like f32 and f64, this comparison produces an ordering
in accordance to
+/// the totalOrder predicate as defined in the IEEE 754 (2008 revision)
floating point standard.
+/// Note that totalOrder treats positive and negative zeros as different. If
it is necessary
+/// to treat them as equal, please normalize zeros before calling this kernel.
+///
+/// Please refer to [`f32::total_cmp`] and [`f64::total_cmp`]
+pub fn lt(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray,
ArrowError> {
+ compare_op(Op::Less, lhs, rhs)
+}
+
+/// Perform `left <= right` operation on two [`Datum`]
+///
+/// For floating values like f32 and f64, this comparison produces an ordering
in accordance to
+/// the totalOrder predicate as defined in the IEEE 754 (2008 revision)
floating point standard.
+/// Note that totalOrder treats positive and negative zeros as different. If
it is necessary
+/// to treat them as equal, please normalize zeros before calling this kernel.
+///
+/// Please refer to [`f32::total_cmp`] and [`f64::total_cmp`]
+pub fn lt_eq(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray,
ArrowError> {
+ compare_op(Op::LessEqual, lhs, rhs)
+}
+
+/// Perform `left > right` operation on two [`Datum`]
+///
+/// For floating values like f32 and f64, this comparison produces an ordering
in accordance to
+/// the totalOrder predicate as defined in the IEEE 754 (2008 revision)
floating point standard.
+/// Note that totalOrder treats positive and negative zeros as different. If
it is necessary
+/// to treat them as equal, please normalize zeros before calling this kernel.
+///
+/// Please refer to [`f32::total_cmp`] and [`f64::total_cmp`]
+pub fn gt(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray,
ArrowError> {
+ compare_op(Op::Greater, lhs, rhs)
+}
+
+/// Perform `left >= right` operation on two [`Datum`]
+///
+/// For floating values like f32 and f64, this comparison produces an ordering
in accordance to
+/// the totalOrder predicate as defined in the IEEE 754 (2008 revision)
floating point standard.
+/// Note that totalOrder treats positive and negative zeros as different. If
it is necessary
+/// to treat them as equal, please normalize zeros before calling this kernel.
+///
+/// Please refer to [`f32::total_cmp`] and [`f64::total_cmp`]
+pub fn gt_eq(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray,
ArrowError> {
+ compare_op(Op::GreaterEqual, lhs, rhs)
+}
+
+/// Perform `op` on the provided `Datum`
+fn compare_op(
+ op: Op,
+ lhs: &dyn Datum,
+ rhs: &dyn Datum,
+) -> Result<BooleanArray, ArrowError> {
+ use arrow_schema::DataType::*;
+ let (l, l_s) = lhs.get();
+ let (r, r_s) = rhs.get();
+
+ let l_len = l.len();
+ let r_len = r.len();
+ let l_nulls = l.logical_nulls();
+ let r_nulls = r.logical_nulls();
+
+ let (len, nulls) = match (l_s, r_s) {
+ (true, true) | (false, false) => {
+ if l_len != r_len {
+ return Err(ArrowError::InvalidArgumentError(format!(
+ "Cannot compare arrays of different lengths, got {l_len}
vs {r_len}"
+ )));
+ }
+ (l_len, NullBuffer::union(l_nulls.as_ref(), r_nulls.as_ref()))
+ }
+ (true, false) => match l_nulls.map(|x| x.null_count() !=
0).unwrap_or_default() {
+ true => (r_len, Some(NullBuffer::new_null(r_len))),
+ false => (r_len, r_nulls), // Left is scalar and not null
+ },
+ (false, true) => match r_nulls.map(|x| x.null_count() !=
0).unwrap_or_default() {
+ true => (l_len, Some(NullBuffer::new_null(l_len))),
+ false => (l_len, l_nulls), // Right is scalar and not null
+ },
+ };
+
+ let l_v = as_dictionary(l);
+ let l = l_v.map(|x| x.values().as_ref()).unwrap_or(l);
+
+ let r_v = as_dictionary(r);
+ let r = r_v.map(|x| x.values().as_ref()).unwrap_or(r);
+
+ let values = downcast_primitive_array! {
+ (l, r) => apply(op, l.values().as_ref(), l_s, l_v,
r.values().as_ref(), r_s, r_v),
+ (Boolean, Boolean) => apply(op, l.as_boolean(), l_s, l_v,
r.as_boolean(), r_s, r_v),
+ (Utf8, Utf8) => apply(op, l.as_string::<i32>(), l_s, l_v,
r.as_string::<i32>(), r_s, r_v),
+ (LargeUtf8, LargeUtf8) => apply(op, l.as_string::<i64>(), l_s, l_v,
r.as_string::<i64>(), r_s, r_v),
+ (Binary, Binary) => apply(op, l.as_binary::<i32>(), l_s, l_v,
r.as_binary::<i32>(), r_s, r_v),
+ (LargeBinary, LargeBinary) => apply(op, l.as_binary::<i64>(), l_s,
l_v, r.as_binary::<i64>(), r_s, r_v),
+ (FixedSizeBinary(_), FixedSizeBinary(_)) => apply(op,
l.as_fixed_size_binary(), l_s, l_v, r.as_fixed_size_binary(), r_s, r_v),
+ (l_t, r_t) => return
Err(ArrowError::InvalidArgumentError(format!("Invalid comparison operation:
{l_t} {op} {r_t}"))),
+ }.unwrap_or_else(|| {
+ let count = nulls.as_ref().map(|x| x.null_count()).unwrap_or_default();
+ assert_eq!(count, len); // Sanity check
+ BooleanBuffer::new_unset(len)
+ });
+
+ assert_eq!(values.len(), len); // Sanity check
+ Ok(BooleanArray::new(values, nulls))
+}
+
+fn as_dictionary(a: &dyn Array) -> Option<&dyn Dictionary> {
+ downcast_dictionary_array! {
+ a => Some(a),
+ _ => None
+ }
+}
+
+trait Dictionary: Array {
+ /// Returns the keys of this dictionary, clamped to be in the range
`0..values.len()`
+ ///
+ /// # Panic
+ ///
+ /// Panics if `values.len() == 0`
+ fn normalized_keys(&self) -> Vec<usize>;
+
+ /// Returns the values of this dictionary
+ fn values(&self) -> &ArrayRef;
+
+ /// Applies the `keys` of this dictionary to the provided array
+ fn take(&self, array: &dyn Array) -> Result<ArrayRef, ArrowError>;
+}
+
+impl<K: ArrowDictionaryKeyType> Dictionary for DictionaryArray<K> {
+ fn normalized_keys(&self) -> Vec<usize> {
+ let v_len = self.values().len();
+ assert_ne!(v_len, 0);
+ let iter = self.keys().values().iter();
+ iter.map(|x| x.as_usize().min(v_len)).collect()
+ }
+
+ fn values(&self) -> &ArrayRef {
+ self.values()
+ }
+
+ fn take(&self, array: &dyn Array) -> Result<ArrayRef, ArrowError> {
+ take(array, self.keys(), None)
+ }
+}
+
+/// Perform a potentially vectored `op` on the provided `ArrayOrd`
+fn apply<T: ArrayOrd>(
+ op: Op,
+ l: T,
+ l_s: bool,
+ l_v: Option<&dyn Dictionary>,
+ r: T,
+ r_s: bool,
+ r_v: Option<&dyn Dictionary>,
+) -> Option<BooleanBuffer> {
+ if l.len() == 0 || r.len() == 0 {
+ return None; // Handle empty dictionaries
+ }
+
+ if !l_s && !r_s && (l_v.is_some() || r_v.is_some()) {
+ // Not scalar and at least one side has a dictionary, need to perform
vectored comparison
+ let l_v = l_v
+ .map(|x| x.normalized_keys())
+ .unwrap_or_else(|| (0..l.len()).collect());
+
+ let r_v = r_v
+ .map(|x| x.normalized_keys())
+ .unwrap_or_else(|| (0..r.len()).collect());
+
+ assert_eq!(l_v.len(), r_v.len()); // Sanity check
+
+ Some(match op {
+ Op::Equal => apply_op_vectored(l, &l_v, r, &r_v, false, T::is_eq),
+ Op::NotEqual => apply_op_vectored(l, &l_v, r, &r_v, true,
T::is_eq),
+ Op::Less => apply_op_vectored(l, &l_v, r, &r_v, false, T::is_lt),
+ Op::LessEqual => apply_op_vectored(r, &r_v, l, &l_v, true,
T::is_lt),
+ Op::Greater => apply_op_vectored(r, &r_v, l, &l_v, false,
T::is_lt),
+ Op::GreaterEqual => apply_op_vectored(l, &l_v, r, &r_v, true,
T::is_lt),
+ })
+ } else {
+ let l_s = l_s.then(|| l_v.map(|x|
x.normalized_keys()[0]).unwrap_or_default());
+ let r_s = r_s.then(|| r_v.map(|x|
x.normalized_keys()[0]).unwrap_or_default());
+
+ let buffer = match op {
+ Op::Equal => apply_op(l, l_s, r, r_s, false, T::is_eq),
+ Op::NotEqual => apply_op(l, l_s, r, r_s, true, T::is_eq),
+ Op::Less => apply_op(l, l_s, r, r_s, false, T::is_lt),
+ Op::LessEqual => apply_op(r, r_s, l, l_s, true, T::is_lt),
+ Op::Greater => apply_op(r, r_s, l, l_s, false, T::is_lt),
+ Op::GreaterEqual => apply_op(l, l_s, r, r_s, true, T::is_lt),
+ };
+
+ // If a side had a dictionary, and was not scalar, we need to
materialize this
+ Some(match (l_v, r_v) {
+ (Some(l_v), _) if l_s.is_none() => take_bits(l_v, buffer),
+ (_, Some(r_v)) if r_s.is_none() => take_bits(r_v, buffer),
+ _ => buffer,
+ })
+ }
+}
+
+/// Perform a take operation on `buffer` with the given dictionary
+fn take_bits(v: &dyn Dictionary, buffer: BooleanBuffer) -> BooleanBuffer {
+ let array = v.take(&BooleanArray::new(buffer, None)).unwrap();
+ array.as_boolean().values().clone()
+}
+
+/// Invokes `f` with values `0..len` collecting the boolean results into a new
`BooleanBuffer`
+///
+/// This is similar to [`MutableBuffer::collect_bool`] but with
+/// the option to efficiently negate the result
+fn collect_bool(len: usize, neg: bool, f: impl Fn(usize) -> bool) ->
BooleanBuffer {
+ let mut buffer = MutableBuffer::new(ceil(len, 64) * 8);
+
+ let chunks = len / 64;
+ let remainder = len % 64;
+ for chunk in 0..chunks {
+ let mut packed = 0;
+ for bit_idx in 0..64 {
+ let i = bit_idx + chunk * 64;
+ packed |= (f(i) as u64) << bit_idx;
+ }
+ if neg {
+ packed = !packed
+ }
+
+ // SAFETY: Already allocated sufficient capacity
+ unsafe { buffer.push_unchecked(packed) }
+ }
+
+ if remainder != 0 {
+ let mut packed = 0;
+ for bit_idx in 0..remainder {
+ let i = bit_idx + chunks * 64;
+ packed |= (f(i) as u64) << bit_idx;
+ }
+ if neg {
+ packed = !packed
+ }
+
+ // SAFETY: Already allocated sufficient capacity
+ unsafe { buffer.push_unchecked(packed) }
+ }
+ BooleanBuffer::new(buffer.into(), 0, len)
+}
+
+/// Applies `op` to possibly scalar `ArrayOrd`
+///
+/// If l is scalar `l_s` will be `Some(idx)` where `idx` is the index of the
scalar value in `l`
+/// If r is scalar `r_s` will be `Some(idx)` where `idx` is the index of the
scalar value in `r`
+fn apply_op<T: ArrayOrd>(
+ l: T,
+ l_s: Option<usize>,
+ r: T,
+ r_s: Option<usize>,
+ neg: bool,
+ op: impl Fn(T::Item, T::Item) -> bool,
+) -> BooleanBuffer {
+ match (l_s, r_s) {
+ (None, None) => {
+ assert_eq!(l.len(), r.len());
+ collect_bool(l.len(), neg, |idx| unsafe {
+ op(l.value_unchecked(idx), r.value_unchecked(idx))
+ })
+ }
+ (Some(l_s), Some(r_s)) => {
+ let a = l.value(l_s);
+ let b = r.value(r_s);
+ std::iter::once(op(a, b)).collect()
+ }
+ (Some(l_s), None) => {
+ let v = l.value(l_s);
+ collect_bool(r.len(), neg, |idx| op(v, unsafe {
r.value_unchecked(idx) }))
+ }
+ (None, Some(r_s)) => {
+ let v = r.value(r_s);
+ collect_bool(l.len(), neg, |idx| op(unsafe {
l.value_unchecked(idx) }, v))
+ }
+ }
+}
+
+/// Applies `op` to possibly scalar `ArrayOrd` with the given indices
+fn apply_op_vectored<T: ArrayOrd>(
+ l: T,
+ l_v: &[usize],
+ r: T,
+ r_v: &[usize],
+ neg: bool,
+ op: impl Fn(T::Item, T::Item) -> bool,
+) -> BooleanBuffer {
+ assert_eq!(l_v.len(), r_v.len());
+ collect_bool(l_v.len(), neg, |idx| unsafe {
+ let l_idx = *l_v.get_unchecked(idx);
+ let r_idx = *r_v.get_unchecked(idx);
+ op(l.value_unchecked(l_idx), r.value_unchecked(r_idx))
+ })
+}
+
+trait ArrayOrd {
+ type Item: Copy + Default;
+
+ fn len(&self) -> usize;
+
+ fn value(&self, idx: usize) -> Self::Item {
+ assert!(idx < self.len());
+ unsafe { self.value_unchecked(idx) }
+ }
+
+ /// # Safety
+ ///
+ /// Safe if `idx < self.len()`
+ unsafe fn value_unchecked(&self, idx: usize) -> Self::Item;
+
+ fn is_eq(l: Self::Item, r: Self::Item) -> bool;
+
+ fn is_lt(l: Self::Item, r: Self::Item) -> bool;
+}
+
+impl<'a> ArrayOrd for &'a BooleanArray {
+ type Item = bool;
+
+ fn len(&self) -> usize {
+ Array::len(self)
+ }
+
+ unsafe fn value_unchecked(&self, idx: usize) -> Self::Item {
+ BooleanArray::value_unchecked(self, idx)
+ }
+
+ fn is_eq(l: Self::Item, r: Self::Item) -> bool {
+ l == r
+ }
+
+ fn is_lt(l: Self::Item, r: Self::Item) -> bool {
+ !l & r
+ }
+}
+
+impl<T: ArrowNativeTypeOp> ArrayOrd for &[T] {
+ type Item = T;
+
+ fn len(&self) -> usize {
+ (*self).len()
+ }
+
+ unsafe fn value_unchecked(&self, idx: usize) -> Self::Item {
+ *self.get_unchecked(idx)
+ }
+
+ fn is_eq(l: Self::Item, r: Self::Item) -> bool {
+ l.is_eq(r)
Review Comment:
TIL: https://doc.rust-lang.org/std/cmp/enum.Ordering.html#method.is_eq
##########
arrow-ord/src/cmp.rs:
##########
@@ -0,0 +1,527 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Comparison kernels for `Array`s.
+//!
+//! These kernels can leverage SIMD if available on your system. Currently no
runtime
+//! detection is provided, you should enable the specific SIMD intrinsics using
+//! `RUSTFLAGS="-C target-feature=+avx2"` for example. See the documentation
+//! [here](https://doc.rust-lang.org/stable/core/arch/) for more information.
+//!
+
+use arrow_array::cast::AsArray;
+use arrow_array::types::{ArrowDictionaryKeyType, ByteArrayType};
+use arrow_array::{
+ downcast_dictionary_array, downcast_primitive_array, Array, ArrayRef,
+ ArrowNativeTypeOp, BooleanArray, Datum, DictionaryArray,
FixedSizeBinaryArray,
+ GenericByteArray,
+};
+use arrow_buffer::bit_util::ceil;
+use arrow_buffer::{ArrowNativeType, BooleanBuffer, MutableBuffer, NullBuffer};
+use arrow_schema::ArrowError;
+use arrow_select::take::take;
+
+#[derive(Debug, Copy, Clone)]
+enum Op {
+ Equal,
+ NotEqual,
+ Less,
+ LessEqual,
+ Greater,
+ GreaterEqual,
+}
+
+impl std::fmt::Display for Op {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ match self {
+ Op::Equal => write!(f, "=="),
+ Op::NotEqual => write!(f, "!="),
+ Op::Less => write!(f, "<"),
+ Op::LessEqual => write!(f, "<="),
+ Op::Greater => write!(f, ">"),
+ Op::GreaterEqual => write!(f, ">="),
+ }
+ }
+}
+
+/// Perform `left == right` operation on two [`Datum`]
+///
+/// For floating values like f32 and f64, this comparison produces an ordering
in accordance to
+/// the totalOrder predicate as defined in the IEEE 754 (2008 revision)
floating point standard.
+/// Note that totalOrder treats positive and negative zeros as different. If
it is necessary
+/// to treat them as equal, please normalize zeros before calling this kernel.
+///
+/// Please refer to [`f32::total_cmp`] and [`f64::total_cmp`]
+pub fn eq(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray,
ArrowError> {
+ compare_op(Op::Equal, lhs, rhs)
+}
+
+/// Perform `left != right` operation on two [`Datum`]
+///
+/// For floating values like f32 and f64, this comparison produces an ordering
in accordance to
+/// the totalOrder predicate as defined in the IEEE 754 (2008 revision)
floating point standard.
+/// Note that totalOrder treats positive and negative zeros as different. If
it is necessary
+/// to treat them as equal, please normalize zeros before calling this kernel.
+///
+/// Please refer to [`f32::total_cmp`] and [`f64::total_cmp`]
+pub fn neq(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray,
ArrowError> {
+ compare_op(Op::NotEqual, lhs, rhs)
+}
+
+/// Perform `left < right` operation on two [`Datum`]
+///
+/// For floating values like f32 and f64, this comparison produces an ordering
in accordance to
+/// the totalOrder predicate as defined in the IEEE 754 (2008 revision)
floating point standard.
+/// Note that totalOrder treats positive and negative zeros as different. If
it is necessary
+/// to treat them as equal, please normalize zeros before calling this kernel.
+///
+/// Please refer to [`f32::total_cmp`] and [`f64::total_cmp`]
+pub fn lt(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray,
ArrowError> {
+ compare_op(Op::Less, lhs, rhs)
+}
+
+/// Perform `left <= right` operation on two [`Datum`]
+///
+/// For floating values like f32 and f64, this comparison produces an ordering
in accordance to
+/// the totalOrder predicate as defined in the IEEE 754 (2008 revision)
floating point standard.
+/// Note that totalOrder treats positive and negative zeros as different. If
it is necessary
+/// to treat them as equal, please normalize zeros before calling this kernel.
+///
+/// Please refer to [`f32::total_cmp`] and [`f64::total_cmp`]
+pub fn lt_eq(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray,
ArrowError> {
+ compare_op(Op::LessEqual, lhs, rhs)
+}
+
+/// Perform `left > right` operation on two [`Datum`]
+///
+/// For floating values like f32 and f64, this comparison produces an ordering
in accordance to
+/// the totalOrder predicate as defined in the IEEE 754 (2008 revision)
floating point standard.
+/// Note that totalOrder treats positive and negative zeros as different. If
it is necessary
+/// to treat them as equal, please normalize zeros before calling this kernel.
+///
+/// Please refer to [`f32::total_cmp`] and [`f64::total_cmp`]
+pub fn gt(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray,
ArrowError> {
+ compare_op(Op::Greater, lhs, rhs)
+}
+
+/// Perform `left >= right` operation on two [`Datum`]
+///
+/// For floating values like f32 and f64, this comparison produces an ordering
in accordance to
+/// the totalOrder predicate as defined in the IEEE 754 (2008 revision)
floating point standard.
+/// Note that totalOrder treats positive and negative zeros as different. If
it is necessary
+/// to treat them as equal, please normalize zeros before calling this kernel.
+///
+/// Please refer to [`f32::total_cmp`] and [`f64::total_cmp`]
+pub fn gt_eq(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray,
ArrowError> {
+ compare_op(Op::GreaterEqual, lhs, rhs)
+}
+
+/// Perform `op` on the provided `Datum`
+fn compare_op(
+ op: Op,
+ lhs: &dyn Datum,
+ rhs: &dyn Datum,
+) -> Result<BooleanArray, ArrowError> {
+ use arrow_schema::DataType::*;
+ let (l, l_s) = lhs.get();
+ let (r, r_s) = rhs.get();
+
+ let l_len = l.len();
+ let r_len = r.len();
+ let l_nulls = l.logical_nulls();
+ let r_nulls = r.logical_nulls();
+
+ let (len, nulls) = match (l_s, r_s) {
+ (true, true) | (false, false) => {
+ if l_len != r_len {
+ return Err(ArrowError::InvalidArgumentError(format!(
+ "Cannot compare arrays of different lengths, got {l_len}
vs {r_len}"
+ )));
+ }
+ (l_len, NullBuffer::union(l_nulls.as_ref(), r_nulls.as_ref()))
+ }
+ (true, false) => match l_nulls.map(|x| x.null_count() !=
0).unwrap_or_default() {
+ true => (r_len, Some(NullBuffer::new_null(r_len))),
+ false => (r_len, r_nulls), // Left is scalar and not null
+ },
+ (false, true) => match r_nulls.map(|x| x.null_count() !=
0).unwrap_or_default() {
+ true => (l_len, Some(NullBuffer::new_null(l_len))),
+ false => (l_len, l_nulls), // Right is scalar and not null
+ },
+ };
+
+ let l_v = as_dictionary(l);
+ let l = l_v.map(|x| x.values().as_ref()).unwrap_or(l);
+
+ let r_v = as_dictionary(r);
+ let r = r_v.map(|x| x.values().as_ref()).unwrap_or(r);
+
+ let values = downcast_primitive_array! {
+ (l, r) => apply(op, l.values().as_ref(), l_s, l_v,
r.values().as_ref(), r_s, r_v),
+ (Boolean, Boolean) => apply(op, l.as_boolean(), l_s, l_v,
r.as_boolean(), r_s, r_v),
+ (Utf8, Utf8) => apply(op, l.as_string::<i32>(), l_s, l_v,
r.as_string::<i32>(), r_s, r_v),
+ (LargeUtf8, LargeUtf8) => apply(op, l.as_string::<i64>(), l_s, l_v,
r.as_string::<i64>(), r_s, r_v),
+ (Binary, Binary) => apply(op, l.as_binary::<i32>(), l_s, l_v,
r.as_binary::<i32>(), r_s, r_v),
+ (LargeBinary, LargeBinary) => apply(op, l.as_binary::<i64>(), l_s,
l_v, r.as_binary::<i64>(), r_s, r_v),
+ (FixedSizeBinary(_), FixedSizeBinary(_)) => apply(op,
l.as_fixed_size_binary(), l_s, l_v, r.as_fixed_size_binary(), r_s, r_v),
+ (l_t, r_t) => return
Err(ArrowError::InvalidArgumentError(format!("Invalid comparison operation:
{l_t} {op} {r_t}"))),
+ }.unwrap_or_else(|| {
+ let count = nulls.as_ref().map(|x| x.null_count()).unwrap_or_default();
+ assert_eq!(count, len); // Sanity check
+ BooleanBuffer::new_unset(len)
+ });
+
+ assert_eq!(values.len(), len); // Sanity check
+ Ok(BooleanArray::new(values, nulls))
+}
+
+fn as_dictionary(a: &dyn Array) -> Option<&dyn Dictionary> {
+ downcast_dictionary_array! {
+ a => Some(a),
+ _ => None
+ }
+}
+
+trait Dictionary: Array {
+ /// Returns the keys of this dictionary, clamped to be in the range
`0..values.len()`
+ ///
+ /// # Panic
+ ///
+ /// Panics if `values.len() == 0`
+ fn normalized_keys(&self) -> Vec<usize>;
+
+ /// Returns the values of this dictionary
+ fn values(&self) -> &ArrayRef;
+
+ /// Applies the `keys` of this dictionary to the provided array
+ fn take(&self, array: &dyn Array) -> Result<ArrayRef, ArrowError>;
+}
+
+impl<K: ArrowDictionaryKeyType> Dictionary for DictionaryArray<K> {
+ fn normalized_keys(&self) -> Vec<usize> {
+ let v_len = self.values().len();
+ assert_ne!(v_len, 0);
+ let iter = self.keys().values().iter();
+ iter.map(|x| x.as_usize().min(v_len)).collect()
+ }
+
+ fn values(&self) -> &ArrayRef {
+ self.values()
+ }
+
+ fn take(&self, array: &dyn Array) -> Result<ArrayRef, ArrowError> {
+ take(array, self.keys(), None)
+ }
+}
+
+/// Perform a potentially vectored `op` on the provided `ArrayOrd`
+fn apply<T: ArrayOrd>(
+ op: Op,
+ l: T,
+ l_s: bool,
+ l_v: Option<&dyn Dictionary>,
+ r: T,
+ r_s: bool,
+ r_v: Option<&dyn Dictionary>,
+) -> Option<BooleanBuffer> {
+ if l.len() == 0 || r.len() == 0 {
+ return None; // Handle empty dictionaries
+ }
+
+ if !l_s && !r_s && (l_v.is_some() || r_v.is_some()) {
+ // Not scalar and at least one side has a dictionary, need to perform
vectored comparison
+ let l_v = l_v
+ .map(|x| x.normalized_keys())
+ .unwrap_or_else(|| (0..l.len()).collect());
+
+ let r_v = r_v
+ .map(|x| x.normalized_keys())
+ .unwrap_or_else(|| (0..r.len()).collect());
+
+ assert_eq!(l_v.len(), r_v.len()); // Sanity check
+
+ Some(match op {
+ Op::Equal => apply_op_vectored(l, &l_v, r, &r_v, false, T::is_eq),
+ Op::NotEqual => apply_op_vectored(l, &l_v, r, &r_v, true,
T::is_eq),
+ Op::Less => apply_op_vectored(l, &l_v, r, &r_v, false, T::is_lt),
+ Op::LessEqual => apply_op_vectored(r, &r_v, l, &l_v, true,
T::is_lt),
+ Op::Greater => apply_op_vectored(r, &r_v, l, &l_v, false,
T::is_lt),
+ Op::GreaterEqual => apply_op_vectored(l, &l_v, r, &r_v, true,
T::is_lt),
+ })
+ } else {
+ let l_s = l_s.then(|| l_v.map(|x|
x.normalized_keys()[0]).unwrap_or_default());
+ let r_s = r_s.then(|| r_v.map(|x|
x.normalized_keys()[0]).unwrap_or_default());
+
+ let buffer = match op {
+ Op::Equal => apply_op(l, l_s, r, r_s, false, T::is_eq),
+ Op::NotEqual => apply_op(l, l_s, r, r_s, true, T::is_eq),
+ Op::Less => apply_op(l, l_s, r, r_s, false, T::is_lt),
+ Op::LessEqual => apply_op(r, r_s, l, l_s, true, T::is_lt),
+ Op::Greater => apply_op(r, r_s, l, l_s, false, T::is_lt),
+ Op::GreaterEqual => apply_op(l, l_s, r, r_s, true, T::is_lt),
+ };
+
+ // If a side had a dictionary, and was not scalar, we need to
materialize this
+ Some(match (l_v, r_v) {
+ (Some(l_v), _) if l_s.is_none() => take_bits(l_v, buffer),
+ (_, Some(r_v)) if r_s.is_none() => take_bits(r_v, buffer),
+ _ => buffer,
+ })
+ }
+}
+
+/// Perform a take operation on `buffer` with the given dictionary
+fn take_bits(v: &dyn Dictionary, buffer: BooleanBuffer) -> BooleanBuffer {
+ let array = v.take(&BooleanArray::new(buffer, None)).unwrap();
+ array.as_boolean().values().clone()
+}
+
+/// Invokes `f` with values `0..len` collecting the boolean results into a new
`BooleanBuffer`
+///
+/// This is similar to [`MutableBuffer::collect_bool`] but with
+/// the option to efficiently negate the result
+fn collect_bool(len: usize, neg: bool, f: impl Fn(usize) -> bool) ->
BooleanBuffer {
+ let mut buffer = MutableBuffer::new(ceil(len, 64) * 8);
+
+ let chunks = len / 64;
+ let remainder = len % 64;
+ for chunk in 0..chunks {
+ let mut packed = 0;
+ for bit_idx in 0..64 {
+ let i = bit_idx + chunk * 64;
+ packed |= (f(i) as u64) << bit_idx;
+ }
+ if neg {
+ packed = !packed
+ }
+
+ // SAFETY: Already allocated sufficient capacity
+ unsafe { buffer.push_unchecked(packed) }
+ }
+
+ if remainder != 0 {
+ let mut packed = 0;
+ for bit_idx in 0..remainder {
+ let i = bit_idx + chunks * 64;
+ packed |= (f(i) as u64) << bit_idx;
+ }
+ if neg {
+ packed = !packed
+ }
+
+ // SAFETY: Already allocated sufficient capacity
+ unsafe { buffer.push_unchecked(packed) }
+ }
+ BooleanBuffer::new(buffer.into(), 0, len)
+}
+
+/// Applies `op` to possibly scalar `ArrayOrd`
+///
+/// If l is scalar `l_s` will be `Some(idx)` where `idx` is the index of the
scalar value in `l`
+/// If r is scalar `r_s` will be `Some(idx)` where `idx` is the index of the
scalar value in `r`
+fn apply_op<T: ArrayOrd>(
+ l: T,
+ l_s: Option<usize>,
+ r: T,
+ r_s: Option<usize>,
+ neg: bool,
+ op: impl Fn(T::Item, T::Item) -> bool,
+) -> BooleanBuffer {
+ match (l_s, r_s) {
+ (None, None) => {
+ assert_eq!(l.len(), r.len());
+ collect_bool(l.len(), neg, |idx| unsafe {
+ op(l.value_unchecked(idx), r.value_unchecked(idx))
+ })
+ }
+ (Some(l_s), Some(r_s)) => {
+ let a = l.value(l_s);
+ let b = r.value(r_s);
+ std::iter::once(op(a, b)).collect()
+ }
+ (Some(l_s), None) => {
+ let v = l.value(l_s);
+ collect_bool(r.len(), neg, |idx| op(v, unsafe {
r.value_unchecked(idx) }))
+ }
+ (None, Some(r_s)) => {
+ let v = r.value(r_s);
+ collect_bool(l.len(), neg, |idx| op(unsafe {
l.value_unchecked(idx) }, v))
+ }
+ }
+}
+
+/// Applies `op` to possibly scalar `ArrayOrd` with the given indices
+fn apply_op_vectored<T: ArrayOrd>(
+ l: T,
+ l_v: &[usize],
+ r: T,
+ r_v: &[usize],
+ neg: bool,
+ op: impl Fn(T::Item, T::Item) -> bool,
+) -> BooleanBuffer {
+ assert_eq!(l_v.len(), r_v.len());
+ collect_bool(l_v.len(), neg, |idx| unsafe {
+ let l_idx = *l_v.get_unchecked(idx);
+ let r_idx = *r_v.get_unchecked(idx);
+ op(l.value_unchecked(l_idx), r.value_unchecked(r_idx))
+ })
+}
+
+trait ArrayOrd {
Review Comment:
Could you possible add some more documentation to this trait?
##########
arrow-ord/src/cmp.rs:
##########
@@ -0,0 +1,527 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Comparison kernels for `Array`s.
+//!
+//! These kernels can leverage SIMD if available on your system. Currently no
runtime
+//! detection is provided, you should enable the specific SIMD intrinsics using
+//! `RUSTFLAGS="-C target-feature=+avx2"` for example. See the documentation
+//! [here](https://doc.rust-lang.org/stable/core/arch/) for more information.
+//!
+
+use arrow_array::cast::AsArray;
+use arrow_array::types::{ArrowDictionaryKeyType, ByteArrayType};
+use arrow_array::{
+ downcast_dictionary_array, downcast_primitive_array, Array, ArrayRef,
+ ArrowNativeTypeOp, BooleanArray, Datum, DictionaryArray,
FixedSizeBinaryArray,
+ GenericByteArray,
+};
+use arrow_buffer::bit_util::ceil;
+use arrow_buffer::{ArrowNativeType, BooleanBuffer, MutableBuffer, NullBuffer};
+use arrow_schema::ArrowError;
+use arrow_select::take::take;
+
+#[derive(Debug, Copy, Clone)]
+enum Op {
+ Equal,
+ NotEqual,
+ Less,
+ LessEqual,
+ Greater,
+ GreaterEqual,
+}
+
+impl std::fmt::Display for Op {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ match self {
+ Op::Equal => write!(f, "=="),
+ Op::NotEqual => write!(f, "!="),
+ Op::Less => write!(f, "<"),
+ Op::LessEqual => write!(f, "<="),
+ Op::Greater => write!(f, ">"),
+ Op::GreaterEqual => write!(f, ">="),
+ }
+ }
+}
+
+/// Perform `left == right` operation on two [`Datum`]
+///
+/// For floating values like f32 and f64, this comparison produces an ordering
in accordance to
Review Comment:
What does "comparison produces an ordering" mean? The output of this kernel
is a BooleanArray
Maybe it should say something like " For floating values like f32 and f64,
this comparison follows the definition of equality from the totalOrder
predicate as defined in the IEEE 754 (2008 revision) floating point standard."
🤔
The same comment applies to `neq` as well
##########
arrow-ord/src/cmp.rs:
##########
@@ -0,0 +1,527 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Comparison kernels for `Array`s.
+//!
+//! These kernels can leverage SIMD if available on your system. Currently no
runtime
+//! detection is provided, you should enable the specific SIMD intrinsics using
+//! `RUSTFLAGS="-C target-feature=+avx2"` for example. See the documentation
+//! [here](https://doc.rust-lang.org/stable/core/arch/) for more information.
+//!
+
+use arrow_array::cast::AsArray;
+use arrow_array::types::{ArrowDictionaryKeyType, ByteArrayType};
+use arrow_array::{
+ downcast_dictionary_array, downcast_primitive_array, Array, ArrayRef,
+ ArrowNativeTypeOp, BooleanArray, Datum, DictionaryArray,
FixedSizeBinaryArray,
+ GenericByteArray,
+};
+use arrow_buffer::bit_util::ceil;
+use arrow_buffer::{ArrowNativeType, BooleanBuffer, MutableBuffer, NullBuffer};
+use arrow_schema::ArrowError;
+use arrow_select::take::take;
+
+#[derive(Debug, Copy, Clone)]
+enum Op {
+ Equal,
+ NotEqual,
+ Less,
+ LessEqual,
+ Greater,
+ GreaterEqual,
+}
+
+impl std::fmt::Display for Op {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ match self {
+ Op::Equal => write!(f, "=="),
+ Op::NotEqual => write!(f, "!="),
+ Op::Less => write!(f, "<"),
+ Op::LessEqual => write!(f, "<="),
+ Op::Greater => write!(f, ">"),
+ Op::GreaterEqual => write!(f, ">="),
+ }
+ }
+}
+
+/// Perform `left == right` operation on two [`Datum`]
+///
+/// For floating values like f32 and f64, this comparison produces an ordering
in accordance to
+/// the totalOrder predicate as defined in the IEEE 754 (2008 revision)
floating point standard.
+/// Note that totalOrder treats positive and negative zeros as different. If
it is necessary
+/// to treat them as equal, please normalize zeros before calling this kernel.
+///
+/// Please refer to [`f32::total_cmp`] and [`f64::total_cmp`]
+pub fn eq(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray,
ArrowError> {
+ compare_op(Op::Equal, lhs, rhs)
+}
+
+/// Perform `left != right` operation on two [`Datum`]
+///
+/// For floating values like f32 and f64, this comparison produces an ordering
in accordance to
+/// the totalOrder predicate as defined in the IEEE 754 (2008 revision)
floating point standard.
+/// Note that totalOrder treats positive and negative zeros as different. If
it is necessary
+/// to treat them as equal, please normalize zeros before calling this kernel.
+///
+/// Please refer to [`f32::total_cmp`] and [`f64::total_cmp`]
+pub fn neq(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray,
ArrowError> {
+ compare_op(Op::NotEqual, lhs, rhs)
+}
+
+/// Perform `left < right` operation on two [`Datum`]
+///
+/// For floating values like f32 and f64, this comparison produces an ordering
in accordance to
+/// the totalOrder predicate as defined in the IEEE 754 (2008 revision)
floating point standard.
+/// Note that totalOrder treats positive and negative zeros as different. If
it is necessary
+/// to treat them as equal, please normalize zeros before calling this kernel.
+///
+/// Please refer to [`f32::total_cmp`] and [`f64::total_cmp`]
+pub fn lt(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray,
ArrowError> {
+ compare_op(Op::Less, lhs, rhs)
+}
+
+/// Perform `left <= right` operation on two [`Datum`]
+///
+/// For floating values like f32 and f64, this comparison produces an ordering
in accordance to
+/// the totalOrder predicate as defined in the IEEE 754 (2008 revision)
floating point standard.
+/// Note that totalOrder treats positive and negative zeros as different. If
it is necessary
+/// to treat them as equal, please normalize zeros before calling this kernel.
+///
+/// Please refer to [`f32::total_cmp`] and [`f64::total_cmp`]
+pub fn lt_eq(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray,
ArrowError> {
+ compare_op(Op::LessEqual, lhs, rhs)
+}
+
+/// Perform `left > right` operation on two [`Datum`]
+///
+/// For floating values like f32 and f64, this comparison produces an ordering
in accordance to
+/// the totalOrder predicate as defined in the IEEE 754 (2008 revision)
floating point standard.
+/// Note that totalOrder treats positive and negative zeros as different. If
it is necessary
+/// to treat them as equal, please normalize zeros before calling this kernel.
+///
+/// Please refer to [`f32::total_cmp`] and [`f64::total_cmp`]
+pub fn gt(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray,
ArrowError> {
+ compare_op(Op::Greater, lhs, rhs)
+}
+
+/// Perform `left >= right` operation on two [`Datum`]
+///
+/// For floating values like f32 and f64, this comparison produces an ordering
in accordance to
+/// the totalOrder predicate as defined in the IEEE 754 (2008 revision)
floating point standard.
+/// Note that totalOrder treats positive and negative zeros as different. If
it is necessary
+/// to treat them as equal, please normalize zeros before calling this kernel.
+///
+/// Please refer to [`f32::total_cmp`] and [`f64::total_cmp`]
+pub fn gt_eq(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray,
ArrowError> {
+ compare_op(Op::GreaterEqual, lhs, rhs)
+}
+
+/// Perform `op` on the provided `Datum`
+fn compare_op(
+ op: Op,
+ lhs: &dyn Datum,
+ rhs: &dyn Datum,
+) -> Result<BooleanArray, ArrowError> {
+ use arrow_schema::DataType::*;
+ let (l, l_s) = lhs.get();
+ let (r, r_s) = rhs.get();
+
+ let l_len = l.len();
+ let r_len = r.len();
+ let l_nulls = l.logical_nulls();
+ let r_nulls = r.logical_nulls();
+
+ let (len, nulls) = match (l_s, r_s) {
+ (true, true) | (false, false) => {
+ if l_len != r_len {
+ return Err(ArrowError::InvalidArgumentError(format!(
+ "Cannot compare arrays of different lengths, got {l_len}
vs {r_len}"
+ )));
+ }
+ (l_len, NullBuffer::union(l_nulls.as_ref(), r_nulls.as_ref()))
+ }
+ (true, false) => match l_nulls.map(|x| x.null_count() !=
0).unwrap_or_default() {
+ true => (r_len, Some(NullBuffer::new_null(r_len))),
+ false => (r_len, r_nulls), // Left is scalar and not null
+ },
+ (false, true) => match r_nulls.map(|x| x.null_count() !=
0).unwrap_or_default() {
+ true => (l_len, Some(NullBuffer::new_null(l_len))),
+ false => (l_len, l_nulls), // Right is scalar and not null
+ },
+ };
+
+ let l_v = as_dictionary(l);
+ let l = l_v.map(|x| x.values().as_ref()).unwrap_or(l);
+
+ let r_v = as_dictionary(r);
+ let r = r_v.map(|x| x.values().as_ref()).unwrap_or(r);
+
+ let values = downcast_primitive_array! {
+ (l, r) => apply(op, l.values().as_ref(), l_s, l_v,
r.values().as_ref(), r_s, r_v),
+ (Boolean, Boolean) => apply(op, l.as_boolean(), l_s, l_v,
r.as_boolean(), r_s, r_v),
+ (Utf8, Utf8) => apply(op, l.as_string::<i32>(), l_s, l_v,
r.as_string::<i32>(), r_s, r_v),
+ (LargeUtf8, LargeUtf8) => apply(op, l.as_string::<i64>(), l_s, l_v,
r.as_string::<i64>(), r_s, r_v),
+ (Binary, Binary) => apply(op, l.as_binary::<i32>(), l_s, l_v,
r.as_binary::<i32>(), r_s, r_v),
+ (LargeBinary, LargeBinary) => apply(op, l.as_binary::<i64>(), l_s,
l_v, r.as_binary::<i64>(), r_s, r_v),
+ (FixedSizeBinary(_), FixedSizeBinary(_)) => apply(op,
l.as_fixed_size_binary(), l_s, l_v, r.as_fixed_size_binary(), r_s, r_v),
+ (l_t, r_t) => return
Err(ArrowError::InvalidArgumentError(format!("Invalid comparison operation:
{l_t} {op} {r_t}"))),
+ }.unwrap_or_else(|| {
+ let count = nulls.as_ref().map(|x| x.null_count()).unwrap_or_default();
+ assert_eq!(count, len); // Sanity check
+ BooleanBuffer::new_unset(len)
+ });
+
+ assert_eq!(values.len(), len); // Sanity check
+ Ok(BooleanArray::new(values, nulls))
+}
+
+fn as_dictionary(a: &dyn Array) -> Option<&dyn Dictionary> {
Review Comment:
I was confused at first with this naming as it shadows
https://docs.rs/arrow/latest/arrow/array/trait.AsArray.html#method.as_dictionary
Maybe we could call it `as_dyn_dictionary` or something 🤔
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]