This is an automated email from the ASF dual-hosted git repository.

caogaofei pushed a commit to branch beyyes/agg_template_alignbydevice
in repository https://gitbox.apache.org/repos/asf/iotdb.git


The following commit(s) were added to 
refs/heads/beyyes/agg_template_alignbydevice by this push:
     new 304f68f1a5b add select max_time(s1), last_value(s1), last_value(s2) 
impl
304f68f1a5b is described below

commit 304f68f1a5baef1caffa041452ad260c0e019e5c
Author: Beyyes <[email protected]>
AuthorDate: Sat May 11 19:27:43 2024 +0800

    add select max_time(s1), last_value(s1), last_value(s2) impl
---
 .../db/queryengine/plan/analyze/Analysis.java      |   3 +-
 .../plan/analyze/ExpressionTypeAnalyzer.java       |  14 +-
 .../plan/analyze/TemplatedAggregationAnalyze.java  |  91 ++++++++++---
 .../db/queryengine/plan/analyze/TemplatedInfo.java |  22 +++-
 .../plan/optimization/AggregationPushDown.java     | 132 +++++++++++++++++--
 .../plan/planner/OperatorTreeGenerator.java        |  66 +++++++---
 .../plan/planner/TemplatedLogicalPlan.java         | 141 +++++++++++++++------
 .../plan/planner/TemplatedLogicalPlanBuilder.java  |  15 +--
 .../plan/parameter/AggregationDescriptor.java      |   2 +-
 9 files changed, 390 insertions(+), 96 deletions(-)

diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/Analysis.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/Analysis.java
index d7e2163eca1..3a58adc130e 100644
--- 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/Analysis.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/Analysis.java
@@ -409,7 +409,8 @@ public class Analysis implements IAnalysis {
     }
 
     if (isAllDevicesInOneTemplate()
-        && (isOnlyQueryTemplateMeasurements() || expression instanceof 
TimeSeriesOperand)) {
+        && isOnlyQueryTemplateMeasurements()
+        && expression instanceof TimeSeriesOperand) {
       TimeSeriesOperand seriesOperand = (TimeSeriesOperand) expression;
       return 
deviceTemplate.getSchemaMap().get(seriesOperand.getPath().getMeasurement()).getType();
     }
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/ExpressionTypeAnalyzer.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/ExpressionTypeAnalyzer.java
index 49904d6532c..1cd7fbcf836 100644
--- 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/ExpressionTypeAnalyzer.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/ExpressionTypeAnalyzer.java
@@ -19,6 +19,7 @@
 
 package org.apache.iotdb.db.queryengine.plan.analyze;
 
+import org.apache.iotdb.commons.path.MeasurementPath;
 import org.apache.iotdb.db.exception.sql.SemanticException;
 import org.apache.iotdb.db.queryengine.common.NodeRef;
 import org.apache.iotdb.db.queryengine.plan.expression.Expression;
@@ -65,7 +66,10 @@ public class ExpressionTypeAnalyzer {
   public static TSDataType analyzeExpression(Analysis analysis, Expression 
expression) {
     if (!analysis.getExpressionTypes().containsKey(NodeRef.of(expression))) {
       ExpressionTypeAnalyzer analyzer = new ExpressionTypeAnalyzer();
-      analyzer.analyze(expression, null);
+
+      Map<String, IMeasurementSchema> context =
+          analysis.isAllDevicesInOneTemplate() ? 
analysis.getDeviceTemplate().getSchemaMap() : null;
+      analyzer.analyze(expression, context);
 
       addExpressionTypes(analysis, analyzer);
     }
@@ -346,6 +350,14 @@ public class ExpressionTypeAnalyzer {
         return setExpressionType(
             timeSeriesOperand, 
context.get(timeSeriesOperand.getOutputSymbol()).getType());
       }
+
+      if (context != null
+          && !(timeSeriesOperand.getPath() instanceof MeasurementPath)
+          && context.containsKey(timeSeriesOperand.getPath().getFullPath())) {
+        return setExpressionType(
+            timeSeriesOperand, 
context.get(timeSeriesOperand.getPath().getFullPath()).getType());
+      }
+
       return setExpressionType(timeSeriesOperand, 
timeSeriesOperand.getPath().getSeriesType());
     }
 
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/TemplatedAggregationAnalyze.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/TemplatedAggregationAnalyze.java
index ae2afa9c8f5..60211412341 100644
--- 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/TemplatedAggregationAnalyze.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/TemplatedAggregationAnalyze.java
@@ -1,3 +1,21 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
 package org.apache.iotdb.db.queryengine.plan.analyze;
 
 import org.apache.iotdb.commons.path.PartialPath;
@@ -10,13 +28,17 @@ import 
org.apache.iotdb.db.queryengine.plan.statement.crud.QueryStatement;
 import org.apache.iotdb.db.schemaengine.template.Template;
 
 import org.apache.tsfile.utils.Pair;
+import org.apache.tsfile.write.schema.IMeasurementSchema;
 
 import java.util.ArrayList;
+import java.util.HashSet;
 import java.util.LinkedHashSet;
 import java.util.List;
+import java.util.Set;
 
 import static 
org.apache.iotdb.db.queryengine.plan.analyze.AnalyzeVisitor.DEVICE_EXPRESSION;
 import static 
org.apache.iotdb.db.queryengine.plan.analyze.AnalyzeVisitor.END_TIME_EXPRESSION;
+import static 
org.apache.iotdb.db.queryengine.plan.analyze.AnalyzeVisitor.analyzeExpressionType;
 import static 
org.apache.iotdb.db.queryengine.plan.analyze.AnalyzeVisitor.analyzeOutput;
 import static 
org.apache.iotdb.db.queryengine.plan.analyze.TemplatedAnalyze.analyzeDataPartition;
 import static 
org.apache.iotdb.db.queryengine.plan.analyze.TemplatedAnalyze.analyzeDeviceToWhere;
@@ -26,10 +48,9 @@ import static 
org.apache.iotdb.db.queryengine.plan.analyze.TemplatedAnalyze.anal
 import static 
org.apache.iotdb.db.queryengine.plan.optimization.LimitOffsetPushDown.canPushDownLimitOffsetInGroupByTimeForDevice;
 import static 
org.apache.iotdb.db.queryengine.plan.optimization.LimitOffsetPushDown.pushDownLimitOffsetInGroupByTimeForDevice;
 
+/** Methods in this class are used for aggregation, templated with align by 
device situation. */
 public class TemplatedAggregationAnalyze {
 
-  // ----------- Methods below are used for aggregation, templated with align 
by device --------
-
   static boolean analyzeAggregation(
       Analysis analysis,
       QueryStatement queryStatement,
@@ -45,6 +66,9 @@ public class TemplatedAggregationAnalyze {
       deviceList = pushDownLimitOffsetInGroupByTimeForDevice(deviceList, 
queryStatement);
     }
 
+    List<Pair<Expression, String>> outputExpressions = new ArrayList<>();
+    analyzeSelect(queryStatement, analysis, outputExpressions, template);
+
     analyzeDeviceToWhere(analysis, queryStatement);
     if (deviceList.isEmpty()) {
       analysis.setFinishQueryAfterAnalyze(true);
@@ -52,15 +76,8 @@ public class TemplatedAggregationAnalyze {
     }
     analysis.setDeviceList(deviceList);
 
-    List<Pair<Expression, String>> outputExpressions = new ArrayList<>();
-    ColumnPaginationController paginationController =
-        new ColumnPaginationController(
-            queryStatement.getSeriesLimit(), queryStatement.getSeriesOffset());
-    for (ResultColumn resultColumn : 
queryStatement.getSelectComponent().getResultColumns()) {}
-
-    analyzeSelect(queryStatement, analysis, outputExpressions, template);
     if (analysis.getWhereExpression() != null
-        && analysis.getWhereExpression().equals(ConstantOperand.FALSE)) {
+        && ConstantOperand.FALSE.equals(analysis.getWhereExpression())) {
       analyzeOutput(analysis, queryStatement, outputExpressions);
       analysis.setFinishQueryAfterAnalyze(true);
       return true;
@@ -87,21 +104,61 @@ public class TemplatedAggregationAnalyze {
       Analysis analysis,
       List<Pair<Expression, String>> outputExpressions,
       Template template) {
+
     LinkedHashSet<Expression> selectExpressions = new LinkedHashSet<>();
     selectExpressions.add(DEVICE_EXPRESSION);
     if (queryStatement.isOutputEndTime()) {
       selectExpressions.add(END_TIME_EXPRESSION);
     }
-    for (Pair<Expression, String> pair : outputExpressions) {
-      Expression selectExpression = pair.left;
-      selectExpressions.add(selectExpression);
+
+    ColumnPaginationController paginationController =
+        new ColumnPaginationController(
+            queryStatement.getSeriesLimit(), queryStatement.getSeriesOffset());
+
+    Set<Expression> aggregationExpressions = new LinkedHashSet<>();
+    for (ResultColumn resultColumn : 
queryStatement.getSelectComponent().getResultColumns()) {
+      if (paginationController.hasCurOffset()) {
+        paginationController.consumeOffset();
+      } else if (paginationController.hasCurLimit()) {
+        Expression selectExpression = resultColumn.getExpression();
+        outputExpressions.add(new Pair<>(selectExpression, 
resultColumn.getAlias()));
+        selectExpressions.add(selectExpression);
+        aggregationExpressions.add(selectExpression);
+      } else {
+        break;
+      }
     }
+
+    analysis.setDeviceTemplate(template);
+    List<String> measurementList = new ArrayList<>();
+    List<IMeasurementSchema> measurementSchemaList = new ArrayList<>();
+    Set<String> measurementSet = new HashSet<>();
+    for (Expression selectExpression : selectExpressions) {
+      if ("device".equalsIgnoreCase(selectExpression.getOutputSymbol())) {
+        continue;
+      }
+
+      String measurement = 
selectExpression.getExpressions().get(0).getOutputSymbol();
+      if (!template.getSchemaMap().containsKey(measurement)) {
+        throw new IllegalArgumentException(
+            "Measurement " + measurement + " is not found in template");
+      }
+
+      // for agg1(s1) + agg2(s1), only record s1 for one time
+      if (!measurementSet.contains(measurement)) {
+        measurementSet.add(measurement);
+        measurementList.add(measurement);
+        measurementSchemaList.add(template.getSchemaMap().get(measurement));
+      }
+
+      analyzeExpressionType(analysis, selectExpression);
+    }
+
+    analysis.setMeasurementList(measurementList);
+    analysis.setMeasurementSchemaList(measurementSchemaList);
+    analysis.setAggregationExpressions(aggregationExpressions);
     analysis.setOutputExpressions(outputExpressions);
     analysis.setSelectExpressions(selectExpressions);
-    analysis.setDeviceTemplate(template);
-    // TODO only add measurement and schema occured in selectExpressions
-    analysis.setMeasurementList(new 
ArrayList<>(template.getSchemaMap().keySet()));
-    analysis.setMeasurementSchemaList(new 
ArrayList<>(template.getSchemaMap().values()));
   }
 
   private static void analyzeDeviceToSourceTransform(Analysis analysis) {
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/TemplatedInfo.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/TemplatedInfo.java
index bf5d54e5cd6..f042c36a07d 100644
--- 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/TemplatedInfo.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/TemplatedInfo.java
@@ -23,6 +23,8 @@ import org.apache.iotdb.commons.path.MeasurementPath;
 import org.apache.iotdb.commons.path.PartialPath;
 import org.apache.iotdb.db.queryengine.plan.expression.Expression;
 import org.apache.iotdb.db.queryengine.plan.expression.leaf.TimeSeriesOperand;
+import 
org.apache.iotdb.db.queryengine.plan.planner.plan.parameter.AggregationDescriptor;
+import 
org.apache.iotdb.db.queryengine.plan.planner.plan.parameter.GroupByTimeParameter;
 import 
org.apache.iotdb.db.queryengine.plan.planner.plan.parameter.InputLocation;
 import org.apache.iotdb.db.queryengine.plan.statement.component.Ordering;
 
@@ -76,10 +78,15 @@ public class TemplatedInfo {
   private int maxTsBlockLineNum = -1;
 
   // variables related to predicate push down
+  // TODO when to init pushDownPredicate in agg situation?
   private Expression pushDownPredicate;
 
-  private Set<Expression> aggSelectExpressions;
+  // variables related to aggregation
+  public List<AggregationDescriptor> aggregationDescriptorList;
+  public GroupByTimeParameter groupByTimeParameter;
+  public boolean outputEndTime = false;
 
+  private Set<Expression> aggSelectExpressions;
   private Expression havingExpression;
 
   public TemplatedInfo(
@@ -96,7 +103,9 @@ public class TemplatedInfo {
       boolean keepNull,
       Map<String, IMeasurementSchema> schemaMap,
       Map<String, List<InputLocation>> layoutMap,
-      Expression pushDownPredicate) {
+      Expression pushDownPredicate,
+      GroupByTimeParameter groupByTimeParameter,
+      boolean outputEndTime) {
     this.measurementList = measurementList;
     this.schemaList = schemaList;
     this.dataTypes = dataTypes;
@@ -113,6 +122,9 @@ public class TemplatedInfo {
       this.layoutMap = layoutMap;
     }
     this.pushDownPredicate = pushDownPredicate;
+
+    this.groupByTimeParameter = groupByTimeParameter;
+    this.outputEndTime = outputEndTime;
   }
 
   public List<String> getMeasurementList() {
@@ -355,6 +367,8 @@ public class TemplatedInfo {
       pushDownPredicate = Expression.deserialize(byteBuffer);
     }
 
+    // TODO add groupByTimeParameter, outputEndTime serialization and 
deserialization
+
     return new TemplatedInfo(
         measurementList,
         measurementSchemaList,
@@ -369,6 +383,8 @@ public class TemplatedInfo {
         keepNull,
         currentSchemaMap,
         layoutMap,
-        pushDownPredicate);
+        pushDownPredicate,
+        null,
+        false);
   }
 }
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/optimization/AggregationPushDown.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/optimization/AggregationPushDown.java
index 344572be2b3..ec1d2fdb09a 100644
--- 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/optimization/AggregationPushDown.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/optimization/AggregationPushDown.java
@@ -63,11 +63,13 @@ import org.apache.tsfile.utils.Pair;
 import org.apache.tsfile.write.schema.IMeasurementSchema;
 
 import java.util.ArrayList;
+import java.util.Collection;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
+import java.util.stream.Collectors;
 
 import static com.google.common.base.Preconditions.checkState;
 import static org.apache.iotdb.db.utils.constant.SqlConstant.COUNT_TIME;
@@ -92,6 +94,12 @@ public class AggregationPushDown implements PlanOptimizer {
   private boolean cannotUseStatistics(QueryStatement queryStatement, Analysis 
analysis) {
     boolean isAlignByDevice = queryStatement.isAlignByDevice();
     if (isAlignByDevice) {
+      if (analysis.isAllDevicesInOneTemplate()) {
+        // TODO agg+template situation, how about the 
SourceTransformExpressions
+        return cannotUseStatistics(
+            analysis.getAggregationExpressions(), 
analysis.getAggregationExpressions());
+      }
+
       // check any of the devices
       String device = analysis.getDeviceList().get(0).toString();
       return cannotUseStatistics(
@@ -172,6 +180,10 @@ public class AggregationPushDown implements PlanOptimizer {
       List<PlanNode> rewrittenChildren = new ArrayList<>();
       for (int i = 0; i < node.getDevices().size(); i++) {
         context.setCurDevice(node.getDevices().get(i));
+        if (context.analysis.isAllDevicesInOneTemplate()) {
+          context.setCurDevicePath(context.analysis.getDeviceList().get(i));
+        }
+
         rewrittenChildren.add(node.getChildren().get(i).accept(this, context));
       }
       node.setChildren(rewrittenChildren);
@@ -257,14 +269,26 @@ public class AggregationPushDown implements PlanOptimizer 
{
               sourceToCountTimeAggregationsMap);
         }
 
-        List<PlanNode> sourceNodeList =
-            constructSourceNodeFromAggregationDescriptors(
-                sourceToAscendingAggregationsMap,
-                sourceToDescendingAggregationsMap,
-                sourceToCountTimeAggregationsMap,
-                node.getScanOrder(),
-                node.getGroupByTimeParameter(),
-                context);
+        List<PlanNode> sourceNodeList;
+        if (context.analysis.isAllDevicesInOneTemplate()) {
+          sourceNodeList =
+              constructSourceNodeFromTemplateAggregationDescriptors(
+                  sourceToAscendingAggregationsMap,
+                  sourceToDescendingAggregationsMap,
+                  sourceToCountTimeAggregationsMap,
+                  node.getScanOrder(),
+                  node.getGroupByTimeParameter(),
+                  context);
+        } else {
+          sourceNodeList =
+              constructSourceNodeFromAggregationDescriptors(
+                  sourceToAscendingAggregationsMap,
+                  sourceToDescendingAggregationsMap,
+                  sourceToCountTimeAggregationsMap,
+                  node.getScanOrder(),
+                  node.getGroupByTimeParameter(),
+                  context);
+        }
 
         if (isSingleSource && ((SeriesScanSourceNode) 
child).getPushDownPredicate() != null) {
           Expression pushDownPredicate = ((SeriesScanSourceNode) 
child).getPushDownPredicate();
@@ -352,7 +376,6 @@ public class AggregationPushDown implements PlanOptimizer {
         GroupByTimeParameter groupByTimeParameter,
         RewriterContext context) {
       List<PlanNode> sourceNodeList = new ArrayList<>();
-      boolean needCheckAscending = groupByTimeParameter == null;
       Map<PartialPath, List<AggregationDescriptor>> 
groupedAscendingAggregations = null;
       if (!countTimeAggregations.isEmpty()) {
         groupedAscendingAggregations = countTimeAggregations;
@@ -371,6 +394,7 @@ public class AggregationPushDown implements PlanOptimizer {
                 context));
       }
 
+      boolean needCheckAscending = groupByTimeParameter == null;
       if (needCheckAscending) {
         Map<PartialPath, List<AggregationDescriptor>> 
groupedDescendingAggregations =
             MetaUtils.groupAlignedAggregations(descendingAggregations);
@@ -388,6 +412,85 @@ public class AggregationPushDown implements PlanOptimizer {
       return sourceNodeList;
     }
 
+    private List<PlanNode> 
constructSourceNodeFromTemplateAggregationDescriptors(
+        Map<PartialPath, List<AggregationDescriptor>> ascendingAggregations,
+        Map<PartialPath, List<AggregationDescriptor>> descendingAggregations,
+        Map<PartialPath, List<AggregationDescriptor>> countTimeAggregations,
+        Ordering scanOrder,
+        GroupByTimeParameter groupByTimeParameter,
+        RewriterContext context) {
+
+      // keySet of ascendingAggregations is measurement,
+      // valueSet of ascendingAggregations is aggDescriptors such as 
count(s1), avg(s1)
+
+      List<PlanNode> sourceNodeList = new ArrayList<>();
+      PartialPath devicePath = context.curDevicePath;
+      List<String> measurementList = context.analysis.getMeasurementList();
+      List<IMeasurementSchema> measurementSchemaList = 
context.analysis.getMeasurementSchemaList();
+      boolean needCheckAscending = groupByTimeParameter == null;
+
+      if (context.analysis.getDeviceTemplate().isDirectAligned()) {
+        AlignedPath alignedPath = new AlignedPath(devicePath);
+        alignedPath.setMeasurementList(measurementList);
+        alignedPath.addSchemas(measurementSchemaList);
+
+        List<AggregationDescriptor> aggregationDescriptors =
+            ascendingAggregations.values().stream()
+                .flatMap(Collection::stream)
+                .collect(Collectors.toList());
+        if (!aggregationDescriptors.isEmpty()) {
+          sourceNodeList.add(
+              createAggregationScanNode(
+                  alignedPath, aggregationDescriptors, scanOrder, 
groupByTimeParameter, context));
+        }
+
+        if (needCheckAscending) {
+          aggregationDescriptors =
+              descendingAggregations.values().stream()
+                  .flatMap(Collection::stream)
+                  .collect(Collectors.toList());
+          sourceNodeList.add(
+              createAggregationScanNode(
+                  alignedPath, aggregationDescriptors, scanOrder, 
groupByTimeParameter, context));
+        }
+      } else {
+        // TODO verify the rightness of non-aligned series
+        for (int i = 0; i < measurementList.size(); i++) {
+          MeasurementPath measurementPath =
+              new MeasurementPath(
+                  devicePath.concatNode(measurementList.get(i)), 
measurementSchemaList.get(i));
+          for (List<AggregationDescriptor> aggregationDescriptorList :
+              descendingAggregations.values()) {
+            sourceNodeList.add(
+                createAggregationScanNode(
+                    measurementPath,
+                    aggregationDescriptorList,
+                    scanOrder,
+                    groupByTimeParameter,
+                    context));
+          }
+
+          if (needCheckAscending) {
+            for (List<AggregationDescriptor> aggregationDescriptorList :
+                descendingAggregations.values()) {
+              sourceNodeList.add(
+                  createAggregationScanNode(
+                      measurementPath,
+                      aggregationDescriptorList,
+                      scanOrder,
+                      groupByTimeParameter,
+                      context));
+            }
+          }
+        }
+      }
+
+      // TODO count(s1+s2) is not supported
+      // TODO count_time is not supported
+
+      return sourceNodeList;
+    }
+
     private SeriesAggregationSourceNode createAggregationScanNode(
         PartialPath selectPath,
         List<AggregationDescriptor> aggregationDescriptorList,
@@ -442,6 +545,7 @@ public class AggregationPushDown implements PlanOptimizer {
     private final boolean isAlignByDevice;
 
     private String curDevice;
+    private PartialPath curDevicePath;
 
     public RewriterContext(Analysis analysis, MPPQueryContext context, boolean 
isAlignByDevice) {
       this.analysis = analysis;
@@ -461,9 +565,17 @@ public class AggregationPushDown implements PlanOptimizer {
       this.curDevice = curDevice;
     }
 
+    public void setCurDevicePath(PartialPath devicePath) {
+      this.curDevicePath = devicePath;
+    }
+
     public Set<Expression> getAggregationExpressions() {
       if (isAlignByDevice) {
-        return analysis.getDeviceToAggregationExpressions().get(curDevice);
+        if (analysis.isAllDevicesInOneTemplate()) {
+          return analysis.getAggregationExpressions();
+        } else {
+          return analysis.getDeviceToAggregationExpressions().get(curDevice);
+        }
       }
       return analysis.getAggregationExpressions();
     }
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/OperatorTreeGenerator.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/OperatorTreeGenerator.java
index c271e8ec2da..1783480abb8 100644
--- 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/OperatorTreeGenerator.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/OperatorTreeGenerator.java
@@ -591,28 +591,62 @@ public class OperatorTreeGenerator extends 
PlanVisitor<Operator, LocalExecutionP
   @Override
   public Operator visitAlignedSeriesAggregationScan(
       AlignedSeriesAggregationScanNode node, LocalExecutionPlanContext 
context) {
-    AlignedPath seriesPath = node.getAlignedPath();
-    boolean ascending = node.getScanOrder() == Ordering.ASC;
+    if (context.isBuildPlanUseTemplate()) {
+      return constructAlignedSeriesAggregationScanOperator(
+          node.getPlanNodeId(),
+          node.getAlignedPath(),
+          context.getTemplatedInfo().aggregationDescriptorList,
+          context.getTemplatedInfo().getPushDownPredicate(),
+          context.getTemplatedInfo().getScanOrder(),
+          context.getTemplatedInfo().groupByTimeParameter,
+          context.getTemplatedInfo().outputEndTime,
+          context);
+    }
+
+    return constructAlignedSeriesAggregationScanOperator(
+        node.getPlanNodeId(),
+        node.getAlignedPath(),
+        node.getAggregationDescriptorList(),
+        node.getPushDownPredicate(),
+        node.getScanOrder(),
+        node.getGroupByTimeParameter(),
+        node.isOutputEndTime(),
+        context);
+  }
+
+  private Operator constructAlignedSeriesAggregationScanOperator(
+      PlanNodeId planNodeId,
+      AlignedPath alignedPath,
+      List<AggregationDescriptor> aggregationDescriptorList,
+      Expression pushDownPredicate,
+      Ordering scanOrder,
+      GroupByTimeParameter groupByTimeParameter,
+      boolean outputEndTime,
+      LocalExecutionPlanContext context) {
+    boolean ascending = scanOrder == Ordering.ASC;
     List<Aggregator> aggregators = new ArrayList<>();
-    for (AggregationDescriptor descriptor : 
node.getAggregationDescriptorList()) {
+    for (AggregationDescriptor descriptor : aggregationDescriptorList) {
       checkArgument(
           descriptor.getInputExpressions().size() == 1,
           "descriptor's input expression size is not 1");
+
       Expression expression = descriptor.getInputExpressions().get(0);
       if (expression instanceof TimeSeriesOperand) {
+        // TODO for template_agg, no need use getPath.getMeasurement
         String inputSeries =
             ((TimeSeriesOperand) (descriptor.getInputExpressions().get(0)))
                 .getPath()
                 .getMeasurement();
-        int seriesIndex = seriesPath.getMeasurementList().indexOf(inputSeries);
+        int seriesIndex = 
alignedPath.getMeasurementList().indexOf(inputSeries);
         TSDataType seriesDataType =
-            
seriesPath.getMeasurementSchema().getSubMeasurementsTSDataTypeList().get(seriesIndex);
+            
alignedPath.getMeasurementSchema().getSubMeasurementsTSDataTypeList().get(seriesIndex);
         aggregators.add(
             new Aggregator(
                 AccumulatorFactory.createAccumulator(
                     descriptor.getAggregationFuncName(),
                     descriptor.getAggregationType(),
                     Collections.singletonList(seriesDataType),
+                    // TODO inputExpression must be devicePath+measurement
                     descriptor.getInputExpressions(),
                     descriptor.getInputAttributes(),
                     ascending,
@@ -627,6 +661,7 @@ public class OperatorTreeGenerator extends 
PlanVisitor<Operator, LocalExecutionP
                     descriptor.getAggregationFuncName(),
                     descriptor.getAggregationType(),
                     Collections.singletonList(TSDataType.INT64),
+                    // TODO inputExpression must be devicePath+measurement
                     descriptor.getInputExpressions(),
                     descriptor.getInputAttributes(),
                     ascending,
@@ -640,23 +675,22 @@ public class OperatorTreeGenerator extends 
PlanVisitor<Operator, LocalExecutionP
       }
     }
 
-    GroupByTimeParameter groupByTimeParameter = node.getGroupByTimeParameter();
     ITimeRangeIterator timeRangeIterator =
         initTimeRangeIterator(groupByTimeParameter, ascending, true);
     long maxReturnSize =
         AggregationUtil.calculateMaxAggregationResultSize(
-            node.getAggregationDescriptorList(), timeRangeIterator, 
context.getTypeProvider());
+            aggregationDescriptorList, timeRangeIterator, 
context.getTypeProvider());
 
     SeriesScanOptions.Builder scanOptionsBuilder = 
getSeriesScanOptionsBuilder(context);
-    scanOptionsBuilder.withAllSensors(new 
HashSet<>(seriesPath.getMeasurementList()));
+    scanOptionsBuilder.withAllSensors(new 
HashSet<>(alignedPath.getMeasurementList()));
 
-    Expression pushDownPredicate = node.getPushDownPredicate();
     if (pushDownPredicate != null) {
       
checkArgument(PredicateUtils.predicateCanPushIntoScan(pushDownPredicate));
       scanOptionsBuilder.withPushDownFilter(
           convertPredicateToFilter(
               pushDownPredicate,
-              node.getAlignedPath().getMeasurementList(),
+              alignedPath.getMeasurementList(),
+              // TODO what's the meaning of isBuildPlanUseTemplate
               context.getTypeProvider().getTemplatedInfo() != null,
               context.getTypeProvider()));
     }
@@ -666,14 +700,14 @@ public class OperatorTreeGenerator extends 
PlanVisitor<Operator, LocalExecutionP
             .getDriverContext()
             .addOperatorContext(
                 context.getNextOperatorId(),
-                node.getPlanNodeId(),
+                planNodeId,
                 AlignedSeriesAggregationScanOperator.class.getSimpleName());
     AlignedSeriesAggregationScanOperator seriesAggregationScanOperator =
         new AlignedSeriesAggregationScanOperator(
-            node.getPlanNodeId(),
-            seriesPath,
-            node.getScanOrder(),
-            node.isOutputEndTime(),
+            planNodeId,
+            alignedPath,
+            scanOrder,
+            outputEndTime,
             scanOptionsBuilder.build(),
             operatorContext,
             aggregators,
@@ -683,7 +717,7 @@ public class OperatorTreeGenerator extends 
PlanVisitor<Operator, LocalExecutionP
 
     ((DataDriverContext) context.getDriverContext())
         .addSourceOperator(seriesAggregationScanOperator);
-    ((DataDriverContext) context.getDriverContext()).addPath(seriesPath);
+    ((DataDriverContext) context.getDriverContext()).addPath(alignedPath);
     context.getDriverContext().setInputDriver(true);
     return seriesAggregationScanOperator;
   }
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/TemplatedLogicalPlan.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/TemplatedLogicalPlan.java
index 7bb9ebd579c..6117e8795d9 100644
--- 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/TemplatedLogicalPlan.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/TemplatedLogicalPlan.java
@@ -25,14 +25,19 @@ import 
org.apache.iotdb.db.queryengine.plan.analyze.Analysis;
 import org.apache.iotdb.db.queryengine.plan.analyze.TemplatedInfo;
 import org.apache.iotdb.db.queryengine.plan.expression.Expression;
 import org.apache.iotdb.db.queryengine.plan.expression.leaf.TimeSeriesOperand;
+import 
org.apache.iotdb.db.queryengine.plan.expression.multi.FunctionExpression;
 import org.apache.iotdb.db.queryengine.plan.planner.plan.node.PlanNode;
+import 
org.apache.iotdb.db.queryengine.plan.planner.plan.parameter.AggregationDescriptor;
 import 
org.apache.iotdb.db.queryengine.plan.planner.plan.parameter.AggregationStep;
 import 
org.apache.iotdb.db.queryengine.plan.planner.plan.parameter.InputLocation;
 import org.apache.iotdb.db.queryengine.plan.statement.crud.QueryStatement;
 
+import org.apache.commons.lang3.Validate;
+import org.apache.tsfile.enums.TSDataType;
 import org.apache.tsfile.write.schema.IMeasurementSchema;
 
 import java.util.ArrayList;
+import java.util.Collection;
 import java.util.HashSet;
 import java.util.LinkedHashMap;
 import java.util.List;
@@ -40,8 +45,11 @@ import java.util.Map;
 import java.util.Set;
 import java.util.stream.Collectors;
 
+import static 
org.apache.iotdb.db.queryengine.common.header.ColumnHeaderConstant.DEVICE;
+import static 
org.apache.iotdb.db.queryengine.common.header.ColumnHeaderConstant.ENDTIME;
 import static 
org.apache.iotdb.db.queryengine.plan.analyze.ExpressionAnalyzer.searchSourceExpressions;
 import static 
org.apache.iotdb.db.queryengine.plan.analyze.TemplatedInfo.makeLayout;
+import static 
org.apache.iotdb.db.queryengine.plan.planner.LogicalPlanBuilder.updateTypeProviderByPartialAggregation;
 import static 
org.apache.iotdb.db.queryengine.plan.planner.LogicalPlanVisitor.pushDownLimitToScanNode;
 
 /**
@@ -71,6 +79,8 @@ public class TemplatedLogicalPlan {
 
   private Map<String, List<InputLocation>> filterLayoutMap;
 
+  List<AggregationDescriptor> aggregationDescriptorList;
+
   public TemplatedLogicalPlan(
       Analysis analysis, QueryStatement queryStatement, MPPQueryContext 
context) {
     this.analysis = analysis;
@@ -144,7 +154,9 @@ public class TemplatedLogicalPlan {
                 queryStatement.isGroupByTime(),
                 analysis.getDeviceTemplate().getSchemaMap(),
                 filterLayoutMap,
-                null));
+                null,
+                analysis.getGroupByTimeParameter(),
+                queryStatement.isOutputEndTime()));
   }
 
   private void initNonAggQueryCommonVariables() {
@@ -199,7 +211,9 @@ public class TemplatedLogicalPlan {
                 queryStatement.isGroupByTime(),
                 analysis.getDeviceTemplate().getSchemaMap(),
                 filterLayoutMap,
-                null));
+                null,
+                analysis.getGroupByTimeParameter(),
+                queryStatement.isOutputEndTime()));
   }
 
   public PlanNode visitQuery() {
@@ -255,14 +269,62 @@ public class TemplatedLogicalPlan {
     return planBuilder.getRoot();
   }
 
+  public PlanNode visitQueryBody(PartialPath devicePath) {
+
+    TemplatedLogicalPlanBuilder planBuilder =
+        new TemplatedLogicalPlanBuilder(analysis, context, newMeasurementList, 
newSchemaList);
+
+    planBuilder =
+        planBuilder
+            .planRawDataSource(
+                devicePath,
+                queryStatement.getResultTimeOrder(),
+                OFFSET_VALUE,
+                limitValue,
+                analysis.isLastLevelUseWildcard())
+            .planFilter(
+                whereExpression,
+                queryStatement.isGroupByTime(),
+                queryStatement.getResultTimeOrder());
+
+    return planBuilder.getRoot();
+  }
+
+  // ============== Methods below are used for templated aggregation 
======================
+
   private PlanNode visitAggregation() {
+    boolean outputPartial =
+        queryStatement.isGroupByLevel()
+            || queryStatement.isGroupByTag()
+            || (queryStatement.isGroupByTime() && 
analysis.getGroupByTimeParameter().hasOverlap());
+    AggregationStep curStep = outputPartial ? AggregationStep.PARTIAL : 
AggregationStep.SINGLE;
+
+    if (queryStatement.isGroupByTime() && 
analysis.getGroupByTimeParameter().hasOverlap()) {
+      curStep =
+          (queryStatement.isGroupByLevel() || queryStatement.isGroupByTag())
+              ? AggregationStep.INTERMEDIATE
+              : AggregationStep.FINAL;
+    }
+
+    aggregationDescriptorList =
+        
constructAggregationDescriptorList(analysis.getAggregationExpressions(), 
curStep);
+    updateTypeProvider(analysis.getAggregationExpressions());
+    if (curStep.isOutputPartial()) {
+      aggregationDescriptorList.forEach(
+          aggregationDescriptor ->
+              updateTypeProviderByPartialAggregation(
+                  aggregationDescriptor, context.getTypeProvider()));
+    }
+
+    context.getTypeProvider().getTemplatedInfo().aggregationDescriptorList =
+        aggregationDescriptorList;
+
     LogicalPlanBuilder planBuilder =
         new TemplatedLogicalPlanBuilder(analysis, context, measurementList, 
schemaList);
-
     Map<String, PlanNode> deviceToSubPlanMap = new LinkedHashMap<>();
     for (PartialPath devicePath : analysis.getDeviceList()) {
       String deviceName = devicePath.getFullPath();
-      PlanNode rootNode = visitDeviceAggregationBody(devicePath);
+      PlanNode rootNode = visitDeviceAggregationBody(devicePath, curStep);
 
       LogicalPlanBuilder subPlanBuilder =
           new TemplatedLogicalPlanBuilder(analysis, context, measurementList, 
schemaList)
@@ -305,28 +367,7 @@ public class TemplatedLogicalPlan {
     return planBuilder.getRoot();
   }
 
-  public PlanNode visitQueryBody(PartialPath devicePath) {
-
-    TemplatedLogicalPlanBuilder planBuilder =
-        new TemplatedLogicalPlanBuilder(analysis, context, newMeasurementList, 
newSchemaList);
-
-    planBuilder =
-        planBuilder
-            .planRawDataSource(
-                devicePath,
-                queryStatement.getResultTimeOrder(),
-                OFFSET_VALUE,
-                limitValue,
-                analysis.isLastLevelUseWildcard())
-            .planFilter(
-                whereExpression,
-                queryStatement.isGroupByTime(),
-                queryStatement.getResultTimeOrder());
-
-    return planBuilder.getRoot();
-  }
-
-  private PlanNode visitDeviceAggregationBody(PartialPath devicePath) {
+  private PlanNode visitDeviceAggregationBody(PartialPath devicePath, 
AggregationStep curStep) {
     TemplatedLogicalPlanBuilder planBuilder =
         new TemplatedLogicalPlanBuilder(analysis, context, newMeasurementList, 
newSchemaList);
 
@@ -343,26 +384,18 @@ public class TemplatedLogicalPlan {
                 queryStatement.isGroupByTime(),
                 queryStatement.getResultTimeOrder());
 
-    boolean outputPartial =
-        queryStatement.isGroupByLevel()
-            || queryStatement.isGroupByTag()
-            || (queryStatement.isGroupByTime() && 
analysis.getGroupByTimeParameter().hasOverlap());
-    AggregationStep curStep = outputPartial ? AggregationStep.PARTIAL : 
AggregationStep.SINGLE;
     planBuilder =
         planBuilder.planRawDataAggregation(
-            analysis.getSelectExpressions(),
+            analysis.getAggregationExpressions(),
             null,
             analysis.getGroupByTimeParameter(),
             analysis.getGroupByParameter(),
             queryStatement.isOutputEndTime(),
             curStep,
-            queryStatement.getResultTimeOrder());
+            queryStatement.getResultTimeOrder(),
+            aggregationDescriptorList);
 
     if (queryStatement.isGroupByTime() && 
analysis.getGroupByTimeParameter().hasOverlap()) {
-      curStep =
-          (queryStatement.isGroupByLevel() || queryStatement.isGroupByTag())
-              ? AggregationStep.INTERMEDIATE
-              : AggregationStep.FINAL;
       planBuilder =
           planBuilder.planSlidingWindowAggregation(
               analysis.getSelectExpressions(),
@@ -374,4 +407,38 @@ public class TemplatedLogicalPlan {
     // no group by level and group by tag
     return planBuilder.getRoot();
   }
+
+  private List<AggregationDescriptor> constructAggregationDescriptorList(
+      Set<Expression> aggregationExpressions, AggregationStep curStep) {
+    return aggregationExpressions.stream()
+        .map(
+            expression -> {
+              Validate.isTrue(expression instanceof FunctionExpression);
+              return new AggregationDescriptor(
+                  ((FunctionExpression) expression).getFunctionName(),
+                  curStep,
+                  expression.getExpressions(),
+                  ((FunctionExpression) expression).getFunctionAttributes());
+            })
+        .collect(Collectors.toList());
+  }
+
+  void updateTypeProvider(Collection<Expression> expressions) {
+    if (expressions == null) {
+      return;
+    }
+    expressions.forEach(
+        expression -> {
+          if (!expression.getExpressionString().equals(DEVICE)
+              && !expression.getExpressionString().equals(ENDTIME)) {
+            context
+                .getTypeProvider()
+                .setType(expression.getExpressionString(), 
getPreAnalyzedType(expression));
+          }
+        });
+  }
+
+  private TSDataType getPreAnalyzedType(Expression expression) {
+    return analysis.getType(expression);
+  }
 }
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/TemplatedLogicalPlanBuilder.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/TemplatedLogicalPlanBuilder.java
index 4942c76e308..6ef95f0e119 100644
--- 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/TemplatedLogicalPlanBuilder.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/TemplatedLogicalPlanBuilder.java
@@ -48,6 +48,7 @@ import java.util.Set;
  * unnecessary judgements.
  */
 public class TemplatedLogicalPlanBuilder extends LogicalPlanBuilder {
+
   private final MPPQueryContext context;
 
   private final Analysis analysis;
@@ -132,6 +133,8 @@ public class TemplatedLogicalPlanBuilder extends 
LogicalPlanBuilder {
     return this;
   }
 
+  // ===================== Methods below are used for aggregation 
=============================
+
   public TemplatedLogicalPlanBuilder planRawDataAggregation(
       Set<Expression> aggregationExpressions,
       Expression groupByExpression,
@@ -139,20 +142,12 @@ public class TemplatedLogicalPlanBuilder extends 
LogicalPlanBuilder {
       GroupByParameter groupByParameter,
       boolean outputEndTime,
       AggregationStep curStep,
-      Ordering scanOrder) {
+      Ordering scanOrder,
+      List<AggregationDescriptor> aggregationDescriptorList) {
     if (aggregationExpressions == null) {
       return this;
     }
 
-    List<AggregationDescriptor> aggregationDescriptorList =
-        constructAggregationDescriptorList(aggregationExpressions, curStep);
-    updateTypeProvider(aggregationExpressions);
-    if (curStep.isOutputPartial()) {
-      aggregationDescriptorList.forEach(
-          aggregationDescriptor ->
-              updateTypeProviderByPartialAggregation(
-                  aggregationDescriptor, context.getTypeProvider()));
-    }
     this.root =
         new RawDataAggregationNode(
             context.getQueryId().genPlanNodeId(),
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/plan/parameter/AggregationDescriptor.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/plan/parameter/AggregationDescriptor.java
index d266de50aaa..c7758007d36 100644
--- 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/plan/parameter/AggregationDescriptor.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/plan/parameter/AggregationDescriptor.java
@@ -265,7 +265,7 @@ public class AggregationDescriptor {
 
   protected String getInputString(List<Expression> expressions) {
     StringBuilder builder = new StringBuilder();
-    if (!(expressions.size() == 0)) {
+    if (!(expressions.isEmpty())) {
       builder.append(expressions.get(0).getExpressionString());
       for (int i = 1; i < expressions.size(); ++i) {
         builder.append(", ").append(expressions.get(i).getExpressionString());


Reply via email to