This is an automated email from the ASF dual-hosted git repository.
morrysnow pushed a commit to branch branch-2.0
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/branch-2.0 by this push:
new 60a852ee574 [fix](Nereids) simplify range result wrong when reference
is nullable (#41356) (#42990)
60a852ee574 is described below
commit 60a852ee574403f869a99abfc83b49fb352c7e63
Author: morrySnow <[email protected]>
AuthorDate: Fri Nov 1 14:55:04 2024 +0800
[fix](Nereids) simplify range result wrong when reference is nullable
(#41356) (#42990)
pick from master #41356
if reference is nullable and simplify result is boolean literal. the
real result should be:
IF(${reference} IS NULL, NULL, ${not_null_result})
---
.../expression/rules/FoldConstantRuleOnFE.java | 8 +-
.../rules/expression/rules/SimplifyRange.java | 244 ++++++++++++---------
.../apache/doris/nereids/util/ExpressionUtils.java | 13 ++
.../rules/expression/SimplifyRangeTest.java | 68 ++++--
4 files changed, 206 insertions(+), 127 deletions(-)
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FoldConstantRuleOnFE.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FoldConstantRuleOnFE.java
index 71377c021b2..b84012ffd64 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FoldConstantRuleOnFE.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FoldConstantRuleOnFE.java
@@ -93,7 +93,13 @@ public class FoldConstantRuleOnFE extends
AbstractExpressionRewriteRule {
} else if (expr instanceof AggregateExpression &&
((AggregateExpression) expr).getFunction().isDistinct()) {
return expr;
}
- return expr.accept(this, ctx);
+ // ATTN: we must return original expr, because OrToIn is implemented
with MutableState,
+ // newExpr will lose these states leading to dead loop by OrToIn ->
SimplifyRange -> FoldConstantByFE
+ Expression newExpr = expr.accept(this, ctx);
+ if (newExpr.equals(expr)) {
+ return expr;
+ }
+ return newExpr;
}
/**
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyRange.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyRange.java
index 2c673aabd23..111b164a459 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyRange.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyRange.java
@@ -35,7 +35,9 @@ import org.apache.doris.nereids.trees.expressions.Not;
import org.apache.doris.nereids.trees.expressions.Or;
import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
+import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
+import org.apache.doris.nereids.types.BooleanType;
import org.apache.doris.nereids.util.ExpressionUtils;
import com.google.common.collect.BoundType;
@@ -82,85 +84,85 @@ public class SimplifyRange extends
AbstractExpressionRewriteRule {
public Expression rewrite(Expression expr, ExpressionRewriteContext ctx) {
if (expr instanceof CompoundPredicate) {
ValueDesc valueDesc = expr.accept(new RangeInference(), null);
- Expression simplifiedExpr = valueDesc.toExpression();
- return simplifiedExpr == null ? valueDesc.expr : simplifiedExpr;
+ Expression exprForNonNull = valueDesc.toExpressionForNonNull();
+ if (exprForNonNull == null) {
+ // this mean cannot simplify
+ return valueDesc.exprForNonNull;
+ }
+ return exprForNonNull;
}
return expr;
}
- private static class RangeInference extends ExpressionVisitor<ValueDesc,
Void> {
+ private static class RangeInference extends ExpressionVisitor<ValueDesc,
ExpressionRewriteContext> {
@Override
- public ValueDesc visit(Expression expr, Void context) {
- return new UnknownValue(expr);
+ public ValueDesc visit(Expression expr, ExpressionRewriteContext
context) {
+ return new UnknownValue(context, expr);
}
- private ValueDesc buildRange(ComparisonPredicate predicate) {
+ private ValueDesc buildRange(ExpressionRewriteContext context,
ComparisonPredicate predicate) {
Expression rewrite = ExpressionRuleExecutor.normalize(predicate);
Expression right = rewrite.child(1);
if (right.isNullLiteral()) {
- // it's safe to return empty value if >, >=, <, <= and = with
null
- if ((predicate instanceof GreaterThan || predicate instanceof
GreaterThanEqual
- || predicate instanceof LessThan || predicate
instanceof LessThanEqual
- || predicate instanceof EqualTo)) {
- return new EmptyValue(rewrite.child(0), rewrite);
- } else {
- return new UnknownValue(predicate);
- }
+ return new UnknownValue(context, predicate);
}
// only handle `NumericType`
if (right.isLiteral() && right.getDataType().isNumericType()) {
- return ValueDesc.range((ComparisonPredicate) rewrite);
+ return ValueDesc.range(context, (ComparisonPredicate) rewrite);
}
- return new UnknownValue(predicate);
+ return new UnknownValue(context, predicate);
}
@Override
- public ValueDesc visitGreaterThan(GreaterThan greaterThan, Void
context) {
- return buildRange(greaterThan);
+ public ValueDesc visitGreaterThan(GreaterThan greaterThan,
ExpressionRewriteContext context) {
+ return buildRange(context, greaterThan);
}
@Override
- public ValueDesc visitGreaterThanEqual(GreaterThanEqual
greaterThanEqual, Void context) {
- return buildRange(greaterThanEqual);
+ public ValueDesc visitGreaterThanEqual(GreaterThanEqual
greaterThanEqual, ExpressionRewriteContext context) {
+ return buildRange(context, greaterThanEqual);
}
@Override
- public ValueDesc visitLessThan(LessThan lessThan, Void context) {
- return buildRange(lessThan);
+ public ValueDesc visitLessThan(LessThan lessThan,
ExpressionRewriteContext context) {
+ return buildRange(context, lessThan);
}
@Override
- public ValueDesc visitLessThanEqual(LessThanEqual lessThanEqual, Void
context) {
- return buildRange(lessThanEqual);
+ public ValueDesc visitLessThanEqual(LessThanEqual lessThanEqual,
ExpressionRewriteContext context) {
+ return buildRange(context, lessThanEqual);
}
@Override
- public ValueDesc visitEqualTo(EqualTo equalTo, Void context) {
- return buildRange(equalTo);
+ public ValueDesc visitEqualTo(EqualTo equalTo,
ExpressionRewriteContext context) {
+ return buildRange(context, equalTo);
}
@Override
- public ValueDesc visitInPredicate(InPredicate inPredicate, Void
context) {
+ public ValueDesc visitInPredicate(InPredicate inPredicate,
ExpressionRewriteContext context) {
// only handle `NumericType`
- if (ExpressionUtils.isAllLiteral(inPredicate.getOptions())
+ if (ExpressionUtils.isAllNonNullLiteral(inPredicate.getOptions())
&&
ExpressionUtils.matchNumericType(inPredicate.getOptions())) {
- return ValueDesc.discrete(inPredicate);
+ return ValueDesc.discrete(context, inPredicate);
}
- return new UnknownValue(inPredicate);
+ return new UnknownValue(context, inPredicate);
}
@Override
- public ValueDesc visitAnd(And and, Void context) {
- return simplify(and, ExpressionUtils.extractConjunction(and),
ValueDesc::intersect, ExpressionUtils::and);
+ public ValueDesc visitAnd(And and, ExpressionRewriteContext context) {
+ return simplify(context, and,
ExpressionUtils.extractConjunction(and),
+ ValueDesc::intersect, ExpressionUtils::and);
}
@Override
- public ValueDesc visitOr(Or or, Void context) {
- return simplify(or, ExpressionUtils.extractDisjunction(or),
ValueDesc::union, ExpressionUtils::or);
+ public ValueDesc visitOr(Or or, ExpressionRewriteContext context) {
+ return simplify(context, or,
ExpressionUtils.extractDisjunction(or),
+ ValueDesc::union, ExpressionUtils::or);
}
- private ValueDesc simplify(Expression originExpr, List<Expression>
predicates,
+ private ValueDesc simplify(ExpressionRewriteContext context,
+ Expression originExpr, List<Expression> predicates,
BinaryOperator<ValueDesc> op, BinaryOperator<Expression>
exprOp) {
Map<Expression, List<ValueDesc>> groupByReference =
predicates.stream()
@@ -184,52 +186,58 @@ public class SimplifyRange extends
AbstractExpressionRewriteRule {
}
// use UnknownValue to wrap different references
- return new UnknownValue(valuePerRefs, originExpr, exprOp);
+ return new UnknownValue(context, valuePerRefs, originExpr, exprOp);
}
}
private abstract static class ValueDesc {
- Expression expr;
+ ExpressionRewriteContext context;
+ Expression exprForNonNull;
Expression reference;
- public ValueDesc(Expression reference, Expression expr) {
- this.expr = expr;
+ public ValueDesc(ExpressionRewriteContext context, Expression
reference, Expression exprForNonNull) {
+ this.context = context;
+ this.exprForNonNull = exprForNonNull;
this.reference = reference;
}
public abstract ValueDesc union(ValueDesc other);
- public static ValueDesc union(RangeValue range, DiscreteValue
discrete, boolean reverseOrder) {
+ public static ValueDesc union(ExpressionRewriteContext context,
+ RangeValue range, DiscreteValue discrete, boolean
reverseOrder) {
long count = discrete.values.stream().filter(x ->
range.range.test(x)).count();
if (count == discrete.values.size()) {
return range;
}
- Expression originExpr = ExpressionUtils.or(range.expr,
discrete.expr);
+ Expression exprForNonNull = FoldConstantRuleOnFE.INSTANCE.rewrite(
+ ExpressionUtils.or(range.exprForNonNull,
discrete.exprForNonNull), context);
List<ValueDesc> sourceValues = reverseOrder
? ImmutableList.of(discrete, range)
: ImmutableList.of(range, discrete);
- return new UnknownValue(sourceValues, originExpr,
ExpressionUtils::or);
+ return new UnknownValue(context, sourceValues, exprForNonNull,
ExpressionUtils::or);
}
public abstract ValueDesc intersect(ValueDesc other);
- public static ValueDesc intersect(RangeValue range, DiscreteValue
discrete) {
- DiscreteValue result = new DiscreteValue(discrete.reference,
discrete.expr);
+ public static ValueDesc intersect(ExpressionRewriteContext context,
RangeValue range, DiscreteValue discrete) {
+ DiscreteValue result = new DiscreteValue(context,
discrete.reference, discrete.exprForNonNull);
discrete.values.stream().filter(x ->
range.range.contains(x)).forEach(result.values::add);
if (!result.values.isEmpty()) {
return result;
}
- return new EmptyValue(range.reference,
ExpressionUtils.and(range.expr, discrete.expr));
+ Expression originExprForNonNull =
FoldConstantRuleOnFE.INSTANCE.rewrite(
+ ExpressionUtils.and(range.exprForNonNull,
discrete.exprForNonNull), context);
+ return new EmptyValue(context, range.reference,
originExprForNonNull);
}
- public abstract Expression toExpression();
+ public abstract Expression toExpressionForNonNull();
- public static ValueDesc range(ComparisonPredicate predicate) {
+ public static ValueDesc range(ExpressionRewriteContext context,
ComparisonPredicate predicate) {
Literal value = (Literal) predicate.right();
if (predicate instanceof EqualTo) {
- return new DiscreteValue(predicate.left(), predicate, value);
+ return new DiscreteValue(context, predicate.left(), predicate,
value);
}
- RangeValue rangeValue = new RangeValue(predicate.left(),
predicate);
+ RangeValue rangeValue = new RangeValue(context, predicate.left(),
predicate);
if (predicate instanceof GreaterThanEqual) {
rangeValue.range = Range.atLeast(value);
} else if (predicate instanceof GreaterThan) {
@@ -243,16 +251,16 @@ public class SimplifyRange extends
AbstractExpressionRewriteRule {
return rangeValue;
}
- public static ValueDesc discrete(InPredicate in) {
+ public static ValueDesc discrete(ExpressionRewriteContext context,
InPredicate in) {
Set<Literal> literals =
in.getOptions().stream().map(Literal.class::cast).collect(Collectors.toSet());
- return new DiscreteValue(in.getCompareExpr(), in, literals);
+ return new DiscreteValue(context, in.getCompareExpr(), in,
literals);
}
}
private static class EmptyValue extends ValueDesc {
- public EmptyValue(Expression reference, Expression expr) {
- super(reference, expr);
+ public EmptyValue(ExpressionRewriteContext context, Expression
reference, Expression exprForNonNull) {
+ super(context, reference, exprForNonNull);
}
@Override
@@ -266,8 +274,12 @@ public class SimplifyRange extends
AbstractExpressionRewriteRule {
}
@Override
- public Expression toExpression() {
- return BooleanLiteral.FALSE;
+ public Expression toExpressionForNonNull() {
+ if (reference.nullable()) {
+ return new And(new IsNull(reference), new
NullLiteral(BooleanType.INSTANCE));
+ } else {
+ return BooleanLiteral.FALSE;
+ }
}
}
@@ -279,8 +291,8 @@ public class SimplifyRange extends
AbstractExpressionRewriteRule {
private static class RangeValue extends ValueDesc {
Range<Literal> range;
- public RangeValue(Expression reference, Expression expr) {
- super(reference, expr);
+ public RangeValue(ExpressionRewriteContext context, Expression
reference, Expression exprForNonNull) {
+ super(context, reference, exprForNonNull);
}
@Override
@@ -290,19 +302,23 @@ public class SimplifyRange extends
AbstractExpressionRewriteRule {
}
try {
if (other instanceof RangeValue) {
+ Expression originExprForNonNull =
FoldConstantRuleOnFE.INSTANCE.rewrite(
+ ExpressionUtils.or(exprForNonNull,
other.exprForNonNull), context);
RangeValue o = (RangeValue) other;
if (range.isConnected(o.range)) {
- RangeValue rangeValue = new RangeValue(reference,
ExpressionUtils.or(expr, other.expr));
+ RangeValue rangeValue = new RangeValue(context,
reference, originExprForNonNull);
rangeValue.range = range.span(o.range);
return rangeValue;
}
- Expression originExpr = ExpressionUtils.or(expr,
other.expr);
- return new UnknownValue(ImmutableList.of(this, other),
originExpr, ExpressionUtils::or);
+ return new UnknownValue(context, ImmutableList.of(this,
other),
+ originExprForNonNull, ExpressionUtils::or);
}
- return union(this, (DiscreteValue) other, false);
+ return union(context, this, (DiscreteValue) other, false);
} catch (Exception e) {
- Expression originExpr = ExpressionUtils.or(expr, other.expr);
- return new UnknownValue(ImmutableList.of(this, other),
originExpr, ExpressionUtils::or);
+ Expression originExprForNonNull =
FoldConstantRuleOnFE.INSTANCE.rewrite(
+ ExpressionUtils.or(exprForNonNull,
other.exprForNonNull), context);
+ return new UnknownValue(context, ImmutableList.of(this, other),
+ originExprForNonNull, ExpressionUtils::or);
}
}
@@ -313,23 +329,27 @@ public class SimplifyRange extends
AbstractExpressionRewriteRule {
}
try {
if (other instanceof RangeValue) {
+ Expression originExprForNonNull =
FoldConstantRuleOnFE.INSTANCE.rewrite(
+ ExpressionUtils.and(exprForNonNull,
other.exprForNonNull), context);
RangeValue o = (RangeValue) other;
if (range.isConnected(o.range)) {
- RangeValue rangeValue = new RangeValue(reference,
ExpressionUtils.and(expr, other.expr));
+ RangeValue rangeValue = new RangeValue(context,
reference, originExprForNonNull);
rangeValue.range = range.intersection(o.range);
return rangeValue;
}
- return new EmptyValue(reference, ExpressionUtils.and(expr,
other.expr));
+ return new EmptyValue(context, reference,
originExprForNonNull);
}
- return intersect(this, (DiscreteValue) other);
+ return intersect(context, this, (DiscreteValue) other);
} catch (Exception e) {
- Expression originExpr = ExpressionUtils.and(expr, other.expr);
- return new UnknownValue(ImmutableList.of(this, other),
originExpr, ExpressionUtils::and);
+ Expression originExprForNonNull =
FoldConstantRuleOnFE.INSTANCE.rewrite(
+ ExpressionUtils.and(exprForNonNull,
other.exprForNonNull), context);
+ return new UnknownValue(context, ImmutableList.of(this, other),
+ originExprForNonNull, ExpressionUtils::and);
}
}
@Override
- public Expression toExpression() {
+ public Expression toExpressionForNonNull() {
List<Expression> result = Lists.newArrayList();
if (range.hasLowerBound()) {
if (range.lowerBoundType() == BoundType.CLOSED) {
@@ -347,11 +367,12 @@ public class SimplifyRange extends
AbstractExpressionRewriteRule {
}
if (!result.isEmpty()) {
return ExpressionUtils.and(result);
- } else if (reference.nullable()) {
- // when reference is nullable, we should filter null slot.
- return new Not(new IsNull(reference));
} else {
- return BooleanLiteral.TRUE;
+ if (reference.nullable()) {
+ return new Or(new Not(new IsNull(reference)), new
NullLiteral(BooleanType.INSTANCE));
+ } else {
+ return BooleanLiteral.TRUE;
+ }
}
}
@@ -369,12 +390,14 @@ public class SimplifyRange extends
AbstractExpressionRewriteRule {
private static class DiscreteValue extends ValueDesc {
Set<Literal> values;
- public DiscreteValue(Expression reference, Expression expr, Literal...
values) {
- this(reference, expr, Arrays.asList(values));
+ public DiscreteValue(ExpressionRewriteContext context,
+ Expression reference, Expression exprForNonNull, Literal...
values) {
+ this(context, reference, exprForNonNull, Arrays.asList(values));
}
- public DiscreteValue(Expression reference, Expression expr,
Collection<Literal> values) {
- super(reference, expr);
+ public DiscreteValue(ExpressionRewriteContext context,
+ Expression reference, Expression exprForNonNull,
Collection<Literal> values) {
+ super(context, reference, exprForNonNull);
this.values = Sets.newTreeSet(values);
}
@@ -385,15 +408,19 @@ public class SimplifyRange extends
AbstractExpressionRewriteRule {
}
try {
if (other instanceof DiscreteValue) {
- DiscreteValue discreteValue = new DiscreteValue(reference,
ExpressionUtils.or(expr, other.expr));
+ Expression originExprForNonNull =
FoldConstantRuleOnFE.INSTANCE.rewrite(
+ ExpressionUtils.or(exprForNonNull,
other.exprForNonNull), context);
+ DiscreteValue discreteValue = new DiscreteValue(context,
reference, originExprForNonNull);
discreteValue.values.addAll(((DiscreteValue)
other).values);
discreteValue.values.addAll(this.values);
return discreteValue;
}
- return union((RangeValue) other, this, true);
+ return union(context, (RangeValue) other, this, true);
} catch (Exception e) {
- Expression originExpr = ExpressionUtils.or(expr, other.expr);
- return new UnknownValue(ImmutableList.of(this, other),
originExpr, ExpressionUtils::or);
+ Expression originExprForNonNull =
FoldConstantRuleOnFE.INSTANCE.rewrite(
+ ExpressionUtils.or(exprForNonNull,
other.exprForNonNull), context);
+ return new UnknownValue(context, ImmutableList.of(this, other),
+ originExprForNonNull, ExpressionUtils::or);
}
}
@@ -404,24 +431,28 @@ public class SimplifyRange extends
AbstractExpressionRewriteRule {
}
try {
if (other instanceof DiscreteValue) {
- DiscreteValue discreteValue = new DiscreteValue(reference,
ExpressionUtils.and(expr, other.expr));
+ Expression originExprForNonNull =
FoldConstantRuleOnFE.INSTANCE.rewrite(
+ ExpressionUtils.and(exprForNonNull,
other.exprForNonNull), context);
+ DiscreteValue discreteValue = new DiscreteValue(context,
reference, originExprForNonNull);
discreteValue.values.addAll(((DiscreteValue)
other).values);
discreteValue.values.retainAll(this.values);
if (discreteValue.values.isEmpty()) {
- return new EmptyValue(reference,
ExpressionUtils.and(expr, other.expr));
+ return new EmptyValue(context, reference,
originExprForNonNull);
} else {
return discreteValue;
}
}
- return intersect((RangeValue) other, this);
+ return intersect(context, (RangeValue) other, this);
} catch (Exception e) {
- Expression originExpr = ExpressionUtils.and(expr, other.expr);
- return new UnknownValue(ImmutableList.of(this, other),
originExpr, ExpressionUtils::and);
+ Expression originExprForNonNull =
FoldConstantRuleOnFE.INSTANCE.rewrite(
+ ExpressionUtils.and(exprForNonNull,
other.exprForNonNull), context);
+ return new UnknownValue(context, ImmutableList.of(this, other),
+ originExprForNonNull, ExpressionUtils::and);
}
}
@Override
- public Expression toExpression() {
+ public Expression toExpressionForNonNull() {
// NOTICE: it's related with `InPredicateToEqualToRule`
// They are same processes, so must change synchronously.
if (values.size() == 1) {
@@ -447,40 +478,49 @@ public class SimplifyRange extends
AbstractExpressionRewriteRule {
private final List<ValueDesc> sourceValues;
private final BinaryOperator<Expression> mergeExprOp;
- private UnknownValue(Expression expr) {
- super(expr, expr);
+ private UnknownValue(ExpressionRewriteContext context, Expression
expr) {
+ super(context, expr, expr);
sourceValues = ImmutableList.of();
mergeExprOp = null;
}
- public UnknownValue(List<ValueDesc> sourceValues, Expression
originExpr,
- BinaryOperator<Expression> mergeExprOp) {
- super(sourceValues.get(0).reference, originExpr);
+ public UnknownValue(ExpressionRewriteContext context,
+ List<ValueDesc> sourceValues, Expression exprForNonNull,
BinaryOperator<Expression> mergeExprOp) {
+ super(context, sourceValues.get(0).reference, exprForNonNull);
this.sourceValues = ImmutableList.copyOf(sourceValues);
this.mergeExprOp = mergeExprOp;
}
@Override
public ValueDesc union(ValueDesc other) {
- Expression originExpr = ExpressionUtils.or(expr, other.expr);
- return new UnknownValue(ImmutableList.of(this, other), originExpr,
ExpressionUtils::or);
+ Expression originExprForNonNull =
FoldConstantRuleOnFE.INSTANCE.rewrite(
+ ExpressionUtils.or(exprForNonNull, other.exprForNonNull),
context);
+ return new UnknownValue(context, ImmutableList.of(this, other),
originExprForNonNull, ExpressionUtils::or);
}
@Override
public ValueDesc intersect(ValueDesc other) {
- Expression originExpr = ExpressionUtils.and(expr, other.expr);
- return new UnknownValue(ImmutableList.of(this, other), originExpr,
ExpressionUtils::and);
+ Expression originExprForNonNull =
FoldConstantRuleOnFE.INSTANCE.rewrite(
+ ExpressionUtils.and(exprForNonNull, other.exprForNonNull),
context);
+ return new UnknownValue(context, ImmutableList.of(this, other),
originExprForNonNull, ExpressionUtils::and);
}
@Override
- public Expression toExpression() {
+ public Expression toExpressionForNonNull() {
if (sourceValues.isEmpty()) {
- return expr;
+ return exprForNonNull;
+ }
+ Expression result = sourceValues.get(0).toExpressionForNonNull();
+ for (int i = 1; i < sourceValues.size(); i++) {
+ result = mergeExprOp.apply(result,
sourceValues.get(i).toExpressionForNonNull());
+ }
+ result = FoldConstantRuleOnFE.INSTANCE.rewrite(result, context);
+ // ATTN: we must return original expr, because OrToIn is
implemented with MutableState,
+ // newExpr will lose these states leading to dead loop by OrToIn
-> SimplifyRange -> FoldConstantByFE
+ if (result.equals(exprForNonNull)) {
+ return exprForNonNull;
}
- return sourceValues.stream()
- .map(ValueDesc::toExpression)
- .reduce(mergeExprOp)
- .get();
+ return result;
}
}
}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java
index 8bec208896a..739a0e9a3cc 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java
@@ -373,6 +373,19 @@ public class ExpressionUtils {
return children.stream().allMatch(c -> c instanceof Literal);
}
+ /**
+ * return true if all children are literal but not null literal.
+ */
+ public static boolean isAllNonNullLiteral(List<Expression> children) {
+ for (Expression child : children) {
+ if ((!(child instanceof Literal)) || (child instanceof
NullLiteral)) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ /** matchNumericType */
public static boolean matchNumericType(List<Expression> children) {
return children.stream().allMatch(c ->
c.getDataType().isNumericType());
}
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/SimplifyRangeTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/SimplifyRangeTest.java
index 74843a26b21..a668cc79925 100644
---
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/SimplifyRangeTest.java
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/SimplifyRangeTest.java
@@ -30,7 +30,9 @@ import org.apache.doris.nereids.trees.plans.RelationId;
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.types.BooleanType;
import org.apache.doris.nereids.types.DataType;
-import org.apache.doris.nereids.types.DoubleType;
+import org.apache.doris.nereids.types.DateTimeV2Type;
+import org.apache.doris.nereids.types.DateV2Type;
+import org.apache.doris.nereids.types.DecimalV3Type;
import org.apache.doris.nereids.types.IntegerType;
import org.apache.doris.nereids.types.StringType;
import org.apache.doris.nereids.types.TinyIntType;
@@ -61,34 +63,39 @@ public class SimplifyRangeTest {
public void testSimplify() {
executor = new
ExpressionRuleExecutor(ImmutableList.of(SimplifyRange.INSTANCE));
assertRewrite("TA", "TA");
- assertRewrite("TA > 3 or TA > null", "TA > 3");
- assertRewrite("TA > 3 or TA < null", "TA > 3");
- assertRewrite("TA > 3 or TA = null", "TA > 3");
- assertRewrite("TA > 3 or TA <> null", "TA > 3 or TA <> null");
+ assertRewrite("TA > 3 or TA > null", "TA > 3 OR NULL");
+ assertRewrite("TA > 3 or TA < null", "TA > 3 OR NULL");
+ assertRewrite("TA > 3 or TA = null", "TA > 3 OR NULL");
+ assertRewrite("TA > 3 or TA <> null", "TA > 3 or null");
assertRewrite("TA > 3 or TA <=> null", "TA > 3 or TA <=> null");
- assertRewrite("TA > 3 and TA > null", "false");
- assertRewrite("TA > 3 and TA < null", "false");
- assertRewrite("TA > 3 and TA = null", "false");
- assertRewrite("TA > 3 and TA <> null", "TA > 3 and TA <> null");
+ assertRewriteNotNull("TA > 3 and TA > null", "TA > 3 and NULL");
+ assertRewriteNotNull("TA > 3 and TA < null", "TA > 3 and NULL");
+ assertRewriteNotNull("TA > 3 and TA = null", "TA > 3 and NULL");
+ assertRewrite("TA > 3 and TA > null", "TA > 3 and null");
+ assertRewrite("TA > 3 and TA < null", "TA > 3 and null");
+ assertRewrite("TA > 3 and TA = null", "TA > 3 and null");
+ assertRewrite("TA > 3 and TA <> null", "TA > 3 and null");
assertRewrite("TA > 3 and TA <=> null", "TA > 3 and TA <=> null");
assertRewrite("(TA >= 1 and TA <=3 ) or (TA > 5 and TA < 7)", "(TA >=
1 and TA <=3 ) or (TA > 5 and TA < 7)");
- assertRewrite("(TA > 3 and TA < 1) or (TA > 7 and TA < 5)", "FALSE");
- assertRewrite("TA > 3 and TA < 1", "FALSE");
+ assertRewriteNotNull("(TA > 3 and TA < 1) or (TA > 7 and TA < 5)",
"FALSE");
+ assertRewrite("(TA > 3 and TA < 1) or (TA > 7 and TA < 5)", "TA is
null and null");
+ assertRewriteNotNull("TA > 3 and TA < 1", "FALSE");
+ assertRewrite("TA > 3 and TA < 1", "TA is null and null");
assertRewrite("TA >= 3 and TA < 3", "TA >= 3 and TA < 3");
- assertRewrite("TA = 1 and TA > 10", "FALSE");
+ assertRewriteNotNull("TA = 1 and TA > 10", "FALSE");
+ assertRewrite("TA = 1 and TA > 10", "TA is null and null");
assertRewrite("TA > 5 or TA < 1", "TA > 5 or TA < 1");
assertRewrite("TA > 5 or TA > 1 or TA > 10", "TA > 1");
- assertRewrite("TA > 5 or TA > 1 or TA < 10", "TA IS NOT NULL");
+ assertRewrite("TA > 5 or TA > 1 or TA < 10", "TA is not null or null");
assertRewriteNotNull("TA > 5 or TA > 1 or TA < 10", "TRUE");
assertRewrite("TA > 5 and TA > 1 and TA > 10", "TA > 10");
assertRewrite("TA > 5 and TA > 1 and TA < 10", "TA > 5 and TA < 10");
assertRewrite("TA > 1 or TA < 1", "TA > 1 or TA < 1");
- assertRewrite("TA > 1 or TA < 10", "TA IS NOT NULL");
+ assertRewrite("TA > 1 or TA < 10", "TA is not null or null");
assertRewriteNotNull("TA > 1 or TA < 10", "TRUE");
assertRewrite("TA > 5 and TA < 10", "TA > 5 and TA < 10");
assertRewrite("TA > 5 and TA > 10", "TA > 10");
- assertRewrite("TA > 5 + 1 and TA > 10", "TA > 5 + 1 and TA > 10");
- assertRewrite("TA > 5 + 1 and TA > 10", "TA > 5 + 1 and TA > 10");
+ assertRewrite("TA > 5 + 1 and TA > 10", "cast(TA as smallint) > 6 and
TA > 10");
assertRewrite("(TA > 1 and TA > 10) or TA > 20", "TA > 10");
assertRewrite("(TA > 1 or TA > 10) and TA > 20", "TA > 20");
assertRewrite("(TA + TB > 1 or TA + TB > 10) and TA + TB > 20", "TA +
TB > 20");
@@ -98,25 +105,32 @@ public class SimplifyRangeTest {
assertRewrite("((TB > 30 and TA > 40) and TA > 20) and (TB > 10 and TB
> 20)", "TB > 30 and TA > 40");
assertRewrite("(TA > 10 and TB > 10) or (TB > 10 and TB > 20)", "TA >
10 and TB > 10 or TB > 20");
assertRewrite("((TA > 10 or TA > 5) and TB > 10) or (TB > 10 and (TB >
20 or TB < 10))", "(TA > 5 and TB > 10) or (TB > 10 and (TB > 20 or TB < 10))");
- assertRewrite("TA in (1,2,3) and TA > 10", "FALSE");
+ assertRewriteNotNull("TA in (1,2,3) and TA > 10", "FALSE");
+ assertRewrite("TA in (1,2,3) and TA > 10", "TA is null and null");
assertRewrite("TA in (1,2,3) and TA >= 1", "TA in (1,2,3)");
assertRewrite("TA in (1,2,3) and TA > 1", "((TA = 2) OR (TA = 3))");
assertRewrite("TA in (1,2,3) or TA >= 1", "TA >= 1");
assertRewrite("TA in (1)", "TA in (1)");
assertRewrite("TA in (1,2,3) and TA < 10", "TA in (1,2,3)");
- assertRewrite("TA in (1,2,3) and TA < 1", "FALSE");
+ assertRewriteNotNull("TA in (1,2,3) and TA < 1", "FALSE");
+ assertRewrite("TA in (1,2,3) and TA < 1", "TA is null and null");
assertRewrite("TA in (1,2,3) or TA < 1", "TA in (1,2,3) or TA < 1");
assertRewrite("TA in (1,2,3) or TA in (2,3,4)", "TA in (1,2,3,4)");
assertRewrite("TA in (1,2,3) or TA in (4,5,6)", "TA in (1,2,3,4,5,6)");
- assertRewrite("TA in (1,2,3) and TA in (4,5,6)", "FALSE");
+ assertRewrite("TA in (1,2,3) and TA in (4,5,6)", "TA is null and
null");
+ assertRewriteNotNull("TA in (1,2,3) and TA in (4,5,6)", "FALSE");
assertRewrite("TA in (1,2,3) and TA in (3,4,5)", "TA = 3");
assertRewrite("TA + TB in (1,2,3) and TA + TB in (3,4,5)", "TA + TB =
3");
assertRewrite("TA in (1,2,3) and DA > 1.5", "TA in (1,2,3) and DA >
1.5");
- assertRewrite("TA = 1 and TA = 3", "FALSE");
- assertRewrite("TA in (1) and TA in (3)", "FALSE");
+ assertRewriteNotNull("TA = 1 and TA = 3", "FALSE");
+ assertRewrite("TA = 1 and TA = 3", "TA is null and null");
+ assertRewriteNotNull("TA in (1) and TA in (3)", "FALSE");
+ assertRewrite("TA in (1) and TA in (3)", "TA is null and null");
assertRewrite("TA in (1) and TA in (1)", "TA = 1");
- assertRewrite("(TA > 3 and TA < 1) and TB < 5", "FALSE");
- assertRewrite("(TA > 3 and TA < 1) or TB < 5", "TB < 5");
+ assertRewriteNotNull("(TA > 3 and TA < 1) and TB < 5", "FALSE");
+ assertRewrite("(TA > 3 and TA < 1) and TB < 5", "TA is null and null
and TB < 5");
+ assertRewrite("TA > 3 and TB < 5 and TA < 1", "TA is null and null and
TB < 5");
+ assertRewrite("(TA > 3 and TA < 1) or TB < 5", "(TA is null and null)
or TB < 5");
assertRewrite("((IA = 1 AND SC ='1') OR SC = '1212') AND IA =1", "((IA
= 1 AND SC ='1') OR SC = '1212') AND IA =1");
}
@@ -133,7 +147,9 @@ public class SimplifyRangeTest {
private void assertRewriteNotNull(String expression, String expected) {
Map<String, Slot> mem = Maps.newHashMap();
Expression needRewriteExpression =
replaceNotNullUnboundSlot(PARSER.parseExpression(expression), mem);
+ needRewriteExpression = typeCoercion(needRewriteExpression);
Expression expectedExpression =
replaceNotNullUnboundSlot(PARSER.parseExpression(expected), mem);
+ expectedExpression = typeCoercion(expectedExpression);
Expression rewrittenExpression =
executor.rewrite(needRewriteExpression, context);
Assertions.assertEquals(expectedExpression, rewrittenExpression);
}
@@ -185,11 +201,15 @@ public class SimplifyRangeTest {
case 'I':
return IntegerType.INSTANCE;
case 'D':
- return DoubleType.INSTANCE;
+ return DecimalV3Type.createDecimalV3Type(2, 1);
case 'S':
return StringType.INSTANCE;
case 'B':
return BooleanType.INSTANCE;
+ case 'C':
+ return DateTimeV2Type.SYSTEM_DEFAULT;
+ case 'A':
+ return DateV2Type.INSTANCE;
default:
return BigIntType.INSTANCE;
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]