JackieTien97 commented on code in PR #14331:
URL: https://github.com/apache/iotdb/pull/14331#discussion_r1887986137


##########
iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/grouped/GroupedUserDefinedAggregateAccumulator.java:
##########
@@ -0,0 +1,139 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package 
org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped;
+
+import org.apache.iotdb.commons.udf.access.RecordIterator;
+import 
org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.array.ObjectBigArray;
+import org.apache.iotdb.udf.api.State;
+import org.apache.iotdb.udf.api.relational.AggregateFunction;
+import org.apache.iotdb.udf.api.utils.ResultValue;
+
+import org.apache.tsfile.block.column.Column;
+import org.apache.tsfile.block.column.ColumnBuilder;
+import org.apache.tsfile.read.common.block.column.BinaryColumn;
+import org.apache.tsfile.read.common.block.column.BinaryColumnBuilder;
+import org.apache.tsfile.read.common.block.column.RunLengthEncodedColumn;
+import org.apache.tsfile.read.common.type.Type;
+import org.apache.tsfile.utils.Binary;
+import org.apache.tsfile.utils.RamUsageEstimator;
+
+import java.util.Arrays;
+import java.util.List;
+
+import static com.google.common.base.Preconditions.checkArgument;
+
+public class GroupedUserDefinedAggregateAccumulator implements 
GroupedAccumulator {
+
+  private static final long INSTANCE_SIZE =
+      
RamUsageEstimator.shallowSizeOfInstance(GroupedUserDefinedAggregateAccumulator.class);
+  private final AggregateFunction aggregateFunction;
+  private final ObjectBigArray<State> stateArray;
+  private final List<Type> inputDataTypes;
+
+  public GroupedUserDefinedAggregateAccumulator(
+      AggregateFunction aggregateFunction, List<Type> inputDataTypes) {
+    this.aggregateFunction = aggregateFunction;
+    this.stateArray = new ObjectBigArray<>();
+    this.inputDataTypes = inputDataTypes;
+  }
+
+  @Override
+  public long getEstimatedSize() {
+    return INSTANCE_SIZE;
+  }
+
+  @Override
+  public void setGroupCount(long groupCount) {
+    stateArray.ensureCapacity(groupCount);
+  }
+
+  private State getOrCreateState(int groupId) {
+    State state = stateArray.get(groupId);
+    if (state == null) {
+      state = aggregateFunction.createState();
+      stateArray.set(groupId, state);
+    }
+    return state;
+  }
+
+  @Override
+  public void addInput(int[] groupIds, Column[] arguments) {
+    RecordIterator iterator =
+        new RecordIterator(
+            Arrays.asList(arguments), inputDataTypes, 
arguments[0].getPositionCount());
+    int index = 0;
+    while (iterator.hasNext()) {
+      int groupId = groupIds[index++];
+      State state = getOrCreateState(groupId);
+      if (state == null) {
+        state = aggregateFunction.createState();
+        stateArray.set(groupId, state);
+      }

Review Comment:
   ```suggestion
   ```
   since you've already done this in getOrCreateState



##########
iotdb-api/udf-api/src/main/java/org/apache/iotdb/udf/api/relational/AggregateFunction.java:
##########
@@ -19,4 +19,83 @@
 
 package org.apache.iotdb.udf.api.relational;
 
-public interface AggregateFunction extends SQLFunction {}
+import org.apache.iotdb.udf.api.State;
+import org.apache.iotdb.udf.api.customizer.config.AggregateFunctionConfig;
+import org.apache.iotdb.udf.api.customizer.parameter.FunctionParameters;
+import org.apache.iotdb.udf.api.exception.UDFException;
+import org.apache.iotdb.udf.api.relational.access.Record;
+import org.apache.iotdb.udf.api.utils.ResultValue;
+
+public interface AggregateFunction extends SQLFunction {
+
+  /**
+   * This method is used to validate {@linkplain FunctionParameters}.
+   *
+   * @param parameters parameters used to validate
+   * @throws UDFException if any parameter is not valid
+   */
+  void validate(FunctionParameters parameters) throws UDFException;
+
+  /**
+   * This method is mainly used to initialize {@linkplain AggregateFunction} 
and set the output data
+   * type. In this method, the user need to do the following things:
+   *
+   * <ul>
+   *   <li>Use {@linkplain FunctionParameters} to get input data types and 
infer output data type.
+   *   <li>Use {@linkplain FunctionParameters} to get necessary attributes.
+   *   <li>Set the output data type in {@linkplain AggregateFunctionConfig}.
+   * </ul>
+   *
+   * <p>This method is called after the AggregateFunction is instantiated and 
before the beginning
+   * of the transformation process.
+   *
+   * @param parameters used to parse the input parameters entered by the user
+   * @param configurations used to set the required properties in the 
ScalarFunction
+   */
+  void beforeStart(FunctionParameters parameters, AggregateFunctionConfig 
configurations);
+
+  /** Create and initialize state. You may bind some resource in this method. 
*/
+  State createState();
+
+  /**
+   * Batch update state with data columns. You shall iterate columns and 
update state with raw
+   * values
+   *
+   * @param state state to be updated
+   * @param input original input data row
+   */
+  void addInput(State state, Record input);
+
+  /**
+   * Merge two state in execution engine.
+   *
+   * @param state current state
+   * @param rhs right-hand-side state to be merged
+   */
+  void combineState(State state, State rhs);
+
+  /**
+   * Remove input data from state. This method is used to remove the data 
points that have been
+   * added to the state. Once it is implemented, {@linkplain 
AggregateFunctionConfig#setRemovable}
+   * should be set to true.
+   *
+   * @param state state to be updated
+   * @param input row to be removed
+   */
+  default void remove(State state, Record input) {
+    throw new UnsupportedOperationException();
+  }
+
+  /**
+   * Calculate output value from final state
+   *
+   * @param state final state
+   * @param resultValue used to collect output data points
+   */
+  void outputFinal(State state, ResultValue resultValue);

Review Comment:
   ```suggestion
   
     /**
      * Calculate output value from final state
      *
      * @param state final state
      * @param resultValue used to collect output data points
      */
     void outputFinal(State state, ResultValue resultValue);
   
     /**
      * Remove input data from state. This method is used to remove the data 
points that have been
      * added to the state. Once it is implemented, {@linkplain 
AggregateFunctionConfig#setRemovable}
      * should be set to true.
      *
      * @param state state to be updated
      * @param input row to be removed
      */
     default void remove(State state, Record input) {
       throw new UnsupportedOperationException();
     }
   ```



##########
iotdb-api/udf-api/src/main/java/org/apache/iotdb/udf/api/relational/AggregateFunction.java:
##########
@@ -19,4 +19,83 @@
 
 package org.apache.iotdb.udf.api.relational;
 
-public interface AggregateFunction extends SQLFunction {}
+import org.apache.iotdb.udf.api.State;
+import org.apache.iotdb.udf.api.customizer.config.AggregateFunctionConfig;
+import org.apache.iotdb.udf.api.customizer.parameter.FunctionParameters;
+import org.apache.iotdb.udf.api.exception.UDFException;
+import org.apache.iotdb.udf.api.relational.access.Record;
+import org.apache.iotdb.udf.api.utils.ResultValue;
+
+public interface AggregateFunction extends SQLFunction {
+
+  /**
+   * This method is used to validate {@linkplain FunctionParameters}.
+   *
+   * @param parameters parameters used to validate
+   * @throws UDFException if any parameter is not valid
+   */
+  void validate(FunctionParameters parameters) throws UDFException;
+
+  /**
+   * This method is mainly used to initialize {@linkplain AggregateFunction} 
and set the output data
+   * type. In this method, the user need to do the following things:
+   *
+   * <ul>
+   *   <li>Use {@linkplain FunctionParameters} to get input data types and 
infer output data type.
+   *   <li>Use {@linkplain FunctionParameters} to get necessary attributes.
+   *   <li>Set the output data type in {@linkplain AggregateFunctionConfig}.
+   * </ul>
+   *
+   * <p>This method is called after the AggregateFunction is instantiated and 
before the beginning
+   * of the transformation process.
+   *
+   * @param parameters used to parse the input parameters entered by the user
+   * @param configurations used to set the required properties in the 
ScalarFunction
+   */
+  void beforeStart(FunctionParameters parameters, AggregateFunctionConfig 
configurations);
+
+  /** Create and initialize state. You may bind some resource in this method. 
*/
+  State createState();
+
+  /**
+   * Batch update state with data columns. You shall iterate columns and 
update state with raw
+   * values
+   *
+   * @param state state to be updated
+   * @param input original input data row
+   */
+  void addInput(State state, Record input);
+
+  /**
+   * Merge two state in execution engine.
+   *
+   * @param state current state
+   * @param rhs right-hand-side state to be merged
+   */
+  void combineState(State state, State rhs);
+
+  /**
+   * Remove input data from state. This method is used to remove the data 
points that have been
+   * added to the state. Once it is implemented, {@linkplain 
AggregateFunctionConfig#setRemovable}
+   * should be set to true.
+   *
+   * @param state state to be updated
+   * @param input row to be removed
+   */
+  default void remove(State state, Record input) {
+    throw new UnsupportedOperationException();
+  }
+
+  /**
+   * Calculate output value from final state
+   *
+   * @param state final state
+   * @param resultValue used to collect output data points
+   */
+  void outputFinal(State state, ResultValue resultValue);

Review Comment:
   change the order, let most common interface previous



##########
iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/relational/TableBuiltinAggregationFunction.java:
##########
@@ -103,7 +104,8 @@ public static Type getIntermediateType(String name, 
List<Type> originalArgumentT
       case "min":
         return originalArgumentTypes.get(0);
       default:
-        throw new IllegalArgumentException("Invalid Aggregation function: " + 
name);
+        // default is UDAF
+        return BLOB;

Review Comment:
   better do a double check here, using name to confirm, else throw the 
previous exception



##########
iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/UserDefinedAggregateFunctionAccumulator.java:
##########
@@ -0,0 +1,122 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package 
org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation;
+
+import org.apache.iotdb.commons.udf.access.RecordIterator;
+import org.apache.iotdb.udf.api.State;
+import org.apache.iotdb.udf.api.relational.AggregateFunction;
+import org.apache.iotdb.udf.api.utils.ResultValue;
+
+import org.apache.tsfile.block.column.Column;
+import org.apache.tsfile.block.column.ColumnBuilder;
+import org.apache.tsfile.file.metadata.statistics.Statistics;
+import org.apache.tsfile.read.common.block.column.BinaryColumn;
+import org.apache.tsfile.read.common.block.column.BinaryColumnBuilder;
+import org.apache.tsfile.read.common.block.column.RunLengthEncodedColumn;
+import org.apache.tsfile.read.common.type.Type;
+import org.apache.tsfile.utils.Binary;
+import org.apache.tsfile.utils.RamUsageEstimator;
+
+import java.util.Arrays;
+import java.util.List;
+
+import static com.google.common.base.Preconditions.checkArgument;
+
+public class UserDefinedAggregateFunctionAccumulator implements 
TableAccumulator {
+
+  private static final long INSTANCE_SIZE =
+      
RamUsageEstimator.shallowSizeOfInstance(UserDefinedAggregateFunctionAccumulator.class);
+  private final AggregateFunction aggregateFunction;
+  private final List<Type> inputDataTypes;
+  private final State state;
+
+  public UserDefinedAggregateFunctionAccumulator(
+      AggregateFunction aggregateFunction, List<Type> inputDataTypes) {
+    this.aggregateFunction = aggregateFunction;
+    this.inputDataTypes = inputDataTypes;
+    this.state = aggregateFunction.createState();
+  }
+
+  @Override
+  public long getEstimatedSize() {
+    return INSTANCE_SIZE;
+  }
+
+  @Override
+  public TableAccumulator copy() {
+    return new UserDefinedAggregateFunctionAccumulator(aggregateFunction, 
inputDataTypes);
+  }
+
+  @Override
+  public void addInput(Column[] arguments) {
+    RecordIterator iterator =
+        new RecordIterator(
+            Arrays.asList(arguments), inputDataTypes, 
arguments[0].getPositionCount());
+    while (iterator.hasNext()) {
+      aggregateFunction.addInput(state, iterator.next());
+    }
+  }
+
+  @Override
+  public void addIntermediate(Column argument) {
+    checkArgument(
+        argument instanceof BinaryColumn
+            || (argument instanceof RunLengthEncodedColumn
+                && ((RunLengthEncodedColumn) argument).getValue() instanceof 
BinaryColumn),
+        "intermediate input and output of UDAF should be BinaryColumn");
+    State otherState = aggregateFunction.createState();
+    for (int i = 0; i < argument.getPositionCount(); i++) {
+      otherState.reset();
+      Binary otherStateBinary = argument.getBinary(i);
+      otherState.deserialize(otherStateBinary.getValues());
+      aggregateFunction.combineState(state, otherState);
+    }
+  }
+
+  @Override
+  public void evaluateIntermediate(ColumnBuilder columnBuilder) {
+    checkArgument(
+        columnBuilder instanceof BinaryColumnBuilder,
+        "intermediate input and output of UDAF should be BinaryColumn");
+    byte[] bytes = state.serialize();
+    columnBuilder.writeBinary(new Binary(bytes));
+  }
+
+  @Override
+  public void evaluateFinal(ColumnBuilder columnBuilder) {
+    ResultValue resultValue = new ResultValue(columnBuilder);
+    aggregateFunction.outputFinal(state, resultValue);
+  }
+
+  @Override
+  public boolean hasFinalResult() {
+    return false;
+  }
+
+  @Override
+  public void addStatistics(Statistics[] statistics) {
+    // UDAF not support calculate from statistics now

Review Comment:
   throw new UnsupportedException("UDAF not support calculate from statistics 
now");



##########
iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/utils/UDFDataTypeTransformer.java:
##########
@@ -122,19 +126,54 @@ public static org.apache.tsfile.read.common.type.Type 
transformUDFDataTypeToRead
       case BOOLEAN:
         return BooleanType.BOOLEAN;
       case INT32:
-      case DATE:
         return IntType.INT32;
+      case DATE:
+        return DateType.DATE;
       case INT64:
-      case TIMESTAMP:
         return LongType.INT64;
+      case TIMESTAMP:
+        return TimestampType.TIMESTAMP;
       case FLOAT:
         return FloatType.FLOAT;
       case DOUBLE:
         return DoubleType.DOUBLE;
       case TEXT:
+        return BinaryType.TEXT;
       case BLOB:
+        return BlobType.BLOB;
       case STRING:
+        return StringType.STRING;
+      default:
+        throw new IllegalArgumentException("Invalid input: " + type);
+    }
+  }
+
+  public static org.apache.tsfile.read.common.type.Type 
transformTSDataTypeToReadType(

Review Comment:
   delete this method, you can use TypeFactory.getType instead



##########
iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/grouped/GroupedUserDefinedAggregateAccumulator.java:
##########
@@ -0,0 +1,139 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package 
org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped;
+
+import org.apache.iotdb.commons.udf.access.RecordIterator;
+import 
org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.array.ObjectBigArray;
+import org.apache.iotdb.udf.api.State;
+import org.apache.iotdb.udf.api.relational.AggregateFunction;
+import org.apache.iotdb.udf.api.utils.ResultValue;
+
+import org.apache.tsfile.block.column.Column;
+import org.apache.tsfile.block.column.ColumnBuilder;
+import org.apache.tsfile.read.common.block.column.BinaryColumn;
+import org.apache.tsfile.read.common.block.column.BinaryColumnBuilder;
+import org.apache.tsfile.read.common.block.column.RunLengthEncodedColumn;
+import org.apache.tsfile.read.common.type.Type;
+import org.apache.tsfile.utils.Binary;
+import org.apache.tsfile.utils.RamUsageEstimator;
+
+import java.util.Arrays;
+import java.util.List;
+
+import static com.google.common.base.Preconditions.checkArgument;
+
+public class GroupedUserDefinedAggregateAccumulator implements 
GroupedAccumulator {
+
+  private static final long INSTANCE_SIZE =
+      
RamUsageEstimator.shallowSizeOfInstance(GroupedUserDefinedAggregateAccumulator.class);
+  private final AggregateFunction aggregateFunction;
+  private final ObjectBigArray<State> stateArray;
+  private final List<Type> inputDataTypes;
+
+  public GroupedUserDefinedAggregateAccumulator(
+      AggregateFunction aggregateFunction, List<Type> inputDataTypes) {
+    this.aggregateFunction = aggregateFunction;
+    this.stateArray = new ObjectBigArray<>();
+    this.inputDataTypes = inputDataTypes;
+  }
+
+  @Override
+  public long getEstimatedSize() {
+    return INSTANCE_SIZE;
+  }
+
+  @Override
+  public void setGroupCount(long groupCount) {
+    stateArray.ensureCapacity(groupCount);
+  }
+
+  private State getOrCreateState(int groupId) {
+    State state = stateArray.get(groupId);
+    if (state == null) {
+      state = aggregateFunction.createState();
+      stateArray.set(groupId, state);
+    }
+    return state;
+  }
+
+  @Override
+  public void addInput(int[] groupIds, Column[] arguments) {
+    RecordIterator iterator =
+        new RecordIterator(
+            Arrays.asList(arguments), inputDataTypes, 
arguments[0].getPositionCount());
+    int index = 0;
+    while (iterator.hasNext()) {
+      int groupId = groupIds[index++];
+      State state = getOrCreateState(groupId);
+      if (state == null) {
+        state = aggregateFunction.createState();
+        stateArray.set(groupId, state);
+      }
+      aggregateFunction.addInput(state, iterator.next());
+    }
+  }
+
+  @Override
+  public void addIntermediate(int[] groupIds, Column argument) {
+    checkArgument(
+        argument instanceof BinaryColumn
+            || (argument instanceof RunLengthEncodedColumn
+                && ((RunLengthEncodedColumn) argument).getValue() instanceof 
BinaryColumn),
+        "intermediate input and output of UDAF should be BinaryColumn");
+
+    for (int i = 0; i < groupIds.length; i++) {
+      if (!argument.isNull(i)) {
+        State otherState = aggregateFunction.createState();
+        Binary otherStateBinary = argument.getBinary(i);
+        otherState.deserialize(otherStateBinary.getValues());
+        aggregateFunction.combineState(getOrCreateState(groupIds[i]), 
otherState);
+      }
+    }
+  }
+
+  @Override
+  public void evaluateIntermediate(int groupId, ColumnBuilder columnBuilder) {
+    checkArgument(
+        columnBuilder instanceof BinaryColumnBuilder,
+        "intermediate input and output of UDAF should be BinaryColumn");
+    if (stateArray.get(groupId) == null) {
+      columnBuilder.writeBinary(new Binary(new byte[0]));

Review Comment:
   when will this if happens? It seems that each group will always have a 
state? If so, better throw IllegalStateException here.



##########
iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/utils/UDFDataTypeTransformer.java:
##########
@@ -122,19 +126,54 @@ public static org.apache.tsfile.read.common.type.Type 
transformUDFDataTypeToRead
       case BOOLEAN:
         return BooleanType.BOOLEAN;
       case INT32:
-      case DATE:
         return IntType.INT32;
+      case DATE:
+        return DateType.DATE;
       case INT64:
-      case TIMESTAMP:
         return LongType.INT64;
+      case TIMESTAMP:
+        return TimestampType.TIMESTAMP;
       case FLOAT:
         return FloatType.FLOAT;
       case DOUBLE:
         return DoubleType.DOUBLE;
       case TEXT:
+        return BinaryType.TEXT;
       case BLOB:
+        return BlobType.BLOB;
       case STRING:
+        return StringType.STRING;
+      default:
+        throw new IllegalArgumentException("Invalid input: " + type);
+    }
+  }
+
+  public static org.apache.tsfile.read.common.type.Type 
transformTSDataTypeToReadType(

Review Comment:
   also delete transformReadTypeToTSDataType in this class, it's not used, and 
we also have a replacement method called InternalTypeManager.getTSDataType



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to