yujun777 commented on code in PR #49982: URL: https://github.com/apache/doris/pull/49982#discussion_r2209658171
########## fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ConstantPropagation.java: ########## @@ -0,0 +1,601 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite; + +import org.apache.doris.common.Pair; +import org.apache.doris.nereids.jobs.JobContext; +import org.apache.doris.nereids.properties.DataTrait; +import org.apache.doris.nereids.properties.OrderKey; +import org.apache.doris.nereids.rules.expression.ExpressionNormalizationAndOptimization; +import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; +import org.apache.doris.nereids.rules.expression.rules.FoldConstantRule; +import org.apache.doris.nereids.trees.expressions.Alias; +import org.apache.doris.nereids.trees.expressions.And; +import org.apache.doris.nereids.trees.expressions.EqualPredicate; +import org.apache.doris.nereids.trees.expressions.EqualTo; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.Match; +import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.Or; +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral; +import org.apache.doris.nereids.trees.expressions.literal.ComparableLiteral; +import org.apache.doris.nereids.trees.expressions.literal.Literal; +import org.apache.doris.nereids.trees.plans.JoinType; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; +import org.apache.doris.nereids.trees.plans.logical.LogicalApply; +import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; +import org.apache.doris.nereids.trees.plans.logical.LogicalHaving; +import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; +import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; +import org.apache.doris.nereids.trees.plans.logical.LogicalProject; +import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat; +import org.apache.doris.nereids.trees.plans.logical.LogicalSort; +import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter; +import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter; +import org.apache.doris.nereids.util.ExpressionUtils; +import org.apache.doris.nereids.util.ImmutableEqualSet; +import org.apache.doris.nereids.util.ImmutableEqualSet.Builder; +import org.apache.doris.nereids.util.JoinUtils; +import org.apache.doris.nereids.util.PlanUtils; +import org.apache.doris.nereids.util.PredicateInferUtils; +import org.apache.doris.nereids.util.TypeCoercionUtils; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Maps; +import com.google.common.collect.Sets; +import org.apache.hadoop.util.Lists; + +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Stream; + +/** + * constant propagation, like: a = 10 and a + b > 30 => a = 10 and 10 + b > 30, + * when processing a plan, it will collect all its children's equal sets and constants uniforms, + * then use them and the plan's expressions to infer more equal sets and constants uniforms, + * finally use the combine uniforms to replace this plan's expression's slot with literals. + */ +public class ConstantPropagation extends DefaultPlanRewriter<ExpressionRewriteContext> implements CustomRewriter { + + private final ExpressionNormalizationAndOptimization exprNormalAndOpt + = new ExpressionNormalizationAndOptimization(false); + + @Override + public Plan rewriteRoot(Plan plan, JobContext jobContext) { + // logical apply uniform maybe not correct. + if (plan.anyMatch(LogicalApply.class::isInstance)) { + return plan; + } + ExpressionRewriteContext context = new ExpressionRewriteContext(jobContext.getCascadesContext()); + return plan.accept(this, context); + } + + @Override + public Plan visitLogicalFilter(LogicalFilter<? extends Plan> filter, ExpressionRewriteContext context) { + filter = visitChildren(this, filter, context); + Expression oldPredicate = filter.getPredicate(); + Expression newPredicate = replaceConstantsAndRewriteExpr(filter, oldPredicate, true, context); + if (isExprEqualIgnoreOrder(oldPredicate, newPredicate)) { + return filter; + } else { + Set<Expression> newConjuncts = Sets.newLinkedHashSet(ExpressionUtils.extractConjunction(newPredicate)); + return filter.withConjunctsAndChild(newConjuncts, filter.child()); + } + } + + @Override + public Plan visitLogicalHaving(LogicalHaving<? extends Plan> having, ExpressionRewriteContext context) { + having = visitChildren(this, having, context); + Expression oldPredicate = having.getPredicate(); + Expression newPredicate = replaceConstantsAndRewriteExpr(having, oldPredicate, true, context); + if (isExprEqualIgnoreOrder(oldPredicate, newPredicate)) { + return having; + } else { + Set<Expression> newConjuncts = Sets.newLinkedHashSet(ExpressionUtils.extractConjunction(newPredicate)); + return having.withConjunctsAndChild(newConjuncts, having.child()); + } + } + + @Override + public Plan visitLogicalProject(LogicalProject<? extends Plan> project, ExpressionRewriteContext context) { + project = visitChildren(this, project, context); + Pair<ImmutableEqualSet<Slot>, Map<Slot, Literal>> childEqualTrait = + getChildEqualSetAndConstants(project, context); + List<NamedExpression> newProjects = project.getProjects().stream() + .map(expr -> replaceNameExpressionConstants( + expr, context, childEqualTrait.first, childEqualTrait.second)) + .collect(ImmutableList.toImmutableList()); + return newProjects.equals(project.getProjects()) ? project : project.withProjects(newProjects); + } + + @Override + public Plan visitLogicalSort(LogicalSort<? extends Plan> sort, ExpressionRewriteContext context) { + sort = visitChildren(this, sort, context); + Pair<ImmutableEqualSet<Slot>, Map<Slot, Literal>> childEqualTrait = getChildEqualSetAndConstants(sort, context); + // for be, order key must be a column, not a literal, so `order by 100#xx` is ok, + // but `order by 100` will make be core. + // so after replaced, we need to remove the constant expr. + List<OrderKey> newOrderKeys = sort.getOrderKeys().stream() + .map(key -> key.withExpression( + replaceConstants(key.getExpr(), false, context, childEqualTrait.first, childEqualTrait.second))) + .filter(key -> !key.getExpr().isConstant()) + .collect(ImmutableList.toImmutableList()); + if (newOrderKeys.isEmpty()) { + return sort.child(); + } else if (!newOrderKeys.equals(sort.getOrderKeys())) { + return sort.withOrderKeys(newOrderKeys); + } else { + return sort; + } + } + + @Override + public Plan visitLogicalAggregate(LogicalAggregate<? extends Plan> aggregate, ExpressionRewriteContext context) { + aggregate = visitChildren(this, aggregate, context); + Pair<ImmutableEqualSet<Slot>, Map<Slot, Literal>> childEqualTrait = + getChildEqualSetAndConstants(aggregate, context); + + List<Expression> oldGroupByExprs = aggregate.getGroupByExpressions(); + List<Expression> newGroupByExprs = Lists.newArrayListWithExpectedSize(oldGroupByExprs.size()); + for (Expression expr : oldGroupByExprs) { + Expression newExpr = replaceConstants(expr, false, context, childEqualTrait.first, childEqualTrait.second); + if (!newExpr.isConstant()) { + newGroupByExprs.add(newExpr); + } + } + + // group by with literal and empty group by are different. + // the former can return 0 row, the latter return at least 1 row. + // when met all group by expression are constant, + // 'eliminateGroupByConstant' will put a project(alias constant as slot) below the agg, + // but this rule cann't put a project below the agg, otherwise this rule may cause a dead loop, + // so when all replaced group by expression are constant, just let new group by add an origin group by. + if (newGroupByExprs.isEmpty() && !oldGroupByExprs.isEmpty()) { + newGroupByExprs.add(oldGroupByExprs.iterator().next()); + } + Set<Expression> newGroupByExprSet = Sets.newHashSet(newGroupByExprs); + + List<NamedExpression> oldOutputExprs = aggregate.getOutputExpressions(); + List<NamedExpression> newOutputExprs = Lists.newArrayListWithExpectedSize(oldOutputExprs.size()); + ImmutableList.Builder<NamedExpression> projectBuilder + = ImmutableList.builderWithExpectedSize(oldOutputExprs.size()); + + boolean containsConstantOutput = false; + + // after normal agg, group by expressions and output expressions are slots, + // after this rule, they may rewrite to literal, since literal are not slot, + // we need eliminate the rewritten literals. + for (NamedExpression expr : oldOutputExprs) { + // ColumnPruning will also add all group by expression into output expressions + // agg output need contains group by expression + Expression replacedExpr = replaceConstants(expr, false, context, + childEqualTrait.first, childEqualTrait.second); + Expression newOutputExpr = newGroupByExprSet.contains(expr) ? expr : replacedExpr; + if (newOutputExpr instanceof NamedExpression) { + newOutputExprs.add((NamedExpression) newOutputExpr); + } + + if (replacedExpr.isConstant()) { + projectBuilder.add(new Alias(expr.getExprId(), replacedExpr, expr.getName())); + containsConstantOutput = true; + } else { + Preconditions.checkArgument(newOutputExpr instanceof NamedExpression, newOutputExpr); + projectBuilder.add(((NamedExpression) newOutputExpr).toSlot()); + } + } + + if (newGroupByExprs.equals(oldGroupByExprs) && newOutputExprs.equals(oldOutputExprs)) { + return aggregate; + } + + aggregate = aggregate.withGroupByAndOutput(newGroupByExprs, newOutputExprs); + if (containsConstantOutput) { + return PlanUtils.projectOrSelf(projectBuilder.build(), aggregate); + } else { + return aggregate; + } + } + + @Override + public Plan visitLogicalRepeat(LogicalRepeat<? extends Plan> repeat, ExpressionRewriteContext context) { + repeat = visitChildren(this, repeat, context); + // TODO: process the repeat like agg ? + return repeat; + } + + @Override + public Plan visitLogicalJoin(LogicalJoin<? extends Plan, ? extends Plan> join, ExpressionRewriteContext context) { + join = visitChildren(this, join, context); + // TODO: need rewrite the mark conditions ? Review Comment: had rewrite mark conditions -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
