This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 87527c4308 Fix some simplification rules for floating-point arithmetic 
operations (#7515)
87527c4308 is described below

commit 87527c430852af82ef182e1607804f6ddd817bbb
Author: Jonah Gao <[email protected]>
AuthorDate: Tue Sep 12 02:03:45 2023 +0800

    Fix some simplification rules for floating-point arithmetic operations 
(#7515)
---
 .../src/simplify_expressions/expr_simplifier.rs    |  40 +++-
 datafusion/sqllogictest/test_files/math.slt        | 223 +++++++++++++++++----
 2 files changed, 217 insertions(+), 46 deletions(-)

diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs 
b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
index 1081f9d688..c92660c7bb 100644
--- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
+++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
@@ -670,18 +670,28 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for 
Simplifier<'a, S> {
                 right: _,
             }) if is_null(&left) => *left,
 
-            // A * 0 --> 0 (if A is not null)
+            // A * 0 --> 0 (if A is not null and not floating, since NAN * 0 
-> NAN)
             Expr::BinaryExpr(BinaryExpr {
                 left,
                 op: Multiply,
                 right,
-            }) if !info.nullable(&left)? && is_zero(&right) => *right,
-            // 0 * A --> 0 (if A is not null)
+            }) if !info.nullable(&left)?
+                && !info.get_data_type(&left)?.is_floating()
+                && is_zero(&right) =>
+            {
+                *right
+            }
+            // 0 * A --> 0 (if A is not null and not floating, since 0 * NAN 
-> NAN)
             Expr::BinaryExpr(BinaryExpr {
                 left,
                 op: Multiply,
                 right,
-            }) if !info.nullable(&right)? && is_zero(&left) => *left,
+            }) if !info.nullable(&right)?
+                && !info.get_data_type(&right)?.is_floating()
+                && is_zero(&left) =>
+            {
+                *left
+            }
 
             //
             // Rules for Divide
@@ -734,19 +744,33 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for 
Simplifier<'a, S> {
                 op: Modulo,
                 right: _,
             }) if is_null(&left) => *left,
-            // A % 1 --> 0
+            // A % 1 --> 0 (if A is not nullable and not floating, since NAN % 
1 --> NAN)
             Expr::BinaryExpr(BinaryExpr {
                 left,
                 op: Modulo,
                 right,
-            }) if !info.nullable(&left)? && is_one(&right) => lit(0),
-            // A % 0 --> DivideByZero Error
+            }) if !info.nullable(&left)?
+                && !info.get_data_type(&left)?.is_floating()
+                && is_one(&right) =>
+            {
+                lit(0)
+            }
+            // A % 0 --> DivideByZero Error (if A is not floating and not null)
+            // A % 0 --> NAN (if A is floating and not null)
             Expr::BinaryExpr(BinaryExpr {
                 left,
                 op: Modulo,
                 right,
             }) if !info.nullable(&left)? && is_zero(&right) => {
-                return 
Err(DataFusionError::ArrowError(ArrowError::DivideByZero));
+                match info.get_data_type(&left)? {
+                    DataType::Float32 => lit(f32::NAN),
+                    DataType::Float64 => lit(f64::NAN),
+                    _ => {
+                        return Err(DataFusionError::ArrowError(
+                            ArrowError::DivideByZero,
+                        ));
+                    }
+                }
             }
 
             //
diff --git a/datafusion/sqllogictest/test_files/math.slt 
b/datafusion/sqllogictest/test_files/math.slt
index dace5fa906..a56d4d2ecf 100644
--- a/datafusion/sqllogictest/test_files/math.slt
+++ b/datafusion/sqllogictest/test_files/math.slt
@@ -114,7 +114,7 @@ false true true NULL
 
 
 statement ok
-CREATE TABLE test_divide_zero_integer_nullable(
+CREATE TABLE test_nullable_integer(
     c1 TINYINT, 
     c2 SMALLINT, 
     c3 INT, 
@@ -128,45 +128,85 @@ CREATE TABLE test_divide_zero_integer_nullable(
     (NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL);
 
 query IIIIIIII
-SELECT c1/0, c2/0, c3/0, c4/0, c5/0, c6/0, c7/0, c8/0 FROM 
test_divide_zero_integer_nullable
+SELECT c1*0, c2*0, c3*0, c4*0, c5*0, c6*0, c7*0, c8*0 FROM 
test_nullable_integer
 ----
 NULL NULL NULL NULL NULL NULL NULL NULL
 
 query IIIIIIII
-INSERT INTO test_divide_zero_integer_nullable VALUES(1, 1, 1, 1, 1, 1, 1, 1)
+SELECT c1/0, c2/0, c3/0, c4/0, c5/0, c6/0, c7/0, c8/0 FROM 
test_nullable_integer
+----
+NULL NULL NULL NULL NULL NULL NULL NULL
+
+query IIIIIIII
+SELECT c1%0, c2%0, c3%0, c4%0, c5%0, c6%0, c7%0, c8%0 FROM 
test_nullable_integer
+----
+NULL NULL NULL NULL NULL NULL NULL NULL
+
+query IIIIIIII
+INSERT INTO test_nullable_integer VALUES(1, 1, 1, 1, 1, 1, 1, 1)
 ----
 1
 
+query IIIIIIII rowsort
+select c1*0, c2*0, c3*0, c4*0, c5*0, c6*0, c7*0, c8*0 from 
test_nullable_integer
+----
+0 0 0 0 0 0 0 0
+NULL NULL NULL NULL NULL NULL NULL NULL 
+
 query error DataFusion error: Arrow error: Divide by zero error
-SELECT c1/0 FROM test_divide_zero_integer_nullable
+SELECT c1/0 FROM test_nullable_integer
 
 query error DataFusion error: Arrow error: Divide by zero error
-SELECT c2/0 FROM test_divide_zero_integer_nullable
+SELECT c2/0 FROM test_nullable_integer
 
 query error DataFusion error: Arrow error: Divide by zero error
-SELECT c3/0 FROM test_divide_zero_integer_nullable
+SELECT c3/0 FROM test_nullable_integer
 
 query error DataFusion error: Arrow error: Divide by zero error
-SELECT c4/0 FROM test_divide_zero_integer_nullable
+SELECT c4/0 FROM test_nullable_integer
 
 query error DataFusion error: Arrow error: Divide by zero error
-SELECT c5/0 FROM test_divide_zero_integer_nullable
+SELECT c5/0 FROM test_nullable_integer
 
 query error DataFusion error: Arrow error: Divide by zero error
-SELECT c6/0 FROM test_divide_zero_integer_nullable
+SELECT c6/0 FROM test_nullable_integer
 
 query error DataFusion error: Arrow error: Divide by zero error
-SELECT c7/0 FROM test_divide_zero_integer_nullable
+SELECT c7/0 FROM test_nullable_integer
 
 query error DataFusion error: Arrow error: Divide by zero error
-SELECT c8/0 FROM test_divide_zero_integer_nullable
+SELECT c8/0 FROM test_nullable_integer
+
+query error DataFusion error: Arrow error: Divide by zero error
+SELECT c1%0 FROM test_nullable_integer
+
+query error DataFusion error: Arrow error: Divide by zero error
+SELECT c2%0 FROM test_nullable_integer
+
+query error DataFusion error: Arrow error: Divide by zero error
+SELECT c3%0 FROM test_nullable_integer
+
+query error DataFusion error: Arrow error: Divide by zero error
+SELECT c4%0 FROM test_nullable_integer
+
+query error DataFusion error: Arrow error: Divide by zero error
+SELECT c5%0 FROM test_nullable_integer
+
+query error DataFusion error: Arrow error: Divide by zero error
+SELECT c6%0 FROM test_nullable_integer
+
+query error DataFusion error: Arrow error: Divide by zero error
+SELECT c7%0 FROM test_nullable_integer
+
+query error DataFusion error: Arrow error: Divide by zero error
+SELECT c8%0 FROM test_nullable_integer
 
 statement ok
-drop table test_divide_zero_integer_nullable
+drop table test_nullable_integer
 
 
 statement ok
-CREATE TABLE test_divide_zero_integer_non_nullable(
+CREATE TABLE test_non_nullable_integer(
     c1 TINYINT NOT NULL, 
     c2 SMALLINT NOT NULL, 
     c3 INT NOT NULL, 
@@ -178,40 +218,70 @@ CREATE TABLE test_divide_zero_integer_non_nullable(
     );
 
 query IIIIIIII
-INSERT INTO test_divide_zero_integer_non_nullable VALUES(1, 1, 1, 1, 1, 1, 1, 
1)
+INSERT INTO test_non_nullable_integer VALUES(1, 1, 1, 1, 1, 1, 1, 1)
 ----
 1
 
+query IIIIIIII rowsort
+select c1*0, c2*0, c3*0, c4*0, c5*0, c6*0, c7*0, c8*0 from 
test_non_nullable_integer
+----
+0 0 0 0 0 0 0 0
+
+query error DataFusion error: Optimizer rule 'simplify_expressions' 
failed\ncaused by\nArrow error: Divide by zero error
+SELECT c1/0 FROM test_non_nullable_integer
+
+query error DataFusion error: Optimizer rule 'simplify_expressions' 
failed\ncaused by\nArrow error: Divide by zero error
+SELECT c2/0 FROM test_non_nullable_integer
+
+query error DataFusion error: Optimizer rule 'simplify_expressions' 
failed\ncaused by\nArrow error: Divide by zero error
+SELECT c3/0 FROM test_non_nullable_integer
+
+query error DataFusion error: Optimizer rule 'simplify_expressions' 
failed\ncaused by\nArrow error: Divide by zero error
+SELECT c4/0 FROM test_non_nullable_integer
+
+query error DataFusion error: Optimizer rule 'simplify_expressions' 
failed\ncaused by\nArrow error: Divide by zero error
+SELECT c5/0 FROM test_non_nullable_integer
+
+query error DataFusion error: Optimizer rule 'simplify_expressions' 
failed\ncaused by\nArrow error: Divide by zero error
+SELECT c6/0 FROM test_non_nullable_integer
+
+query error DataFusion error: Optimizer rule 'simplify_expressions' 
failed\ncaused by\nArrow error: Divide by zero error
+SELECT c7/0 FROM test_non_nullable_integer
+
+query error DataFusion error: Optimizer rule 'simplify_expressions' 
failed\ncaused by\nArrow error: Divide by zero error
+SELECT c8/0 FROM test_non_nullable_integer
+
+
 query error DataFusion error: Optimizer rule 'simplify_expressions' 
failed\ncaused by\nArrow error: Divide by zero error
-SELECT c1/0 FROM test_divide_zero_integer_non_nullable
+SELECT c1%0 FROM test_non_nullable_integer
 
 query error DataFusion error: Optimizer rule 'simplify_expressions' 
failed\ncaused by\nArrow error: Divide by zero error
-SELECT c2/0 FROM test_divide_zero_integer_non_nullable
+SELECT c2%0 FROM test_non_nullable_integer
 
 query error DataFusion error: Optimizer rule 'simplify_expressions' 
failed\ncaused by\nArrow error: Divide by zero error
-SELECT c3/0 FROM test_divide_zero_integer_non_nullable
+SELECT c3%0 FROM test_non_nullable_integer
 
 query error DataFusion error: Optimizer rule 'simplify_expressions' 
failed\ncaused by\nArrow error: Divide by zero error
-SELECT c4/0 FROM test_divide_zero_integer_non_nullable
+SELECT c4%0 FROM test_non_nullable_integer
 
 query error DataFusion error: Optimizer rule 'simplify_expressions' 
failed\ncaused by\nArrow error: Divide by zero error
-SELECT c5/0 FROM test_divide_zero_integer_non_nullable
+SELECT c5%0 FROM test_non_nullable_integer
 
 query error DataFusion error: Optimizer rule 'simplify_expressions' 
failed\ncaused by\nArrow error: Divide by zero error
-SELECT c6/0 FROM test_divide_zero_integer_non_nullable
+SELECT c6%0 FROM test_non_nullable_integer
 
 query error DataFusion error: Optimizer rule 'simplify_expressions' 
failed\ncaused by\nArrow error: Divide by zero error
-SELECT c7/0 FROM test_divide_zero_integer_non_nullable
+SELECT c7%0 FROM test_non_nullable_integer
 
 query error DataFusion error: Optimizer rule 'simplify_expressions' 
failed\ncaused by\nArrow error: Divide by zero error
-SELECT c8/0 FROM test_divide_zero_integer_non_nullable
+SELECT c8%0 FROM test_non_nullable_integer
 
 statement ok
-drop table test_divide_zero_integer_non_nullable
+drop table test_non_nullable_integer
 
 
 statement ok
-CREATE TABLE test_divide_zero_float_nullable(
+CREATE TABLE test_nullable_float(
     c1 float,
     c2 double, 
     ) AS VALUES
@@ -222,7 +292,16 @@ CREATE TABLE test_divide_zero_float_nullable(
     ('NaN'::double, 'NaN'::double);
 
 query RR rowsort
-SELECT c1/0, c2/0 FROM test_divide_zero_float_nullable
+SELECT c1*0, c2*0 FROM test_nullable_float
+----
+0 0
+0 0
+0 0
+NULL NULL
+NaN NaN
+
+query RR rowsort
+SELECT c1/0, c2/0 FROM test_nullable_float
 ----
 -Infinity -Infinity
 Infinity Infinity
@@ -230,18 +309,36 @@ NULL NULL
 NaN NaN
 NaN NaN
 
+query RR rowsort
+SELECT c1%0, c2%0 FROM test_nullable_float
+----
+NULL NULL
+NaN NaN
+NaN NaN
+NaN NaN
+NaN NaN
+
+query RR rowsort
+SELECT c1%1, c2%1 FROM test_nullable_float
+----
+0 0
+0 0
+0 0
+NULL NULL
+NaN NaN
+
 statement ok
-drop table test_divide_zero_float_nullable
+drop table test_nullable_float
 
 
 statement ok
-CREATE TABLE test_divide_zero_float_non_nullable(
+CREATE TABLE test_non_nullable_float(
     c1 float NOT NULL,
     c2 double NOT NULL, 
     ); 
 
 query RR
-INSERT INTO test_divide_zero_float_non_nullable VALUES
+INSERT INTO test_non_nullable_float VALUES
     (-1.0, -1.0),
     (1.0, 1.0),
     (0., 0.),
@@ -250,42 +347,92 @@ INSERT INTO test_divide_zero_float_non_nullable VALUES
 4
 
 query RR rowsort
-SELECT c1/0, c2/0 FROM test_divide_zero_float_non_nullable
+SELECT c1*0, c2*0 FROM test_non_nullable_float
+----
+0 0
+0 0
+0 0
+NaN NaN
+
+query RR rowsort
+SELECT c1/0, c2/0 FROM test_non_nullable_float
 ----
 -Infinity -Infinity
 Infinity Infinity
 NaN NaN
 NaN NaN
 
+query RR rowsort
+SELECT c1%0, c2%0 FROM test_non_nullable_float
+----
+NaN NaN
+NaN NaN
+NaN NaN
+NaN NaN
+
+query RR rowsort
+SELECT c1%1, c2%1 FROM test_non_nullable_float
+----
+0 0
+0 0
+0 0
+NaN NaN
+
 statement ok
-drop table test_divide_zero_float_non_nullable
+drop table test_non_nullable_float
 
 
 statement ok
-CREATE TABLE test_divide_zero_decimal_nullable(c1 DECIMAL(9, 2)) AS VALUES 
(1), (NULL);
+CREATE TABLE test_nullable_decimal(c1 DECIMAL(9, 2)) AS VALUES (1), (NULL);
+
+query R rowsort
+SELECT c1*0 FROM test_nullable_decimal WHERE c1 IS NULL;
+----
+NULL
+
+query R rowsort
+SELECT c1/0 FROM test_nullable_decimal WHERE c1 IS NULL;
+----
+NULL
 
 query R rowsort
-SELECT c1/0 FROM test_divide_zero_decimal_nullable WHERE c1 IS NULL;
+SELECT c1%0 FROM test_nullable_decimal WHERE c1 IS NULL;
 ----
 NULL
 
+query R rowsort
+SELECT c1*0 FROM test_nullable_decimal WHERE c1 IS NOT NULL;
+----
+0
+
+query error DataFusion error: Arrow error: Divide by zero error
+SELECT c1/0 FROM test_nullable_decimal WHERE c1 IS NOT NULL;
+
 query error DataFusion error: Arrow error: Divide by zero error
-SELECT c1/0 FROM test_divide_zero_decimal_nullable WHERE c1 IS NOT NULL;
+SELECT c1%0 FROM test_nullable_decimal WHERE c1 IS NOT NULL;
 
 statement ok
-drop table test_divide_zero_decimal_nullable
+drop table test_nullable_decimal  
 
 
 statement ok
-CREATE TABLE test_divide_zero_decimal_non_nullable(c1 DECIMAL(9,2) NOT NULL); 
+CREATE TABLE test_non_nullable_decimal(c1 DECIMAL(9,2) NOT NULL); 
 
 query R
-INSERT INTO test_divide_zero_decimal_non_nullable VALUES(1)
+INSERT INTO test_non_nullable_decimal VALUES(1)
 ----
 1
 
+query R rowsort
+SELECT c1*0 FROM test_non_nullable_decimal
+----
+0
+
+query error DataFusion error: Optimizer rule 'simplify_expressions' 
failed\ncaused by\nArrow error: Divide by zero error 
+SELECT c1/0 FROM test_non_nullable_decimal 
+
 query error DataFusion error: Optimizer rule 'simplify_expressions' 
failed\ncaused by\nArrow error: Divide by zero error 
-SELECT c1/0 FROM test_divide_zero_decimal_non_nullable 
+SELECT c1%0 FROM test_non_nullable_decimal 
 
 statement ok
-drop table test_divide_zero_decimal_non_nullable 
\ No newline at end of file
+drop table test_non_nullable_decimal 
\ No newline at end of file

Reply via email to