This is an automated email from the ASF dual-hosted git repository.

hui pushed a commit to branch lmh/GroupByLevelDebug
in repository https://gitbox.apache.org/repos/asf/iotdb.git

commit d39cb928a16de68b4e0c34903801e45500cc29f5
Author: liuminghui233 <[email protected]>
AuthorDate: Wed Jun 1 21:01:37 2022 +0800

    refactor LocalExecutionPlanner
---
 .../apache/iotdb/db/mpp/plan/analyze/Analyzer.java |  3 +-
 .../mpp/plan/analyze/GroupByLevelController.java   |  6 +++-
 .../db/mpp/plan/planner/LocalExecutionPlanner.java | 29 +++++----------
 .../plan/parameter/AggregationDescriptor.java      | 13 +++----
 .../plan/analyze/AggregationDescriptorTest.java    | 41 ++++++++++++----------
 5 files changed, 43 insertions(+), 49 deletions(-)

diff --git 
a/server/src/main/java/org/apache/iotdb/db/mpp/plan/analyze/Analyzer.java 
b/server/src/main/java/org/apache/iotdb/db/mpp/plan/analyze/Analyzer.java
index 8418241072..2e8ac3a942 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/plan/analyze/Analyzer.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/plan/analyze/Analyzer.java
@@ -655,7 +655,8 @@ public class Analyzer {
         Set<Expression> transformExpressions,
         Map<Expression, Expression> rawPathToGroupedPathMap) {
       GroupByLevelController groupByLevelController =
-          new 
GroupByLevelController(queryStatement.getGroupByLevelComponent().getLevels());
+          new GroupByLevelController(
+              queryStatement.getGroupByLevelComponent().getLevels(), 
typeProvider);
       for (Pair<Expression, String> measurementWithAlias : outputExpressions) {
         groupByLevelController.control(measurementWithAlias.left, 
measurementWithAlias.right);
       }
diff --git 
a/server/src/main/java/org/apache/iotdb/db/mpp/plan/analyze/GroupByLevelController.java
 
b/server/src/main/java/org/apache/iotdb/db/mpp/plan/analyze/GroupByLevelController.java
index 6055bc930c..a13353518e 100644
--- 
a/server/src/main/java/org/apache/iotdb/db/mpp/plan/analyze/GroupByLevelController.java
+++ 
b/server/src/main/java/org/apache/iotdb/db/mpp/plan/analyze/GroupByLevelController.java
@@ -61,12 +61,15 @@ public class GroupByLevelController {
    */
   private final Map<String, String> aliasToColumnMap;
 
-  public GroupByLevelController(int[] levels) {
+  private final TypeProvider typeProvider;
+
+  public GroupByLevelController(int[] levels, TypeProvider typeProvider) {
     this.levels = levels;
     this.groupedPathMap = new LinkedHashMap<>();
     this.rawPathToGroupedPathMap = new HashMap<>();
     this.columnToAliasMap = new HashMap<>();
     this.aliasToColumnMap = new HashMap<>();
+    this.typeProvider = typeProvider;
   }
 
   public void control(Expression expression, String alias) {
@@ -77,6 +80,7 @@ public class GroupByLevelController {
 
     PartialPath rawPath = ((TimeSeriesOperand) 
expression.getExpressions().get(0)).getPath();
     PartialPath groupedPath = generatePartialPathByLevel(rawPath.getNodes(), 
levels);
+    typeProvider.setType(groupedPath.getFullPath(), rawPath.getSeriesType());
 
     Expression rawPathExpression = new TimeSeriesOperand(rawPath);
     Expression groupedPathExpression = new TimeSeriesOperand(groupedPath);
diff --git 
a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/LocalExecutionPlanner.java
 
b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/LocalExecutionPlanner.java
index 96e9dfdc1c..fc1f737abd 100644
--- 
a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/LocalExecutionPlanner.java
+++ 
b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/LocalExecutionPlanner.java
@@ -825,20 +825,14 @@ public class LocalExecutionPlanner {
       List<Aggregator> aggregators = new ArrayList<>();
       Map<String, List<InputLocation>> layout = makeLayout(node);
       for (GroupByLevelDescriptor descriptor : 
node.getGroupByLevelDescriptors()) {
-        List<String> inputColumnNames = descriptor.getInputColumnNames();
-        List<InputLocation[]> inputLocationList = new 
ArrayList<>(inputColumnNames.size());
-        inputColumnNames.forEach(
-            inputColumnName ->
-                inputLocationList.add(layout.get(inputColumnName).toArray(new 
InputLocation[0])));
-
+        List<InputLocation[]> inputLocationList = 
calcInputLocationList(descriptor, layout);
         aggregators.add(
             new Aggregator(
                 AccumulatorFactory.createAccumulator(
                     descriptor.getAggregationType(),
                     context
                         .getTypeProvider()
-                        // get the type of first inputExpression
-                        
.getType(descriptor.getInputExpressions().get(0).toString()),
+                        
.getType(descriptor.getInputExpressions().get(0).getExpressionString()),
                     ascending),
                 descriptor.getStep(),
                 inputLocationList));
@@ -962,20 +956,13 @@ public class LocalExecutionPlanner {
 
     private List<InputLocation[]> calcInputLocationList(
         AggregationDescriptor descriptor, Map<String, List<InputLocation>> 
layout) {
-      List<String> inputColumnNames = descriptor.getInputColumnNames();
-      // it may include double parts
-      List<List<InputLocation>> inputLocationParts = new 
ArrayList<>(inputColumnNames.size());
-      inputColumnNames.forEach(o -> inputLocationParts.add(layout.get(o)));
-
+      List<List<String>> inputColumnNames = 
descriptor.getInputColumnNamesList();
       List<InputLocation[]> inputLocationList = new ArrayList<>();
-      for (int i = 0; i < inputLocationParts.get(0).size(); i++) {
-        if (inputColumnNames.size() == 1) {
-          inputLocationList.add(new InputLocation[] 
{inputLocationParts.get(0).get(i)});
-        } else {
-          inputLocationList.add(
-              new InputLocation[] {
-                inputLocationParts.get(0).get(i), 
inputLocationParts.get(1).get(i)
-              });
+      for (List<String> inputColumnNamesOfOneInput : inputColumnNames) {
+        checkArgument(
+            inputColumnNamesOfOneInput.size() == 1 || 
inputColumnNamesOfOneInput.size() == 2);
+        for (String inputColumnName : inputColumnNamesOfOneInput) {
+          inputLocationList.add(layout.get(inputColumnName).toArray(new 
InputLocation[0]));
         }
       }
       return inputLocationList;
diff --git 
a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/parameter/AggregationDescriptor.java
 
b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/parameter/AggregationDescriptor.java
index 942562bee7..f97add0369 100644
--- 
a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/parameter/AggregationDescriptor.java
+++ 
b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/parameter/AggregationDescriptor.java
@@ -25,6 +25,7 @@ import org.apache.iotdb.tsfile.utils.ReadWriteIOUtils;
 
 import java.nio.ByteBuffer;
 import java.util.ArrayList;
+import java.util.Collections;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
@@ -72,20 +73,16 @@ public class AggregationDescriptor {
     return outputColumnNames;
   }
 
-  public List<String> getInputColumnNames() {
+  public List<List<String>> getInputColumnNamesList() {
     if (step.isInputRaw()) {
       return inputExpressions.stream()
-          .map(Expression::getExpressionString)
+          .map(expression -> 
Collections.singletonList(expression.getExpressionString()))
           .collect(Collectors.toList());
     }
 
-    List<AggregationType> inputAggregationTypes = 
getActualAggregationTypes(step.isInputPartial());
-    List<String> inputColumnNames = new ArrayList<>();
+    List<List<String>> inputColumnNames = new ArrayList<>();
     for (Expression expression : inputExpressions) {
-      for (AggregationType funcName : inputAggregationTypes) {
-        inputColumnNames.add(
-            funcName.toString().toLowerCase() + "(" + 
expression.getExpressionString() + ")");
-      }
+      inputColumnNames.add(getInputColumnNames(expression));
     }
     return inputColumnNames;
   }
diff --git 
a/server/src/test/java/org/apache/iotdb/db/mpp/plan/analyze/AggregationDescriptorTest.java
 
b/server/src/test/java/org/apache/iotdb/db/mpp/plan/analyze/AggregationDescriptorTest.java
index fb7b72515b..4b8bbf6b00 100644
--- 
a/server/src/test/java/org/apache/iotdb/db/mpp/plan/analyze/AggregationDescriptorTest.java
+++ 
b/server/src/test/java/org/apache/iotdb/db/mpp/plan/analyze/AggregationDescriptorTest.java
@@ -145,20 +145,19 @@ public class AggregationDescriptorTest {
 
   @Test
   public void testInputColumnNames() {
-    List<String> expectedInputColumnNames =
+    List<List<List<String>>> expectedInputColumnNames =
         Arrays.asList(
-            "root.sg.d1.s1",
-            "count(root.sg.d1.s1)",
-            "sum(root.sg.d1.s1)",
-            "last_value(root.sg.d1.s1)",
-            "max_time(root.sg.d1.s1)",
-            "max_value(root.sg.d1.s1)");
+            
Collections.singletonList(Collections.singletonList("root.sg.d1.s1")),
+            
Collections.singletonList(Collections.singletonList("root.sg.d1.s1")),
+            Collections.singletonList(Arrays.asList("count(root.sg.d1.s1)", 
"sum(root.sg.d1.s1)")),
+            Collections.singletonList(
+                Arrays.asList("last_value(root.sg.d1.s1)", 
"max_time(root.sg.d1.s1)")),
+            
Collections.singletonList(Collections.singletonList("max_value(root.sg.d1.s1)")),
+            
Collections.singletonList(Collections.singletonList("count(root.sg.d1.s1)")));
     Assert.assertEquals(
         expectedInputColumnNames,
         aggregationDescriptorList.stream()
-            .map(AggregationDescriptor::getInputColumnNames)
-            .flatMap(List::stream)
-            .distinct()
+            .map(AggregationDescriptor::getInputColumnNamesList)
             .collect(Collectors.toList()));
   }
 
@@ -177,18 +176,24 @@ public class AggregationDescriptorTest {
 
   @Test
   public void testInputColumnNamesInGroupByLevel() {
-    List<String> expectedInputColumnNames =
+    List<List<List<String>>> expectedInputColumnNames =
         Arrays.asList(
-            "count(root.sg.d2.s1)",
-            "count(root.sg.d1.s1)",
-            "sum(root.sg.d1.s1)",
-            "sum(root.sg.d2.s1)");
+            Arrays.asList(
+                Collections.singletonList("count(root.sg.d2.s1)"),
+                Collections.singletonList("count(root.sg.d1.s1)")),
+            Arrays.asList(
+                Arrays.asList("count(root.sg.d1.s1)", "sum(root.sg.d1.s1)"),
+                Arrays.asList("count(root.sg.d2.s1)", "sum(root.sg.d2.s1)")),
+            Arrays.asList(
+                Collections.singletonList("count(root.sg.d2.s1)"),
+                Collections.singletonList("count(root.sg.d1.s1)")),
+            Arrays.asList(
+                Arrays.asList("count(root.sg.d1.s1)", "sum(root.sg.d1.s1)"),
+                Arrays.asList("count(root.sg.d2.s1)", "sum(root.sg.d2.s1)")));
     Assert.assertEquals(
         expectedInputColumnNames,
         groupByLevelDescriptorList.stream()
-            .map(GroupByLevelDescriptor::getInputColumnNames)
-            .flatMap(List::stream)
-            .distinct()
+            .map(GroupByLevelDescriptor::getInputColumnNamesList)
             .collect(Collectors.toList()));
   }
 

Reply via email to