This is an automated email from the ASF dual-hosted git repository.
caogaofei pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iotdb.git
The following commit(s) were added to refs/heads/master by this push:
new 7ed70ecefdf Support mode, stddev and variance function for grouped
aggregation (#13911)
7ed70ecefdf is described below
commit 7ed70ecefdf78e91311875388faef86cc6c3634b
Author: Weihao Li <[email protected]>
AuthorDate: Fri Oct 25 15:40:21 2024 +0800
Support mode, stddev and variance function for grouped aggregation (#13911)
---
.../db/it/IoTDBMultiIDsWithAttributesTableIT.java | 78 +++-
.../it/query/recent/IoTDBTableAggregationIT.java | 17 +
.../relational/aggregation/AccumulatorFactory.java | 18 +
.../aggregation/TableModeAccumulator.java | 139 +++++--
.../source/relational/aggregation/Utils.java | 2 +-
.../grouped/GroupedModeAccumulator.java | 452 +++++++++++++++++++++
.../grouped/GroupedVarianceAccumulator.java | 225 ++++++++++
.../aggregation/grouped/array/MapBigArray.java | 83 ++++
.../plan/planner/TableOperatorGenerator.java | 32 +-
.../metadata/TableBuiltinAggregationFunction.java | 43 +-
.../relational/planner/optimizations/Util.java | 18 +-
11 files changed, 1017 insertions(+), 90 deletions(-)
diff --git
a/integration-test/src/test/java/org/apache/iotdb/relational/it/db/it/IoTDBMultiIDsWithAttributesTableIT.java
b/integration-test/src/test/java/org/apache/iotdb/relational/it/db/it/IoTDBMultiIDsWithAttributesTableIT.java
index 34b8af428a3..1b4f100bb44 100644
---
a/integration-test/src/test/java/org/apache/iotdb/relational/it/db/it/IoTDBMultiIDsWithAttributesTableIT.java
+++
b/integration-test/src/test/java/org/apache/iotdb/relational/it/db/it/IoTDBMultiIDsWithAttributesTableIT.java
@@ -47,7 +47,7 @@ public class IoTDBMultiIDsWithAttributesTableIT {
"CREATE DATABASE db",
"USE db",
"CREATE TABLE table0 (device string id, level string id, attr1 string
attribute, attr2 string attribute, num int32 measurement, bigNum int64
measurement, "
- + "floatNum double measurement, str TEXT measurement, bool BOOLEAN
measurement, date DATE measurement, blob BLOB measurement, ts TIMESTAMP
measurement, stringV STRING measurement, doubleNum DOUBLE measurement)",
+ + "floatNum FLOAT measurement, str TEXT measurement, bool BOOLEAN
measurement, date DATE measurement, blob BLOB measurement, ts TIMESTAMP
measurement, stringV STRING measurement, doubleNum DOUBLE measurement)",
"insert into table0(device, level, attr1, attr2,
time,num,bigNum,floatNum,str,bool) values('d1', 'l1', 'c', 'd',
0,3,2947483648,231.2121,'coconut',FALSE)",
"insert into table0(device, level, attr1, attr2,
time,num,bigNum,floatNum,str,bool,blob,ts,doubleNum) values('d1', 'l2', 'y',
'z',
20,2,2147483648,434.12,'pineapple',TRUE,X'108DCD62',2024-09-24T06:15:35.000+00:00,6666.8)",
"insert into table0(device, level, attr1, attr2,
time,num,bigNum,floatNum,str,bool) values('d1', 'l3', 't', 'a',
40,1,2247483648,12.123,'apricot',TRUE)",
@@ -1251,7 +1251,44 @@ public class IoTDBMultiIDsWithAttributesTableIT {
"select
mode(device),mode(level),mode(attr1),mode(attr2),mode(date),mode(bool),mode(date),mode(ts),mode(stringv),mode(doublenum)
from table0 where device='d2' and level='l1'";
retArray =
new String[] {
-
"d2,l1,d,c,null,false,null,2024-08-01T06:15:35.000Z,test-string3,null,",
+ "d2,l1,d,c,null,false,null,null,null,null,",
+ };
+ tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME);
+
+ expectedHeader = buildHeaders(1);
+ sql =
+ "select mode(stringv) from table0 where device='d2' and level='l1' and
stringv is not null";
+ retArray =
+ new String[] {
+ "test-string3,",
+ };
+ tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME);
+
+ // no push-down, test GroupedAccumulator
+ expectedHeader = buildHeaders(16);
+ sql =
+ "select
mode(time),mode(device),mode(level),mode(attr1),mode(attr2),mode(num),mode(bignum),mode(floatnum),mode(date),mode(str),mode(bool),mode(date),mode(ts),mode(stringv),mode(doublenum),count(num+1)
from table0 where device='d2' and level='l4' and time=80 group by device,
level";
+ retArray =
+ new String[] {
+
"1970-01-01T00:00:00.080Z,d2,l4,null,null,9,2147483646,43.12,null,apple,false,null,2024-09-20T06:15:35.000Z,test-string2,6666.7,1,",
+ };
+ tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME);
+
+ expectedHeader = buildHeaders(11);
+ sql =
+ "select
mode(device),mode(level),mode(attr1),mode(attr2),mode(date),mode(bool),mode(date),mode(ts),mode(stringv),mode(doublenum),count(num+1)
from table0 where device='d2' and level='l1' group by device, level";
+ retArray =
+ new String[] {
+ "d2,l1,d,c,null,false,null,null,null,null,3,",
+ };
+ tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME);
+
+ expectedHeader = buildHeaders(2);
+ sql =
+ "select mode(stringv),count(num+1) from table0 where device='d2' and
level='l1' and stringv is not null group by device, level";
+ retArray =
+ new String[] {
+ "test-string3,1,",
};
tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME);
}
@@ -1269,6 +1306,43 @@ public class IoTDBMultiIDsWithAttributesTableIT {
"16.0,10.7,16.0,4.0,3.3,4.0,50.0,33.3,50.0,7.1,5.8,7.1,null,0.0,null,null,0.0,null,",
};
tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME);
+
+ expectedHeader = buildHeaders(19);
+ sql =
+ "select \n"
+ +
"round(variance(num),1),round(var_pop(num),1),round(var_samp(num),1),round(stddev(num),1),round(stddev_pop(num),1),round(stddev_samp(num),1),\n"
+ +
"round(variance(floatnum),1),round(var_pop(floatnum),1),round(var_samp(floatnum),1),round(stddev(floatnum),1),round(stddev_pop(floatnum),1),round(stddev_samp(floatnum),1),\n"
+ +
"round(variance(doublenum),1),round(var_pop(doublenum),1),round(var_samp(doublenum),1),round(stddev(doublenum),1),round(stddev_pop(doublenum),1),round(stddev_samp(doublenum),1),
count(num+1) from table0 where device='d2' and level='l4' group by device,
level";
+ retArray =
+ new String[] {
+
"16.0,10.7,16.0,4.0,3.3,4.0,50.0,33.3,50.0,7.1,5.8,7.1,null,0.0,null,null,0.0,null,3,",
+ };
+ tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME);
+
+ sql =
+ "select \n"
+ +
"round(variance(num),1),round(var_pop(num),1),round(var_samp(num),1),round(stddev(num),1),round(stddev_pop(num),1),round(stddev_samp(num),1),\n"
+ +
"round(variance(floatnum),1),round(var_pop(floatnum),1),round(var_samp(floatnum),1),round(stddev(floatnum),1),round(stddev_pop(floatnum),1),round(stddev_samp(floatnum),1),\n"
+ +
"round(variance(doublenum),1),round(var_pop(doublenum),1),round(var_samp(doublenum),1),round(stddev(doublenum),1),round(stddev_pop(doublenum),1),round(stddev_samp(doublenum),1),
count(num+1) from table0 group by device";
+ retArray =
+ new String[] {
+
"20.0,18.7,20.0,4.5,4.3,4.5,1391642.5,1298866.4,1391642.5,1179.7,1139.7,1179.7,0.1,0.1,0.1,0.4,0.2,0.4,15,",
+
"20.0,18.7,20.0,4.5,4.3,4.5,1391642.5,1298866.4,1391642.5,1179.7,1139.7,1179.7,0.0,0.0,0.0,0.2,0.2,0.2,15,"
+ };
+ tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME);
+
+ expectedHeader = buildHeaders(18);
+ sql =
+ "select \n"
+ +
"round(variance(num),1),round(var_pop(num),1),round(var_samp(num),1),round(stddev(num),1),round(stddev_pop(num),1),round(stddev_samp(num),1),\n"
+ +
"round(variance(floatnum),1),round(var_pop(floatnum),1),round(var_samp(floatnum),1),round(stddev(floatnum),1),round(stddev_pop(floatnum),1),round(stddev_samp(floatnum),1),\n"
+ +
"round(variance(doublenum),1),round(var_pop(doublenum),1),round(var_samp(doublenum),1),round(stddev(doublenum),1),round(stddev_pop(doublenum),1),round(stddev_samp(doublenum),1)
from table0 group by device";
+ retArray =
+ new String[] {
+
"20.0,18.7,20.0,4.5,4.3,4.5,1391642.5,1298866.4,1391642.5,1179.7,1139.7,1179.7,0.1,0.1,0.1,0.4,0.2,0.4,",
+
"20.0,18.7,20.0,4.5,4.3,4.5,1391642.5,1298866.4,1391642.5,1179.7,1139.7,1179.7,0.0,0.0,0.0,0.2,0.2,0.2,"
+ };
+ tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME);
}
// ==================================================================
diff --git
a/integration-test/src/test/java/org/apache/iotdb/relational/it/query/recent/IoTDBTableAggregationIT.java
b/integration-test/src/test/java/org/apache/iotdb/relational/it/query/recent/IoTDBTableAggregationIT.java
index d24d50fc431..db0530bf41f 100644
---
a/integration-test/src/test/java/org/apache/iotdb/relational/it/query/recent/IoTDBTableAggregationIT.java
+++
b/integration-test/src/test/java/org/apache/iotdb/relational/it/query/recent/IoTDBTableAggregationIT.java
@@ -32,6 +32,7 @@ import org.junit.runner.RunWith;
import static org.apache.iotdb.db.it.utils.TestUtils.prepareTableData;
import static org.apache.iotdb.db.it.utils.TestUtils.tableResultSetEqualTest;
+import static
org.apache.iotdb.relational.it.db.it.IoTDBMultiIDsWithAttributesTableIT.buildHeaders;
@RunWith(IoTDBTestRunner.class)
@Category({TableLocalStandaloneIT.class, TableClusterIT.class})
@@ -3608,4 +3609,20 @@ public class IoTDBTableAggregationIT {
retArray,
DATABASE_NAME);
}
+
+ @Test
+ public void modeTest() {
+ // AggTableScan + Agg mixed test
+ String[] expectedHeader = buildHeaders(11);
+ String[] retArray =
+ new String[] {
+
"A,null,null,null,null,null,null,null,null,2024-09-24T06:15:40.000Z,null,",
+
"A,null,null,null,null,null,null,null,null,2024-09-24T06:15:40.000Z,null,",
+ };
+ tableResultSetEqualTest(
+ "select mode(type),
mode(s1),mode(s2),mode(s3),mode(s4),mode(s5),mode(s6),mode(s7),mode(s8),mode(s9),mode(s10)
from table1 group by city",
+ expectedHeader,
+ retArray,
+ DATABASE_NAME);
+ }
}
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/AccumulatorFactory.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/AccumulatorFactory.java
index 5003883eb45..f47c8121e76 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/AccumulatorFactory.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/AccumulatorFactory.java
@@ -33,7 +33,9 @@ import
org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggr
import
org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedMaxByAccumulator;
import
org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedMinAccumulator;
import
org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedMinByAccumulator;
+import
org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedModeAccumulator;
import
org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedSumAccumulator;
+import
org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedVarianceAccumulator;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Expression;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.SymbolReference;
@@ -135,6 +137,22 @@ public class AccumulatorFactory {
return new GroupedMaxByAccumulator(inputDataTypes.get(0),
inputDataTypes.get(1));
case MIN_BY:
return new GroupedMinByAccumulator(inputDataTypes.get(0),
inputDataTypes.get(1));
+ case MODE:
+ return new GroupedModeAccumulator(inputDataTypes.get(0));
+ case STDDEV:
+ case STDDEV_SAMP:
+ return new GroupedVarianceAccumulator(
+ inputDataTypes.get(0),
VarianceAccumulator.VarianceType.STDDEV_SAMP);
+ case STDDEV_POP:
+ return new GroupedVarianceAccumulator(
+ inputDataTypes.get(0),
VarianceAccumulator.VarianceType.STDDEV_POP);
+ case VARIANCE:
+ case VAR_SAMP:
+ return new GroupedVarianceAccumulator(
+ inputDataTypes.get(0), VarianceAccumulator.VarianceType.VAR_SAMP);
+ case VAR_POP:
+ return new GroupedVarianceAccumulator(
+ inputDataTypes.get(0), VarianceAccumulator.VarianceType.VAR_POP);
default:
throw new IllegalArgumentException("Invalid Aggregation function: " +
aggregationType);
}
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/TableModeAccumulator.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/TableModeAccumulator.java
index 3f90db0a1bd..1ad01095fac 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/TableModeAccumulator.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/TableModeAccumulator.java
@@ -31,10 +31,11 @@ import org.apache.tsfile.utils.RamUsageEstimator;
import java.util.HashMap;
import java.util.Map;
-import java.util.Optional;
import static
org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.Utils.UNSUPPORTED_TYPE_MESSAGE;
import static
org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.Utils.serializeBinaryValue;
+import static org.apache.tsfile.utils.BytesUtils.bytesToBool;
+import static org.apache.tsfile.utils.BytesUtils.bytesToLongFromOffset;
public class TableModeAccumulator implements TableAccumulator {
@@ -51,6 +52,8 @@ public class TableModeAccumulator implements TableAccumulator
{
private Map<Double, Long> doubleCountMap;
private Map<Binary, Long> binaryCountMap;
+ private long nullCount;
+
public TableModeAccumulator(TSDataType seriesDataType) {
this.seriesDataType = seriesDataType;
switch (seriesDataType) {
@@ -147,11 +150,14 @@ public class TableModeAccumulator implements
TableAccumulator {
if (booleanCountMap.isEmpty()) {
columnBuilder.appendNull();
} else {
- Optional<Boolean> maxKey =
- booleanCountMap.entrySet().stream()
- .max(Map.Entry.comparingByValue())
- .map(Map.Entry::getKey);
- maxKey.ifPresent(columnBuilder::writeBoolean);
+ // must be present
+ Map.Entry<Boolean, Long> maxEntry =
+
booleanCountMap.entrySet().stream().max(Map.Entry.comparingByValue()).get();
+ if (maxEntry.getValue() < nullCount) {
+ columnBuilder.appendNull();
+ } else {
+ columnBuilder.writeBoolean(maxEntry.getKey());
+ }
}
break;
case INT32:
@@ -159,22 +165,26 @@ public class TableModeAccumulator implements
TableAccumulator {
if (intCountMap.isEmpty()) {
columnBuilder.appendNull();
} else {
- Optional<Integer> maxKey =
- intCountMap.entrySet().stream()
- .max(Map.Entry.comparingByValue())
- .map(Map.Entry::getKey);
- maxKey.ifPresent(columnBuilder::writeInt);
+ Map.Entry<Integer, Long> maxEntry =
+
intCountMap.entrySet().stream().max(Map.Entry.comparingByValue()).get();
+ if (maxEntry.getValue() < nullCount) {
+ columnBuilder.appendNull();
+ } else {
+ columnBuilder.writeInt(maxEntry.getKey());
+ }
}
break;
case FLOAT:
if (floatCountMap.isEmpty()) {
columnBuilder.appendNull();
} else {
- Optional<Float> maxKey =
- floatCountMap.entrySet().stream()
- .max(Map.Entry.comparingByValue())
- .map(Map.Entry::getKey);
- maxKey.ifPresent(columnBuilder::writeFloat);
+ Map.Entry<Float, Long> maxEntry =
+
floatCountMap.entrySet().stream().max(Map.Entry.comparingByValue()).get();
+ if (maxEntry.getValue() < nullCount) {
+ columnBuilder.appendNull();
+ } else {
+ columnBuilder.writeFloat(maxEntry.getKey());
+ }
}
break;
case INT64:
@@ -182,22 +192,26 @@ public class TableModeAccumulator implements
TableAccumulator {
if (longCountMap.isEmpty()) {
columnBuilder.appendNull();
} else {
- Optional<Long> maxKey =
- longCountMap.entrySet().stream()
- .max(Map.Entry.comparingByValue())
- .map(Map.Entry::getKey);
- maxKey.ifPresent(columnBuilder::writeLong);
+ Map.Entry<Long, Long> maxEntry =
+
longCountMap.entrySet().stream().max(Map.Entry.comparingByValue()).get();
+ if (maxEntry.getValue() < nullCount) {
+ columnBuilder.appendNull();
+ } else {
+ columnBuilder.writeLong(maxEntry.getKey());
+ }
}
break;
case DOUBLE:
if (doubleCountMap.isEmpty()) {
columnBuilder.appendNull();
} else {
- Optional<Double> maxKey =
- doubleCountMap.entrySet().stream()
- .max(Map.Entry.comparingByValue())
- .map(Map.Entry::getKey);
- maxKey.ifPresent(columnBuilder::writeDouble);
+ Map.Entry<Double, Long> maxEntry =
+
doubleCountMap.entrySet().stream().max(Map.Entry.comparingByValue()).get();
+ if (maxEntry.getValue() < nullCount) {
+ columnBuilder.appendNull();
+ } else {
+ columnBuilder.writeDouble(maxEntry.getKey());
+ }
}
break;
case TEXT:
@@ -206,11 +220,13 @@ public class TableModeAccumulator implements
TableAccumulator {
if (binaryCountMap.isEmpty()) {
columnBuilder.appendNull();
} else {
- Optional<Binary> maxKey =
- binaryCountMap.entrySet().stream()
- .max(Map.Entry.comparingByValue())
- .map(Map.Entry::getKey);
- maxKey.ifPresent(columnBuilder::writeBinary);
+ Map.Entry<Binary, Long> maxEntry =
+
binaryCountMap.entrySet().stream().max(Map.Entry.comparingByValue()).get();
+ if (maxEntry.getValue() < nullCount) {
+ columnBuilder.appendNull();
+ } else {
+ columnBuilder.writeBinary(maxEntry.getKey());
+ }
}
break;
default:
@@ -249,15 +265,22 @@ public class TableModeAccumulator implements
TableAccumulator {
if (binaryCountMap != null) {
binaryCountMap.clear();
}
+ nullCount = 0;
}
+ // haveNull | nullCount (optional) | countMap
private byte[] serializeCountMap() {
byte[] bytes;
- int offset = 0;
+ int offset = 1 + (nullCount == 0 ? 0 : Long.BYTES);
+ ;
switch (seriesDataType) {
case BOOLEAN:
- bytes = new byte[4 + (1 + 8) * booleanCountMap.size()];
+ bytes = new byte[offset + 4 + (1 + 8) * booleanCountMap.size()];
+ BytesUtils.boolToBytes(nullCount != 0, bytes, 0);
+ if (nullCount != 0) {
+ BytesUtils.longToBytes(nullCount, bytes, 1);
+ }
BytesUtils.intToBytes(booleanCountMap.size(), bytes, offset);
offset += 4;
for (Map.Entry<Boolean, Long> entry : booleanCountMap.entrySet()) {
@@ -269,7 +292,11 @@ public class TableModeAccumulator implements
TableAccumulator {
break;
case INT32:
case DATE:
- bytes = new byte[4 + (4 + 8) * intCountMap.size()];
+ bytes = new byte[offset + 4 + (4 + 8) * intCountMap.size()];
+ BytesUtils.boolToBytes(nullCount != 0, bytes, 0);
+ if (nullCount != 0) {
+ BytesUtils.longToBytes(nullCount, bytes, 1);
+ }
BytesUtils.intToBytes(intCountMap.size(), bytes, offset);
offset += 4;
for (Map.Entry<Integer, Long> entry : intCountMap.entrySet()) {
@@ -280,7 +307,11 @@ public class TableModeAccumulator implements
TableAccumulator {
}
break;
case FLOAT:
- bytes = new byte[4 + (4 + 8) * floatCountMap.size()];
+ bytes = new byte[offset + 4 + (4 + 8) * floatCountMap.size()];
+ BytesUtils.boolToBytes(nullCount != 0, bytes, 0);
+ if (nullCount != 0) {
+ BytesUtils.longToBytes(nullCount, bytes, 1);
+ }
BytesUtils.intToBytes(floatCountMap.size(), bytes, offset);
offset += 4;
for (Map.Entry<Float, Long> entry : floatCountMap.entrySet()) {
@@ -292,7 +323,11 @@ public class TableModeAccumulator implements
TableAccumulator {
break;
case INT64:
case TIMESTAMP:
- bytes = new byte[4 + (8 + 8) * longCountMap.size()];
+ bytes = new byte[offset + 4 + (8 + 8) * longCountMap.size()];
+ BytesUtils.boolToBytes(nullCount != 0, bytes, 0);
+ if (nullCount != 0) {
+ BytesUtils.longToBytes(nullCount, bytes, 1);
+ }
BytesUtils.intToBytes(longCountMap.size(), bytes, offset);
offset += 4;
for (Map.Entry<Long, Long> entry : longCountMap.entrySet()) {
@@ -303,7 +338,11 @@ public class TableModeAccumulator implements
TableAccumulator {
}
break;
case DOUBLE:
- bytes = new byte[4 + (8 + 8) * doubleCountMap.size()];
+ bytes = new byte[offset + 4 + (8 + 8) * doubleCountMap.size()];
+ BytesUtils.boolToBytes(nullCount != 0, bytes, 0);
+ if (nullCount != 0) {
+ BytesUtils.longToBytes(nullCount, bytes, 1);
+ }
BytesUtils.intToBytes(doubleCountMap.size(), bytes, offset);
offset += 4;
for (Map.Entry<Double, Long> entry : doubleCountMap.entrySet()) {
@@ -318,11 +357,16 @@ public class TableModeAccumulator implements
TableAccumulator {
case BLOB:
bytes =
new byte
- [4
+ [offset
+ + 4
+ (8 + 4) * binaryCountMap.size()
+ binaryCountMap.keySet().stream()
.mapToInt(key -> key.getValues().length)
.sum()];
+ BytesUtils.boolToBytes(nullCount != 0, bytes, 0);
+ if (nullCount != 0) {
+ BytesUtils.longToBytes(nullCount, bytes, 1);
+ }
BytesUtils.intToBytes(binaryCountMap.size(), bytes, offset);
offset += 4;
for (Map.Entry<Binary, Long> entry : binaryCountMap.entrySet()) {
@@ -343,6 +387,11 @@ public class TableModeAccumulator implements
TableAccumulator {
private void deserializeAndMergeCountMap(byte[] bytes) {
int offset = 0;
+ if (bytesToBool(bytes, 0)) {
+ nullCount += bytesToLongFromOffset(bytes, Long.BYTES, 1);
+ offset += Long.BYTES;
+ }
+ offset++;
int size = BytesUtils.bytesToInt(bytes, offset);
offset += 4;
@@ -378,7 +427,7 @@ public class TableModeAccumulator implements
TableAccumulator {
case INT64:
case TIMESTAMP:
for (int i = 0; i < size; i++) {
- long key = BytesUtils.bytesToLong(bytes, offset);
+ long key = BytesUtils.bytesToLongFromOffset(bytes, Long.BYTES,
offset);
offset += 8;
long count = BytesUtils.bytesToLongFromOffset(bytes, 8, offset);
offset += 8;
@@ -420,6 +469,8 @@ public class TableModeAccumulator implements
TableAccumulator {
if (booleanCountMap.size() > MAP_SIZE_THRESHOLD) {
checkMapSize(booleanCountMap.size());
}
+ } else {
+ nullCount++;
}
}
}
@@ -429,6 +480,8 @@ public class TableModeAccumulator implements
TableAccumulator {
if (!column.isNull(i)) {
intCountMap.compute(column.getInt(i), (k, v) -> v == null ? 1 : v + 1);
checkMapSize(intCountMap.size());
+ } else {
+ nullCount++;
}
}
}
@@ -438,6 +491,8 @@ public class TableModeAccumulator implements
TableAccumulator {
if (!column.isNull(i)) {
floatCountMap.compute(column.getFloat(i), (k, v) -> v == null ? 1 : v
+ 1);
checkMapSize(floatCountMap.size());
+ } else {
+ nullCount++;
}
}
}
@@ -447,6 +502,8 @@ public class TableModeAccumulator implements
TableAccumulator {
if (!column.isNull(i)) {
longCountMap.compute(column.getLong(i), (k, v) -> v == null ? 1 : v +
1);
checkMapSize(longCountMap.size());
+ } else {
+ nullCount++;
}
}
}
@@ -456,6 +513,8 @@ public class TableModeAccumulator implements
TableAccumulator {
if (!column.isNull(i)) {
doubleCountMap.compute(column.getDouble(i), (k, v) -> v == null ? 1 :
v + 1);
checkMapSize(doubleCountMap.size());
+ } else {
+ nullCount++;
}
}
}
@@ -465,6 +524,8 @@ public class TableModeAccumulator implements
TableAccumulator {
if (!column.isNull(i)) {
binaryCountMap.compute(column.getBinary(i), (k, v) -> v == null ? 1 :
v + 1);
checkMapSize(binaryCountMap.size());
+ } else {
+ nullCount++;
}
}
}
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/Utils.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/Utils.java
index 99252b9606d..37f2f5d912d 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/Utils.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/Utils.java
@@ -73,7 +73,7 @@ public class Utils {
public static void serializeBinaryValue(Binary binary, byte[] valueBytes,
int offset) {
BytesUtils.intToBytes(binary.getValues().length, valueBytes, offset);
- offset += 4;
+ offset += Integer.BYTES;
System.arraycopy(binary.getValues(), 0, valueBytes, offset,
binary.getValues().length);
}
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/grouped/GroupedModeAccumulator.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/grouped/GroupedModeAccumulator.java
new file mode 100644
index 00000000000..f3a5a316760
--- /dev/null
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/grouped/GroupedModeAccumulator.java
@@ -0,0 +1,452 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package
org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped;
+
+import org.apache.iotdb.db.conf.IoTDBDescriptor;
+import
org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.array.LongBigArray;
+import
org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.array.MapBigArray;
+
+import org.apache.tsfile.block.column.Column;
+import org.apache.tsfile.block.column.ColumnBuilder;
+import org.apache.tsfile.enums.TSDataType;
+import org.apache.tsfile.utils.Binary;
+import org.apache.tsfile.utils.BytesUtils;
+import org.apache.tsfile.utils.RamUsageEstimator;
+import org.apache.tsfile.utils.TsPrimitiveType;
+
+import java.util.HashMap;
+import java.util.Map;
+
+import static
org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.Utils.UNSUPPORTED_TYPE_MESSAGE;
+import static
org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.Utils.serializeBinaryValue;
+import static org.apache.tsfile.utils.BytesUtils.bytesToBool;
+import static org.apache.tsfile.utils.BytesUtils.bytesToLongFromOffset;
+import static org.apache.tsfile.utils.TsPrimitiveType.getByType;
+
+public class GroupedModeAccumulator implements GroupedAccumulator {
+
+ private final int MAP_SIZE_THRESHOLD =
+ IoTDBDescriptor.getInstance().getConfig().getModeMapSizeThreshold();
+ private static final long INSTANCE_SIZE =
+ RamUsageEstimator.shallowSizeOfInstance(GroupedModeAccumulator.class);
+ private final TSDataType seriesDataType;
+
+ private final MapBigArray countMaps = new MapBigArray();
+
+ private final LongBigArray nullCounts = new LongBigArray();
+
+ public GroupedModeAccumulator(TSDataType seriesDataType) {
+ this.seriesDataType = seriesDataType;
+ }
+
+ @Override
+ public long getEstimatedSize() {
+ return INSTANCE_SIZE + countMaps.sizeOf() + nullCounts.sizeOf();
+ }
+
+ @Override
+ public void setGroupCount(long groupCount) {
+ countMaps.ensureCapacity(groupCount);
+ nullCounts.ensureCapacity(groupCount);
+ }
+
+ @Override
+ public void addInput(int[] groupIds, Column[] arguments) {
+ switch (seriesDataType) {
+ case BOOLEAN:
+ addBooleanInput(groupIds, arguments[0]);
+ break;
+ case INT32:
+ case DATE:
+ addIntInput(groupIds, arguments[0]);
+ break;
+ case FLOAT:
+ addFloatInput(groupIds, arguments[0]);
+ break;
+ case INT64:
+ case TIMESTAMP:
+ addLongInput(groupIds, arguments[0]);
+ break;
+ case DOUBLE:
+ addDoubleInput(groupIds, arguments[0]);
+ break;
+ case TEXT:
+ case STRING:
+ case BLOB:
+ addBinaryInput(groupIds, arguments[0]);
+ break;
+ default:
+ throw new UnsupportedOperationException(
+ String.format(UNSUPPORTED_TYPE_MESSAGE, seriesDataType));
+ }
+ }
+
+ @Override
+ public void addIntermediate(int[] groupIds, Column argument) {
+ for (int i = 0; i < argument.getPositionCount(); i++) {
+ if (argument.isNull(i)) {
+ continue;
+ }
+
+ byte[] bytes = argument.getBinary(i).getValues();
+ deserializeAndMergeCountMap(groupIds[i], bytes);
+ }
+ }
+
+ @Override
+ public void evaluateIntermediate(int groupId, ColumnBuilder columnBuilder) {
+ columnBuilder.writeBinary(new Binary(serializeCountMap(groupId)));
+ }
+
+ @Override
+ public void evaluateFinal(int groupId, ColumnBuilder columnBuilder) {
+ HashMap<TsPrimitiveType, Long> countMap = countMaps.get(groupId);
+ if (countMap.isEmpty()) {
+ columnBuilder.appendNull();
+ return;
+ }
+ // must be present
+ Map.Entry<TsPrimitiveType, Long> maxEntry =
+ countMap.entrySet().stream().max(Map.Entry.comparingByValue()).get();
+ if (maxEntry.getValue() < nullCounts.get(groupId)) {
+ columnBuilder.appendNull();
+ return;
+ }
+
+ switch (seriesDataType) {
+ case BOOLEAN:
+ columnBuilder.writeBoolean(maxEntry.getKey().getBoolean());
+ break;
+ case INT32:
+ case DATE:
+ columnBuilder.writeInt(maxEntry.getKey().getInt());
+ break;
+ case FLOAT:
+ columnBuilder.writeFloat(maxEntry.getKey().getFloat());
+ break;
+ case INT64:
+ case TIMESTAMP:
+ columnBuilder.writeLong(maxEntry.getKey().getLong());
+ break;
+ case DOUBLE:
+ columnBuilder.writeDouble(maxEntry.getKey().getDouble());
+ break;
+ case TEXT:
+ case STRING:
+ case BLOB:
+ columnBuilder.writeBinary(maxEntry.getKey().getBinary());
+ break;
+ default:
+ throw new UnsupportedOperationException(
+ String.format(UNSUPPORTED_TYPE_MESSAGE, seriesDataType));
+ }
+ }
+
+ @Override
+ public void prepareFinal() {}
+
+ // haveNull | nullCount (optional) | countMap
+ private byte[] serializeCountMap(int groupId) {
+ byte[] bytes;
+ int offset = 1 + (nullCounts.get(groupId) == 0 ? 0 : Long.BYTES);
+ HashMap<TsPrimitiveType, Long> countMap = countMaps.get(groupId);
+
+ switch (seriesDataType) {
+ case BOOLEAN:
+ bytes = new byte[offset + Integer.BYTES + (1 + Long.BYTES) *
countMap.size()];
+ BytesUtils.boolToBytes(nullCounts.get(groupId) != 0, bytes, 0);
+ if (nullCounts.get(groupId) != 0) {
+ BytesUtils.longToBytes(nullCounts.get(groupId), bytes, 1);
+ }
+ BytesUtils.intToBytes(countMap.size(), bytes, offset);
+ offset += 4;
+ for (Map.Entry<TsPrimitiveType, Long> entry : countMap.entrySet()) {
+ BytesUtils.boolToBytes(entry.getKey().getBoolean(), bytes, offset);
+ offset += 1;
+ BytesUtils.longToBytes(entry.getValue(), bytes, offset);
+ offset += Long.BYTES;
+ }
+ break;
+ case INT32:
+ case DATE:
+ bytes = new byte[offset + Integer.BYTES + (Integer.BYTES + Long.BYTES)
* countMap.size()];
+ BytesUtils.boolToBytes(nullCounts.get(groupId) != 0, bytes, 0);
+ if (nullCounts.get(groupId) != 0) {
+ BytesUtils.longToBytes(nullCounts.get(groupId), bytes, 1);
+ }
+ BytesUtils.intToBytes(countMap.size(), bytes, offset);
+ offset += Integer.BYTES;
+ for (Map.Entry<TsPrimitiveType, Long> entry : countMap.entrySet()) {
+ BytesUtils.intToBytes(entry.getKey().getInt(), bytes, offset);
+ offset += Integer.BYTES;
+ BytesUtils.longToBytes(entry.getValue(), bytes, offset);
+ offset += Long.BYTES;
+ }
+ break;
+ case FLOAT:
+ bytes = new byte[offset + Integer.BYTES + (Float.BYTES + Long.BYTES) *
countMap.size()];
+ BytesUtils.boolToBytes(nullCounts.get(groupId) != 0, bytes, 0);
+ if (nullCounts.get(groupId) != 0) {
+ BytesUtils.longToBytes(nullCounts.get(groupId), bytes, 1);
+ }
+ BytesUtils.intToBytes(countMap.size(), bytes, offset);
+ offset += Integer.BYTES;
+ for (Map.Entry<TsPrimitiveType, Long> entry : countMap.entrySet()) {
+ BytesUtils.floatToBytes(entry.getKey().getFloat(), bytes, offset);
+ offset += Float.BYTES;
+ BytesUtils.longToBytes(entry.getValue(), bytes, offset);
+ offset += Long.BYTES;
+ }
+ break;
+ case INT64:
+ case TIMESTAMP:
+ bytes = new byte[offset + Integer.BYTES + (Long.BYTES + Long.BYTES) *
countMap.size()];
+ BytesUtils.boolToBytes(nullCounts.get(groupId) != 0, bytes, 0);
+ if (nullCounts.get(groupId) != 0) {
+ BytesUtils.longToBytes(nullCounts.get(groupId), bytes, 1);
+ }
+ BytesUtils.intToBytes(countMap.size(), bytes, offset);
+ offset += Integer.BYTES;
+ for (Map.Entry<TsPrimitiveType, Long> entry : countMap.entrySet()) {
+ BytesUtils.longToBytes(entry.getKey().getLong(), bytes, offset);
+ offset += Long.BYTES;
+ BytesUtils.longToBytes(entry.getValue(), bytes, offset);
+ offset += Long.BYTES;
+ }
+ break;
+ case DOUBLE:
+ bytes = new byte[offset + Integer.BYTES + (Double.BYTES + Long.BYTES)
* countMap.size()];
+ BytesUtils.boolToBytes(nullCounts.get(groupId) != 0, bytes, 0);
+ if (nullCounts.get(groupId) != 0) {
+ BytesUtils.longToBytes(nullCounts.get(groupId), bytes, 1);
+ }
+ BytesUtils.intToBytes(countMap.size(), bytes, offset);
+ offset += Integer.BYTES;
+ for (Map.Entry<TsPrimitiveType, Long> entry : countMap.entrySet()) {
+ BytesUtils.doubleToBytes(entry.getKey().getDouble(), bytes, offset);
+ offset += Double.BYTES;
+ BytesUtils.longToBytes(entry.getValue(), bytes, offset);
+ offset += Long.BYTES;
+ }
+ break;
+ case TEXT:
+ case STRING:
+ case BLOB:
+ bytes =
+ new byte
+ [offset
+ + Integer.BYTES
+ + (Integer.BYTES + Long.BYTES) * countMap.size()
+ + countMap.keySet().stream()
+ .mapToInt(key -> key.getBinary().getValues().length)
+ .sum()];
+ BytesUtils.boolToBytes(nullCounts.get(groupId) != 0, bytes, 0);
+ if (nullCounts.get(groupId) != 0) {
+ BytesUtils.longToBytes(nullCounts.get(groupId), bytes, 1);
+ }
+ BytesUtils.intToBytes(countMap.size(), bytes, offset);
+ offset += Integer.BYTES;
+ for (Map.Entry<TsPrimitiveType, Long> entry : countMap.entrySet()) {
+ Binary binary = entry.getKey().getBinary();
+ serializeBinaryValue(binary, bytes, offset);
+ offset += (Integer.BYTES + binary.getLength());
+ BytesUtils.longToBytes(entry.getValue(), bytes, offset);
+ offset += Long.BYTES;
+ }
+ break;
+ default:
+ throw new UnsupportedOperationException(
+ String.format(UNSUPPORTED_TYPE_MESSAGE, seriesDataType));
+ }
+
+ return bytes;
+ }
+
+ private void deserializeAndMergeCountMap(int groupId, byte[] bytes) {
+ int offset = 0;
+ if (bytesToBool(bytes, 0)) {
+ nullCounts.add(groupId, bytesToLongFromOffset(bytes, Long.BYTES, 1));
+ offset += Long.BYTES;
+ }
+ offset++;
+ int size = BytesUtils.bytesToInt(bytes, offset);
+ offset += Integer.BYTES;
+
+ HashMap<TsPrimitiveType, Long> countMap = countMaps.get(groupId);
+
+ switch (seriesDataType) {
+ case BOOLEAN:
+ for (int i = 0; i < size; i++) {
+ TsPrimitiveType key = new
TsPrimitiveType.TsBoolean(bytesToBool(bytes, offset));
+ offset += 1;
+ long count = BytesUtils.bytesToLongFromOffset(bytes, Long.BYTES,
offset);
+ offset += Long.BYTES;
+ countMap.compute(key, (k, v) -> v == null ? count : v + count);
+ }
+ break;
+ case INT32:
+ case DATE:
+ for (int i = 0; i < size; i++) {
+ TsPrimitiveType key = new
TsPrimitiveType.TsInt(BytesUtils.bytesToInt(bytes, offset));
+ offset += Integer.BYTES;
+ long count = BytesUtils.bytesToLongFromOffset(bytes, Long.BYTES,
offset);
+ offset += Long.BYTES;
+ countMap.compute(key, (k, v) -> v == null ? count : v + count);
+ }
+ break;
+ case FLOAT:
+ for (int i = 0; i < size; i++) {
+ TsPrimitiveType key = new
TsPrimitiveType.TsFloat(BytesUtils.bytesToFloat(bytes, offset));
+ offset += Float.BYTES;
+ long count = BytesUtils.bytesToLongFromOffset(bytes, Long.BYTES,
offset);
+ offset += Long.BYTES;
+ countMap.compute(key, (k, v) -> v == null ? count : v + count);
+ }
+ break;
+ case INT64:
+ case TIMESTAMP:
+ for (int i = 0; i < size; i++) {
+ TsPrimitiveType key =
+ new TsPrimitiveType.TsLong(
+ BytesUtils.bytesToLongFromOffset(bytes, Long.BYTES, offset));
+ offset += Long.BYTES;
+ long count = BytesUtils.bytesToLongFromOffset(bytes, Long.BYTES,
offset);
+ offset += Long.BYTES;
+ countMap.compute(key, (k, v) -> v == null ? count : v + count);
+ }
+ break;
+ case DOUBLE:
+ for (int i = 0; i < size; i++) {
+ TsPrimitiveType key =
+ new TsPrimitiveType.TsDouble(BytesUtils.bytesToDouble(bytes,
offset));
+ offset += Double.BYTES;
+ long count = BytesUtils.bytesToLongFromOffset(bytes, Long.BYTES,
offset);
+ offset += Long.BYTES;
+ countMap.compute(key, (k, v) -> v == null ? count : v + count);
+ }
+ break;
+ case TEXT:
+ case STRING:
+ case BLOB:
+ for (int i = 0; i < size; i++) {
+ int length = BytesUtils.bytesToInt(bytes, offset);
+ offset += Integer.BYTES;
+ TsPrimitiveType key =
+ new TsPrimitiveType.TsBinary(new
Binary(BytesUtils.subBytes(bytes, offset, length)));
+ offset += length;
+ long count = BytesUtils.bytesToLongFromOffset(bytes, Long.BYTES,
offset);
+ offset += Long.BYTES;
+ countMap.compute(key, (k, v) -> v == null ? count : v + count);
+ }
+ break;
+ default:
+ throw new UnsupportedOperationException(
+ String.format(UNSUPPORTED_TYPE_MESSAGE, seriesDataType));
+ }
+ }
+
+ private void addBooleanInput(int[] groupIds, Column column) {
+ for (int i = 0; i < column.getPositionCount(); i++) {
+ if (!column.isNull(i)) {
+ HashMap<TsPrimitiveType, Long> countMap = countMaps.get(groupIds[i]);
+ countMap.compute(
+ getByType(seriesDataType, column.getBoolean(i)), (k, v) -> v ==
null ? 1 : v + 1);
+ checkMapSize(countMap.size());
+
+ } else {
+ nullCounts.increment(groupIds[i]);
+ }
+ }
+ }
+
+ private void addIntInput(int[] groupIds, Column column) {
+ for (int i = 0; i < column.getPositionCount(); i++) {
+ if (!column.isNull(i)) {
+ HashMap<TsPrimitiveType, Long> countMap = countMaps.get(groupIds[i]);
+ countMap.compute(
+ getByType(seriesDataType, column.getInt(i)), (k, v) -> v == null ?
1 : v + 1);
+ checkMapSize(countMap.size());
+ } else {
+ nullCounts.increment(groupIds[i]);
+ }
+ }
+ }
+
+ private void addFloatInput(int[] groupIds, Column column) {
+ for (int i = 0; i < column.getPositionCount(); i++) {
+ if (!column.isNull(i)) {
+ HashMap<TsPrimitiveType, Long> countMap = countMaps.get(groupIds[i]);
+ countMap.compute(
+ getByType(seriesDataType, column.getFloat(i)), (k, v) -> v == null
? 1 : v + 1);
+ checkMapSize(countMap.size());
+ } else {
+ nullCounts.increment(groupIds[i]);
+ }
+ }
+ }
+
+ private void addLongInput(int[] groupIds, Column column) {
+ for (int i = 0; i < column.getPositionCount(); i++) {
+ if (!column.isNull(i)) {
+ HashMap<TsPrimitiveType, Long> countMap = countMaps.get(groupIds[i]);
+ countMap.compute(
+ getByType(seriesDataType, column.getLong(i)), (k, v) -> v == null
? 1 : v + 1);
+ checkMapSize(countMap.size());
+ } else {
+ nullCounts.increment(groupIds[i]);
+ }
+ }
+ }
+
+ private void addDoubleInput(int[] groupIds, Column column) {
+ for (int i = 0; i < column.getPositionCount(); i++) {
+ if (!column.isNull(i)) {
+ HashMap<TsPrimitiveType, Long> countMap = countMaps.get(groupIds[i]);
+ countMap.compute(
+ getByType(seriesDataType, column.getDouble(i)), (k, v) -> v ==
null ? 1 : v + 1);
+ checkMapSize(countMap.size());
+ } else {
+ nullCounts.increment(groupIds[i]);
+ }
+ }
+ }
+
+ private void addBinaryInput(int[] groupIds, Column column) {
+ for (int i = 0; i < column.getPositionCount(); i++) {
+ if (!column.isNull(i)) {
+ HashMap<TsPrimitiveType, Long> countMap = countMaps.get(groupIds[i]);
+ countMap.compute(
+ getByType(seriesDataType, column.getBinary(i)), (k, v) -> v ==
null ? 1 : v + 1);
+ checkMapSize(countMap.size());
+ } else {
+ nullCounts.increment(groupIds[i]);
+ }
+ }
+ }
+
+ private void checkMapSize(int size) {
+ if (size > MAP_SIZE_THRESHOLD) {
+ throw new RuntimeException(
+ String.format(
+ "distinct values has exceeded the threshold %s when calculate
Mode in one group",
+ MAP_SIZE_THRESHOLD));
+ }
+ }
+}
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/grouped/GroupedVarianceAccumulator.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/grouped/GroupedVarianceAccumulator.java
new file mode 100644
index 00000000000..c80406e93c4
--- /dev/null
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/grouped/GroupedVarianceAccumulator.java
@@ -0,0 +1,225 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package
org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped;
+
+import
org.apache.iotdb.db.queryengine.execution.aggregation.VarianceAccumulator;
+import
org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.array.DoubleBigArray;
+import
org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.array.LongBigArray;
+
+import org.apache.tsfile.block.column.Column;
+import org.apache.tsfile.block.column.ColumnBuilder;
+import org.apache.tsfile.enums.TSDataType;
+import org.apache.tsfile.utils.Binary;
+import org.apache.tsfile.utils.BytesUtils;
+import org.apache.tsfile.utils.RamUsageEstimator;
+import org.apache.tsfile.write.UnSupportedDataTypeException;
+
+public class GroupedVarianceAccumulator implements GroupedAccumulator {
+
+ private static final long INSTANCE_SIZE =
+
RamUsageEstimator.shallowSizeOfInstance(GroupedVarianceAccumulator.class);
+ private final TSDataType seriesDataType;
+ private final VarianceAccumulator.VarianceType varianceType;
+
+ private final LongBigArray counts = new LongBigArray();
+ private final DoubleBigArray means = new DoubleBigArray();
+ private final DoubleBigArray m2s = new DoubleBigArray();
+
+ public GroupedVarianceAccumulator(
+ TSDataType seriesDataType, VarianceAccumulator.VarianceType
varianceType) {
+ this.seriesDataType = seriesDataType;
+ this.varianceType = varianceType;
+ }
+
+ @Override
+ public long getEstimatedSize() {
+ return INSTANCE_SIZE + counts.sizeOf() + means.sizeOf() + m2s.sizeOf();
+ }
+
+ @Override
+ public void setGroupCount(long groupCount) {
+ counts.ensureCapacity(groupCount);
+ means.ensureCapacity(groupCount);
+ m2s.ensureCapacity(groupCount);
+ }
+
+ @Override
+ public void addInput(int[] groupIds, Column[] arguments) {
+ switch (seriesDataType) {
+ case INT32:
+ addIntInput(groupIds, arguments[0]);
+ return;
+ case INT64:
+ addLongInput(groupIds, arguments[0]);
+ return;
+ case FLOAT:
+ addFloatInput(groupIds, arguments[0]);
+ return;
+ case DOUBLE:
+ addDoubleInput(groupIds, arguments[0]);
+ return;
+ case TEXT:
+ case BLOB:
+ case BOOLEAN:
+ case DATE:
+ case STRING:
+ case TIMESTAMP:
+ default:
+ throw new UnSupportedDataTypeException(
+ String.format("Unsupported data type in aggregation variance :
%s", seriesDataType));
+ }
+ }
+
+ @Override
+ public void addIntermediate(int[] groupIds, Column argument) {
+ for (int i = 0; i < argument.getPositionCount(); i++) {
+ if (argument.isNull(i)) {
+ continue;
+ }
+
+ byte[] bytes = argument.getBinary(i).getValues();
+ long intermediateCount = BytesUtils.bytesToLong(bytes, Long.BYTES);
+ double intermediateMean = BytesUtils.bytesToDouble(bytes, Long.BYTES);
+ double intermediateM2 = BytesUtils.bytesToDouble(bytes, (Long.BYTES +
Double.BYTES));
+
+ long newCount = counts.get(groupIds[i]) + intermediateCount;
+ double newMean =
+ ((intermediateCount * intermediateMean)
+ + (counts.get(groupIds[i]) * means.get(groupIds[i])))
+ / newCount;
+ double delta = intermediateMean - means.get(groupIds[i]);
+
+ m2s.add(
+ groupIds[i],
+ intermediateM2 + delta * delta * intermediateCount *
counts.get(groupIds[i]) / newCount);
+ counts.set(groupIds[i], newCount);
+ means.set(groupIds[i], newMean);
+ }
+ }
+
+ @Override
+ public void evaluateIntermediate(int groupId, ColumnBuilder columnBuilder) {
+ if (counts.get(groupId) == 0) {
+ columnBuilder.appendNull();
+ } else {
+ byte[] bytes = new byte[24];
+ BytesUtils.longToBytes(counts.get(groupId), bytes, 0);
+ BytesUtils.doubleToBytes(means.get(groupId), bytes, Long.BYTES);
+ BytesUtils.doubleToBytes(m2s.get(groupId), bytes, Long.BYTES +
Double.BYTES);
+ columnBuilder.writeBinary(new Binary(bytes));
+ }
+ }
+
+ @Override
+ public void evaluateFinal(int groupId, ColumnBuilder columnBuilder) {
+ switch (varianceType) {
+ case STDDEV_POP:
+ if (counts.get(groupId) == 0) {
+ columnBuilder.appendNull();
+ } else {
+ columnBuilder.writeDouble(Math.sqrt(m2s.get(groupId) /
counts.get(groupId)));
+ }
+ break;
+ case STDDEV_SAMP:
+ if (counts.get(groupId) < 2) {
+ columnBuilder.appendNull();
+ } else {
+ columnBuilder.writeDouble(Math.sqrt(m2s.get(groupId) /
(counts.get(groupId) - 1)));
+ }
+ break;
+ case VAR_POP:
+ if (counts.get(groupId) == 0) {
+ columnBuilder.appendNull();
+ } else {
+ columnBuilder.writeDouble(m2s.get(groupId) / counts.get(groupId));
+ }
+ break;
+ case VAR_SAMP:
+ if (counts.get(groupId) < 2) {
+ columnBuilder.appendNull();
+ } else {
+ columnBuilder.writeDouble(m2s.get(groupId) / (counts.get(groupId) -
1));
+ }
+ break;
+ default:
+ throw new EnumConstantNotPresentException(
+ VarianceAccumulator.VarianceType.class, varianceType.name());
+ }
+ }
+
+ @Override
+ public void prepareFinal() {}
+
+ private void addIntInput(int[] groupIds, Column column) {
+ for (int i = 0; i < column.getPositionCount(); i++) {
+ if (column.isNull(i)) {
+ continue;
+ }
+
+ int value = column.getInt(i);
+ counts.increment(groupIds[i]);
+ double delta = value - means.get(groupIds[i]);
+ means.add(groupIds[i], delta / counts.get(groupIds[i]));
+ m2s.add(groupIds[i], delta * (value - means.get(groupIds[i])));
+ }
+ }
+
+ private void addLongInput(int[] groupIds, Column column) {
+ for (int i = 0; i < column.getPositionCount(); i++) {
+ if (column.isNull(i)) {
+ continue;
+ }
+
+ long value = column.getLong(i);
+ counts.increment(groupIds[i]);
+ double delta = value - means.get(groupIds[i]);
+ means.add(groupIds[i], delta / counts.get(groupIds[i]));
+ m2s.add(groupIds[i], delta * (value - means.get(groupIds[i])));
+ }
+ }
+
+ private void addFloatInput(int[] groupIds, Column column) {
+ for (int i = 0; i < column.getPositionCount(); i++) {
+ if (column.isNull(i)) {
+ continue;
+ }
+
+ float value = column.getFloat(i);
+ counts.increment(groupIds[i]);
+ double delta = value - means.get(groupIds[i]);
+ means.add(groupIds[i], delta / counts.get(groupIds[i]));
+ m2s.add(groupIds[i], delta * (value - means.get(groupIds[i])));
+ }
+ }
+
+ private void addDoubleInput(int[] groupIds, Column column) {
+ for (int i = 0; i < column.getPositionCount(); i++) {
+ if (column.isNull(i)) {
+ continue;
+ }
+
+ double value = column.getDouble(i);
+ counts.increment(groupIds[i]);
+ double delta = value - means.get(groupIds[i]);
+ means.add(groupIds[i], delta / counts.get(groupIds[i]));
+ m2s.add(groupIds[i], delta * (value - means.get(groupIds[i])));
+ }
+ }
+}
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/grouped/array/MapBigArray.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/grouped/array/MapBigArray.java
new file mode 100644
index 00000000000..ca674dbf2e6
--- /dev/null
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/grouped/array/MapBigArray.java
@@ -0,0 +1,83 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package
org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.array;
+
+import org.apache.tsfile.utils.TsPrimitiveType;
+
+import java.util.HashMap;
+
+import static org.apache.tsfile.utils.RamUsageEstimator.shallowSizeOfInstance;
+import static org.apache.tsfile.utils.RamUsageEstimator.sizeOfObject;
+
+public final class MapBigArray {
+ private static final long INSTANCE_SIZE =
shallowSizeOfInstance(MapBigArray.class);
+ private final ObjectBigArray<HashMap<TsPrimitiveType, Long>> array;
+ private long sizeOfMaps;
+
+ public MapBigArray() {
+ array = new ObjectBigArray<>();
+ }
+
+ public MapBigArray(HashMap<TsPrimitiveType, Long> slice) {
+ array = new ObjectBigArray<>(slice);
+ }
+
+ /** Returns the size of this big array in bytes. */
+ public long sizeOf() {
+ return INSTANCE_SIZE + array.sizeOf() + sizeOfMaps;
+ }
+
+ /**
+ * Returns the element of this big array at specified index.
+ *
+ * @param index a position in this big array.
+ * @return the element of this big array at the specified position.
+ */
+ public HashMap<TsPrimitiveType, Long> get(long index) {
+ HashMap<TsPrimitiveType, Long> result = array.get(index);
+ if (result == null) {
+ result = new HashMap<>();
+ array.set(index, result);
+ }
+ return result;
+ }
+
+ /**
+ * Sets the element of this big array at specified index.
+ *
+ * @param index a position in this big array.
+ */
+ public void set(long index, HashMap<TsPrimitiveType, Long> value) {
+ updateRetainedSize(index, value);
+ array.set(index, value);
+ }
+
+ /**
+ * Ensures this big array is at least the specified length. If the array is
smaller, segments are
+ * added until the array is larger then the specified length.
+ */
+ public void ensureCapacity(long length) {
+ array.ensureCapacity(length);
+ }
+
+ private void updateRetainedSize(long index, HashMap<TsPrimitiveType, Long>
value) {
+ HashMap<TsPrimitiveType, Long> currentValue = array.get(index);
+ if (currentValue != null) {
+ sizeOfMaps -= sizeOfObject(currentValue);
+ }
+ if (value != null) {
+ sizeOfMaps += sizeOfObject(value);
+ }
+ }
+}
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/TableOperatorGenerator.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/TableOperatorGenerator.java
index 1d02608b317..56d6868a24b 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/TableOperatorGenerator.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/TableOperatorGenerator.java
@@ -129,6 +129,7 @@ import
org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Expression;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.FunctionCall;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Literal;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.LongLiteral;
+import
org.apache.iotdb.db.queryengine.plan.relational.type.InternalTypeManager;
import org.apache.iotdb.db.queryengine.plan.statement.component.Ordering;
import
org.apache.iotdb.db.queryengine.transformation.dag.column.ColumnTransformer;
import
org.apache.iotdb.db.queryengine.transformation.dag.column.leaf.LeafColumnTransformer;
@@ -145,7 +146,6 @@ import org.apache.tsfile.read.common.TimeRange;
import org.apache.tsfile.read.common.type.BinaryType;
import org.apache.tsfile.read.common.type.BlobType;
import org.apache.tsfile.read.common.type.BooleanType;
-import org.apache.tsfile.read.common.type.RowType;
import org.apache.tsfile.read.common.type.Type;
import org.apache.tsfile.read.filter.basic.Filter;
import org.apache.tsfile.utils.Binary;
@@ -1368,26 +1368,21 @@ public class TableOperatorGenerator extends
PlanVisitor<Operator, LocalExecution
TypeProvider typeProvider,
boolean scanAscending) {
List<Integer> argumentChannels = new ArrayList<>();
- List<TSDataType> argumentTypes = new ArrayList<>();
for (Expression argument : aggregation.getArguments()) {
Symbol argumentSymbol = Symbol.from(argument);
argumentChannels.add(childLayout.get(argumentSymbol));
-
- // get argument types
- Type type = typeProvider.getTableModelType(argumentSymbol);
- if (type instanceof RowType) {
- type.getTypeParameters().forEach(subType ->
argumentTypes.add(getTSDataType(subType)));
- } else {
- argumentTypes.add(getTSDataType(type));
- }
}
String functionName =
aggregation.getResolvedFunction().getSignature().getName();
+ List<TSDataType> originalArgumentTypes =
+
aggregation.getResolvedFunction().getSignature().getArgumentTypes().stream()
+ .map(InternalTypeManager::getTSDataType)
+ .collect(Collectors.toList());
TableAccumulator accumulator =
createAccumulator(
functionName,
getAggregationTypeByFuncName(functionName),
- argumentTypes,
+ originalArgumentTypes,
aggregation.getArguments(),
Collections.emptyMap(),
scanAscending);
@@ -1454,26 +1449,21 @@ public class TableOperatorGenerator extends
PlanVisitor<Operator, LocalExecution
AggregationNode.Step step,
TypeProvider typeProvider) {
List<Integer> argumentChannels = new ArrayList<>();
- List<TSDataType> argumentTypes = new ArrayList<>();
for (Expression argument : aggregation.getArguments()) {
Symbol argumentSymbol = Symbol.from(argument);
argumentChannels.add(childLayout.get(argumentSymbol));
-
- // get argument types
- Type type = typeProvider.getTableModelType(argumentSymbol);
- if (type instanceof RowType) {
- type.getTypeParameters().forEach(subType ->
argumentTypes.add(getTSDataType(subType)));
- } else {
- argumentTypes.add(getTSDataType(type));
- }
}
String functionName =
aggregation.getResolvedFunction().getSignature().getName();
+ List<TSDataType> originalArgumentTypes =
+
aggregation.getResolvedFunction().getSignature().getArgumentTypes().stream()
+ .map(InternalTypeManager::getTSDataType)
+ .collect(Collectors.toList());
GroupedAccumulator accumulator =
createGroupedAccumulator(
functionName,
getAggregationTypeByFuncName(functionName),
- argumentTypes,
+ originalArgumentTypes,
Collections.emptyList(),
Collections.emptyMap(),
true);
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/TableBuiltinAggregationFunction.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/TableBuiltinAggregationFunction.java
index 202ba867890..b070d78b977 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/TableBuiltinAggregationFunction.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/TableBuiltinAggregationFunction.java
@@ -21,10 +21,11 @@ package
org.apache.iotdb.db.queryengine.plan.relational.metadata;
import org.apache.iotdb.common.rpc.thrift.TAggregationType;
-import com.google.common.collect.ImmutableList;
+import org.apache.tsfile.read.common.type.RowType;
import org.apache.tsfile.read.common.type.Type;
import java.util.Arrays;
+import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
@@ -75,18 +76,34 @@ public enum TableBuiltinAggregationFunction {
return NATIVE_FUNCTION_NAMES;
}
- public static List<Type> getIntermediateTypes(String name, List<Type>
originalArgumentTypes) {
- if (COUNT.functionName.equalsIgnoreCase(name)) {
- return ImmutableList.of(INT64);
- } else if (SUM.functionName.equalsIgnoreCase(name)) {
- return ImmutableList.of(DOUBLE);
- } else if (AVG.functionName.equalsIgnoreCase(name)) {
- return ImmutableList.of(DOUBLE, INT64);
- } else if (LAST.functionName.equalsIgnoreCase(name)) {
- return ImmutableList.of(originalArgumentTypes.get(0), INT64);
- } else {
- // TODO(beyyes) consider other aggregations which changed the result type
- return ImmutableList.copyOf(originalArgumentTypes);
+ public static Type getIntermediateType(String name, List<Type>
originalArgumentTypes) {
+ final String functionName = name.toLowerCase();
+ switch (functionName) {
+ case "count":
+ return INT64;
+ case "sum":
+ return DOUBLE;
+ case "avg":
+ case "first":
+ case "first_by":
+ case "last":
+ case "last_by":
+ case "mode":
+ case "max_by":
+ case "min_by":
+ case "stddev":
+ case "stddev_pop":
+ case "stddev_samp":
+ case "variance":
+ case "var_pop":
+ case "var_samp":
+ return RowType.anonymous(Collections.emptyList());
+ case "extreme":
+ case "max":
+ case "min":
+ return originalArgumentTypes.get(0);
+ default:
+ throw new IllegalArgumentException("Invalid Aggregation function: " +
name);
}
}
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/Util.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/Util.java
index cdd8f01ab12..9a2c9cb55e8 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/Util.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/Util.java
@@ -28,12 +28,10 @@ import
org.apache.iotdb.db.queryengine.plan.relational.planner.node.AggregationN
import
org.apache.iotdb.db.queryengine.plan.relational.planner.node.AggregationTableScanNode;
import com.google.common.collect.ImmutableList;
-import org.apache.tsfile.read.common.type.RowType;
import org.apache.tsfile.read.common.type.Type;
import org.apache.tsfile.utils.Pair;
import java.util.LinkedHashMap;
-import java.util.List;
import java.util.Map;
import java.util.Optional;
@@ -58,14 +56,10 @@ public class Util {
for (Map.Entry<Symbol, AggregationNode.Aggregation> entry :
node.getAggregations().entrySet()) {
AggregationNode.Aggregation originalAggregation = entry.getValue();
ResolvedFunction resolvedFunction =
originalAggregation.getResolvedFunction();
- List<Type> intermediateTypes =
- TableBuiltinAggregationFunction.getIntermediateTypes(
+ Type intermediateType =
+ TableBuiltinAggregationFunction.getIntermediateType(
resolvedFunction.getSignature().getName(),
resolvedFunction.getSignature().getArgumentTypes());
- Type intermediateType =
- intermediateTypes.size() == 1
- ? intermediateTypes.get(0)
- : RowType.anonymous(intermediateTypes);
Symbol intermediateSymbol =
symbolAllocator.newSymbol(resolvedFunction.getSignature().getName(),
intermediateType);
// TODO put symbol and its type to TypeProvide or later process: add all
map contents of
@@ -128,14 +122,10 @@ public class Util {
for (Map.Entry<Symbol, AggregationNode.Aggregation> entry :
node.getAggregations().entrySet()) {
AggregationNode.Aggregation originalAggregation = entry.getValue();
ResolvedFunction resolvedFunction =
originalAggregation.getResolvedFunction();
- List<Type> intermediateTypes =
- TableBuiltinAggregationFunction.getIntermediateTypes(
+ Type intermediateType =
+ TableBuiltinAggregationFunction.getIntermediateType(
resolvedFunction.getSignature().getName(),
resolvedFunction.getSignature().getArgumentTypes());
- Type intermediateType =
- intermediateTypes.size() == 1
- ? intermediateTypes.get(0)
- : RowType.anonymous(intermediateTypes);
Symbol intermediateSymbol =
symbolAllocator.newSymbol(resolvedFunction.getSignature().getName(),
intermediateType);