This is an automated email from the ASF dual-hosted git repository.

huajianlan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new 730cd1a0c1 [Feature](Nereids) Simplify range of predicate (#14113)
730cd1a0c1 is described below

commit 730cd1a0c1417de16fa822f5ed649d3bc82b4b34
Author: shee <[email protected]>
AuthorDate: Mon Nov 21 20:24:03 2022 +0800

    [Feature](Nereids) Simplify range of predicate (#14113)
    
    Simplify range of predicate
    
    for example:
    1. `a > 1 or a > 2` => `a > 1`
    2. `a in (1,2,3) or a (3,4,5)` => `a in (1,2,3,4,5)`
---
 .../expression/rewrite/ExpressionOptimization.java |   5 +-
 .../expression/rewrite/ExpressionRuleExecutor.java |   5 +
 .../expression/rewrite/rules/SimplifyRange.java    | 461 +++++++++++++++++++++
 .../nereids/trees/expressions/literal/Literal.java |   3 +-
 .../apache/doris/nereids/util/ExpressionUtils.java |   4 +
 .../nereids/datasets/ssb/SSBJoinReorderTest.java   |   4 +-
 .../expression/rewrite/SimplifyRangeTest.java      | 140 +++++++
 7 files changed, 618 insertions(+), 4 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/ExpressionOptimization.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/ExpressionOptimization.java
index 973af7ab72..9e3eace63f 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/ExpressionOptimization.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/ExpressionOptimization.java
@@ -20,6 +20,7 @@ package org.apache.doris.nereids.rules.expression.rewrite;
 import 
org.apache.doris.nereids.rules.expression.rewrite.rules.DistinctPredicatesRule;
 import 
org.apache.doris.nereids.rules.expression.rewrite.rules.ExtractCommonFactorRule;
 import 
org.apache.doris.nereids.rules.expression.rewrite.rules.SimplifyComparisonPredicate;
+import org.apache.doris.nereids.rules.expression.rewrite.rules.SimplifyRange;
 
 import com.google.common.collect.ImmutableList;
 
@@ -32,7 +33,9 @@ public class ExpressionOptimization extends ExpressionRewrite 
{
     public static final List<ExpressionRewriteRule> OPTIMIZE_REWRITE_RULES = 
ImmutableList.of(
             ExtractCommonFactorRule.INSTANCE,
             DistinctPredicatesRule.INSTANCE,
-            SimplifyComparisonPredicate.INSTANCE);
+            SimplifyComparisonPredicate.INSTANCE,
+            SimplifyRange.INSTANCE
+    );
     private static final ExpressionRuleExecutor EXECUTOR = new 
ExpressionRuleExecutor(OPTIMIZE_REWRITE_RULES);
 
     public ExpressionOptimization() {
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/ExpressionRuleExecutor.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/ExpressionRuleExecutor.java
index 44829a2a31..d8da2617cc 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/ExpressionRuleExecutor.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/ExpressionRuleExecutor.java
@@ -17,6 +17,7 @@
 
 package org.apache.doris.nereids.rules.expression.rewrite;
 
+import 
org.apache.doris.nereids.rules.expression.rewrite.rules.NormalizeBinaryPredicatesRule;
 import org.apache.doris.nereids.trees.expressions.Expression;
 import org.apache.doris.qe.ConnectContext;
 
@@ -67,4 +68,8 @@ public class ExpressionRuleExecutor {
         return rule.rewrite(expr, ctx);
     }
 
+    public static Expression normalize(Expression expression) {
+        return NormalizeBinaryPredicatesRule.INSTANCE.rewrite(expression, 
null);
+    }
+
 }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/rules/SimplifyRange.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/rules/SimplifyRange.java
new file mode 100644
index 0000000000..010031486c
--- /dev/null
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/rules/SimplifyRange.java
@@ -0,0 +1,461 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.expression.rewrite.rules;
+
+import 
org.apache.doris.nereids.rules.expression.rewrite.AbstractExpressionRewriteRule;
+import 
org.apache.doris.nereids.rules.expression.rewrite.ExpressionRewriteContext;
+import 
org.apache.doris.nereids.rules.expression.rewrite.ExpressionRuleExecutor;
+import org.apache.doris.nereids.trees.expressions.And;
+import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
+import org.apache.doris.nereids.trees.expressions.CompoundPredicate;
+import org.apache.doris.nereids.trees.expressions.EqualTo;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.GreaterThan;
+import org.apache.doris.nereids.trees.expressions.GreaterThanEqual;
+import org.apache.doris.nereids.trees.expressions.InPredicate;
+import org.apache.doris.nereids.trees.expressions.LessThan;
+import org.apache.doris.nereids.trees.expressions.LessThanEqual;
+import org.apache.doris.nereids.trees.expressions.Or;
+import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
+import org.apache.doris.nereids.trees.expressions.literal.Literal;
+import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
+import org.apache.doris.nereids.util.ExpressionUtils;
+
+import com.google.common.collect.BoundType;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.Lists;
+import com.google.common.collect.Range;
+import com.google.common.collect.Sets;
+
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Map.Entry;
+import java.util.Set;
+import java.util.function.BinaryOperator;
+import java.util.stream.Collectors;
+
+/**
+ * This class implements the function to simplify expression range.
+ * for example:
+ * a > 1 and a > 2 => a > 2
+ * a > 1 or a > 2 => a > 1
+ * a in (1,2,3) and a > 1 => a in (2,3)
+ * a in (1,2,3) and a in (3,4,5) => a = 3
+ * a in(1,2,3) and a in (4,5,6) => false
+ * The logic is as follows:
+ * 1. for `And` expression.
+ *    1. extract conjunctions then build `ValueDesc` for each conjunction
+ *    2. grouping according to `reference`, `ValueDesc` in the same group can 
perform intersect
+ *    for example:
+ *    a > 1 and a > 2
+ *    1. a > 1 => RangeValueDesc((1...+∞)), a > 2 => RangeValueDesc((2...+∞))
+ *    2. (1...+∞) intersect (2...+∞) => (2...+∞)
+ * 2. for `Or` expression (similar to `And`).
+ * todo: support a > 10 and (a < 10 or a > 20 ) => a > 20
+ */
+public class SimplifyRange extends AbstractExpressionRewriteRule {
+
+    public static final SimplifyRange INSTANCE = new SimplifyRange();
+
+    @Override
+    public Expression rewrite(Expression expr, ExpressionRewriteContext ctx) {
+        if (expr instanceof CompoundPredicate) {
+            ValueDesc valueDesc = expr.accept(new RangeInference(), null);
+            Expression simplifiedExpr = valueDesc.toExpression();
+            return simplifiedExpr == null ? valueDesc.expr : simplifiedExpr;
+        }
+        return expr;
+    }
+
+    private static class RangeInference extends ExpressionVisitor<ValueDesc, 
Void> {
+
+        @Override
+        public ValueDesc visit(Expression expr, Void context) {
+            return new UnknownValue(expr);
+        }
+
+        private ValueDesc buildRange(ComparisonPredicate predicate) {
+            Expression rewrite = ExpressionRuleExecutor.normalize(predicate);
+            Expression right = rewrite.child(1);
+            // only handle `NumericType`
+            if (right.isLiteral() && right.getDataType().isNumericType()) {
+                return ValueDesc.range((ComparisonPredicate) rewrite);
+            }
+            return new UnknownValue(predicate);
+        }
+
+        @Override
+        public ValueDesc visitGreaterThan(GreaterThan greaterThan, Void 
context) {
+            return buildRange(greaterThan);
+        }
+
+        @Override
+        public ValueDesc visitGreaterThanEqual(GreaterThanEqual 
greaterThanEqual, Void context) {
+            return buildRange(greaterThanEqual);
+        }
+
+        @Override
+        public ValueDesc visitLessThan(LessThan lessThan, Void context) {
+            return buildRange(lessThan);
+        }
+
+        @Override
+        public ValueDesc visitLessThanEqual(LessThanEqual lessThanEqual, Void 
context) {
+            return buildRange(lessThanEqual);
+        }
+
+        @Override
+        public ValueDesc visitEqualTo(EqualTo equalTo, Void context) {
+            return buildRange(equalTo);
+        }
+
+        @Override
+        public ValueDesc visitInPredicate(InPredicate inPredicate, Void 
context) {
+            // only handle `NumericType`
+            if (ExpressionUtils.isAllLiteral(inPredicate.getOptions())
+                    && 
ExpressionUtils.matchNumericType(inPredicate.getOptions())) {
+                return ValueDesc.discrete(inPredicate);
+            }
+            return new UnknownValue(inPredicate);
+        }
+
+        @Override
+        public ValueDesc visitAnd(And and, Void context) {
+            return simplify(and, ExpressionUtils.extractConjunction(and), 
ValueDesc::intersect, ExpressionUtils::and);
+        }
+
+        @Override
+        public ValueDesc visitOr(Or or, Void context) {
+            return simplify(or, ExpressionUtils.extractDisjunction(or), 
ValueDesc::union, ExpressionUtils::or);
+        }
+
+        private ValueDesc simplify(Expression originExpr, List<Expression> 
predicates,
+                BinaryOperator<ValueDesc> op, BinaryOperator<Expression> 
exprOp) {
+
+            Map<Expression, List<ValueDesc>> groupByReference = 
predicates.stream()
+                    .map(predicate -> predicate.accept(this, null))
+                    .collect(Collectors.groupingBy(p -> p.reference, 
LinkedHashMap::new, Collectors.toList()));
+
+            List<ValueDesc> valuePerRefs = Lists.newArrayList();
+            for (Entry<Expression, List<ValueDesc>> referenceValues : 
groupByReference.entrySet()) {
+                List<ValueDesc> valuePerReference = referenceValues.getValue();
+
+                // merge per reference
+                ValueDesc simplifiedValue = valuePerReference.stream()
+                        .reduce(op)
+                        .get();
+
+                valuePerRefs.add(simplifiedValue);
+            }
+
+            if (valuePerRefs.size() == 1) {
+                return valuePerRefs.get(0);
+            }
+
+            // use UnknownValue to wrap different references
+            return new UnknownValue(valuePerRefs, originExpr, exprOp);
+        }
+    }
+
+    private abstract static class ValueDesc {
+        Expression expr;
+        Expression reference;
+
+        public ValueDesc(Expression reference, Expression expr) {
+            this.expr = expr;
+            this.reference = reference;
+        }
+
+        public abstract ValueDesc union(ValueDesc other);
+
+        public static ValueDesc union(RangeValue range, DiscreteValue 
discrete, boolean reverseOrder) {
+            long count = discrete.values.stream().filter(x -> 
range.range.test(x)).count();
+            if (count == discrete.values.size()) {
+                return range;
+            }
+            Expression originExpr = ExpressionUtils.or(range.expr, 
discrete.expr);
+            List<ValueDesc> sourceValues = reverseOrder
+                    ? ImmutableList.of(discrete, range)
+                    : ImmutableList.of(range, discrete);
+            return new UnknownValue(sourceValues, originExpr, 
ExpressionUtils::or);
+        }
+
+        public abstract ValueDesc intersect(ValueDesc other);
+
+        public static ValueDesc intersect(RangeValue range, DiscreteValue 
discrete) {
+            DiscreteValue result = new DiscreteValue(discrete.reference, 
discrete.expr);
+            discrete.values.stream().filter(x -> 
range.range.contains(x)).forEach(result.values::add);
+            if (result.values.size() > 0) {
+                return result;
+            }
+            return new EmptyValue(range.reference, 
ExpressionUtils.and(range.expr, discrete.expr));
+        }
+
+        public abstract Expression toExpression();
+
+        public static ValueDesc range(ComparisonPredicate predicate) {
+            Literal value = (Literal) predicate.right();
+            if (predicate instanceof EqualTo) {
+                return new DiscreteValue(predicate.left(), predicate, value);
+            }
+            RangeValue rangeValue = new RangeValue(predicate.left(), 
predicate);
+            if (predicate instanceof GreaterThanEqual) {
+                rangeValue.range = Range.atLeast(value);
+            } else if (predicate instanceof GreaterThan) {
+                rangeValue.range = Range.greaterThan(value);
+            } else if (predicate instanceof LessThanEqual) {
+                rangeValue.range = Range.atMost(value);
+            } else if (predicate instanceof LessThan) {
+                rangeValue.range = Range.lessThan(value);
+            }
+
+            return rangeValue;
+        }
+
+        public static ValueDesc discrete(InPredicate in) {
+            Set<Literal> literals = 
in.getOptions().stream().map(Literal.class::cast).collect(Collectors.toSet());
+            return new DiscreteValue(in.getCompareExpr(), in, literals);
+        }
+    }
+
+    private static class EmptyValue extends ValueDesc {
+
+        public EmptyValue(Expression reference, Expression expr) {
+            super(reference, expr);
+        }
+
+        @Override
+        public ValueDesc union(ValueDesc other) {
+            return other;
+        }
+
+        @Override
+        public ValueDesc intersect(ValueDesc other) {
+            return this;
+        }
+
+        @Override
+        public Expression toExpression() {
+            return BooleanLiteral.FALSE;
+        }
+    }
+
+    /**
+     * use @see com.google.common.collect.Range to wrap `ComparisonPredicate`
+     * for example:
+     * a > 1 => (1...+∞)
+     */
+    private static class RangeValue extends ValueDesc {
+        Range<Literal> range;
+
+        public RangeValue(Expression reference, Expression expr) {
+            super(reference, expr);
+        }
+
+        @Override
+        public ValueDesc union(ValueDesc other) {
+            if (other instanceof EmptyValue) {
+                return other.union(this);
+            }
+            try {
+                if (other instanceof RangeValue) {
+                    RangeValue o = (RangeValue) other;
+                    if (range.isConnected(o.range)) {
+                        RangeValue rangeValue = new RangeValue(reference, 
ExpressionUtils.or(expr, other.expr));
+                        rangeValue.range = range.span(o.range);
+                        return rangeValue;
+                    }
+                    Expression originExpr = ExpressionUtils.or(expr, 
other.expr);
+                    return new UnknownValue(ImmutableList.of(this, other), 
originExpr, ExpressionUtils::or);
+                }
+                return union(this, (DiscreteValue) other, false);
+            } catch (Exception e) {
+                Expression originExpr = ExpressionUtils.or(expr, other.expr);
+                return new UnknownValue(ImmutableList.of(this, other), 
originExpr, ExpressionUtils::or);
+            }
+        }
+
+        @Override
+        public ValueDesc intersect(ValueDesc other) {
+            if (other instanceof EmptyValue) {
+                return other.intersect(this);
+            }
+            try {
+                if (other instanceof RangeValue) {
+                    RangeValue o = (RangeValue) other;
+                    if (range.isConnected(o.range)) {
+                        RangeValue rangeValue = new RangeValue(reference, 
ExpressionUtils.and(expr, other.expr));
+                        rangeValue.range = range.intersection(o.range);
+                        return rangeValue;
+                    }
+                    return new EmptyValue(reference, ExpressionUtils.and(expr, 
other.expr));
+                }
+                return intersect(this, (DiscreteValue) other);
+            } catch (Exception e) {
+                Expression originExpr = ExpressionUtils.and(expr, other.expr);
+                return new UnknownValue(ImmutableList.of(this, other), 
originExpr, ExpressionUtils::and);
+            }
+        }
+
+        @Override
+        public Expression toExpression() {
+            List<Expression> result = Lists.newArrayList();
+            if (range.hasLowerBound()) {
+                if (range.lowerBoundType() == BoundType.CLOSED) {
+                    result.add(new GreaterThanEqual(reference, 
range.lowerEndpoint()));
+                } else {
+                    result.add(new GreaterThan(reference, 
range.lowerEndpoint()));
+                }
+            }
+            if (range.hasUpperBound()) {
+                if (range.upperBoundType() == BoundType.CLOSED) {
+                    result.add(new LessThanEqual(reference, 
range.upperEndpoint()));
+                } else {
+                    result.add(new LessThan(reference, range.upperEndpoint()));
+                }
+            }
+            return result.isEmpty() ? BooleanLiteral.TRUE : 
ExpressionUtils.and(result);
+        }
+
+        @Override
+        public String toString() {
+            return range == null ? "UnknwonRange" : range.toString();
+        }
+    }
+
+    /**
+     * use `Set` to wrap `InPredicate`
+     * for example:
+     * a in (1,2,3) => [1,2,3]
+     */
+    private static class DiscreteValue extends ValueDesc {
+        Set<Literal> values;
+
+        public DiscreteValue(Expression reference, Expression expr, Literal... 
values) {
+            this(reference, expr, Arrays.asList(values));
+        }
+
+        public DiscreteValue(Expression reference, Expression expr, 
Collection<Literal> values) {
+            super(reference, expr);
+            this.values = Sets.newTreeSet(values);
+        }
+
+        @Override
+        public ValueDesc union(ValueDesc other) {
+            if (other instanceof EmptyValue) {
+                return other.union(this);
+            }
+            try {
+                if (other instanceof DiscreteValue) {
+                    DiscreteValue discreteValue = new DiscreteValue(reference, 
ExpressionUtils.or(expr, other.expr));
+                    discreteValue.values.addAll(((DiscreteValue) 
other).values);
+                    discreteValue.values.addAll(this.values);
+                    return discreteValue;
+                }
+                return union((RangeValue) other, this, true);
+            } catch (Exception e) {
+                Expression originExpr = ExpressionUtils.or(expr, other.expr);
+                return new UnknownValue(ImmutableList.of(this, other), 
originExpr, ExpressionUtils::or);
+            }
+        }
+
+        @Override
+        public ValueDesc intersect(ValueDesc other) {
+            if (other instanceof EmptyValue) {
+                return other.intersect(this);
+            }
+            try {
+                if (other instanceof DiscreteValue) {
+                    DiscreteValue discreteValue = new DiscreteValue(reference, 
ExpressionUtils.and(expr, other.expr));
+                    discreteValue.values.addAll(((DiscreteValue) 
other).values);
+                    discreteValue.values.retainAll(this.values);
+                    if (discreteValue.values.isEmpty()) {
+                        return new EmptyValue(reference, 
ExpressionUtils.and(expr, other.expr));
+                    } else {
+                        return discreteValue;
+                    }
+                }
+                return intersect((RangeValue) other, this);
+            } catch (Exception e) {
+                Expression originExpr = ExpressionUtils.and(expr, other.expr);
+                return new UnknownValue(ImmutableList.of(this, other), 
originExpr, ExpressionUtils::and);
+            }
+        }
+
+        @Override
+        public Expression toExpression() {
+            if (values.size() == 1) {
+                return new EqualTo(reference, values.iterator().next());
+            } else {
+                return new InPredicate(reference, Lists.newArrayList(values));
+            }
+        }
+
+        @Override
+        public String toString() {
+            return values.toString();
+        }
+    }
+
+    /**
+     * Represents processing result.
+     */
+    private static class UnknownValue extends ValueDesc {
+        private final List<ValueDesc> sourceValues;
+        private final BinaryOperator<Expression> mergeExprOp;
+
+        private UnknownValue(Expression expr) {
+            super(expr, expr);
+            sourceValues = ImmutableList.of();
+            mergeExprOp = null;
+        }
+
+        public UnknownValue(List<ValueDesc> sourceValues, Expression 
originExpr,
+                BinaryOperator<Expression> mergeExprOp) {
+            super(sourceValues.get(0).reference, originExpr);
+            this.sourceValues = ImmutableList.copyOf(sourceValues);
+            this.mergeExprOp = mergeExprOp;
+        }
+
+        @Override
+        public ValueDesc union(ValueDesc other) {
+            Expression originExpr = ExpressionUtils.or(expr, other.expr);
+            return new UnknownValue(ImmutableList.of(this, other), originExpr, 
ExpressionUtils::or);
+        }
+
+        @Override
+        public ValueDesc intersect(ValueDesc other) {
+            Expression originExpr = ExpressionUtils.and(expr, other.expr);
+            return new UnknownValue(ImmutableList.of(this, other), originExpr, 
ExpressionUtils::or);
+        }
+
+        @Override
+        public Expression toExpression() {
+            if (sourceValues.isEmpty()) {
+                return expr;
+            }
+            return sourceValues.stream()
+                    .map(ValueDesc::toExpression)
+                    .reduce(mergeExprOp)
+                    .get();
+        }
+    }
+}
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/Literal.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/Literal.java
index 6c8dc2677e..307f7d8574 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/Literal.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/Literal.java
@@ -38,7 +38,7 @@ import java.util.Objects;
  * All data type literal expression in Nereids.
  * TODO: Increase the implementation of sub expression. such as Integer.
  */
-public abstract class Literal extends Expression implements LeafExpression {
+public abstract class Literal extends Expression implements LeafExpression, 
Comparable<Literal> {
 
     private final DataType dataType;
 
@@ -129,6 +129,7 @@ public abstract class Literal extends Expression implements 
LeafExpression {
     /**
      * literal expr compare.
      */
+    @Override
     public int compareTo(Literal other) {
         if (isNullLiteral() && other.isNullLiteral()) {
             return 0;
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java
index 9bb36489c9..d3154211a0 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java
@@ -293,6 +293,10 @@ public class ExpressionUtils {
         return children.stream().allMatch(c -> c instanceof Literal);
     }
 
+    public static boolean matchNumericType(List<Expression> children) {
+        return children.stream().allMatch(c -> 
c.getDataType().isNumericType());
+    }
+
     public static boolean hasNullLiteral(List<Expression> children) {
         return children.stream().anyMatch(c -> c instanceof NullLiteral);
     }
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/datasets/ssb/SSBJoinReorderTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/datasets/ssb/SSBJoinReorderTest.java
index 54f0010cd2..2861e9b1f7 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/datasets/ssb/SSBJoinReorderTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/datasets/ssb/SSBJoinReorderTest.java
@@ -55,7 +55,7 @@ public class SSBJoinReorderTest extends SSBTestBase 
implements PatternMatchSuppo
                         "(lo_partkey = p_partkey)"
                 ),
                 ImmutableList.of(
-                        "((d_year = 1997) OR (d_year = 1998))",
+                        "d_year IN (1997, 1998)",
                         "(c_region = 'AMERICA')",
                         "(s_region = 'AMERICA')",
                         "((p_mfgr = 'MFGR#1') OR (p_mfgr = 'MFGR#2'))"
@@ -74,7 +74,7 @@ public class SSBJoinReorderTest extends SSBTestBase 
implements PatternMatchSuppo
                         "(lo_partkey = p_partkey)"
                 ),
                 ImmutableList.of(
-                        "((d_year = 1997) OR (d_year = 1998))",
+                        "d_year IN (1997, 1998)",
                         "(s_nation = 'UNITED STATES')",
                         "(p_category = 'MFGR#14')"
                 )
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rewrite/SimplifyRangeTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rewrite/SimplifyRangeTest.java
new file mode 100644
index 0000000000..617ad971f7
--- /dev/null
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rewrite/SimplifyRangeTest.java
@@ -0,0 +1,140 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.expression.rewrite;
+
+import org.apache.doris.nereids.analyzer.UnboundSlot;
+import org.apache.doris.nereids.parser.NereidsParser;
+import org.apache.doris.nereids.rules.expression.rewrite.rules.SimplifyRange;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.Slot;
+import org.apache.doris.nereids.trees.expressions.SlotReference;
+import org.apache.doris.nereids.types.BigIntType;
+import org.apache.doris.nereids.types.BooleanType;
+import org.apache.doris.nereids.types.DataType;
+import org.apache.doris.nereids.types.DoubleType;
+import org.apache.doris.nereids.types.IntegerType;
+import org.apache.doris.nereids.types.StringType;
+import org.apache.doris.nereids.types.TinyIntType;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.Lists;
+import com.google.common.collect.Maps;
+import org.junit.jupiter.api.Assertions;
+import org.junit.jupiter.api.Test;
+
+import java.util.List;
+import java.util.Map;
+
+public class SimplifyRangeTest {
+
+    private static final NereidsParser PARSER = new NereidsParser();
+    private ExpressionRuleExecutor executor;
+
+    @Test
+    public void testSimplify() {
+        executor = new 
ExpressionRuleExecutor(ImmutableList.of(SimplifyRange.INSTANCE));
+        assertRewrite("TA", "TA");
+        assertRewrite("(TA >= 1 and TA <=3 ) or (TA > 5 and TA < 7)", "(TA >= 
1 and TA <=3 ) or (TA > 5 and TA < 7)");
+        assertRewrite("(TA > 3 and TA < 1) or (TA > 7 and TA < 5)", "FALSE");
+        assertRewrite("TA > 3 and TA < 1", "FALSE");
+        assertRewrite("TA >= 3 and TA < 3", "TA >= 3 and TA < 3");
+        assertRewrite("TA = 1 and TA > 10", "FALSE");
+        assertRewrite("TA > 5 or TA < 1", "TA > 5 or TA < 1");
+        assertRewrite("TA > 5 or TA > 1 or TA > 10", "TA > 1");
+        assertRewrite("TA > 5 or TA > 1 or TA < 10", "TRUE");
+        assertRewrite("TA > 5 and TA > 1 and TA > 10", "TA > 10");
+        assertRewrite("TA > 5 and TA > 1 and TA < 10", "TA > 5 and TA < 10");
+        assertRewrite("TA > 1 or TA < 1", "TA > 1 or TA < 1");
+        assertRewrite("TA > 1 or TA < 10", "TRUE");
+        assertRewrite("TA > 5 and TA < 10", "TA > 5 and TA < 10");
+        assertRewrite("TA > 5 and TA > 10", "TA > 10");
+        assertRewrite("TA > 5 + 1 and TA > 10", "TA > 5 + 1 and TA > 10");
+        assertRewrite("TA > 5 + 1 and TA > 10", "TA > 5 + 1 and TA > 10");
+        assertRewrite("(TA > 1 and TA > 10) or TA > 20", "TA > 10");
+        assertRewrite("(TA > 1 or TA > 10) and TA > 20", "TA > 20");
+        assertRewrite("(TA + TB > 1 or TA + TB > 10) and TA + TB > 20", "TA + 
TB > 20");
+        assertRewrite("TA > 10 or TA > 10", "TA > 10");
+        assertRewrite("(TA > 10 or TA > 20) and (TB > 10 and TB < 20)", "TA > 
10 and (TB > 10 and TB < 20) ");
+        assertRewrite("(TA > 10 or TA > 20) and (TB > 10 and TB > 20)", "TA > 
10 and TB > 20");
+        assertRewrite("((TB > 30 and TA > 40) and TA > 20) and (TB > 10 and TB 
> 20)", "TB > 30 and TA > 40");
+        assertRewrite("(TA > 10 and TB > 10) or (TB > 10 and TB > 20)", "TA > 
10 and TB > 10 or TB > 20");
+        assertRewrite("((TA > 10 or TA > 5) and TB > 10) or (TB > 10 and (TB > 
20 or TB < 10))", "(TA > 5 and TB > 10) or (TB > 10 and (TB > 20 or TB < 10))");
+        assertRewrite("TA in (1,2,3) and TA > 10", "FALSE");
+        assertRewrite("TA in (1,2,3) and TA >= 1", "TA in (1,2,3)");
+        assertRewrite("TA in (1,2,3) and TA > 1", "TA in (2,3)");
+        assertRewrite("TA in (1,2,3) or TA >= 1", "TA >= 1");
+        assertRewrite("TA in (1)", "TA in (1)");
+        assertRewrite("TA in (1,2,3) and TA < 10", "TA in (1,2,3)");
+        assertRewrite("TA in (1,2,3) and TA < 1", "FALSE");
+        assertRewrite("TA in (1,2,3) or TA < 1", "TA in (1,2,3) or TA < 1");
+        assertRewrite("TA in (1,2,3) or TA in (2,3,4)", "TA in (1,2,3,4)");
+        assertRewrite("TA in (1,2,3) or TA in (4,5,6)", "TA in (1,2,3,4,5,6)");
+        assertRewrite("TA in (1,2,3) and TA in (4,5,6)", "FALSE");
+        assertRewrite("TA in (1,2,3) and TA in (3,4,5)", "TA = 3");
+        assertRewrite("TA + TB in (1,2,3) and TA + TB in (3,4,5)", "TA + TB = 
3");
+        assertRewrite("TA in (1,2,3) and DA > 1.5", "TA in (1,2,3) and DA > 
1.5");
+        assertRewrite("TA = 1 and TA = 3", "FALSE");
+        assertRewrite("TA in (1) and TA in (3)", "FALSE");
+        assertRewrite("TA in (1) and TA in (1)", "TA = 1");
+        assertRewrite("(TA > 3 and TA < 1) and TB < 5", "FALSE");
+        assertRewrite("(TA > 3 and TA < 1) or TB < 5", "TB < 5");
+    }
+
+    private void assertRewrite(String expression, String expected) {
+        Map<String, Slot> mem = Maps.newHashMap();
+        Expression needRewriteExpression = 
replaceUnboundSlot(PARSER.parseExpression(expression), mem);
+        Expression expectedExpression = 
replaceUnboundSlot(PARSER.parseExpression(expected), mem);
+        Expression rewrittenExpression = 
executor.rewrite(needRewriteExpression);
+        Assertions.assertEquals(expectedExpression, rewrittenExpression);
+    }
+
+    private Expression replaceUnboundSlot(Expression expression, Map<String, 
Slot> mem) {
+        List<Expression> children = Lists.newArrayList();
+        boolean hasNewChildren = false;
+        for (Expression child : expression.children()) {
+            Expression newChild = replaceUnboundSlot(child, mem);
+            if (newChild != child) {
+                hasNewChildren = true;
+            }
+            children.add(newChild);
+        }
+        if (expression instanceof UnboundSlot) {
+            String name = ((UnboundSlot) expression).getName();
+            mem.putIfAbsent(name, new SlotReference(name, 
getType(name.charAt(0))));
+            return mem.get(name);
+        }
+        return hasNewChildren ? expression.withChildren(children) : expression;
+    }
+
+    private DataType getType(char t) {
+        switch (t) {
+            case 'T':
+                return TinyIntType.INSTANCE;
+            case 'I':
+                return IntegerType.INSTANCE;
+            case 'D':
+                return DoubleType.INSTANCE;
+            case 'S':
+                return StringType.INSTANCE;
+            case 'B':
+                return BooleanType.INSTANCE;
+            default:
+                return BigIntType.INSTANCE;
+        }
+    }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to