This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new aed319c10c Scalar arithmetic should return error when overflows.
(#5811)
aed319c10c is described below
commit aed319c10c87ab6f0490cc2e8657d98e454d7f80
Author: Zhiyuan Zheng <[email protected]>
AuthorDate: Wed Apr 12 02:07:12 2023 +0800
Scalar arithmetic should return error when overflows. (#5811)
* Scalar arithmetic should return error when overflows.
* Add a few tests.
* add new checked_* ops.
* fix test failure.
---
datafusion/common/src/scalar.rs | 185 +++++++++++++++++++++++++++++++++++++++-
1 file changed, 183 insertions(+), 2 deletions(-)
diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs
index 8c26ea9b04..63b6caa623 100644
--- a/datafusion/common/src/scalar.rs
+++ b/datafusion/common/src/scalar.rs
@@ -579,6 +579,41 @@ macro_rules! primitive_op {
}
};
}
+macro_rules! primitive_checked_op {
+ ($LEFT:expr, $RIGHT:expr, $SCALAR:ident, $FUNCTION:ident, $OPERATION:tt)
=> {
+ match ($LEFT, $RIGHT) {
+ (lhs, None) => Ok(ScalarValue::$SCALAR(*lhs)),
+ #[allow(unused_variables)]
+ (None, Some(b)) => {
+ primitive_checked_right!(*b, $OPERATION, $SCALAR)
+ }
+ (Some(a), Some(b)) => {
+ if let Some(value) = (*a).$FUNCTION(*b) {
+ Ok(ScalarValue::$SCALAR(Some(value)))
+ } else {
+ Err(DataFusionError::Execution(
+ "Overflow while calculating ScalarValue.".to_string(),
+ ))
+ }
+ }
+ }
+ };
+}
+
+macro_rules! primitive_checked_right {
+ ($TERM:expr, -, $SCALAR:ident) => {
+ if let Some(value) = $TERM.checked_neg() {
+ Ok(ScalarValue::$SCALAR(Some(value)))
+ } else {
+ Err(DataFusionError::Execution(
+ "Overflow while calculating ScalarValue.".to_string(),
+ ))
+ }
+ };
+ ($TERM:expr, $OPERATION:tt, $SCALAR:ident) => {
+ primitive_right!($TERM, $OPERATION, $SCALAR)
+ };
+}
macro_rules! primitive_right {
($TERM:expr, +, $SCALAR:ident) => {
@@ -625,6 +660,41 @@ macro_rules! unsigned_subtraction_error {
}};
}
+macro_rules! impl_checked_op {
+ ($LHS:expr, $RHS:expr, $FUNCTION:ident, $OPERATION:tt) => {
+ // Only covering primitive types that support checked_* operands, and
fall back to raw operation for other types.
+ match ($LHS, $RHS) {
+ (ScalarValue::UInt64(lhs), ScalarValue::UInt64(rhs)) => {
+ primitive_checked_op!(lhs, rhs, UInt64, $FUNCTION, $OPERATION)
+ },
+ (ScalarValue::Int64(lhs), ScalarValue::Int64(rhs)) => {
+ primitive_checked_op!(lhs, rhs, Int64, $FUNCTION, $OPERATION)
+ },
+ (ScalarValue::UInt32(lhs), ScalarValue::UInt32(rhs)) => {
+ primitive_checked_op!(lhs, rhs, UInt32, $FUNCTION, $OPERATION)
+ },
+ (ScalarValue::Int32(lhs), ScalarValue::Int32(rhs)) => {
+ primitive_checked_op!(lhs, rhs, Int32, $FUNCTION, $OPERATION)
+ },
+ (ScalarValue::UInt16(lhs), ScalarValue::UInt16(rhs)) => {
+ primitive_checked_op!(lhs, rhs, UInt16, $FUNCTION, $OPERATION)
+ },
+ (ScalarValue::Int16(lhs), ScalarValue::Int16(rhs)) => {
+ primitive_checked_op!(lhs, rhs, Int16, $FUNCTION, $OPERATION)
+ },
+ (ScalarValue::UInt8(lhs), ScalarValue::UInt8(rhs)) => {
+ primitive_checked_op!(lhs, rhs, UInt8, $FUNCTION, $OPERATION)
+ },
+ (ScalarValue::Int8(lhs), ScalarValue::Int8(rhs)) => {
+ primitive_checked_op!(lhs, rhs, Int8, $FUNCTION, $OPERATION)
+ },
+ _ => {
+ impl_op!($LHS, $RHS, $OPERATION)
+ }
+ }
+ };
+}
+
macro_rules! impl_op {
($LHS:expr, $RHS:expr, +) => {
impl_op_arithmetic!($LHS, $RHS, +)
@@ -1855,11 +1925,21 @@ impl ScalarValue {
impl_op!(self, rhs, +)
}
+ pub fn add_checked<T: Borrow<ScalarValue>>(&self, other: T) ->
Result<ScalarValue> {
+ let rhs = other.borrow();
+ impl_checked_op!(self, rhs, checked_add, +)
+ }
+
pub fn sub<T: Borrow<ScalarValue>>(&self, other: T) -> Result<ScalarValue>
{
let rhs = other.borrow();
impl_op!(self, rhs, -)
}
+ pub fn sub_checked<T: Borrow<ScalarValue>>(&self, other: T) ->
Result<ScalarValue> {
+ let rhs = other.borrow();
+ impl_checked_op!(self, rhs, checked_sub, -)
+ }
+
pub fn is_unsigned(&self) -> bool {
matches!(
self,
@@ -1926,9 +2006,9 @@ impl ScalarValue {
}
let distance = if self > other {
- self.sub(other).ok()?
+ self.sub_checked(other).ok()?
} else {
- other.sub(self).ok()?
+ other.sub_checked(self).ok()?
};
match distance {
@@ -3625,8 +3705,10 @@ mod tests {
use std::cmp::Ordering;
use std::sync::Arc;
+ use arrow::compute;
use arrow::compute::kernels;
use arrow::datatypes::ArrowPrimitiveType;
+ use arrow_array::ArrowNumericType;
use rand::Rng;
use crate::cast::{as_string_array, as_uint32_array, as_uint64_array};
@@ -3664,6 +3746,100 @@ mod tests {
Ok(())
}
+ #[test]
+ fn scalar_sub_trait_int32_test() -> Result<()> {
+ let int_value = ScalarValue::Int32(Some(42));
+ let int_value_2 = ScalarValue::Int32(Some(100));
+ assert_eq!(int_value.sub(&int_value_2)?,
ScalarValue::Int32(Some(-58)));
+ assert_eq!(int_value_2.sub(int_value)?, ScalarValue::Int32(Some(58)));
+ Ok(())
+ }
+
+ #[test]
+ fn scalar_sub_trait_int32_overflow_test() -> Result<()> {
+ let int_value = ScalarValue::Int32(Some(i32::MAX));
+ let int_value_2 = ScalarValue::Int32(Some(i32::MIN));
+ assert!(matches!(
+ int_value.sub_checked(&int_value_2),
+ Err(DataFusionError::Execution(msg)) if msg == "Overflow while
calculating ScalarValue."
+ ));
+ Ok(())
+ }
+
+ #[test]
+ fn scalar_sub_trait_int64_test() -> Result<()> {
+ let int_value = ScalarValue::Int64(Some(42));
+ let int_value_2 = ScalarValue::Int64(Some(100));
+ assert_eq!(int_value.sub(&int_value_2)?,
ScalarValue::Int64(Some(-58)));
+ assert_eq!(int_value_2.sub(int_value)?, ScalarValue::Int64(Some(58)));
+ Ok(())
+ }
+
+ #[test]
+ fn scalar_sub_trait_int64_overflow_test() -> Result<()> {
+ let int_value = ScalarValue::Int64(Some(i64::MAX));
+ let int_value_2 = ScalarValue::Int64(Some(i64::MIN));
+ assert!(matches!(
+ int_value.sub_checked(&int_value_2),
+ Err(DataFusionError::Execution(msg)) if msg == "Overflow while
calculating ScalarValue."
+ ));
+ Ok(())
+ }
+
+ #[test]
+ fn scalar_add_overflow_test() -> Result<()> {
+ check_scalar_add_overflow::<Int8Type>(
+ ScalarValue::Int8(Some(i8::MAX)),
+ ScalarValue::Int8(Some(i8::MAX)),
+ );
+ check_scalar_add_overflow::<UInt8Type>(
+ ScalarValue::UInt8(Some(u8::MAX)),
+ ScalarValue::UInt8(Some(u8::MAX)),
+ );
+ check_scalar_add_overflow::<Int16Type>(
+ ScalarValue::Int16(Some(i16::MAX)),
+ ScalarValue::Int16(Some(i16::MAX)),
+ );
+ check_scalar_add_overflow::<UInt16Type>(
+ ScalarValue::UInt16(Some(u16::MAX)),
+ ScalarValue::UInt16(Some(u16::MAX)),
+ );
+ check_scalar_add_overflow::<Int32Type>(
+ ScalarValue::Int32(Some(i32::MAX)),
+ ScalarValue::Int32(Some(i32::MAX)),
+ );
+ check_scalar_add_overflow::<UInt32Type>(
+ ScalarValue::UInt32(Some(u32::MAX)),
+ ScalarValue::UInt32(Some(u32::MAX)),
+ );
+ check_scalar_add_overflow::<Int64Type>(
+ ScalarValue::Int64(Some(i64::MAX)),
+ ScalarValue::Int64(Some(i64::MAX)),
+ );
+ check_scalar_add_overflow::<UInt64Type>(
+ ScalarValue::UInt64(Some(u64::MAX)),
+ ScalarValue::UInt64(Some(u64::MAX)),
+ );
+
+ Ok(())
+ }
+
+ // Verifies that ScalarValue has the same behavior with compute kernal
when it overflows.
+ fn check_scalar_add_overflow<T>(left: ScalarValue, right: ScalarValue)
+ where
+ T: ArrowNumericType,
+ {
+ let scalar_result = left.add_checked(&right);
+
+ let left_array = left.to_array();
+ let right_array = right.to_array();
+ let arrow_left_array = left_array.as_primitive::<T>();
+ let arrow_right_array = right_array.as_primitive::<T>();
+ let arrow_result = compute::add_checked(arrow_left_array,
arrow_right_array);
+
+ assert_eq!(scalar_result.is_ok(), arrow_result.is_ok());
+ }
+
#[test]
fn test_interval_add_timestamp() -> Result<()> {
let interval = ScalarValue::IntervalMonthDayNano(Some(123));
@@ -5231,6 +5407,11 @@ mod tests {
ScalarValue::Decimal128(Some(123), 5, 5),
ScalarValue::Decimal128(Some(120), 5, 5),
),
+ // Overflows
+ (
+ ScalarValue::Int8(Some(i8::MAX)),
+ ScalarValue::Int8(Some(i8::MIN)),
+ ),
];
for (lhs, rhs) in cases {
let distance = lhs.distance(&rhs);