This is an automated email from the ASF dual-hosted git repository.

yiguolei pushed a commit to branch branch-2.1
in repository https://gitbox.apache.org/repos/asf/doris.git

commit 0f81ecf41528e2f2bb39dba5956302057463fff2
Author: 谢健 <[email protected]>
AuthorDate: Thu Jan 25 14:01:34 2024 +0800

    [feat](Nereids): eliminate inner join by pk fk when comparing mv (#30258)
---
 .../doris/nereids/jobs/executor/Rewriter.java      |   2 +-
 .../joinorder/hypergraph/node/StructInfoNode.java  |  23 --
 .../rules/exploration/mv/HyperGraphComparator.java |  53 +++-
 .../nereids/rules/rewrite/EliminateJoinByFK.java   | 319 +++++----------------
 .../rules/rewrite/EliminateJoinByUnique.java       |   1 -
 .../nereids/rules/rewrite/ForeignKeyContext.java   | 184 ++++++++++++
 .../org/apache/doris/nereids/util/JoinUtils.java   |  42 ++-
 .../rules/exploration/mv/EliminateJoinTest.java    | 103 +++++++
 8 files changed, 443 insertions(+), 284 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
index 4304933f752..4b488b6cfee 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
@@ -303,7 +303,7 @@ public class Rewriter extends AbstractBatchJobExecutor {
 
             // this rule should invoke after infer predicate and push down 
distinct, and before push down limit
             topic("eliminate join according unique or foreign key",
-                custom(RuleType.ELIMINATE_JOIN_BY_FOREIGN_KEY, 
EliminateJoinByFK::new),
+                bottomUp(new EliminateJoinByFK()),
                 topDown(new EliminateJoinByUnique())
             ),
 
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/node/StructInfoNode.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/node/StructInfoNode.java
index ba31ff823a6..93e13e59da5 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/node/StructInfoNode.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/node/StructInfoNode.java
@@ -19,7 +19,6 @@ package 
org.apache.doris.nereids.jobs.joinorder.hypergraph.node;
 
 import org.apache.doris.common.Pair;
 import org.apache.doris.nereids.jobs.joinorder.hypergraph.edge.Edge;
-import org.apache.doris.nereids.properties.FunctionalDependencies;
 import org.apache.doris.nereids.trees.expressions.Expression;
 import org.apache.doris.nereids.trees.plans.GroupPlan;
 import org.apache.doris.nereids.trees.plans.Plan;
@@ -29,11 +28,8 @@ import 
org.apache.doris.nereids.trees.plans.logical.LogicalCatalogRelation;
 import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
 import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
 import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanVisitor;
-import org.apache.doris.nereids.util.JoinUtils;
 import org.apache.doris.nereids.util.Utils;
 
-import com.google.common.base.Supplier;
-import com.google.common.base.Suppliers;
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableList.Builder;
 import com.google.common.collect.ImmutableSet;
@@ -51,36 +47,17 @@ import javax.annotation.Nullable;
 public class StructInfoNode extends AbstractNode {
     private final List<Set<Expression>> expressions;
     private final Set<CatalogRelation> relationSet;
-    private final Supplier<Boolean> eliminateSupplier;
 
     public StructInfoNode(int index, Plan plan, List<Edge> edges) {
         super(extractPlan(plan), index, edges);
         relationSet = plan.collect(CatalogRelation.class::isInstance);
         expressions = collectExpressions(plan);
-        eliminateSupplier = Suppliers.memoize(this::computeElimination);
     }
 
     public StructInfoNode(int index, Plan plan) {
         this(index, plan, new ArrayList<>());
     }
 
-    private boolean computeElimination() {
-        if (getJoinEdges().isEmpty()) {
-            return false;
-        }
-        return getJoinEdges().stream().allMatch(e -> {
-            if (e.getRightExtendedNodes() == getNodeMap()) {
-                return JoinUtils.canEliminateByLeft(e.getJoin(), 
FunctionalDependencies.EMPTY_FUNC_DEPS,
-                        
plan.getLogicalProperties().getFunctionalDependencies());
-            }
-            return false;
-        });
-    }
-
-    public boolean canEliminate() {
-        return eliminateSupplier.get();
-    }
-
     private @Nullable List<Set<Expression>> collectExpressions(Plan plan) {
 
         Pair<Boolean, Builder<Set<Expression>>> collector = Pair.of(true, 
ImmutableList.builder());
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/HyperGraphComparator.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/HyperGraphComparator.java
index 1e65ad74913..3339d009c79 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/HyperGraphComparator.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/HyperGraphComparator.java
@@ -29,7 +29,9 @@ import 
org.apache.doris.nereids.rules.rewrite.PushDownFilterThroughJoin;
 import org.apache.doris.nereids.trees.expressions.Expression;
 import org.apache.doris.nereids.trees.expressions.Slot;
 import org.apache.doris.nereids.trees.plans.JoinType;
+import org.apache.doris.nereids.trees.plans.Plan;
 import org.apache.doris.nereids.util.ExpressionUtils;
+import org.apache.doris.nereids.util.JoinUtils;
 
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableMap;
@@ -143,10 +145,57 @@ public class HyperGraphComparator {
     }
 
     private boolean tryEliminateNodesAndEdge() {
-        for (int i : LongBitmap.getIterator(eliminateViewNodesMap)) {
-            if (!((StructInfoNode) viewHyperGraph.getNode(i)).canEliminate()) {
+        boolean hasFilterEdgeAbove = viewHyperGraph.getFilterEdges().stream()
+                .filter(e -> LongBitmap.getCardinality(e.getReferenceNodes()) 
== 1)
+                .anyMatch(e -> LongBitmap.isSubset(e.getReferenceNodes(), 
eliminateViewNodesMap));
+        if (hasFilterEdgeAbove) {
+            // If there is some filter edge above the eliminated node, we 
should rebuild a plan
+            // Right now, just refuse it.
+            return false;
+        }
+        for (JoinEdge joinEdge : viewHyperGraph.getJoinEdges()) {
+            if (!LongBitmap.isOverlap(joinEdge.getReferenceNodes(), 
eliminateViewNodesMap)) {
+                continue;
+            }
+            // eliminate by unique
+            if (joinEdge.getJoinType().isLeftOuterJoin()) {
+                long eliminatedRight =
+                        
LongBitmap.newBitmapIntersect(joinEdge.getRightExtendedNodes(), 
eliminateViewNodesMap);
+                if (LongBitmap.getCardinality(eliminatedRight) != 1) {
+                    return false;
+                }
+                Plan rigthPlan = viewHyperGraph
+                        
.getNode(LongBitmap.lowestOneIndex(joinEdge.getRightExtendedNodes())).getPlan();
+                return JoinUtils.canEliminateByLeft(joinEdge.getJoin(),
+                        
rigthPlan.getLogicalProperties().getFunctionalDependencies());
+            }
+            // eliminate by pk fk
+            if (joinEdge.getJoinType().isInnerJoin()) {
+                if (!joinEdge.isSimple()) {
+                    return false;
+                }
+                long eliminatedLeft =
+                        
LongBitmap.newBitmapIntersect(joinEdge.getLeftExtendedNodes(), 
eliminateViewNodesMap);
+                long eliminatedRight =
+                        
LongBitmap.newBitmapIntersect(joinEdge.getRightExtendedNodes(), 
eliminateViewNodesMap);
+                if (LongBitmap.getCardinality(eliminatedLeft) == 0
+                        && LongBitmap.getCardinality(eliminatedRight) == 1) {
+                    Plan foreign = viewHyperGraph
+                            
.getNode(LongBitmap.lowestOneIndex(joinEdge.getLeftExtendedNodes())).getPlan();
+                    Plan primary = viewHyperGraph
+                            
.getNode(LongBitmap.lowestOneIndex(joinEdge.getRightExtendedNodes())).getPlan();
+                    return JoinUtils.canEliminateByFk(joinEdge.getJoin(), 
primary, foreign);
+                } else if (LongBitmap.getCardinality(eliminatedLeft) == 1
+                        && LongBitmap.getCardinality(eliminatedRight) == 0) {
+                    Plan foreign = viewHyperGraph
+                            
.getNode(LongBitmap.lowestOneIndex(joinEdge.getRightExtendedNodes())).getPlan();
+                    Plan primary = viewHyperGraph
+                            
.getNode(LongBitmap.lowestOneIndex(joinEdge.getLeftExtendedNodes())).getPlan();
+                    return JoinUtils.canEliminateByFk(joinEdge.getJoin(), 
primary, foreign);
+                }
                 return false;
             }
+
         }
         return true;
     }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateJoinByFK.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateJoinByFK.java
index b4a6eac207b..c8eba2e4f31 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateJoinByFK.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateJoinByFK.java
@@ -17,297 +17,106 @@
 
 package org.apache.doris.nereids.rules.rewrite;
 
-import org.apache.doris.catalog.Column;
-import org.apache.doris.catalog.TableIf;
-import org.apache.doris.nereids.jobs.JobContext;
+import org.apache.doris.nereids.rules.Rule;
+import org.apache.doris.nereids.rules.RuleType;
 import org.apache.doris.nereids.trees.expressions.Alias;
 import org.apache.doris.nereids.trees.expressions.Expression;
 import org.apache.doris.nereids.trees.expressions.IsNull;
 import org.apache.doris.nereids.trees.expressions.NamedExpression;
 import org.apache.doris.nereids.trees.expressions.Not;
 import org.apache.doris.nereids.trees.expressions.Slot;
-import org.apache.doris.nereids.trees.expressions.SlotReference;
 import org.apache.doris.nereids.trees.expressions.functions.ExpressionTrait;
 import org.apache.doris.nereids.trees.plans.Plan;
-import org.apache.doris.nereids.trees.plans.logical.LogicalCatalogRelation;
 import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
 import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
 import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
-import org.apache.doris.nereids.trees.plans.logical.LogicalRelation;
-import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter;
-import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter;
 import org.apache.doris.nereids.util.ImmutableEqualSet;
+import org.apache.doris.nereids.util.JoinUtils;
 
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableMap;
 import com.google.common.collect.ImmutableSet;
 import com.google.common.collect.Sets;
 
-import java.util.HashMap;
-import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
+import java.util.Optional;
 import java.util.Set;
-import java.util.stream.Collectors;
 import javax.annotation.Nullable;
 
 /**
  * Eliminate join by foreign.
  */
-public class EliminateJoinByFK extends DefaultPlanRewriter<JobContext> 
implements CustomRewriter {
-    @Override
-    public Plan rewriteRoot(Plan plan, JobContext jobContext) {
-        EliminateJoinByFKHelper helper = new EliminateJoinByFKHelper();
-        return helper.rewriteRoot(plan, jobContext);
-    }
-
-    private static class EliminateJoinByFKHelper
-            extends DefaultPlanRewriter<ForeignKeyContext> implements 
CustomRewriter {
-
-        @Override
-        public Plan rewriteRoot(Plan plan, JobContext jobContext) {
-            return plan.accept(this, new ForeignKeyContext());
-        }
-
-        @Override
-        public Plan visit(Plan plan, ForeignKeyContext context) {
-            Plan newPlan = visitChildren(this, plan, context);
-            // always expire primary key except filter, project and join.
-            // always keep foreign key alive
-            context.expirePrimaryKey(plan);
-            return newPlan;
-        }
-
-        @Override
-        public Plan visitLogicalRelation(LogicalRelation relation, 
ForeignKeyContext context) {
-            if (!(relation instanceof LogicalCatalogRelation)) {
-                return relation;
-            }
-            context.putAllForeignKeys(((LogicalCatalogRelation) 
relation).getTable());
-            relation.getOutput().stream()
-                    .filter(SlotReference.class::isInstance)
-                    .map(SlotReference.class::cast)
-                    .forEach(context::putSlot);
-            return relation;
-        }
-
-        private boolean canEliminate(LogicalJoin<?, ?> join, Map<Slot, Slot> 
primaryToForeign,
-                ForeignKeyContext context) {
-            if (!join.getOtherJoinConjuncts().isEmpty()) {
-                return false;
-            }
-            if (!join.getJoinType().isInnerJoin() && 
!join.getJoinType().isSemiJoin()) {
-                return false;
-            }
-            return context.satisfyConstraint(primaryToForeign, join);
-        }
+public class EliminateJoinByFK extends OneRewriteRuleFactory {
 
-        private @Nullable Map<Expression, Expression> 
tryMapOutputToForeignPlan(Plan foreignPlan,
-                Set<Slot> output, Map<Slot, Slot> primaryToForeign) {
-            Set<Slot> residualPrimary = Sets.difference(output, 
foreignPlan.getOutputSet());
-            ImmutableMap.Builder<Expression, Expression> builder = new 
ImmutableMap.Builder<>();
-            for (Slot slot : residualPrimary) {
-                if (primaryToForeign.containsKey(slot)) {
-                    builder.put(slot, primaryToForeign.get(slot));
-                } else {
-                    return null;
-                }
-            }
-            return builder.build();
-        }
-
-        private Plan applyNullCompensationFilter(Plan child, Set<Slot> 
childSlots) {
-            Set<Expression> predicates = childSlots.stream()
-                    .filter(ExpressionTrait::nullable)
-                    .map(s -> new Not(new IsNull(s)))
-                    .collect(ImmutableSet.toImmutableSet());
-            if (predicates.isEmpty()) {
-                return child;
-            }
-            return new LogicalFilter<>(predicates, child);
-        }
-
-        private Plan tryEliminatePrimaryPlan(LogicalProject<LogicalJoin<?, ?>> 
project,
-                Plan foreignPlan, Set<Slot> foreignKeys,
-                Map<Slot, Slot> primaryToForeign, ForeignKeyContext context) {
-            Set<Slot> output = project.getInputSlots();
-            Map<Expression, Expression> outputToForeign =
-                    tryMapOutputToForeignPlan(foreignPlan, output, 
primaryToForeign);
-            if (outputToForeign != null && canEliminate(project.child(), 
primaryToForeign, context)) {
-                List<NamedExpression> newProjects = 
project.getProjects().stream()
-                        .map(e -> outputToForeign.containsKey(e)
-                                ? new Alias(e.getExprId(), 
outputToForeign.get(e), e.toSql())
-                                : (NamedExpression) e.rewriteUp(s -> 
outputToForeign.getOrDefault(s, s)))
-                        .collect(ImmutableList.toImmutableList());
-                return project.withProjects(newProjects)
-                        .withChildren(applyNullCompensationFilter(foreignPlan, 
foreignKeys));
-            }
-            return project;
-        }
-
-        private @Nullable Map<Slot, Slot> 
mapPrimaryToForeign(ImmutableEqualSet<Slot> equivalenceSet,
-                Set<Slot> foreignKeys) {
-            ImmutableMap.Builder<Slot, Slot> builder = new 
ImmutableMap.Builder<>();
-            for (Slot foreignSlot : foreignKeys) {
-                Set<Slot> primarySlots = 
equivalenceSet.calEqualSet(foreignSlot);
-                if (primarySlots.size() != 1) {
-                    return null;
-                }
-                builder.put(primarySlots.iterator().next(), foreignSlot);
-            }
-            return builder.build();
-        }
-
-        // Right now we only support eliminate inner join, which should meet 
the following condition:
-        // 1. only contain null-reject equal condition, and which all meet 
fk-pk constraint
-        // 2. only output foreign table output or can be converted to foreign 
table output
-        // 4. if foreign key is null, add a isNotNull predicate for 
null-reject join condition
-        private Plan eliminateJoin(LogicalProject<LogicalJoin<?, ?>> project, 
ForeignKeyContext context) {
-            LogicalJoin<?, ?> join = project.child();
+    // Right now we only support eliminate inner join, which should meet the 
following condition:
+    // 1. only contain null-reject equal condition, and which all meet fk-pk 
constraint
+    // 2. only output foreign table output or can be converted to foreign 
table output
+    // 3. if foreign key is null, add a isNotNull predicate for null-reject 
join condition
+    @Override
+    public Rule build() {
+        return logicalProject(
+                logicalJoin().when(join -> join.getJoinType().isInnerJoin())
+        ).then(project -> {
+            LogicalJoin<Plan, Plan> join = project.child();
             ImmutableEqualSet<Slot> equalSet = join.getEqualSlots();
-            Set<Slot> leftSlots = 
Sets.intersection(join.left().getOutputSet(), equalSet.getAllItemSet());
-            Set<Slot> rightSlots = 
Sets.intersection(join.right().getOutputSet(), equalSet.getAllItemSet());
-            if (context.isForeignKey(leftSlots) && 
context.isPrimaryKey(rightSlots)) {
-                Map<Slot, Slot> primaryToForeignSlot = 
mapPrimaryToForeign(equalSet, leftSlots);
-                if (primaryToForeignSlot != null) {
-                    return tryEliminatePrimaryPlan(project, join.left(), 
leftSlots, primaryToForeignSlot, context);
-                }
-            } else if (context.isForeignKey(rightSlots) && 
context.isPrimaryKey(leftSlots)) {
-                Map<Slot, Slot> primaryToForeignSlot = 
mapPrimaryToForeign(equalSet, rightSlots);
-                if (primaryToForeignSlot != null) {
-                    return tryEliminatePrimaryPlan(project, join.right(), 
rightSlots, primaryToForeignSlot, context);
-                }
-            }
-            return project;
-        }
-
-        @Override
-        public Plan visitLogicalProject(LogicalProject<?> project, 
ForeignKeyContext context) {
-            project = visitChildren(this, project, context);
-            for (NamedExpression expression : project.getProjects()) {
-                if (expression instanceof Alias && expression.child(0) 
instanceof Slot) {
-                    context.putAlias(expression.toSlot(), (Slot) 
expression.child(0));
-                }
+            Set<Slot> residualSlot = Sets.difference(project.getInputSlots(), 
equalSet.getAllItemSet());
+            Plan res = null;
+            if (join.left().getOutputSet().containsAll(residualSlot)) {
+                res = tryEliminatePrimary(project, equalSet, join.right(), 
join.left());
             }
-            if (project.child() instanceof LogicalJoin<?, ?>) {
-                return eliminateJoin((LogicalProject<LogicalJoin<?, ?>>) 
project, context);
+            if (res == null && 
join.right().getOutputSet().containsAll(residualSlot)) {
+                res = tryEliminatePrimary(project, equalSet, join.left(), 
join.right());
             }
-            return project;
-        }
-
-        @Override
-        public Plan visitLogicalJoin(LogicalJoin<?, ?> join, ForeignKeyContext 
context) {
-            Plan plan = visitChildren(this, join, context);
-            context.addJoin(join);
-            return plan;
-        }
-
-        @Override
-        public Plan visitLogicalFilter(LogicalFilter<?> filter, 
ForeignKeyContext context) {
-            Plan plan = visitChildren(this, filter, context);
-            context.addFilter(filter);
-            return plan;
-        }
+            return res;
+        }).toRule(RuleType.ELIMINATE_JOIN_BY_UK);
     }
 
-    private static class ForeignKeyContext {
-        Set<Map<Column, Column>> constraints = new HashSet<>();
-        Set<Column> foreignKeys = new HashSet<>();
-        Set<Column> primaryKeys = new HashSet<>();
-        Map<Slot, Column> slotToColumn = new HashMap<>();
-        Map<Slot, Set<LogicalJoin<?, ?>>> slotWithJoin = new HashMap<>();
-        Map<Slot, Set<Expression>> slotWithPredicates = new HashMap<>();
-
-        public void putAllForeignKeys(TableIf table) {
-            table.getForeignKeyConstraints().forEach(c -> {
-                Map<Column, Column> constraint = c.getForeignToPrimary(table);
-                constraints.add(c.getForeignToPrimary(table));
-                foreignKeys.addAll(constraint.keySet());
-                primaryKeys.addAll(constraint.values());
-            });
-        }
-
-        public boolean isForeignKey(Set<Slot> key) {
-            return foreignKeys.containsAll(
-                    key.stream().map(s -> 
slotToColumn.get(s)).collect(Collectors.toSet()));
-        }
-
-        public boolean isPrimaryKey(Set<Slot> key) {
-            return primaryKeys.containsAll(
-                    key.stream().map(s -> 
slotToColumn.get(s)).collect(Collectors.toSet()));
-        }
-
-        public void putSlot(SlotReference slot) {
-            if (!slot.getColumn().isPresent()) {
-                return;
-            }
-            Column c = slot.getColumn().get();
-            slotToColumn.put(slot, c);
-        }
-
-        public void putAlias(Slot newSlot, Slot originSlot) {
-            if (slotToColumn.containsKey(originSlot)) {
-                slotToColumn.put(newSlot, slotToColumn.get(originSlot));
-            }
-        }
-
-        public void addFilter(LogicalFilter<?> filter) {
-            filter.getOutput().stream()
-                    .filter(slotToColumn::containsKey)
-                    .forEach(slot -> {
-                        slotWithPredicates.computeIfAbsent(slot, v -> new 
HashSet<>());
-                        
slotWithPredicates.get(slot).addAll(filter.getConjuncts());
-                    });
-        }
-
-        public void addJoin(LogicalJoin<?, ?> join) {
-            join.getOutput().stream()
-                    .filter(slotToColumn::containsKey)
-                    .forEach(slot ->
-                            slotWithJoin.computeIfAbsent(slot, v -> 
Sets.newHashSet((join))));
-        }
-
-        public void expirePrimaryKey(Plan plan) {
-            plan.getOutput().stream()
-                    .filter(slotToColumn::containsKey)
-                    .map(s -> slotToColumn.get(s))
-                    .forEach(primaryKeys::remove);
-        }
+    private @Nullable Plan 
tryEliminatePrimary(LogicalProject<LogicalJoin<Plan, Plan>> project,
+            ImmutableEqualSet<Slot> equalSet, Plan primary, Plan foreign) {
+        if (!JoinUtils.canEliminateByFk(project.child(), primary, foreign)) {
+            return null;
+        }
+        Set<Slot> output = project.getInputSlots();
+        Set<Slot> foreignKeys = Sets.intersection(foreign.getOutputSet(), 
equalSet.getAllItemSet());
+        Map<Expression, Expression> outputToForeign =
+                tryMapOutputToForeignPlan(foreign, output, equalSet);
+        if (outputToForeign != null) {
+            List<NamedExpression> newProjects = project.getProjects().stream()
+                    .map(e -> outputToForeign.containsKey(e)
+                            ? new Alias(e.getExprId(), outputToForeign.get(e), 
e.toSql())
+                            : (NamedExpression) e.rewriteUp(s -> 
outputToForeign.getOrDefault(s, s)))
+                    .collect(ImmutableList.toImmutableList());
+            return project.withProjects(newProjects)
+                    .withChildren(applyNullCompensationFilter(foreign, 
foreignKeys));
+        }
+        return project;
+    }
 
-        public boolean satisfyConstraint(Map<Slot, Slot> primaryToForeign, 
LogicalJoin<?, ?> join) {
-            Map<Column, Column> foreignToPrimary = 
primaryToForeign.entrySet().stream()
-                    .collect(ImmutableMap.toImmutableMap(
-                            e -> slotToColumn.get(e.getValue()),
-                            e -> slotToColumn.get(e.getKey())));
-            // The primary key can only contain join that may be eliminated
-            if (!primaryToForeign.keySet().stream().allMatch(p ->
-                    slotWithJoin.get(p).size() == 1 && 
slotWithJoin.get(p).iterator().next() == join)) {
-                return false;
-            }
-            // The foreign key's filters must contain primary filters
-            if (!isPredicateCompatible(primaryToForeign)) {
-                return false;
+    private @Nullable Map<Expression, Expression> 
tryMapOutputToForeignPlan(Plan foreignPlan,
+            Set<Slot> output, ImmutableEqualSet<Slot> equalSet) {
+        Set<Slot> residualPrimary = Sets.difference(output, 
foreignPlan.getOutputSet());
+        ImmutableMap.Builder<Expression, Expression> builder = new 
ImmutableMap.Builder<>();
+        for (Slot primarySlot : residualPrimary) {
+            Optional<Slot> replacedForeign = 
equalSet.calEqualSet(primarySlot).stream()
+                    .filter(foreignPlan.getOutputSet()::contains)
+                    .findFirst();
+            if (!replacedForeign.isPresent()) {
+                return null;
             }
-            return constraints.contains(foreignToPrimary);
+            builder.put(primarySlot, replacedForeign.get());
         }
+        return builder.build();
+    }
 
-        // When predicates of foreign keys is a subset of that of primary keys
-        private boolean isPredicateCompatible(Map<Slot, Slot> 
primaryToForeign) {
-            return primaryToForeign.entrySet().stream().allMatch(pf -> {
-                // There is no predicate in primary key
-                if (!slotWithPredicates.containsKey(pf.getKey()) || 
slotWithPredicates.get(pf.getKey()).isEmpty()) {
-                    return true;
-                }
-                // There are some predicates in primary key but there is no 
predicate in foreign key
-                if (slotWithPredicates.containsKey(pf.getValue()) && 
slotWithPredicates.get(pf.getValue()).isEmpty()) {
-                    return false;
-                }
-                Set<Expression> primaryPredicates = 
slotWithPredicates.get(pf.getKey()).stream()
-                        .map(e -> e.rewriteUp(
-                                s -> s instanceof Slot ? 
primaryToForeign.getOrDefault(s, (Slot) s) : s))
-                        .collect(Collectors.toSet());
-                return 
slotWithPredicates.get(pf.getValue()).containsAll(primaryPredicates);
-            });
+    private Plan applyNullCompensationFilter(Plan child, Set<Slot> childSlots) 
{
+        Set<Expression> predicates = childSlots.stream()
+                .filter(ExpressionTrait::nullable)
+                .map(s -> new Not(new IsNull(s)))
+                .collect(ImmutableSet.toImmutableSet());
+        if (predicates.isEmpty()) {
+            return child;
         }
+        return new LogicalFilter<>(predicates, child);
     }
 }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateJoinByUnique.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateJoinByUnique.java
index 009b5efc552..25f157c7bdb 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateJoinByUnique.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateJoinByUnique.java
@@ -37,7 +37,6 @@ public class EliminateJoinByUnique extends 
OneRewriteRuleFactory {
                 return project;
             }
             if (!JoinUtils.canEliminateByLeft(join,
-                    
join.left().getLogicalProperties().getFunctionalDependencies(),
                     
join.right().getLogicalProperties().getFunctionalDependencies())) {
                 return project;
             }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ForeignKeyContext.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ForeignKeyContext.java
new file mode 100644
index 00000000000..9cbac5cf09c
--- /dev/null
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ForeignKeyContext.java
@@ -0,0 +1,184 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.rewrite;
+
+import org.apache.doris.catalog.Column;
+import org.apache.doris.catalog.TableIf;
+import org.apache.doris.nereids.trees.expressions.Alias;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import org.apache.doris.nereids.trees.expressions.Slot;
+import org.apache.doris.nereids.trees.expressions.SlotReference;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalCatalogRelation;
+import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
+import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
+import org.apache.doris.nereids.trees.plans.logical.LogicalRelation;
+import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanVisitor;
+
+import com.google.common.collect.ImmutableMap;
+
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Set;
+import java.util.stream.Collectors;
+
+/**
+ * Record Foreign Key Context
+ */
+public class ForeignKeyContext {
+    Set<Map<Column, Column>> constraints = new HashSet<>();
+    Set<Column> foreignKeys = new HashSet<>();
+    Set<Column> primaryKeys = new HashSet<>();
+    Map<Slot, Column> slotToColumn = new HashMap<>();
+    Map<Slot, Set<Expression>> slotWithPredicates = new HashMap<>();
+
+    /**
+     * Collect Foreign Key Constraint From this Plan
+     */
+    public ForeignKeyContext collectForeignKeyConstraint(Plan plan) {
+        plan.accept(new DefaultPlanVisitor<Void, ForeignKeyContext>() {
+            @Override
+            public Void visit(Plan plan, ForeignKeyContext context) {
+                super.visit(plan, context);
+                // always expire primary key except filter, project and join.
+                // always keep foreign key alive
+                context.expirePrimaryKey(plan);
+                return null;
+            }
+
+            @Override
+            public Void visitLogicalRelation(LogicalRelation relation, 
ForeignKeyContext context) {
+                if (relation instanceof LogicalCatalogRelation) {
+                    context.putAllForeignKeys(((LogicalCatalogRelation) 
relation).getTable());
+                    relation.getOutput().stream()
+                            .filter(SlotReference.class::isInstance)
+                            .map(SlotReference.class::cast)
+                            .forEach(context::putSlot);
+                }
+                return null;
+            }
+
+            @Override
+            public Void visitLogicalProject(LogicalProject<?> project, 
ForeignKeyContext context) {
+                super.visit(project, context);
+                for (NamedExpression expression : project.getProjects()) {
+                    if (expression instanceof Alias && expression.child(0) 
instanceof Slot) {
+                        context.putAlias(expression.toSlot(), (Slot) 
expression.child(0));
+                    }
+                }
+                return null;
+            }
+
+            @Override
+            public Void visitLogicalFilter(LogicalFilter<?> filter, 
ForeignKeyContext context) {
+                super.visit(filter, context);
+                context.addFilter(filter);
+                return null;
+            }
+        }, this);
+        return this;
+    }
+
+    void putAllForeignKeys(TableIf table) {
+        table.getForeignKeyConstraints().forEach(c -> {
+            Map<Column, Column> constraint = c.getForeignToPrimary(table);
+            constraints.add(c.getForeignToPrimary(table));
+            foreignKeys.addAll(constraint.keySet());
+            primaryKeys.addAll(constraint.values());
+        });
+    }
+
+    public boolean isForeignKey(Set<Slot> key) {
+        return foreignKeys.containsAll(
+                key.stream().map(s -> 
slotToColumn.get(s)).collect(Collectors.toSet()));
+    }
+
+    public boolean isPrimaryKey(Set<Slot> key) {
+        return primaryKeys.containsAll(
+                key.stream().map(s -> 
slotToColumn.get(s)).collect(Collectors.toSet()));
+    }
+
+    void putSlot(SlotReference slot) {
+        if (!slot.getColumn().isPresent()) {
+            return;
+        }
+        Column c = slot.getColumn().get();
+        slotToColumn.put(slot, c);
+    }
+
+    void putAlias(Slot newSlot, Slot originSlot) {
+        if (slotToColumn.containsKey(originSlot)) {
+            slotToColumn.put(newSlot, slotToColumn.get(originSlot));
+        }
+    }
+
+    private void addFilter(LogicalFilter<?> filter) {
+        filter.getOutput().stream()
+                .filter(slotToColumn::containsKey)
+                .forEach(slot -> {
+                    slotWithPredicates.computeIfAbsent(slot, v -> new 
HashSet<>());
+                    slotWithPredicates.get(slot).addAll(filter.getConjuncts());
+                });
+    }
+
+    private void expirePrimaryKey(Plan plan) {
+        plan.getOutput().stream()
+                .filter(slotToColumn::containsKey)
+                .map(s -> slotToColumn.get(s))
+                .forEach(primaryKeys::remove);
+    }
+
+    /**
+     * Check whether the given mapping relation satisfies any constraints
+     */
+    public boolean satisfyConstraint(Map<Slot, Slot> primaryToForeign) {
+        Map<Column, Column> foreignToPrimary = 
primaryToForeign.entrySet().stream()
+                .collect(ImmutableMap.toImmutableMap(
+                        e -> slotToColumn.get(e.getValue()),
+                        e -> slotToColumn.get(e.getKey())));
+        if (primaryToForeign.isEmpty()) {
+            return false;
+        }
+        // The foreign key's filters must contain primary filters
+        if (!isPredicateCompatible(primaryToForeign)) {
+            return false;
+        }
+        return constraints.contains(foreignToPrimary);
+    }
+
+    // When predicates of foreign keys is a subset of that of primary keys
+    private boolean isPredicateCompatible(Map<Slot, Slot> primaryToForeign) {
+        return primaryToForeign.entrySet().stream().allMatch(pf -> {
+            // There is no predicate in primary key
+            if (!slotWithPredicates.containsKey(pf.getKey()) || 
slotWithPredicates.get(pf.getKey()).isEmpty()) {
+                return true;
+            }
+            // There are some predicates in primary key but there is no 
predicate in foreign key
+            if (slotWithPredicates.containsKey(pf.getValue()) && 
slotWithPredicates.get(pf.getValue()).isEmpty()) {
+                return false;
+            }
+            Set<Expression> primaryPredicates = 
slotWithPredicates.get(pf.getKey()).stream()
+                    .map(e -> e.rewriteUp(
+                            s -> s instanceof Slot ? 
primaryToForeign.getOrDefault(s, (Slot) s) : s))
+                    .collect(Collectors.toSet());
+            return 
slotWithPredicates.get(pf.getValue()).containsAll(primaryPredicates);
+        });
+    }
+}
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/JoinUtils.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/JoinUtils.java
index 8ed077ba759..9beaa29c433 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/JoinUtils.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/JoinUtils.java
@@ -25,6 +25,7 @@ import 
org.apache.doris.nereids.properties.DistributionSpecHash;
 import org.apache.doris.nereids.properties.DistributionSpecHash.ShuffleType;
 import org.apache.doris.nereids.properties.DistributionSpecReplicated;
 import org.apache.doris.nereids.properties.FunctionalDependencies;
+import org.apache.doris.nereids.rules.rewrite.ForeignKeyContext;
 import org.apache.doris.nereids.trees.expressions.EqualPredicate;
 import org.apache.doris.nereids.trees.expressions.ExprId;
 import org.apache.doris.nereids.trees.expressions.Expression;
@@ -43,7 +44,9 @@ import org.apache.doris.qe.ConnectContext;
 import org.apache.doris.qe.SessionVariable;
 
 import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
 import com.google.common.collect.Lists;
+import com.google.common.collect.Sets;
 
 import java.util.HashSet;
 import java.util.List;
@@ -279,11 +282,46 @@ public class JoinUtils {
                 .collect(ImmutableList.toImmutableList());
     }
 
+    private static Map<Slot, Slot> mapPrimaryToForeign(ImmutableEqualSet<Slot> 
equivalenceSet,
+            Set<Slot> foreignKeys) {
+        ImmutableMap.Builder<Slot, Slot> builder = new 
ImmutableMap.Builder<>();
+        for (Slot foreignSlot : foreignKeys) {
+            Set<Slot> primarySlots = equivalenceSet.calEqualSet(foreignSlot);
+            if (primarySlots.size() != 1) {
+                return ImmutableMap.of();
+            }
+            builder.put(primarySlots.iterator().next(), foreignSlot);
+        }
+        return builder.build();
+    }
+
+    /**
+     * Check whether the given join can be eliminated by pk-fk
+     */
+    public static boolean canEliminateByFk(LogicalJoin<?, ?> join, Plan 
primaryPlan, Plan foreignPlan) {
+        if (!join.getJoinType().isInnerJoin() || 
!join.getOtherJoinConjuncts().isEmpty() || join.isMarkJoin()) {
+            return false;
+        }
+
+        ForeignKeyContext context = new ForeignKeyContext();
+        context.collectForeignKeyConstraint(primaryPlan);
+        context.collectForeignKeyConstraint(foreignPlan);
+
+        ImmutableEqualSet<Slot> equalSet = join.getEqualSlots();
+        Set<Slot> primaryKey = Sets.intersection(equalSet.getAllItemSet(), 
primaryPlan.getOutputSet());
+        Set<Slot> foreignKey = Sets.intersection(equalSet.getAllItemSet(), 
foreignPlan.getOutputSet());
+        if (!context.isForeignKey(foreignKey) || 
!context.isPrimaryKey(primaryKey)) {
+            return false;
+        }
+
+        Map<Slot, Slot> primaryToForeignKey = mapPrimaryToForeign(equalSet, 
foreignKey);
+        return context.satisfyConstraint(primaryToForeignKey);
+    }
+
     /**
      * can this join be eliminated by its left child
      */
-    public static boolean canEliminateByLeft(LogicalJoin<?, ?> join, 
FunctionalDependencies leftFuncDeps,
-            FunctionalDependencies rightFuncDeps) {
+    public static boolean canEliminateByLeft(LogicalJoin<?, ?> join, 
FunctionalDependencies rightFuncDeps) {
         if (join.getJoinType().isLeftOuterJoin()) {
             Pair<Set<Slot>, Set<Slot>> njHashKeys = 
join.extractNullRejectHashKeys();
             if (!join.getOtherJoinConjuncts().isEmpty() || njHashKeys == null) 
{
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/mv/EliminateJoinTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/mv/EliminateJoinTest.java
index f76d6334b28..3e245741178 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/mv/EliminateJoinTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/mv/EliminateJoinTest.java
@@ -27,6 +27,7 @@ import org.apache.doris.nereids.trees.plans.Plan;
 import org.apache.doris.nereids.util.PlanChecker;
 
 import org.junit.jupiter.api.Assertions;
+import org.junit.jupiter.api.Disabled;
 import org.junit.jupiter.api.Test;
 
 class EliminateJoinTest extends SqlTestBase {
@@ -55,6 +56,7 @@ class EliminateJoinTest extends SqlTestBase {
         HyperGraph h2 = HyperGraph.builderForMv(p2).buildAll().get(0);
         ComparisonResult res = HyperGraphComparator.isLogicCompatible(h1, h2, 
constructContext(p1, p2));
         Assertions.assertTrue(!res.isInvalid());
+        Assertions.assertTrue(res.getViewExpressions().isEmpty());
     }
 
     @Test
@@ -83,6 +85,107 @@ class EliminateJoinTest extends SqlTestBase {
         HyperGraph h2 = HyperGraph.builderForMv(p2).buildAll().get(0);
         ComparisonResult res = HyperGraphComparator.isLogicCompatible(h1, h2, 
constructContext(p1, p2));
         Assertions.assertTrue(!res.isInvalid());
+        Assertions.assertTrue(res.getViewExpressions().isEmpty());
+        dropConstraint("alter table T2 drop constraint uk");
+    }
+
+    @Test
+    void testLOJWithPKFK() throws Exception {
+        
connectContext.getSessionVariable().setDisableNereidsRules("INFER_PREDICATES");
+        CascadesContext c1 = createCascadesContext(
+                "select * from T1",
+                connectContext
+        );
+        Plan p1 = PlanChecker.from(c1)
+                .analyze()
+                .rewrite()
+                .getPlan().child(0);
+        addConstraint("alter table T2 add constraint pk primary key (id)");
+        addConstraint("alter table T1 add constraint fk foreign key (id) 
references T2(id)");
+        CascadesContext c2 = createCascadesContext(
+                "select * from T1 inner join T2 "
+                        + "on T1.id = T2.id ",
+                connectContext
+        );
+        Plan p2 = PlanChecker.from(c2)
+                .analyze()
+                .rewrite()
+                .applyExploration(RuleSet.BUSHY_TREE_JOIN_REORDER)
+                .getAllPlan().get(0).child(0);
+        HyperGraph h1 = HyperGraph.builderForMv(p1).buildAll().get(0);
+        HyperGraph h2 = HyperGraph.builderForMv(p2).buildAll().get(0);
+        ComparisonResult res = HyperGraphComparator.isLogicCompatible(h1, h2, 
constructContext(p1, p2));
+        Assertions.assertTrue(!res.isInvalid());
+        Assertions.assertTrue(res.getViewExpressions().isEmpty());
+        dropConstraint("alter table T2 drop constraint pk");
+    }
+
+    @Disabled
+    @Test
+    void testLOJWithPKFKAndUK1() throws Exception {
+        
connectContext.getSessionVariable().setDisableNereidsRules("INFER_PREDICATES");
+        CascadesContext c1 = createCascadesContext(
+                "select * from T1",
+                connectContext
+        );
+        Plan p1 = PlanChecker.from(c1)
+                .analyze()
+                .rewrite()
+                .getPlan().child(0);
+        addConstraint("alter table T2 add constraint pk primary key (id)");
+        addConstraint("alter table T1 add constraint fk foreign key (id) 
references T2(id)");
+        addConstraint("alter table T3 add constraint uk unique (id)");
+        CascadesContext c2 = createCascadesContext(
+                "select * from (select T1.*, T3.id as id3 from T1 left outer 
join T3 on T1.id = T3.id) T1 inner join T2 "
+                        + "on T1.id = T2.id ",
+                connectContext
+        );
+        Plan p2 = PlanChecker.from(c2)
+                .analyze()
+                .rewrite()
+                .applyExploration(RuleSet.BUSHY_TREE_JOIN_REORDER)
+                .getAllPlan().get(0).child(0);
+        HyperGraph h1 = HyperGraph.builderForMv(p1).buildAll().get(0);
+        HyperGraph h2 = HyperGraph.builderForMv(p2).buildAll().get(0);
+        ComparisonResult res = HyperGraphComparator.isLogicCompatible(h1, h2, 
constructContext(p1, p2));
+        Assertions.assertTrue(!res.isInvalid());
+        Assertions.assertTrue(res.getViewExpressions().isEmpty());
+        dropConstraint("alter table T2 drop constraint pk");
+        dropConstraint("alter table T3 drop constraint uk");
+    }
+
+    @Disabled
+    @Test
+    void testLOJWithPKFKAndUK2() throws Exception {
+        
connectContext.getSessionVariable().setDisableNereidsRules("INFER_PREDICATES");
+        CascadesContext c1 = createCascadesContext(
+                "select * from T1",
+                connectContext
+        );
+        Plan p1 = PlanChecker.from(c1)
+                .analyze()
+                .rewrite()
+                .getPlan().child(0);
+        addConstraint("alter table T2 add constraint pk primary key (id)");
+        addConstraint("alter table T1 add constraint fk foreign key (id) 
references T2(id)");
+        addConstraint("alter table T3 add constraint uk unique (id)");
+        CascadesContext c2 = createCascadesContext(
+                "select * from (select T1.*, T2.id as id2 from T1 inner join 
T2 on T1.id = T2.id) T1 left outer join T3 "
+                        + "on T1.id = T3.id ",
+                connectContext
+        );
+        Plan p2 = PlanChecker.from(c2)
+                .analyze()
+                .rewrite()
+                .applyExploration(RuleSet.BUSHY_TREE_JOIN_REORDER)
+                .getAllPlan().get(0).child(0);
+        HyperGraph h1 = HyperGraph.builderForMv(p1).buildAll().get(0);
+        HyperGraph h2 = HyperGraph.builderForMv(p2).buildAll().get(0);
+        ComparisonResult res = HyperGraphComparator.isLogicCompatible(h1, h2, 
constructContext(p1, p2));
+        Assertions.assertTrue(!res.isInvalid());
+        Assertions.assertTrue(res.getViewExpressions().isEmpty());
+        dropConstraint("alter table T2 drop constraint pk");
+        dropConstraint("alter table T3 drop constraint uk");
     }
 
     LogicalCompatibilityContext constructContext(Plan p1, Plan p2) {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to