This is an automated email from the ASF dual-hosted git repository.
houqp pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 4ddd2f5 Implement PartialOrd for ScalarValue (#838)
4ddd2f5 is described below
commit 4ddd2f5e7582ffe662aea27bbb74c58cd0715152
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Sun Aug 8 15:42:31 2021 -0700
Implement PartialOrd for ScalarValue (#838)
* Implement PartialOrd for ScalarValue.
* Avoid catch all match.
---
datafusion/src/scalar.rs | 146 +++++++++++++++++++++++++++++++++++++++++++++++
1 file changed, 146 insertions(+)
diff --git a/datafusion/src/scalar.rs b/datafusion/src/scalar.rs
index 3896055..3fbcadd 100644
--- a/datafusion/src/scalar.rs
+++ b/datafusion/src/scalar.rs
@@ -28,6 +28,7 @@ use arrow::{
},
};
use ordered_float::OrderedFloat;
+use std::cmp::Ordering;
use std::convert::{Infallible, TryInto};
use std::str::FromStr;
use std::{convert::TryFrom, fmt, iter::repeat, sync::Arc};
@@ -156,6 +157,81 @@ impl PartialEq for ScalarValue {
}
}
+// manual implementation of `PartialOrd` that uses OrderedFloat to
+// get defined behavior for floating point
+impl PartialOrd for ScalarValue {
+ fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
+ use ScalarValue::*;
+ // This purposely doesn't have a catch-all "(_, _)" so that
+ // any newly added enum variant will require editing this list
+ // or else face a compile error
+ match (self, other) {
+ (Boolean(v1), Boolean(v2)) => v1.partial_cmp(v2),
+ (Boolean(_), _) => None,
+ (Float32(v1), Float32(v2)) => {
+ let v1 = v1.map(OrderedFloat);
+ let v2 = v2.map(OrderedFloat);
+ v1.partial_cmp(&v2)
+ }
+ (Float32(_), _) => None,
+ (Float64(v1), Float64(v2)) => {
+ let v1 = v1.map(OrderedFloat);
+ let v2 = v2.map(OrderedFloat);
+ v1.partial_cmp(&v2)
+ }
+ (Float64(_), _) => None,
+ (Int8(v1), Int8(v2)) => v1.partial_cmp(v2),
+ (Int8(_), _) => None,
+ (Int16(v1), Int16(v2)) => v1.partial_cmp(v2),
+ (Int16(_), _) => None,
+ (Int32(v1), Int32(v2)) => v1.partial_cmp(v2),
+ (Int32(_), _) => None,
+ (Int64(v1), Int64(v2)) => v1.partial_cmp(v2),
+ (Int64(_), _) => None,
+ (UInt8(v1), UInt8(v2)) => v1.partial_cmp(v2),
+ (UInt8(_), _) => None,
+ (UInt16(v1), UInt16(v2)) => v1.partial_cmp(v2),
+ (UInt16(_), _) => None,
+ (UInt32(v1), UInt32(v2)) => v1.partial_cmp(v2),
+ (UInt32(_), _) => None,
+ (UInt64(v1), UInt64(v2)) => v1.partial_cmp(v2),
+ (UInt64(_), _) => None,
+ (Utf8(v1), Utf8(v2)) => v1.partial_cmp(v2),
+ (Utf8(_), _) => None,
+ (LargeUtf8(v1), LargeUtf8(v2)) => v1.partial_cmp(v2),
+ (LargeUtf8(_), _) => None,
+ (Binary(v1), Binary(v2)) => v1.partial_cmp(v2),
+ (Binary(_), _) => None,
+ (LargeBinary(v1), LargeBinary(v2)) => v1.partial_cmp(v2),
+ (LargeBinary(_), _) => None,
+ (List(v1, t1), List(v2, t2)) => {
+ if t1.eq(t2) {
+ v1.partial_cmp(v2)
+ } else {
+ None
+ }
+ }
+ (List(_, _), _) => None,
+ (Date32(v1), Date32(v2)) => v1.partial_cmp(v2),
+ (Date32(_), _) => None,
+ (Date64(v1), Date64(v2)) => v1.partial_cmp(v2),
+ (Date64(_), _) => None,
+ (TimestampSecond(v1), TimestampSecond(v2)) => v1.partial_cmp(v2),
+ (TimestampSecond(_), _) => None,
+ (TimestampMillisecond(v1), TimestampMillisecond(v2)) =>
v1.partial_cmp(v2),
+ (TimestampMillisecond(_), _) => None,
+ (TimestampMicrosecond(v1), TimestampMicrosecond(v2)) =>
v1.partial_cmp(v2),
+ (TimestampMicrosecond(_), _) => None,
+ (TimestampNanosecond(v1), TimestampNanosecond(v2)) =>
v1.partial_cmp(v2),
+ (TimestampNanosecond(_), _) => None,
+ (IntervalYearMonth(v1), IntervalYearMonth(v2)) =>
v1.partial_cmp(v2),
+ (IntervalYearMonth(_), _) => None,
+ (IntervalDayTime(v1), IntervalDayTime(v2)) => v1.partial_cmp(v2),
+ (IntervalDayTime(_), _) => None,
+ }
+ }
+}
+
impl Eq for ScalarValue {}
// manual implementation of `Hash` that uses OrderedFloat to
@@ -1577,4 +1653,74 @@ mod tests {
// per distinct value.
assert_eq!(std::mem::size_of::<ScalarValue>(), 32);
}
+
+ #[test]
+ fn scalar_partial_ordering() {
+ use ScalarValue::*;
+
+ assert_eq!(
+ Int64(Some(33)).partial_cmp(&Int64(Some(0))),
+ Some(Ordering::Greater)
+ );
+ assert_eq!(
+ Int64(Some(0)).partial_cmp(&Int64(Some(33))),
+ Some(Ordering::Less)
+ );
+ assert_eq!(
+ Int64(Some(33)).partial_cmp(&Int64(Some(33))),
+ Some(Ordering::Equal)
+ );
+ // For different data type, `partial_cmp` returns None.
+ assert_eq!(Int64(Some(33)).partial_cmp(&Int32(Some(33))), None);
+ assert_eq!(Int32(Some(33)).partial_cmp(&Int64(Some(33))), None);
+
+ assert_eq!(
+ List(
+ Some(Box::new(vec![Int32(Some(1)), Int32(Some(5))])),
+ Box::new(DataType::Int32)
+ )
+ .partial_cmp(&List(
+ Some(Box::new(vec![Int32(Some(1)), Int32(Some(5))])),
+ Box::new(DataType::Int32)
+ )),
+ Some(Ordering::Equal)
+ );
+
+ assert_eq!(
+ List(
+ Some(Box::new(vec![Int32(Some(10)), Int32(Some(5))])),
+ Box::new(DataType::Int32)
+ )
+ .partial_cmp(&List(
+ Some(Box::new(vec![Int32(Some(1)), Int32(Some(5))])),
+ Box::new(DataType::Int32)
+ )),
+ Some(Ordering::Greater)
+ );
+
+ assert_eq!(
+ List(
+ Some(Box::new(vec![Int32(Some(1)), Int32(Some(5))])),
+ Box::new(DataType::Int32)
+ )
+ .partial_cmp(&List(
+ Some(Box::new(vec![Int32(Some(10)), Int32(Some(5))])),
+ Box::new(DataType::Int32)
+ )),
+ Some(Ordering::Less)
+ );
+
+ // For different data type, `partial_cmp` returns None.
+ assert_eq!(
+ List(
+ Some(Box::new(vec![Int64(Some(1)), Int64(Some(5))])),
+ Box::new(DataType::Int64)
+ )
+ .partial_cmp(&List(
+ Some(Box::new(vec![Int32(Some(1)), Int32(Some(5))])),
+ Box::new(DataType::Int32)
+ )),
+ None
+ );
+ }
}