morrySnow commented on code in PR #56469: URL: https://github.com/apache/doris/pull/56469#discussion_r2385050915
########## fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/NestedCaseWhenCondToLiteral.java: ########## @@ -0,0 +1,227 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.expression.rules; + +import org.apache.doris.common.Pair; +import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher; +import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory; +import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; +import org.apache.doris.nereids.rules.expression.ExpressionRuleType; +import org.apache.doris.nereids.trees.expressions.CaseWhen; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.WhenClause; +import org.apache.doris.nereids.trees.expressions.functions.scalar.If; +import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral; +import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Lists; +import com.google.common.collect.Maps; +import com.google.common.collect.Sets; + +import java.util.List; +import java.util.Map; +import java.util.Set; + +/** + * For nested CaseWhen/IF expression, replace the inner CaseWhen/IF condition with TRUE/FALSE literal + * when the condition also exists in the outer CaseWhen/IF conditions. + * + * on the nested CASE/IF path, a condition may exist in multiple CASE/IF branches, + * for any inner case when or if condition, its boolean value is determined by the outermost CASE/IF branch, + * that is the first occurrence of the condition on the nested CASE/IF path. + * + * <br> + * 1. if it exists in outer case's current branch condition, replace it with TRUE + * e.g. + * case when A then + * (case when A then 1 else 2 end) + * ... + * end + * then inner case condition A will replace with TRUE: + * case when A then + * (case when TRUE then 1 else 2 end) + * ... + * end + * <br> + * 2. if it exists in outer case's previous branch condition, replace it with FALSE + * e.g. + * case when A then ... + * when B then + * (case when A then 1 else 2 end) + * ... + * end + * then inner case condition A will replace with FALSE: + * case when A then ... + * when B then + * (case when FALSE then 1 else 2 end) + * ... + * end + * <br> + */ +public class NestedCaseWhenCondToLiteral implements ExpressionPatternRuleFactory { + + public static final NestedCaseWhenCondToLiteral INSTANCE = new NestedCaseWhenCondToLiteral(); + + @Override + public List<ExpressionPatternMatcher<? extends Expression>> buildRules() { + return ImmutableList.of( + root(Expression.class) + .when(this::needRewrite) + .thenApply(ctx -> rewrite(ctx.expr, ctx.rewriteContext)) + .toRule(ExpressionRuleType.NESTED_CASE_WHEN_COND_TO_LITERAL) + ); + } + + private boolean needRewrite(Expression expression) { + return expression.containsType(CaseWhen.class, If.class); + } + + private Expression rewrite(Expression expression, ExpressionRewriteContext context) { + return expression.accept(new NestedCondReplacer(), context); + } + + private static class NestedCondReplacer extends DefaultExpressionRewriter<ExpressionRewriteContext> { + + // condition literals is used to record the boolean literal for a condition expression, + // 1. if a condition, if it exists in outer case/if conditions, it will be replaced with the literal. + // 2. otherwise it's the first time occur, then: + // a) when enter a case/if branch, set this condition to TRUE literal + // b) when leave a case/if branch, set this condition to FALSE literal + // c) when leave the whole case/if statement, remove this condition literal + private final Map<Expression, BooleanLiteral> conditionLiterals = Maps.newHashMap(); Review Comment: i think u'd better add a ut to check this map work as u expected under all conditions ########## fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FoldConstantRuleOnFE.java: ########## @@ -565,57 +566,68 @@ public Expression visitBinaryArithmetic(BinaryArithmetic binaryArithmetic, Expre public Expression visitCaseWhen(CaseWhen caseWhen, ExpressionRewriteContext context) { CaseWhen originCaseWhen = caseWhen; caseWhen = rewriteChildren(caseWhen, context); - Expression newDefault = null; - boolean foundNewDefault = false; - - List<WhenClause> whenClauses = new ArrayList<>(); + final Expression oldDefault = caseWhen.getDefaultValue().orElse(null); + Expression newDefault = oldDefault; + ImmutableList.Builder<WhenClause> whenClausesBuilder + = ImmutableList.builderWithExpectedSize(caseWhen.getWhenClauses().size()); + Set<Expression> uniqueOperands = Sets.newHashSet(); for (WhenClause whenClause : caseWhen.getWhenClauses()) { Expression whenOperand = whenClause.getOperand(); - - if (!(whenOperand.isLiteral())) { - whenClauses.add(new WhenClause(whenOperand, whenClause.getResult())); + if (!whenOperand.isLiteral() && uniqueOperands.add(whenOperand)) { + whenClausesBuilder.add(new WhenClause(whenOperand, whenClause.getResult())); } else if (BooleanLiteral.TRUE.equals(whenOperand)) { - foundNewDefault = true; newDefault = whenClause.getResult(); break; } } - - Expression defaultResult = null; - if (caseWhen.getDefaultValue().isPresent()) { - defaultResult = caseWhen.getDefaultValue().get(); - } - if (foundNewDefault) { - defaultResult = newDefault; + List<WhenClause> newWhenClauses = whenClausesBuilder.build(); + Expression realTypeCoercionDefault = newDefault == null + ? new NullLiteral(caseWhen.getDataType()) + : TypeCoercionUtils.ensureSameResultType(originCaseWhen, newDefault, context); + boolean allThenEqualsDefault = true; + for (WhenClause whenClause : newWhenClauses) { + Expression typeCoercionThen = TypeCoercionUtils.ensureSameResultType( + originCaseWhen, whenClause.getResult(), context); Review Comment: why need to call `ensureSameResultType`? after analysis, all when clause's result and defualt value should be same type -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
