This is an automated email from the ASF dual-hosted git repository.
starocean999 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new 289bdd44083 [test](nereids)add fe ut for
SimplifyArithmeticComparisonRule (#27644)
289bdd44083 is described below
commit 289bdd44083a21177961c3a604f2086f4e984116
Author: starocean999 <[email protected]>
AuthorDate: Mon Jan 29 15:26:37 2024 +0800
[test](nereids)add fe ut for SimplifyArithmeticComparisonRule (#27644)
---
.../rules/SimplifyArithmeticComparisonRule.java | 20 ++++++++-----
.../SimplifyArithmeticComparisonRuleTest.java | 34 ++++++++++++++++++++++
2 files changed, 46 insertions(+), 8 deletions(-)
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyArithmeticComparisonRule.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyArithmeticComparisonRule.java
index eda95ba32b1..7606d082479 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyArithmeticComparisonRule.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyArithmeticComparisonRule.java
@@ -82,20 +82,23 @@ public class SimplifyArithmeticComparisonRule extends
AbstractExpressionRewriteR
@Override
public Expression visitComparisonPredicate(ComparisonPredicate comparison,
ExpressionRewriteContext context) {
- ComparisonPredicate newComparison = comparison;
if (couldRearrange(comparison)) {
- newComparison = normalize(comparison);
+ ComparisonPredicate newComparison = normalize(comparison);
if (newComparison == null) {
return comparison;
}
try {
- List<Expression> children =
tryRearrangeChildren(newComparison.left(), newComparison.right());
- newComparison = (ComparisonPredicate)
newComparison.withChildren(children);
+ List<Expression> children =
+ tryRearrangeChildren(newComparison.left(),
newComparison.right(), context);
+ newComparison = (ComparisonPredicate) visitComparisonPredicate(
+ (ComparisonPredicate)
newComparison.withChildren(children), context);
} catch (Exception e) {
return comparison;
}
+ return TypeCoercionUtils.processComparisonPredicate(newComparison);
+ } else {
+ return comparison;
}
- return TypeCoercionUtils.processComparisonPredicate(newComparison);
}
private boolean couldRearrange(ComparisonPredicate cmp) {
@@ -104,11 +107,12 @@ public class SimplifyArithmeticComparisonRule extends
AbstractExpressionRewriteR
&&
cmp.left().children().stream().anyMatch(Expression::isConstant);
}
- private List<Expression> tryRearrangeChildren(Expression left, Expression
right) throws Exception {
- if (!left.child(1).isLiteral()) {
+ private List<Expression> tryRearrangeChildren(Expression left, Expression
right,
+ ExpressionRewriteContext context) throws Exception {
+ if (!left.child(1).isConstant()) {
throw new RuntimeException(String.format("Expected literal when
arranging children for Expr %s", left));
}
- Literal leftLiteral = (Literal) left.child(1);
+ Literal leftLiteral = (Literal)
FoldConstantRule.INSTANCE.rewrite(left.child(1), context);
Expression leftExpr = left.child(0);
Class<? extends Expression> oppositeOperator =
rearrangementMap.get(left.getClass());
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyArithmeticComparisonRuleTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyArithmeticComparisonRuleTest.java
index 5a438ded653..fc31daaa941 100644
---
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyArithmeticComparisonRuleTest.java
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyArithmeticComparisonRuleTest.java
@@ -41,6 +41,40 @@ class SimplifyArithmeticComparisonRuleTest extends
ExpressionRewriteTestHelper {
assertRewriteAfterSimplify("a + 1 > 1", "a > cast((1 - 1) as INT)",
nameToSlot);
assertRewriteAfterSimplify("a - 1 > 1", "a > cast((1 + 1) as INT)",
nameToSlot);
assertRewriteAfterSimplify("a / -2 > 1", "cast((1 * -2) as INT) > a",
nameToSlot);
+
+ // test integer type
+ assertRewriteAfterSimplify("1 + a > 2", "a > cast((2 - 1) as INT)",
nameToSlot);
+ assertRewriteAfterSimplify("-1 + a > 2", "a > cast((2 - (-1)) as
INT)", nameToSlot);
+ assertRewriteAfterSimplify("1 - a > 2", "a < cast((1 - 2) as INT)",
nameToSlot);
+ assertRewriteAfterSimplify("-1 - a > 2", "a < cast(((-1) - 2) as
INT)", nameToSlot);
+ assertRewriteAfterSimplify("2 * a > 1", "((2 * a) > 1)", nameToSlot);
+ assertRewriteAfterSimplify("-2 * a > 1", "((-2 * a) > 1)", nameToSlot);
+ assertRewriteAfterSimplify("2 / a > 1", "((2 / a) > 1)", nameToSlot);
+ assertRewriteAfterSimplify("-2 / a > 1", "((-2 / a) > 1)", nameToSlot);
+ assertRewriteAfterSimplify("a * 2 > 1", "((a * 2) > 1)", nameToSlot);
+ assertRewriteAfterSimplify("a * (-2) > 1", "((a * (-2)) > 1)",
nameToSlot);
+ assertRewriteAfterSimplify("a / 2 > 1", "(a > cast((1 * 2) as INT))",
nameToSlot);
+
+ // test decimal type
+ assertRewriteAfterSimplify("1.1 + a > 2.22", "(cast(a as DECIMALV3(12,
2)) > cast((2.22 - 1.1) as DECIMALV3(12, 2)))", nameToSlot);
+ assertRewriteAfterSimplify("-1.1 + a > 2.22", "(cast(a as
DECIMALV3(12, 2)) > cast((2.22 - (-1.1)) as DECIMALV3(12, 2)))", nameToSlot);
+ assertRewriteAfterSimplify("1.1 - a > 2.22", "(cast(a as DECIMALV3(11,
1)) < cast((1.1 - 2.22) as DECIMALV3(11, 1)))", nameToSlot);
+ assertRewriteAfterSimplify("-1.1 - a > 2.22", "(cast(a as
DECIMALV3(11, 1)) < cast((-1.1 - 2.22) as DECIMALV3(11, 1)))", nameToSlot);
+ assertRewriteAfterSimplify("2.22 * a > 1.1", "((2.22 * a) > 1.1)",
nameToSlot);
+ assertRewriteAfterSimplify("-2.22 * a > 1.1", "-2.22 * a > 1.1",
nameToSlot);
+ assertRewriteAfterSimplify("2.22 / a > 1.1", "((2.22 / a) > 1.1)",
nameToSlot);
+ assertRewriteAfterSimplify("-2.22 / a > 1.1", "((-2.22 / a) > 1.1)",
nameToSlot);
+ assertRewriteAfterSimplify("a * 2.22 > 1.1", "a * 2.22 > 1.1",
nameToSlot);
+ assertRewriteAfterSimplify("a * (-2.22) > 1.1", "a * (-2.22) > 1.1",
nameToSlot);
+ assertRewriteAfterSimplify("a / 2.22 > 1.1", "(cast(a as DECIMALV3(13,
3)) > cast((1.1 * 2.22) as DECIMALV3(13, 3)))", nameToSlot);
+ assertRewriteAfterSimplify("a / (-2.22) > 1.1", "(cast((1.1 * -2.22)
as DECIMALV3(13, 3)) > cast(a as DECIMALV3(13, 3)))", nameToSlot);
+
+ // test (1 + a) can be processed
+ assertRewriteAfterSimplify("2 - (1 + a) > 3", "(a < ((2 - 3) - 1))",
nameToSlot);
+ assertRewriteAfterSimplify("(1 - a) / 2 > 3", "(a < (1 - 6))",
nameToSlot);
+ assertRewriteAfterSimplify("1 - a / 2 > 3", "(a < ((1 - 3) * 2))",
nameToSlot);
+ assertRewriteAfterSimplify("(1 - (a + 4)) / 2 > 3", "(cast(a as
BIGINT) < ((1 - 6) - 4))", nameToSlot);
+ assertRewriteAfterSimplify("2 * (1 + a) > 1", "(2 * (1 + a)) > 1",
nameToSlot);
}
private void assertRewriteAfterSimplify(String expr, String expected,
Map<String, Slot> slotNameToSlot) {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]