This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 2413155a3e feat: Add `fail_on_overflow` option to `BinaryExpr` (#11400)
2413155a3e is described below
commit 2413155a3ed808285e31421a8b6aac23b8abdb91
Author: Andy Grove <[email protected]>
AuthorDate: Thu Jul 11 08:56:47 2024 -0600
feat: Add `fail_on_overflow` option to `BinaryExpr` (#11400)
* update tests
* update tests
* add rustdoc
* update PartialEq impl
* fix
* address feedback about improving api
---
datafusion/core/src/physical_planner.rs | 4 +-
datafusion/physical-expr/src/expressions/binary.rs | 126 +++++++++++++++++++--
2 files changed, 121 insertions(+), 9 deletions(-)
diff --git a/datafusion/core/src/physical_planner.rs
b/datafusion/core/src/physical_planner.rs
index 6aad4d5755..d2bc334ec3 100644
--- a/datafusion/core/src/physical_planner.rs
+++ b/datafusion/core/src/physical_planner.rs
@@ -2312,7 +2312,7 @@ mod tests {
// verify that the plan correctly casts u8 to i64
// the cast from u8 to i64 for literal will be simplified, and get
lit(int64(5))
// the cast here is implicit so has CastOptions with safe=true
- let expected = "BinaryExpr { left: Column { name: \"c7\", index: 2 },
op: Lt, right: Literal { value: Int64(5) } }";
+ let expected = "BinaryExpr { left: Column { name: \"c7\", index: 2 },
op: Lt, right: Literal { value: Int64(5) }, fail_on_overflow: false }";
assert!(format!("{exec_plan:?}").contains(expected));
Ok(())
}
@@ -2551,7 +2551,7 @@ mod tests {
let execution_plan = plan(&logical_plan).await?;
// verify that the plan correctly adds cast from Int64(1) to Utf8, and
the const will be evaluated.
- let expected = "expr: [(BinaryExpr { left: BinaryExpr { left: Column {
name: \"c1\", index: 0 }, op: Eq, right: Literal { value: Utf8(\"a\") } }, op:
Or, right: BinaryExpr { left: Column { name: \"c1\", index: 0 }, op: Eq, right:
Literal { value: Utf8(\"1\") } } }";
+ let expected = "expr: [(BinaryExpr { left: BinaryExpr { left: Column {
name: \"c1\", index: 0 }, op: Eq, right: Literal { value: Utf8(\"a\") },
fail_on_overflow: false }, op: Or, right: BinaryExpr { left: Column { name:
\"c1\", index: 0 }, op: Eq, right: Literal { value: Utf8(\"1\") },
fail_on_overflow: false }, fail_on_overflow: false }";
let actual = format!("{execution_plan:?}");
assert!(actual.contains(expected), "{}", actual);
diff --git a/datafusion/physical-expr/src/expressions/binary.rs
b/datafusion/physical-expr/src/expressions/binary.rs
index c153ead963..c34dcdfb75 100644
--- a/datafusion/physical-expr/src/expressions/binary.rs
+++ b/datafusion/physical-expr/src/expressions/binary.rs
@@ -53,6 +53,8 @@ pub struct BinaryExpr {
left: Arc<dyn PhysicalExpr>,
op: Operator,
right: Arc<dyn PhysicalExpr>,
+ /// Specifies whether an error is returned on overflow or not
+ fail_on_overflow: bool,
}
impl BinaryExpr {
@@ -62,7 +64,22 @@ impl BinaryExpr {
op: Operator,
right: Arc<dyn PhysicalExpr>,
) -> Self {
- Self { left, op, right }
+ Self {
+ left,
+ op,
+ right,
+ fail_on_overflow: false,
+ }
+ }
+
+ /// Create new binary expression with explicit fail_on_overflow value
+ pub fn with_fail_on_overflow(self, fail_on_overflow: bool) -> Self {
+ Self {
+ left: self.left,
+ op: self.op,
+ right: self.right,
+ fail_on_overflow,
+ }
}
/// Get the left side of the binary expression
@@ -273,8 +290,11 @@ impl PhysicalExpr for BinaryExpr {
}
match self.op {
+ Operator::Plus if self.fail_on_overflow => return apply(&lhs,
&rhs, add),
Operator::Plus => return apply(&lhs, &rhs, add_wrapping),
+ Operator::Minus if self.fail_on_overflow => return apply(&lhs,
&rhs, sub),
Operator::Minus => return apply(&lhs, &rhs, sub_wrapping),
+ Operator::Multiply if self.fail_on_overflow => return apply(&lhs,
&rhs, mul),
Operator::Multiply => return apply(&lhs, &rhs, mul_wrapping),
Operator::Divide => return apply(&lhs, &rhs, div),
Operator::Modulo => return apply(&lhs, &rhs, rem),
@@ -327,11 +347,10 @@ impl PhysicalExpr for BinaryExpr {
self: Arc<Self>,
children: Vec<Arc<dyn PhysicalExpr>>,
) -> Result<Arc<dyn PhysicalExpr>> {
- Ok(Arc::new(BinaryExpr::new(
- Arc::clone(&children[0]),
- self.op,
- Arc::clone(&children[1]),
- )))
+ Ok(Arc::new(
+ BinaryExpr::new(Arc::clone(&children[0]), self.op,
Arc::clone(&children[1]))
+ .with_fail_on_overflow(self.fail_on_overflow),
+ ))
}
fn evaluate_bounds(&self, children: &[&Interval]) -> Result<Interval> {
@@ -496,7 +515,12 @@ impl PartialEq<dyn Any> for BinaryExpr {
fn eq(&self, other: &dyn Any) -> bool {
down_cast_any_ref(other)
.downcast_ref::<Self>()
- .map(|x| self.left.eq(&x.left) && self.op == x.op &&
self.right.eq(&x.right))
+ .map(|x| {
+ self.left.eq(&x.left)
+ && self.op == x.op
+ && self.right.eq(&x.right)
+ && self.fail_on_overflow.eq(&x.fail_on_overflow)
+ })
.unwrap_or(false)
}
}
@@ -661,6 +685,7 @@ mod tests {
use datafusion_common::plan_datafusion_err;
use datafusion_expr::type_coercion::binary::get_input_types;
+ use datafusion_physical_expr_common::expressions::column::Column;
/// Performs a binary operation, applying any type coercion necessary
fn binary_op(
@@ -4008,4 +4033,91 @@ mod tests {
.unwrap();
assert_eq!(&casted, &dictionary);
}
+
+ #[test]
+ fn test_add_with_overflow() -> Result<()> {
+ // create test data
+ let l = Arc::new(Int32Array::from(vec![1, i32::MAX]));
+ let r = Arc::new(Int32Array::from(vec![2, 1]));
+ let schema = Arc::new(Schema::new(vec![
+ Field::new("l", DataType::Int32, false),
+ Field::new("r", DataType::Int32, false),
+ ]));
+ let batch = RecordBatch::try_new(schema, vec![l, r])?;
+
+ // create expression
+ let expr = BinaryExpr::new(
+ Arc::new(Column::new("l", 0)),
+ Operator::Plus,
+ Arc::new(Column::new("r", 1)),
+ )
+ .with_fail_on_overflow(true);
+
+ // evaluate expression
+ let result = expr.evaluate(&batch);
+ assert!(result
+ .err()
+ .unwrap()
+ .to_string()
+ .contains("Overflow happened on: 2147483647 + 1"));
+ Ok(())
+ }
+
+ #[test]
+ fn test_subtract_with_overflow() -> Result<()> {
+ // create test data
+ let l = Arc::new(Int32Array::from(vec![1, i32::MIN]));
+ let r = Arc::new(Int32Array::from(vec![2, 1]));
+ let schema = Arc::new(Schema::new(vec![
+ Field::new("l", DataType::Int32, false),
+ Field::new("r", DataType::Int32, false),
+ ]));
+ let batch = RecordBatch::try_new(schema, vec![l, r])?;
+
+ // create expression
+ let expr = BinaryExpr::new(
+ Arc::new(Column::new("l", 0)),
+ Operator::Minus,
+ Arc::new(Column::new("r", 1)),
+ )
+ .with_fail_on_overflow(true);
+
+ // evaluate expression
+ let result = expr.evaluate(&batch);
+ assert!(result
+ .err()
+ .unwrap()
+ .to_string()
+ .contains("Overflow happened on: -2147483648 - 1"));
+ Ok(())
+ }
+
+ #[test]
+ fn test_mul_with_overflow() -> Result<()> {
+ // create test data
+ let l = Arc::new(Int32Array::from(vec![1, i32::MAX]));
+ let r = Arc::new(Int32Array::from(vec![2, 2]));
+ let schema = Arc::new(Schema::new(vec![
+ Field::new("l", DataType::Int32, false),
+ Field::new("r", DataType::Int32, false),
+ ]));
+ let batch = RecordBatch::try_new(schema, vec![l, r])?;
+
+ // create expression
+ let expr = BinaryExpr::new(
+ Arc::new(Column::new("l", 0)),
+ Operator::Multiply,
+ Arc::new(Column::new("r", 1)),
+ )
+ .with_fail_on_overflow(true);
+
+ // evaluate expression
+ let result = expr.evaluate(&batch);
+ assert!(result
+ .err()
+ .unwrap()
+ .to_string()
+ .contains("Overflow happened on: 2147483647 * 2"));
+ Ok(())
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]