yujun777 commented on code in PR #49982:
URL: https://github.com/apache/doris/pull/49982#discussion_r2209105822


##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ConstantPropagation.java:
##########
@@ -0,0 +1,601 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.rewrite;
+
+import org.apache.doris.common.Pair;
+import org.apache.doris.nereids.jobs.JobContext;
+import org.apache.doris.nereids.properties.DataTrait;
+import org.apache.doris.nereids.properties.OrderKey;
+import 
org.apache.doris.nereids.rules.expression.ExpressionNormalizationAndOptimization;
+import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
+import org.apache.doris.nereids.rules.expression.rules.FoldConstantRule;
+import org.apache.doris.nereids.trees.expressions.Alias;
+import org.apache.doris.nereids.trees.expressions.And;
+import org.apache.doris.nereids.trees.expressions.EqualPredicate;
+import org.apache.doris.nereids.trees.expressions.EqualTo;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.Match;
+import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import org.apache.doris.nereids.trees.expressions.Or;
+import org.apache.doris.nereids.trees.expressions.Slot;
+import org.apache.doris.nereids.trees.expressions.SlotReference;
+import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
+import org.apache.doris.nereids.trees.expressions.literal.ComparableLiteral;
+import org.apache.doris.nereids.trees.expressions.literal.Literal;
+import org.apache.doris.nereids.trees.plans.JoinType;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
+import org.apache.doris.nereids.trees.plans.logical.LogicalApply;
+import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
+import org.apache.doris.nereids.trees.plans.logical.LogicalHaving;
+import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
+import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
+import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat;
+import org.apache.doris.nereids.trees.plans.logical.LogicalSort;
+import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter;
+import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter;
+import org.apache.doris.nereids.util.ExpressionUtils;
+import org.apache.doris.nereids.util.ImmutableEqualSet;
+import org.apache.doris.nereids.util.ImmutableEqualSet.Builder;
+import org.apache.doris.nereids.util.JoinUtils;
+import org.apache.doris.nereids.util.PlanUtils;
+import org.apache.doris.nereids.util.PredicateInferUtils;
+import org.apache.doris.nereids.util.TypeCoercionUtils;
+
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.Preconditions;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.Maps;
+import com.google.common.collect.Sets;
+import org.apache.hadoop.util.Lists;
+
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.Set;
+import java.util.stream.Stream;
+
+/**
+ * constant propagation, like: a = 10 and a + b > 30 => a = 10 and 10 + b > 30,
+ * when processing a plan, it will collect all its children's equal sets and 
constants uniforms,
+ * then use them and the plan's expressions to infer more equal sets and 
constants uniforms,
+ * finally use the combine uniforms to replace this plan's expression's slot 
with literals.
+ */
+public class ConstantPropagation extends 
DefaultPlanRewriter<ExpressionRewriteContext> implements CustomRewriter {
+
+    private final ExpressionNormalizationAndOptimization exprNormalAndOpt
+            = new ExpressionNormalizationAndOptimization(false);
+
+    @Override
+    public Plan rewriteRoot(Plan plan, JobContext jobContext) {
+        // logical apply uniform maybe not correct.
+        if (plan.anyMatch(LogicalApply.class::isInstance)) {
+            return plan;
+        }
+        ExpressionRewriteContext context = new 
ExpressionRewriteContext(jobContext.getCascadesContext());
+        return plan.accept(this, context);
+    }
+
+    @Override
+    public Plan visitLogicalFilter(LogicalFilter<? extends Plan> filter, 
ExpressionRewriteContext context) {
+        filter = visitChildren(this, filter, context);
+        Expression oldPredicate = filter.getPredicate();
+        Expression newPredicate = replaceConstantsAndRewriteExpr(filter, 
oldPredicate, true, context);
+        if (isExprEqualIgnoreOrder(oldPredicate, newPredicate)) {
+            return filter;
+        } else {
+            Set<Expression> newConjuncts = 
Sets.newLinkedHashSet(ExpressionUtils.extractConjunction(newPredicate));
+            return filter.withConjunctsAndChild(newConjuncts, filter.child());
+        }
+    }
+
+    @Override
+    public Plan visitLogicalHaving(LogicalHaving<? extends Plan> having, 
ExpressionRewriteContext context) {
+        having = visitChildren(this, having, context);
+        Expression oldPredicate = having.getPredicate();
+        Expression newPredicate = replaceConstantsAndRewriteExpr(having, 
oldPredicate, true, context);
+        if (isExprEqualIgnoreOrder(oldPredicate, newPredicate)) {
+            return having;
+        } else {
+            Set<Expression> newConjuncts = 
Sets.newLinkedHashSet(ExpressionUtils.extractConjunction(newPredicate));
+            return having.withConjunctsAndChild(newConjuncts, having.child());
+        }
+    }
+
+    @Override
+    public Plan visitLogicalProject(LogicalProject<? extends Plan> project, 
ExpressionRewriteContext context) {
+        project = visitChildren(this, project, context);
+        Pair<ImmutableEqualSet<Slot>, Map<Slot, Literal>> childEqualTrait =
+                getChildEqualSetAndConstants(project, context);
+        List<NamedExpression> newProjects = project.getProjects().stream()
+                .map(expr -> replaceNameExpressionConstants(
+                        expr, context, childEqualTrait.first, 
childEqualTrait.second))
+                .collect(ImmutableList.toImmutableList());
+        return newProjects.equals(project.getProjects()) ? project : 
project.withProjects(newProjects);
+    }
+
+    @Override
+    public Plan visitLogicalSort(LogicalSort<? extends Plan> sort, 
ExpressionRewriteContext context) {
+        sort = visitChildren(this, sort, context);
+        Pair<ImmutableEqualSet<Slot>, Map<Slot, Literal>> childEqualTrait = 
getChildEqualSetAndConstants(sort, context);
+        // for be, order key must be a column, not a literal, so `order by 
100#xx` is ok,
+        // but `order by 100` will make be core.
+        // so after replaced, we need to remove the constant expr.
+        List<OrderKey> newOrderKeys = sort.getOrderKeys().stream()
+                .map(key -> key.withExpression(
+                        replaceConstants(key.getExpr(), false, context, 
childEqualTrait.first, childEqualTrait.second)))
+                .filter(key -> !key.getExpr().isConstant())
+                .collect(ImmutableList.toImmutableList());
+        if (newOrderKeys.isEmpty()) {
+            return sort.child();
+        } else if (!newOrderKeys.equals(sort.getOrderKeys())) {
+            return sort.withOrderKeys(newOrderKeys);
+        } else {
+            return sort;
+        }
+    }
+
+    @Override
+    public Plan visitLogicalAggregate(LogicalAggregate<? extends Plan> 
aggregate, ExpressionRewriteContext context) {
+        aggregate = visitChildren(this, aggregate, context);
+        Pair<ImmutableEqualSet<Slot>, Map<Slot, Literal>> childEqualTrait =
+                getChildEqualSetAndConstants(aggregate, context);
+
+        List<Expression> oldGroupByExprs = aggregate.getGroupByExpressions();
+        List<Expression> newGroupByExprs = 
Lists.newArrayListWithExpectedSize(oldGroupByExprs.size());
+        for (Expression expr : oldGroupByExprs) {
+            Expression newExpr = replaceConstants(expr, false, context, 
childEqualTrait.first, childEqualTrait.second);
+            if (!newExpr.isConstant()) {
+                newGroupByExprs.add(newExpr);
+            }
+        }
+
+        // group by with literal and empty group by are different.
+        // the former can return 0 row, the latter return at least 1 row.
+        // when met all group by expression are constant,
+        // 'eliminateGroupByConstant' will put a project(alias constant as 
slot) below the agg,
+        // but this rule cann't put a project below the agg, otherwise this 
rule may cause a dead loop,
+        // so when all replaced group by expression are constant, just let new 
group by add an origin group by.
+        if (newGroupByExprs.isEmpty() && !oldGroupByExprs.isEmpty()) {
+            newGroupByExprs.add(oldGroupByExprs.iterator().next());
+        }
+        Set<Expression> newGroupByExprSet = Sets.newHashSet(newGroupByExprs);
+
+        List<NamedExpression> oldOutputExprs = 
aggregate.getOutputExpressions();
+        List<NamedExpression> newOutputExprs = 
Lists.newArrayListWithExpectedSize(oldOutputExprs.size());
+        ImmutableList.Builder<NamedExpression> projectBuilder
+                = ImmutableList.builderWithExpectedSize(oldOutputExprs.size());
+
+        boolean containsConstantOutput = false;
+
+        // after normal agg, group by expressions and output expressions are 
slots,
+        // after this rule, they may rewrite to literal, since literal are not 
slot,
+        // we need eliminate the rewritten literals.
+        for (NamedExpression expr : oldOutputExprs) {
+            // ColumnPruning will also add all group by expression into output 
expressions
+            // agg output need contains group by expression
+            Expression replacedExpr = replaceConstants(expr, false, context,
+                    childEqualTrait.first, childEqualTrait.second);
+            Expression newOutputExpr = newGroupByExprSet.contains(expr) ? expr 
: replacedExpr;
+            if (newOutputExpr instanceof NamedExpression) {
+                newOutputExprs.add((NamedExpression) newOutputExpr);
+            }
+
+            if (replacedExpr.isConstant()) {
+                projectBuilder.add(new Alias(expr.getExprId(), replacedExpr, 
expr.getName()));
+                containsConstantOutput = true;
+            } else {
+                Preconditions.checkArgument(newOutputExpr instanceof 
NamedExpression, newOutputExpr);
+                projectBuilder.add(((NamedExpression) newOutputExpr).toSlot());
+            }
+        }
+
+        if (newGroupByExprs.equals(oldGroupByExprs) && 
newOutputExprs.equals(oldOutputExprs)) {
+            return aggregate;
+        }
+
+        aggregate = aggregate.withGroupByAndOutput(newGroupByExprs, 
newOutputExprs);
+        if (containsConstantOutput) {
+            return PlanUtils.projectOrSelf(projectBuilder.build(), aggregate);
+        } else {
+            return aggregate;
+        }
+    }
+
+    @Override
+    public Plan visitLogicalRepeat(LogicalRepeat<? extends Plan> repeat, 
ExpressionRewriteContext context) {
+        repeat = visitChildren(this, repeat, context);
+        // TODO: process the repeat like agg ?
+        return repeat;
+    }
+
+    @Override
+    public Plan visitLogicalJoin(LogicalJoin<? extends Plan, ? extends Plan> 
join, ExpressionRewriteContext context) {
+        join = visitChildren(this, join, context);
+        // TODO: need rewrite the mark conditions ?
+        List<Expression> allJoinConjuncts = 
Stream.concat(join.getHashJoinConjuncts().stream(),
+                        join.getOtherJoinConjuncts().stream())
+                .collect(ImmutableList.toImmutableList());
+        Expression oldPredicate = ExpressionUtils.and(allJoinConjuncts);
+        Expression newPredicate = replaceConstantsAndRewriteExpr(join, 
oldPredicate, true, context);
+        if (isExprEqualIgnoreOrder(oldPredicate, newPredicate)) {
+            return join;
+        }
+
+        // TODO: code from FindHashConditionForJoin
+        Pair<List<Expression>, List<Expression>> pair = 
JoinUtils.extractExpressionForHashTable(
+                join.left().getOutput(), join.right().getOutput(), 
ExpressionUtils.extractConjunction(newPredicate));
+
+        List<Expression> newHashJoinConjuncts = pair.first;
+        List<Expression> newOtherJoinConjuncts = pair.second;
+        JoinType joinType = join.getJoinType();
+        if (joinType == JoinType.CROSS_JOIN && 
!newHashJoinConjuncts.isEmpty()) {
+            joinType = JoinType.INNER_JOIN;
+        }
+
+        if (newHashJoinConjuncts.equals(join.getHashJoinConjuncts())
+                && newOtherJoinConjuncts.equals(join.getOtherJoinConjuncts())) 
{
+            return join;
+        }
+
+        return new LogicalJoin<>(joinType,
+                newHashJoinConjuncts,
+                newOtherJoinConjuncts,
+                join.getMarkJoinConjuncts(),
+                join.getDistributeHint(),
+                join.getMarkJoinSlotReference(),
+                join.children(), join.getJoinReorderContext());
+    }
+
+    // for sql: create table t as select cast('1' as varchar(30))
+    // the select will add a parent plan: result sink. the result sink 
contains a output slot reference, and its
+    // data type is varchar(30),  but if replace the slot reference with a 
varchar literal '1', then the data type info
+    // varchar(30) will lost, because varchar literal '1' data type is always 
varchar(1), so t's column will get
+    // a error type.
+    // so we don't rewrite logical sink then.
+    // @Override
+    // public Plan visitLogicalSink(LogicalSink<? extends Plan> sink, 
ExpressionRewriteContext context) {
+    //     sink = visitChildren(this, sink, context);
+    //     Pair<ImmutableEqualSet<Slot>, Map<Slot, Literal>> childEqualTrait
+    //     = getChildEqualSetAndConstants(sink, context);
+    //     List<NamedExpression> newOutputExprs = 
sink.getOutputExprs().stream()
+    //             .map(expr ->
+    //                     replaceNameExpressionConstants(expr, context, 
childEqualTrait.first, childEqualTrait.second))
+    //             .collect(ImmutableList.toImmutableList());
+    //     return newOutputExprs.equals(sink.getOutputExprs()) ? sink : 
sink.withOutputExprs(newOutputExprs);
+    // }
+
+    /**
+     * replace constants and rewrite expression.
+     */
+    @VisibleForTesting
+    public Expression replaceConstantsAndRewriteExpr(LogicalPlan plan, 
Expression expression,
+            boolean useInnerInfer, ExpressionRewriteContext context) {
+        // for expression `a = 1 and a + b = 2 and b + c = 2 and c + d =2 and 
...`:
+        // propagate constant `a = 1`, then get `1 + b = 2`, after rewrite 
this expression, will get `b = 1`;
+        // then propagate constant `b = 1`, then get `1 + c = 2`, after 
rewrite this expression, will get `c = 1`,
+        // ...
+        // so constant propagate and rewrite expression need to do in a loop.
+        Pair<ImmutableEqualSet<Slot>, Map<Slot, Literal>> childEqualTrait = 
getChildEqualSetAndConstants(plan, context);
+        Expression afterExpression = expression;
+        for (int i = 0; i < 100; i++) {
+            Expression beforeExpression = afterExpression;
+            afterExpression = replaceConstants(beforeExpression, 
useInnerInfer, context,
+                    childEqualTrait.first, childEqualTrait.second);
+            if (isExprEqualIgnoreOrder(beforeExpression, afterExpression)) {
+                break;
+            }
+            if (afterExpression.isLiteral()) {
+                break;
+            }
+            beforeExpression = afterExpression;
+            afterExpression = exprNormalAndOpt.rewrite(beforeExpression, 
context);
+        }
+        return afterExpression;
+    }
+
+    // process NameExpression
+    private NamedExpression replaceNameExpressionConstants(NamedExpression 
expr, ExpressionRewriteContext context,
+            ImmutableEqualSet<Slot> equalSet, Map<Slot, Literal> constants) {
+
+        // if a project item is a slot reference, and the slot equals to a 
constant value, don't rewrite it.
+        // because rule `EliminateUnnecessaryProject ` can eliminate a project 
when the project's output slots equal to
+        // its child's output slots.
+        // for example, for `sink -> ... -> project(a, b, c) -> filter(a = 10)`
+        // if rewrite project to project(alias 10 as a, b, c), later other 
rule may prune `alias 10 as a`, and project
+        // will become project(b, c), so project and filter's output slot will 
not equal,
+        // then the project cannot be eliminated.
+        // so we don't replace SlotReference.
+        // for safety reason, we only replace Alias
+        if (!(expr instanceof Alias)) {
+            return expr;
+        }
+
+        // PushProjectThroughUnion require projection is a slot reference, or 
like (cast slot reference as xx);
+        // TODO: if PushProjectThroughUnion support projection like  `literal 
as xx`, then delete this check.
+        if (ExpressionUtils.getExpressionCoveredByCast(expr.child(0)) 
instanceof SlotReference) {
+            return expr;
+        }
+
+        Expression newExpr = replaceConstants(expr, false, context, equalSet, 
constants);
+        if (newExpr instanceof NamedExpression) {
+            return (NamedExpression) newExpr;
+        } else {
+            return new Alias(expr.getExprId(), newExpr, expr.getName());
+        }
+    }
+
+    private Expression replaceConstants(Expression expression, boolean 
useInnerInfer, ExpressionRewriteContext context,
+            ImmutableEqualSet<Slot> parentEqualSet, Map<Slot, Literal> 
parentConstants) {
+        if (expression instanceof And) {
+            return replaceAndConstants((And) expression, useInnerInfer, 
context, parentEqualSet, parentConstants);
+        } else if (expression instanceof Or) {
+            return replaceOrConstants((Or) expression, useInnerInfer, context, 
parentEqualSet, parentConstants);
+        } else if (!parentConstants.isEmpty()
+                && expression.anyMatch(e -> e instanceof Slot && 
parentConstants.containsKey(e))) {
+            Expression newExpr = ExpressionUtils.replaceIf(expression, 
parentConstants, this::canReplaceExpression);
+            if (!newExpr.equals(expression)) {
+                newExpr = FoldConstantRule.evaluate(newExpr, context);
+            }
+            return newExpr;
+        } else {
+            return expression;
+        }
+    }
+
+    // process AND expression
+    private Expression replaceAndConstants(And expression, boolean 
useInnerInfer, ExpressionRewriteContext context,
+            ImmutableEqualSet<Slot> parentEqualSet, Map<Slot, Literal> 
parentConstants) {
+        List<Expression> conjunctions = 
ExpressionUtils.extractConjunction(expression);
+        Optional<Pair<ImmutableEqualSet<Slot>, Map<Slot, Literal>>> 
equalAndConstantOptions =
+                expandEqualSetAndConstants(conjunctions, useInnerInfer, 
parentEqualSet, parentConstants);
+        // infer conflict constants like a = 10 and a = 30
+        if (!equalAndConstantOptions.isPresent()) {
+            return BooleanLiteral.FALSE;
+        }
+        Set<Slot> inputSlots = expression.getInputSlots();
+        ImmutableEqualSet<Slot> newEqualSet = 
equalAndConstantOptions.get().first;
+        Map<Slot, Literal> newConstants = equalAndConstantOptions.get().second;
+        // myInferConstantSlots : the slots that are inferred by this 
expression, not inferred by parent
+        // myInferConstantSlots[slot] = true means expression had contains 
conjunct `slot = constant`
+        Map<Slot, Boolean> myInferConstantSlots = 
Maps.newLinkedHashMapWithExpectedSize(
+                Math.max(0, newConstants.size() - parentConstants.size()));
+        for (Slot slot : newConstants.keySet()) {
+            if (!parentConstants.containsKey(slot)) {
+                myInferConstantSlots.put(slot, false);
+            }
+        }
+        ImmutableList.Builder<Expression> builder = 
ImmutableList.builderWithExpectedSize(conjunctions.size());
+        for (Expression child : conjunctions) {
+            Expression newChild = child;
+            // for expression, `a = 10 and a > b` will infer constant relation 
`a = 10`,
+            // need to replace a with 10 to this expression,
+            // for the first conjunction `a = 10`, no need to replace because 
after replace will get `10 = 10`,
+            // for the second conjunction `a > b`, need replace and got `10 > 
b`
+            if (needReplaceWithConstant(newChild, newConstants, 
myInferConstantSlots)) {
+                newChild = replaceConstants(newChild, useInnerInfer, context, 
newEqualSet, newConstants);
+            }
+            if (newChild.equals(BooleanLiteral.FALSE)) {
+                return BooleanLiteral.FALSE;
+            }
+            if (newChild instanceof And) {
+                builder.addAll(ExpressionUtils.extractConjunction(newChild));
+            } else {
+                builder.add(newChild);
+            }
+        }
+        // if the expression infer `slot = constant`, but not contains 
conjunct `slot = constant`, need to add it
+        for (Map.Entry<Slot, Boolean> entry : myInferConstantSlots.entrySet()) 
{
+            // if this expression don't contain the slot, no add it, to avoid 
the expression size increase too long
+            if (!entry.getValue() && inputSlots.contains(entry.getKey())) {
+                Slot slot = entry.getKey();
+                EqualTo equal = new EqualTo(slot, newConstants.get(slot), 
true);
+                
builder.add(TypeCoercionUtils.processComparisonPredicate(equal));
+            }
+        }
+        return expression.withChildren(builder.build());
+    }
+
+    // process OR expression
+    private Expression replaceOrConstants(Or expression, boolean 
useInnerInfer, ExpressionRewriteContext context,
+            ImmutableEqualSet<Slot> parentEqualSet, Map<Slot, Literal> 
parentConstants) {
+        List<Expression> disjunctions = 
ExpressionUtils.extractDisjunction(expression);
+        ImmutableList.Builder<Expression> builder = 
ImmutableList.builderWithExpectedSize(disjunctions.size());
+        for (Expression child : disjunctions) {
+            Expression newChild = replaceConstants(child, useInnerInfer, 
context, parentEqualSet, parentConstants);
+            if (newChild.equals(BooleanLiteral.TRUE)) {
+                return BooleanLiteral.TRUE;
+            }
+            builder.add(newChild);
+        }
+        return expression.withChildren(builder.build());
+    }
+
+    private boolean needReplaceWithConstant(Expression expression, Map<Slot, 
Literal> constants,
+            Map<Slot, Boolean> myInferConstantSlots) {
+        if (expression instanceof EqualTo && expression.child(0) instanceof 
Slot) {
+            Slot slot = (Slot) expression.child(0);
+            if (myInferConstantSlots.get(slot) == Boolean.FALSE

Review Comment:
   had fix



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to