This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 2413155a3e feat: Add `fail_on_overflow` option to `BinaryExpr` (#11400)
2413155a3e is described below

commit 2413155a3ed808285e31421a8b6aac23b8abdb91
Author: Andy Grove <[email protected]>
AuthorDate: Thu Jul 11 08:56:47 2024 -0600

    feat: Add `fail_on_overflow` option to `BinaryExpr` (#11400)
    
    * update tests
    
    * update tests
    
    * add rustdoc
    
    * update PartialEq impl
    
    * fix
    
    * address feedback about improving api
---
 datafusion/core/src/physical_planner.rs            |   4 +-
 datafusion/physical-expr/src/expressions/binary.rs | 126 +++++++++++++++++++--
 2 files changed, 121 insertions(+), 9 deletions(-)

diff --git a/datafusion/core/src/physical_planner.rs 
b/datafusion/core/src/physical_planner.rs
index 6aad4d5755..d2bc334ec3 100644
--- a/datafusion/core/src/physical_planner.rs
+++ b/datafusion/core/src/physical_planner.rs
@@ -2312,7 +2312,7 @@ mod tests {
         // verify that the plan correctly casts u8 to i64
         // the cast from u8 to i64 for literal will be simplified, and get 
lit(int64(5))
         // the cast here is implicit so has CastOptions with safe=true
-        let expected = "BinaryExpr { left: Column { name: \"c7\", index: 2 }, 
op: Lt, right: Literal { value: Int64(5) } }";
+        let expected = "BinaryExpr { left: Column { name: \"c7\", index: 2 }, 
op: Lt, right: Literal { value: Int64(5) }, fail_on_overflow: false }";
         assert!(format!("{exec_plan:?}").contains(expected));
         Ok(())
     }
@@ -2551,7 +2551,7 @@ mod tests {
         let execution_plan = plan(&logical_plan).await?;
         // verify that the plan correctly adds cast from Int64(1) to Utf8, and 
the const will be evaluated.
 
-        let expected = "expr: [(BinaryExpr { left: BinaryExpr { left: Column { 
name: \"c1\", index: 0 }, op: Eq, right: Literal { value: Utf8(\"a\") } }, op: 
Or, right: BinaryExpr { left: Column { name: \"c1\", index: 0 }, op: Eq, right: 
Literal { value: Utf8(\"1\") } } }";
+        let expected = "expr: [(BinaryExpr { left: BinaryExpr { left: Column { 
name: \"c1\", index: 0 }, op: Eq, right: Literal { value: Utf8(\"a\") }, 
fail_on_overflow: false }, op: Or, right: BinaryExpr { left: Column { name: 
\"c1\", index: 0 }, op: Eq, right: Literal { value: Utf8(\"1\") }, 
fail_on_overflow: false }, fail_on_overflow: false }";
 
         let actual = format!("{execution_plan:?}");
         assert!(actual.contains(expected), "{}", actual);
diff --git a/datafusion/physical-expr/src/expressions/binary.rs 
b/datafusion/physical-expr/src/expressions/binary.rs
index c153ead963..c34dcdfb75 100644
--- a/datafusion/physical-expr/src/expressions/binary.rs
+++ b/datafusion/physical-expr/src/expressions/binary.rs
@@ -53,6 +53,8 @@ pub struct BinaryExpr {
     left: Arc<dyn PhysicalExpr>,
     op: Operator,
     right: Arc<dyn PhysicalExpr>,
+    /// Specifies whether an error is returned on overflow or not
+    fail_on_overflow: bool,
 }
 
 impl BinaryExpr {
@@ -62,7 +64,22 @@ impl BinaryExpr {
         op: Operator,
         right: Arc<dyn PhysicalExpr>,
     ) -> Self {
-        Self { left, op, right }
+        Self {
+            left,
+            op,
+            right,
+            fail_on_overflow: false,
+        }
+    }
+
+    /// Create new binary expression with explicit fail_on_overflow value
+    pub fn with_fail_on_overflow(self, fail_on_overflow: bool) -> Self {
+        Self {
+            left: self.left,
+            op: self.op,
+            right: self.right,
+            fail_on_overflow,
+        }
     }
 
     /// Get the left side of the binary expression
@@ -273,8 +290,11 @@ impl PhysicalExpr for BinaryExpr {
         }
 
         match self.op {
+            Operator::Plus if self.fail_on_overflow => return apply(&lhs, 
&rhs, add),
             Operator::Plus => return apply(&lhs, &rhs, add_wrapping),
+            Operator::Minus if self.fail_on_overflow => return apply(&lhs, 
&rhs, sub),
             Operator::Minus => return apply(&lhs, &rhs, sub_wrapping),
+            Operator::Multiply if self.fail_on_overflow => return apply(&lhs, 
&rhs, mul),
             Operator::Multiply => return apply(&lhs, &rhs, mul_wrapping),
             Operator::Divide => return apply(&lhs, &rhs, div),
             Operator::Modulo => return apply(&lhs, &rhs, rem),
@@ -327,11 +347,10 @@ impl PhysicalExpr for BinaryExpr {
         self: Arc<Self>,
         children: Vec<Arc<dyn PhysicalExpr>>,
     ) -> Result<Arc<dyn PhysicalExpr>> {
-        Ok(Arc::new(BinaryExpr::new(
-            Arc::clone(&children[0]),
-            self.op,
-            Arc::clone(&children[1]),
-        )))
+        Ok(Arc::new(
+            BinaryExpr::new(Arc::clone(&children[0]), self.op, 
Arc::clone(&children[1]))
+                .with_fail_on_overflow(self.fail_on_overflow),
+        ))
     }
 
     fn evaluate_bounds(&self, children: &[&Interval]) -> Result<Interval> {
@@ -496,7 +515,12 @@ impl PartialEq<dyn Any> for BinaryExpr {
     fn eq(&self, other: &dyn Any) -> bool {
         down_cast_any_ref(other)
             .downcast_ref::<Self>()
-            .map(|x| self.left.eq(&x.left) && self.op == x.op && 
self.right.eq(&x.right))
+            .map(|x| {
+                self.left.eq(&x.left)
+                    && self.op == x.op
+                    && self.right.eq(&x.right)
+                    && self.fail_on_overflow.eq(&x.fail_on_overflow)
+            })
             .unwrap_or(false)
     }
 }
@@ -661,6 +685,7 @@ mod tests {
 
     use datafusion_common::plan_datafusion_err;
     use datafusion_expr::type_coercion::binary::get_input_types;
+    use datafusion_physical_expr_common::expressions::column::Column;
 
     /// Performs a binary operation, applying any type coercion necessary
     fn binary_op(
@@ -4008,4 +4033,91 @@ mod tests {
         .unwrap();
         assert_eq!(&casted, &dictionary);
     }
+
+    #[test]
+    fn test_add_with_overflow() -> Result<()> {
+        // create test data
+        let l = Arc::new(Int32Array::from(vec![1, i32::MAX]));
+        let r = Arc::new(Int32Array::from(vec![2, 1]));
+        let schema = Arc::new(Schema::new(vec![
+            Field::new("l", DataType::Int32, false),
+            Field::new("r", DataType::Int32, false),
+        ]));
+        let batch = RecordBatch::try_new(schema, vec![l, r])?;
+
+        // create expression
+        let expr = BinaryExpr::new(
+            Arc::new(Column::new("l", 0)),
+            Operator::Plus,
+            Arc::new(Column::new("r", 1)),
+        )
+        .with_fail_on_overflow(true);
+
+        // evaluate expression
+        let result = expr.evaluate(&batch);
+        assert!(result
+            .err()
+            .unwrap()
+            .to_string()
+            .contains("Overflow happened on: 2147483647 + 1"));
+        Ok(())
+    }
+
+    #[test]
+    fn test_subtract_with_overflow() -> Result<()> {
+        // create test data
+        let l = Arc::new(Int32Array::from(vec![1, i32::MIN]));
+        let r = Arc::new(Int32Array::from(vec![2, 1]));
+        let schema = Arc::new(Schema::new(vec![
+            Field::new("l", DataType::Int32, false),
+            Field::new("r", DataType::Int32, false),
+        ]));
+        let batch = RecordBatch::try_new(schema, vec![l, r])?;
+
+        // create expression
+        let expr = BinaryExpr::new(
+            Arc::new(Column::new("l", 0)),
+            Operator::Minus,
+            Arc::new(Column::new("r", 1)),
+        )
+        .with_fail_on_overflow(true);
+
+        // evaluate expression
+        let result = expr.evaluate(&batch);
+        assert!(result
+            .err()
+            .unwrap()
+            .to_string()
+            .contains("Overflow happened on: -2147483648 - 1"));
+        Ok(())
+    }
+
+    #[test]
+    fn test_mul_with_overflow() -> Result<()> {
+        // create test data
+        let l = Arc::new(Int32Array::from(vec![1, i32::MAX]));
+        let r = Arc::new(Int32Array::from(vec![2, 2]));
+        let schema = Arc::new(Schema::new(vec![
+            Field::new("l", DataType::Int32, false),
+            Field::new("r", DataType::Int32, false),
+        ]));
+        let batch = RecordBatch::try_new(schema, vec![l, r])?;
+
+        // create expression
+        let expr = BinaryExpr::new(
+            Arc::new(Column::new("l", 0)),
+            Operator::Multiply,
+            Arc::new(Column::new("r", 1)),
+        )
+        .with_fail_on_overflow(true);
+
+        // evaluate expression
+        let result = expr.evaluate(&batch);
+        assert!(result
+            .err()
+            .unwrap()
+            .to_string()
+            .contains("Overflow happened on: 2147483647 * 2"));
+        Ok(())
+    }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to