This is an automated email from the ASF dual-hosted git repository.

jackietien pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iotdb.git


The following commit(s) were added to refs/heads/master by this push:
     new bf0666e096c Support global-aggregation in TableModel
bf0666e096c is described below

commit bf0666e096c70ad1a39ba2fa9733fa89e2e40cc1
Author: Weihao Li <[email protected]>
AuthorDate: Sun Sep 29 09:54:42 2024 +0800

    Support global-aggregation in TableModel
---
 .../source/relational/aggregation/Accumulator.java |  33 ++++
 .../relational/aggregation/AccumulatorFactory.java | 202 +++++++++++++++++++++
 .../aggregation/AggregationOperator.java           | 158 ++++++++++++++++
 .../source/relational/aggregation/Aggregator.java  |  81 +++++++++
 .../relational/aggregation/AvgAccumulator.java     | 191 +++++++++++++++++++
 .../relational/aggregation/CountAccumulator.java   |  80 ++++++++
 .../plan/planner/TableOperatorGenerator.java       |  89 ++++++++-
 .../metadata/TableBuiltinAggregationFunction.java  |  11 ++
 8 files changed, 844 insertions(+), 1 deletion(-)

diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/Accumulator.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/Accumulator.java
new file mode 100644
index 00000000000..734fa21c00c
--- /dev/null
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/Accumulator.java
@@ -0,0 +1,33 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package 
org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation;
+
+import org.apache.tsfile.block.column.Column;
+import org.apache.tsfile.block.column.ColumnBuilder;
+
+public interface Accumulator {
+  long getEstimatedSize();
+
+  Accumulator copy();
+
+  void addInput(Column[] arguments);
+
+  void addIntermediate(Column argument);
+
+  void evaluateIntermediate(ColumnBuilder columnBuilder);
+
+  void evaluateFinal(ColumnBuilder columnBuilder);
+
+  void reset();
+}
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/AccumulatorFactory.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/AccumulatorFactory.java
new file mode 100644
index 00000000000..f5157f19543
--- /dev/null
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/AccumulatorFactory.java
@@ -0,0 +1,202 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package 
org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation;
+
+import org.apache.iotdb.common.rpc.thrift.TAggregationType;
+import org.apache.iotdb.db.queryengine.plan.expression.Expression;
+import 
org.apache.iotdb.db.queryengine.plan.expression.binary.CompareBinaryExpression;
+import org.apache.iotdb.db.queryengine.plan.expression.leaf.ConstantOperand;
+
+import org.apache.tsfile.enums.TSDataType;
+
+import java.util.List;
+import java.util.Map;
+
+import static com.google.common.base.Preconditions.checkState;
+
+public class AccumulatorFactory {
+
+  public static Accumulator createAccumulator(
+      String functionName,
+      TAggregationType aggregationType,
+      List<TSDataType> inputDataTypes,
+      List<Expression> inputExpressions,
+      Map<String, String> inputAttributes,
+      boolean ascending) {
+    if (aggregationType == TAggregationType.UDAF) {
+      // If UDAF accumulator receives raw input, it needs to check input's 
attribute
+      throw new UnsupportedOperationException();
+    } else {
+      return createBuiltinAccumulator(
+          aggregationType, inputDataTypes, inputExpressions, inputAttributes, 
ascending);
+    }
+  }
+
+  public static Accumulator createBuiltinAccumulator(
+      TAggregationType aggregationType,
+      List<TSDataType> inputDataTypes,
+      List<Expression> inputExpressions,
+      Map<String, String> inputAttributes,
+      boolean ascending) {
+    return isMultiInputAggregation(aggregationType)
+        ? createBuiltinMultiInputAccumulator(aggregationType, inputDataTypes)
+        : createBuiltinSingleInputAccumulator(
+            aggregationType, inputDataTypes.get(0), inputExpressions, 
inputAttributes, ascending);
+  }
+
+  public static boolean isMultiInputAggregation(TAggregationType 
aggregationType) {
+    switch (aggregationType) {
+      case MAX_BY:
+      case MIN_BY:
+        return true;
+      default:
+        return false;
+    }
+  }
+
+  public static Accumulator createBuiltinMultiInputAccumulator(
+      TAggregationType aggregationType, List<TSDataType> inputDataTypes) {
+    switch (aggregationType) {
+      case MAX_BY:
+        checkState(inputDataTypes.size() == 2, "Wrong inputDataTypes size.");
+        // return new MaxByAccumulator(inputDataTypes.get(0), 
inputDataTypes.get(1));
+      case MIN_BY:
+        checkState(inputDataTypes.size() == 2, "Wrong inputDataTypes size.");
+        // return new MinByAccumulator(inputDataTypes.get(0), 
inputDataTypes.get(1));
+      default:
+        throw new IllegalArgumentException("Invalid Aggregation function: " + 
aggregationType);
+    }
+  }
+
+  private static Accumulator createBuiltinSingleInputAccumulator(
+      TAggregationType aggregationType,
+      TSDataType tsDataType,
+      List<Expression> inputExpressions,
+      Map<String, String> inputAttributes,
+      boolean ascending) {
+    switch (aggregationType) {
+      case COUNT:
+        return new CountAccumulator();
+      case AVG:
+        return new AvgAccumulator(tsDataType);
+        /*case SUM:
+          return new SumAccumulator(tsDataType);
+        case EXTREME:
+          return new ExtremeAccumulator(tsDataType);
+        case MAX_TIME:
+          return ascending ? new MaxTimeAccumulator() : new 
MaxTimeDescAccumulator();
+        case MIN_TIME:
+          return ascending ? new MinTimeAccumulator() : new 
MinTimeDescAccumulator();
+        case MAX_VALUE:
+          return new MaxValueAccumulator(tsDataType);
+        case MIN_VALUE:
+          return new MinValueAccumulator(tsDataType);
+        case LAST_VALUE:
+          return ascending
+              ? new LastValueAccumulator(tsDataType)
+              : new LastValueDescAccumulator(tsDataType);
+        case FIRST_VALUE:
+          return ascending
+              ? new FirstValueAccumulator(tsDataType)
+              : new FirstValueDescAccumulator(tsDataType);
+        case COUNT_IF:
+          return new CountIfAccumulator(
+              initKeepEvaluator(inputExpressions.get(1)),
+              Boolean.parseBoolean(inputAttributes.getOrDefault("ignoreNull", 
"true")));
+        case TIME_DURATION:
+          return new TimeDurationAccumulator();
+        case MODE:
+          return createModeAccumulator(tsDataType);
+        case COUNT_TIME:
+          return new CountTimeAccumulator();
+        case STDDEV:
+        case STDDEV_SAMP:
+          return new VarianceAccumulator(tsDataType, 
VarianceAccumulator.VarianceType.STDDEV_SAMP);
+        case STDDEV_POP:
+          return new VarianceAccumulator(tsDataType, 
VarianceAccumulator.VarianceType.STDDEV_POP);
+        case VARIANCE:
+        case VAR_SAMP:
+          return new VarianceAccumulator(tsDataType, 
VarianceAccumulator.VarianceType.VAR_SAMP);
+        case VAR_POP:
+          return new VarianceAccumulator(tsDataType, 
VarianceAccumulator.VarianceType.VAR_POP);*/
+      default:
+        throw new IllegalArgumentException("Invalid Aggregation function: " + 
aggregationType);
+    }
+  }
+
+  /*private Accumulator createModeAccumulator(TSDataType tsDataType) {
+    switch (tsDataType) {
+      case BOOLEAN:
+        return new BooleanModeAccumulator();
+      case TEXT:
+        return new BinaryModeAccumulator();
+      case INT32:
+        return new IntModeAccumulator();
+      case INT64:
+        return new LongModeAccumulator();
+      case FLOAT:
+        return new FloatModeAccumulator();
+      case DOUBLE:
+        return new DoubleModeAccumulator();
+      case BLOB:
+      case STRING:
+      case TIMESTAMP:
+      case DATE:
+      default:
+        throw new IllegalArgumentException("Unknown data type: " + tsDataType);
+    }
+  }*/
+
+  @FunctionalInterface
+  public interface KeepEvaluator {
+    boolean apply(long keep);
+  }
+
+  public static KeepEvaluator initKeepEvaluator(Expression keepExpression) {
+    // We have checked semantic in FE,
+    // keep expression must be ConstantOperand or CompareBinaryExpression here
+    if (keepExpression instanceof ConstantOperand) {
+      return keep -> keep >= 
Long.parseLong(keepExpression.getExpressionString());
+    } else {
+      long constant =
+          Long.parseLong(
+              ((CompareBinaryExpression) keepExpression)
+                  .getRightExpression()
+                  .getExpressionString());
+      switch (keepExpression.getExpressionType()) {
+        case LESS_THAN:
+          return keep -> keep < constant;
+        case LESS_EQUAL:
+          return keep -> keep <= constant;
+        case GREATER_THAN:
+          return keep -> keep > constant;
+        case GREATER_EQUAL:
+          return keep -> keep >= constant;
+        case EQUAL_TO:
+          return keep -> keep == constant;
+        case NON_EQUAL:
+          return keep -> keep != constant;
+        default:
+          throw new IllegalArgumentException(
+              "unsupported expression type: " + 
keepExpression.getExpressionType());
+      }
+    }
+  }
+}
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/AggregationOperator.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/AggregationOperator.java
new file mode 100644
index 00000000000..425f05f7867
--- /dev/null
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/AggregationOperator.java
@@ -0,0 +1,158 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package 
org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation;
+
+import org.apache.iotdb.db.queryengine.execution.MemoryEstimationHelper;
+import org.apache.iotdb.db.queryengine.execution.operator.Operator;
+import org.apache.iotdb.db.queryengine.execution.operator.OperatorContext;
+import 
org.apache.iotdb.db.queryengine.execution.operator.process.ProcessOperator;
+import 
org.apache.iotdb.db.queryengine.execution.operator.source.relational.TableScanOperator;
+import 
org.apache.iotdb.db.queryengine.plan.planner.memory.MemoryReservationManager;
+
+import com.google.common.util.concurrent.ListenableFuture;
+import org.apache.tsfile.block.column.Column;
+import org.apache.tsfile.block.column.ColumnBuilder;
+import org.apache.tsfile.common.conf.TSFileDescriptor;
+import org.apache.tsfile.read.common.block.TsBlock;
+import org.apache.tsfile.read.common.block.TsBlockBuilder;
+import org.apache.tsfile.read.common.block.column.RunLengthEncodedColumn;
+import org.apache.tsfile.utils.RamUsageEstimator;
+
+import java.util.List;
+
+import static com.google.common.collect.ImmutableList.toImmutableList;
+
+public class AggregationOperator implements ProcessOperator {
+  private static final long INSTANCE_SIZE =
+      RamUsageEstimator.shallowSizeOfInstance(AggregationOperator.class);
+
+  private final OperatorContext operatorContext;
+
+  private final Operator child;
+
+  private final List<Aggregator> aggregators;
+
+  private final TsBlockBuilder resultBuilder;
+
+  private final ColumnBuilder[] resultColumnsBuilder;
+
+  private final long maxReturnSize =
+      TSFileDescriptor.getInstance().getConfig().getMaxTsBlockSizeInBytes();
+
+  protected MemoryReservationManager memoryReservationManager;
+
+  private boolean finished = false;
+
+  public AggregationOperator(
+      OperatorContext operatorContext, Operator child, List<Aggregator> 
aggregators) {
+    this.operatorContext = operatorContext;
+    this.child = child;
+    this.aggregators = aggregators;
+    this.resultBuilder =
+        new TsBlockBuilder(
+            
aggregators.stream().map(Aggregator::getType).collect(toImmutableList()));
+    this.resultColumnsBuilder = resultBuilder.getValueColumnBuilders();
+    this.memoryReservationManager =
+        operatorContext
+            .getDriverContext()
+            .getFragmentInstanceContext()
+            .getMemoryReservationContext();
+  }
+
+  @Override
+  public ListenableFuture<?> isBlocked() {
+    return child.isBlocked();
+  }
+
+  @Override
+  public boolean hasNext() throws Exception {
+    return !finished;
+  }
+
+  @Override
+  public TsBlock next() throws Exception {
+    // Each call only calculate at most once, no need to check time slice.
+    TsBlock block;
+    if (child.hasNextWithTimer()) {
+      block = child.nextWithTimer();
+      if (block == null) {
+        return null;
+      }
+
+      for (Aggregator aggregator : aggregators) {
+        aggregator.processBlock(block);
+      }
+
+      return null;
+    } else {
+      // evaluate output
+      Column[] valueColumns = new Column[resultColumnsBuilder.length];
+      for (int i = 0; i < aggregators.size(); i++) {
+        aggregators.get(i).evaluate(resultColumnsBuilder[i]);
+        valueColumns[i] = resultColumnsBuilder[i].build();
+      }
+
+      finished = true;
+      return TsBlock.wrapBlocksWithoutCopy(
+          1, new 
RunLengthEncodedColumn(TableScanOperator.TIME_COLUMN_TEMPLATE, 1), 
valueColumns);
+    }
+  }
+
+  @Override
+  public boolean isFinished() throws Exception {
+    return finished;
+  }
+
+  @Override
+  public void close() throws Exception {
+    child.close();
+  }
+
+  @Override
+  public OperatorContext getOperatorContext() {
+    return operatorContext;
+  }
+
+  @Override
+  public long calculateMaxPeekMemory() {
+    return Math.max(
+        child.calculateMaxPeekMemoryWithCounter(),
+        calculateRetainedSizeAfterCallingNext() + calculateMaxReturnSize());
+  }
+
+  @Override
+  public long calculateMaxReturnSize() {
+    return maxReturnSize;
+  }
+
+  @Override
+  public long calculateRetainedSizeAfterCallingNext() {
+    return child.calculateMaxReturnSize() + 
child.calculateRetainedSizeAfterCallingNext();
+  }
+
+  @Override
+  public long ramBytesUsed() {
+    return INSTANCE_SIZE
+        + MemoryEstimationHelper.getEstimatedSizeOfAccountableObject(child)
+        + aggregators.stream().mapToLong(Aggregator::getEstimatedSize).count()
+        + 
MemoryEstimationHelper.getEstimatedSizeOfAccountableObject(operatorContext)
+        + resultBuilder.getRetainedSizeInBytes();
+  }
+}
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/Aggregator.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/Aggregator.java
new file mode 100644
index 00000000000..cbd465efd8e
--- /dev/null
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/Aggregator.java
@@ -0,0 +1,81 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package 
org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation;
+
+import 
org.apache.iotdb.db.queryengine.plan.relational.planner.node.AggregationNode;
+
+import com.google.common.primitives.Ints;
+import org.apache.tsfile.block.column.Column;
+import org.apache.tsfile.block.column.ColumnBuilder;
+import org.apache.tsfile.enums.TSDataType;
+import org.apache.tsfile.read.common.block.TsBlock;
+
+import java.util.List;
+import java.util.OptionalInt;
+
+import static com.google.common.base.Preconditions.checkArgument;
+import static java.util.Objects.requireNonNull;
+
+public class Aggregator {
+  private final Accumulator accumulator;
+  private final AggregationNode.Step step;
+  private final TSDataType outputType;
+  private final int[] inputChannels;
+  private final OptionalInt maskChannel;
+
+  public Aggregator(
+      Accumulator accumulator,
+      AggregationNode.Step step,
+      TSDataType outputType,
+      List<Integer> inputChannels,
+      OptionalInt maskChannel) {
+    this.accumulator = requireNonNull(accumulator, "accumulator is null");
+    this.step = requireNonNull(step, "step is null");
+    this.outputType = requireNonNull(outputType, "intermediateType is null");
+    this.inputChannels = Ints.toArray(requireNonNull(inputChannels, 
"inputChannels is null"));
+    this.maskChannel = requireNonNull(maskChannel, "maskChannel is null");
+    checkArgument(
+        step.isInputRaw() || inputChannels.size() == 1,
+        "expected 1 input channel for intermediate aggregation");
+  }
+
+  public TSDataType getType() {
+    return outputType;
+  }
+
+  public void processBlock(TsBlock block) {
+    if (step.isInputRaw()) {
+      Column[] arguments = block.getColumns(inputChannels);
+      accumulator.addInput(arguments);
+    } else {
+      accumulator.addIntermediate(block.getColumn(inputChannels[0]));
+    }
+  }
+
+  public void evaluate(ColumnBuilder columnBuilder) {
+    if (step.isOutputPartial()) {
+      accumulator.evaluateIntermediate(columnBuilder);
+    } else {
+      accumulator.evaluateFinal(columnBuilder);
+    }
+  }
+
+  public void reset() {
+    accumulator.reset();
+  }
+
+  public long getEstimatedSize() {
+    return accumulator.getEstimatedSize();
+  }
+}
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/AvgAccumulator.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/AvgAccumulator.java
new file mode 100644
index 00000000000..d5c9eff3630
--- /dev/null
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/AvgAccumulator.java
@@ -0,0 +1,191 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package 
org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation;
+
+import org.apache.tsfile.block.column.Column;
+import org.apache.tsfile.block.column.ColumnBuilder;
+import org.apache.tsfile.enums.TSDataType;
+import org.apache.tsfile.read.common.block.column.BinaryColumn;
+import org.apache.tsfile.read.common.block.column.BinaryColumnBuilder;
+import org.apache.tsfile.utils.Binary;
+import org.apache.tsfile.utils.BytesUtils;
+import org.apache.tsfile.utils.RamUsageEstimator;
+import org.apache.tsfile.write.UnSupportedDataTypeException;
+
+import java.io.ByteArrayOutputStream;
+import java.io.DataOutputStream;
+import java.io.IOException;
+
+import static com.google.common.base.Preconditions.checkArgument;
+
+public class AvgAccumulator implements Accumulator {
+  private static final long INSTANCE_SIZE =
+      RamUsageEstimator.shallowSizeOfInstance(AvgAccumulator.class);
+  private final TSDataType argumentDataType;
+  private long countValue;
+  private double sumValue;
+  private boolean initResult = false;
+
+  public AvgAccumulator(TSDataType argumentDataType) {
+    this.argumentDataType = argumentDataType;
+  }
+
+  @Override
+  public long getEstimatedSize() {
+    return INSTANCE_SIZE;
+  }
+
+  @Override
+  public Accumulator copy() {
+    return new AvgAccumulator(argumentDataType);
+  }
+
+  @Override
+  public void addInput(Column[] arguments) {
+    checkArgument(arguments.length == 1, "argument of Avg should be one 
column");
+    switch (argumentDataType) {
+      case INT32:
+        addIntInput(arguments[0]);
+        return;
+      case INT64:
+        addLongInput(arguments[0]);
+        return;
+      case FLOAT:
+        addFloatInput(arguments[0]);
+        return;
+      case DOUBLE:
+        addDoubleInput(arguments[0]);
+        return;
+      case TEXT:
+      case BLOB:
+      case STRING:
+      case BOOLEAN:
+      case DATE:
+      case TIMESTAMP:
+      default:
+        throw new UnSupportedDataTypeException(
+            String.format("Unsupported data type in aggregation AVG : %s", 
argumentDataType));
+    }
+  }
+
+  @Override
+  public void addIntermediate(Column argument) {
+    checkArgument(
+        argument instanceof BinaryColumn,
+        "intermediate input and output of Avg should be BinaryColumn");
+    if (argument.isNull(0)) {
+      return;
+    }
+    initResult = true;
+    deserialize(argument.getBinary(0).getValues());
+    if (countValue == 0) {
+      initResult = false;
+    }
+  }
+
+  @Override
+  public void evaluateIntermediate(ColumnBuilder columnBuilder) {
+    checkArgument(
+        columnBuilder instanceof BinaryColumnBuilder,
+        "intermediate input and output of Avg should be BinaryColumn");
+    if (!initResult) {
+      columnBuilder.appendNull();
+    } else {
+      columnBuilder.writeBinary(new Binary(serializeState()));
+    }
+  }
+
+  private void deserialize(byte[] bytes) {
+    countValue = BytesUtils.bytesToLong(bytes, Long.BYTES);
+    sumValue = BytesUtils.bytesToDouble(bytes, Long.BYTES);
+  }
+
+  private byte[] serializeState() {
+    ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
+    DataOutputStream dataOutputStream = new 
DataOutputStream(byteArrayOutputStream);
+    try {
+      dataOutputStream.writeLong(countValue);
+      dataOutputStream.writeDouble(sumValue);
+    } catch (IOException e) {
+      throw new UnsupportedOperationException(
+          "Failed to serialize intermediate result for AvgAccumulator.", e);
+    }
+    return byteArrayOutputStream.toByteArray();
+  }
+
+  private void addIntInput(Column column) {
+    int count = column.getPositionCount();
+    for (int i = 0; i < count; i++) {
+      if (!column.isNull(i)) {
+        initResult = true;
+        countValue++;
+        sumValue += column.getInt(i);
+      }
+    }
+  }
+
+  private void addLongInput(Column column) {
+    int count = column.getPositionCount();
+    for (int i = 0; i < count; i++) {
+      if (!column.isNull(i)) {
+        initResult = true;
+        countValue++;
+        sumValue += column.getLong(i);
+      }
+    }
+  }
+
+  private void addFloatInput(Column column) {
+    int count = column.getPositionCount();
+    for (int i = 0; i < count; i++) {
+      if (!column.isNull(i)) {
+        initResult = true;
+        countValue++;
+        sumValue += column.getFloat(i);
+      }
+    }
+  }
+
+  private void addDoubleInput(Column column) {
+    int count = column.getPositionCount();
+    for (int i = 0; i < count; i++) {
+      if (!column.isNull(i)) {
+        initResult = true;
+        countValue++;
+        sumValue += column.getDouble(i);
+      }
+    }
+  }
+
+  @Override
+  public void evaluateFinal(ColumnBuilder columnBuilder) {
+    if (!initResult) {
+      columnBuilder.appendNull();
+    } else {
+      columnBuilder.writeDouble(sumValue / countValue);
+    }
+  }
+
+  @Override
+  public void reset() {
+    initResult = false;
+    this.countValue = 0;
+    this.sumValue = 0.0;
+  }
+}
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/CountAccumulator.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/CountAccumulator.java
new file mode 100644
index 00000000000..3f876564376
--- /dev/null
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/CountAccumulator.java
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package 
org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation;
+
+import org.apache.tsfile.block.column.Column;
+import org.apache.tsfile.block.column.ColumnBuilder;
+import org.apache.tsfile.utils.RamUsageEstimator;
+
+import static com.google.common.base.Preconditions.checkArgument;
+
+public class CountAccumulator implements Accumulator {
+  private static final long INSTANCE_SIZE =
+      RamUsageEstimator.shallowSizeOfInstance(CountAccumulator.class);
+  private long countState = 0;
+
+  @Override
+  public long getEstimatedSize() {
+    return INSTANCE_SIZE;
+  }
+
+  @Override
+  public Accumulator copy() {
+    return new CountAccumulator();
+  }
+
+  @Override
+  public void addInput(Column[] arguments) {
+    checkArgument(arguments.length == 1, "argument of Count should be one 
column");
+    int count = arguments[0].getPositionCount();
+    if (!arguments[0].mayHaveNull()) {
+      countState += count;
+    } else {
+      for (int i = 0; i < count; i++) {
+        if (!arguments[0].isNull(i)) {
+          countState++;
+        }
+      }
+    }
+  }
+
+  @Override
+  public void addIntermediate(Column argument) {
+    checkArgument(argument.getPositionCount() == 1, "partialResult should 
always be one line");
+    if (argument.isNull(0)) {
+      return;
+    }
+    countState += argument.getLong(0);
+  }
+
+  @Override
+  public void evaluateIntermediate(ColumnBuilder columnBuilder) {
+    columnBuilder.writeLong(countState);
+  }
+
+  @Override
+  public void evaluateFinal(ColumnBuilder columnBuilder) {
+    columnBuilder.writeLong(countState);
+  }
+
+  @Override
+  public void reset() {
+    countState = 0;
+  }
+}
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/TableOperatorGenerator.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/TableOperatorGenerator.java
index b0b520c92bb..7a5cc0db045 100644
--- 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/TableOperatorGenerator.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/TableOperatorGenerator.java
@@ -52,6 +52,9 @@ import 
org.apache.iotdb.db.queryengine.execution.operator.source.ExchangeOperato
 import 
org.apache.iotdb.db.queryengine.execution.operator.source.relational.TableFullOuterJoinOperator;
 import 
org.apache.iotdb.db.queryengine.execution.operator.source.relational.TableInnerJoinOperator;
 import 
org.apache.iotdb.db.queryengine.execution.operator.source.relational.TableScanOperator;
+import 
org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.Accumulator;
+import 
org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.AggregationOperator;
+import 
org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.Aggregator;
 import 
org.apache.iotdb.db.queryengine.execution.relational.ColumnTransformerBuilder;
 import org.apache.iotdb.db.queryengine.plan.analyze.TypeProvider;
 import org.apache.iotdb.db.queryengine.plan.planner.plan.node.PlanNode;
@@ -93,6 +96,8 @@ import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableMap;
 import org.apache.tsfile.common.conf.TSFileDescriptor;
 import org.apache.tsfile.enums.TSDataType;
+import org.apache.tsfile.read.common.type.RowType;
+import org.apache.tsfile.read.common.type.Type;
 import org.apache.tsfile.read.filter.basic.Filter;
 import org.apache.tsfile.write.schema.IMeasurementSchema;
 import org.apache.tsfile.write.schema.MeasurementSchema;
@@ -109,6 +114,7 @@ import java.util.List;
 import java.util.Map;
 import java.util.Objects;
 import java.util.Optional;
+import java.util.OptionalInt;
 import java.util.Set;
 import java.util.stream.Collectors;
 
@@ -118,8 +124,10 @@ import static 
org.apache.iotdb.commons.schema.table.column.TsTableColumnCategory
 import static 
org.apache.iotdb.db.queryengine.common.DataNodeEndPoints.isSameNode;
 import static 
org.apache.iotdb.db.queryengine.execution.operator.process.join.merge.MergeSortComparator.getComparatorForTable;
 import static 
org.apache.iotdb.db.queryengine.execution.operator.source.relational.TableScanOperator.constructAlignedPath;
+import static 
org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.AccumulatorFactory.createAccumulator;
 import static 
org.apache.iotdb.db.queryengine.plan.analyze.PredicateUtils.convertPredicateToFilter;
 import static 
org.apache.iotdb.db.queryengine.plan.planner.OperatorTreeGenerator.ASC_TIME_COMPARATOR;
+import static 
org.apache.iotdb.db.queryengine.plan.relational.metadata.TableBuiltinAggregationFunction.getAggregationTypeByFuncName;
 import static 
org.apache.iotdb.db.queryengine.plan.relational.type.InternalTypeManager.getTSDataType;
 
 /** This Visitor is responsible for transferring Table PlanNode Tree to Table 
Operator Tree. */
@@ -914,7 +922,86 @@ public class TableOperatorGenerator extends 
PlanVisitor<Operator, LocalExecution
 
   @Override
   public Operator visitAggregation(AggregationNode node, 
LocalExecutionPlanContext context) {
-    throw new UnsupportedOperationException("Agg-BE not supported");
+    OperatorContext operatorContext =
+        context
+            .getDriverContext()
+            .addOperatorContext(
+                context.getNextOperatorId(),
+                node.getPlanNodeId(),
+                AggregationNode.class.getSimpleName());
+    Operator child = node.getChild().accept(this, context);
+
+    if (node.getGroupingKeys().isEmpty()) {
+      return planGlobalAggregation(node, child, context.getTypeProvider(), 
operatorContext);
+    }
+
+    throw new UnsupportedOperationException();
+    // return planGroupByAggregation(node, child, outputTypes, 
operatorContext);
+  }
+
+  private Operator planGlobalAggregation(
+      AggregationNode node, Operator child, TypeProvider typeProvider, 
OperatorContext context) {
+
+    Map<Symbol, AggregationNode.Aggregation> aggregationMap = 
node.getAggregations();
+    ImmutableList.Builder<Aggregator> aggregatorBuilder = new 
ImmutableList.Builder<>();
+    Map<Symbol, Integer> childLayout =
+        makeLayoutFromOutputSymbols(node.getChild().getOutputSymbols());
+
+    node.getOutputSymbols()
+        .forEach(
+            symbol ->
+                aggregatorBuilder.add(
+                    buildAggregator(
+                        childLayout, aggregationMap.get(symbol), 
node.getStep(), typeProvider)));
+    return new AggregationOperator(context, child, aggregatorBuilder.build());
+  }
+
+  private ImmutableMap<Symbol, Integer> 
makeLayoutFromOutputSymbols(List<Symbol> outputSymbols) {
+    ImmutableMap.Builder<Symbol, Integer> outputMappings = 
ImmutableMap.builder();
+    int channel = 0;
+    for (Symbol symbol : outputSymbols) {
+      outputMappings.put(symbol, channel);
+      channel++;
+    }
+    return outputMappings.buildOrThrow();
+  }
+
+  private Aggregator buildAggregator(
+      Map<Symbol, Integer> childLayout,
+      AggregationNode.Aggregation aggregation,
+      AggregationNode.Step step,
+      TypeProvider typeProvider) {
+    List<Integer> argumentChannels = new ArrayList<>();
+    List<TSDataType> argumentTypes = new ArrayList<>();
+    for (Expression argument : aggregation.getArguments()) {
+      Symbol argumentSymbol = Symbol.from(argument);
+      argumentChannels.add(childLayout.get(argumentSymbol));
+
+      // get argument types
+      Type type = typeProvider.getTableModelType(argumentSymbol);
+      if (type instanceof RowType) {
+        type.getTypeParameters().forEach(subType -> 
argumentTypes.add(getTSDataType(subType)));
+      } else {
+        argumentTypes.add(getTSDataType(type));
+      }
+    }
+
+    String functionName = 
aggregation.getResolvedFunction().getSignature().getName();
+    Accumulator accumulator =
+        createAccumulator(
+            functionName,
+            getAggregationTypeByFuncName(functionName),
+            argumentTypes,
+            Collections.emptyList(),
+            Collections.emptyMap(),
+            true);
+
+    return new Aggregator(
+        accumulator,
+        step,
+        
getTSDataType(aggregation.getResolvedFunction().getSignature().getReturnType()),
+        argumentChannels,
+        OptionalInt.empty());
   }
 
   @Override
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/TableBuiltinAggregationFunction.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/TableBuiltinAggregationFunction.java
index e14df7ee408..72f0cdadd02 100644
--- 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/TableBuiltinAggregationFunction.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/TableBuiltinAggregationFunction.java
@@ -19,6 +19,8 @@
 
 package org.apache.iotdb.db.queryengine.plan.relational.metadata;
 
+import org.apache.iotdb.common.rpc.thrift.TAggregationType;
+
 import com.google.common.collect.ImmutableList;
 import org.apache.tsfile.read.common.type.Type;
 
@@ -113,4 +115,13 @@ public enum TableBuiltinAggregationFunction {
       return ImmutableList.copyOf(originalArgumentTypes);
     }
   }
+
+  public static TAggregationType getAggregationTypeByFuncName(String funcName) 
{
+    if (NATIVE_FUNCTION_NAMES.contains(funcName)) {
+      return TAggregationType.valueOf(funcName.toUpperCase());
+    } else {
+      // fallback to UDAF if no enum found
+      return TAggregationType.UDAF;
+    }
+  }
 }


Reply via email to