This is an automated email from the ASF dual-hosted git repository.

caogaofei pushed a commit to branch beyyes/topk_streamsort
in repository https://gitbox.apache.org/repos/asf/iotdb.git

commit e3a1e7a2b7a61e32d18059e5a2f56b60e94bfa11
Author: Beyyes <[email protected]>
AuthorDate: Mon Jul 29 11:50:47 2024 +0800

    perfect streamsort and topk
---
 .../it/query/old/orderBy/IoTDBOrderByTableIT.java  |  4 +-
 .../rule/MergeLimitOverProjectWithMergeSort.java   | 38 +---------
 .../iterative/rule/MergeLimitWithMergeSort.java    | 83 +++++++++++++++++++---
 .../plan/relational/analyzer/SortTest.java         | 19 +++++
 4 files changed, 97 insertions(+), 47 deletions(-)

diff --git 
a/integration-test/src/test/java/org/apache/iotdb/relational/it/query/old/orderBy/IoTDBOrderByTableIT.java
 
b/integration-test/src/test/java/org/apache/iotdb/relational/it/query/old/orderBy/IoTDBOrderByTableIT.java
index c266895b6b3..85e5c4e39f7 100644
--- 
a/integration-test/src/test/java/org/apache/iotdb/relational/it/query/old/orderBy/IoTDBOrderByTableIT.java
+++ 
b/integration-test/src/test/java/org/apache/iotdb/relational/it/query/old/orderBy/IoTDBOrderByTableIT.java
@@ -66,7 +66,7 @@ public class IoTDBOrderByTableIT {
         "USE db",
         "CREATE TABLE table0 (device string id, attr1 string attribute, num 
int32 measurement, bigNum int64 measurement, "
             + "floatNum double measurement, str TEXT measurement, bool BOOLEAN 
measurement)",
-        "insert into table0(device, time,num,bigNum,floatNum,str,bool) 
values('d1', 0,3,2947483648,231.2121,'coconut',FALSE)",
+        "insert into table0(device, attr1, time,num,bigNum,floatNum,str,bool) 
values('d1', 'high', 0,3,2947483648,231.2121,'coconut',FALSE)",
         "insert into table0(device, time,num,bigNum,floatNum,str,bool) 
values('d1', 20,2,2147483648,434.12,'pineapple',TRUE)",
         "insert into table0(device, time,num,bigNum,floatNum,str,bool) 
values('d1', 40,1,2247483648,12.123,'apricot',TRUE)",
         "insert into table0(device, time,num,bigNum,floatNum,str,bool) 
values('d1', 80,9,2147483646,43.12,'apple',FALSE)",
@@ -85,7 +85,7 @@ public class IoTDBOrderByTableIT {
 
   private static final String[] sql2 =
       new String[] {
-        "insert into table0(device,time,num,bigNum,floatNum,str,bool) 
values('d2',0,3,2947483648,231.2121,'coconut',FALSE)",
+        "insert into table0(device,attr1,time,num,bigNum,floatNum,str,bool) 
values('d2','high',0,3,2947483648,231.2121,'coconut',FALSE)",
         "insert into table0(device,time,num,bigNum,floatNum,str,bool) 
values('d2',20,2,2147483648,434.12,'pineapple',TRUE)",
         "insert into table0(device,time,num,bigNum,floatNum,str,bool) 
values('d2',40,1,2247483648,12.123,'apricot',TRUE)",
         "insert into table0(device,time,num,bigNum,floatNum,str,bool) 
values('d2',80,9,2147483646,43.12,'apple',FALSE)",
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/iterative/rule/MergeLimitOverProjectWithMergeSort.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/iterative/rule/MergeLimitOverProjectWithMergeSort.java
index 5b456e8227c..6e720e2def4 100644
--- 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/iterative/rule/MergeLimitOverProjectWithMergeSort.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/iterative/rule/MergeLimitOverProjectWithMergeSort.java
@@ -24,7 +24,6 @@ import 
org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.Rule;
 import org.apache.iotdb.db.queryengine.plan.relational.planner.node.LimitNode;
 import 
org.apache.iotdb.db.queryengine.plan.relational.planner.node.MergeSortNode;
 import 
org.apache.iotdb.db.queryengine.plan.relational.planner.node.ProjectNode;
-import 
org.apache.iotdb.db.queryengine.plan.relational.planner.node.StreamSortNode;
 import org.apache.iotdb.db.queryengine.plan.relational.planner.node.TopKNode;
 import org.apache.iotdb.db.queryengine.plan.relational.utils.matching.Capture;
 import org.apache.iotdb.db.queryengine.plan.relational.utils.matching.Captures;
@@ -32,8 +31,7 @@ import 
org.apache.iotdb.db.queryengine.plan.relational.utils.matching.Pattern;
 
 import com.google.common.collect.ImmutableList;
 
-import java.util.Optional;
-
+import static 
org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.MergeLimitWithMergeSort.transformByMergeSortNode;
 import static 
org.apache.iotdb.db.queryengine.plan.relational.planner.node.Patterns.limit;
 import static 
org.apache.iotdb.db.queryengine.plan.relational.planner.node.Patterns.mergeSort;
 import static 
org.apache.iotdb.db.queryengine.plan.relational.planner.node.Patterns.project;
@@ -53,7 +51,7 @@ import static 
org.apache.iotdb.db.queryengine.plan.relational.utils.matching.Cap
  *
  * <pre>
  * - Project (identity, narrowing)
- *    - TopN (limit = x, order by a, b)
+ *    - TopK (limit = x, order by a, b)
  * </pre>
  *
  * Applies to LimitNode without ties only.
@@ -82,38 +80,8 @@ public class MergeLimitOverProjectWithMergeSort implements 
Rule<LimitNode> {
   public Result apply(LimitNode parent, Captures captures, Context context) {
     ProjectNode project = captures.get(PROJECT);
     MergeSortNode mergeSortNode = captures.get(MERGE_SORT);
-
-    TopKNode topKNode;
     PlanNode childOfMergeSort = 
context.getLookup().resolve(mergeSortNode.getChildren().get(0));
-    if (childOfMergeSort instanceof StreamSortNode) {
-      topKNode =
-          new TopKNode(
-              parent.getPlanNodeId(),
-              mergeSortNode.getOrderingScheme(),
-              parent.getCount(),
-              parent.getOutputSymbols(),
-              true);
-      for (PlanNode child : mergeSortNode.getChildren()) {
-        LimitNode limitNode =
-            new LimitNode(
-                context.getIdAllocator().genPlanNodeId(),
-                child,
-                parent.getCount(),
-                Optional.empty());
-        topKNode.addChild(limitNode);
-      }
-
-    } else {
-      topKNode =
-          new TopKNode(
-              parent.getPlanNodeId(),
-              mergeSortNode.getChildren(),
-              mergeSortNode.getOrderingScheme(),
-              parent.getCount(),
-              parent.getOutputSymbols(),
-              false);
-    }
-
+    TopKNode topKNode = transformByMergeSortNode(parent, mergeSortNode, 
childOfMergeSort, context);
     return 
Result.ofPlanNode(project.replaceChildren(ImmutableList.of(topKNode)));
   }
 }
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/iterative/rule/MergeLimitWithMergeSort.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/iterative/rule/MergeLimitWithMergeSort.java
index cf8006258a2..71bf6bca346 100644
--- 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/iterative/rule/MergeLimitWithMergeSort.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/iterative/rule/MergeLimitWithMergeSort.java
@@ -13,19 +13,53 @@
  */
 package org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule;
 
+import org.apache.iotdb.db.queryengine.plan.planner.plan.node.PlanNode;
 import org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.Rule;
 import org.apache.iotdb.db.queryengine.plan.relational.planner.node.LimitNode;
 import 
org.apache.iotdb.db.queryengine.plan.relational.planner.node.MergeSortNode;
+import 
org.apache.iotdb.db.queryengine.plan.relational.planner.node.StreamSortNode;
 import org.apache.iotdb.db.queryengine.plan.relational.planner.node.TopKNode;
 import org.apache.iotdb.db.queryengine.plan.relational.utils.matching.Capture;
 import org.apache.iotdb.db.queryengine.plan.relational.utils.matching.Captures;
 import org.apache.iotdb.db.queryengine.plan.relational.utils.matching.Pattern;
 
+import java.util.Optional;
+
 import static 
org.apache.iotdb.db.queryengine.plan.relational.planner.node.Patterns.limit;
 import static 
org.apache.iotdb.db.queryengine.plan.relational.planner.node.Patterns.mergeSort;
 import static 
org.apache.iotdb.db.queryengine.plan.relational.planner.node.Patterns.source;
 import static 
org.apache.iotdb.db.queryengine.plan.relational.utils.matching.Capture.newCapture;
 
+/**
+ * Transforms:
+ *
+ * <pre>
+ * - Limit (limit = x)
+ *       - MergeSort (order by a, b)
+ * </pre>
+ *
+ * Into:
+ *
+ * <pre>
+ * - TopK (limit = x, order by a, b)
+ * </pre>
+ *
+ * <pre>
+ * - Limit (limit = x)
+ *       - MergeSort (order by a, b)
+ *          - StreamSort (order by a, b)
+ * </pre>
+ *
+ * Into:
+ *
+ * <pre>
+ * - TopK (limit = x, order by a, b)
+ *    - Limit (limit = x)
+ *      - StreamSort (order by a, b)
+ * </pre>
+ *
+ * Applies to LimitNode without ties only.
+ */
 public class MergeLimitWithMergeSort implements Rule<LimitNode> {
   private static final Capture<MergeSortNode> CHILD = newCapture();
 
@@ -41,15 +75,44 @@ public class MergeLimitWithMergeSort implements 
Rule<LimitNode> {
 
   @Override
   public Result apply(LimitNode parent, Captures captures, Context context) {
-    MergeSortNode mergeSort = captures.get(CHILD);
-
-    return Result.ofPlanNode(
-        new TopKNode(
-            parent.getPlanNodeId(),
-            mergeSort.getChildren(),
-            mergeSort.getOrderingScheme(),
-            parent.getCount(),
-            parent.getOutputSymbols(),
-            false));
+    MergeSortNode mergeSortNode = captures.get(CHILD);
+    PlanNode childOfMergeSort = 
context.getLookup().resolve(mergeSortNode.getChildren().get(0));
+    TopKNode topKNode = transformByMergeSortNode(parent, mergeSortNode, 
childOfMergeSort, context);
+    return Result.ofPlanNode(topKNode);
+  }
+
+  static TopKNode transformByMergeSortNode(
+      LimitNode parent, MergeSortNode mergeSortNode, PlanNode 
childOfMergeSort, Context context) {
+    TopKNode topKNode;
+    if (childOfMergeSort instanceof StreamSortNode) {
+      topKNode =
+          new TopKNode(
+              parent.getPlanNodeId(),
+              mergeSortNode.getOrderingScheme(),
+              parent.getCount(),
+              parent.getOutputSymbols(),
+              true);
+      for (PlanNode child : mergeSortNode.getChildren()) {
+        LimitNode limitNode =
+            new LimitNode(
+                context.getIdAllocator().genPlanNodeId(),
+                child,
+                parent.getCount(),
+                Optional.empty());
+        topKNode.addChild(limitNode);
+      }
+
+    } else {
+      topKNode =
+          new TopKNode(
+              parent.getPlanNodeId(),
+              mergeSortNode.getChildren(),
+              mergeSortNode.getOrderingScheme(),
+              parent.getCount(),
+              parent.getOutputSymbols(),
+              true);
+    }
+
+    return topKNode;
   }
 }
diff --git 
a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/SortTest.java
 
b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/SortTest.java
index b34b46c1b36..6f44ddcc89f 100644
--- 
a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/SortTest.java
+++ 
b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/SortTest.java
@@ -173,6 +173,25 @@ public class SortTest {
         0,
         0,
         false);
+
+    sql = "SELECT * FROM table1 order by tag2 desc, tag3 asc offset 5 limit 
10";
+    context = new MPPQueryContext(sql, queryId, sessionInfo, null, null);
+    actualAnalysis = analyzeSQL(sql, metadata);
+    logicalQueryPlan =
+        new LogicalPlanner(context, metadata, sessionInfo, 
warningCollector).plan(actualAnalysis);
+    rootNode = logicalQueryPlan.getRootNode();
+    // LogicalPlan: `Output-Offset-Limit-StreamSort-TableScan`
+    assertTrue(getChildrenNode(rootNode, 3) instanceof StreamSortNode);
+    distributionPlanner = new TableDistributedPlanner(actualAnalysis, 
logicalQueryPlan, context);
+    distributedQueryPlan = distributionPlanner.plan();
+    assertEquals(3, distributedQueryPlan.getFragments().size());
+    // DistributedPlan: `Output-Offset-TopK-Limit-StreamSort-TableScan`
+    identitySinkNode =
+        (IdentitySinkNode) 
distributedQueryPlan.getFragments().get(0).getPlanNodeTree();
+    assertTrue(getChildrenNode(identitySinkNode, 3) instanceof TopKNode);
+    topKNode = (TopKNode) getChildrenNode(identitySinkNode, 3);
+    assertTrue(topKNode.getChildren().get(1) instanceof LimitNode);
+    assertTrue(getChildrenNode(topKNode.getChildren().get(1), 1) instanceof 
StreamSortNode);
   }
 
   // order by all_ids, time, others

Reply via email to