yujun777 commented on code in PR #56469:
URL: https://github.com/apache/doris/pull/56469#discussion_r2390206540
##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FoldConstantRuleOnFE.java:
##########
@@ -565,57 +566,68 @@ public Expression visitBinaryArithmetic(BinaryArithmetic
binaryArithmetic, Expre
public Expression visitCaseWhen(CaseWhen caseWhen,
ExpressionRewriteContext context) {
CaseWhen originCaseWhen = caseWhen;
caseWhen = rewriteChildren(caseWhen, context);
- Expression newDefault = null;
- boolean foundNewDefault = false;
-
- List<WhenClause> whenClauses = new ArrayList<>();
+ final Expression oldDefault = caseWhen.getDefaultValue().orElse(null);
+ Expression newDefault = oldDefault;
+ ImmutableList.Builder<WhenClause> whenClausesBuilder
+ =
ImmutableList.builderWithExpectedSize(caseWhen.getWhenClauses().size());
+ Set<Expression> uniqueOperands = Sets.newHashSet();
for (WhenClause whenClause : caseWhen.getWhenClauses()) {
Expression whenOperand = whenClause.getOperand();
-
- if (!(whenOperand.isLiteral())) {
- whenClauses.add(new WhenClause(whenOperand,
whenClause.getResult()));
+ if (!whenOperand.isLiteral() && uniqueOperands.add(whenOperand)) {
+ whenClausesBuilder.add(new WhenClause(whenOperand,
whenClause.getResult()));
} else if (BooleanLiteral.TRUE.equals(whenOperand)) {
- foundNewDefault = true;
newDefault = whenClause.getResult();
break;
}
}
-
- Expression defaultResult = null;
- if (caseWhen.getDefaultValue().isPresent()) {
- defaultResult = caseWhen.getDefaultValue().get();
- }
- if (foundNewDefault) {
- defaultResult = newDefault;
+ List<WhenClause> newWhenClauses = whenClausesBuilder.build();
+ Expression realTypeCoercionDefault = newDefault == null
+ ? new NullLiteral(caseWhen.getDataType())
+ : TypeCoercionUtils.ensureSameResultType(originCaseWhen,
newDefault, context);
+ boolean allThenEqualsDefault = true;
+ for (WhenClause whenClause : newWhenClauses) {
+ Expression typeCoercionThen =
TypeCoercionUtils.ensureSameResultType(
+ originCaseWhen, whenClause.getResult(), context);
Review Comment:
update
##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/NestedCaseWhenCondToLiteral.java:
##########
@@ -0,0 +1,227 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.expression.rules;
+
+import org.apache.doris.common.Pair;
+import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher;
+import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory;
+import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
+import org.apache.doris.nereids.rules.expression.ExpressionRuleType;
+import org.apache.doris.nereids.trees.expressions.CaseWhen;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.WhenClause;
+import org.apache.doris.nereids.trees.expressions.functions.scalar.If;
+import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
+import
org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.Lists;
+import com.google.common.collect.Maps;
+import com.google.common.collect.Sets;
+
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+/**
+ * For nested CaseWhen/IF expression, replace the inner CaseWhen/IF condition
with TRUE/FALSE literal
+ * when the condition also exists in the outer CaseWhen/IF conditions.
+ *
+ * on the nested CASE/IF path, a condition may exist in multiple CASE/IF
branches,
+ * for any inner case when or if condition, its boolean value is determined by
the outermost CASE/IF branch,
+ * that is the first occurrence of the condition on the nested CASE/IF path.
+ *
+ * <br>
+ * 1. if it exists in outer case's current branch condition, replace it with
TRUE
+ * e.g.
+ * case when A then
+ * (case when A then 1 else 2 end)
+ * ...
+ * end
+ * then inner case condition A will replace with TRUE:
+ * case when A then
+ * (case when TRUE then 1 else 2 end)
+ * ...
+ * end
+ * <br>
+ * 2. if it exists in outer case's previous branch condition, replace it with
FALSE
+ * e.g.
+ * case when A then ...
+ * when B then
+ * (case when A then 1 else 2 end)
+ * ...
+ * end
+ * then inner case condition A will replace with FALSE:
+ * case when A then ...
+ * when B then
+ * (case when FALSE then 1 else 2 end)
+ * ...
+ * end
+ * <br>
+ */
+public class NestedCaseWhenCondToLiteral implements
ExpressionPatternRuleFactory {
+
+ public static final NestedCaseWhenCondToLiteral INSTANCE = new
NestedCaseWhenCondToLiteral();
+
+ @Override
+ public List<ExpressionPatternMatcher<? extends Expression>> buildRules() {
+ return ImmutableList.of(
+ root(Expression.class)
+ .when(this::needRewrite)
+ .thenApply(ctx -> rewrite(ctx.expr,
ctx.rewriteContext))
+
.toRule(ExpressionRuleType.NESTED_CASE_WHEN_COND_TO_LITERAL)
+ );
+ }
+
+ private boolean needRewrite(Expression expression) {
+ return expression.containsType(CaseWhen.class, If.class);
+ }
+
+ private Expression rewrite(Expression expression, ExpressionRewriteContext
context) {
+ return expression.accept(new NestedCondReplacer(), context);
+ }
+
+ private static class NestedCondReplacer extends
DefaultExpressionRewriter<ExpressionRewriteContext> {
+
+ // condition literals is used to record the boolean literal for a
condition expression,
+ // 1. if a condition, if it exists in outer case/if conditions, it
will be replaced with the literal.
+ // 2. otherwise it's the first time occur, then:
+ // a) when enter a case/if branch, set this condition to TRUE
literal
+ // b) when leave a case/if branch, set this condition to FALSE
literal
+ // c) when leave the whole case/if statement, remove this condition
literal
+ private final Map<Expression, BooleanLiteral> conditionLiterals =
Maps.newHashMap();
Review Comment:
add ut
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]