This is an automated email from the ASF dual-hosted git repository.

morrysnow pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new 223454d1db7 [feature](Nereids) elimite inner join by foreign key 
(#28486)
223454d1db7 is described below

commit 223454d1db71866ae52e925a000c4b20ac745c6c
Author: 谢健 <[email protected]>
AuthorDate: Thu Dec 21 12:58:55 2023 +0800

    [feature](Nereids) elimite inner join by foreign key (#28486)
---
 .../java/org/apache/doris/catalog/TableIf.java     |  36 +++
 .../catalog/constraint/ForeignKeyConstraint.java   |   9 +
 .../catalog/constraint/PrimaryKeyConstraint.java   |   5 +
 .../doris/catalog/constraint/UniqueConstraint.java |   7 +
 .../doris/nereids/jobs/executor/Rewriter.java      |   3 +
 .../org/apache/doris/nereids/rules/RuleType.java   |   1 +
 .../nereids/rules/rewrite/EliminateJoinByFK.java   | 333 +++++++++++++++++++++
 .../plans/logical/LogicalCatalogRelation.java      |  25 +-
 .../nereids/trees/plans/logical/LogicalJoin.java   |  20 ++
 .../nereids/util/ImmutableEquivalenceSet.java      | 101 +++++++
 .../rules/rewrite/EliminateJoinByFkTest.java       | 121 ++++++++
 .../apache/doris/utframe/TestWithFeService.java    |  18 ++
 12 files changed, 678 insertions(+), 1 deletion(-)

diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/TableIf.java 
b/fe/fe-core/src/main/java/org/apache/doris/catalog/TableIf.java
index eb472d8884f..a188d8f7ae2 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/catalog/TableIf.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/TableIf.java
@@ -165,6 +165,42 @@ public interface TableIf {
         throw new RuntimeException(String.format("Not implemented constraint 
for table %s", this));
     }
 
+    default Set<ForeignKeyConstraint> getForeignKeyConstraints() {
+        readLock();
+        try {
+            return getConstraintsMap().values().stream()
+                    .filter(ForeignKeyConstraint.class::isInstance)
+                    .map(ForeignKeyConstraint.class::cast)
+                    .collect(ImmutableSet.toImmutableSet());
+        } finally {
+            readUnlock();
+        }
+    }
+
+    default Set<PrimaryKeyConstraint> getPrimaryKeyConstraints() {
+        readLock();
+        try {
+            return getConstraintsMap().values().stream()
+                    .filter(PrimaryKeyConstraint.class::isInstance)
+                    .map(PrimaryKeyConstraint.class::cast)
+                    .collect(ImmutableSet.toImmutableSet());
+        } finally {
+            readUnlock();
+        }
+    }
+
+    default Set<UniqueConstraint> getUniqueConstraints() {
+        readLock();
+        try {
+            return getConstraintsMap().values().stream()
+                    .filter(UniqueConstraint.class::isInstance)
+                    .map(UniqueConstraint.class::cast)
+                    .collect(ImmutableSet.toImmutableSet());
+        } finally {
+            readUnlock();
+        }
+    }
+
     // Note this function is not thread safe
     default void checkConstraintNotExistence(String name, Constraint 
primaryKeyConstraint,
             Map<String, Constraint> constraintMap) {
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/catalog/constraint/ForeignKeyConstraint.java
 
b/fe/fe-core/src/main/java/org/apache/doris/catalog/constraint/ForeignKeyConstraint.java
index cae63abe13b..b8097e4665c 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/catalog/constraint/ForeignKeyConstraint.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/catalog/constraint/ForeignKeyConstraint.java
@@ -26,6 +26,7 @@ import com.google.common.collect.ImmutableMap.Builder;
 import com.google.common.collect.ImmutableSet;
 
 import java.util.List;
+import java.util.Map;
 import java.util.Objects;
 
 public class ForeignKeyConstraint extends Constraint {
@@ -61,6 +62,14 @@ public class ForeignKeyConstraint extends Constraint {
         return foreignToReference.get(column);
     }
 
+    public Map<Column, Column> getForeignToPrimary(TableIf curTable) {
+        ImmutableMap.Builder<Column, Column> columnBuilder = new 
ImmutableMap.Builder<>();
+        TableIf refTable = referencedTable.toTableIf();
+        foreignToReference.forEach((k, v) ->
+                columnBuilder.put(curTable.getColumn(k), 
refTable.getColumn(v)));
+        return columnBuilder.build();
+    }
+
     public Column getReferencedColumn(String column) {
         return getReferencedTable().getColumn(getReferencedColumnName(column));
     }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/catalog/constraint/PrimaryKeyConstraint.java
 
b/fe/fe-core/src/main/java/org/apache/doris/catalog/constraint/PrimaryKeyConstraint.java
index 02d59788ad5..fd894c498cb 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/catalog/constraint/PrimaryKeyConstraint.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/catalog/constraint/PrimaryKeyConstraint.java
@@ -17,6 +17,7 @@
 
 package org.apache.doris.catalog.constraint;
 
+import org.apache.doris.catalog.Column;
 import org.apache.doris.catalog.TableIf;
 
 import com.google.common.base.Objects;
@@ -42,6 +43,10 @@ public class PrimaryKeyConstraint extends Constraint {
         return columns;
     }
 
+    public ImmutableSet<Column> getPrimaryKeys(TableIf table) {
+        return 
columns.stream().map(table::getColumn).collect(ImmutableSet.toImmutableSet());
+    }
+
     public void addForeignTable(TableIf table) {
         foreignTables.add(new TableIdentifier(table));
     }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/catalog/constraint/UniqueConstraint.java
 
b/fe/fe-core/src/main/java/org/apache/doris/catalog/constraint/UniqueConstraint.java
index 975ff0937b3..2fc7fbb2612 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/catalog/constraint/UniqueConstraint.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/catalog/constraint/UniqueConstraint.java
@@ -17,6 +17,9 @@
 
 package org.apache.doris.catalog.constraint;
 
+import org.apache.doris.catalog.Column;
+import org.apache.doris.catalog.TableIf;
+
 import com.google.common.base.Objects;
 import com.google.common.collect.ImmutableSet;
 
@@ -34,6 +37,10 @@ public class UniqueConstraint extends Constraint {
         return columns;
     }
 
+    public ImmutableSet<Column> getUniqueKeys(TableIf table) {
+        return 
columns.stream().map(table::getColumn).collect(ImmutableSet.toImmutableSet());
+    }
+
     @Override
     public boolean equals(Object o) {
         if (this == o) {
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
index 16ad9ffd82d..58aca3bef07 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
@@ -55,6 +55,7 @@ import 
org.apache.doris.nereids.rules.rewrite.EliminateAssertNumRows;
 import org.apache.doris.nereids.rules.rewrite.EliminateDedupJoinCondition;
 import org.apache.doris.nereids.rules.rewrite.EliminateEmptyRelation;
 import org.apache.doris.nereids.rules.rewrite.EliminateFilter;
+import org.apache.doris.nereids.rules.rewrite.EliminateJoinByFK;
 import org.apache.doris.nereids.rules.rewrite.EliminateJoinCondition;
 import org.apache.doris.nereids.rules.rewrite.EliminateLimit;
 import org.apache.doris.nereids.rules.rewrite.EliminateNotNull;
@@ -285,6 +286,8 @@ public class Rewriter extends AbstractBatchJobExecutor {
                     custom(RuleType.PUSH_DOWN_DISTINCT_THROUGH_JOIN, 
PushDownDistinctThroughJoin::new)
             ),
 
+            // this rule should invoke after infer predicate and push down 
distinct, and before push down limit
+            custom(RuleType.ELIMINATE_JOIN_BY_FOREIGN_KEY, 
EliminateJoinByFK::new),
             topic("Limit optimization",
                     // TODO: the logical plan should not contains any phase 
information,
                     //       we should refactor like AggregateStrategies, e.g. 
LimitStrategies,
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
index 3a9b7cdce5c..b21f5f02dae 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
@@ -200,6 +200,7 @@ public enum RuleType {
     ELIMINATE_LIMIT_ON_EMPTY_RELATION(RuleTypeClass.REWRITE),
     ELIMINATE_FILTER(RuleTypeClass.REWRITE),
     ELIMINATE_JOIN(RuleTypeClass.REWRITE),
+    ELIMINATE_JOIN_BY_FOREIGN_KEY(RuleTypeClass.REWRITE),
     ELIMINATE_JOIN_CONDITION(RuleTypeClass.REWRITE),
     ELIMINATE_FILTER_ON_ONE_RELATION(RuleTypeClass.REWRITE),
     ELIMINATE_SEMI_JOIN(RuleTypeClass.REWRITE),
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateJoinByFK.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateJoinByFK.java
new file mode 100644
index 00000000000..078657827fa
--- /dev/null
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateJoinByFK.java
@@ -0,0 +1,333 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.rewrite;
+
+import org.apache.doris.catalog.Column;
+import org.apache.doris.catalog.Table;
+import org.apache.doris.catalog.TableIf;
+import org.apache.doris.nereids.jobs.JobContext;
+import org.apache.doris.nereids.trees.expressions.Alias;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.IsNull;
+import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import org.apache.doris.nereids.trees.expressions.Not;
+import org.apache.doris.nereids.trees.expressions.Slot;
+import org.apache.doris.nereids.trees.expressions.SlotReference;
+import org.apache.doris.nereids.trees.expressions.functions.ExpressionTrait;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
+import org.apache.doris.nereids.trees.plans.logical.LogicalCatalogRelation;
+import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
+import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
+import org.apache.doris.nereids.trees.plans.logical.LogicalLimit;
+import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
+import org.apache.doris.nereids.trees.plans.logical.LogicalRelation;
+import org.apache.doris.nereids.trees.plans.logical.LogicalSort;
+import org.apache.doris.nereids.trees.plans.logical.LogicalTopN;
+import org.apache.doris.nereids.trees.plans.logical.LogicalWindow;
+import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter;
+import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter;
+import org.apache.doris.nereids.util.ImmutableEquivalenceSet;
+
+import com.google.common.collect.BiMap;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.ImmutableMap.Builder;
+import com.google.common.collect.ImmutableSet;
+import com.google.common.collect.Sets;
+
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Map.Entry;
+import java.util.Set;
+import java.util.stream.Collectors;
+import javax.annotation.Nullable;
+
+/**
+ * Eliminate join by foreign.
+ */
+public class EliminateJoinByFK extends DefaultPlanRewriter<JobContext> 
implements CustomRewriter {
+    @Override
+    public Plan rewriteRoot(Plan plan, JobContext jobContext) {
+        EliminateJoinByFKHelper helper = new EliminateJoinByFKHelper();
+        return helper.rewriteRoot(plan, jobContext);
+    }
+
+    private static class EliminateJoinByFKHelper
+            extends DefaultPlanRewriter<ForeignKeyContext> implements 
CustomRewriter {
+
+        @Override
+        public Plan rewriteRoot(Plan plan, JobContext jobContext) {
+            return plan.accept(this, new ForeignKeyContext());
+        }
+
+        @Override
+        public Plan visit(Plan plan, ForeignKeyContext context) {
+            Plan newPlan = visitChildren(this, plan, context);
+            context.passThroughPlan(plan);
+            return newPlan;
+        }
+
+        @Override
+        public Plan visitLogicalRelation(LogicalRelation relation, 
ForeignKeyContext context) {
+            if (!(relation instanceof LogicalCatalogRelation)) {
+                return relation;
+            }
+            if (!(((LogicalCatalogRelation) relation).getTable() instanceof 
Table)) {
+                return relation;
+            }
+            context.putAllForeignKeys(((LogicalCatalogRelation) 
relation).getTable());
+            relation.getOutput().stream()
+                    .filter(SlotReference.class::isInstance)
+                    .map(SlotReference.class::cast)
+                    .forEach(context::putSlot);
+            return relation;
+        }
+
+        private boolean canEliminate(LogicalJoin<?, ?> join, BiMap<Slot, Slot> 
equalSlots, ForeignKeyContext context) {
+            if (!join.getOtherJoinConjuncts().isEmpty()) {
+                return false;
+            }
+            if (!join.getJoinType().isInnerJoin() && 
!join.getJoinType().isSemiJoin()) {
+                return false;
+            }
+            return context.satisfyConstraint(equalSlots, join);
+        }
+
+        private boolean isForeignKeyAndUnique(Plan plan,
+                Set<Slot> keySet, ForeignKeyContext context) {
+            boolean unique = keySet.stream()
+                    .allMatch(s -> 
plan.getLogicalProperties().getFunctionalDependencies().isUnique(s));
+            return unique && context.isForeignKey(keySet);
+        }
+
+        private @Nullable Map<Expression, Expression> 
tryGetOutputToChildMap(Plan child,
+                Set<Slot> output, BiMap<Slot, Slot> equalSlots) {
+            Set<Slot> residual = Sets.difference(output, child.getOutputSet());
+            if (equalSlots.keySet().containsAll(residual)) {
+                return residual.stream()
+                        .collect(ImmutableMap.toImmutableMap(e -> e, 
equalSlots::get));
+            }
+            if (equalSlots.values().containsAll(residual)) {
+                return residual.stream()
+                        .collect(ImmutableMap.toImmutableMap(e -> e, e -> 
equalSlots.inverse().get(e)));
+            }
+            return null;
+        }
+
+        private Plan applyNullCompensationFilter(Plan child, Set<Slot> 
childSlots) {
+            Set<Expression> predicates = childSlots.stream()
+                    .filter(ExpressionTrait::nullable)
+                    .map(s -> new Not(new IsNull(s)))
+                    .collect(ImmutableSet.toImmutableSet());
+            if (predicates.isEmpty()) {
+                return child;
+            }
+            return new LogicalFilter<>(predicates, child);
+        }
+
+        private @Nullable Plan 
tryConstructPlanWithJoinChild(LogicalProject<LogicalJoin<?, ?>> project, Plan 
child,
+                BiMap<Slot, Slot> equalSlots, ForeignKeyContext context) {
+            Set<Slot> output = project.getInputSlots();
+            Set<Slot> keySet = 
child.getOutputSet().containsAll(equalSlots.keySet())
+                    ? equalSlots.keySet()
+                    : equalSlots.values();
+            Map<Expression, Expression> outputToRight = 
tryGetOutputToChildMap(child, output, equalSlots);
+            if (outputToRight != null && isForeignKeyAndUnique(child, keySet, 
context)) {
+                List<NamedExpression> newProjects = 
project.getProjects().stream()
+                        .map(e -> outputToRight.containsKey(e)
+                                ? new Alias(e.getExprId(), 
outputToRight.get(e), e.toSql())
+                                : (NamedExpression) e.rewriteUp(s -> 
outputToRight.getOrDefault(s, s)))
+                        .collect(ImmutableList.toImmutableList());
+                return project.withProjects(newProjects)
+                        .withChildren(applyNullCompensationFilter(child, 
keySet));
+            }
+            return null;
+        }
+
+        // Right now we only support eliminate inner join, which should meet 
the following condition:
+        // 1. only contain null-reject equal condition, and which all meet 
fk-pk constraint
+        // 2. only output foreign table output or can be converted to foreign 
table output
+        // 3. foreign key is unique
+        // 4. if foreign key is null, add a isNotNull predicate for 
null-reject join condition
+        private Plan eliminateJoin(LogicalProject<LogicalJoin<?, ?>> project, 
ForeignKeyContext context) {
+            LogicalJoin<?, ?> join = project.child();
+            ImmutableEquivalenceSet<Slot> equalSet = join.getEqualSlots();
+            BiMap<Slot, Slot> equalSlots = equalSet.tryToMap();
+            if (equalSlots == null) {
+                return project;
+            }
+            if (!canEliminate(join, equalSlots, context)) {
+                return project;
+            }
+            Plan keepLeft = tryConstructPlanWithJoinChild(project, 
join.left(), equalSlots, context);
+            if (keepLeft != null) {
+                return keepLeft;
+            }
+            Plan keepRight = tryConstructPlanWithJoinChild(project, 
join.right(), equalSlots, context);
+            if (keepRight != null) {
+                return keepRight;
+            }
+            return project;
+        }
+
+        @Override
+        public Plan visitLogicalProject(LogicalProject<?> project, 
ForeignKeyContext context) {
+            project = visitChildren(this, project, context);
+            for (NamedExpression expression : project.getProjects()) {
+                if (expression instanceof Alias && expression.child(0) 
instanceof Slot) {
+                    context.putAlias(expression.toSlot(), (Slot) 
expression.child(0));
+                }
+            }
+            if (project.child() instanceof LogicalJoin<?, ?>) {
+                return eliminateJoin((LogicalProject<LogicalJoin<?, ?>>) 
project, context);
+            }
+            return project;
+        }
+    }
+
+    private static class ForeignKeyContext {
+        static Set<Class<?>> propagatePrimaryKeyOperator = ImmutableSet
+                .<Class<?>>builder()
+                .add(LogicalProject.class)
+                .add(LogicalSort.class)
+                .add(LogicalJoin.class)
+                .build();
+        static Set<Class<?>> propagateForeignKeyOperator = ImmutableSet
+                .<Class<?>>builder()
+                .add(LogicalProject.class)
+                .add(LogicalSort.class)
+                .add(LogicalJoin.class)
+                .add(LogicalFilter.class)
+                .add(LogicalTopN.class)
+                .add(LogicalLimit.class)
+                .add(LogicalAggregate.class)
+                .add(LogicalWindow.class)
+                .build();
+        Set<Map<Column, Column>> constraints = new HashSet<>();
+        Set<Column> foreignKeys = new HashSet<>();
+        Set<Column> primaryKeys = new HashSet<>();
+        Map<Slot, Column> slotToColumn = new HashMap<>();
+        Map<Column, Set<LogicalJoin<?, ?>>> columnWithJoin = new HashMap<>();
+        Map<Column, Set<Expression>> columnWithPredicates = new HashMap<>();
+
+        public void putAllForeignKeys(TableIf table) {
+            table.getForeignKeyConstraints().forEach(c -> {
+                Map<Column, Column> constraint = c.getForeignToPrimary(table);
+                constraints.add(c.getForeignToPrimary(table));
+                foreignKeys.addAll(constraint.keySet());
+                primaryKeys.addAll(constraint.values());
+            });
+        }
+
+        public boolean isForeignKey(Set<Slot> key) {
+            return foreignKeys.containsAll(
+                    key.stream().map(s -> 
slotToColumn.get(s)).collect(Collectors.toSet()));
+        }
+
+        public void putSlot(SlotReference slot) {
+            if (!slot.getColumn().isPresent()) {
+                return;
+            }
+            Column c = slot.getColumn().get();
+            slotToColumn.put(slot, c);
+        }
+
+        public void putAlias(Slot newSlot, Slot originSlot) {
+            if (slotToColumn.containsKey(originSlot)) {
+                slotToColumn.put(newSlot, slotToColumn.get(originSlot));
+            }
+        }
+
+        public void passThroughPlan(Plan plan) {
+            Set<Column> output = plan.getOutput().stream()
+                    .filter(slotToColumn::containsKey)
+                    .map(s -> slotToColumn.get(s))
+                    .collect(ImmutableSet.toImmutableSet());
+            if (plan instanceof LogicalJoin) {
+                output.forEach(c ->
+                        columnWithJoin.computeIfAbsent(c, v -> 
Sets.newHashSet((LogicalJoin<?, ?>) plan)));
+                return;
+            }
+            if (plan instanceof LogicalFilter) {
+                output.forEach(c -> {
+                    columnWithPredicates.computeIfAbsent(c, v -> new 
HashSet<>());
+                    columnWithPredicates.get(c).addAll(((LogicalFilter<?>) 
plan).getConjuncts());
+                });
+                return;
+            }
+            if (!propagatePrimaryKeyOperator.contains(plan.getClass())) {
+                output.forEach(primaryKeys::remove);
+            }
+            if (!propagateForeignKeyOperator.contains(plan.getClass())) {
+                output.forEach(foreignKeys::remove);
+            }
+        }
+
+        public boolean satisfyConstraint(BiMap<Slot, Slot> equalSlots, 
LogicalJoin<?, ?> join) {
+            ImmutableMap.Builder<Column, Column> foreignToPrimaryBuilder = new 
Builder<>();
+            for (Entry<Slot, Slot> entry : equalSlots.entrySet()) {
+                Slot left = entry.getKey();
+                Slot right = entry.getValue();
+                if (!slotToColumn.containsKey(left) || 
!slotToColumn.containsKey(right)) {
+                    return false;
+                }
+                Column leftColumn = slotToColumn.get(left);
+                Column rightColumn = slotToColumn.get(right);
+                if (foreignKeys.contains(leftColumn)) {
+                    foreignToPrimaryBuilder.put(leftColumn, rightColumn);
+                } else if (foreignKeys.contains(rightColumn)) {
+                    foreignToPrimaryBuilder.put(rightColumn, leftColumn);
+                } else {
+                    return false;
+                }
+            }
+            Map<Column, Column> foreignToPrimary = 
foreignToPrimaryBuilder.build();
+            // The primary key can only contain join that may be eliminated
+            if (!foreignToPrimary.values().stream().allMatch(p ->
+                    columnWithJoin.get(p).size() == 1 && 
columnWithJoin.get(p).iterator().next() == join)) {
+                return false;
+            }
+            // The foreign key's filters must contain primary filters
+            if (!isPredicateCompatible(equalSlots, foreignToPrimary)) {
+                return false;
+            }
+            return constraints.contains(foreignToPrimary);
+        }
+
+        private boolean isPredicateCompatible(BiMap<Slot, Slot> equalSlots, 
Map<Column, Column> foreignToPrimary) {
+            return foreignToPrimary.entrySet().stream().allMatch(fp -> {
+                BiMap<Slot, Slot> primarySlotToForeign = 
equalSlots.keySet().stream()
+                        .map(slotToColumn::get)
+                        .anyMatch(primaryKeys::contains)
+                        ? equalSlots :
+                        equalSlots.inverse();
+                if (!columnWithPredicates.containsKey(fp.getValue())) {
+                    return true;
+                }
+                Set<Expression> primaryPredicates = 
columnWithPredicates.get(fp.getValue()).stream()
+                        .map(e -> e.rewriteUp(
+                                s -> s instanceof Slot ? 
primarySlotToForeign.getOrDefault(s, (Slot) s) : s))
+                        .collect(Collectors.toSet());
+                return 
columnWithPredicates.get(fp.getKey()).containsAll(primaryPredicates);
+            });
+        }
+    }
+}
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalCatalogRelation.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalCatalogRelation.java
index ed1e6a588be..e50a049e0f4 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalCatalogRelation.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalCatalogRelation.java
@@ -17,6 +17,7 @@
 
 package org.apache.doris.nereids.trees.plans.logical;
 
+import org.apache.doris.catalog.Column;
 import org.apache.doris.catalog.DatabaseIf;
 import org.apache.doris.catalog.Env;
 import org.apache.doris.catalog.OlapTable;
@@ -42,6 +43,7 @@ import org.apache.commons.lang3.StringUtils;
 import java.util.List;
 import java.util.Objects;
 import java.util.Optional;
+import java.util.Set;
 import java.util.function.Supplier;
 
 /**
@@ -122,8 +124,9 @@ public abstract class LogicalCatalogRelation extends 
LogicalRelation implements
     @Override
     public FunctionalDependencies computeFuncDeps(Supplier<List<Slot>> 
outputSupplier) {
         Builder fdBuilder = new Builder();
+        Set<Slot> output = ImmutableSet.copyOf(outputSupplier.get());
         if (table instanceof OlapTable && ((OlapTable) 
table).getKeysType().isAggregationFamily()) {
-            ImmutableSet<Slot> slotSet = computeOutput().stream()
+            ImmutableSet<Slot> slotSet = output.stream()
                     .filter(SlotReference.class::isInstance)
                     .map(SlotReference.class::cast)
                     .filter(s -> s.getColumn().isPresent()
@@ -131,6 +134,26 @@ public abstract class LogicalCatalogRelation extends 
LogicalRelation implements
                     .collect(ImmutableSet.toImmutableSet());
             fdBuilder.addUniqueSlot(slotSet);
         }
+        table.getPrimaryKeyConstraints().forEach(c -> {
+            Set<Column> columns = c.getPrimaryKeys(this.getTable());
+            ImmutableSet<Slot> slotSet = output.stream()
+                    .filter(SlotReference.class::isInstance)
+                    .map(SlotReference.class::cast)
+                    .filter(s -> s.getColumn().isPresent()
+                            && columns.contains(s.getColumn().get()))
+                    .collect(ImmutableSet.toImmutableSet());
+            fdBuilder.addUniqueSlot(slotSet);
+        });
+        table.getUniqueConstraints().forEach(c -> {
+            Set<Column> columns = c.getUniqueKeys(this.getTable());
+            ImmutableSet<Slot> slotSet = output.stream()
+                    .filter(SlotReference.class::isInstance)
+                    .map(SlotReference.class::cast)
+                    .filter(s -> s.getColumn().isPresent()
+                            && columns.contains(s.getColumn().get()))
+                    .collect(ImmutableSet.toImmutableSet());
+            fdBuilder.addUniqueSlot(slotSet);
+        });
         return fdBuilder.build();
     }
 }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalJoin.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalJoin.java
index 70c28b41e8a..c771f0372e1 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalJoin.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalJoin.java
@@ -24,6 +24,7 @@ import 
org.apache.doris.nereids.properties.FunctionalDependencies;
 import org.apache.doris.nereids.properties.FunctionalDependencies.Builder;
 import org.apache.doris.nereids.properties.LogicalProperties;
 import org.apache.doris.nereids.rules.exploration.join.JoinReorderContext;
+import org.apache.doris.nereids.trees.expressions.EqualPredicate;
 import org.apache.doris.nereids.trees.expressions.EqualTo;
 import org.apache.doris.nereids.trees.expressions.ExprId;
 import org.apache.doris.nereids.trees.expressions.Expression;
@@ -36,6 +37,7 @@ import org.apache.doris.nereids.trees.plans.PlanType;
 import org.apache.doris.nereids.trees.plans.algebra.Join;
 import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor;
 import org.apache.doris.nereids.util.ExpressionUtils;
+import org.apache.doris.nereids.util.ImmutableEquivalenceSet;
 import org.apache.doris.nereids.util.JoinUtils;
 import org.apache.doris.nereids.util.Utils;
 
@@ -454,6 +456,24 @@ public class LogicalJoin<LEFT_CHILD_TYPE extends Plan, 
RIGHT_CHILD_TYPE extends
         return fdBuilder.build();
     }
 
+    /**
+     * get Equal slot from join
+     */
+    public ImmutableEquivalenceSet<Slot> getEqualSlots() {
+        // TODO: Use fd in the future
+        if (!joinType.isInnerJoin() && !joinType.isSemiJoin()) {
+            return ImmutableEquivalenceSet.of();
+        }
+        ImmutableEquivalenceSet.Builder<Slot> builder = new 
ImmutableEquivalenceSet.Builder<>();
+        hashJoinConjuncts.stream()
+                .filter(e -> e instanceof EqualPredicate
+                        && e.child(0) instanceof Slot
+                        && e.child(1) instanceof Slot)
+                .forEach(e ->
+                        builder.addEqualPair((Slot) e.child(0), (Slot) 
e.child(1)));
+        return builder.build();
+    }
+
     @Override
     public JSONObject toJson() {
         JSONObject logicalJoin = super.toJson();
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ImmutableEquivalenceSet.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ImmutableEquivalenceSet.java
new file mode 100644
index 00000000000..e54ecb7c9e7
--- /dev/null
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ImmutableEquivalenceSet.java
@@ -0,0 +1,101 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.util;
+
+import com.google.common.collect.BiMap;
+import com.google.common.collect.ImmutableBiMap;
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.ImmutableSet;
+
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Set;
+import javax.annotation.Nullable;
+
+/**
+ * EquivalenceSet
+ */
+public class ImmutableEquivalenceSet<T> {
+    final Map<T, T> root;
+
+    ImmutableEquivalenceSet(Map<T, T> root) {
+        this.root = ImmutableMap.copyOf(root);
+    }
+
+    public static <T> ImmutableEquivalenceSet<T> of() {
+        return new ImmutableEquivalenceSet<>(ImmutableMap.of());
+    }
+
+    /**
+     * Builder of ImmutableEquivalenceSet
+     */
+    public static class Builder<T> {
+        final Map<T, T> parent = new HashMap<>();
+
+        public void addEqualPair(T a, T b) {
+            parent.computeIfAbsent(b, v -> v);
+            parent.computeIfAbsent(a, v -> v);
+            union(a, b);
+        }
+
+        private void union(T a, T b) {
+            T root1 = findRoot(a);
+            T root2 = findRoot(b);
+
+            if (root1 != root2) {
+                parent.put(b, root1);
+                findRoot(b);
+            }
+        }
+
+        private T findRoot(T a) {
+            if (!parent.get(a).equals(a)) {
+                parent.put(a, findRoot(parent.get(a)));
+            }
+            return parent.get(a);
+        }
+
+        public ImmutableEquivalenceSet<T> build() {
+            parent.keySet().forEach(this::findRoot);
+            return new ImmutableEquivalenceSet<>(parent);
+        }
+    }
+
+    /**
+     * cal equal set for a
+     */
+    public Set<T> calEqualSet(T a) {
+        T ra = root.get(a);
+        return root.keySet().stream()
+                .filter(t -> root.get(t).equals(ra))
+                .collect(ImmutableSet.toImmutableSet());
+    }
+
+    /**
+     * try to convert it to a map, such as a = b c = d.
+     * When meets a = b a = c, return null
+     */
+    public @Nullable BiMap<T, T> tryToMap() {
+        if (root.values().stream().distinct().count() * 2 != root.size()) {
+            return null;
+        }
+        return root.keySet().stream()
+                .filter(t -> !root.get(t).equals(t))
+                .collect(ImmutableBiMap.toImmutableBiMap(t -> t, root::get));
+    }
+}
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/EliminateJoinByFkTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/EliminateJoinByFkTest.java
new file mode 100644
index 00000000000..1faaca4f3f5
--- /dev/null
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/EliminateJoinByFkTest.java
@@ -0,0 +1,121 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.rewrite;
+
+import org.apache.doris.nereids.util.MemoPatternMatchSupported;
+import org.apache.doris.nereids.util.PlanChecker;
+import org.apache.doris.utframe.TestWithFeService;
+
+import org.junit.jupiter.api.Test;
+
+class EliminateJoinByFkTest extends TestWithFeService implements 
MemoPatternMatchSupported {
+    @Override
+    protected void runBeforeAll() throws Exception {
+        createDatabase("test");
+        connectContext.setDatabase("default_cluster:test");
+        createTables(
+                "CREATE TABLE IF NOT EXISTS pri (\n"
+                        + "    id1 int not null\n"
+                        + ")\n"
+                        + "DUPLICATE KEY(id1)\n"
+                        + "DISTRIBUTED BY HASH(id1) BUCKETS 10\n"
+                        + "PROPERTIES (\"replication_num\" = \"1\")\n",
+                "CREATE TABLE IF NOT EXISTS foreign_not_null (\n"
+                        + "    id2 int not null\n"
+                        + ")\n"
+                        + "DUPLICATE KEY(id2)\n"
+                        + "DISTRIBUTED BY HASH(id2) BUCKETS 10\n"
+                        + "PROPERTIES (\"replication_num\" = \"1\")\n",
+                "CREATE TABLE IF NOT EXISTS foreign_null (\n"
+                        + "    id3 int\n"
+                        + ")\n"
+                        + "DUPLICATE KEY(id3)\n"
+                        + "DISTRIBUTED BY HASH(id3) BUCKETS 10\n"
+                        + "PROPERTIES (\"replication_num\" = \"1\")\n"
+        );
+        addConstraint("Alter table pri add constraint pk primary key (id1)");
+        addConstraint("Alter table foreign_not_null add constraint f_not_null 
foreign key (id2)\n"
+                + "references pri(id1)");
+        addConstraint("Alter table foreign_null add constraint f_not_null 
foreign key (id3)\n"
+                + "references pri(id1)");
+    }
+
+    @Test
+    void testNotNull() throws Exception {
+        addConstraint("Alter table foreign_not_null add constraint uk1 unique 
(id2)\n");
+        String sql = "select pri.id1 from pri inner join foreign_not_null on 
pri.id1 = foreign_not_null.id2";
+        PlanChecker.from(connectContext)
+                .analyze(sql)
+                .rewrite()
+                .nonMatch(logicalJoin())
+                .printlnTree();
+        dropConstraint("Alter table foreign_not_null drop constraint uk1\n");
+    }
+
+    @Test
+    void testNotNullWithPredicate() throws Exception {
+        addConstraint("Alter table foreign_not_null add constraint uk2 unique 
(id2)\n");
+        String sql = "select pri.id1 from pri inner join foreign_not_null on 
pri.id1 = foreign_not_null.id2\n"
+                + "where pri.id1 = 1";
+        PlanChecker.from(connectContext)
+                .analyze(sql)
+                .rewrite()
+                .nonMatch(logicalJoin())
+                .printlnTree();
+        dropConstraint("Alter table foreign_not_null drop constraint uk2\n");
+    }
+
+    @Test
+    void testNull() throws Exception {
+        addConstraint("Alter table foreign_null add constraint uk unique 
(id3)\n");
+        String sql = "select pri.id1 from pri inner join foreign_null on 
pri.id1 = foreign_null.id3";
+        PlanChecker.from(connectContext)
+                .analyze(sql)
+                .rewrite()
+                .nonMatch(logicalJoin())
+                .printlnTree();
+        dropConstraint("Alter table foreign_null drop constraint uk\n");
+    }
+
+    @Test
+    void testNullWithPredicate() throws Exception {
+        addConstraint("Alter table foreign_null add constraint uk unique 
(id3)\n");
+        String sql = "select pri.id1 from pri inner join foreign_null on 
pri.id1 = foreign_null.id3\n"
+                + "where pri.id1 = 1";
+        PlanChecker.from(connectContext)
+                .analyze(sql)
+                .rewrite()
+                .nonMatch(logicalJoin())
+                .printlnTree();
+        dropConstraint("Alter table foreign_null drop constraint uk\n");
+    }
+
+    @Test
+    void testMultiJoin() throws Exception {
+        addConstraint("Alter table foreign_null add constraint uk unique 
(id3)\n");
+        String sql = "select id1 from "
+                + "foreign_null inner join foreign_not_null on id2 = id3\n"
+                + "inner join pri on id1 = id3";
+        PlanChecker.from(connectContext)
+                .analyze(sql)
+                .rewrite()
+                .nonMatch(logicalOlapScan().when(scan -> 
scan.getTable().getName().equals("pri")))
+                .printlnTree();
+        dropConstraint("Alter table foreign_null drop constraint uk\n");
+    }
+}
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/utframe/TestWithFeService.java 
b/fe/fe-core/src/test/java/org/apache/doris/utframe/TestWithFeService.java
index e8491091797..85f6e22e6f4 100644
--- a/fe/fe-core/src/test/java/org/apache/doris/utframe/TestWithFeService.java
+++ b/fe/fe-core/src/test/java/org/apache/doris/utframe/TestWithFeService.java
@@ -61,7 +61,9 @@ import org.apache.doris.nereids.glue.LogicalPlanAdapter;
 import org.apache.doris.nereids.parser.NereidsParser;
 import org.apache.doris.nereids.rules.RuleType;
 import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator;
+import org.apache.doris.nereids.trees.plans.commands.AddConstraintCommand;
 import org.apache.doris.nereids.trees.plans.commands.CreateTableCommand;
+import org.apache.doris.nereids.trees.plans.commands.DropConstraintCommand;
 import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
 import org.apache.doris.nereids.util.MemoTestUtils;
 import org.apache.doris.planner.Planner;
@@ -686,6 +688,22 @@ public abstract class TestWithFeService {
         Env.getCurrentEnv().createFunction(createFunctionStmt);
     }
 
+    public void addConstraint(String sql) throws Exception {
+        LogicalPlan parsed = new NereidsParser().parseSingle(sql);
+        StmtExecutor stmtExecutor = new StmtExecutor(connectContext, sql);
+        if (parsed instanceof AddConstraintCommand) {
+            ((AddConstraintCommand) parsed).run(connectContext, stmtExecutor);
+        }
+    }
+
+    public void dropConstraint(String sql) throws Exception {
+        LogicalPlan parsed = new NereidsParser().parseSingle(sql);
+        StmtExecutor stmtExecutor = new StmtExecutor(connectContext, sql);
+        if (parsed instanceof DropConstraintCommand) {
+            ((DropConstraintCommand) parsed).run(connectContext, stmtExecutor);
+        }
+    }
+
     protected void dropPolicy(String sql) throws Exception {
         DropPolicyStmt stmt = (DropPolicyStmt) parseAndAnalyzeStmt(sql);
         Env.getCurrentEnv().getPolicyMgr().dropPolicy(stmt);


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to