This is an automated email from the ASF dual-hosted git repository.

hui pushed a commit to branch lmh/calcInputLocationListDebug
in repository https://gitbox.apache.org/repos/asf/iotdb.git

commit c9942424a391a72c9facc302b2a1321d2f8f57e8
Author: Minghui Liu <[email protected]>
AuthorDate: Thu Jun 2 15:40:23 2022 +0800

    refactor calcInputLocationList
---
 .../db/mpp/plan/planner/LocalExecutionPlanner.java | 35 +++++++++---------
 .../plan/parameter/AggregationDescriptor.java      | 13 +++----
 .../plan/analyze/AggregationDescriptorTest.java    | 41 ++++++++++++----------
 3 files changed, 44 insertions(+), 45 deletions(-)

diff --git 
a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/LocalExecutionPlanner.java
 
b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/LocalExecutionPlanner.java
index 2a248a7f1f..dc0cb9ea17 100644
--- 
a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/LocalExecutionPlanner.java
+++ 
b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/LocalExecutionPlanner.java
@@ -827,12 +827,7 @@ public class LocalExecutionPlanner {
       List<Aggregator> aggregators = new ArrayList<>();
       Map<String, List<InputLocation>> layout = makeLayout(node);
       for (GroupByLevelDescriptor descriptor : 
node.getGroupByLevelDescriptors()) {
-        List<String> inputColumnNames = descriptor.getInputColumnNames();
-        List<InputLocation[]> inputLocationList = new 
ArrayList<>(inputColumnNames.size());
-        inputColumnNames.forEach(
-            inputColumnName ->
-                inputLocationList.add(layout.get(inputColumnName).toArray(new 
InputLocation[0])));
-
+        List<InputLocation[]> inputLocationList = 
calcInputLocationList(descriptor, layout);
         aggregators.add(
             new Aggregator(
                 AccumulatorFactory.createAccumulator(
@@ -964,20 +959,22 @@ public class LocalExecutionPlanner {
 
     private List<InputLocation[]> calcInputLocationList(
         AggregationDescriptor descriptor, Map<String, List<InputLocation>> 
layout) {
-      List<String> inputColumnNames = descriptor.getInputColumnNames();
-      // it may include double parts
-      List<List<InputLocation>> inputLocationParts = new 
ArrayList<>(inputColumnNames.size());
-      inputColumnNames.forEach(o -> inputLocationParts.add(layout.get(o)));
-
+      List<List<String>> inputColumnNames = 
descriptor.getInputColumnNamesList();
       List<InputLocation[]> inputLocationList = new ArrayList<>();
-      for (int i = 0; i < inputLocationParts.get(0).size(); i++) {
-        if (inputColumnNames.size() == 1) {
-          inputLocationList.add(new InputLocation[] 
{inputLocationParts.get(0).get(i)});
-        } else {
-          inputLocationList.add(
-              new InputLocation[] {
-                inputLocationParts.get(0).get(i), 
inputLocationParts.get(1).get(i)
-              });
+
+      for (List<String> inputColumnNamesOfOneInput : inputColumnNames) {
+        // it may include double parts
+        List<List<InputLocation>> inputLocationParts = new ArrayList<>();
+        inputColumnNamesOfOneInput.forEach(o -> 
inputLocationParts.add(layout.get(o)));
+        for (int i = 0; i < inputLocationParts.get(0).size(); i++) {
+          if (inputColumnNamesOfOneInput.size() == 1) {
+            inputLocationList.add(new InputLocation[] 
{inputLocationParts.get(0).get(i)});
+          } else {
+            inputLocationList.add(
+                new InputLocation[] {
+                  inputLocationParts.get(0).get(i), 
inputLocationParts.get(1).get(i)
+                });
+          }
         }
       }
       return inputLocationList;
diff --git 
a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/parameter/AggregationDescriptor.java
 
b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/parameter/AggregationDescriptor.java
index 942562bee7..f97add0369 100644
--- 
a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/parameter/AggregationDescriptor.java
+++ 
b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/parameter/AggregationDescriptor.java
@@ -25,6 +25,7 @@ import org.apache.iotdb.tsfile.utils.ReadWriteIOUtils;
 
 import java.nio.ByteBuffer;
 import java.util.ArrayList;
+import java.util.Collections;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
@@ -72,20 +73,16 @@ public class AggregationDescriptor {
     return outputColumnNames;
   }
 
-  public List<String> getInputColumnNames() {
+  public List<List<String>> getInputColumnNamesList() {
     if (step.isInputRaw()) {
       return inputExpressions.stream()
-          .map(Expression::getExpressionString)
+          .map(expression -> 
Collections.singletonList(expression.getExpressionString()))
           .collect(Collectors.toList());
     }
 
-    List<AggregationType> inputAggregationTypes = 
getActualAggregationTypes(step.isInputPartial());
-    List<String> inputColumnNames = new ArrayList<>();
+    List<List<String>> inputColumnNames = new ArrayList<>();
     for (Expression expression : inputExpressions) {
-      for (AggregationType funcName : inputAggregationTypes) {
-        inputColumnNames.add(
-            funcName.toString().toLowerCase() + "(" + 
expression.getExpressionString() + ")");
-      }
+      inputColumnNames.add(getInputColumnNames(expression));
     }
     return inputColumnNames;
   }
diff --git 
a/server/src/test/java/org/apache/iotdb/db/mpp/plan/analyze/AggregationDescriptorTest.java
 
b/server/src/test/java/org/apache/iotdb/db/mpp/plan/analyze/AggregationDescriptorTest.java
index fb7b72515b..4b8bbf6b00 100644
--- 
a/server/src/test/java/org/apache/iotdb/db/mpp/plan/analyze/AggregationDescriptorTest.java
+++ 
b/server/src/test/java/org/apache/iotdb/db/mpp/plan/analyze/AggregationDescriptorTest.java
@@ -145,20 +145,19 @@ public class AggregationDescriptorTest {
 
   @Test
   public void testInputColumnNames() {
-    List<String> expectedInputColumnNames =
+    List<List<List<String>>> expectedInputColumnNames =
         Arrays.asList(
-            "root.sg.d1.s1",
-            "count(root.sg.d1.s1)",
-            "sum(root.sg.d1.s1)",
-            "last_value(root.sg.d1.s1)",
-            "max_time(root.sg.d1.s1)",
-            "max_value(root.sg.d1.s1)");
+            
Collections.singletonList(Collections.singletonList("root.sg.d1.s1")),
+            
Collections.singletonList(Collections.singletonList("root.sg.d1.s1")),
+            Collections.singletonList(Arrays.asList("count(root.sg.d1.s1)", 
"sum(root.sg.d1.s1)")),
+            Collections.singletonList(
+                Arrays.asList("last_value(root.sg.d1.s1)", 
"max_time(root.sg.d1.s1)")),
+            
Collections.singletonList(Collections.singletonList("max_value(root.sg.d1.s1)")),
+            
Collections.singletonList(Collections.singletonList("count(root.sg.d1.s1)")));
     Assert.assertEquals(
         expectedInputColumnNames,
         aggregationDescriptorList.stream()
-            .map(AggregationDescriptor::getInputColumnNames)
-            .flatMap(List::stream)
-            .distinct()
+            .map(AggregationDescriptor::getInputColumnNamesList)
             .collect(Collectors.toList()));
   }
 
@@ -177,18 +176,24 @@ public class AggregationDescriptorTest {
 
   @Test
   public void testInputColumnNamesInGroupByLevel() {
-    List<String> expectedInputColumnNames =
+    List<List<List<String>>> expectedInputColumnNames =
         Arrays.asList(
-            "count(root.sg.d2.s1)",
-            "count(root.sg.d1.s1)",
-            "sum(root.sg.d1.s1)",
-            "sum(root.sg.d2.s1)");
+            Arrays.asList(
+                Collections.singletonList("count(root.sg.d2.s1)"),
+                Collections.singletonList("count(root.sg.d1.s1)")),
+            Arrays.asList(
+                Arrays.asList("count(root.sg.d1.s1)", "sum(root.sg.d1.s1)"),
+                Arrays.asList("count(root.sg.d2.s1)", "sum(root.sg.d2.s1)")),
+            Arrays.asList(
+                Collections.singletonList("count(root.sg.d2.s1)"),
+                Collections.singletonList("count(root.sg.d1.s1)")),
+            Arrays.asList(
+                Arrays.asList("count(root.sg.d1.s1)", "sum(root.sg.d1.s1)"),
+                Arrays.asList("count(root.sg.d2.s1)", "sum(root.sg.d2.s1)")));
     Assert.assertEquals(
         expectedInputColumnNames,
         groupByLevelDescriptorList.stream()
-            .map(GroupByLevelDescriptor::getInputColumnNames)
-            .flatMap(List::stream)
-            .distinct()
+            .map(GroupByLevelDescriptor::getInputColumnNamesList)
             .collect(Collectors.toList()));
   }
 

Reply via email to