This is an automated email from the ASF dual-hosted git repository.

starocean999 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new 289bdd44083 [test](nereids)add fe ut for 
SimplifyArithmeticComparisonRule (#27644)
289bdd44083 is described below

commit 289bdd44083a21177961c3a604f2086f4e984116
Author: starocean999 <[email protected]>
AuthorDate: Mon Jan 29 15:26:37 2024 +0800

    [test](nereids)add fe ut for SimplifyArithmeticComparisonRule (#27644)
---
 .../rules/SimplifyArithmeticComparisonRule.java    | 20 ++++++++-----
 .../SimplifyArithmeticComparisonRuleTest.java      | 34 ++++++++++++++++++++++
 2 files changed, 46 insertions(+), 8 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyArithmeticComparisonRule.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyArithmeticComparisonRule.java
index eda95ba32b1..7606d082479 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyArithmeticComparisonRule.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyArithmeticComparisonRule.java
@@ -82,20 +82,23 @@ public class SimplifyArithmeticComparisonRule extends 
AbstractExpressionRewriteR
 
     @Override
     public Expression visitComparisonPredicate(ComparisonPredicate comparison, 
ExpressionRewriteContext context) {
-        ComparisonPredicate newComparison = comparison;
         if (couldRearrange(comparison)) {
-            newComparison = normalize(comparison);
+            ComparisonPredicate newComparison = normalize(comparison);
             if (newComparison == null) {
                 return comparison;
             }
             try {
-                List<Expression> children = 
tryRearrangeChildren(newComparison.left(), newComparison.right());
-                newComparison = (ComparisonPredicate) 
newComparison.withChildren(children);
+                List<Expression> children =
+                        tryRearrangeChildren(newComparison.left(), 
newComparison.right(), context);
+                newComparison = (ComparisonPredicate) visitComparisonPredicate(
+                        (ComparisonPredicate) 
newComparison.withChildren(children), context);
             } catch (Exception e) {
                 return comparison;
             }
+            return TypeCoercionUtils.processComparisonPredicate(newComparison);
+        } else {
+            return comparison;
         }
-        return TypeCoercionUtils.processComparisonPredicate(newComparison);
     }
 
     private boolean couldRearrange(ComparisonPredicate cmp) {
@@ -104,11 +107,12 @@ public class SimplifyArithmeticComparisonRule extends 
AbstractExpressionRewriteR
                 && 
cmp.left().children().stream().anyMatch(Expression::isConstant);
     }
 
-    private List<Expression> tryRearrangeChildren(Expression left, Expression 
right) throws Exception {
-        if (!left.child(1).isLiteral()) {
+    private List<Expression> tryRearrangeChildren(Expression left, Expression 
right,
+            ExpressionRewriteContext context) throws Exception {
+        if (!left.child(1).isConstant()) {
             throw new RuntimeException(String.format("Expected literal when 
arranging children for Expr %s", left));
         }
-        Literal leftLiteral = (Literal) left.child(1);
+        Literal leftLiteral = (Literal) 
FoldConstantRule.INSTANCE.rewrite(left.child(1), context);
         Expression leftExpr = left.child(0);
 
         Class<? extends Expression> oppositeOperator = 
rearrangementMap.get(left.getClass());
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyArithmeticComparisonRuleTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyArithmeticComparisonRuleTest.java
index 5a438ded653..fc31daaa941 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyArithmeticComparisonRuleTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyArithmeticComparisonRuleTest.java
@@ -41,6 +41,40 @@ class SimplifyArithmeticComparisonRuleTest extends 
ExpressionRewriteTestHelper {
         assertRewriteAfterSimplify("a + 1 > 1", "a > cast((1 - 1) as INT)", 
nameToSlot);
         assertRewriteAfterSimplify("a - 1 > 1", "a > cast((1 + 1) as INT)", 
nameToSlot);
         assertRewriteAfterSimplify("a / -2 > 1", "cast((1 * -2) as INT) > a", 
nameToSlot);
+
+        // test integer type
+        assertRewriteAfterSimplify("1 + a > 2", "a > cast((2 - 1) as INT)", 
nameToSlot);
+        assertRewriteAfterSimplify("-1 + a > 2", "a > cast((2 - (-1)) as 
INT)", nameToSlot);
+        assertRewriteAfterSimplify("1 - a > 2", "a < cast((1 - 2) as INT)", 
nameToSlot);
+        assertRewriteAfterSimplify("-1 - a > 2", "a < cast(((-1) - 2) as 
INT)", nameToSlot);
+        assertRewriteAfterSimplify("2 * a > 1", "((2 * a) > 1)", nameToSlot);
+        assertRewriteAfterSimplify("-2 * a > 1", "((-2 * a) > 1)", nameToSlot);
+        assertRewriteAfterSimplify("2 / a > 1", "((2 / a) > 1)", nameToSlot);
+        assertRewriteAfterSimplify("-2 / a > 1", "((-2 / a) > 1)", nameToSlot);
+        assertRewriteAfterSimplify("a * 2 > 1", "((a * 2) > 1)", nameToSlot);
+        assertRewriteAfterSimplify("a * (-2) > 1", "((a * (-2)) > 1)", 
nameToSlot);
+        assertRewriteAfterSimplify("a / 2 > 1", "(a > cast((1 * 2) as INT))", 
nameToSlot);
+
+        // test decimal type
+        assertRewriteAfterSimplify("1.1 + a > 2.22", "(cast(a as DECIMALV3(12, 
2)) > cast((2.22 - 1.1) as DECIMALV3(12, 2)))", nameToSlot);
+        assertRewriteAfterSimplify("-1.1 + a > 2.22", "(cast(a as 
DECIMALV3(12, 2)) > cast((2.22 - (-1.1)) as DECIMALV3(12, 2)))", nameToSlot);
+        assertRewriteAfterSimplify("1.1 - a > 2.22", "(cast(a as DECIMALV3(11, 
1)) < cast((1.1 - 2.22) as DECIMALV3(11, 1)))", nameToSlot);
+        assertRewriteAfterSimplify("-1.1 - a > 2.22", "(cast(a as 
DECIMALV3(11, 1)) < cast((-1.1 - 2.22) as DECIMALV3(11, 1)))", nameToSlot);
+        assertRewriteAfterSimplify("2.22 * a > 1.1", "((2.22 * a) > 1.1)", 
nameToSlot);
+        assertRewriteAfterSimplify("-2.22 * a > 1.1", "-2.22 * a > 1.1", 
nameToSlot);
+        assertRewriteAfterSimplify("2.22 / a > 1.1", "((2.22 / a) > 1.1)", 
nameToSlot);
+        assertRewriteAfterSimplify("-2.22 / a > 1.1", "((-2.22 / a) > 1.1)", 
nameToSlot);
+        assertRewriteAfterSimplify("a * 2.22 > 1.1", "a * 2.22 > 1.1", 
nameToSlot);
+        assertRewriteAfterSimplify("a * (-2.22) > 1.1", "a * (-2.22) > 1.1", 
nameToSlot);
+        assertRewriteAfterSimplify("a / 2.22 > 1.1", "(cast(a as DECIMALV3(13, 
3)) > cast((1.1 * 2.22) as DECIMALV3(13, 3)))", nameToSlot);
+        assertRewriteAfterSimplify("a / (-2.22) > 1.1", "(cast((1.1 * -2.22) 
as DECIMALV3(13, 3)) > cast(a as DECIMALV3(13, 3)))", nameToSlot);
+
+        // test (1 + a) can be processed
+        assertRewriteAfterSimplify("2 - (1 + a) > 3", "(a < ((2 - 3) - 1))", 
nameToSlot);
+        assertRewriteAfterSimplify("(1 - a) / 2 > 3", "(a < (1 - 6))", 
nameToSlot);
+        assertRewriteAfterSimplify("1 - a / 2 > 3", "(a < ((1 - 3) * 2))", 
nameToSlot);
+        assertRewriteAfterSimplify("(1 - (a + 4)) / 2 > 3", "(cast(a as 
BIGINT) < ((1 - 6) - 4))", nameToSlot);
+        assertRewriteAfterSimplify("2 * (1 + a) > 1", "(2 * (1 + a)) > 1", 
nameToSlot);
     }
 
     private void assertRewriteAfterSimplify(String expr, String expected, 
Map<String, Slot> slotNameToSlot) {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to