This is an automated email from the ASF dual-hosted git repository.

morrysnow pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new 615ec034c3f [feat](case when) replace null with false for case when 
condition (#56424)
615ec034c3f is described below

commit 615ec034c3f196828368347e02e6748ec2ecdf37
Author: yujun <[email protected]>
AuthorDate: Mon Sep 29 14:49:26 2025 +0800

    [feat](case when) replace null with false for case when condition (#56424)
    
    for a case when condition, the condition evaluate result is null or
    false have the same effect: not hit the condition.
    
    in most case, nullable cann't fold in logistic expression, for example
    `null and a = 1` and `null or a = 1` cann't fold.
    but false can fold in logistic expression, `false and a=1` can fold to
    false, `false or a = 1` can fold to `a = 1`.
    
    so if we replace the null to false in case when condition, then the
    expression may be fold more simple.
    
    in fact, for case/if condition, null can replace with FALSE when it is
    the expression root or all its ancestors are AND/OR/CASE IF CONDITION,
    and this rewrite will not change the hit or not of the branch.
    
    for example:
    
    for sql:   'case  when null and a > 1 then ...':
    1. after use this rule rewrite to 'case when false and a > 1 then ... ',
    2. then constant fold rule will rewrite it to 'case when false then ...',
    3. then case when can remove this branch since its condition is false.
---
 .../rules/expression/ExpressionOptimization.java   |   2 +
 .../rules/expression/ExpressionRuleType.java       |   1 +
 .../rules/ReplaceNullWithFalseForCond.java         | 131 ++++++++++++++++++
 .../nereids/rules/rewrite/EliminateFilter.java     |  38 +-----
 .../trees/expressions/functions/scalar/If.java     |  15 +++
 .../apache/doris/nereids/util/ExpressionUtils.java |  16 ---
 .../expression/ExpressionRewriteTestHelper.java    |  13 +-
 .../rules/ReplaceNullWithFalseForCondTest.java     | 146 +++++++++++++++++++++
 .../nereids/rules/rewrite/EliminateFilterTest.java |  24 ----
 9 files changed, 310 insertions(+), 76 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionOptimization.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionOptimization.java
index 489af4b331c..a7e6f805516 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionOptimization.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionOptimization.java
@@ -26,6 +26,7 @@ import 
org.apache.doris.nereids.rules.expression.rules.DistinctPredicatesRule;
 import org.apache.doris.nereids.rules.expression.rules.ExtractCommonFactorRule;
 import org.apache.doris.nereids.rules.expression.rules.LikeToEqualRewrite;
 import org.apache.doris.nereids.rules.expression.rules.NullSafeEqualToEqual;
+import 
org.apache.doris.nereids.rules.expression.rules.ReplaceNullWithFalseForCond;
 import 
org.apache.doris.nereids.rules.expression.rules.SimplifyComparisonPredicate;
 import 
org.apache.doris.nereids.rules.expression.rules.SimplifyConflictCompound;
 import org.apache.doris.nereids.rules.expression.rules.SimplifyInPredicate;
@@ -57,6 +58,7 @@ public class ExpressionOptimization extends ExpressionRewrite 
{
 
                     DateFunctionRewrite.INSTANCE,
                     ArrayContainToArrayOverlap.INSTANCE,
+                    ReplaceNullWithFalseForCond.INSTANCE,
                     CaseWhenToIf.INSTANCE,
                     TopnToMax.INSTANCE,
                     NullSafeEqualToEqual.INSTANCE,
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionRuleType.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionRuleType.java
index fa29d64888c..1cb43a3113d 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionRuleType.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionRuleType.java
@@ -44,6 +44,7 @@ public enum ExpressionRuleType {
     MEDIAN_CONVERT,
     NORMALIZE_BINARY_PREDICATES,
     NULL_SAFE_EQUAL_TO_EQUAL,
+    REPLACE_NULL_WITH_FALSE_FOR_COND,
     REPLACE_VARIABLE_BY_LITERAL,
     SIMPLIFY_ARITHMETIC_COMPARISON,
     SIMPLIFY_ARITHMETIC,
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/ReplaceNullWithFalseForCond.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/ReplaceNullWithFalseForCond.java
new file mode 100644
index 00000000000..43a63a905c1
--- /dev/null
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/ReplaceNullWithFalseForCond.java
@@ -0,0 +1,131 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.expression.rules;
+
+import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher;
+import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory;
+import org.apache.doris.nereids.rules.expression.ExpressionRuleType;
+import org.apache.doris.nereids.trees.expressions.CaseWhen;
+import org.apache.doris.nereids.trees.expressions.CompoundPredicate;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.WhenClause;
+import org.apache.doris.nereids.trees.expressions.functions.scalar.If;
+import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
+import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
+import org.apache.doris.nereids.types.DataType;
+
+import com.google.common.collect.ImmutableList;
+
+import java.util.List;
+import java.util.Optional;
+
+/**
+ * Replace null literal with false literal for condition expression.
+ * Because in nereids, we use boolean type to represent three-value logic,
+ * so we need to replace null literal with false literal for condition 
expression.
+ * For example, in filter, join condition, case when predicate, etc.
+ *
+ * rule: if(null and a > 1, ...) => if(false and a > 1, ...)
+ *       case when null and a > 1 then ... => case when false and a > 1 then 
...
+ *       null or (null and a > 1) or not(null) => false or (false and a > 1) 
or not(null)
+ */
+public class ReplaceNullWithFalseForCond implements 
ExpressionPatternRuleFactory {
+
+    public static final ReplaceNullWithFalseForCond INSTANCE = new 
ReplaceNullWithFalseForCond();
+
+    @Override
+    public List<ExpressionPatternMatcher<? extends Expression>> buildRules() {
+        return ImmutableList.of(
+                matchesTopType(CaseWhen.class).then(this::rewrite)
+                        
.toRule(ExpressionRuleType.REPLACE_NULL_WITH_FALSE_FOR_COND),
+                matchesTopType(If.class).then(this::rewrite)
+                        
.toRule(ExpressionRuleType.REPLACE_NULL_WITH_FALSE_FOR_COND)
+        );
+    }
+
+    protected Expression rewrite(Expression expression) {
+        return replace(expression, false);
+    }
+
+    /**
+     * replace null which its ancestors are all AND/OR/CASE WHEN/IF CONDITION.
+     * NOTICE: NOT's type is boolean too, if replace null to false in NOT, 
will get NOT(NULL) = NOT(FALSE) = TRUE,
+     * but it is wrong,  NOT(NULL) = NULL. For null, only under the AND / OR, 
can rewrite it as FALSE.
+     */
+    public static Expression replace(Expression expression, boolean 
replaceCaseThen) {
+        if (!expression.containsType(NullLiteral.class)) {
+            return expression;
+        }
+        if (expression.isNullLiteral()) {
+            DataType dataType = expression.getDataType();
+            if (dataType.isBooleanType() || dataType.isNullType()) {
+                return BooleanLiteral.FALSE;
+            }
+        } else if (expression instanceof CompoundPredicate) {
+            // process AND / OR
+            ImmutableList.Builder<Expression> builder
+                    = 
ImmutableList.builderWithExpectedSize(expression.children().size());
+            for (Expression child : expression.children()) {
+                builder.add(replace(child, replaceCaseThen));
+            }
+            List<Expression> newChildren = builder.build();
+            if (!newChildren.equals(expression.children())) {
+                return expression.withChildren(builder.build());
+            }
+        } else if (expression instanceof CaseWhen) {
+            CaseWhen caseWhen = (CaseWhen) expression;
+            ImmutableList.Builder<WhenClause> whenClausesBuilder
+                    = 
ImmutableList.builderWithExpectedSize(caseWhen.getWhenClauses().size());
+            for (WhenClause whenClause : caseWhen.getWhenClauses()) {
+                Expression newOperand = replace(whenClause.getOperand(), true);
+                Expression newResult = whenClause.getResult();
+                if (replaceCaseThen) {
+                    newResult = replace(whenClause.getResult(), true);
+                }
+                whenClausesBuilder.add(new WhenClause(newOperand, newResult));
+            }
+            List<WhenClause> newWhenClauses = whenClausesBuilder.build();
+            Optional<Expression> newDefaultValueOpt = 
caseWhen.getDefaultValue();
+            if (caseWhen.getDefaultValue().isPresent() && replaceCaseThen) {
+                newDefaultValueOpt = 
Optional.of(replace(caseWhen.getDefaultValue().get(), true));
+            }
+            if (!newWhenClauses.equals(caseWhen.getWhenClauses())
+                    || !newDefaultValueOpt.equals(caseWhen.getDefaultValue())) 
{
+                return newDefaultValueOpt
+                        .map(defaultValue -> new CaseWhen(newWhenClauses, 
defaultValue))
+                        .orElseGet(() -> new CaseWhen(newWhenClauses));
+            }
+        } else if (expression instanceof If) {
+            If ifExpr = (If) expression;
+            Expression newCondition = replace(ifExpr.getCondition(), true);
+            Expression newTrueValue = ifExpr.getTrueValue();
+            Expression newFalseValue = ifExpr.getFalseValue();
+            if (replaceCaseThen) {
+                newTrueValue = replace(newTrueValue, true);
+                newFalseValue = replace(newFalseValue, true);
+            }
+            if (!newCondition.equals(ifExpr.getCondition())
+                    || !newTrueValue.equals(ifExpr.getTrueValue())
+                    || !newFalseValue.equals(ifExpr.getFalseValue())) {
+                return new If(newCondition, newTrueValue, newFalseValue);
+            }
+        }
+
+        return expression;
+    }
+}
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateFilter.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateFilter.java
index 6d28c7d030f..402df5195a1 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateFilter.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateFilter.java
@@ -22,7 +22,7 @@ import org.apache.doris.nereids.rules.Rule;
 import org.apache.doris.nereids.rules.RuleType;
 import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
 import org.apache.doris.nereids.rules.expression.rules.FoldConstantRule;
-import org.apache.doris.nereids.trees.expressions.CompoundPredicate;
+import 
org.apache.doris.nereids.rules.expression.rules.ReplaceNullWithFalseForCond;
 import org.apache.doris.nereids.trees.expressions.Expression;
 import org.apache.doris.nereids.trees.expressions.Slot;
 import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
@@ -34,7 +34,6 @@ import 
org.apache.doris.nereids.trees.plans.logical.LogicalOneRowRelation;
 import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
 import org.apache.doris.nereids.util.ExpressionUtils;
 
-import com.google.common.annotations.VisibleForTesting;
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableSet;
 
@@ -55,7 +54,9 @@ public class EliminateFilter implements RewriteRuleFactory {
                     ImmutableSet.Builder<Expression> newConjuncts = 
ImmutableSet.builder();
                     ExpressionRewriteContext context = new 
ExpressionRewriteContext(ctx.cascadesContext);
                     for (Expression expression : filter.getConjuncts()) {
-                        expression = 
FoldConstantRule.evaluate(eliminateNullLiteral(expression), context);
+                        expression = FoldConstantRule.evaluate(
+                                
ReplaceNullWithFalseForCond.replace(expression, true),
+                                context);
                         if (expression == BooleanLiteral.FALSE || 
expression.isNullLiteral()) {
                             return new 
LogicalEmptyRelation(ctx.statementContext.getNextRelationId(),
                                     filter.getOutput());
@@ -87,7 +88,9 @@ public class EliminateFilter implements RewriteRuleFactory {
         ExpressionRewriteContext context = new 
ExpressionRewriteContext(cascadesContext);
         for (Expression expression : filter.getConjuncts()) {
             Expression newExpr = ExpressionUtils.replace(expression, 
replaceMap);
-            Expression foldExpression = 
FoldConstantRule.evaluate(eliminateNullLiteral(newExpr), context);
+            Expression foldExpression = FoldConstantRule.evaluate(
+                    ReplaceNullWithFalseForCond.replace(newExpr, true),
+                    context);
 
             if (foldExpression == BooleanLiteral.FALSE || 
expression.isNullLiteral()) {
                 return new LogicalEmptyRelation(
@@ -104,31 +107,4 @@ public class EliminateFilter implements RewriteRuleFactory 
{
             return new LogicalFilter<>(conjuncts, filter.child());
         }
     }
-
-    @VisibleForTesting
-    public static Expression eliminateNullLiteral(Expression expression) {
-        if (!expression.anyMatch(e -> ((Expression) e).isNullLiteral())) {
-            return expression;
-        }
-
-        return replaceNullToFalse(expression);
-    }
-
-    // only replace null which its ancestors are all and/or
-    // NOTICE: NOT's type is boolean too, if replace null to false in NOT, 
will got NOT(NULL) = NOT(FALSE) = TRUE,
-    // but it is wrong,  NOT(NULL) = NULL. For a filter, only the AND / OR, 
can keep NULL as FALSE.
-    private static Expression replaceNullToFalse(Expression expression) {
-        if (expression.isNullLiteral()) {
-            return BooleanLiteral.FALSE;
-        }
-
-        if (expression instanceof CompoundPredicate) {
-            ImmutableList.Builder<Expression> builder = 
ImmutableList.builderWithExpectedSize(
-                    expression.children().size());
-            expression.children().forEach(e -> 
builder.add(replaceNullToFalse(e)));
-            return expression.withChildren(builder.build());
-        }
-
-        return expression;
-    }
 }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/If.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/If.java
index cf65aa68a06..34be9567dea 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/If.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/If.java
@@ -97,4 +97,19 @@ public class If extends ScalarFunction
             return null;
         }
     }
+
+    /** get condition */
+    public Expression getCondition() {
+        return child(0);
+    }
+
+    /** get true value */
+    public Expression getTrueValue() {
+        return child(1);
+    }
+
+    /** get false value */
+    public Expression getFalseValue() {
+        return child(2);
+    }
 }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java
index cd04627ee30..e6d130421ad 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java
@@ -574,22 +574,6 @@ public class ExpressionUtils {
         return result.build();
     }
 
-    private static class ExpressionReplacer
-            extends DefaultExpressionRewriter<Map<? extends Expression, ? 
extends Expression>> {
-        public static final ExpressionReplacer INSTANCE = new 
ExpressionReplacer();
-
-        private ExpressionReplacer() {
-        }
-
-        @Override
-        public Expression visit(Expression expr, Map<? extends Expression, ? 
extends Expression> replaceMap) {
-            if (replaceMap.containsKey(expr)) {
-                return replaceMap.get(expr);
-            }
-            return super.visit(expr, replaceMap);
-        }
-    }
-
     /**
      * merge arguments into an expression array
      *
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/ExpressionRewriteTestHelper.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/ExpressionRewriteTestHelper.java
index 05946071a74..8912e911bfd 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/ExpressionRewriteTestHelper.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/ExpressionRewriteTestHelper.java
@@ -49,6 +49,7 @@ import org.junit.jupiter.api.Assertions;
 
 import java.util.List;
 import java.util.Map;
+import java.util.function.Function;
 
 public abstract class ExpressionRewriteTestHelper extends ExpressionRewrite {
     protected static final NereidsParser PARSER = new NereidsParser();
@@ -62,7 +63,7 @@ public abstract class ExpressionRewriteTestHelper extends 
ExpressionRewrite {
         context = new ExpressionRewriteContext(cascadesContext);
     }
 
-    protected final void assertRewrite(String expression, String expected) {
+    protected void assertRewrite(String expression, String expected) {
         Map<String, Slot> mem = Maps.newHashMap();
         Expression needRewriteExpression = 
replaceUnboundSlot(PARSER.parseExpression(expression), mem);
         Expression expectedExpression = 
replaceUnboundSlot(PARSER.parseExpression(expected), mem);
@@ -90,12 +91,14 @@ public abstract class ExpressionRewriteTestHelper extends 
ExpressionRewrite {
     }
 
     protected void assertRewriteAfterTypeCoercion(String expression, String 
expected) {
+        assertRewriteAfterConvert(expression, expected, 
ExpressionRewriteTestHelper::typeCoercion);
+    }
+
+    protected void assertRewriteAfterConvert(String expression, String 
expected, Function<Expression, Expression> converter) {
         Map<String, Slot> mem = Maps.newHashMap();
-        Expression needRewriteExpression = PARSER.parseExpression(expression);
-        needRewriteExpression = 
typeCoercion(replaceUnboundSlot(needRewriteExpression, mem));
-        Expression expectedExpression = PARSER.parseExpression(expected);
+        Expression needRewriteExpression = 
converter.apply(replaceUnboundSlot(PARSER.parseExpression(expression), mem));
         Expression rewrittenExpression = 
executor.rewrite(needRewriteExpression, context);
-        expectedExpression = 
typeCoercion(replaceUnboundSlot(expectedExpression, mem));
+        Expression expectedExpression = 
converter.apply(replaceUnboundSlot(PARSER.parseExpression(expected), mem));
         Assertions.assertEquals(expectedExpression.toSql(), 
rewrittenExpression.toSql());
     }
 
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/ReplaceNullWithFalseForCondTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/ReplaceNullWithFalseForCondTest.java
new file mode 100644
index 00000000000..14953b21cfa
--- /dev/null
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/ReplaceNullWithFalseForCondTest.java
@@ -0,0 +1,146 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.expression.rules;
+
+import org.apache.doris.nereids.analyzer.Scope;
+import org.apache.doris.nereids.rules.analysis.ExpressionAnalyzer;
+import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
+import org.apache.doris.nereids.rules.expression.ExpressionRewriteTestHelper;
+import org.apache.doris.nereids.rules.expression.ExpressionRuleExecutor;
+import org.apache.doris.nereids.trees.expressions.And;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.Or;
+import org.apache.doris.nereids.util.ExpressionUtils;
+
+import com.google.common.collect.ImmutableList;
+import org.junit.jupiter.api.Test;
+
+import java.util.function.Function;
+
+class ReplaceNullWithFalseForCondTest extends ExpressionRewriteTestHelper {
+
+    private final ReplaceNullWithFalseForCond replaceCaseThenInstance = new 
ReplaceNullWithFalseForCond() {
+        @Override
+        protected Expression rewrite(Expression expression) {
+            return replace(expression, true);
+        }
+    };
+
+    @Test
+    void testCaseWhen() {
+        executor = new ExpressionRuleExecutor(ImmutableList.of(
+                bottomUp(ReplaceNullWithFalseForCond.INSTANCE)
+        ));
+
+        String sql = "case when null then null"
+                + " when null and a = 1 and not(null) or "
+                + " (case when a = 2 and null then null "
+                + "       when null then not(null) "
+                + "       else null or a=3"
+                + "  end) "
+                + " then (case when null then null else null end) "
+                + " else null end";
+
+        String expectedSql = "case when false then null"
+                + " when false and a = 1 and not(null) or "
+                + " (case when a = 2 and false then false "
+                + "       when false then not(null) "
+                + "       else false or a=3"
+                + "  end) "
+                + " then (case when false then null else null end) "
+                + " else null end";
+
+        assertRewrite(sql, expectedSql);
+
+        executor = new ExpressionRuleExecutor(ImmutableList.of(
+                bottomUp(replaceCaseThenInstance)
+        ));
+
+        expectedSql = "case when false then false"
+                + " when false and a = 1 and not(null) or "
+                + " (case when a = 2 and false then false "
+                + "       when false then not(null) "
+                + "       else false or a=3"
+                + "  end) "
+                + " then (case when false then false else false end) "
+                + " else false end";
+
+        assertRewrite(sql, expectedSql);
+    }
+
+    @Test
+    void testIf() {
+        executor = new ExpressionRuleExecutor(ImmutableList.of(
+                bottomUp(ReplaceNullWithFalseForCond.INSTANCE)
+        ));
+
+        String sql = "if("
+                + " null and not(null) and if(null and not(null), null, null),"
+                + " null and not(null),"
+                + " if(a = 1 and null, null, null)"
+                + ")";
+
+        String expectedSql = "if("
+                + " false and not(null) and if(false and not(null), false, 
false),"
+                + " null and not(null),"
+                + " if(a = 1 and false, null, null)"
+                + ")";
+
+        assertRewrite(sql, expectedSql);
+
+        executor = new ExpressionRuleExecutor(ImmutableList.of(
+                bottomUp(replaceCaseThenInstance, SimplifyCastRule.INSTANCE)
+        ));
+
+        expectedSql = "if("
+                + " false and not(null) and if(false and not(null), false, 
false),"
+                + " false and not(null),"
+                + " if(a = 1 and false, false, false)"
+                + ")";
+
+        assertRewrite(sql, expectedSql);
+    }
+
+    @Override
+    protected void assertRewrite(String sql, String expectedSql) {
+        Function<Expression, Expression> converter = expr -> new 
ExpressionAnalyzer(
+                null, new Scope(ImmutableList.of()), null, false, false
+        ) {
+            // ExpressionAnalyzer will rewrite 'false and xxx' to 'false', but 
we want to keep the structure of the expression,
+            @Override
+            public Expression visitAnd(And and, ExpressionRewriteContext 
context) {
+                return new And(
+                        ExpressionUtils.extractConjunction(and)
+                                .stream()
+                                .map(e -> e.accept(this, context))
+                                .collect(ImmutableList.toImmutableList()));
+            }
+
+            @Override
+            public Expression visitOr(Or or, ExpressionRewriteContext context) 
{
+                return new Or(
+                        ExpressionUtils.extractDisjunction(or)
+                                .stream()
+                                .map(e -> e.accept(this, context))
+                                .collect(ImmutableList.toImmutableList()));
+            }
+        }.analyze(expr, null);
+
+        assertRewriteAfterConvert(sql, expectedSql, converter);
+    }
+}
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/EliminateFilterTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/EliminateFilterTest.java
index 692f6532541..a295b42d6ca 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/EliminateFilterTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/EliminateFilterTest.java
@@ -24,13 +24,11 @@ import 
org.apache.doris.nereids.trees.expressions.Expression;
 import org.apache.doris.nereids.trees.expressions.GreaterThan;
 import org.apache.doris.nereids.trees.expressions.Not;
 import org.apache.doris.nereids.trees.expressions.Or;
-import org.apache.doris.nereids.trees.expressions.SlotReference;
 import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
 import org.apache.doris.nereids.trees.expressions.literal.Literal;
 import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
 import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
 import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
-import org.apache.doris.nereids.types.IntegerType;
 import org.apache.doris.nereids.util.ExpressionUtils;
 import org.apache.doris.nereids.util.LogicalPlanBuilder;
 import org.apache.doris.nereids.util.MemoPatternMatchSupported;
@@ -38,7 +36,6 @@ import org.apache.doris.nereids.util.MemoTestUtils;
 import org.apache.doris.nereids.util.PlanChecker;
 import org.apache.doris.nereids.util.PlanConstructor;
 
-import org.junit.jupiter.api.Assertions;
 import org.junit.jupiter.api.Test;
 
 import java.util.Arrays;
@@ -186,25 +183,4 @@ class EliminateFilterTest implements 
MemoPatternMatchSupported {
                         logicalFilter(logicalOlapScan()).when(f -> 
f.getPredicate() instanceof GreaterThan)
                 );
     }
-
-    @Test
-    void testEliminateNullLiteral() {
-        Expression a = new SlotReference("a", IntegerType.INSTANCE);
-        Expression b = new SlotReference("b", IntegerType.INSTANCE);
-        Expression one = Literal.of(1);
-        Expression two = Literal.of(2);
-        Expression expression = new And(Arrays.asList(
-               new And(new GreaterThan(a, one), new 
NullLiteral(IntegerType.INSTANCE)),
-               new Or(Arrays.asList(new GreaterThan(b, two), new 
NullLiteral(IntegerType.INSTANCE),
-                       new EqualTo(a, new NullLiteral(IntegerType.INSTANCE)))),
-               new Not(new And(new GreaterThan(a, one), new 
NullLiteral(IntegerType.INSTANCE)))
-        ));
-        Expression expectExpression = new And(Arrays.asList(
-                new And(new GreaterThan(a, one), BooleanLiteral.FALSE),
-                new Or(Arrays.asList(new GreaterThan(b, two), 
BooleanLiteral.FALSE,
-                        new EqualTo(a, new 
NullLiteral(IntegerType.INSTANCE)))),
-                new Not(new And(new GreaterThan(a, one), new 
NullLiteral(IntegerType.INSTANCE)))
-        ));
-        Assertions.assertEquals(expectExpression, new 
EliminateFilter().eliminateNullLiteral(expression));
-    }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to