This is an automated email from the ASF dual-hosted git repository.
huajianlan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new 730cd1a0c1 [Feature](Nereids) Simplify range of predicate (#14113)
730cd1a0c1 is described below
commit 730cd1a0c1417de16fa822f5ed649d3bc82b4b34
Author: shee <[email protected]>
AuthorDate: Mon Nov 21 20:24:03 2022 +0800
[Feature](Nereids) Simplify range of predicate (#14113)
Simplify range of predicate
for example:
1. `a > 1 or a > 2` => `a > 1`
2. `a in (1,2,3) or a (3,4,5)` => `a in (1,2,3,4,5)`
---
.../expression/rewrite/ExpressionOptimization.java | 5 +-
.../expression/rewrite/ExpressionRuleExecutor.java | 5 +
.../expression/rewrite/rules/SimplifyRange.java | 461 +++++++++++++++++++++
.../nereids/trees/expressions/literal/Literal.java | 3 +-
.../apache/doris/nereids/util/ExpressionUtils.java | 4 +
.../nereids/datasets/ssb/SSBJoinReorderTest.java | 4 +-
.../expression/rewrite/SimplifyRangeTest.java | 140 +++++++
7 files changed, 618 insertions(+), 4 deletions(-)
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/ExpressionOptimization.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/ExpressionOptimization.java
index 973af7ab72..9e3eace63f 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/ExpressionOptimization.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/ExpressionOptimization.java
@@ -20,6 +20,7 @@ package org.apache.doris.nereids.rules.expression.rewrite;
import
org.apache.doris.nereids.rules.expression.rewrite.rules.DistinctPredicatesRule;
import
org.apache.doris.nereids.rules.expression.rewrite.rules.ExtractCommonFactorRule;
import
org.apache.doris.nereids.rules.expression.rewrite.rules.SimplifyComparisonPredicate;
+import org.apache.doris.nereids.rules.expression.rewrite.rules.SimplifyRange;
import com.google.common.collect.ImmutableList;
@@ -32,7 +33,9 @@ public class ExpressionOptimization extends ExpressionRewrite
{
public static final List<ExpressionRewriteRule> OPTIMIZE_REWRITE_RULES =
ImmutableList.of(
ExtractCommonFactorRule.INSTANCE,
DistinctPredicatesRule.INSTANCE,
- SimplifyComparisonPredicate.INSTANCE);
+ SimplifyComparisonPredicate.INSTANCE,
+ SimplifyRange.INSTANCE
+ );
private static final ExpressionRuleExecutor EXECUTOR = new
ExpressionRuleExecutor(OPTIMIZE_REWRITE_RULES);
public ExpressionOptimization() {
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/ExpressionRuleExecutor.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/ExpressionRuleExecutor.java
index 44829a2a31..d8da2617cc 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/ExpressionRuleExecutor.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/ExpressionRuleExecutor.java
@@ -17,6 +17,7 @@
package org.apache.doris.nereids.rules.expression.rewrite;
+import
org.apache.doris.nereids.rules.expression.rewrite.rules.NormalizeBinaryPredicatesRule;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.qe.ConnectContext;
@@ -67,4 +68,8 @@ public class ExpressionRuleExecutor {
return rule.rewrite(expr, ctx);
}
+ public static Expression normalize(Expression expression) {
+ return NormalizeBinaryPredicatesRule.INSTANCE.rewrite(expression,
null);
+ }
+
}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/rules/SimplifyRange.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/rules/SimplifyRange.java
new file mode 100644
index 0000000000..010031486c
--- /dev/null
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/rules/SimplifyRange.java
@@ -0,0 +1,461 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.expression.rewrite.rules;
+
+import
org.apache.doris.nereids.rules.expression.rewrite.AbstractExpressionRewriteRule;
+import
org.apache.doris.nereids.rules.expression.rewrite.ExpressionRewriteContext;
+import
org.apache.doris.nereids.rules.expression.rewrite.ExpressionRuleExecutor;
+import org.apache.doris.nereids.trees.expressions.And;
+import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
+import org.apache.doris.nereids.trees.expressions.CompoundPredicate;
+import org.apache.doris.nereids.trees.expressions.EqualTo;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.GreaterThan;
+import org.apache.doris.nereids.trees.expressions.GreaterThanEqual;
+import org.apache.doris.nereids.trees.expressions.InPredicate;
+import org.apache.doris.nereids.trees.expressions.LessThan;
+import org.apache.doris.nereids.trees.expressions.LessThanEqual;
+import org.apache.doris.nereids.trees.expressions.Or;
+import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
+import org.apache.doris.nereids.trees.expressions.literal.Literal;
+import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
+import org.apache.doris.nereids.util.ExpressionUtils;
+
+import com.google.common.collect.BoundType;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.Lists;
+import com.google.common.collect.Range;
+import com.google.common.collect.Sets;
+
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Map.Entry;
+import java.util.Set;
+import java.util.function.BinaryOperator;
+import java.util.stream.Collectors;
+
+/**
+ * This class implements the function to simplify expression range.
+ * for example:
+ * a > 1 and a > 2 => a > 2
+ * a > 1 or a > 2 => a > 1
+ * a in (1,2,3) and a > 1 => a in (2,3)
+ * a in (1,2,3) and a in (3,4,5) => a = 3
+ * a in(1,2,3) and a in (4,5,6) => false
+ * The logic is as follows:
+ * 1. for `And` expression.
+ * 1. extract conjunctions then build `ValueDesc` for each conjunction
+ * 2. grouping according to `reference`, `ValueDesc` in the same group can
perform intersect
+ * for example:
+ * a > 1 and a > 2
+ * 1. a > 1 => RangeValueDesc((1...+∞)), a > 2 => RangeValueDesc((2...+∞))
+ * 2. (1...+∞) intersect (2...+∞) => (2...+∞)
+ * 2. for `Or` expression (similar to `And`).
+ * todo: support a > 10 and (a < 10 or a > 20 ) => a > 20
+ */
+public class SimplifyRange extends AbstractExpressionRewriteRule {
+
+ public static final SimplifyRange INSTANCE = new SimplifyRange();
+
+ @Override
+ public Expression rewrite(Expression expr, ExpressionRewriteContext ctx) {
+ if (expr instanceof CompoundPredicate) {
+ ValueDesc valueDesc = expr.accept(new RangeInference(), null);
+ Expression simplifiedExpr = valueDesc.toExpression();
+ return simplifiedExpr == null ? valueDesc.expr : simplifiedExpr;
+ }
+ return expr;
+ }
+
+ private static class RangeInference extends ExpressionVisitor<ValueDesc,
Void> {
+
+ @Override
+ public ValueDesc visit(Expression expr, Void context) {
+ return new UnknownValue(expr);
+ }
+
+ private ValueDesc buildRange(ComparisonPredicate predicate) {
+ Expression rewrite = ExpressionRuleExecutor.normalize(predicate);
+ Expression right = rewrite.child(1);
+ // only handle `NumericType`
+ if (right.isLiteral() && right.getDataType().isNumericType()) {
+ return ValueDesc.range((ComparisonPredicate) rewrite);
+ }
+ return new UnknownValue(predicate);
+ }
+
+ @Override
+ public ValueDesc visitGreaterThan(GreaterThan greaterThan, Void
context) {
+ return buildRange(greaterThan);
+ }
+
+ @Override
+ public ValueDesc visitGreaterThanEqual(GreaterThanEqual
greaterThanEqual, Void context) {
+ return buildRange(greaterThanEqual);
+ }
+
+ @Override
+ public ValueDesc visitLessThan(LessThan lessThan, Void context) {
+ return buildRange(lessThan);
+ }
+
+ @Override
+ public ValueDesc visitLessThanEqual(LessThanEqual lessThanEqual, Void
context) {
+ return buildRange(lessThanEqual);
+ }
+
+ @Override
+ public ValueDesc visitEqualTo(EqualTo equalTo, Void context) {
+ return buildRange(equalTo);
+ }
+
+ @Override
+ public ValueDesc visitInPredicate(InPredicate inPredicate, Void
context) {
+ // only handle `NumericType`
+ if (ExpressionUtils.isAllLiteral(inPredicate.getOptions())
+ &&
ExpressionUtils.matchNumericType(inPredicate.getOptions())) {
+ return ValueDesc.discrete(inPredicate);
+ }
+ return new UnknownValue(inPredicate);
+ }
+
+ @Override
+ public ValueDesc visitAnd(And and, Void context) {
+ return simplify(and, ExpressionUtils.extractConjunction(and),
ValueDesc::intersect, ExpressionUtils::and);
+ }
+
+ @Override
+ public ValueDesc visitOr(Or or, Void context) {
+ return simplify(or, ExpressionUtils.extractDisjunction(or),
ValueDesc::union, ExpressionUtils::or);
+ }
+
+ private ValueDesc simplify(Expression originExpr, List<Expression>
predicates,
+ BinaryOperator<ValueDesc> op, BinaryOperator<Expression>
exprOp) {
+
+ Map<Expression, List<ValueDesc>> groupByReference =
predicates.stream()
+ .map(predicate -> predicate.accept(this, null))
+ .collect(Collectors.groupingBy(p -> p.reference,
LinkedHashMap::new, Collectors.toList()));
+
+ List<ValueDesc> valuePerRefs = Lists.newArrayList();
+ for (Entry<Expression, List<ValueDesc>> referenceValues :
groupByReference.entrySet()) {
+ List<ValueDesc> valuePerReference = referenceValues.getValue();
+
+ // merge per reference
+ ValueDesc simplifiedValue = valuePerReference.stream()
+ .reduce(op)
+ .get();
+
+ valuePerRefs.add(simplifiedValue);
+ }
+
+ if (valuePerRefs.size() == 1) {
+ return valuePerRefs.get(0);
+ }
+
+ // use UnknownValue to wrap different references
+ return new UnknownValue(valuePerRefs, originExpr, exprOp);
+ }
+ }
+
+ private abstract static class ValueDesc {
+ Expression expr;
+ Expression reference;
+
+ public ValueDesc(Expression reference, Expression expr) {
+ this.expr = expr;
+ this.reference = reference;
+ }
+
+ public abstract ValueDesc union(ValueDesc other);
+
+ public static ValueDesc union(RangeValue range, DiscreteValue
discrete, boolean reverseOrder) {
+ long count = discrete.values.stream().filter(x ->
range.range.test(x)).count();
+ if (count == discrete.values.size()) {
+ return range;
+ }
+ Expression originExpr = ExpressionUtils.or(range.expr,
discrete.expr);
+ List<ValueDesc> sourceValues = reverseOrder
+ ? ImmutableList.of(discrete, range)
+ : ImmutableList.of(range, discrete);
+ return new UnknownValue(sourceValues, originExpr,
ExpressionUtils::or);
+ }
+
+ public abstract ValueDesc intersect(ValueDesc other);
+
+ public static ValueDesc intersect(RangeValue range, DiscreteValue
discrete) {
+ DiscreteValue result = new DiscreteValue(discrete.reference,
discrete.expr);
+ discrete.values.stream().filter(x ->
range.range.contains(x)).forEach(result.values::add);
+ if (result.values.size() > 0) {
+ return result;
+ }
+ return new EmptyValue(range.reference,
ExpressionUtils.and(range.expr, discrete.expr));
+ }
+
+ public abstract Expression toExpression();
+
+ public static ValueDesc range(ComparisonPredicate predicate) {
+ Literal value = (Literal) predicate.right();
+ if (predicate instanceof EqualTo) {
+ return new DiscreteValue(predicate.left(), predicate, value);
+ }
+ RangeValue rangeValue = new RangeValue(predicate.left(),
predicate);
+ if (predicate instanceof GreaterThanEqual) {
+ rangeValue.range = Range.atLeast(value);
+ } else if (predicate instanceof GreaterThan) {
+ rangeValue.range = Range.greaterThan(value);
+ } else if (predicate instanceof LessThanEqual) {
+ rangeValue.range = Range.atMost(value);
+ } else if (predicate instanceof LessThan) {
+ rangeValue.range = Range.lessThan(value);
+ }
+
+ return rangeValue;
+ }
+
+ public static ValueDesc discrete(InPredicate in) {
+ Set<Literal> literals =
in.getOptions().stream().map(Literal.class::cast).collect(Collectors.toSet());
+ return new DiscreteValue(in.getCompareExpr(), in, literals);
+ }
+ }
+
+ private static class EmptyValue extends ValueDesc {
+
+ public EmptyValue(Expression reference, Expression expr) {
+ super(reference, expr);
+ }
+
+ @Override
+ public ValueDesc union(ValueDesc other) {
+ return other;
+ }
+
+ @Override
+ public ValueDesc intersect(ValueDesc other) {
+ return this;
+ }
+
+ @Override
+ public Expression toExpression() {
+ return BooleanLiteral.FALSE;
+ }
+ }
+
+ /**
+ * use @see com.google.common.collect.Range to wrap `ComparisonPredicate`
+ * for example:
+ * a > 1 => (1...+∞)
+ */
+ private static class RangeValue extends ValueDesc {
+ Range<Literal> range;
+
+ public RangeValue(Expression reference, Expression expr) {
+ super(reference, expr);
+ }
+
+ @Override
+ public ValueDesc union(ValueDesc other) {
+ if (other instanceof EmptyValue) {
+ return other.union(this);
+ }
+ try {
+ if (other instanceof RangeValue) {
+ RangeValue o = (RangeValue) other;
+ if (range.isConnected(o.range)) {
+ RangeValue rangeValue = new RangeValue(reference,
ExpressionUtils.or(expr, other.expr));
+ rangeValue.range = range.span(o.range);
+ return rangeValue;
+ }
+ Expression originExpr = ExpressionUtils.or(expr,
other.expr);
+ return new UnknownValue(ImmutableList.of(this, other),
originExpr, ExpressionUtils::or);
+ }
+ return union(this, (DiscreteValue) other, false);
+ } catch (Exception e) {
+ Expression originExpr = ExpressionUtils.or(expr, other.expr);
+ return new UnknownValue(ImmutableList.of(this, other),
originExpr, ExpressionUtils::or);
+ }
+ }
+
+ @Override
+ public ValueDesc intersect(ValueDesc other) {
+ if (other instanceof EmptyValue) {
+ return other.intersect(this);
+ }
+ try {
+ if (other instanceof RangeValue) {
+ RangeValue o = (RangeValue) other;
+ if (range.isConnected(o.range)) {
+ RangeValue rangeValue = new RangeValue(reference,
ExpressionUtils.and(expr, other.expr));
+ rangeValue.range = range.intersection(o.range);
+ return rangeValue;
+ }
+ return new EmptyValue(reference, ExpressionUtils.and(expr,
other.expr));
+ }
+ return intersect(this, (DiscreteValue) other);
+ } catch (Exception e) {
+ Expression originExpr = ExpressionUtils.and(expr, other.expr);
+ return new UnknownValue(ImmutableList.of(this, other),
originExpr, ExpressionUtils::and);
+ }
+ }
+
+ @Override
+ public Expression toExpression() {
+ List<Expression> result = Lists.newArrayList();
+ if (range.hasLowerBound()) {
+ if (range.lowerBoundType() == BoundType.CLOSED) {
+ result.add(new GreaterThanEqual(reference,
range.lowerEndpoint()));
+ } else {
+ result.add(new GreaterThan(reference,
range.lowerEndpoint()));
+ }
+ }
+ if (range.hasUpperBound()) {
+ if (range.upperBoundType() == BoundType.CLOSED) {
+ result.add(new LessThanEqual(reference,
range.upperEndpoint()));
+ } else {
+ result.add(new LessThan(reference, range.upperEndpoint()));
+ }
+ }
+ return result.isEmpty() ? BooleanLiteral.TRUE :
ExpressionUtils.and(result);
+ }
+
+ @Override
+ public String toString() {
+ return range == null ? "UnknwonRange" : range.toString();
+ }
+ }
+
+ /**
+ * use `Set` to wrap `InPredicate`
+ * for example:
+ * a in (1,2,3) => [1,2,3]
+ */
+ private static class DiscreteValue extends ValueDesc {
+ Set<Literal> values;
+
+ public DiscreteValue(Expression reference, Expression expr, Literal...
values) {
+ this(reference, expr, Arrays.asList(values));
+ }
+
+ public DiscreteValue(Expression reference, Expression expr,
Collection<Literal> values) {
+ super(reference, expr);
+ this.values = Sets.newTreeSet(values);
+ }
+
+ @Override
+ public ValueDesc union(ValueDesc other) {
+ if (other instanceof EmptyValue) {
+ return other.union(this);
+ }
+ try {
+ if (other instanceof DiscreteValue) {
+ DiscreteValue discreteValue = new DiscreteValue(reference,
ExpressionUtils.or(expr, other.expr));
+ discreteValue.values.addAll(((DiscreteValue)
other).values);
+ discreteValue.values.addAll(this.values);
+ return discreteValue;
+ }
+ return union((RangeValue) other, this, true);
+ } catch (Exception e) {
+ Expression originExpr = ExpressionUtils.or(expr, other.expr);
+ return new UnknownValue(ImmutableList.of(this, other),
originExpr, ExpressionUtils::or);
+ }
+ }
+
+ @Override
+ public ValueDesc intersect(ValueDesc other) {
+ if (other instanceof EmptyValue) {
+ return other.intersect(this);
+ }
+ try {
+ if (other instanceof DiscreteValue) {
+ DiscreteValue discreteValue = new DiscreteValue(reference,
ExpressionUtils.and(expr, other.expr));
+ discreteValue.values.addAll(((DiscreteValue)
other).values);
+ discreteValue.values.retainAll(this.values);
+ if (discreteValue.values.isEmpty()) {
+ return new EmptyValue(reference,
ExpressionUtils.and(expr, other.expr));
+ } else {
+ return discreteValue;
+ }
+ }
+ return intersect((RangeValue) other, this);
+ } catch (Exception e) {
+ Expression originExpr = ExpressionUtils.and(expr, other.expr);
+ return new UnknownValue(ImmutableList.of(this, other),
originExpr, ExpressionUtils::and);
+ }
+ }
+
+ @Override
+ public Expression toExpression() {
+ if (values.size() == 1) {
+ return new EqualTo(reference, values.iterator().next());
+ } else {
+ return new InPredicate(reference, Lists.newArrayList(values));
+ }
+ }
+
+ @Override
+ public String toString() {
+ return values.toString();
+ }
+ }
+
+ /**
+ * Represents processing result.
+ */
+ private static class UnknownValue extends ValueDesc {
+ private final List<ValueDesc> sourceValues;
+ private final BinaryOperator<Expression> mergeExprOp;
+
+ private UnknownValue(Expression expr) {
+ super(expr, expr);
+ sourceValues = ImmutableList.of();
+ mergeExprOp = null;
+ }
+
+ public UnknownValue(List<ValueDesc> sourceValues, Expression
originExpr,
+ BinaryOperator<Expression> mergeExprOp) {
+ super(sourceValues.get(0).reference, originExpr);
+ this.sourceValues = ImmutableList.copyOf(sourceValues);
+ this.mergeExprOp = mergeExprOp;
+ }
+
+ @Override
+ public ValueDesc union(ValueDesc other) {
+ Expression originExpr = ExpressionUtils.or(expr, other.expr);
+ return new UnknownValue(ImmutableList.of(this, other), originExpr,
ExpressionUtils::or);
+ }
+
+ @Override
+ public ValueDesc intersect(ValueDesc other) {
+ Expression originExpr = ExpressionUtils.and(expr, other.expr);
+ return new UnknownValue(ImmutableList.of(this, other), originExpr,
ExpressionUtils::or);
+ }
+
+ @Override
+ public Expression toExpression() {
+ if (sourceValues.isEmpty()) {
+ return expr;
+ }
+ return sourceValues.stream()
+ .map(ValueDesc::toExpression)
+ .reduce(mergeExprOp)
+ .get();
+ }
+ }
+}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/Literal.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/Literal.java
index 6c8dc2677e..307f7d8574 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/Literal.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/Literal.java
@@ -38,7 +38,7 @@ import java.util.Objects;
* All data type literal expression in Nereids.
* TODO: Increase the implementation of sub expression. such as Integer.
*/
-public abstract class Literal extends Expression implements LeafExpression {
+public abstract class Literal extends Expression implements LeafExpression,
Comparable<Literal> {
private final DataType dataType;
@@ -129,6 +129,7 @@ public abstract class Literal extends Expression implements
LeafExpression {
/**
* literal expr compare.
*/
+ @Override
public int compareTo(Literal other) {
if (isNullLiteral() && other.isNullLiteral()) {
return 0;
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java
index 9bb36489c9..d3154211a0 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java
@@ -293,6 +293,10 @@ public class ExpressionUtils {
return children.stream().allMatch(c -> c instanceof Literal);
}
+ public static boolean matchNumericType(List<Expression> children) {
+ return children.stream().allMatch(c ->
c.getDataType().isNumericType());
+ }
+
public static boolean hasNullLiteral(List<Expression> children) {
return children.stream().anyMatch(c -> c instanceof NullLiteral);
}
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/datasets/ssb/SSBJoinReorderTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/datasets/ssb/SSBJoinReorderTest.java
index 54f0010cd2..2861e9b1f7 100644
---
a/fe/fe-core/src/test/java/org/apache/doris/nereids/datasets/ssb/SSBJoinReorderTest.java
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/datasets/ssb/SSBJoinReorderTest.java
@@ -55,7 +55,7 @@ public class SSBJoinReorderTest extends SSBTestBase
implements PatternMatchSuppo
"(lo_partkey = p_partkey)"
),
ImmutableList.of(
- "((d_year = 1997) OR (d_year = 1998))",
+ "d_year IN (1997, 1998)",
"(c_region = 'AMERICA')",
"(s_region = 'AMERICA')",
"((p_mfgr = 'MFGR#1') OR (p_mfgr = 'MFGR#2'))"
@@ -74,7 +74,7 @@ public class SSBJoinReorderTest extends SSBTestBase
implements PatternMatchSuppo
"(lo_partkey = p_partkey)"
),
ImmutableList.of(
- "((d_year = 1997) OR (d_year = 1998))",
+ "d_year IN (1997, 1998)",
"(s_nation = 'UNITED STATES')",
"(p_category = 'MFGR#14')"
)
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rewrite/SimplifyRangeTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rewrite/SimplifyRangeTest.java
new file mode 100644
index 0000000000..617ad971f7
--- /dev/null
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rewrite/SimplifyRangeTest.java
@@ -0,0 +1,140 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.expression.rewrite;
+
+import org.apache.doris.nereids.analyzer.UnboundSlot;
+import org.apache.doris.nereids.parser.NereidsParser;
+import org.apache.doris.nereids.rules.expression.rewrite.rules.SimplifyRange;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.Slot;
+import org.apache.doris.nereids.trees.expressions.SlotReference;
+import org.apache.doris.nereids.types.BigIntType;
+import org.apache.doris.nereids.types.BooleanType;
+import org.apache.doris.nereids.types.DataType;
+import org.apache.doris.nereids.types.DoubleType;
+import org.apache.doris.nereids.types.IntegerType;
+import org.apache.doris.nereids.types.StringType;
+import org.apache.doris.nereids.types.TinyIntType;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.Lists;
+import com.google.common.collect.Maps;
+import org.junit.jupiter.api.Assertions;
+import org.junit.jupiter.api.Test;
+
+import java.util.List;
+import java.util.Map;
+
+public class SimplifyRangeTest {
+
+ private static final NereidsParser PARSER = new NereidsParser();
+ private ExpressionRuleExecutor executor;
+
+ @Test
+ public void testSimplify() {
+ executor = new
ExpressionRuleExecutor(ImmutableList.of(SimplifyRange.INSTANCE));
+ assertRewrite("TA", "TA");
+ assertRewrite("(TA >= 1 and TA <=3 ) or (TA > 5 and TA < 7)", "(TA >=
1 and TA <=3 ) or (TA > 5 and TA < 7)");
+ assertRewrite("(TA > 3 and TA < 1) or (TA > 7 and TA < 5)", "FALSE");
+ assertRewrite("TA > 3 and TA < 1", "FALSE");
+ assertRewrite("TA >= 3 and TA < 3", "TA >= 3 and TA < 3");
+ assertRewrite("TA = 1 and TA > 10", "FALSE");
+ assertRewrite("TA > 5 or TA < 1", "TA > 5 or TA < 1");
+ assertRewrite("TA > 5 or TA > 1 or TA > 10", "TA > 1");
+ assertRewrite("TA > 5 or TA > 1 or TA < 10", "TRUE");
+ assertRewrite("TA > 5 and TA > 1 and TA > 10", "TA > 10");
+ assertRewrite("TA > 5 and TA > 1 and TA < 10", "TA > 5 and TA < 10");
+ assertRewrite("TA > 1 or TA < 1", "TA > 1 or TA < 1");
+ assertRewrite("TA > 1 or TA < 10", "TRUE");
+ assertRewrite("TA > 5 and TA < 10", "TA > 5 and TA < 10");
+ assertRewrite("TA > 5 and TA > 10", "TA > 10");
+ assertRewrite("TA > 5 + 1 and TA > 10", "TA > 5 + 1 and TA > 10");
+ assertRewrite("TA > 5 + 1 and TA > 10", "TA > 5 + 1 and TA > 10");
+ assertRewrite("(TA > 1 and TA > 10) or TA > 20", "TA > 10");
+ assertRewrite("(TA > 1 or TA > 10) and TA > 20", "TA > 20");
+ assertRewrite("(TA + TB > 1 or TA + TB > 10) and TA + TB > 20", "TA +
TB > 20");
+ assertRewrite("TA > 10 or TA > 10", "TA > 10");
+ assertRewrite("(TA > 10 or TA > 20) and (TB > 10 and TB < 20)", "TA >
10 and (TB > 10 and TB < 20) ");
+ assertRewrite("(TA > 10 or TA > 20) and (TB > 10 and TB > 20)", "TA >
10 and TB > 20");
+ assertRewrite("((TB > 30 and TA > 40) and TA > 20) and (TB > 10 and TB
> 20)", "TB > 30 and TA > 40");
+ assertRewrite("(TA > 10 and TB > 10) or (TB > 10 and TB > 20)", "TA >
10 and TB > 10 or TB > 20");
+ assertRewrite("((TA > 10 or TA > 5) and TB > 10) or (TB > 10 and (TB >
20 or TB < 10))", "(TA > 5 and TB > 10) or (TB > 10 and (TB > 20 or TB < 10))");
+ assertRewrite("TA in (1,2,3) and TA > 10", "FALSE");
+ assertRewrite("TA in (1,2,3) and TA >= 1", "TA in (1,2,3)");
+ assertRewrite("TA in (1,2,3) and TA > 1", "TA in (2,3)");
+ assertRewrite("TA in (1,2,3) or TA >= 1", "TA >= 1");
+ assertRewrite("TA in (1)", "TA in (1)");
+ assertRewrite("TA in (1,2,3) and TA < 10", "TA in (1,2,3)");
+ assertRewrite("TA in (1,2,3) and TA < 1", "FALSE");
+ assertRewrite("TA in (1,2,3) or TA < 1", "TA in (1,2,3) or TA < 1");
+ assertRewrite("TA in (1,2,3) or TA in (2,3,4)", "TA in (1,2,3,4)");
+ assertRewrite("TA in (1,2,3) or TA in (4,5,6)", "TA in (1,2,3,4,5,6)");
+ assertRewrite("TA in (1,2,3) and TA in (4,5,6)", "FALSE");
+ assertRewrite("TA in (1,2,3) and TA in (3,4,5)", "TA = 3");
+ assertRewrite("TA + TB in (1,2,3) and TA + TB in (3,4,5)", "TA + TB =
3");
+ assertRewrite("TA in (1,2,3) and DA > 1.5", "TA in (1,2,3) and DA >
1.5");
+ assertRewrite("TA = 1 and TA = 3", "FALSE");
+ assertRewrite("TA in (1) and TA in (3)", "FALSE");
+ assertRewrite("TA in (1) and TA in (1)", "TA = 1");
+ assertRewrite("(TA > 3 and TA < 1) and TB < 5", "FALSE");
+ assertRewrite("(TA > 3 and TA < 1) or TB < 5", "TB < 5");
+ }
+
+ private void assertRewrite(String expression, String expected) {
+ Map<String, Slot> mem = Maps.newHashMap();
+ Expression needRewriteExpression =
replaceUnboundSlot(PARSER.parseExpression(expression), mem);
+ Expression expectedExpression =
replaceUnboundSlot(PARSER.parseExpression(expected), mem);
+ Expression rewrittenExpression =
executor.rewrite(needRewriteExpression);
+ Assertions.assertEquals(expectedExpression, rewrittenExpression);
+ }
+
+ private Expression replaceUnboundSlot(Expression expression, Map<String,
Slot> mem) {
+ List<Expression> children = Lists.newArrayList();
+ boolean hasNewChildren = false;
+ for (Expression child : expression.children()) {
+ Expression newChild = replaceUnboundSlot(child, mem);
+ if (newChild != child) {
+ hasNewChildren = true;
+ }
+ children.add(newChild);
+ }
+ if (expression instanceof UnboundSlot) {
+ String name = ((UnboundSlot) expression).getName();
+ mem.putIfAbsent(name, new SlotReference(name,
getType(name.charAt(0))));
+ return mem.get(name);
+ }
+ return hasNewChildren ? expression.withChildren(children) : expression;
+ }
+
+ private DataType getType(char t) {
+ switch (t) {
+ case 'T':
+ return TinyIntType.INSTANCE;
+ case 'I':
+ return IntegerType.INSTANCE;
+ case 'D':
+ return DoubleType.INSTANCE;
+ case 'S':
+ return StringType.INSTANCE;
+ case 'B':
+ return BooleanType.INSTANCE;
+ default:
+ return BigIntType.INSTANCE;
+ }
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]