This is an automated email from the ASF dual-hosted git repository.
starocean999 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new fe7938dc29d [feat](nereids) simplify range support not equal / is null
/ is not null (#57537)
fe7938dc29d is described below
commit fe7938dc29d020cdda77ff0a110c5d08e058226a
Author: yujun <[email protected]>
AuthorDate: Thu Nov 20 16:07:35 2025 +0800
[feat](nereids) simplify range support not equal / is null / is not null
(#57537)
---
.../nereids/rules/expression/rules/AddMinMax.java | 70 +-
.../rules/expression/rules/ConditionRewrite.java | 10 +-
.../rules/expression/rules/RangeInference.java | 1446 ++++++++++++++++----
.../rules/expression/rules/SimplifyRange.java | 81 +-
.../nereids/rules/rewrite/PullUpPredicates.java | 3 +-
.../doris/nereids/trees/expressions/Not.java | 6 +-
.../apache/doris/nereids/util/ExpressionUtils.java | 30 -
.../rules/expression/ExpressionRewriteTest.java | 22 +-
.../rules/expression/SimplifyRangeTest.java | 252 +++-
.../adjust_nullable/test_subquery_nullable.out | 2 +-
.../push_down_filter_other_condition.out | 2 +-
.../predicate_infer/infer_predicate.out | 8 +-
.../adjust_virtual_slot_nullable.out | 8 +-
.../mv_p0/where/k123_nereids/k123_nereids.groovy | 2 +-
.../adjust_virtual_slot_nullable.groovy | 2 +-
15 files changed, 1525 insertions(+), 419 deletions(-)
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/AddMinMax.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/AddMinMax.java
index e3ea1a43b3a..f90ea3caf1a 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/AddMinMax.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/AddMinMax.java
@@ -21,11 +21,17 @@ import
org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher;
import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory;
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
import org.apache.doris.nereids.rules.expression.ExpressionRuleType;
+import org.apache.doris.nereids.rules.expression.rules.AddMinMax.MinMaxValue;
+import
org.apache.doris.nereids.rules.expression.rules.RangeInference.CompoundValue;
import
org.apache.doris.nereids.rules.expression.rules.RangeInference.DiscreteValue;
import
org.apache.doris.nereids.rules.expression.rules.RangeInference.EmptyValue;
+import
org.apache.doris.nereids.rules.expression.rules.RangeInference.IsNotNullValue;
+import
org.apache.doris.nereids.rules.expression.rules.RangeInference.IsNullValue;
+import
org.apache.doris.nereids.rules.expression.rules.RangeInference.NotDiscreteValue;
import
org.apache.doris.nereids.rules.expression.rules.RangeInference.RangeValue;
import
org.apache.doris.nereids.rules.expression.rules.RangeInference.UnknownValue;
import
org.apache.doris.nereids.rules.expression.rules.RangeInference.ValueDesc;
+import
org.apache.doris.nereids.rules.expression.rules.RangeInference.ValueDescVisitor;
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
import org.apache.doris.nereids.trees.expressions.CompoundPredicate;
import org.apache.doris.nereids.trees.expressions.EqualTo;
@@ -39,14 +45,15 @@ import
org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
import org.apache.doris.nereids.trees.expressions.literal.ComparableLiteral;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.util.ExpressionUtils;
+import org.apache.doris.nereids.util.PlanUtils;
import com.google.common.collect.BoundType;
import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Range;
import com.google.common.collect.Sets;
-import org.apache.commons.lang3.NotImplementedException;
import java.util.List;
import java.util.Map;
@@ -63,13 +70,14 @@ import java.util.stream.Collectors;
* a between 10 and 20 and b between 10 and 20 or a between 100 and 200 and b
between 100 and 200
* => (a <= 20 and b <= 20 or a >= 100 and b >= 100) and a >= 10 and a <=
200 and b >= 10 and b <= 200
*/
-public class AddMinMax implements ExpressionPatternRuleFactory {
+public class AddMinMax implements ExpressionPatternRuleFactory,
ValueDescVisitor<Map<Expression, MinMaxValue>, Void> {
public static final AddMinMax INSTANCE = new AddMinMax();
@Override
public List<ExpressionPatternMatcher<? extends Expression>> buildRules() {
return ImmutableList.of(
matchesTopType(CompoundPredicate.class)
+ .whenCtx(ctx ->
PlanUtils.isConditionExpressionPlan(ctx.rewriteContext.plan.orElse(null)))
.thenApply(ctx -> rewrite(ctx.expr,
ctx.rewriteContext))
.toRule(ExpressionRuleType.ADD_MIN_MAX)
);
@@ -78,7 +86,7 @@ public class AddMinMax implements
ExpressionPatternRuleFactory {
/** rewrite */
public Expression rewrite(CompoundPredicate expr, ExpressionRewriteContext
context) {
ValueDesc valueDesc = (new RangeInference()).getValue(expr, context);
- Map<Expression, MinMaxValue> exprMinMaxValues =
getExprMinMaxValues(valueDesc);
+ Map<Expression, MinMaxValue> exprMinMaxValues = valueDesc.accept(this,
null);
removeUnnecessaryMinMaxValues(expr, exprMinMaxValues);
if (!exprMinMaxValues.isEmpty()) {
return addExprMinMaxValues(expr, context, exprMinMaxValues);
@@ -92,7 +100,8 @@ public class AddMinMax implements
ExpressionPatternRuleFactory {
MATCH_NONE,
}
- private static class MinMaxValue {
+ /** record each expression's min and max value */
+ public static class MinMaxValue {
// min max range, if range = null means empty
Range<ComparableLiteral> range;
@@ -280,21 +289,8 @@ public class AddMinMax implements
ExpressionPatternRuleFactory {
return (expr instanceof SlotReference) && ((SlotReference)
expr).getOriginalColumn().isPresent();
}
- private Map<Expression, MinMaxValue> getExprMinMaxValues(ValueDesc value) {
- if (value instanceof EmptyValue) {
- return getExprMinMaxValues((EmptyValue) value);
- } else if (value instanceof DiscreteValue) {
- return getExprMinMaxValues((DiscreteValue) value);
- } else if (value instanceof RangeValue) {
- return getExprMinMaxValues((RangeValue) value);
- } else if (value instanceof UnknownValue) {
- return getExprMinMaxValues((UnknownValue) value);
- } else {
- throw new NotImplementedException("not implements");
- }
- }
-
- private Map<Expression, MinMaxValue> getExprMinMaxValues(EmptyValue value)
{
+ @Override
+ public Map<Expression, MinMaxValue> visitEmptyValue(EmptyValue value, Void
context) {
Expression reference = value.getReference();
Map<Expression, MinMaxValue> exprMinMaxValues = Maps.newHashMap();
if (isExprNeedAddMinMax(reference)) {
@@ -303,7 +299,8 @@ public class AddMinMax implements
ExpressionPatternRuleFactory {
return exprMinMaxValues;
}
- private Map<Expression, MinMaxValue> getExprMinMaxValues(DiscreteValue
value) {
+ @Override
+ public Map<Expression, MinMaxValue> visitDiscreteValue(DiscreteValue
value, Void context) {
Expression reference = value.getReference();
Map<Expression, MinMaxValue> exprMinMaxValues = Maps.newHashMap();
if (isExprNeedAddMinMax(reference)) {
@@ -312,7 +309,23 @@ public class AddMinMax implements
ExpressionPatternRuleFactory {
return exprMinMaxValues;
}
- private Map<Expression, MinMaxValue> getExprMinMaxValues(RangeValue value)
{
+ @Override
+ public Map<Expression, MinMaxValue> visitNotDiscreteValue(NotDiscreteValue
value, Void context) {
+ return ImmutableMap.of();
+ }
+
+ @Override
+ public Map<Expression, MinMaxValue> visitIsNullValue(IsNullValue value,
Void context) {
+ return ImmutableMap.of();
+ }
+
+ @Override
+ public Map<Expression, MinMaxValue> visitIsNotNullValue(IsNotNullValue
value, Void context) {
+ return ImmutableMap.of();
+ }
+
+ @Override
+ public Map<Expression, MinMaxValue> visitRangeValue(RangeValue value, Void
context) {
Expression reference = value.getReference();
Map<Expression, MinMaxValue> exprMinMaxValues = Maps.newHashMap();
if (isExprNeedAddMinMax(reference)) {
@@ -321,16 +334,14 @@ public class AddMinMax implements
ExpressionPatternRuleFactory {
return exprMinMaxValues;
}
- private Map<Expression, MinMaxValue> getExprMinMaxValues(UnknownValue
valueDesc) {
+ @Override
+ public Map<Expression, MinMaxValue> visitCompoundValue(CompoundValue
valueDesc, Void context) {
List<ValueDesc> sourceValues = valueDesc.getSourceValues();
- if (sourceValues.isEmpty()) {
- return Maps.newHashMap();
- }
- Map<Expression, MinMaxValue> result =
Maps.newHashMap(getExprMinMaxValues(sourceValues.get(0)));
+ Map<Expression, MinMaxValue> result =
Maps.newHashMap(sourceValues.get(0).accept(this, context));
int nextExprOrderIndex = result.values().stream().mapToInt(k ->
k.exprOrderIndex).max().orElse(0);
for (int i = 1; i < sourceValues.size(); i++) {
// process in sourceValues[i]
- Map<Expression, MinMaxValue> minMaxValues =
getExprMinMaxValues(sourceValues.get(i));
+ Map<Expression, MinMaxValue> minMaxValues =
sourceValues.get(i).accept(this, context);
// merge values of sourceValues[i] into result.
// also keep the value's relative order in sourceValues[i].
// for example, if a and b in sourceValues[i], but not in result,
then during merging,
@@ -398,4 +409,9 @@ public class AddMinMax implements
ExpressionPatternRuleFactory {
}
return result;
}
+
+ @Override
+ public Map<Expression, MinMaxValue> visitUnknownValue(UnknownValue
valueDesc, Void context) {
+ return ImmutableMap.of();
+ }
}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/ConditionRewrite.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/ConditionRewrite.java
index 6e429624dfe..c28a395a97b 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/ConditionRewrite.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/ConditionRewrite.java
@@ -30,9 +30,8 @@ import
org.apache.doris.nereids.trees.expressions.functions.scalar.If;
import
org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
-import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
-import org.apache.doris.nereids.trees.plans.logical.LogicalHaving;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
+import org.apache.doris.nereids.util.PlanUtils;
import com.google.common.collect.ImmutableList;
@@ -67,18 +66,17 @@ public abstract class ConditionRewrite extends
DefaultExpressionRewriter<Boolean
return expression.accept(this, rootIsCondition(context));
}
+ // for the expression root, only filter and join expression can treat as
condition
protected boolean rootIsCondition(ExpressionRewriteContext context) {
Plan plan = context.plan.orElse(null);
- if (plan instanceof LogicalFilter || plan instanceof LogicalHaving) {
- return true;
- } else if (plan instanceof LogicalJoin) {
+ if (plan instanceof LogicalJoin) {
// null aware join can not treat null as false
ExpressionSource source = context.source.orElse(null);
return ((LogicalJoin<?, ?>) plan).getJoinType() !=
JoinType.NULL_AWARE_LEFT_ANTI_JOIN
&& (source == ExpressionSource.JOIN_HASH_CONDITION
|| source ==
ExpressionSource.JOIN_OTHER_CONDITION);
} else {
- return false;
+ return PlanUtils.isConditionExpressionPlan(plan);
}
}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/RangeInference.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/RangeInference.java
index 3a87149fa14..483c12d34a1 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/RangeInference.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/RangeInference.java
@@ -19,7 +19,6 @@ package org.apache.doris.nereids.rules.expression.rules;
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
import org.apache.doris.nereids.trees.expressions.And;
-import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.GreaterThan;
@@ -28,31 +27,32 @@ import
org.apache.doris.nereids.trees.expressions.InPredicate;
import org.apache.doris.nereids.trees.expressions.IsNull;
import org.apache.doris.nereids.trees.expressions.LessThan;
import org.apache.doris.nereids.trees.expressions.LessThanEqual;
+import org.apache.doris.nereids.trees.expressions.Not;
import org.apache.doris.nereids.trees.expressions.Or;
+import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
import org.apache.doris.nereids.trees.expressions.literal.ComparableLiteral;
-import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.util.ExpressionUtils;
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.Preconditions;
+import com.google.common.base.Predicate;
import com.google.common.collect.BoundType;
import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
-import com.google.common.collect.Multimap;
-import com.google.common.collect.Multimaps;
+import com.google.common.collect.Maps;
import com.google.common.collect.Range;
import com.google.common.collect.RangeSet;
import com.google.common.collect.Sets;
import com.google.common.collect.TreeRangeSet;
-import java.util.ArrayList;
-import java.util.Collection;
-import java.util.LinkedHashMap;
import java.util.List;
+import java.util.Map;
import java.util.Map.Entry;
+import java.util.Optional;
import java.util.Set;
-import java.util.function.BinaryOperator;
-import java.util.stream.Collectors;
/**
* collect range of expression
@@ -71,255 +71,646 @@ public class RangeInference extends
ExpressionVisitor<RangeInference.ValueDesc,
return new UnknownValue(context, expr);
}
- private ValueDesc buildRange(ExpressionRewriteContext context,
ComparisonPredicate predicate) {
- Expression right = predicate.child(1);
- if (right.isNullLiteral()) {
- return new UnknownValue(context, predicate);
- }
- // only handle `NumericType` and `DateLikeType` and `StringLikeType`
- DataType rightDataType = right.getDataType();
- if (right instanceof ComparableLiteral
- && (rightDataType.isNumericType() ||
rightDataType.isDateLikeType()
- || rightDataType.isStringLikeType())) {
- return ValueDesc.range(context, predicate);
- }
- return new UnknownValue(context, predicate);
- }
-
@Override
public ValueDesc visitGreaterThan(GreaterThan greaterThan,
ExpressionRewriteContext context) {
- return buildRange(context, greaterThan);
+ Optional<ComparableLiteral> rightLiteral =
tryGetComparableLiteral(greaterThan.right());
+ if (rightLiteral.isPresent()) {
+ return new RangeValue(context, greaterThan.left(),
Range.greaterThan(rightLiteral.get()));
+ } else {
+ return new UnknownValue(context, greaterThan);
+ }
}
@Override
public ValueDesc visitGreaterThanEqual(GreaterThanEqual greaterThanEqual,
ExpressionRewriteContext context) {
- return buildRange(context, greaterThanEqual);
+ Optional<ComparableLiteral> rightLiteral =
tryGetComparableLiteral(greaterThanEqual.right());
+ if (rightLiteral.isPresent()) {
+ return new RangeValue(context, greaterThanEqual.left(),
Range.atLeast(rightLiteral.get()));
+ } else {
+ return new UnknownValue(context, greaterThanEqual);
+ }
}
@Override
public ValueDesc visitLessThan(LessThan lessThan, ExpressionRewriteContext
context) {
- return buildRange(context, lessThan);
+ Optional<ComparableLiteral> rightLiteral =
tryGetComparableLiteral(lessThan.right());
+ if (rightLiteral.isPresent()) {
+ return new RangeValue(context, lessThan.left(),
Range.lessThan(rightLiteral.get()));
+ } else {
+ return new UnknownValue(context, lessThan);
+ }
}
@Override
public ValueDesc visitLessThanEqual(LessThanEqual lessThanEqual,
ExpressionRewriteContext context) {
- return buildRange(context, lessThanEqual);
+ Optional<ComparableLiteral> rightLiteral =
tryGetComparableLiteral(lessThanEqual.right());
+ if (rightLiteral.isPresent()) {
+ return new RangeValue(context, lessThanEqual.left(),
Range.atMost(rightLiteral.get()));
+ } else {
+ return new UnknownValue(context, lessThanEqual);
+ }
}
@Override
public ValueDesc visitEqualTo(EqualTo equalTo, ExpressionRewriteContext
context) {
- return buildRange(context, equalTo);
+ Optional<ComparableLiteral> rightLiteral =
tryGetComparableLiteral(equalTo.right());
+ if (rightLiteral.isPresent()) {
+ return new DiscreteValue(context, equalTo.left(),
ImmutableSet.of(rightLiteral.get()));
+ } else {
+ return new UnknownValue(context, equalTo);
+ }
}
@Override
public ValueDesc visitInPredicate(InPredicate inPredicate,
ExpressionRewriteContext context) {
// only handle `NumericType` and `DateLikeType`
if (inPredicate.getOptions().size() <=
InPredicateDedup.REWRITE_OPTIONS_MAX_SIZE
- &&
ExpressionUtils.isAllNonNullComparableLiteral(inPredicate.getOptions())
- && (ExpressionUtils.matchNumericType(inPredicate.getOptions())
- ||
ExpressionUtils.matchDateLikeType(inPredicate.getOptions()))) {
- return ValueDesc.discrete(context, inPredicate);
+ &&
ExpressionUtils.isAllNonNullComparableLiteral(inPredicate.getOptions())) {
+ Set<ComparableLiteral> values =
Sets.newLinkedHashSetWithExpectedSize(inPredicate.getOptions().size());
+ boolean succ = true;
+ for (Expression value : inPredicate.getOptions()) {
+ Optional<ComparableLiteral> literal =
tryGetComparableLiteral(value);
+ if (!literal.isPresent()) {
+ succ = false;
+ break;
+ }
+ values.add(literal.get());
+ }
+ if (succ) {
+ return new DiscreteValue(context,
inPredicate.getCompareExpr(), values);
+ }
}
+
return new UnknownValue(context, inPredicate);
}
+ private Optional<ComparableLiteral> tryGetComparableLiteral(Expression
right) {
+ // only handle `NumericType` and `DateLikeType` and `StringLikeType`
+ DataType rightDataType = right.getDataType();
+ if (right instanceof ComparableLiteral
+ && !right.isNullLiteral()
+ && (rightDataType.isNumericType() ||
rightDataType.isDateLikeType()
+ || rightDataType.isStringLikeType())) {
+ return Optional.of((ComparableLiteral) right);
+ } else {
+ return Optional.empty();
+ }
+ }
+
+ @Override
+ public ValueDesc visitNot(Not not, ExpressionRewriteContext context) {
+ ValueDesc childValue = not.child().accept(this, context);
+ if (childValue instanceof DiscreteValue) {
+ return new NotDiscreteValue(context, childValue.getReference(),
((DiscreteValue) childValue).values);
+ } else if (childValue instanceof IsNullValue) {
+ return new IsNotNullValue(context, childValue.getReference(), not);
+ } else {
+ return new UnknownValue(context, not);
+ }
+ }
+
+ @Override
+ public ValueDesc visitIsNull(IsNull isNull, ExpressionRewriteContext
context) {
+ return new IsNullValue(context, isNull.child());
+ }
+
@Override
public ValueDesc visitAnd(And and, ExpressionRewriteContext context) {
- return simplify(context, ExpressionUtils.extractConjunction(and),
- ValueDesc::intersect, true);
+ return processCompound(context,
ExpressionUtils.extractConjunction(and), true);
}
@Override
public ValueDesc visitOr(Or or, ExpressionRewriteContext context) {
- return simplify(context, ExpressionUtils.extractDisjunction(or),
- ValueDesc::union, false);
+ return processCompound(context,
ExpressionUtils.extractDisjunction(or), false);
}
- private ValueDesc simplify(ExpressionRewriteContext context,
List<Expression> predicates,
- BinaryOperator<ValueDesc> op, boolean isAnd) {
-
- boolean convertIsNullToEmptyValue = isAnd &&
predicates.stream().anyMatch(expr -> expr instanceof NullLiteral);
- Multimap<Expression, ValueDesc> groupByReference
- = Multimaps.newListMultimap(new LinkedHashMap<>(),
ArrayList::new);
+ private ValueDesc processCompound(ExpressionRewriteContext context,
List<Expression> predicates, boolean isAnd) {
+ boolean hasNullExpression = false;
+ boolean hasIsNullExpression = false;
+ boolean hasNotIsNullExpression = false;
+ Predicate<Expression> isNotNull = expression -> expression instanceof
Not
+ && expression.child(0) instanceof IsNull
+ && !((Not) expression).isGeneratedIsNotNull();
for (Expression predicate : predicates) {
- // EmptyValue(a) = IsNull(a) and null, it doesn't equals to
IsNull(a).
- // Only the and expression contains at least a null literal in its
conjunctions,
- // then EmptyValue(a) can equivalent to IsNull(a).
- // so for expression and(IsNull(a), IsNull(b), ..., null), a, b
can convert to EmptyValue.
- // What's more, if a is not nullable, then EmptyValue(a) always
equals to IsNull(a),
- // but we don't consider this case here, we should fold IsNull(a)
to FALSE using other rule.
+ hasNullExpression = hasNullExpression || predicate.isNullLiteral();
+ hasIsNullExpression = hasIsNullExpression || predicate instanceof
IsNull;
+ hasNotIsNullExpression = hasNotIsNullExpression ||
isNotNull.test(predicate);
+ }
+ boolean convertIsNullToEmptyValue = isAnd && hasNullExpression &&
hasIsNullExpression;
+ boolean convertNotIsNullToRangeAll = !isAnd && hasNullExpression &&
hasNotIsNullExpression;
+ Map<Expression, ValueDescCollector> groupByReference =
Maps.newLinkedHashMap();
+ for (Expression predicate : predicates) {
+ // given an expression A, no matter A is nullable or not,
+ // 'A is null and null' can represent as EmptyValue(A),
+ // 'A is not null or null' can represent as RangeAll(A).
ValueDesc valueDesc = null;
- if (convertIsNullToEmptyValue && predicate instanceof IsNull) {
+ if (predicate instanceof IsNull && convertIsNullToEmptyValue) {
valueDesc = new EmptyValue(context, ((IsNull)
predicate).child());
+ } else if (isNotNull.test(predicate) &&
convertNotIsNullToRangeAll) {
+ valueDesc = new RangeValue(context,
predicate.child(0).child(0), Range.all());
+ } else if (predicate.isNullLiteral() && (convertIsNullToEmptyValue
|| convertNotIsNullToRangeAll)) {
+ continue;
} else {
valueDesc = predicate.accept(this, context);
}
- List<ValueDesc> valueDescs = (List<ValueDesc>)
groupByReference.get(valueDesc.reference);
- valueDescs.add(valueDesc);
+ Expression reference = valueDesc.reference;
+ groupByReference.computeIfAbsent(reference, key -> new
ValueDescCollector()).add(valueDesc);
}
List<ValueDesc> valuePerRefs = Lists.newArrayList();
- for (Entry<Expression, Collection<ValueDesc>> referenceValues :
groupByReference.asMap().entrySet()) {
+ for (Entry<Expression, ValueDescCollector> referenceValues :
groupByReference.entrySet()) {
Expression reference = referenceValues.getKey();
- List<ValueDesc> valuePerReference = (List)
referenceValues.getValue();
- if (!isAnd) {
- valuePerReference = ValueDesc.unionDiscreteAndRange(context,
reference, valuePerReference);
- }
-
- // merge per reference
- ValueDesc simplifiedValue = valuePerReference.get(0);
- for (int i = 1; i < valuePerReference.size(); i++) {
- simplifiedValue = op.apply(simplifiedValue,
valuePerReference.get(i));
+ ValueDescCollector collector = referenceValues.getValue();
+ ValueDesc mergedValue;
+ if (isAnd) {
+ mergedValue = intersect(context, reference, collector);
+ } else {
+ mergedValue = union(context, reference, collector);
}
-
- valuePerRefs.add(simplifiedValue);
+ valuePerRefs.add(mergedValue);
}
if (valuePerRefs.size() == 1) {
return valuePerRefs.get(0);
}
- // use UnknownValue to wrap different references
- return new UnknownValue(context, valuePerRefs, isAnd);
+ Expression reference =
SimplifyRange.INSTANCE.getCompoundExpression(context, valuePerRefs, isAnd);
+ return new CompoundValue(context, reference, valuePerRefs, isAnd);
}
- /**
- * value desc
- */
- public abstract static class ValueDesc {
- ExpressionRewriteContext context;
- Expression reference;
+ private ValueDesc intersect(ExpressionRewriteContext context, Expression
reference, ValueDescCollector collector) {
+ List<ValueDesc> resultValues = Lists.newArrayList();
- public ValueDesc(ExpressionRewriteContext context, Expression
reference) {
- this.context = context;
- this.reference = reference;
+ // merge all the range values
+ Range<ComparableLiteral> mergeRangeValue = null;
+ if (!collector.hasEmptyValue && !collector.rangeValues.isEmpty()) {
+ RangeValue mergeRangeValueDesc = null;
+ for (RangeValue rangeValue : collector.rangeValues) {
+ if (mergeRangeValueDesc == null) {
+ mergeRangeValueDesc = rangeValue;
+ } else {
+ ValueDesc combineValue =
mergeRangeValueDesc.intersect(rangeValue);
+ if (combineValue instanceof RangeValue) {
+ mergeRangeValueDesc = (RangeValue) combineValue;
+ } else {
+ collector.add(combineValue);
+ mergeRangeValueDesc = null;
+ // no need to process the lefts.
+ if (combineValue instanceof EmptyValue) {
+ break;
+ }
+ }
+ }
+ }
+ if (!collector.hasEmptyValue && mergeRangeValueDesc != null) {
+ mergeRangeValue = mergeRangeValueDesc.range;
+ }
}
- public Expression getReference() {
- return reference;
+ // merge all the discrete values
+ Set<ComparableLiteral> mergeDiscreteValues = null;
+ if (!collector.hasEmptyValue && !collector.discreteValues.isEmpty()) {
+ mergeDiscreteValues =
Sets.newLinkedHashSet(collector.discreteValues.get(0).values);
+ for (int i = 1; i < collector.discreteValues.size(); i++) {
+
mergeDiscreteValues.retainAll(collector.discreteValues.get(i).values);
+ }
+ if (mergeDiscreteValues.isEmpty()) {
+ collector.add(new EmptyValue(context, reference));
+ mergeDiscreteValues = null;
+ }
}
- public ExpressionRewriteContext getExpressionRewriteContext() {
- return context;
+ // merge all the not discrete values
+ Set<ComparableLiteral> mergeNotDiscreteValues =
Sets.newLinkedHashSet();
+ if (!collector.hasEmptyValue &&
!collector.notDiscreteValues.isEmpty()) {
+ for (NotDiscreteValue notDiscreteValue :
collector.notDiscreteValues) {
+ mergeNotDiscreteValues.addAll(notDiscreteValue.values);
+ }
+ if (mergeRangeValue != null) {
+ Range<ComparableLiteral> finalValue = mergeRangeValue;
+ mergeNotDiscreteValues.removeIf(value ->
!finalValue.contains(value));
+ }
+ if (mergeDiscreteValues != null) {
+ Set<ComparableLiteral> finalValues = mergeDiscreteValues;
+ mergeNotDiscreteValues.removeIf(value ->
!finalValues.contains(value));
+ mergeDiscreteValues.removeIf(mergeNotDiscreteValues::contains);
+ if (mergeDiscreteValues.isEmpty()) {
+ collector.add(new EmptyValue(context, reference));
+ mergeDiscreteValues = null;
+ }
+ }
+ }
+ if (!collector.hasEmptyValue) {
+ // merge range + discrete values
+ if (mergeRangeValue != null && mergeDiscreteValues != null) {
+ ValueDesc newMergeValue = new RangeValue(context, reference,
mergeRangeValue)
+ .intersect(new DiscreteValue(context, reference,
mergeDiscreteValues));
+ resultValues.add(newMergeValue);
+ } else if (mergeRangeValue != null) {
+ resultValues.add(new RangeValue(context, reference,
mergeRangeValue));
+ } else if (mergeDiscreteValues != null) {
+ resultValues.add(new DiscreteValue(context, reference,
mergeDiscreteValues));
+ }
+ if (!collector.hasEmptyValue && !mergeNotDiscreteValues.isEmpty())
{
+ resultValues.add(new NotDiscreteValue(context, reference,
mergeNotDiscreteValues));
+ }
}
- public abstract ValueDesc union(ValueDesc other);
-
- /** or */
- public static ValueDesc union(ExpressionRewriteContext context,
- RangeValue range, DiscreteValue discrete, boolean
reverseOrder) {
- if (discrete.values.stream().allMatch(x -> range.range.test(x))) {
- return range;
- }
- List<ValueDesc> sourceValues = reverseOrder
- ? ImmutableList.of(discrete, range)
- : ImmutableList.of(range, discrete);
- return new UnknownValue(context, sourceValues, false);
- }
-
- /** merge discrete and ranges only, no merge other value desc */
- public static List<ValueDesc>
unionDiscreteAndRange(ExpressionRewriteContext context,
- Expression reference, List<ValueDesc> valueDescs) {
- // Since in-predicate's options is a list, the discrete values
need to kept options' order.
- // If not keep options' order, the result in-predicate's option
list will not equals to
- // the input in-predicate, later nereids will need to simplify the
new in-predicate,
- // then cause dead loop.
- Set<ComparableLiteral> discreteValues = Sets.newLinkedHashSet();
- for (ValueDesc valueDesc : valueDescs) {
- if (valueDesc instanceof DiscreteValue) {
- discreteValues.addAll(((DiscreteValue)
valueDesc).getValues());
- }
- }
-
- // for 'a > 8 or a = 8', then range (8, +00) can convert to [8,
+00)
- RangeSet<ComparableLiteral> rangeSet = TreeRangeSet.create();
- for (ValueDesc valueDesc : valueDescs) {
- if (valueDesc instanceof RangeValue) {
- Range<ComparableLiteral> range = ((RangeValue)
valueDesc).range;
- rangeSet.add(range);
- if (range.hasLowerBound()
- && range.lowerBoundType() == BoundType.OPEN
- && discreteValues.contains(range.lowerEndpoint()))
{
- rangeSet.add(Range.singleton(range.lowerEndpoint()));
- }
- if (range.hasUpperBound()
- && range.upperBoundType() == BoundType.OPEN
- && discreteValues.contains(range.upperEndpoint()))
{
- rangeSet.add(Range.singleton(range.upperEndpoint()));
- }
- }
+ // process empty value
+ if (collector.hasEmptyValue) {
+ if (!reference.nullable()) {
+ return new UnknownValue(context, BooleanLiteral.FALSE);
}
+ resultValues.add(new EmptyValue(context, reference));
+ }
+
+ // process is null and is not null
+ // for non-nullable a: EmptyValue(a) = a is null and null
+ boolean hasIsNullValue = collector.hasIsNullValue ||
collector.hasEmptyValue && reference.nullable();
+ boolean hasIsNotNullValue = collector.isNotNullValueOpt.isPresent()
+ || collector.isGenerateNotNullValueOpt.isPresent()
+ || mergeRangeValue != null && !mergeRangeValue.hasLowerBound()
&& !mergeRangeValue.hasUpperBound();
+ if (hasIsNullValue && hasIsNotNullValue) {
+ return new UnknownValue(context, BooleanLiteral.FALSE);
+ }
+ // nullable's EmptyValue have contains IsNull, no need to add
+ if (!collector.hasEmptyValue && collector.hasIsNullValue) {
+ resultValues.add(new IsNullValue(context, reference));
+ }
+ collector.isNotNullValueOpt.ifPresent(resultValues::add);
+ collector.isGenerateNotNullValueOpt.ifPresent(resultValues::add);
+ Optional<ValueDesc> shortCutResult = mergeCompoundValues(context,
reference, resultValues, collector, true);
+ if (shortCutResult.isPresent()) {
+ return shortCutResult.get();
+ }
+ // unknownValue should be empty
+ resultValues.addAll(collector.unknownValues);
- if (!rangeSet.isEmpty()) {
- discreteValues.removeIf(x -> rangeSet.contains(x));
+ Preconditions.checkArgument(!resultValues.isEmpty());
+ if (resultValues.size() == 1) {
+ return resultValues.get(0);
+ } else {
+ return new CompoundValue(context, reference, resultValues, true);
+ }
+ }
+
+ private ValueDesc union(ExpressionRewriteContext context, Expression
reference, ValueDescCollector collector) {
+ List<ValueDesc> resultValues =
Lists.newArrayListWithExpectedSize(collector.size() + 3);
+ // Since in-predicate's options is a list, the discrete values need to
kept options' order.
+ // If not keep options' order, the result in-predicate's option list
will not equals to
+ // the input in-predicate, later nereids will need to simplify the new
in-predicate,
+ // then cause dead loop.
+ Set<ComparableLiteral> discreteValues = Sets.newLinkedHashSet();
+ for (DiscreteValue discreteValue : collector.discreteValues) {
+ discreteValues.addAll(discreteValue.values);
+ }
+
+ // for 'a > 8 or a = 8', then range (8, +00) can convert to [8, +00)
+ RangeSet<ComparableLiteral> rangeSet = TreeRangeSet.create();
+ for (RangeValue rangeValue : collector.rangeValues) {
+ Range<ComparableLiteral> range = rangeValue.range;
+ rangeSet.add(range);
+ if (range.hasLowerBound()
+ && range.lowerBoundType() == BoundType.OPEN
+ && discreteValues.contains(range.lowerEndpoint())) {
+ rangeSet.add(Range.singleton(range.lowerEndpoint()));
+ }
+ if (range.hasUpperBound()
+ && range.upperBoundType() == BoundType.OPEN
+ && discreteValues.contains(range.upperEndpoint())) {
+ rangeSet.add(Range.singleton(range.upperEndpoint()));
}
+ }
+
+ if (!rangeSet.isEmpty()) {
+ discreteValues.removeIf(rangeSet::contains);
+ }
- List<ValueDesc> result =
Lists.newArrayListWithExpectedSize(valueDescs.size());
+ Set<ComparableLiteral> mergeNotDiscreteValues =
Sets.newLinkedHashSet();
+ boolean hasRangeAll = false;
+ if (!collector.notDiscreteValues.isEmpty()) {
+
mergeNotDiscreteValues.addAll(collector.notDiscreteValues.get(0).values);
+ // a not in (1, 2) or a not in (1, 2, 3) => a not in (1, 2)
+ for (int i = 1; i < collector.notDiscreteValues.size(); i++) {
+
mergeNotDiscreteValues.retainAll(collector.notDiscreteValues.get(i).values);
+ }
+ // a not in (1, 2, 3) or a in (1, 2, 4) => a not in (3)
+ mergeNotDiscreteValues.removeIf(
+ value -> discreteValues.contains(value) ||
rangeSet.contains(value));
+ discreteValues.removeIf(mergeNotDiscreteValues::contains);
+ if (mergeNotDiscreteValues.isEmpty()) {
+ resultValues.add(new RangeValue(context, reference,
Range.all()));
+ } else {
+ resultValues.add(new NotDiscreteValue(context, reference,
mergeNotDiscreteValues));
+ }
+ } else {
if (!discreteValues.isEmpty()) {
- result.add(new DiscreteValue(context, reference,
discreteValues));
+ resultValues.add(new DiscreteValue(context, reference,
discreteValues));
}
for (Range<ComparableLiteral> range : rangeSet.asRanges()) {
- result.add(new RangeValue(context, reference, range));
+ hasRangeAll = hasRangeAll || !range.hasUpperBound() &&
!range.hasLowerBound();
+ resultValues.add(new RangeValue(context, reference, range));
}
- for (ValueDesc valueDesc : valueDescs) {
- if (!(valueDesc instanceof DiscreteValue) && !(valueDesc
instanceof RangeValue)) {
- result.add(valueDesc);
+ }
+
+ boolean hasIsNullValue = collector.hasIsNullValue ||
collector.hasEmptyValue && !reference.nullable();
+ boolean hasIsNotNullValue = collector.isNotNullValueOpt.isPresent()
+ || collector.isGenerateNotNullValueOpt.isPresent() ||
hasRangeAll;
+ if (hasIsNullValue && hasIsNotNullValue) {
+ return new UnknownValue(context, BooleanLiteral.TRUE);
+ } else if (collector.hasIsNullValue) {
+ resultValues.add(new IsNullValue(context, reference));
+ } else {
+ collector.isNotNullValueOpt.ifPresent(resultValues::add);
+ collector.isGenerateNotNullValueOpt.ifPresent(resultValues::add);
+ }
+ Optional<ValueDesc> shortCutResult = mergeCompoundValues(context,
reference, resultValues, collector, false);
+ if (shortCutResult.isPresent()) {
+ return shortCutResult.get();
+ }
+ if (collector.hasEmptyValue) {
+ // for IsNotNull OR EmptyValue, need keep the EmptyValue
+ boolean ignoreEmptyValue = !resultValues.isEmpty() &&
!reference.nullable();
+ for (ValueDesc valueDesc : resultValues) {
+ if (valueDesc instanceof CompoundValue) {
+ ignoreEmptyValue = ignoreEmptyValue || !((CompoundValue)
valueDesc).hasNoneNullable;
+ } else if (valueDesc.nullable() || valueDesc instanceof
IsNullValue) {
+ ignoreEmptyValue = true;
}
+ if (ignoreEmptyValue) {
+ break;
+ }
+ }
+ if (!ignoreEmptyValue) {
+ resultValues.add(new EmptyValue(context, reference));
}
+ }
+ resultValues.addAll(collector.unknownValues);
+ Preconditions.checkArgument(!resultValues.isEmpty());
+ if (resultValues.size() == 1) {
+ return resultValues.get(0);
+ } else {
+ return new CompoundValue(context, reference, resultValues, false);
+ }
+ }
- return result;
+ private Optional<ValueDesc> mergeCompoundValues(ExpressionRewriteContext
context, Expression reference,
+ List<ValueDesc> resultValues, ValueDescCollector collector,
boolean isAnd) {
+ // for A and (B or C):
+ // if A and B is false/empty, then A and (B or C) = A and C
+ // if B's range is bigger than A, then A and (B or C) = A
+ // for A or (B and C):
+ // if A's range is bigger than B, then A or (B and C) = A
+ // if A or B is true/all, then A or (B and C) = A or C
+ for (CompoundValue compoundValue : collector.compoundValues) {
+ if (isAnd != compoundValue.isAnd
+ && compoundValue.reference.equals(reference)
+ // no process the compose value which reference different
+ &&
compoundValue.sourceValues.get(0).reference.equals(reference)) {
+ ImmutableList.Builder<ValueDesc> newSourceValuesBuilder
+ =
ImmutableList.builderWithExpectedSize(compoundValue.sourceValues.size());
+ boolean skipWholeCompoundValue = false;
+ boolean hasNullableSkipSourceValues = false;
+ for (ValueDesc innerValue : compoundValue.sourceValues) {
+ IntersectType intersectType = IntersectType.OTHERS;
+ UnionType unionType = UnionType.OTHERS;
+ for (ValueDesc outerValue : resultValues) {
+ if (isAnd) {
+ skipWholeCompoundValue = skipWholeCompoundValue ||
innerValue.containsAll(outerValue);
+ IntersectType type =
outerValue.getIntersectType(innerValue);
+ if (type == IntersectType.EMPTY_VALUE
+ && intersectType != IntersectType.FALSE
+ && outerValue.nullable()) {
+ intersectType = type;
+ }
+ if (type == IntersectType.FALSE) {
+ intersectType = type;
+ }
+ } else {
+ skipWholeCompoundValue = skipWholeCompoundValue ||
outerValue.containsAll(innerValue);
+ UnionType type =
outerValue.getUnionType(innerValue);
+ if (type == UnionType.TRUE) {
+ unionType = type;
+ }
+ }
+ }
+ if (skipWholeCompoundValue) {
+ break;
+ }
+ if (isAnd) {
+ if (intersectType == IntersectType.OTHERS) {
+ newSourceValuesBuilder.add(innerValue);
+ } else {
+ hasNullableSkipSourceValues =
hasNullableSkipSourceValues
+ || intersectType ==
IntersectType.EMPTY_VALUE;
+ }
+ } else {
+ if (unionType == UnionType.OTHERS) {
+ newSourceValuesBuilder.add(innerValue);
+ } else {
+ hasNullableSkipSourceValues =
hasNullableSkipSourceValues
+ || unionType == UnionType.RANGE_ALL;
+ }
+ }
+ }
+ if (!skipWholeCompoundValue) {
+ List<ValueDesc> newSourceValues =
newSourceValuesBuilder.build();
+ if (newSourceValues.isEmpty()) {
+ if (isAnd) {
+ if (!hasNullableSkipSourceValues) {
+ return Optional.of(new UnknownValue(context,
BooleanLiteral.FALSE));
+ }
+ resultValues.add(new EmptyValue(context,
reference));
+ } else {
+ if (!hasNullableSkipSourceValues) {
+ return Optional.of(new UnknownValue(context,
BooleanLiteral.TRUE));
+ }
+ resultValues.add(new RangeValue(context,
reference, Range.all()));
+ }
+ } else if (newSourceValues.size() == 1) {
+ resultValues.add(newSourceValues.get(0));
+ } else {
+ resultValues.add(new CompoundValue(context, reference,
newSourceValues, compoundValue.isAnd));
+ }
+ }
+ } else {
+ resultValues.add(compoundValue);
+ }
+ }
+
+ return Optional.empty();
+ }
+
+ /** value desc visitor */
+ public interface ValueDescVisitor<R, C> {
+ R visitEmptyValue(EmptyValue emptyValue, C context);
+
+ R visitRangeValue(RangeValue rangeValue, C context);
+
+ R visitDiscreteValue(DiscreteValue discreteValue, C context);
+
+ R visitNotDiscreteValue(NotDiscreteValue notDiscreteValue, C context);
+
+ R visitIsNullValue(IsNullValue isNullValue, C context);
+
+ R visitIsNotNullValue(IsNotNullValue isNotNullValue, C context);
+
+ R visitCompoundValue(CompoundValue compoundValue, C context);
+
+ R visitUnknownValue(UnknownValue unknownValue, C context);
+ }
+
+ private static class ValueDescCollector implements ValueDescVisitor<Void,
Void> {
+ // generated not is null != not is null
+ Optional<IsNotNullValue> isNotNullValueOpt = Optional.empty();
+ Optional<IsNotNullValue> isGenerateNotNullValueOpt = Optional.empty();
+
+ boolean hasIsNullValue = false;
+ boolean hasEmptyValue = false;
+ List<RangeValue> rangeValues = Lists.newArrayList();
+ List<DiscreteValue> discreteValues = Lists.newArrayList();
+ List<NotDiscreteValue> notDiscreteValues = Lists.newArrayList();
+ List<CompoundValue> compoundValues = Lists.newArrayList();
+ List<UnknownValue> unknownValues = Lists.newArrayList();
+
+ void add(ValueDesc value) {
+ value.accept(this, null);
+ }
+
+ int size() {
+ return rangeValues.size() + discreteValues.size() +
compoundValues.size() + unknownValues.size();
+ }
+
+ @Override
+ public Void visitEmptyValue(EmptyValue emptyValue, Void context) {
+ hasEmptyValue = true;
+ return null;
+ }
+
+ @Override
+ public Void visitRangeValue(RangeValue rangeValue, Void context) {
+ rangeValues.add(rangeValue);
+ return null;
}
- /** intersect */
- public abstract ValueDesc intersect(ValueDesc other);
+ @Override
+ public Void visitDiscreteValue(DiscreteValue discreteValue, Void
context) {
+ discreteValues.add(discreteValue);
+ return null;
+ }
+
+ @Override
+ public Void visitNotDiscreteValue(NotDiscreteValue notDiscreteValue,
Void context) {
+ notDiscreteValues.add(notDiscreteValue);
+ return null;
+ }
- /** intersect */
- public static ValueDesc intersect(ExpressionRewriteContext context,
RangeValue range, DiscreteValue discrete) {
- // Since in-predicate's options is a list, the discrete values
need to kept options' order.
- // If not keep options' order, the result in-predicate's option
list will not equals to
- // the input in-predicate, later nereids will need to simplify the
new in-predicate,
- // then cause dead loop.
- Set<ComparableLiteral> newValues =
discrete.values.stream().filter(x -> range.range.contains(x))
- .collect(Collectors.toCollection(
- () ->
Sets.newLinkedHashSetWithExpectedSize(discrete.values.size())));
- if (newValues.isEmpty()) {
- return new EmptyValue(context, range.reference);
+ @Override
+ public Void visitIsNullValue(IsNullValue isNullValue, Void context) {
+ hasIsNullValue = true;
+ return null;
+ }
+
+ @Override
+ public Void visitIsNotNullValue(IsNotNullValue isNotNullValue, Void
context) {
+ if (isNotNullValue.not.isGeneratedIsNotNull()) {
+ isGenerateNotNullValueOpt = Optional.of(isNotNullValue);
} else {
- return new DiscreteValue(context, range.reference, newValues);
+ isNotNullValueOpt = Optional.of(isNotNullValue);
}
+ return null;
}
- private static ValueDesc range(ExpressionRewriteContext context,
ComparisonPredicate predicate) {
- ComparableLiteral value = (ComparableLiteral) predicate.right();
- if (predicate instanceof EqualTo) {
- return new DiscreteValue(context, predicate.left(),
Sets.newHashSet(value));
- }
- Range<ComparableLiteral> range = null;
- if (predicate instanceof GreaterThanEqual) {
- range = Range.atLeast(value);
- } else if (predicate instanceof GreaterThan) {
- range = Range.greaterThan(value);
- } else if (predicate instanceof LessThanEqual) {
- range = Range.atMost(value);
- } else if (predicate instanceof LessThan) {
- range = Range.lessThan(value);
- }
+ @Override
+ public Void visitCompoundValue(CompoundValue compoundValue, Void
context) {
+ compoundValues.add(compoundValue);
+ return null;
+ }
- return new RangeValue(context, predicate.left(), range);
+ @Override
+ public Void visitUnknownValue(UnknownValue unknownValue, Void context)
{
+ unknownValues.add(unknownValue);
+ return null;
}
+ }
+
+ /** union two value result */
+ public enum UnionType {
+ TRUE, // equals TRUE
+ RANGE_ALL, // trueOrNull(reference)
+ OTHERS, // other case
+ }
+
+ /** intersect two value result */
+ public enum IntersectType {
+ FALSE, // equals FALSE
+ EMPTY_VALUE, // falseOrNull(reference)
+ OTHERS, // other case
+ }
+
+ /**
+ * value desc
+ */
+ public abstract static class ValueDesc {
+ protected final ExpressionRewriteContext context;
+ protected final Expression reference;
- private static ValueDesc discrete(ExpressionRewriteContext context,
InPredicate in) {
- // Since in-predicate's options is a list, the discrete values
need to kept options' order.
- // If not keep options' order, the result in-predicate's option
list will not equals to
- // the input in-predicate, later nereids will need to simplify the
new in-predicate,
- // then cause dead loop.
- // Set<ComparableLiteral> literals = (Set)
Utils.fastToImmutableSet(in.getOptions());
- Set<ComparableLiteral> literals = in.getOptions().stream()
- .map(ComparableLiteral.class::cast)
- .collect(Collectors.toCollection(
- () ->
Sets.newLinkedHashSetWithExpectedSize(in.getOptions().size())));
- return new DiscreteValue(context, in.getCompareExpr(), literals);
+ public ValueDesc(ExpressionRewriteContext context, Expression
reference) {
+ this.context = context;
+ this.reference = reference;
+ }
+
+ public ExpressionRewriteContext getExpressionRewriteContext() {
+ return context;
+ }
+
+ public Expression getReference() {
+ return reference;
+ }
+
+ public <R, C> R accept(ValueDescVisitor<R, C> visitor, C context) {
+ return visit(visitor, context);
}
+
+ protected abstract <R, C> R visit(ValueDescVisitor<R, C> visitor, C
context);
+
+ protected abstract boolean nullable();
+
+ // X containsAll Y, means:
+ // 1) when Y is TRUE, X is TRUE;
+ // 2) when Y is FALSE, X can be any;
+ // 3) when Y is null, X is null;
+ // then will have:
+ // use in 'A and (B or C)', if B containsAll A, then rewrite it to 'A',
+ // use in 'A or (B and C)', if A containsAll B, then rewrite it to 'A'.
+ @VisibleForTesting
+ public final boolean containsAll(ValueDesc other) {
+ return containsAll(other, 0);
+ }
+
+ protected abstract boolean containsAll(ValueDesc other, int depth);
+
+ // X, Y intersectWithIsEmpty, means 'X and Y' is:
+ // 1) FALSE && !X.nullable() && !Y.nullable();
+ // 2) EmptyValue && X.nullable() && Y.nullable()), the nullable check
no loss the null
+ // use in 'A and (B or C)', if A, B intersectWithIsEmpty, then rewrite
it to 'A and C'
+ @VisibleForTesting
+ public final IntersectType getIntersectType(ValueDesc other) {
+ return getIntersectType(other, 0);
+ }
+
+ protected abstract IntersectType getIntersectType(ValueDesc other, int
depth);
+
+ // X, Y unionWithIsAll, means 'X union Y' is:
+ // 1) TRUE && !X.nullable() && !Y.nullable();
+ // 2) Range.all() && X.nullable() && Y.nullable(), the nullable check
no loss the null;
+ // use in 'A or (B and C)', if A, B unionWithIsAll, then rewrite it to
'A or C'
+ @VisibleForTesting
+ public final UnionType getUnionType(ValueDesc other) {
+ return getUnionType(other, 0);
+ }
+
+ protected abstract UnionType getUnionType(ValueDesc other, int depth);
}
/**
@@ -332,13 +723,49 @@ public class RangeInference extends
ExpressionVisitor<RangeInference.ValueDesc,
}
@Override
- public ValueDesc union(ValueDesc other) {
- return other;
+ protected <R, C> R visit(ValueDescVisitor<R, C> visitor, C context) {
+ return visitor.visitEmptyValue(this, context);
}
@Override
- public ValueDesc intersect(ValueDesc other) {
- return this;
+ protected boolean nullable() {
+ return reference.nullable();
+ }
+
+ @Override
+ protected boolean containsAll(ValueDesc other, int depth) {
+ return other instanceof EmptyValue || (other instanceof
IsNullValue && !reference.nullable());
+ }
+
+ @Override
+ protected IntersectType getIntersectType(ValueDesc other, int depth) {
+ if (other instanceof EmptyValue || other instanceof RangeValue
+ || other instanceof DiscreteValue || other instanceof
NotDiscreteValue
+ || other instanceof IsNullValue) {
+ return reference.nullable() ? IntersectType.EMPTY_VALUE :
IntersectType.FALSE;
+ } else if (other instanceof IsNotNullValue) {
+ return IntersectType.FALSE;
+ } else if (other instanceof CompoundValue) {
+ return ((CompoundValue) other).getIntersectType(this, depth);
+ } else {
+ return IntersectType.OTHERS;
+ }
+ }
+
+ @Override
+ protected UnionType getUnionType(ValueDesc other, int depth) {
+ if (other instanceof RangeValue) {
+ if (((RangeValue) other).isRangeAll()) {
+ return reference.nullable() ? UnionType.RANGE_ALL :
UnionType.TRUE;
+ }
+ } else if (other instanceof IsNotNullValue) {
+ if (!reference.nullable()) {
+ return UnionType.TRUE;
+ }
+ } else if (other instanceof CompoundValue) {
+ return ((CompoundValue) other).getUnionType(this, depth);
+ }
+ return UnionType.OTHERS;
}
}
@@ -348,7 +775,8 @@ public class RangeInference extends
ExpressionVisitor<RangeInference.ValueDesc,
* a > 1 => (1...+∞)
*/
public static class RangeValue extends ValueDesc {
- Range<ComparableLiteral> range;
+
+ final Range<ComparableLiteral> range;
public RangeValue(ExpressionRewriteContext context, Expression
reference, Range<ComparableLiteral> range) {
super(context, reference);
@@ -360,54 +788,122 @@ public class RangeInference extends
ExpressionVisitor<RangeInference.ValueDesc,
}
@Override
- public ValueDesc union(ValueDesc other) {
+ protected <R, C> R visit(ValueDescVisitor<R, C> visitor, C context) {
+ return visitor.visitRangeValue(this, context);
+ }
+
+ @Override
+ protected boolean nullable() {
+ return reference.nullable();
+ }
+
+ @Override
+ protected boolean containsAll(ValueDesc other, int depth) {
if (other instanceof EmptyValue) {
- return other.union(this);
+ return true;
+ } else if (other instanceof RangeValue) {
+ return range.encloses(((RangeValue) other).range);
+ } else if (other instanceof DiscreteValue) {
+ return range.containsAll(((DiscreteValue) other).values);
+ } else if (other instanceof NotDiscreteValue || other instanceof
IsNotNullValue) {
+ return isRangeAll();
+ } else if (other instanceof CompoundValue) {
+ return ((CompoundValue) other).isContainedAllBy(this, depth);
+ } else {
+ return false;
}
- if (other instanceof RangeValue) {
- RangeValue o = (RangeValue) other;
- if (range.isConnected(o.range)) {
- return new RangeValue(context, reference,
range.span(o.range));
+ }
+
+ @Override
+ protected IntersectType getIntersectType(ValueDesc other, int depth) {
+ if (other instanceof EmptyValue) {
+ return reference.nullable() ? IntersectType.EMPTY_VALUE :
IntersectType.FALSE;
+ } else if (other instanceof RangeValue) {
+ if (intersect((RangeValue) other) instanceof EmptyValue) {
+ return reference.nullable() ? IntersectType.EMPTY_VALUE :
IntersectType.FALSE;
}
- return new UnknownValue(context, ImmutableList.of(this,
other), false);
- }
- if (other instanceof DiscreteValue) {
- return union(context, this, (DiscreteValue) other, false);
+ } else if (other instanceof DiscreteValue) {
+ if (intersect((DiscreteValue) other) instanceof EmptyValue) {
+ return reference.nullable() ? IntersectType.EMPTY_VALUE :
IntersectType.FALSE;
+ }
+ } else if (other instanceof IsNullValue) {
+ return reference.nullable() ? IntersectType.EMPTY_VALUE :
IntersectType.FALSE;
+ } else if (other instanceof CompoundValue) {
+ return ((CompoundValue) other).getIntersectType(this, depth);
}
- return new UnknownValue(context, ImmutableList.of(this, other),
false);
+ return IntersectType.OTHERS;
}
@Override
- public ValueDesc intersect(ValueDesc other) {
- if (other instanceof EmptyValue) {
- return other.intersect(this);
+ protected UnionType getUnionType(ValueDesc other, int depth) {
+ if ((other instanceof EmptyValue || other instanceof
DiscreteValue) && isRangeAll()) {
+ return reference.nullable() ? UnionType.RANGE_ALL :
UnionType.TRUE;
+ } else if (other instanceof RangeValue) {
+ Range<ComparableLiteral> otherRange = ((RangeValue)
other).range;
+ if (range.isConnected(otherRange)) {
+ Range<ComparableLiteral> unionRange =
range.span(otherRange);
+ if (!unionRange.hasLowerBound() &&
!unionRange.hasUpperBound()) {
+ return reference.nullable() ? UnionType.RANGE_ALL :
UnionType.TRUE;
+ }
+ }
+ } else if (other instanceof NotDiscreteValue) {
+ Set<ComparableLiteral> notDiscreteValues = ((NotDiscreteValue)
other).values;
+ boolean succ = true;
+ for (ComparableLiteral value : notDiscreteValues) {
+ if (!range.contains(value)) {
+ succ = false;
+ break;
+ }
+ }
+ if (succ) {
+ return reference.nullable() ? UnionType.RANGE_ALL :
UnionType.TRUE;
+ }
+ } else if (other instanceof IsNullValue && !reference.nullable()
&& isRangeAll()) {
+ return UnionType.TRUE;
+ } else if (other instanceof IsNotNullValue) {
+ if (!reference.nullable()) {
+ return UnionType.TRUE;
+ }
+ } else if (other instanceof CompoundValue) {
+ return ((CompoundValue) other).getUnionType(this, depth);
}
- if (other instanceof RangeValue) {
- RangeValue o = (RangeValue) other;
- if (range.isConnected(o.range)) {
- Range<ComparableLiteral> newRange =
range.intersection(o.range);
- if (!newRange.isEmpty()) {
- if (newRange.hasLowerBound() &&
newRange.hasUpperBound()
- &&
newRange.lowerEndpoint().compareTo(newRange.upperEndpoint()) == 0
- && newRange.lowerBoundType() ==
BoundType.CLOSED
- && newRange.lowerBoundType() ==
BoundType.CLOSED) {
- return new DiscreteValue(context, reference,
Sets.newHashSet(newRange.lowerEndpoint()));
- } else {
- return new RangeValue(context, reference,
newRange);
- }
+ return UnionType.OTHERS;
+ }
+
+ private ValueDesc intersect(RangeValue other) {
+ if (range.isConnected(other.range)) {
+ Range<ComparableLiteral> newRange =
range.intersection(other.range);
+ if (!newRange.isEmpty()) {
+ if (newRange.hasLowerBound() && newRange.hasUpperBound()
+ &&
newRange.lowerEndpoint().compareTo(newRange.upperEndpoint()) == 0
+ && newRange.lowerBoundType() == BoundType.CLOSED
+ && newRange.lowerBoundType() == BoundType.CLOSED) {
+ return new DiscreteValue(context, reference,
Sets.newHashSet(newRange.lowerEndpoint()));
+ } else {
+ return new RangeValue(context, reference, newRange);
}
}
- return new EmptyValue(context, reference);
}
- if (other instanceof DiscreteValue) {
- return intersect(context, this, (DiscreteValue) other);
+ return new EmptyValue(context, reference);
+ }
+
+ private ValueDesc intersect(DiscreteValue other) {
+ Set<ComparableLiteral> intersectValues =
Sets.newLinkedHashSetWithExpectedSize(other.values.size());
+ for (ComparableLiteral value : other.values) {
+ if (range.contains(value)) {
+ intersectValues.add(value);
+ }
+ }
+ if (intersectValues.isEmpty()) {
+ return new EmptyValue(context, reference);
+ } else {
+ return new DiscreteValue(context, reference, intersectValues);
}
- return new UnknownValue(context, ImmutableList.of(this, other),
true);
}
- @Override
- public String toString() {
- return range == null ? "UnknownRange" : range.toString();
+ @VisibleForTesting
+ public boolean isRangeAll() {
+ return !range.hasLowerBound() && !range.hasUpperBound();
}
}
@@ -430,93 +926,326 @@ public class RangeInference extends
ExpressionVisitor<RangeInference.ValueDesc,
}
@Override
- public ValueDesc union(ValueDesc other) {
+ protected <R, C> R visit(ValueDescVisitor<R, C> visitor, C context) {
+ return visitor.visitDiscreteValue(this, context);
+ }
+
+ @Override
+ protected boolean nullable() {
+ return reference.nullable();
+ }
+
+ @Override
+ protected boolean containsAll(ValueDesc other, int depth) {
if (other instanceof EmptyValue) {
- return other.union(this);
+ return true;
+ } else if (other instanceof DiscreteValue) {
+ return values.containsAll(((DiscreteValue) other).values);
+ } else if (other instanceof CompoundValue) {
+ return ((CompoundValue) other).isContainedAllBy(this, depth);
+ } else {
+ return false;
}
- if (other instanceof DiscreteValue) {
- Set<ComparableLiteral> newValues = Sets.newLinkedHashSet();
- newValues.addAll(((DiscreteValue) other).values);
- newValues.addAll(this.values);
- return new DiscreteValue(context, reference, newValues);
+ }
+
+ @Override
+ protected IntersectType getIntersectType(ValueDesc other, int depth) {
+ if (other instanceof EmptyValue) {
+ return reference.nullable() ? IntersectType.EMPTY_VALUE :
IntersectType.FALSE;
+ } else if (other instanceof RangeValue) {
+ return ((RangeValue) other).getIntersectType(this, depth);
+ } else if (other instanceof DiscreteValue) {
+ Set<ComparableLiteral> otherValues = ((DiscreteValue)
other).values;
+ for (ComparableLiteral value : otherValues) {
+ if (values.contains(value)) {
+ return IntersectType.OTHERS;
+ }
+ }
+ return reference.nullable() ? IntersectType.EMPTY_VALUE :
IntersectType.FALSE;
+ } else if (other instanceof IsNullValue) {
+ return reference.nullable() ? IntersectType.EMPTY_VALUE :
IntersectType.FALSE;
+ } else if (other instanceof CompoundValue) {
+ return ((CompoundValue) other).getIntersectType(this, depth);
+ } else {
+ return IntersectType.OTHERS;
}
+ }
+
+ @Override
+ protected UnionType getUnionType(ValueDesc other, int depth) {
if (other instanceof RangeValue) {
- return union(context, (RangeValue) other, this, true);
+ return ((RangeValue) other).getUnionType(this, depth);
+ } else if (other instanceof NotDiscreteValue) {
+ boolean succ = true;
+ Set<ComparableLiteral> notDiscreteValues = ((NotDiscreteValue)
other).values;
+ for (ComparableLiteral value : notDiscreteValues) {
+ if (!values.contains(value)) {
+ succ = false;
+ break;
+ }
+ }
+ if (succ) {
+ return reference.nullable() ? UnionType.RANGE_ALL :
UnionType.TRUE;
+ }
+ } else if (other instanceof IsNotNullValue) {
+ if (!reference.nullable()) {
+ return UnionType.TRUE;
+ }
+ } else if (other instanceof CompoundValue) {
+ return ((CompoundValue) other).getUnionType(this, depth);
}
- return new UnknownValue(context, ImmutableList.of(this, other),
false);
+ return UnionType.OTHERS;
+ }
+ }
+
+ /**
+ * for example:
+ * a not in (1,2,3) => [1,2,3]
+ */
+ public static class NotDiscreteValue extends ValueDesc {
+ final Set<ComparableLiteral> values;
+
+ public NotDiscreteValue(ExpressionRewriteContext context,
+ Expression reference, Set<ComparableLiteral> values) {
+ super(context, reference);
+ this.values = values;
}
@Override
- public ValueDesc intersect(ValueDesc other) {
+ protected <R, C> R visit(ValueDescVisitor<R, C> visitor, C context) {
+ return visitor.visitNotDiscreteValue(this, context);
+ }
+
+ @Override
+ protected boolean nullable() {
+ return reference.nullable();
+ }
+
+ @Override
+ protected boolean containsAll(ValueDesc other, int depth) {
if (other instanceof EmptyValue) {
- return other.intersect(this);
- }
- if (other instanceof DiscreteValue) {
- Set<ComparableLiteral> newValues = Sets.newLinkedHashSet();
- newValues.addAll(this.values);
- newValues.retainAll(((DiscreteValue) other).values);
- if (newValues.isEmpty()) {
- return new EmptyValue(context, reference);
- } else {
- return new DiscreteValue(context, reference, newValues);
+ return true;
+ } else if (other instanceof RangeValue) {
+ Range<ComparableLiteral> range = ((RangeValue) other).range;
+ for (ComparableLiteral value : values) {
+ if (range.contains(value)) {
+ return false;
+ }
+ }
+ return true;
+ } else if (other instanceof DiscreteValue) {
+ Set<ComparableLiteral> discreteValues = ((DiscreteValue)
other).values;
+ for (ComparableLiteral value : values) {
+ if (discreteValues.contains(value)) {
+ return false;
+ }
}
+ return true;
+ } else if (other instanceof NotDiscreteValue) {
+ return ((NotDiscreteValue) other).values.containsAll(values);
+ } else if (other instanceof CompoundValue) {
+ return ((CompoundValue) other).isContainedAllBy(this, depth);
+ } else {
+ return false;
+ }
+ }
+
+ @Override
+ protected IntersectType getIntersectType(ValueDesc other, int depth) {
+ if (other instanceof EmptyValue) {
+ return reference.nullable() ? IntersectType.EMPTY_VALUE :
IntersectType.FALSE;
+ } else if (other instanceof DiscreteValue) {
+ if (values.containsAll(((DiscreteValue) other).values)) {
+ return reference.nullable() ? IntersectType.EMPTY_VALUE :
IntersectType.FALSE;
+ }
+ } else if (other instanceof IsNullValue) {
+ return reference.nullable() ? IntersectType.EMPTY_VALUE :
IntersectType.FALSE;
+ } else if (other instanceof CompoundValue) {
+ return ((CompoundValue) other).getIntersectType(this, depth);
+ }
+ return IntersectType.OTHERS;
+ }
+
+ @Override
+ protected UnionType getUnionType(ValueDesc other, int depth) {
+ if (other instanceof RangeValue) {
+ return ((RangeValue) other).getUnionType(this, depth);
+ } else if (other instanceof DiscreteValue) {
+ return ((DiscreteValue) other).getUnionType(this, depth);
+ } else if (other instanceof NotDiscreteValue) {
+ Set<ComparableLiteral> notDiscreteValues = ((NotDiscreteValue)
other).values;
+ for (ComparableLiteral value : notDiscreteValues) {
+ if (values.contains(value)) {
+ return UnionType.OTHERS;
+ }
+ }
+ return reference.nullable() ? UnionType.RANGE_ALL :
UnionType.TRUE;
+ } else if (other instanceof IsNotNullValue) {
+ if (!reference.nullable()) {
+ return UnionType.TRUE;
+ }
+ } else if (other instanceof CompoundValue) {
+ return ((CompoundValue) other).getUnionType(this, depth);
+ }
+ return UnionType.OTHERS;
+ }
+ }
+
+ /**
+ * a is null
+ */
+ public static class IsNullValue extends ValueDesc {
+
+ public IsNullValue(ExpressionRewriteContext context, Expression
reference) {
+ super(context, reference);
+ }
+
+ @Override
+ protected <R, C> R visit(ValueDescVisitor<R, C> visitor, C context) {
+ return visitor.visitIsNullValue(this, context);
+ }
+
+ @Override
+ protected boolean nullable() {
+ return false;
+ }
+
+ @Override
+ protected boolean containsAll(ValueDesc other, int depth) {
+ if (other instanceof EmptyValue) {
+ return !reference.nullable();
+ } else if (other instanceof IsNullValue) {
+ return true;
+ } else if (other instanceof CompoundValue) {
+ return ((CompoundValue) other).isContainedAllBy(this, depth);
+ } else {
+ return false;
+ }
+ }
+
+ @Override
+ protected IntersectType getIntersectType(ValueDesc other, int depth) {
+ if (other instanceof EmptyValue || other instanceof RangeValue
+ || other instanceof DiscreteValue || other instanceof
NotDiscreteValue) {
+ return reference.nullable() ? IntersectType.EMPTY_VALUE :
IntersectType.FALSE;
+ } else if (other instanceof IsNotNullValue) {
+ return IntersectType.FALSE;
+ } else if (other instanceof CompoundValue) {
+ return ((CompoundValue) other).getIntersectType(this, depth);
}
+ return IntersectType.OTHERS;
+ }
+
+ @Override
+ protected UnionType getUnionType(ValueDesc other, int depth) {
if (other instanceof RangeValue) {
- return intersect(context, (RangeValue) other, this);
+ return ((RangeValue) other).getUnionType(this, depth);
+ } else if (other instanceof IsNotNullValue) {
+ return UnionType.TRUE;
+ } else {
+ return UnionType.OTHERS;
+ }
+ }
+ }
+
+ /**
+ * a is not null
+ */
+ public static class IsNotNullValue extends ValueDesc {
+ final Not not;
+
+ public IsNotNullValue(ExpressionRewriteContext context, Expression
reference, Not not) {
+ super(context, reference);
+ this.not = not;
+ }
+
+ public Not getNotExpression() {
+ return this.not;
+ }
+
+ @Override
+ protected <R, C> R visit(ValueDescVisitor<R, C> visitor, C context) {
+ return visitor.visitIsNotNullValue(this, context);
+ }
+
+ @Override
+ protected boolean nullable() {
+ return false;
+ }
+
+ @Override
+ protected boolean containsAll(ValueDesc other, int depth) {
+ if (other instanceof IsNotNullValue) {
+ return not.isGeneratedIsNotNull() == ((IsNotNullValue)
other).not.isGeneratedIsNotNull();
+ } else if (other instanceof CompoundValue) {
+ return ((CompoundValue) other).isContainedAllBy(this, depth);
+ } else {
+ return false;
+ }
+ }
+
+ @Override
+ protected IntersectType getIntersectType(ValueDesc other, int depth) {
+ if (other instanceof EmptyValue || other instanceof IsNullValue) {
+ return IntersectType.FALSE;
+ } else if (other instanceof CompoundValue) {
+ return ((CompoundValue) other).getIntersectType(this, depth);
+ } else {
+ return IntersectType.OTHERS;
}
- return new UnknownValue(context, ImmutableList.of(this, other),
true);
}
@Override
- public String toString() {
- return values.toString();
+ protected UnionType getUnionType(ValueDesc other, int depth) {
+ if (other instanceof EmptyValue || other instanceof RangeValue
+ || other instanceof DiscreteValue || other instanceof
NotDiscreteValue) {
+ if (!reference.nullable()) {
+ return UnionType.TRUE;
+ }
+ } else if (other instanceof IsNullValue) {
+ return UnionType.TRUE;
+ } else if (other instanceof CompoundValue) {
+ return ((CompoundValue) other).getUnionType(this, depth);
+ }
+ return UnionType.OTHERS;
}
}
/**
- * Represents processing result.
+ * Represents processing compound predicate.
*/
- public static class UnknownValue extends ValueDesc {
+ public static class CompoundValue extends ValueDesc {
+ private static final int MAX_SEARCH_DEPTH = 1;
private final List<ValueDesc> sourceValues;
private final boolean isAnd;
+ private final Set<Class<? extends ValueDesc>> subClasses;
+ private final boolean hasNullable;
+ private final boolean hasNoneNullable;
- private UnknownValue(ExpressionRewriteContext context, Expression
expr) {
- super(context, expr);
- sourceValues = ImmutableList.of();
- isAnd = false;
- }
-
- private UnknownValue(ExpressionRewriteContext context,
+ /** constructor */
+ public CompoundValue(ExpressionRewriteContext context, Expression
reference,
List<ValueDesc> sourceValues, boolean isAnd) {
- super(context, getReference(context, sourceValues, isAnd));
+ super(context, reference);
this.sourceValues = ImmutableList.copyOf(sourceValues);
this.isAnd = isAnd;
- }
-
- // reference is used to simplify multiple ValueDescs.
- // when ValueDesc A op ValueDesc B, only A and B's references equals,
- // can reduce them, like A op B = A.
- // If A and B's reference not equal, A op B will always get
UnknownValue(A op B).
- //
- // for example:
- // 1. RangeValue(a < 10, reference=a) union RangeValue(a > 20,
reference=a)
- // = UnknownValue1(a < 10 or a > 20, reference=a)
- // 2. RangeValue(a < 10, reference=a) union RangeValue(b > 20,
reference=b)
- // = UnknownValue2(a < 10 or b > 20, reference=(a < 10 or b > 20))
- // then given EmptyValue(, reference=a) E,
- // 1. since E and UnknownValue1's reference equals, then
- // E union UnknownValue1 = E.union(UnknownValue1) = UnknownValue1,
- // 2. since E and UnknownValue2's reference not equals, then
- // E union UnknownValue2 = UnknownValue3(E union UnknownValue2,
reference=E union UnknownValue2)
- private static Expression getReference(ExpressionRewriteContext
context,
- List<ValueDesc> sourceValues, boolean isAnd) {
- Expression reference = sourceValues.get(0).reference;
- for (int i = 1; i < sourceValues.size(); i++) {
- if (!reference.equals(sourceValues.get(i).reference)) {
- return SimplifyRange.INSTANCE.getExpression(context,
sourceValues, isAnd);
+ this.subClasses = Sets.newHashSet();
+ this.subClasses.add(getClass());
+ boolean hasNullable = false;
+ boolean hasNonNullable = false;
+ for (ValueDesc sourceValue : sourceValues) {
+ if (sourceValue instanceof CompoundValue) {
+ CompoundValue compoundSource = (CompoundValue) sourceValue;
+ this.subClasses.addAll(compoundSource.subClasses);
+ hasNullable = hasNullable || compoundSource.hasNullable;
+ hasNonNullable = hasNonNullable ||
compoundSource.hasNoneNullable;
+ } else {
+ this.subClasses.add(sourceValue.getClass());
+ hasNullable = hasNullable || sourceValue.nullable();
+ hasNonNullable = hasNonNullable || !sourceValue.nullable();
}
}
- return reference;
+ this.hasNullable = hasNullable;
+ this.hasNoneNullable = hasNonNullable;
}
public List<ValueDesc> getSourceValues() {
@@ -528,23 +1257,174 @@ public class RangeInference extends
ExpressionVisitor<RangeInference.ValueDesc,
}
@Override
- public ValueDesc union(ValueDesc other) {
- // for RangeValue/DiscreteValue/UnknownValue, when union with
EmptyValue,
- // call EmptyValue.union(this) => this
- if (other instanceof EmptyValue) {
- return other.union(this);
+ protected <R, C> R visit(ValueDescVisitor<R, C> visitor, C context) {
+ return visitor.visitCompoundValue(this, context);
+ }
+
+ @Override
+ protected boolean nullable() {
+ return hasNullable;
+ }
+
+ @Override
+ protected boolean containsAll(ValueDesc other, int depth) {
+ // in fact, when merge the value desc for the same reference,
+ // all the value desc should not be unknown value
+ if (depth > MAX_SEARCH_DEPTH || other instanceof UnknownValue ||
subClasses.contains(UnknownValue.class)) {
+ return false;
+ }
+ if (!isAnd && (!other.nullable() || !hasNoneNullable)) {
+ // for OR value desc:
+ // 1) if other not nullable, then no need to consider other is
null, this is null
+ // 2) if other is nullable, then when other is null, then the
reference is null,
+ // so if this OR no non-nullable, then this is null too.
+ for (ValueDesc valueDesc : sourceValues) {
+ if (valueDesc.containsAll(other, depth + 1)) {
+ return true;
+ }
+ }
+ return false;
+ } else {
+ // when other is nullable, why OR should check all source
values containsAll ?
+ // give an example: for an OR: (c1 or c2 or c3), suppose c1
containsAll other,
+ // then when other is null, the OR = null or c2 or c3, it may
not be null.
+ // a example: 'a > 1 or a is null' not contains all 'a > 10',
even if 'a > 1' contains all 'a > 10'
+ for (ValueDesc valueDesc : sourceValues) {
+ if (!valueDesc.containsAll(other, depth + 1)) {
+ return false;
+ }
+ }
+ return true;
+ }
+ }
+
+ // check other containsAll this
+ private boolean isContainedAllBy(ValueDesc other, int depth) {
+ // do want to process the complicate cases,
+ // and in fact, when merge value desc for same reference,
+ // all the value should not contain UnknownValue.
+ if (depth > MAX_SEARCH_DEPTH || other instanceof UnknownValue ||
subClasses.contains(UnknownValue.class)) {
+ return false;
+ }
+ if (isAnd) {
+ // for C = c1 and c2 and c3, suppose other containsAll c1,
then will have:
+ // when c1 is true, other is true,
+ // when c1 is null, other is null,
+ // so, when C is true, then c1 is true, so other is true,
+ // when C is null, then the reference must be null, so, c1
is null too, then other is null
+ for (ValueDesc valueDesc : sourceValues) {
+ if (other.containsAll(valueDesc, depth)) {
+ return true;
+ }
+ }
+ return false;
+ } else {
+ // for C = c1 or c2 or c3, suppose other contains c1, c2, c3.
+ // so when C is true, then at least one ci is true, so other
is true.
+ // when C is null, then at least one ci is null, so other
is null.
+ // so other will contain all C
+ for (ValueDesc valueDesc : sourceValues) {
+ if (!other.containsAll(valueDesc, depth)) {
+ return false;
+ }
+ }
+ return true;
}
- return new UnknownValue(context, ImmutableList.of(this, other),
false);
}
@Override
- public ValueDesc intersect(ValueDesc other) {
- // for RangeValue/DiscreteValue/UnknownValue, when intersect with
EmptyValue,
- // call EmptyValue.intersect(this) => EmptyValue
- if (other instanceof EmptyValue) {
- return other.intersect(this);
+ protected IntersectType getIntersectType(ValueDesc other, int depth) {
+ if ((!nullable() && other.nullable()) || depth > MAX_SEARCH_DEPTH)
{
+ return IntersectType.OTHERS;
+ }
+ if (isAnd) {
+ boolean hasEmptyValueType = false;
+ for (ValueDesc valueDesc : sourceValues) {
+ IntersectType type = valueDesc.getIntersectType(other,
depth + 1);
+ if (type == IntersectType.FALSE) {
+ return type;
+ }
+ hasEmptyValueType = hasEmptyValueType || type ==
IntersectType.EMPTY_VALUE;
+ }
+ return hasEmptyValueType ? IntersectType.EMPTY_VALUE :
IntersectType.OTHERS;
+ } else {
+ boolean hasEmptyValueType = false;
+ for (ValueDesc valueDesc : sourceValues) {
+ IntersectType type = valueDesc.getIntersectType(other,
depth + 1);
+ if (type == IntersectType.OTHERS) {
+ return type;
+ }
+ hasEmptyValueType = hasEmptyValueType || type ==
IntersectType.EMPTY_VALUE;
+ }
+ return hasEmptyValueType ? IntersectType.EMPTY_VALUE :
IntersectType.FALSE;
}
- return new UnknownValue(context, ImmutableList.of(this, other),
true);
+ }
+
+ @Override
+ protected UnionType getUnionType(ValueDesc other, int depth) {
+ if ((!nullable() && other.nullable()) || depth > MAX_SEARCH_DEPTH)
{
+ return UnionType.OTHERS;
+ }
+ if (isAnd) {
+ UnionType resultType = null;
+ for (ValueDesc valueDesc : sourceValues) {
+ UnionType type = valueDesc.getUnionType(other, depth + 1);
+ if (type == UnionType.OTHERS) {
+ return type;
+ }
+ if (resultType == null) {
+ resultType = type;
+ }
+ if (resultType != type) {
+ return UnionType.OTHERS;
+ }
+ }
+ return resultType;
+ } else {
+ for (ValueDesc valueDesc : sourceValues) {
+ UnionType type = valueDesc.getUnionType(other, depth + 1);
+ if (type != UnionType.OTHERS) {
+ return type;
+ }
+ }
+ return UnionType.OTHERS;
+ }
+ }
+ }
+
+ /**
+ * Represents unknown value expression.
+ */
+ public static class UnknownValue extends ValueDesc {
+
+ public UnknownValue(ExpressionRewriteContext context, Expression
expression) {
+ super(context, expression);
+ }
+
+ @Override
+ protected <R, C> R visit(ValueDescVisitor<R, C> visitor, C context) {
+ return visitor.visitUnknownValue(this, context);
+ }
+
+ @Override
+ protected boolean nullable() {
+ return reference.nullable();
+ }
+
+ @Override
+ protected boolean containsAll(ValueDesc other, int depth) {
+ // when merge all the value desc, the value desc's reference are
the same.
+ return other instanceof UnknownValue;
+ }
+
+ @Override
+ protected IntersectType getIntersectType(ValueDesc other, int depth) {
+ return IntersectType.OTHERS;
+ }
+
+ @Override
+ protected UnionType getUnionType(ValueDesc other, int depth) {
+ return UnionType.OTHERS;
}
}
}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyRange.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyRange.java
index d47f5877a96..0f8a1d60fe8 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyRange.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyRange.java
@@ -21,11 +21,16 @@ import
org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher;
import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory;
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
import org.apache.doris.nereids.rules.expression.ExpressionRuleType;
+import
org.apache.doris.nereids.rules.expression.rules.RangeInference.CompoundValue;
import
org.apache.doris.nereids.rules.expression.rules.RangeInference.DiscreteValue;
import
org.apache.doris.nereids.rules.expression.rules.RangeInference.EmptyValue;
+import
org.apache.doris.nereids.rules.expression.rules.RangeInference.IsNotNullValue;
+import
org.apache.doris.nereids.rules.expression.rules.RangeInference.IsNullValue;
+import
org.apache.doris.nereids.rules.expression.rules.RangeInference.NotDiscreteValue;
import
org.apache.doris.nereids.rules.expression.rules.RangeInference.RangeValue;
import
org.apache.doris.nereids.rules.expression.rules.RangeInference.UnknownValue;
import
org.apache.doris.nereids.rules.expression.rules.RangeInference.ValueDesc;
+import
org.apache.doris.nereids.rules.expression.rules.RangeInference.ValueDescVisitor;
import org.apache.doris.nereids.rules.rewrite.SkipSimpleExprs;
import org.apache.doris.nereids.trees.expressions.CompoundPredicate;
import org.apache.doris.nereids.trees.expressions.Expression;
@@ -34,6 +39,7 @@ import
org.apache.doris.nereids.trees.expressions.GreaterThanEqual;
import org.apache.doris.nereids.trees.expressions.IsNull;
import org.apache.doris.nereids.trees.expressions.LessThan;
import org.apache.doris.nereids.trees.expressions.LessThanEqual;
+import org.apache.doris.nereids.trees.expressions.Not;
import org.apache.doris.nereids.trees.expressions.literal.ComparableLiteral;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.util.ExpressionUtils;
@@ -43,10 +49,9 @@ import com.google.common.collect.BoundType;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import com.google.common.collect.Range;
-import org.apache.commons.lang3.NotImplementedException;
import java.util.List;
-import java.util.stream.Collectors;
+import java.util.Set;
/**
* This class implements the function to simplify expression range.
@@ -69,7 +74,7 @@ import java.util.stream.Collectors;
* 2. for `Or` expression (similar to `And`).
* todo: support a > 10 and (a < 10 or a > 20 ) => a > 20
*/
-public class SimplifyRange implements ExpressionPatternRuleFactory {
+public class SimplifyRange implements ExpressionPatternRuleFactory,
ValueDescVisitor<Expression, Void> {
public static final SimplifyRange INSTANCE = new SimplifyRange();
@Override
@@ -83,34 +88,22 @@ public class SimplifyRange implements
ExpressionPatternRuleFactory {
}
/** rewrite */
- public static Expression rewrite(CompoundPredicate expr,
ExpressionRewriteContext context) {
+ public Expression rewrite(CompoundPredicate expr, ExpressionRewriteContext
context) {
if (SkipSimpleExprs.isSimpleExpr(expr)) {
return expr;
}
ValueDesc valueDesc = (new RangeInference()).getValue(expr, context);
- return INSTANCE.getExpression(valueDesc);
+ return valueDesc.accept(this, null);
}
- private Expression getExpression(ValueDesc value) {
- if (value instanceof EmptyValue) {
- return getExpression((EmptyValue) value);
- } else if (value instanceof DiscreteValue) {
- return getExpression((DiscreteValue) value);
- } else if (value instanceof RangeValue) {
- return getExpression((RangeValue) value);
- } else if (value instanceof UnknownValue) {
- return getExpression((UnknownValue) value);
- } else {
- throw new NotImplementedException("not implements");
- }
- }
-
- private Expression getExpression(EmptyValue value) {
+ @Override
+ public Expression visitEmptyValue(EmptyValue value, Void context) {
Expression reference = value.getReference();
return ExpressionUtils.falseOrNull(reference);
}
- private Expression getExpression(RangeValue value) {
+ @Override
+ public Expression visitRangeValue(RangeValue value, Void context) {
Expression reference = value.getReference();
Range<ComparableLiteral> range = value.getRange();
List<Expression> result = Lists.newArrayList();
@@ -135,27 +128,51 @@ public class SimplifyRange implements
ExpressionPatternRuleFactory {
}
}
- private Expression getExpression(DiscreteValue value) {
- return ExpressionUtils.toInPredicateOrEqualTo(value.getReference(),
-
value.getValues().stream().map(Literal.class::cast).collect(Collectors.toList()));
+ @Override
+ public Expression visitDiscreteValue(DiscreteValue value, Void context) {
+ return getDiscreteExpression(value.getReference(), value.values);
}
- private Expression getExpression(UnknownValue value) {
- List<ValueDesc> sourceValues = value.getSourceValues();
- if (sourceValues.isEmpty()) {
- return value.getReference();
- } else {
- return getExpression(value.getExpressionRewriteContext(),
sourceValues, value.isAnd());
+ @Override
+ public Expression visitNotDiscreteValue(NotDiscreteValue value, Void
context) {
+ return new Not(getDiscreteExpression(value.getReference(),
value.values));
+ }
+
+ @Override
+ public Expression visitIsNullValue(IsNullValue value, Void context) {
+ return new IsNull(value.getReference());
+ }
+
+ @Override
+ public Expression visitIsNotNullValue(IsNotNullValue value, Void context) {
+ return value.getNotExpression();
+ }
+
+ @Override
+ public Expression visitCompoundValue(CompoundValue value, Void context) {
+ return getCompoundExpression(value.getExpressionRewriteContext(),
value.getSourceValues(), value.isAnd());
+ }
+
+ @Override
+ public Expression visitUnknownValue(UnknownValue value, Void context) {
+ return value.getReference();
+ }
+
+ private Expression getDiscreteExpression(Expression reference,
Set<ComparableLiteral> values) {
+ ImmutableList.Builder<Expression> options =
ImmutableList.builderWithExpectedSize(values.size());
+ for (ComparableLiteral value : values) {
+ options.add((Expression) value);
}
+ return ExpressionUtils.toInPredicateOrEqualTo(reference,
options.build());
}
/** getExpression */
- public Expression getExpression(ExpressionRewriteContext context,
+ public Expression getCompoundExpression(ExpressionRewriteContext context,
List<ValueDesc> sourceValues, boolean isAnd) {
Preconditions.checkArgument(!sourceValues.isEmpty());
List<Expression> sourceExprs =
Lists.newArrayListWithExpectedSize(sourceValues.size());
for (ValueDesc sourceValue : sourceValues) {
- Expression expr = getExpression(sourceValue);
+ Expression expr = sourceValue.accept(this, null);
if (isAnd) {
sourceExprs.addAll(ExpressionUtils.extractConjunction(expr));
} else {
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PullUpPredicates.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PullUpPredicates.java
index 94a1cff6c9a..8575ab9e5a2 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PullUpPredicates.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PullUpPredicates.java
@@ -336,7 +336,8 @@ public class PullUpPredicates extends
PlanVisitor<ImmutableSet<Expression>, Void
genPredicates =
FoldConstantRuleOnFE.evaluate(genPredicates, rewriteContext);
if (isScalar) {
// Aggregation will return null if there are
no matching rows
- pullPredicates.add(new Or(new IsNull(slot),
genPredicates));
+ // SimplifyRange will put IsNull at the back
+ pullPredicates.add(new Or(genPredicates, new
IsNull(slot)));
} else {
pullPredicates.add(genPredicates);
}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Not.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Not.java
index 12c0252d3a3..e12276ff57f 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Not.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Not.java
@@ -50,13 +50,11 @@ public class Not extends Expression implements
UnaryExpression, ExpectsInputType
}
public Not(Expression child, boolean isGeneratedIsNotNull) {
- super(ImmutableList.of(child));
- this.isGeneratedIsNotNull = isGeneratedIsNotNull;
+ this(ImmutableList.of(child), isGeneratedIsNotNull);
}
private Not(List<Expression> child, boolean isGeneratedIsNotNull) {
- super(child);
- this.isGeneratedIsNotNull = isGeneratedIsNotNull;
+ this(child, isGeneratedIsNotNull, false);
}
public boolean isGeneratedIsNotNull() {
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java
index 3d7496d1ba8..0efa98ddb0b 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java
@@ -644,26 +644,6 @@ public class ExpressionUtils {
return true;
}
- /** matchNumericType */
- public static boolean matchNumericType(List<Expression> children) {
- for (Expression child : children) {
- if (!child.getDataType().isNumericType()) {
- return false;
- }
- }
- return true;
- }
-
- /** matchDateLikeType */
- public static boolean matchDateLikeType(List<Expression> children) {
- for (Expression child : children) {
- if (!child.getDataType().isDateLikeType()) {
- return false;
- }
- }
- return true;
- }
-
/** hasNullLiteral */
public static boolean hasNullLiteral(List<Expression> children) {
for (Expression child : children) {
@@ -674,16 +654,6 @@ public class ExpressionUtils {
return false;
}
- /** hasOnlyMetricType */
- public static boolean hasOnlyMetricType(List<Expression> children) {
- for (Expression child : children) {
- if (child.getDataType().isOnlyMetricType()) {
- return true;
- }
- }
- return false;
- }
-
/**
* canInferNotNullForMarkSlot
*/
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/ExpressionRewriteTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/ExpressionRewriteTest.java
index 43e44255f66..638d24a3949 100644
---
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/ExpressionRewriteTest.java
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/ExpressionRewriteTest.java
@@ -41,12 +41,16 @@ import
org.apache.doris.nereids.trees.expressions.literal.SmallIntLiteral;
import org.apache.doris.nereids.trees.expressions.literal.StringLiteral;
import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral;
import org.apache.doris.nereids.trees.expressions.literal.VarcharLiteral;
+import org.apache.doris.nereids.trees.plans.RelationId;
+import org.apache.doris.nereids.trees.plans.logical.LogicalEmptyRelation;
+import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.types.DecimalV2Type;
import org.apache.doris.nereids.types.DecimalV3Type;
import org.apache.doris.nereids.types.StringType;
import org.apache.doris.nereids.types.VarcharType;
import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableSet;
import org.junit.jupiter.api.Test;
import java.math.BigDecimal;
@@ -56,6 +60,14 @@ import java.math.BigDecimal;
*/
class ExpressionRewriteTest extends ExpressionRewriteTestHelper {
+ public ExpressionRewriteTest() {
+ super();
+ LogicalFilter<?> filter = new
LogicalFilter<LogicalEmptyRelation>(ImmutableSet.of(),
+ new LogicalEmptyRelation(new RelationId(1),
ImmutableList.of()));
+ // AddMinMax run in filter plan
+ context = new ExpressionRewriteContext(filter, cascadesContext);
+ }
+
@Test
void testNotRewrite() {
executor = new ExpressionRuleExecutor(ImmutableList.of(
@@ -389,26 +401,26 @@ class ExpressionRewriteTest extends
ExpressionRewriteTestHelper {
assertRewriteAfterTypeCoercion("ISNULL(TA) and TA between 20 and 10",
"ISNULL(TA) and null");
// assertRewriteAfterTypeCoercion("ISNULL(TA) and TA > 10",
"ISNULL(TA) and null"); // should be, but not support now
assertRewriteAfterTypeCoercion("ISNULL(TA) and TA > 10 and null",
"ISNULL(TA) and null");
- assertRewriteAfterTypeCoercion("ISNULL(TA) or TA > 10", "ISNULL(TA) or
TA > 10");
+ assertRewriteAfterTypeCoercion("ISNULL(TA) or TA > 10", "TA > 10 or
ISNULL(TA)");
// assertRewriteAfterTypeCoercion("(TA < 30 or TA > 40) and TA between
20 and 10", "TA IS NULL AND NULL"); // should be, but not support because
flatten and
assertRewriteAfterTypeCoercion("(TA < 30 or TA > 40) and TA is null
and null", "TA IS NULL AND NULL");
assertRewriteAfterTypeCoercion("(TA < 30 or TA > 40) or TA between 20
and 10", "TA < 30 or TA > 40");
assertRewriteAfterTypeCoercion("TA between 10 and 20 or TA between 30
and 40 or TA between 60 and 50",
- "(TA <= 20 or TA >= 30) and TA >= 10 and TA <= 40");
+ "TA >= 10 and TA <= 40 and (TA <= 20 or TA >= 30)");
// should be, but not support yet, because 'TA is null and null' =>
UnknownValue(EmptyValue(TA) and null)
//assertRewriteAfterTypeCoercion("TA between 10 and 20 or TA between
30 and 40 or TA is null and null",
// "(TA <= 20 or TA >= 30) and TA >= 10 and TA <= 40");
assertRewriteAfterTypeCoercion("TA between 10 and 20 or TA between 30
and 40 or TA is null and null",
- "(TA <= 20 or TA >= 30 or TA is null and null) and TA >= 10
and TA <= 40");
+ "TA >= 10 and TA <= 40 and (TA <= 20 or TA >= 30)");
assertRewriteAfterTypeCoercion("TA between 10 and 20 or TA between 30
and 40 or TA is null",
"TA >= 10 and TA <= 20 or TA >= 30 and TA <= 40 or TA is
null");
assertRewriteAfterTypeCoercion("ISNULL(TB) and (TA between 10 and 20
or TA between 30 and 40 or TA between 60 and 50)",
- "ISNULL(TB) and ((TA <= 20 or TA >= 30) and TA >= 10 and TA <=
40)");
+ "ISNULL(TB) and TA >= 10 and TA <= 40 and (TA <= 20 or TA >=
30)");
assertRewriteAfterTypeCoercion("ISNULL(TB) and (TA between 10 and 20
or TA between 30 and 40 or TA is null)",
"ISNULL(TB) and (TA >= 10 and TA <= 20 or TA >= 30 and TA <=
40 or TA is null)");
assertRewriteAfterTypeCoercion("TB between 20 and 10 and (TA between
10 and 20 or TA between 30 and 40 or TA between 60 and 50)",
- "TB IS NULL AND NULL and (TA <= 20 or TA >= 30) and TA >= 10
and TA <= 40");
+ "TB IS NULL AND NULL and TA >= 10 and TA <= 40 and (TA <= 20
or TA >= 30)");
assertRewriteAfterTypeCoercion("TA between 10 and 20 and TB between 10
and 20 or TA between 30 and 40 and TB between 30 and 40 or TA between 60 and 50
and TB between 60 and 50",
"(TA <= 20 and TB <= 20 or TA >= 30 and TB >= 30 or TA is null
and null and TB is null) and TA >= 10 and TA <= 40 and TB >= 10 and TB <= 40");
}
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/SimplifyRangeTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/SimplifyRangeTest.java
index 5fff6e51ace..fabafa2399f 100644
---
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/SimplifyRangeTest.java
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/SimplifyRangeTest.java
@@ -23,7 +23,11 @@ import org.apache.doris.nereids.analyzer.UnboundSlot;
import org.apache.doris.nereids.parser.NereidsParser;
import org.apache.doris.nereids.rules.analysis.ExpressionAnalyzer;
import org.apache.doris.nereids.rules.expression.rules.RangeInference;
+import
org.apache.doris.nereids.rules.expression.rules.RangeInference.CompoundValue;
import
org.apache.doris.nereids.rules.expression.rules.RangeInference.EmptyValue;
+import
org.apache.doris.nereids.rules.expression.rules.RangeInference.IsNotNullValue;
+import
org.apache.doris.nereids.rules.expression.rules.RangeInference.IsNullValue;
+import
org.apache.doris.nereids.rules.expression.rules.RangeInference.NotDiscreteValue;
import
org.apache.doris.nereids.rules.expression.rules.RangeInference.RangeValue;
import
org.apache.doris.nereids.rules.expression.rules.RangeInference.UnknownValue;
import
org.apache.doris.nereids.rules.expression.rules.RangeInference.ValueDesc;
@@ -58,33 +62,54 @@ public class SimplifyRangeTest extends ExpressionRewrite {
private static final NereidsParser PARSER = new NereidsParser();
private ExpressionRuleExecutor executor;
private ExpressionRewriteContext context;
+ private final Map<String, Slot> commonMem;
public SimplifyRangeTest() {
CascadesContext cascadesContext = MemoTestUtils.createCascadesContext(
new UnboundRelation(new RelationId(1),
ImmutableList.of("tbl")));
context = new ExpressionRewriteContext(cascadesContext);
+ commonMem = Maps.newHashMap();
}
@Test
public void testRangeInference() {
ValueDesc valueDesc = getValueDesc("TA IS NULL");
+ Assertions.assertInstanceOf(IsNullValue.class, valueDesc);
+ Assertions.assertEquals("TA", valueDesc.getReference().toSql());
+
+ valueDesc = getValueDesc("NULL");
Assertions.assertInstanceOf(UnknownValue.class, valueDesc);
- List<ValueDesc> sourceValues = ((UnknownValue)
valueDesc).getSourceValues();
- Assertions.assertEquals(0, sourceValues.size());
- Assertions.assertEquals("TA IS NULL",
valueDesc.getReference().toSql());
+ Assertions.assertEquals("NULL", valueDesc.getReference().toSql());
+
+ valueDesc = getValueDesc("TA IS NOT NULL");
+ Assertions.assertInstanceOf(IsNotNullValue.class, valueDesc);
+ Assertions.assertEquals("TA", valueDesc.getReference().toSql());
+
+ valueDesc = getValueDesc("TA != 10");
+ Assertions.assertInstanceOf(NotDiscreteValue.class, valueDesc);
+ Assertions.assertEquals("TA", valueDesc.getReference().toSql());
+
+ valueDesc = getValueDesc("TA IS NULL AND NULL");
+ Assertions.assertInstanceOf(EmptyValue.class, valueDesc);
+ Assertions.assertEquals("TA", valueDesc.getReference().toSql());
+
+ valueDesc = getValueDesc("TA IS NOT NULL OR NULL");
+ Assertions.assertInstanceOf(RangeValue.class, valueDesc);
+ Assertions.assertEquals("TA", valueDesc.getReference().toSql());
+ Assertions.assertTrue(((RangeValue) valueDesc).isRangeAll());
valueDesc = getValueDesc("TA IS NULL AND TB IS NULL AND NULL");
- Assertions.assertInstanceOf(UnknownValue.class, valueDesc);
- sourceValues = ((UnknownValue) valueDesc).getSourceValues();
- Assertions.assertEquals(3, sourceValues.size());
+ Assertions.assertInstanceOf(CompoundValue.class, valueDesc);
+ List<ValueDesc> sourceValues = ((CompoundValue)
valueDesc).getSourceValues();
+ Assertions.assertEquals(2, sourceValues.size());
Assertions.assertInstanceOf(EmptyValue.class, sourceValues.get(0));
Assertions.assertInstanceOf(EmptyValue.class, sourceValues.get(1));
Assertions.assertEquals("TA",
sourceValues.get(0).getReference().toSql());
Assertions.assertEquals("TB",
sourceValues.get(1).getReference().toSql());
valueDesc = getValueDesc("L + RANDOM(1, 10) > 8 AND L + RANDOM(1, 10)
< 1");
- Assertions.assertInstanceOf(UnknownValue.class, valueDesc);
- sourceValues = ((UnknownValue) valueDesc).getSourceValues();
+ Assertions.assertInstanceOf(CompoundValue.class, valueDesc);
+ sourceValues = ((CompoundValue) valueDesc).getSourceValues();
Assertions.assertEquals(2, sourceValues.size());
for (ValueDesc value : sourceValues) {
Assertions.assertInstanceOf(RangeValue.class, value);
@@ -93,11 +118,109 @@ public class SimplifyRangeTest extends ExpressionRewrite {
}
@Test
- public void testSimplify() {
+ public void testValueDescContainsAll() {
+ SlotReference xa = new SlotReference("xa", IntegerType.INSTANCE,
false);
+
+ checkContainsAll(true, "TA is null and null", "TA is null and null");
+ checkContainsAll(false, "TA is null and null", "TA > 1");
+ checkContainsAll(false, "TA is null and null", "TA = 1");
+ checkContainsAll(false, "TA is null and null", "TA != 1");
+ checkContainsAll(false, "TA is null and null", "TA is null");
+ // XA is null and null will rewrite to 'FALSE'
+ // checkContainsAll(true, "XA is null and null", "XA is null");
+ Assertions.assertTrue(new EmptyValue(context, xa).containsAll(new
IsNullValue(context, xa)));
+ checkContainsAll(false, "TA is null and null", "TA = 1 or TA > 10");
+
+ checkContainsAll(true, "TA > 1", "TA is null and null");
+ checkContainsAll(true, "TA > 1", "TA > 10");
+ checkContainsAll(false, "TA > 1", "TA > 0");
+ checkContainsAll(true, "TA >= 1", "TA > 1");
+ checkContainsAll(false, "TA > 1", "TA >= 1");
+ checkContainsAll(true, "TA > 1", "TA > 1");
+ checkContainsAll(true, "TA > 1", "TA > 1 and TA < 10");
+ checkContainsAll(false, "TA > 1", "TA >= 1 and TA < 10");
+ checkContainsAll(true, "TA > 0", "TA in (1, 2, 3)");
+ checkContainsAll(false, "TA > 0", "TA in (-1, 1, 2, 3)");
+ checkContainsAll(false, "TA > 1", "TA != 0");
+ checkContainsAll(false, "TA > 1", "TA != 1");
+ checkContainsAll(false, "TA > 1", "TA != 2");
+ checkContainsAll(true, "TA is not null or null", "TA != 2");
+ checkContainsAll(false, "TA is not null or null", "TA is null");
+ checkContainsAll(true, "TA is not null or null", "TA is not null");
+ checkContainsAll(true, "TA is not null or null", "TA is null and
null");
+ checkContainsAll(false, "TA > 1", "TA is null");
+ checkContainsAll(false, "TA > 1", "TA is not null");
+ checkContainsAll(true, "TA > 1", "(TA > 2 and TA < 5) or (TA > 7 and
TA < 9)");
+ checkContainsAll(false, "TA > 1", "(TA >= 1 and TA < 5) or (TA > 7
and TA < 9)");
+ checkContainsAll(true, "TA > 1", "TA > 5 and TA is not null");
+ checkContainsAll(true, "TA > 1", "(TA > 5 and TA < 8) and TA is not
null");
+ checkContainsAll(true, "TA > 1", "TA > 5 and TA != 0");
+ checkContainsAll(false, "TA > 1", "TA > 5 or TA is not null");
+ checkContainsAll(false, "TA > 1", "TA > 5 or TA != 0");
+
+ checkContainsAll(true, "TA in (1, 2, 3)", "TA is null and null");
+ checkContainsAll(false, "TA in (1, 2, 3, 4)", "TA between 2 and 3");
+ checkContainsAll(true, "TA in (1, 2, 3)", "TA in (1, 2)");
+ checkContainsAll(false, "TA in (1, 2, 3)", "TA in (1, 2, 4)");
+ checkContainsAll(false, "TA in (1, 2, 3)", "TA not in (1, 2)");
+ checkContainsAll(false, "TA in (1, 2, 3)", "TA not in (5, 6)");
+ checkContainsAll(false, "TA in (1, 2, 3)", "TA is null");
+ checkContainsAll(false, "TA in (1, 2, 3)", "TA is not null");
+ checkContainsAll(true, "TA in (1, 2, 3)", "TA in (1, 2) and TA is not
null");
+ checkContainsAll(true, "TA in (1, 2, 3)", "TA in (1, 2) and TA is
null");
+ checkContainsAll(false, "TA in (1, 2, 3)", "TA != 1 and TA is not
null");
+ checkContainsAll(false, "TA in (0, 1, 2, 3)", "TA between 1 and 2 and
TA is not null");
+
+ checkContainsAll(true, "TA not in (1, 2)", "TA is null and null");
+ checkContainsAll(false, "TA not in (1, 2, 3, 4, 5)", "TA between 2 and
4");
+ checkContainsAll(false, "TA not in (1, 2, 3, 4, 5)", "TA is not null
or null");
+ checkContainsAll(false, "TA not in (1, 2)", "TA in (1)");
+ checkContainsAll(false, "TA not in (1, 2)", "TA in (1, 2)");
+ checkContainsAll(true, "TA not in (1, 2)", "TA in (3, 4)");
+ checkContainsAll(false, "TA not in (1, 2, 3)", "TA in (1, 4)");
+ checkContainsAll(false, "TA not in (1, 2, 3)", "TA is null");
+ checkContainsAll(false, "TA not in (1, 2, 3)", "TA is not null");
+ checkContainsAll(false, "TA not in (1, 2, 3)", "TA is not null or
null");
+ checkContainsAll(true, "TA not in (1, 2)", "(TA is not null or null)
and (TA is null or TA > 10)");
+
+ checkContainsAll(false, "TA is null", "TA is null and null");
+ checkContainsAll(false, "TA is null", "TA > 10");
+ checkContainsAll(false, "TA is null", "TA = 10");
+ checkContainsAll(false, "TA is null", "TA != 10");
+ checkContainsAll(true, "TA is null", "TA is null");
+ checkContainsAll(false, "TA is null", "TA is not null");
+ checkContainsAll(true, "TA is null", "TA is null and (TA > 10)");
+ checkContainsAll(false, "TA is null", "TA is null or (TA > 10)");
+
+ checkContainsAll(false, "TA is not null", "TA is null and null");
+ checkContainsAll(false, "TA is not null", "TA > 10");
+ checkContainsAll(false, "TA is not null", "TA = 10");
+ checkContainsAll(false, "TA is not null", "TA != 10");
+ checkContainsAll(false, "TA is not null", "TA is null");
+ checkContainsAll(true, "TA is not null", "TA is not null");
+ checkContainsAll(false, "TA is not null", "TA is not null or null");
+ checkContainsAll(true, "TA is not null", "TA is not null and (TA >
10)");
+ checkContainsAll(false, "TA is not null", "TA is not null or (TA >
10)");
+
+ checkContainsAll(true, "TA < 1 or TA > 10", "TA is null and null");
+ checkContainsAll(true, "TA < 1 or TA > 10", "TA < 0");
+ checkContainsAll(false, "TA < 1 or TA > 10", "TA <= 1");
+ checkContainsAll(true, "TA < 1 or TA > 10", "TA = 0");
+ checkContainsAll(false, "TA < 1 or TA > 10", "TA in (0, 1)");
+ checkContainsAll(true, "TA not in (1, 2, 13) or TA > 10", "TA not in
(1, 2, 13, 15)");
+ }
+
+ private void checkContainsAll(boolean isContains, String expr1, String
expr2) {
+ Assertions.assertEquals(isContains,
getValueDesc(expr1).containsAll(getValueDesc(expr2)));
+ }
+
+ @Test
+ public void testSimplifyNumeric() {
executor = new ExpressionRuleExecutor(ImmutableList.of(
bottomUp(SimplifyRange.INSTANCE)
));
assertRewrite("TA", "TA");
+ assertRewrite("TA > 10 and (TA > 20 or TA < 10)", "TA > 20");
assertRewrite("TA > 3 or TA > null", "TA > 3 OR NULL");
assertRewrite("TA > 3 or TA < null", "TA > 3 OR NULL");
assertRewrite("TA > 3 or TA = null", "TA > 3 OR NULL");
@@ -107,6 +230,7 @@ public class SimplifyRangeTest extends ExpressionRewrite {
"TA in (11, 12) OR TA <= 10 OR TA >= 13");
assertRewrite("TA > 3 or TA <> null", "TA > 3 or null");
assertRewrite("TA > 3 or TA <=> null", "TA > 3 or TA <=> null");
+ assertRewrite("TA >= 0 and TA <= 3", "TA >= 0 and TA <= 3");
assertRewrite("(TA < 1 or TA > 2) or (TA >= 0 and TA <= 3)", "TA IS
NOT NULL OR NULL");
assertRewrite("TA between 10 and 20 or TA between 100 and 120 or TA
between 15 and 25 or TA between 115 and 125",
"TA >= 10 and TA <= 25 or TA >= 100 and TA <= 125");
@@ -128,10 +252,51 @@ public class SimplifyRangeTest extends ExpressionRewrite {
assertRewriteNotNull("TA = 1 and TA > 10", "FALSE");
assertRewrite("TA = 1 and TA > 10", "TA is null and null");
assertRewrite("TA >= 1 and TA <= 1", "TA = 1");
+ assertRewrite("TA = 1 and TA = 2", "TA IS NULL AND NULL");
+ assertRewriteNotNull("TA = 1 and TA = 2", "FALSE");
+ assertRewrite("TA not in (1) and TA not in (1)", "TA != 1");
+ assertRewrite("TA not in (1, 2, 3) and TA not in (1, 4, 5)", "TA not
in (1, 2, 3, 4, 5)");
+ assertRewrite("TA = 1 and TA not in (2)", "TA = 1");
+ assertRewrite("TA = 1 and TA not in (1, 2)", "TA is null and null");
+ assertRewriteNotNull("TA = 1 and TA not in (1, 2)", "FALSE");
+ assertRewrite("TA > 10 and TA not in (1, 2, 3)", "TA > 10");
+ assertRewrite("TA > 10 and TA not in (1, 2, 3, 11)", "TA > 10 and TA
!= 11");
+ assertRewrite("TA > 10 and TA not in (1, 2, 3, 11, 12)", "TA > 10 and
TA NOT IN (11, 12)");
+ assertRewrite("TA is null", "TA is null");
+ assertRewriteNotNull("TA is null", "TA is null");
+ assertRewrite("TA is not null", "TA is not null");
+ assertRewrite("TA is null and TA is not null", "FALSE");
+ assertRewriteNotNull("TA is null and TA is not null", "FALSE");
+ assertRewrite("TA = 1 and TA != 1 and TA is null", "TA is null and
null");
+ assertRewriteNotNull("TA = 1 and TA != 1 and TA is null", "FALSE");
+ assertRewrite("TA = 1 and TA != 1 and TA is not null", "FALSE");
+ assertRewriteNotNull("TA = 1 and TA != 1 and TA is not null", "FALSE");
+ assertRewrite("TA = 1 and TA != 1 and (TA > 10 or TA < 5)", "TA is
null and null");
+ assertRewriteNotNull("TA = 1 and TA != 1 and (TA > 10 or TA < 5)",
"FALSE");
+ assertRewrite("TA = 1 and TA != 1 and (TA > 10 or TA is not null)",
"TA is null and null");
+ assertRewrite("TA = 1 and TA != 1 and (TA > 10 or (TA < 5 and TA is
not null))", "TA is null and null");
+ assertRewrite("TA = 1 and TA != 1 and (TA > 10 or (TA < 5 and TA is
not null) or (TA > 7 and TA is not null))",
+ "TA is null and null");
assertRewrite("TA > 5 or TA < 1", "TA < 1 or TA > 5");
assertRewrite("TA > 5 or TA > 1 or TA > 10", "TA > 1");
assertRewrite("TA > 5 or TA > 1 or TA < 10", "TA is not null or null");
assertRewriteNotNull("TA > 5 or TA > 1 or TA < 10", "TRUE");
+ assertRewrite("TA != 1 or TA != 1", "TA != 1");
+ assertRewrite("TA != 1 or TA != 2", "TA is not null or null");
+ assertRewriteNotNull("TA != 1 or TA != 2", "TRUE");
+ assertRewrite("TA not in (1, 2, 3) or TA not in (1, 2, 4)", "TA not in
(1, 2)");
+ assertRewrite("TA not in (1, 2) or TA in (2, 1)", "TA is not null or
null");
+ assertRewrite("TA not in (1, 2) or TA in (1)", "TA != 2");
+ assertRewrite("TA not in (1, 2) or TA in (1, 2, 3)", "TA is not null
or null");
+ assertRewrite("TA not in (1, 3) or TA < 2", "TA != 3");
+ assertRewrite("TA is null and null", "TA is null and null");
+ assertRewrite("TA is null", "TA is null");
+ assertRewrite("TA is null and null or TA = 1", "TA = 1");
+ assertRewrite("TA is null and null or TA is null", "TA is null");
+ assertRewrite("TA is null and null or (TA is null and TA > 10) ", "(TA
> 10 and TA is null) or TA is null and null");
+ assertRewrite("TA is null and null or TA is not null", "TA is not null
or TA is null and null");
+ assertRewriteNotNull("TA != 1 or TA != 2", "TRUE");
+ assertRewrite("TA is null or TA is not null", "TRUE");
assertRewrite("TA > 5 and TA > 1 and TA > 10", "TA > 10");
assertRewrite("TA > 5 and TA > 1 and TA < 10", "TA > 5 and TA < 10");
assertRewrite("TA > 1 or TA < 1", "TA < 1 or TA > 1");
@@ -149,12 +314,14 @@ public class SimplifyRangeTest extends ExpressionRewrite {
assertRewrite("(TA > 10 or TA > 20) and (TB > 10 and TB > 20)", "TA >
10 and TB > 20");
assertRewrite("((TB > 30 and TA > 40) and TA > 20) and (TB > 10 and TB
> 20)", "TB > 30 and TA > 40");
assertRewrite("(TA > 10 and TB > 10) or (TB > 10 and TB > 20)", "TA >
10 and TB > 10 or TB > 20");
- assertRewrite("((TA > 10 or TA > 5) and TB > 10) or (TB > 10 and (TB >
20 or TB < 10))", "(TA > 5 and TB > 10) or (TB > 10 and (TB < 10 or TB > 20))");
+ assertRewrite("((TA > 10 or TA > 5) and TB > 10) or (TB > 10 and (TB >
20 or TB < 10))", "(TA > 5 and TB > 10) or TB > 20");
assertRewriteNotNull("TA in (1,2,3) and TA > 10", "FALSE");
assertRewrite("TA in (1,2,3) and TA > 10", "TA is null and null");
assertRewrite("TA in (1,2,3) and TA >= 1", "TA in (1,2,3)");
assertRewrite("TA in (1,2,3) and TA > 1", "TA IN (2, 3)");
assertRewrite("TA in (1,2,3) or TA >= 1", "TA >= 1");
+ assertRewrite("TA is null and (TA = 4 or TA = 5)", "TA in (4, 5) and
TA is null");
+ assertRewrite("(TA != 3 or TA is null) and (TA = 4 or TA = 5)", "TA in
(4, 5)");
assertRewrite("TA in (1)", "TA in (1)");
assertRewrite("TA in (1,2,3) and TA < 10", "TA in (1,2,3)");
assertRewriteNotNull("TA in (1,2,3) and TA < 1", "FALSE");
@@ -179,7 +346,47 @@ public class SimplifyRangeTest extends ExpressionRewrite {
assertRewrite("(TA > 3 and TA < 1) and (TB < 5 and TB = 6)", "TA is
null and null and TB is null");
assertRewrite("TA > 3 and TB < 5 and TA < 1", "TA is null and null and
TB < 5");
assertRewrite("(TA > 3 and TA < 1) or TB < 5", "(TA is null and null)
or TB < 5");
- assertRewrite("((IA = 1 AND SC ='1') OR SC = '1212') AND IA =1", "((IA
= 1 AND SC ='1') OR SC = '1212') AND IA =1");
+
+ // A and (B or C) = A
+ assertRewrite("TA > 10 and (TA > 5 or (TA is not null and TA > 1))",
"TA > 10");
+ assertRewrite("TA > 10 and (TA != 4 or (TA is not null and TA > 1))",
"TA > 10");
+ assertRewrite("TA = 5 and (TA != 4 or (TA is not null and TA > 1))",
"TA = 5");
+ assertRewrite("TA = 5 and (TA in (1, 2, 5) or (TA is not null and TA >
1))", "TA = 5");
+ assertRewrite("TA = 5 and (TA > 3 or (TA is not null and TA > 1))",
"TA = 5");
+ assertRewrite("TA not in (1, 2) and (TA not in (1) or (TA is not null
and TA > 1))", "TA not in (1, 2)");
+ assertRewrite("TA not in (1, 2) and (TA not in (1, 2) or (TA is not
null and TA > 1))", "TA not in (1, 2)");
+ assertRewrite("TA not in (1, 2) and (TA not in (2, 3) or (TA is not
null and TA > 1))", "TA not in (1, 2) and (TA not in (2, 3) or (TA > 1 and TA
is not null))");
+ assertRewrite("TA is null and null and (TA = 10 or (TA is not null and
TA > 1))", "TA is null and null");
+ assertRewrite("TA is null and null and (TA != 10 or (TA is not null
and TA > 1))", "TA is null and null");
+ assertRewrite("TA is null and null and (TA > 20 or (TA is not null and
TA > 1))", "TA is null and null");
+ assertRewrite("TA is null and null and (TA is null and null or (TA is
not null and TA > 1))", "TA is null and null");
+ assertRewrite("TA is null and null and (TA is null or (TA is not null
and TA > 1))", "TA is null and null");
+ assertRewrite("TA is null and (TA is null or (TA is not null and TA >
1))", "TA is null");
+ assertRewrite("TA is not null and (TA is not null or (TA is not null
and TA > 1))", "TA is not null");
+
+ assertRewrite("TA is null and null", "TA is null and null");
+ assertRewriteNotNull("TA is null and null", "FALSE");
+ assertRewrite("TA is null", "TA is null");
+ assertRewriteNotNull("TA is null", "TA is null");
+ assertRewrite("TA is not null", "TA is not null");
+ assertRewriteNotNull("TA is not null", "TA is not null");
+ assertRewrite("TA is null and null or TA is null", "TA is null");
+ assertRewriteNotNull("TA is null and null or TA is null", "TA is
null");
+ assertRewrite("TA is null and null or TA is not null", "TA is not null
or TA is null and null");
+ assertRewriteNotNull("TA is null and null or TA is not null", "not TA
is null");
+ assertRewrite("TA is null or TA is not null", "TRUE");
+ assertRewriteNotNull("TA is null or TA is not null", "TRUE");
+ assertRewrite("(TA is null and null) and TA is null", "TA is null and
null");
+ assertRewriteNotNull("(TA is null and null) and TA is null", "FALSE");
+ assertRewrite("TA is null and null and TA is not null", "FALSE");
+ assertRewriteNotNull("TA is null and null and TA is not null",
"FALSE");
+ assertRewrite("TA is null and TA is not null", "FALSE");
+ assertRewriteNotNull("TA is null and TA is not null", "FALSE");
+
+ assertRewrite("(TA is not null or null) and TA > 10", "TA > 10");
+ assertRewrite("(TA is not null or null) or TA > 10", "TA is not null
or null");
+ assertRewrite("(TA is not null or null) or TA is null", "TRUE");
+ assertRewrite("TA is not null or null or TA is not null", "TA is not
null or null");
assertRewrite("TA + TC", "TA + TC");
assertRewrite("(TA + TC >= 1 and TA + TC <=3 ) or (TA + TC > 5 and TA
+ TC < 7)", "(TA + TC >= 1 and TA + TC <=3 ) or (TA + TC > 5 and TA + TC < 7)");
@@ -208,7 +415,8 @@ public class SimplifyRangeTest extends ExpressionRewrite {
assertRewrite("(TA + TC > 10 or TA + TC > 20) and (TB > 10 and TB >
20)", "TA + TC > 10 and TB > 20");
assertRewrite("((TB > 30 and TA + TC > 40) and TA + TC > 20) and (TB >
10 and TB > 20)", "TB > 30 and TA + TC > 40");
assertRewrite("(TA + TC > 10 and TB > 10) or (TB > 10 and TB > 20)",
"TA + TC > 10 and TB > 10 or TB > 20");
- assertRewrite("((TA + TC > 10 or TA + TC > 5) and TB > 10) or (TB > 10
and (TB > 20 or TB < 10))", "(TA + TC > 5 and TB > 10) or (TB > 10 and (TB < 10
or TB > 20))");
+ assertRewrite("((TA + TC > 10 or TA + TC > 5) and TB > 10) or (TB > 10
and (TB > 20 or TB < 10))",
+ "(TA + TC > 5 and TB > 10) or TB > 20");
assertRewriteNotNull("TA + TC in (1,2,3) and TA + TC > 10", "FALSE");
assertRewrite("TA + TC in (1,2,3) and TA + TC > 10", "(TA + TC) is
null and null");
assertRewrite("TA + TC in (1,2,3) and TA + TC >= 1", "TA + TC in
(1,2,3)");
@@ -234,21 +442,27 @@ public class SimplifyRangeTest extends ExpressionRewrite {
assertRewriteNotNull("(TA + TC > 3 and TA + TC < 1) and TB < 5",
"FALSE");
assertRewrite("(TA + TC > 3 and TA + TC < 1) and TB < 5", "(TA + TC)
is null and null and TB < 5");
assertRewrite("(TA + TC > 3 and TA + TC < 1) or TB < 5", "((TA + TC)
is null and null) OR TB < 5");
-
assertRewrite("(TA + TC > 3 OR TA < 1) AND TB = 2 AND IA =1", "(TA +
TC > 3 OR TA < 1) AND TB = 2 AND IA =1");
- assertRewrite("SA = '20250101' and SA < '20200101'", "SA is null and
null");
- assertRewrite("SA > '20250101' and SA > '20260110'", "SA >
'20260110'");
// random is non-foldable, so the two random(1, 10) are distinct,
cann't merge range for them.
Expression expr = rewriteExpression("X + random(1, 10) > 10 AND X +
random(1, 10) < 1", true);
Assertions.assertEquals("AND[((X + random(1, 10)) > 10),((X +
random(1, 10)) < 1)]", expr.toSql());
-
expr = rewrite("TA + random(1, 10) between 10 and 20",
Maps.newHashMap());
Assertions.assertEquals("AND[((cast(TA as BIGINT) + random(1, 10)) >=
10),((cast(TA as BIGINT) + random(1, 10)) <= 20)]", expr.toSql());
expr = rewrite("TA + random(1, 10) between 20 and 10",
Maps.newHashMap());
Assertions.assertEquals("AND[(cast(TA as BIGINT) + random(1, 10)) IS
NULL,NULL]", expr.toSql());
}
+ @Test
+ public void testSimplifyString() {
+ executor = new ExpressionRuleExecutor(ImmutableList.of(
+ bottomUp(SimplifyRange.INSTANCE)
+ ));
+ assertRewrite("SA = '20250101' and SA < '20200101'", "SA is null and
null");
+ assertRewrite("SA > '20250101' and SA > '20260110'", "SA >
'20260110'");
+ assertRewrite("((IA = 1 AND SC ='1') OR SC = '1212') AND IA =1", "((IA
= 1 AND SC ='1') OR SC = '1212') AND IA =1");
+ }
+
@Test
public void testSimplifyDate() {
executor = new ExpressionRuleExecutor(ImmutableList.of(
@@ -415,8 +629,7 @@ public class SimplifyRangeTest extends ExpressionRewrite {
}
private ValueDesc getValueDesc(String expression) {
- Map<String, Slot> mem = Maps.newHashMap();
- Expression parseExpression =
replaceUnboundSlot(PARSER.parseExpression(expression), mem);
+ Expression parseExpression =
replaceUnboundSlot(PARSER.parseExpression(expression), commonMem);
parseExpression = typeCoercion(parseExpression);
return (new RangeInference()).getValue(parseExpression, context);
}
@@ -466,7 +679,8 @@ public class SimplifyRangeTest extends ExpressionRewrite {
}
if (expression instanceof UnboundSlot) {
String name = ((UnboundSlot) expression).getName();
- mem.putIfAbsent(name, new SlotReference(name,
getType(name.charAt(0))));
+ boolean notNullable = name.charAt(0) == 'X' || name.length() >= 2
&& name.charAt(1) == 'X';
+ mem.putIfAbsent(name, new SlotReference(name,
getType(name.charAt(0)), !notNullable));
return mem.get(name);
}
return hasNewChildren ? expression.withChildren(children) : expression;
diff --git
a/regression-test/data/nereids_rules_p0/adjust_nullable/test_subquery_nullable.out
b/regression-test/data/nereids_rules_p0/adjust_nullable/test_subquery_nullable.out
index 348814d39f1..cdf32d50f56 100644
---
a/regression-test/data/nereids_rules_p0/adjust_nullable/test_subquery_nullable.out
+++
b/regression-test/data/nereids_rules_p0/adjust_nullable/test_subquery_nullable.out
@@ -127,7 +127,7 @@ PhysicalResultSink
-- !correlate_notop_notnullable_agg_scalar_subquery_shape --
PhysicalResultSink
--PhysicalDistribute[DistributionSpecGather]
-----PhysicalProject[AND[any_value(count(x)) IS NULL,NULL] AS `x > 10 and x <
1`, a AS `a`, assert_true(OR[count(*) IS NULL,(count(*) <= 1)], 'correlate
scalar subquery must return only 1 row') AS `assert_true(OR[count(*) IS
NULL,(count(*) <= 1)], 'correlate scalar subquery must return only 1 row')`,
assert_true(OR[count(*) IS NULL,(count(*) <= 1)], 'correlate scalar subquery
must return only 1 row') AS `assert_true(OR[count(*) IS NULL,(count(*) <= 1)],
'correlate scalar subquery must ret [...]
+----PhysicalProject[AND[any_value(count(x)) IS NULL,NULL] AS `x > 10 and x <
1`, a AS `a`, assert_true(OR[(count(*) <= 1),count(*) IS NULL], 'correlate
scalar subquery must return only 1 row') AS `assert_true(OR[count(*) IS
NULL,(count(*) <= 1)], 'correlate scalar subquery must return only 1 row')`,
assert_true(OR[(count(*) <= 1),count(*) IS NULL], 'correlate scalar subquery
must return only 1 row') AS `assert_true(OR[count(*) IS NULL,(count(*) <= 1)],
'correlate scalar subquery must ret [...]
------hashJoin[LEFT_OUTER_JOIN broadcast] hashCondition=((expr_(cast(a as
BIGINT) + cast(b as BIGINT)) = cast(x as BIGINT))) otherCondition=()
--------PhysicalProject[(cast(a as BIGINT) + cast(b as BIGINT)) AS
`expr_(cast(a as BIGINT) + cast(b as BIGINT))`, test_subquery_nullable_t1.a]
----------PhysicalOlapScan[test_subquery_nullable_t1]
diff --git
a/regression-test/data/nereids_rules_p0/filter_push_down/push_down_filter_other_condition.out
b/regression-test/data/nereids_rules_p0/filter_push_down/push_down_filter_other_condition.out
index 370abe414b5..359c8aace75 100644
---
a/regression-test/data/nereids_rules_p0/filter_push_down/push_down_filter_other_condition.out
+++
b/regression-test/data/nereids_rules_p0/filter_push_down/push_down_filter_other_condition.out
@@ -229,7 +229,7 @@ PhysicalResultSink
-- !pushdown_left_outer_join_subquery_outer --
PhysicalResultSink
---NestedLoopJoin[INNER_JOIN]OR[(t1.id = t2.id),AND[id IS NULL,(id > 1)]]
+--NestedLoopJoin[INNER_JOIN]OR[(t1.id = t2.id),AND[(id > 1),id IS NULL]]
----PhysicalOlapScan[t1]
----PhysicalAssertNumRows
------PhysicalOlapScan[t2]
diff --git
a/regression-test/data/nereids_rules_p0/predicate_infer/infer_predicate.out
b/regression-test/data/nereids_rules_p0/predicate_infer/infer_predicate.out
index 0794088f2d0..6351649fbec 100644
--- a/regression-test/data/nereids_rules_p0/predicate_infer/infer_predicate.out
+++ b/regression-test/data/nereids_rules_p0/predicate_infer/infer_predicate.out
@@ -354,14 +354,14 @@ PhysicalResultSink
PhysicalResultSink
--hashJoin[INNER_JOIN] hashCondition=((t12.id = t34.id)) otherCondition=()
----hashJoin[INNER_JOIN] hashCondition=((t1.id = t2.id)) otherCondition=()
-------filter(( not (id = 3)) and ( not (id = 4)) and (t1.id < 9) and (t1.id >
1))
+------filter(( not id IN (3, 4)) and (t12.id < 9) and (t12.id > 1))
--------PhysicalOlapScan[t1]
-------filter(( not (id = 3)) and ( not (id = 4)) and (t2.id < 9) and (t2.id >
1))
+------filter(( not id IN (3, 4)) and (t2.id < 9) and (t2.id > 1))
--------PhysicalOlapScan[t2]
----hashJoin[INNER_JOIN] hashCondition=((t3.id = t4.id)) otherCondition=()
-------filter(( not (id = 3)) and ( not (id = 4)) and (t34.id < 9) and (t34.id
> 1))
+------filter(( not id IN (3, 4)) and (t3.id < 9) and (t3.id > 1))
--------PhysicalOlapScan[t3]
-------filter(( not (id = 3)) and ( not (id = 4)) and (t4.id < 9) and (t4.id >
1))
+------filter(( not id IN (3, 4)) and (t4.id < 9) and (t4.id > 1))
--------PhysicalOlapScan[t4]
-- !infer8 --
diff --git
a/regression-test/data/query_p0/virtual_slot_ref/adjust_virtual_slot_nullable.out
b/regression-test/data/query_p0/virtual_slot_ref/adjust_virtual_slot_nullable.out
index ffb68feb194..7785b1127f6 100644
---
a/regression-test/data/query_p0/virtual_slot_ref/adjust_virtual_slot_nullable.out
+++
b/regression-test/data/query_p0/virtual_slot_ref/adjust_virtual_slot_nullable.out
@@ -4,14 +4,14 @@ PhysicalResultSink
--hashJoin[INNER_JOIN] hashCondition=((t1.c_int = t2.c_int)) otherCondition=()
----PhysicalOlapScan[tbl_adjust_virtual_slot_nullable_1(t1)]
----PhysicalProject
-------filter(OR[( not dayofmonth(c_date) IN (1, 3)),( not dayofmonth(c_date)
IN (2, 3))])
+------filter(OR[( not dayofmonth(c_date) IN (1, 3)),( not
(cast(dayofmonth(c_date) as INT) = c_int))])
--------PhysicalOlapScan[tbl_adjust_virtual_slot_nullable_2(t2)]
-- !left_join_result --
-1 2020-01-01 1 2022-02-01
1 2020-01-01 1 2022-02-02
-1 2020-01-02 1 2022-02-01
+1 2020-01-01 1 2022-02-03
1 2020-01-02 1 2022-02-02
-1 2020-01-03 1 2022-02-01
+1 2020-01-02 1 2022-02-03
1 2020-01-03 1 2022-02-02
+1 2020-01-03 1 2022-02-03
diff --git
a/regression-test/suites/mv_p0/where/k123_nereids/k123_nereids.groovy
b/regression-test/suites/mv_p0/where/k123_nereids/k123_nereids.groovy
index 281e18d68a2..6be70892232 100644
--- a/regression-test/suites/mv_p0/where/k123_nereids/k123_nereids.groovy
+++ b/regression-test/suites/mv_p0/where/k123_nereids/k123_nereids.groovy
@@ -83,7 +83,7 @@ suite ("k123p_nereids") {
qt_select_mv_constant """select bitmap_empty() from d_table where true;"""
- mv_rewrite_success_without_check_chosen("select k2 from d_table where k1=1
and (k1>2 or k1 < 0) order by k2;", "kwh1")
+ mv_rewrite_success_without_check_chosen("select k2 from d_table where k1=1
and (k1>2 or k1 * k1 > 10) order by k2;", "kwh1")
qt_select_mv "select k2 from d_table where k1=1 and (k1>2 or k1 < 0) order
by k2;"
diff --git
a/regression-test/suites/query_p0/virtual_slot_ref/adjust_virtual_slot_nullable.groovy
b/regression-test/suites/query_p0/virtual_slot_ref/adjust_virtual_slot_nullable.groovy
index df2a3087906..04273e48403 100644
---
a/regression-test/suites/query_p0/virtual_slot_ref/adjust_virtual_slot_nullable.groovy
+++
b/regression-test/suites/query_p0/virtual_slot_ref/adjust_virtual_slot_nullable.groovy
@@ -51,7 +51,7 @@ suite("adjust_virtual_slot_nullable") {
NOT (
day(t2.c_date) IN (1, 3)
AND
- day(t2.c_date) IN (2, 3, 3)
+ day(t2.c_date) = t2.c_int
);
"""
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]