This is an automated email from the ASF dual-hosted git repository.

dataroaring pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/doris.git

commit 49b81b54542288c9e0db57d4267fe7efc9c9f85a
Author: 谢健 <[email protected]>
AuthorDate: Tue May 28 12:17:49 2024 +0800

    [fix](Nereids) Optimize BFS Memory Usage to Mitigate Exponential Data 
Growth (#35440)
    
    The origin pr is #34948 and the temporary solution is #35408.
    
    In our effort to streamline and optimize dependency handling, we
    implement the following steps:
    
    - Detect Circular Dependencies: Identify any circular references within
      functional dependencies. If any are found, we remove the specific
      dependencies responsible for creating these cycles.
    - Clean Up Group By Dependencies: Remove all dependencies listed in the
      'group by' clauses to simplify and enhance query performance.
---
 .../apache/doris/nereids/properties/FuncDeps.java  | 66 ++++++++++++++--------
 .../doris/nereids/properties/FuncDepsTest.java     |  4 +-
 .../rules/rewrite/EliminateGroupByKeyTest.java     | 54 +++++++++++++++++-
 3 files changed, 96 insertions(+), 28 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/FuncDeps.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/FuncDeps.java
index c77b5ed03b6..6c1b302d7dc 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/FuncDeps.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/FuncDeps.java
@@ -22,9 +22,9 @@ import org.apache.doris.nereids.trees.expressions.Slot;
 import com.google.common.collect.ImmutableSet;
 import com.google.common.collect.Sets;
 
-import java.util.ArrayList;
+import java.util.HashMap;
 import java.util.HashSet;
-import java.util.List;
+import java.util.Map;
 import java.util.Objects;
 import java.util.Set;
 
@@ -61,13 +61,17 @@ public class FuncDeps {
     }
 
     private final Set<FuncDepsItem> items;
+    private final Map<Set<Slot>, Set<Set<Slot>>> edges;
 
-    FuncDeps() {
+    public FuncDeps() {
         items = new HashSet<>();
+        edges = new HashMap<>();
     }
 
     public void addFuncItems(Set<Slot> determinants, Set<Slot> dependencies) {
         items.add(new FuncDepsItem(determinants, dependencies));
+        edges.computeIfAbsent(determinants, k -> new HashSet<>());
+        edges.get(determinants).add(dependencies);
     }
 
     public int size() {
@@ -78,6 +82,32 @@ public class FuncDeps {
         return items.isEmpty();
     }
 
+    private void dfs(Set<Slot> parent, Set<Set<Slot>> visited, 
Set<FuncDepsItem> circleItem) {
+        visited.add(parent);
+        if (!edges.containsKey(parent)) {
+            return;
+        }
+        for (Set<Slot> child : edges.get(parent)) {
+            if (visited.contains(child)) {
+                circleItem.add(new FuncDepsItem(parent, child));
+                continue;
+            }
+            dfs(child, visited, circleItem);
+        }
+    }
+
+    // find item that not in a circle
+    private Set<FuncDepsItem> findValidItems() {
+        Set<FuncDepsItem> circleItem = new HashSet<>();
+        Set<Set<Slot>> visited = new HashSet<>();
+        for (Set<Slot> parent : edges.keySet()) {
+            if (!visited.contains(parent)) {
+                dfs(parent, visited, circleItem);
+            }
+        }
+        return Sets.difference(items, circleItem);
+    }
+
     /**
      * Reduces a given set of slot sets by eliminating dependencies using a 
breadth-first search (BFS) approach.
      * <p>
@@ -97,30 +127,16 @@ public class FuncDeps {
      * @return the minimal set of slot sets after applying all possible 
reductions
      */
     public Set<Set<Slot>> eliminateDeps(Set<Set<Slot>> slots) {
-        Set<Set<Slot>> minSlotSet = slots;
-        List<Set<Set<Slot>>> reduceSlotSets = new ArrayList<>();
-        reduceSlotSets.add(slots);
-        // To avoid memory usage due to multiple iterations,
-        // we set a maximum number of loop iterations.
-        int count = 0;
-        while (!reduceSlotSets.isEmpty() && count < 100) {
-            count += 1;
-            List<Set<Set<Slot>>> newReduceSlotSets = new ArrayList<>();
-            for (Set<Set<Slot>> slotSet : reduceSlotSets) {
-                for (FuncDepsItem funcDepsItem : items) {
-                    if (slotSet.contains(funcDepsItem.dependencies)
-                            && slotSet.contains(funcDepsItem.determinants)) {
-                        Set<Set<Slot>> newSet = Sets.newHashSet(slotSet);
-                        newSet.remove(funcDepsItem.dependencies);
-                        if (minSlotSet.size() > newSet.size()) {
-                            minSlotSet = newSet;
-                        }
-                        newReduceSlotSets.add(newSet);
-                    }
-                }
+        Set<Set<Slot>> minSlotSet = Sets.newHashSet(slots);
+        Set<Set<Slot>> eliminatedSlots = new HashSet<>();
+        Set<FuncDepsItem> validItems = findValidItems();
+        for (FuncDepsItem funcDepsItem : validItems) {
+            if (minSlotSet.contains(funcDepsItem.dependencies)
+                    && minSlotSet.contains(funcDepsItem.determinants)) {
+                eliminatedSlots.add(funcDepsItem.dependencies);
             }
-            reduceSlotSets = newReduceSlotSets;
         }
+        minSlotSet.removeAll(eliminatedSlots);
         return minSlotSet;
     }
 
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/FuncDepsTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/FuncDepsTest.java
index a9496392fc6..64df33acd60 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/FuncDepsTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/FuncDepsTest.java
@@ -85,7 +85,7 @@ class FuncDepsTest {
         funcDeps.addFuncItems(Sets.newHashSet(s2), Sets.newHashSet(s1));
         Set<Set<Slot>> slots = funcDeps.eliminateDeps(slotSet);
         Set<Set<Slot>> expected = new HashSet<>();
-        expected.add(set2);
+        expected.add(set1);
         expected.add(set3);
         expected.add(set4);
         Assertions.assertEquals(expected, slots);
@@ -101,7 +101,7 @@ class FuncDepsTest {
         funcDeps.addFuncItems(Sets.newHashSet(s4), Sets.newHashSet(s1));
         Set<Set<Slot>> slots = funcDeps.eliminateDeps(slotSet);
         Set<Set<Slot>> expected = new HashSet<>();
-        expected.add(set3);
+        expected.add(set1);
         Assertions.assertEquals(expected, slots);
     }
 
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupByKeyTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupByKeyTest.java
index 907999a1f2b..203e902b3eb 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupByKeyTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupByKeyTest.java
@@ -17,13 +17,31 @@
 
 package org.apache.doris.nereids.rules.rewrite;
 
+import org.apache.doris.nereids.properties.FuncDeps;
+import org.apache.doris.nereids.trees.expressions.Slot;
+import org.apache.doris.nereids.trees.expressions.SlotReference;
+import org.apache.doris.nereids.types.IntegerType;
 import org.apache.doris.nereids.util.MemoPatternMatchSupported;
 import org.apache.doris.nereids.util.PlanChecker;
 import org.apache.doris.utframe.TestWithFeService;
 
+import com.google.common.collect.ImmutableSet;
+import com.google.common.collect.Sets;
+import org.junit.jupiter.api.Assertions;
 import org.junit.jupiter.api.Test;
 
+import java.util.Set;
+
 class EliminateGroupByKeyTest extends TestWithFeService implements 
MemoPatternMatchSupported {
+    Slot s1 = new SlotReference("1", IntegerType.INSTANCE, false);
+    Slot s2 = new SlotReference("2", IntegerType.INSTANCE, false);
+    Slot s3 = new SlotReference("3", IntegerType.INSTANCE, false);
+    Slot s4 = new SlotReference("4", IntegerType.INSTANCE, false);
+    Set<Slot> set1 = Sets.newHashSet(s1);
+    Set<Slot> set2 = Sets.newHashSet(s2);
+    Set<Slot> set3 = Sets.newHashSet(s3);
+    Set<Slot> set4 = Sets.newHashSet(s4);
+
     @Override
     protected void runBeforeAll() throws Exception {
         createDatabase("test");
@@ -42,6 +60,40 @@ class EliminateGroupByKeyTest extends TestWithFeService 
implements MemoPatternMa
         
connectContext.getSessionVariable().setDisableNereidsRules("PRUNE_EMPTY_PARTITION");
     }
 
+    @Test
+    void testEliminateChain() {
+        FuncDeps funcDeps = new FuncDeps();
+        funcDeps.addFuncItems(set1, set2);
+        funcDeps.addFuncItems(set2, set3);
+        funcDeps.addFuncItems(set3, set4);
+        Set<Set<Slot>> slots = funcDeps.eliminateDeps(ImmutableSet.of(set1, 
set2, set3, set4));
+        Assertions.assertEquals(1, slots.size());
+        Assertions.assertEquals(set1, slots.iterator().next());
+    }
+
+    @Test
+    void testEliminateCircle() {
+        FuncDeps funcDeps = new FuncDeps();
+        funcDeps.addFuncItems(set1, set2);
+        funcDeps.addFuncItems(set2, set3);
+        funcDeps.addFuncItems(set3, set4);
+        funcDeps.addFuncItems(set4, set1);
+        Set<Set<Slot>> slots = funcDeps.eliminateDeps(ImmutableSet.of(set1, 
set2, set3, set4));
+        Assertions.assertEquals(1, slots.size());
+        Assertions.assertEquals(set1, slots.iterator().next());
+    }
+
+    @Test
+    void testEliminateTree() {
+        FuncDeps funcDeps = new FuncDeps();
+        funcDeps.addFuncItems(set1, set2);
+        funcDeps.addFuncItems(set1, set3);
+        funcDeps.addFuncItems(set1, set4);
+        Set<Set<Slot>> slots = funcDeps.eliminateDeps(ImmutableSet.of(set1, 
set2, set3, set4));
+        Assertions.assertEquals(1, slots.size());
+        Assertions.assertEquals(set1, slots.iterator().next());
+    }
+
     @Test
     void testEliminateByUniform() {
         PlanChecker.from(connectContext)
@@ -67,7 +119,7 @@ class EliminateGroupByKeyTest extends TestWithFeService 
implements MemoPatternMa
                 .matches(logicalAggregate().when(agg ->
                         agg.getGroupByExpressions().size() == 2));
         PlanChecker.from(connectContext)
-                .analyze("select id from t1 where id = 1 and name = \"\" group 
by name, id")
+                .analyze("select name as n, count(id) as c from t1 where name 
= \"\" group by name, id having c = 2")
                 .rewrite()
                 .printlnTree()
                 .matches(logicalAggregate().when(agg ->


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to