This is an automated email from the ASF dual-hosted git repository.
jackietien pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iotdb.git
The following commit(s) were added to refs/heads/master by this push:
new 19f2f8104e9 Fix bugs in accumulator of UDAF
19f2f8104e9 is described below
commit 19f2f8104e97f2de7dcfb2e32988d5b8bab2c225
Author: Weihao Li <[email protected]>
AuthorDate: Fri May 16 14:00:38 2025 +0800
Fix bugs in accumulator of UDAF
---
.../udf/IoTDBUserDefinedAggregateFunctionIT.java | 6 +--
...BUserDefinedAggregationFunctionNonStreamIT.java | 57 ++++++++++++++++++++++
.../GroupedUserDefinedAggregateAccumulator.java | 29 ++++++++---
.../aggregation/grouped/array/ObjectBigArray.java | 4 ++
4 files changed, 86 insertions(+), 10 deletions(-)
diff --git
a/integration-test/src/test/java/org/apache/iotdb/relational/it/db/it/udf/IoTDBUserDefinedAggregateFunctionIT.java
b/integration-test/src/test/java/org/apache/iotdb/relational/it/db/it/udf/IoTDBUserDefinedAggregateFunctionIT.java
index 858f2d310e2..7ef7fd3b590 100644
---
a/integration-test/src/test/java/org/apache/iotdb/relational/it/db/it/udf/IoTDBUserDefinedAggregateFunctionIT.java
+++
b/integration-test/src/test/java/org/apache/iotdb/relational/it/db/it/udf/IoTDBUserDefinedAggregateFunctionIT.java
@@ -42,7 +42,7 @@ import static org.junit.Assert.fail;
@Category({TableLocalStandaloneIT.class, TableClusterIT.class})
public class IoTDBUserDefinedAggregateFunctionIT {
private static final String DATABASE_NAME = "test";
- private static final String[] sqls =
+ protected static final String[] sqls =
new String[] {
"CREATE DATABASE " + DATABASE_NAME,
"USE " + DATABASE_NAME,
@@ -679,12 +679,12 @@ public class IoTDBUserDefinedAggregateFunctionIT {
};
tableResultSetEqualTest(
- "select device_id, first_two_sum(s1, s2, time) as sum from table1
group by device_id",
+ "select device_id, first_two_sum(s1, s2, time) as sum from table1
group by device_id order by device_id",
expectedHeader,
retArray,
DATABASE_NAME);
tableResultSetEqualTest(
- "select device_id, first_two_sum(s1, s2, s9) as sum from table1 group
by device_id",
+ "select device_id, first_two_sum(s1, s2, s9) as sum from table1 group
by device_id order by device_id",
expectedHeader,
retArray,
DATABASE_NAME);
diff --git
a/integration-test/src/test/java/org/apache/iotdb/relational/it/db/it/udf/IoTDBUserDefinedAggregationFunctionNonStreamIT.java
b/integration-test/src/test/java/org/apache/iotdb/relational/it/db/it/udf/IoTDBUserDefinedAggregationFunctionNonStreamIT.java
new file mode 100644
index 00000000000..c74e748732a
--- /dev/null
+++
b/integration-test/src/test/java/org/apache/iotdb/relational/it/db/it/udf/IoTDBUserDefinedAggregationFunctionNonStreamIT.java
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iotdb.relational.it.db.it.udf;
+
+import org.apache.iotdb.it.env.EnvFactory;
+import org.apache.iotdb.it.framework.IoTDBTestRunner;
+import org.apache.iotdb.itbase.category.TableClusterIT;
+import org.apache.iotdb.itbase.category.TableLocalStandaloneIT;
+
+import org.junit.AfterClass;
+import org.junit.BeforeClass;
+import org.junit.experimental.categories.Category;
+import org.junit.runner.RunWith;
+
+import static org.apache.iotdb.db.it.utils.TestUtils.prepareTableData;
+
+@RunWith(IoTDBTestRunner.class)
+@Category({TableLocalStandaloneIT.class, TableClusterIT.class})
+public class IoTDBUserDefinedAggregationFunctionNonStreamIT
+ extends IoTDBUserDefinedAggregateFunctionIT {
+
+ @BeforeClass
+ public static void setUp() throws Exception {
+ EnvFactory.getEnv().getConfig().getCommonConfig().setSortBufferSize(128 *
1024);
+
EnvFactory.getEnv().getConfig().getCommonConfig().setMaxTsBlockSizeInByte(4 *
1024);
+ EnvFactory.getEnv().initClusterEnvironment();
+ String original = sqls[2];
+ // make 'province', 'city', 'region' be FIELD to cover cases using
GroupedAccumulator
+ sqls[2] =
+ "CREATE TABLE table1(province STRING FIELD, city STRING FIELD, region
STRING FIELD, device_id STRING TAG, color STRING ATTRIBUTE, type STRING
ATTRIBUTE, s1 INT32 FIELD, s2 INT64 FIELD, s3 FLOAT FIELD, s4 DOUBLE FIELD, s5
BOOLEAN FIELD, s6 TEXT FIELD, s7 STRING FIELD, s8 BLOB FIELD, s9 TIMESTAMP
FIELD, s10 DATE FIELD)";
+ prepareTableData(sqls);
+ // rollback original content
+ sqls[2] = original;
+ }
+
+ @AfterClass
+ public static void tearDown() throws Exception {
+ EnvFactory.getEnv().cleanClusterEnvironment();
+ }
+}
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/grouped/GroupedUserDefinedAggregateAccumulator.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/grouped/GroupedUserDefinedAggregateAccumulator.java
index dbf74b711ca..b3d24e923ae 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/grouped/GroupedUserDefinedAggregateAccumulator.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/grouped/GroupedUserDefinedAggregateAccumulator.java
@@ -82,13 +82,23 @@ public class GroupedUserDefinedAggregateAccumulator
implements GroupedAccumulato
? new RecordIterator(
Arrays.asList(arguments), inputDataTypes,
arguments[0].getPositionCount())
: new MaskedRecordIterator(Arrays.asList(arguments),
inputDataTypes, mask);
- int[] selectedPositions = mask.getSelectedPositions();
+
int index = 0;
- while (iterator.hasNext()) {
- int groupId = groupIds[selectedPositions[index]];
- index++;
- State state = getOrCreateState(groupId);
- aggregateFunction.addInput(state, iterator.next());
+ if (mask.isSelectAll()) {
+ while (iterator.hasNext()) {
+ int groupId = groupIds[index];
+ index++;
+ State state = getOrCreateState(groupId);
+ aggregateFunction.addInput(state, iterator.next());
+ }
+ } else {
+ int[] selectedPositions = mask.getSelectedPositions();
+ while (iterator.hasNext()) {
+ int groupId = groupIds[selectedPositions[index]];
+ index++;
+ State state = getOrCreateState(groupId);
+ aggregateFunction.addInput(state, iterator.next());
+ }
}
}
@@ -141,6 +151,11 @@ public class GroupedUserDefinedAggregateAccumulator
implements GroupedAccumulato
@Override
public void close() {
aggregateFunction.beforeDestroy();
- stateArray.forEach(State::destroyState);
+ stateArray.forEach(
+ state -> {
+ if (state != null) {
+ state.destroyState();
+ }
+ });
}
}
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/grouped/array/ObjectBigArray.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/grouped/array/ObjectBigArray.java
index 51a7401c811..b5585824e43 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/grouped/array/ObjectBigArray.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/grouped/array/ObjectBigArray.java
@@ -117,6 +117,10 @@ public final class ObjectBigArray<T> {
fill(null);
}
+ /**
+ * Attention: the element in Array may be null!!! You have to handle this
case like {@link
+ * MapBigArray#reset()} in input action.
+ */
public void forEach(Consumer<T> action) {
for (Object[] segment : array) {
if (segment == null) {