This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 1ed7a6f65 Simplification Rules for Modulo Operator (#3669)
1ed7a6f65 is described below

commit 1ed7a6f65e7db232b61d45e819b98af039f4a489
Author: askoa <[email protected]>
AuthorDate: Mon Oct 3 09:46:46 2022 -0400

    Simplification Rules for Modulo Operator (#3669)
    
    * Simplify Rules for Modulo Operator
    
    * add divide by zero error
    
    * fix PR comments
    
    * remove error on mod by zero based on PR comment
    
    * fix clippy issues
    
    * add mod by zero back
    
    Co-authored-by: askoa <askoa@local>
---
 datafusion/optimizer/src/simplify_expressions.rs | 75 +++++++++++++++++++++++-
 1 file changed, 74 insertions(+), 1 deletion(-)

diff --git a/datafusion/optimizer/src/simplify_expressions.rs 
b/datafusion/optimizer/src/simplify_expressions.rs
index 969fa0169..4f92a65e4 100644
--- a/datafusion/optimizer/src/simplify_expressions.rs
+++ b/datafusion/optimizer/src/simplify_expressions.rs
@@ -21,6 +21,7 @@ use crate::expr_simplifier::ExprSimplifiable;
 use crate::{expr_simplifier::SimplifyInfo, OptimizerConfig, OptimizerRule};
 use arrow::array::new_null_array;
 use arrow::datatypes::{DataType, Field, Schema};
+use arrow::error::ArrowError;
 use arrow::record_batch::RecordBatch;
 use datafusion_common::{DFSchema, DFSchemaRef, DataFusionError, Result, 
ScalarValue};
 use datafusion_expr::{
@@ -549,7 +550,7 @@ impl<'a, S: SimplifyInfo> ExprRewriter for Simplifier<'a, 
S> {
     /// rewrite the expression simplifying any constant expressions
     fn mutate(&mut self, expr: Expr) -> Result<Expr> {
         use Expr::*;
-        use Operator::{And, Divide, Eq, Multiply, NotEq, Or};
+        use Operator::{And, Divide, Eq, Modulo, Multiply, NotEq, Or};
 
         let info = self.info;
         let new_expr = match expr {
@@ -796,6 +797,37 @@ impl<'a, S: SimplifyInfo> ExprRewriter for Simplifier<'a, 
S> {
                 right,
             } if !info.nullable(&left)? && left == right => lit(1),
 
+            //
+            // Rules for Modulo
+            //
+
+            // A % null --> null
+            BinaryExpr {
+                left: _,
+                op: Modulo,
+                right,
+            } if is_null(&right) => *right,
+            // null % A --> null
+            BinaryExpr {
+                left,
+                op: Modulo,
+                right: _,
+            } if is_null(&left) => *left,
+            // A % 1 --> 0
+            BinaryExpr {
+                left,
+                op: Modulo,
+                right,
+            } if !info.nullable(&left)? && is_one(&right) => lit(0),
+            // A % 0 --> DivideByZero Error
+            BinaryExpr {
+                left,
+                op: Modulo,
+                right,
+            } if !info.nullable(&left)? && is_zero(&right) => {
+                return 
Err(DataFusionError::ArrowError(ArrowError::DivideByZero))
+            }
+
             //
             // Rules for Not
             //
@@ -1077,6 +1109,47 @@ mod tests {
         assert_eq!(simplify(expr), expected);
     }
 
+    #[test]
+    fn test_simplify_modulo_by_null() {
+        let null = Expr::Literal(ScalarValue::Null);
+        // A % null --> null
+        {
+            let expr = binary_expr(col("c2"), Operator::Modulo, null.clone());
+            assert_eq!(simplify(expr), null);
+        }
+        // null % A --> null
+        {
+            let expr = binary_expr(null.clone(), Operator::Modulo, col("c2"));
+            assert_eq!(simplify(expr), null);
+        }
+    }
+
+    #[test]
+    fn test_simplify_modulo_by_one() {
+        let expr = binary_expr(col("c2"), Operator::Modulo, lit(1));
+        // if c2 is null, c2 % 1 = null, so can't simplify
+        let expected = expr.clone();
+
+        assert_eq!(simplify(expr), expected);
+    }
+
+    #[test]
+    fn test_simplify_modulo_by_one_non_null() {
+        let expr = binary_expr(col("c2_non_null"), Operator::Modulo, lit(1));
+        let expected = lit(0);
+
+        assert_eq!(simplify(expr), expected);
+    }
+
+    #[test]
+    #[should_panic(
+        expected = "called `Result::unwrap()` on an `Err` value: 
ArrowError(DivideByZero)"
+    )]
+    fn test_simplify_modulo_by_zero_non_null() {
+        let expr = binary_expr(col("c2_non_null"), Operator::Modulo, lit(0));
+        simplify(expr);
+    }
+
     #[test]
     fn test_simplify_simple_and() {
         // (c > 5) AND (c > 5)

Reply via email to