This is an automated email from the ASF dual-hosted git repository. lancelly pushed a commit to branch max_by in repository https://gitbox.apache.org/repos/asf/iotdb.git
commit 955edb759d0f6323a55ec9f004b62967d9ac189b Author: lancelly <[email protected]> AuthorDate: Wed Jan 24 10:23:54 2024 +0800 partial IT --- .../db/it/aggregation/IoTDBAggregationIT.java | 53 +++ .../maxby/IoTDBMaxByAlignedSeriesIT.java | 62 ++++ .../db/it/aggregation/maxby/IoTDBMaxByIT.java | 362 +++++++++++++++++++++ .../execution/aggregation/Aggregator.java | 1 + .../execution/aggregation/MaxByAccumulator.java | 4 +- .../queryengine/plan/analyze/AnalyzeVisitor.java | 21 +- .../plan/analyze/ExpressionTypeAnalyzer.java | 18 +- .../db/queryengine/plan/analyze/TypeProvider.java | 4 + ...catDeviceAndBindSchemaForExpressionVisitor.java | 7 +- .../db/queryengine/plan/parser/ASTVisitor.java | 8 +- .../plan/planner/LogicalPlanBuilder.java | 23 +- .../plan/parameter/AggregationDescriptor.java | 2 +- .../CrossSeriesAggregationDescriptor.java | 2 +- .../org/apache/iotdb/db/utils/SchemaUtils.java | 2 +- .../iotdb/db/utils/constant/TestConstant.java | 4 + .../execution/aggregation/AccumulatorTest.java | 54 ++- .../thrift-commons/src/main/thrift/common.thrift | 4 +- 17 files changed, 599 insertions(+), 32 deletions(-) diff --git a/integration-test/src/test/java/org/apache/iotdb/db/it/aggregation/IoTDBAggregationIT.java b/integration-test/src/test/java/org/apache/iotdb/db/it/aggregation/IoTDBAggregationIT.java index 563347bee48..1dd8fca19cd 100644 --- a/integration-test/src/test/java/org/apache/iotdb/db/it/aggregation/IoTDBAggregationIT.java +++ b/integration-test/src/test/java/org/apache/iotdb/db/it/aggregation/IoTDBAggregationIT.java @@ -43,6 +43,7 @@ import static org.apache.iotdb.db.utils.constant.TestConstant.avg; import static org.apache.iotdb.db.utils.constant.TestConstant.count; import static org.apache.iotdb.db.utils.constant.TestConstant.firstValue; import static org.apache.iotdb.db.utils.constant.TestConstant.lastValue; +import static org.apache.iotdb.db.utils.constant.TestConstant.maxBy; import static org.apache.iotdb.db.utils.constant.TestConstant.maxTime; import static org.apache.iotdb.db.utils.constant.TestConstant.maxValue; import static org.apache.iotdb.db.utils.constant.TestConstant.minTime; @@ -981,4 +982,56 @@ public class IoTDBAggregationIT { expectedHeader, retArray); } + + @Test + public void maxByTest() { + String[] retArray = new String[] {"8499,500.0"}; + try (Connection connection = EnvFactory.getEnv().getConnection(); + Statement statement = connection.createStatement()) { + + int cnt; + try (ResultSet resultSet = + statement.executeQuery( + "SELECT max_value(time, s0) " + + "FROM root.vehicle.d0 WHERE time >= 100 AND time < 9000")) { + cnt = 0; + while (resultSet.next()) { + String ans = + resultSet.getString(TIMESTAMP_STR) + "," + resultSet.getString(maxBy("time", d0s0)); + Assert.assertEquals(retArray[cnt], ans); + cnt++; + } + Assert.assertEquals(1, cnt); + } + + try (ResultSet resultSet = + statement.executeQuery( + "SELECT max_value(time,s0) FROM root.vehicle.d0 WHERE time < 2500")) { + while (resultSet.next()) { + String ans = + resultSet.getString(TIMESTAMP_STR) + "," + resultSet.getString(maxBy("time", d0s0)); + Assert.assertEquals(retArray[cnt], ans); + cnt++; + } + Assert.assertEquals(2, cnt); + } + + // keep the correctness of `order by time desc` + cnt = 0; + try (ResultSet resultSet = + statement.executeQuery( + "SELECT max_by(time,s0) FROM root.vehicle.d0 WHERE time >= 100 AND time < 9000 order by time desc")) { + while (resultSet.next()) { + String ans = + resultSet.getString(TIMESTAMP_STR) + "," + resultSet.getString(maxBy("time", d0s0)); + Assert.assertEquals(retArray[cnt], ans); + cnt++; + } + Assert.assertEquals(1, cnt); + } + } catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } } diff --git a/integration-test/src/test/java/org/apache/iotdb/db/it/aggregation/maxby/IoTDBMaxByAlignedSeriesIT.java b/integration-test/src/test/java/org/apache/iotdb/db/it/aggregation/maxby/IoTDBMaxByAlignedSeriesIT.java new file mode 100644 index 00000000000..a92eab034bb --- /dev/null +++ b/integration-test/src/test/java/org/apache/iotdb/db/it/aggregation/maxby/IoTDBMaxByAlignedSeriesIT.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.it.aggregation.maxby; + +import org.apache.iotdb.it.env.EnvFactory; +import org.junit.BeforeClass; + +import static org.apache.iotdb.db.it.utils.TestUtils.prepareData; + +public class IoTDBMaxByAlignedSeriesIT extends IoTDBMaxByIT{ + protected static final String[] ALIGNED_DATASET = + new String[] { + // x input + "CREATE ALIGNED TIMESERIES root.db.d1(x1 INT32, x2 INT64, x3 FLOAT, x4 DOUBLE, x5 BOOLEAN, x6 TEXT)", + // y input + "CREATE ALIGNED TIMESERIES root.db.d1(y1 INT32, y2 INT64, y3 FLOAT, y4 DOUBLE, y5 BOOLEAN, y6 TEXT)", + "INSERT INTO root.db.d1(timestamp,x1,x2,x3,x4,x5,x6) values(1, 1, 1, 1, 1, true, \"1\")", + "INSERT INTO root.db.d1(timestamp,x1,x2,x3,x4,x5,x6) values(2, 2, 2, 2, 2, false, \"2\")", + "INSERT INTO root.db.d1(timestamp,x1,x2,x3,x4,x5,x6) values(3, 3, 3, 3, 3, false, \"3\")", + "INSERT INTO root.db.d1(timestamp,y1,y2,y3,y4,y5,y6) values(2, 2, 2, 2, 2, true, \"4\")", + "INSERT INTO root.db.d1(timestamp,y1,y2,y3,y4,y5,y6) values(3, 3, 3, 3, 3, false, \"3\")", + "INSERT INTO root.db.d1(timestamp,y1,y2,y3,y4,y5,y6) values(4, 4, 4, 4, 4, false, \"4\")", + "INSERT INTO root.db.d1(timestamp,x1,x2,x3,x4,x5,x6) values(8, 3, 3, 3, 3, false, \"3\")", + "INSERT INTO root.db.d1(timestamp,y1,y2,y3,y4,y5,y6) values(8, 8, 8, 8, 8, false, \"4\")", + "flush", + // For Align By Device + "CREATE ALIGNED TIMESERIES root.db.d2(x1 INT32, x2 INT64, x3 FLOAT, x4 DOUBLE, x5 BOOLEAN, x6 TEXT)", + "CREATE ALIGNED TIMESERIES root.db.d2(y1 INT32, y2 INT64, y3 FLOAT, y4 DOUBLE, y5 BOOLEAN, y6 TEXT)", + "INSERT INTO root.db.d2(timestamp,x1,x2,x3,x4,x5,x6) values(1, 1, 1, 1, 1, true, \"1\")", + "INSERT INTO root.db.d2(timestamp,x1,x2,x3,x4,x5,x6) values(2, 2, 2, 2, 2, false, \"2\")", + "INSERT INTO root.db.d2(timestamp,x1,x2,x3,x4,x5,x6) values(3, 3, 3, 3, 3, false, \"3\")", + "INSERT INTO root.db.d2(timestamp,y1,y2,y3,y4,y5,y6) values(2, 2, 2, 2, 2, true, \"4\")", + "INSERT INTO root.db.d2(timestamp,y1,y2,y3,y4,y5,y6) values(3, 3, 3, 3, 3, false, \"3\")", + "INSERT INTO root.db.d2(timestamp,y1,y2,y3,y4,y5,y6) values(4, 4, 4, 4, 4, false, \"4\")", + "INSERT INTO root.db.d2(timestamp,x1,x2,x3,x4,x5,x6) values(8, 3, 3, 3, 3, false, \"3\")", + "INSERT INTO root.db.d2(timestamp,y1,y2,y3,y4,y5,y6) values(8, 8, 8, 8, 8, false, \"4\")", + }; + + @BeforeClass + public static void setUp() throws Exception { + EnvFactory.getEnv().getConfig().getCommonConfig().setPartitionInterval(1000); + EnvFactory.getEnv().initClusterEnvironment(); + prepareData(ALIGNED_DATASET); + } +} diff --git a/integration-test/src/test/java/org/apache/iotdb/db/it/aggregation/maxby/IoTDBMaxByIT.java b/integration-test/src/test/java/org/apache/iotdb/db/it/aggregation/maxby/IoTDBMaxByIT.java new file mode 100644 index 00000000000..891d5cd96e6 --- /dev/null +++ b/integration-test/src/test/java/org/apache/iotdb/db/it/aggregation/maxby/IoTDBMaxByIT.java @@ -0,0 +1,362 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.it.aggregation.maxby; + +import org.apache.iotdb.it.env.EnvFactory; +import org.apache.iotdb.it.framework.IoTDBTestRunner; +import org.apache.iotdb.itbase.category.ClusterIT; +import org.apache.iotdb.itbase.category.LocalStandaloneIT; + +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; + +import java.sql.Connection; +import java.sql.ResultSet; +import java.sql.Statement; +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; + +import static org.apache.iotdb.db.it.utils.TestUtils.prepareData; +import static org.apache.iotdb.db.it.utils.TestUtils.resultSetEqualTest; +import static org.apache.iotdb.db.utils.constant.TestConstant.TIMESTAMP_STR; +import static org.apache.iotdb.db.utils.constant.TestConstant.maxBy; +import static org.apache.iotdb.itbase.constant.TestConstant.DEVICE; +import static org.junit.Assert.fail; + +@RunWith(IoTDBTestRunner.class) +@Category({LocalStandaloneIT.class, ClusterIT.class}) +public class IoTDBMaxByIT { + protected static final String[] NON_ALIGNED_DATASET = + new String[] { + "CREATE DATABASE root.db", + // x input + "CREATE TIMESERIES root.db.d1.x1 WITH DATATYPE=INT32, ENCODING=PLAIN", + "CREATE TIMESERIES root.db.d1.x2 WITH DATATYPE=INT64, ENCODING=PLAIN", + "CREATE TIMESERIES root.db.d1.x3 WITH DATATYPE=FLOAT, ENCODING=PLAIN", + "CREATE TIMESERIES root.db.d1.x4 WITH DATATYPE=DOUBLE, ENCODING=PLAIN", + "CREATE TIMESERIES root.db.d1.x5 WITH DATATYPE=BOOLEAN, ENCODING=PLAIN", + "CREATE TIMESERIES root.db.d1.x6 WITH DATATYPE=TEXT, ENCODING=PLAIN", + // y input + "CREATE TIMESERIES root.db.d1.y1 WITH DATATYPE=INT32, ENCODING=PLAIN", + "CREATE TIMESERIES root.db.d1.y2 WITH DATATYPE=INT64, ENCODING=PLAIN", + "CREATE TIMESERIES root.db.d1.y3 WITH DATATYPE=FLOAT, ENCODING=PLAIN", + "CREATE TIMESERIES root.db.d1.y4 WITH DATATYPE=DOUBLE, ENCODING=PLAIN", + "CREATE TIMESERIES root.db.d1.y5 WITH DATATYPE=BOOLEAN, ENCODING=PLAIN", + "CREATE TIMESERIES root.db.d1.y6 WITH DATATYPE=TEXT, ENCODING=PLAIN", + "INSERT INTO root.db.d1(timestamp,x1,x2,x3,x4,x5,x6) values(1, 1, 1, 1, 1, true, \"1\")", + "INSERT INTO root.db.d1(timestamp,x1,x2,x3,x4,x5,x6) values(2, 2, 2, 2, 2, false, \"2\")", + "INSERT INTO root.db.d1(timestamp,x1,x2,x3,x4,x5,x6) values(3, 3, 3, 3, 3, false, \"3\")", + "INSERT INTO root.db.d1(timestamp,y1,y2,y3,y4,y5,y6) values(2, 2, 2, 2, 2, true, \"4\")", + "INSERT INTO root.db.d1(timestamp,y1,y2,y3,y4,y5,y6) values(3, 3, 3, 3, 3, false, \"3\")", + "INSERT INTO root.db.d1(timestamp,y1,y2,y3,y4,y5,y6) values(4, 4, 4, 4, 4, false, \"4\")", + "INSERT INTO root.db.d1(timestamp,x1,x2,x3,x4,x5,x6) values(8, 3, 3, 3, 3, false, \"3\")", + "INSERT INTO root.db.d1(timestamp,y1,y2,y3,y4,y5,y6) values(8, 8, 8, 8, 8, false, \"4\")", + "flush", + + // For Align By Device + "CREATE TIMESERIES root.db.d2.x1 WITH DATATYPE=INT32, ENCODING=PLAIN", + "CREATE TIMESERIES root.db.d2.x2 WITH DATATYPE=INT64, ENCODING=PLAIN", + "CREATE TIMESERIES root.db.d2.x3 WITH DATATYPE=FLOAT, ENCODING=PLAIN", + "CREATE TIMESERIES root.db.d2.x4 WITH DATATYPE=DOUBLE, ENCODING=PLAIN", + "CREATE TIMESERIES root.db.d2.x5 WITH DATATYPE=BOOLEAN, ENCODING=PLAIN", + "CREATE TIMESERIES root.db.d2.x6 WITH DATATYPE=TEXT, ENCODING=PLAIN", + // y input + "CREATE TIMESERIES root.db.d2.y1 WITH DATATYPE=INT32, ENCODING=PLAIN", + "CREATE TIMESERIES root.db.d2.y2 WITH DATATYPE=INT64, ENCODING=PLAIN", + "CREATE TIMESERIES root.db.d2.y3 WITH DATATYPE=FLOAT, ENCODING=PLAIN", + "CREATE TIMESERIES root.db.d2.y4 WITH DATATYPE=DOUBLE, ENCODING=PLAIN", + "CREATE TIMESERIES root.db.d2.y5 WITH DATATYPE=BOOLEAN, ENCODING=PLAIN", + "CREATE TIMESERIES root.db.d2.y6 WITH DATATYPE=TEXT, ENCODING=PLAIN", + "INSERT INTO root.db.d2(timestamp,x1,x2,x3,x4,x5,x6) values(1, 1, 1, 1, 1, true, \"1\")", + "INSERT INTO root.db.d2(timestamp,x1,x2,x3,x4,x5,x6) values(2, 2, 2, 2, 2, false, \"2\")", + "INSERT INTO root.db.d2(timestamp,x1,x2,x3,x4,x5,x6) values(3, 3, 3, 3, 3, false, \"3\")", + "INSERT INTO root.db.d2(timestamp,x1,x2,x3,x4,x5,x6) values(1, 4, 4, 4, 4, true, \"1\")", + "INSERT INTO root.db.d2(timestamp,y1,y2,y3,y4,y5,y6) values(2, 2, 2, 2, 2, true, \"4\")", + "INSERT INTO root.db.d2(timestamp,y1,y2,y3,y4,y5,y6) values(3, 3, 3, 3, 3, false, \"3\")", + "INSERT INTO root.db.d2(timestamp,y1,y2,y3,y4,y5,y6) values(4, 1, 1, 1, 1, false, \"1\")", + "flush" + }; + + protected static final String UNSUPPORTED_TYPE_MESSAGE = "Unsupported data type in MaxBy:"; + + @BeforeClass + public static void setUp() throws Exception { + EnvFactory.getEnv().getConfig().getCommonConfig().setPartitionInterval(1000); + EnvFactory.getEnv().initClusterEnvironment(); + prepareData(NON_ALIGNED_DATASET); + } + + @AfterClass + public static void tearDown() throws Exception { + EnvFactory.getEnv().cleanClusterEnvironment(); + } + + @Test + public void testMaxByWithUnsupportedYInputTypes() { + try (Connection connection = EnvFactory.getEnv().getConnection(); + Statement statement = connection.createStatement()) { + try { + try (ResultSet resultSet = + statement.executeQuery("SELECT max_by(x1, y5) FROM root.db.d1")) { + resultSet.next(); + fail(); + } + } catch (Exception e) { + Assert.assertTrue(e.getMessage(), e.getMessage().contains(UNSUPPORTED_TYPE_MESSAGE)); + } + try { + try (ResultSet resultSet = + statement.executeQuery("SELECT max_by(x1, y6) FROM root.db.d1")) { + resultSet.next(); + fail(); + } + } catch (Exception e) { + Assert.assertTrue(e.getMessage(), e.getMessage().contains(UNSUPPORTED_TYPE_MESSAGE)); + } + try { + try (ResultSet resultSet = + statement.executeQuery("SELECT max_by(x5, y5) FROM root.db.d1")) { + resultSet.next(); + fail(); + } + } catch (Exception e) { + Assert.assertTrue(e.getMessage(), e.getMessage().contains(UNSUPPORTED_TYPE_MESSAGE)); + } + try { + try (ResultSet resultSet = + statement.executeQuery("SELECT max_by(x5, y6) FROM root.db.d1")) { + resultSet.next(); + fail(); + } + } catch (Exception e) { + Assert.assertTrue(e.getMessage(), e.getMessage().contains(UNSUPPORTED_TYPE_MESSAGE)); + } + try { + try (ResultSet resultSet = + statement.executeQuery("SELECT max_by(x6, y5) FROM root.db.d1")) { + resultSet.next(); + fail(); + } + } catch (Exception e) { + Assert.assertTrue(e.getMessage(), e.getMessage().contains(UNSUPPORTED_TYPE_MESSAGE)); + } + try { + try (ResultSet resultSet = + statement.executeQuery("SELECT max_by(x6, y6) FROM root.db.d1")) { + resultSet.next(); + fail(); + } + } catch (Exception e) { + Assert.assertTrue(e.getMessage(), e.getMessage().contains(UNSUPPORTED_TYPE_MESSAGE)); + } + } catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testMaxByWithDifferentXAndYInputTypes() { + try (Connection connection = EnvFactory.getEnv().getConnection(); + Statement statement = connection.createStatement()) { + Map<String, String[]> expectedHeaders = + generateExpectedHeadersForMaxByTest( + "root.db.d1", new String[] {"x1", "x2", "x3", "x4", "x5", "x6"}, new String[] {"y1", "y2", "y3", "y4"}); + String[] retArray = new String[] {"3,3,3.0,3.0,false,3,"}; + for (Map.Entry<String, String[]> expectedHeader : expectedHeaders.entrySet()) { + String y = expectedHeader.getKey(); + resultSetEqualTest( + String.format( + "select max_by(x1,%s),max_by(x2,%s),max_by(x3,%s),max_by(x4,%s),max_by(x5,%s),max_by(x6,%s) from root.db.d1 where time <= 3", + y, y, y, y, y, y), + expectedHeader.getValue(), + retArray); + } + } catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testMaxByWithDifferentXAndYInputTypesAndNullXValue() { + try (Connection connection = EnvFactory.getEnv().getConnection(); + Statement statement = connection.createStatement()) { + Map<String, String[]> expectedHeaders = + generateExpectedHeadersForMaxByTest( + "root.db.d1", new String[] {"x1", "x2", "x3", "x4", "x5", "x6"},new String[] {"y1", "y2", "y3", "y4"}); + String[] retArray = new String[] {"null,null,null,null,null,null,"}; + for (Map.Entry<String, String[]> expectedHeader : expectedHeaders.entrySet()) { + String y = expectedHeader.getKey(); + resultSetEqualTest( + String.format( + "select max_by(x1,%s),max_by(x2,%s),max_by(x3,%s),max_by(x4,%s),max_by(x5,%s),max_by(x6,%s) from root.db.d1 where time <= 4", + y, y, y, y, y, y), + expectedHeader.getValue(), + retArray); + } + } catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testMaxByWithDifferentYInputTypesAndXAsTime() { + try (Connection connection = EnvFactory.getEnv().getConnection(); + Statement statement = connection.createStatement()) { + Map<String, String[]> expectedHeaders = + generateExpectedHeadersForMaxByTest( + "root.db.d1", new String[] {"Time", "Time", "Time", "Time", "Time", "Time"},new String[] {"y1", "y2", "y3", "y4"}); + String[] retArray = new String[] {"3,3,3,3,3,3,"}; + for (Map.Entry<String, String[]> expectedHeader : expectedHeaders.entrySet()) { + String y = expectedHeader.getKey(); + resultSetEqualTest( + String.format( + "select max_by(time,%s),max_by(time,%s),max_by(time,%s),max_by(time,%s),max_by(time,%s),max_by(time,%s) from root.db.d1 where time <= 3", + y, y, y, y, y, y), + expectedHeader.getValue(), + retArray); + } + String[] retArray1 = new String[] {"4,4,4,4,4,4,"}; + for (Map.Entry<String, String[]> expectedHeader : expectedHeaders.entrySet()) { + String y = expectedHeader.getKey(); + resultSetEqualTest( + String.format( + "select max_by(time,%s),max_by(time,%s),max_by(time,%s),max_by(time,%s),max_by(time,%s),max_by(time,%s) from root.db.d1 where time <= 4", + y, y, y, y, y, y), + expectedHeader.getValue(), + retArray1); + } + } catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testMaxByWithExpression() { + try (Connection connection = EnvFactory.getEnv().getConnection(); + Statement statement = connection.createStatement()) { + String[] expectedHeader = new String[]{"max_by(root.db.d1.x1 + 1 - 3, -cos(sin(root.db.d1.y2 / 10)))","max_by(root.db.d1.x2 * 2 / 3, -cos(sin(root.db.d1.y2 / 10)))","max_by(floor(root.db.d1.x3), -cos(sin(root.db.d1.y2 / 10)))","max_by(ceil(root.db.d1.x4), -cos(sin(root.db.d1.y2 / 10)))","max_by(root.db.d1.x5, -cos(sin(root.db.d1.y2 / 10)))","max_by(REPLACE(root.db.d1.x6, '3', '4'), -cos(sin(root.db.d1.y2 / 10)))",}; + String[] retArray = new String[] {"1.0,2.0,3.0,3.0,false,4,"}; + String y = "-cos(sin(y2 / 10))"; + resultSetEqualTest( + String.format( + "select max_by(x1 + 1 - 3,%s),max_by(x2 * 2 / 3,%s),max_by(floor(x3),%s),max_by(ceil(x4),%s),max_by(x5,%s),max_by(replace(x6, '3', '4'),%s) from root.db.d1 where time <= 3", + y, y, y, y, y, y), + expectedHeader, + retArray); + } catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testMaxByWithAlignByDevice() { + try (Connection connection = EnvFactory.getEnv().getConnection(); + Statement statement = connection.createStatement()) { + String[] expectedHeader = new String[]{DEVICE, "max_by(x1 + 1 - 3, -cos(sin(y2 / 10)))","max_by(x2 * 2 / 3, -cos(sin(y2 / 10)))","max_by(floor(x3), -cos(sin(y2 / 10)))","max_by(ceil(x4), -cos(sin(y2 / 10)))","max_by(x5, -cos(sin(y2 / 10)))","max_by(REPLACE(x6, '3', '4'), -cos(sin(y2 / 10)))",}; + String[] retArray = new String[] {"root.db.d1,1.0,2.0,3.0,3.0,false,4,", "root.db.d2,1.0,2.0,3.0,3.0,false,4,"}; + String y = "-cos(sin(y2 / 10))"; + resultSetEqualTest( + String.format( + "select max_by(x1 + 1 - 3,%s),max_by(x2 * 2 / 3,%s),max_by(floor(x3),%s),max_by(ceil(x4),%s),max_by(x5,%s),max_by(replace(x6, '3', '4'),%s) from root.db.** where time <= 3 align by device", + y, y, y, y, y, y), + expectedHeader, + retArray); + } catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testMaxByWithGroupBy() { + try (Connection connection = EnvFactory.getEnv().getConnection(); + Statement statement = connection.createStatement()) { + String[] expectedHeader = new String[]{TIMESTAMP_STR, "max_by(root.db.d1.x1, root.db.d1.y2)","max_by(root.db.d1.x2, root.db.d1.y2)","max_by(root.db.d1.x3, root.db.d1.y2)","max_by(root.db.d1.x4, root.db.d1.y2)","max_by(root.db.d1.x5, root.db.d1.y2)","max_by(root.db.d1.x6, root.db.d1.y2)",}; + String[] retArray = new String[] {"0,3,3,3.0,3.0,false,3,", "4,null,null,null,null,null,null,", "8,3,3,3.0,3.0,false,3,"}; + String y = "y2"; + resultSetEqualTest( + String.format( + "select max_by(x1,%s),max_by(x2,%s),max_by(x3,%s),max_by(x4,%s),max_by(x5,%s),max_by(x6,%s) from root.db.d1 group by ([0,9),4ms)", + y, y, y, y, y, y), + expectedHeader, + retArray); + } catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testMaxByWithSlidingWindow() { + try (Connection connection = EnvFactory.getEnv().getConnection(); + Statement statement = connection.createStatement()) { + String[] expectedHeader = new String[]{TIMESTAMP_STR, "max_by(root.db.d1.x1, root.db.d1.y2)","max_by(root.db.d1.x2, root.db.d1.y2)","max_by(root.db.d1.x3, root.db.d1.y2)","max_by(root.db.d1.x4, root.db.d1.y2)","max_by(root.db.d1.x5, root.db.d1.y2)","max_by(root.db.d1.x6, root.db.d1.y2)",}; + String[] retArray = new String[] {"0,3,3,3.0,3.0,false,3,", "4,null,null,null,null,null,null,", "8,3,3,3.0,3.0,false,3,"}; + String y = "y2"; + resultSetEqualTest( + String.format( + "select max_by(x1,%s),max_by(x2,%s),max_by(x3,%s),max_by(x4,%s),max_by(x5,%s),max_by(x6,%s) from root.db.d1 group by ([0,9),4ms)", + y, y, y, y, y, y), + expectedHeader, + retArray); + } catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + + // test max_by different types of x + // test max_by time + // test max_by with expression + // test max_by align by device + // test max_by group by time + // test max_by sliding window + // test max_by aligned series + // test max_by multi data region + // test max_by group by level + // test max_by having + + /** @return yInput -> expectedHeader */ + private Map<String, String[]> generateExpectedHeadersForMaxByTest( + String device, String[] xInput, String[] yInput) { + Map<String, String[]> res = new HashMap<>(); + Arrays.stream(yInput) + .forEach( + y -> { + res.put( + y, + Arrays.stream(xInput) + .map(x -> maxBy("Time".equals(x) ? x : device + "." + x, device + "." + y)) + .toArray(String[]::new)); + }); + return res; + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/Aggregator.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/Aggregator.java index e99311b8d21..e31c92d4fc1 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/Aggregator.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/Aggregator.java @@ -79,6 +79,7 @@ public class Aggregator { "RawDataAggregateOperator can only process one tsBlock input."); int index = inputLocationList.get(0)[i].getValueColumnIndex(); // for count_time, time column is also its value column + // for max_by, the input column can also be time column. timeAndValueColumn[1 + i] = index == -1 ? timeAndValueColumn[0] : tsBlock.getColumn(index); } accumulator.addInput(timeAndValueColumn, bitMap, lastIndex); diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/MaxByAccumulator.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/MaxByAccumulator.java index 72a00c98b78..21add5ae17e 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/MaxByAccumulator.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/MaxByAccumulator.java @@ -76,12 +76,12 @@ public class MaxByAccumulator implements Accumulator { } } - // partialResult should be like: | partialX | partialY | + // partialResult should be like: | partialMaxByBinary | @Override public void addIntermediate(Column[] partialResult) { checkArgument(partialResult.length == 2, "partialResult of MaxBy should be 2"); // Return if y is null. - if (partialResult[1].isNull(0)) { + if (partialResult[0].isNull(0)) { return; } switch (yDataType) { diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/AnalyzeVisitor.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/AnalyzeVisitor.java index 7f36a0d099c..698c55eb43c 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/AnalyzeVisitor.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/AnalyzeVisitor.java @@ -136,6 +136,7 @@ import org.apache.iotdb.db.queryengine.plan.statement.sys.ExplainStatement; import org.apache.iotdb.db.queryengine.plan.statement.sys.ShowQueriesStatement; import org.apache.iotdb.db.queryengine.plan.statement.sys.ShowVersionStatement; import org.apache.iotdb.db.schemaengine.template.Template; +import org.apache.iotdb.db.utils.constant.SqlConstant; import org.apache.iotdb.rpc.RpcUtils; import org.apache.iotdb.rpc.TSStatusCode; import org.apache.iotdb.tsfile.file.metadata.enums.TSDataType; @@ -1135,11 +1136,14 @@ public class AnalyzeVisitor extends StatementVisitor<Analysis, MPPQueryContext> sourceTransformExpressions.add(countTimeSourceExpression); } } else { - // We just process first input Expression of AggregationFunction, + // We just process first input Expression of COUNT_IF, // keep other input Expressions as origin - // If AggregationFunction need more than one input series, - // we need to reconsider the process of it - sourceTransformExpressions.add(expression.getExpressions().get(0)); + if (SqlConstant.COUNT_IF.equalsIgnoreCase( + ((FunctionExpression) expression).getFunctionName())) { + sourceTransformExpressions.add(expression.getExpressions().get(0)); + } else { + sourceTransformExpressions.addAll(expression.getExpressions()); + } } } @@ -1211,8 +1215,13 @@ public class AnalyzeVisitor extends StatementVisitor<Analysis, MPPQueryContext> } else { for (Expression aggExpression : analysis.getAggregationExpressions()) { - // for AggregationExpression, only the first Expression of input need to transform - sourceTransformExpressions.add(aggExpression.getExpressions().get(0)); + // for COUNT_IF, only the first Expression of input need to transform + if (SqlConstant.COUNT_IF.equalsIgnoreCase( + ((FunctionExpression) aggExpression).getFunctionName())) { + sourceTransformExpressions.add(aggExpression.getExpressions().get(0)); + } else { + sourceTransformExpressions.addAll(aggExpression.getExpressions()); + } } } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/ExpressionTypeAnalyzer.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/ExpressionTypeAnalyzer.java index 8f4c1d6fe63..813c5b4d8ba 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/ExpressionTypeAnalyzer.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/ExpressionTypeAnalyzer.java @@ -407,25 +407,27 @@ public class ExpressionTypeAnalyzer { // based on the data type of their input. // Currently, for all aggregate functions without a fixed output type, the output type is // determined by the first input. - switch (aggregateFunctionName) { + switch (aggregateFunctionName.toLowerCase()) { case SqlConstant.MIN_TIME: case SqlConstant.MAX_TIME: + case SqlConstant.MIN_VALUE: + case SqlConstant.MAX_VALUE: + case SqlConstant.EXTREME: + case SqlConstant.LAST_VALUE: + case SqlConstant.FIRST_VALUE: case SqlConstant.COUNT: - case SqlConstant.TIME_DURATION: - case SqlConstant.COUNT_TIME: case SqlConstant.AVG: case SqlConstant.SUM: + case SqlConstant.COUNT_IF: + case SqlConstant.TIME_DURATION: + case SqlConstant.MODE: + case SqlConstant.COUNT_TIME: case SqlConstant.STDDEV: case SqlConstant.STDDEV_POP: case SqlConstant.STDDEV_SAMP: case SqlConstant.VARIANCE: case SqlConstant.VAR_POP: case SqlConstant.VAR_SAMP: - case SqlConstant.LAST_VALUE: - case SqlConstant.FIRST_VALUE: - case SqlConstant.MIN_VALUE: - case SqlConstant.MAX_VALUE: - case SqlConstant.MODE: case SqlConstant.MAX_BY: return expressionTypes.get(NodeRef.of(inputExpressions.get(0))); default: diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/TypeProvider.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/TypeProvider.java index d54bf4b1e7a..0c3bcee6466 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/TypeProvider.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/TypeProvider.java @@ -29,6 +29,8 @@ import java.util.HashMap; import java.util.Map; import java.util.Objects; +import static org.apache.iotdb.db.queryengine.plan.expression.leaf.TimestampOperand.TIMESTAMP_EXPRESSION_STRING; + public class TypeProvider { private final Map<String, TSDataType> typeMap; @@ -42,6 +44,8 @@ public class TypeProvider { public TypeProvider(Map<String, TSDataType> typeMap, TemplatedInfo templatedInfo) { this.typeMap = typeMap; this.templatedInfo = templatedInfo; + // The type of TimeStampOperand is INT64 + this.typeMap.putIfAbsent(TIMESTAMP_EXPRESSION_STRING, TSDataType.INT64); } public TSDataType getType(String symbol) { diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/expression/visitor/cartesian/ConcatDeviceAndBindSchemaForExpressionVisitor.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/expression/visitor/cartesian/ConcatDeviceAndBindSchemaForExpressionVisitor.java index 8217d0a92fd..b4fec47329a 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/expression/visitor/cartesian/ConcatDeviceAndBindSchemaForExpressionVisitor.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/expression/visitor/cartesian/ConcatDeviceAndBindSchemaForExpressionVisitor.java @@ -29,6 +29,7 @@ import org.apache.iotdb.db.queryengine.plan.expression.leaf.ConstantOperand; import org.apache.iotdb.db.queryengine.plan.expression.leaf.TimeSeriesOperand; import org.apache.iotdb.db.queryengine.plan.expression.leaf.TimestampOperand; import org.apache.iotdb.db.queryengine.plan.expression.multi.FunctionExpression; +import org.apache.iotdb.db.utils.constant.SqlConstant; import java.util.ArrayList; import java.util.Collection; @@ -56,11 +57,9 @@ public class ConcatDeviceAndBindSchemaForExpressionVisitor extendedExpressions.add(concatExpression); } - // We just process first input Expression of AggregationFunction, + // We just process first input Expression of COUNT_IF, // keep other input Expressions as origin and bind Type - // If AggregationFunction need more than one input series, - // we need to reconsider the process of it - if (functionExpression.isBuiltInAggregationFunctionExpression()) { + if (SqlConstant.COUNT_IF.equalsIgnoreCase(functionExpression.getFunctionName())) { List<Expression> children = functionExpression.getExpressions(); bindTypeForAggregationNonSeriesInputExpressions( functionExpression.getFunctionName(), children, extendedExpressions); diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/parser/ASTVisitor.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/parser/ASTVisitor.java index 57889ffdade..7512ef77011 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/parser/ASTVisitor.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/parser/ASTVisitor.java @@ -2715,15 +2715,15 @@ public class ASTVisitor extends IoTDBSqlParserBaseVisitor<Statement> { return parseFunctionExpression(context, canUseFullPath); } + if (context.time != null) { + return new TimestampOperand(); + } + if (context.fullPathInExpression() != null) { return new TimeSeriesOperand( parseFullPathInExpression(context.fullPathInExpression(), canUseFullPath)); } - if (context.time != null) { - return new TimestampOperand(); - } - if (context.constant() != null && !context.constant().isEmpty()) { return parseConstantOperand(context.constant(0)); } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/LogicalPlanBuilder.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/LogicalPlanBuilder.java index 038149b2b9a..4d013814ace 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/LogicalPlanBuilder.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/LogicalPlanBuilder.java @@ -700,17 +700,34 @@ public class LogicalPlanBuilder { AggregationDescriptor aggregationDescriptor, TypeProvider typeProvider) { List<TAggregationType> splitAggregations = SchemaUtils.splitPartialAggregation(aggregationDescriptor.getAggregationType()); - String inputExpressionStr = - aggregationDescriptor.getInputExpressions().get(0).getExpressionString(); for (TAggregationType aggregation : splitAggregations) { String functionName = aggregation.toString().toLowerCase(); TSDataType aggregationType = SchemaUtils.getAggregationType(functionName); + String inputExpressionStr = + getExpressionStringThatDeterminesReturnType(aggregationDescriptor); typeProvider.setType( - String.format("%s(%s)", functionName, inputExpressionStr), + String.format("%s(%s)", functionName, aggregationDescriptor.getParametersString()), aggregationType == null ? typeProvider.getType(inputExpressionStr) : aggregationType); } } + /** + * For aggregate functions that accept single input, we return the first input Expression. For + * aggregate functions that can accept multiple inputs, if the type of intermediate results is + * determined by the input Expression, we return the corresponding Expression used to determine + * the type of intermediate results. + */ + private static String getExpressionStringThatDeterminesReturnType( + AggregationDescriptor aggregationDescriptor) { + switch (aggregationDescriptor.getAggregationType()) { + case MAX_BY_Y_INPUT: + return aggregationDescriptor.getInputExpressions().get(1).getExpressionString(); + case MAX_BY_X_INPUT: + default: + return aggregationDescriptor.getInputExpressions().get(0).getExpressionString(); + } + } + public static void updateTypeProviderByPartialAggregation( CrossSeriesAggregationDescriptor aggregationDescriptor, TypeProvider typeProvider) { List<TAggregationType> splitAggregations = diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/plan/parameter/AggregationDescriptor.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/plan/parameter/AggregationDescriptor.java index a0166b0ba7a..3c223cd54ee 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/plan/parameter/AggregationDescriptor.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/plan/parameter/AggregationDescriptor.java @@ -192,7 +192,7 @@ public class AggregationDescriptor { * * <p>The parameter part -> root.sg.d.s1, sin(root.sg.d.s1) */ - protected String getParametersString() { + public String getParametersString() { if (parametersString == null) { StringBuilder builder = new StringBuilder(); if (!inputExpressions.isEmpty()) { diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/plan/parameter/CrossSeriesAggregationDescriptor.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/plan/parameter/CrossSeriesAggregationDescriptor.java index 165df8d10c9..f03285e599a 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/plan/parameter/CrossSeriesAggregationDescriptor.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/plan/parameter/CrossSeriesAggregationDescriptor.java @@ -95,7 +95,7 @@ public class CrossSeriesAggregationDescriptor extends AggregationDescriptor { * <p>The parameter part -> root.*.*.s1, 3 */ @Override - protected String getParametersString() { + public String getParametersString() { if (parametersString == null) { StringBuilder builder = new StringBuilder(outputExpression.getExpressionString()); for (int i = 1; i < expressionNumOfOneInput; i++) { diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/utils/SchemaUtils.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/utils/SchemaUtils.java index b7ba465abf7..4e6e9619669 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/utils/SchemaUtils.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/utils/SchemaUtils.java @@ -205,7 +205,7 @@ public class SchemaUtils { case VAR_SAMP: return Collections.singletonList(TAggregationType.VAR_SAMP); case MAX_BY: - return Collections.singletonList(TAggregationType.MAX_BY); + return Arrays.asList(TAggregationType.MAX_BY_X_INPUT, TAggregationType.MAX_BY_Y_INPUT); case AVG: return Arrays.asList(TAggregationType.COUNT, TAggregationType.SUM); case TIME_DURATION: diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/utils/constant/TestConstant.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/utils/constant/TestConstant.java index 46dcb37ccff..12d881ad760 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/utils/constant/TestConstant.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/utils/constant/TestConstant.java @@ -75,6 +75,10 @@ public class TestConstant { return String.format("min_value(%s)", path); } + public static String maxBy(String x, String y) { + return String.format("max_by(%s, %s)", x, y); + } + private TestConstant() {} public static String getTestTsFilePath( diff --git a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/execution/aggregation/AccumulatorTest.java b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/execution/aggregation/AccumulatorTest.java index d0f92222888..55fe3edeead 100644 --- a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/execution/aggregation/AccumulatorTest.java +++ b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/execution/aggregation/AccumulatorTest.java @@ -29,6 +29,7 @@ import org.apache.iotdb.tsfile.read.common.block.column.BinaryColumnBuilder; import org.apache.iotdb.tsfile.read.common.block.column.Column; import org.apache.iotdb.tsfile.read.common.block.column.ColumnBuilder; import org.apache.iotdb.tsfile.read.common.block.column.DoubleColumnBuilder; +import org.apache.iotdb.tsfile.read.common.block.column.IntColumnBuilder; import org.apache.iotdb.tsfile.read.common.block.column.LongColumnBuilder; import org.apache.iotdb.tsfile.read.common.block.column.TimeColumnBuilder; import org.apache.iotdb.tsfile.utils.Binary; @@ -39,6 +40,7 @@ import org.junit.Before; import org.junit.Test; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.List; @@ -55,12 +57,14 @@ public class AccumulatorTest { private void initInputTsBlock() { List<TSDataType> dataTypes = new ArrayList<>(); dataTypes.add(TSDataType.DOUBLE); + dataTypes.add(TSDataType.INT32); TsBlockBuilder tsBlockBuilder = new TsBlockBuilder(dataTypes); TimeColumnBuilder timeColumnBuilder = tsBlockBuilder.getTimeColumnBuilder(); ColumnBuilder[] columnBuilders = tsBlockBuilder.getValueColumnBuilders(); for (int i = 0; i < 100; i++) { timeColumnBuilder.writeLong(i); columnBuilders[0].writeDouble(i * 1.0); + columnBuilders[1].writeInt(-i); tsBlockBuilder.declarePosition(); } rawData = tsBlockBuilder.build(); @@ -69,13 +73,21 @@ public class AccumulatorTest { statistics.update(100L, 100d); } - public Column[] getTimeAndValueColumn(int columnIndex) { + private Column[] getTimeAndValueColumn(int columnIndex) { Column[] columns = new Column[2]; columns[0] = rawData.getTimeColumn(); columns[1] = rawData.getColumn(columnIndex); return columns; } + private Column[] getTimeAndTwoValueColumns(int columnIndex1, int columnIndex2) { + Column[] columns = new Column[3]; + columns[0] = rawData.getTimeColumn(); + columns[1] = rawData.getColumn(columnIndex1); + columns[2] = rawData.getColumn(columnIndex2); + return columns; + } + @Test public void avgAccumulatorTest() { Accumulator avgAccumulator = @@ -864,4 +876,44 @@ public class AccumulatorTest { varSampAccumulator.outputFinal(finalResult); Assert.assertEquals(841.6666666666666, finalResult.build().getDouble(0), 0.001); } + + @Test + public void maxByAccumulatorTest() { + Accumulator maxByAccumulator = + AccumulatorFactory.createAccumulator( + TAggregationType.MAX_BY, + Arrays.asList(TSDataType.INT32, TSDataType.DOUBLE), + Collections.emptyList(), + Collections.emptyMap(), + true); + Assert.assertEquals(TSDataType.INT32, maxByAccumulator.getIntermediateType()[0]); + Assert.assertEquals(TSDataType.DOUBLE, maxByAccumulator.getIntermediateType()[1]); + Assert.assertEquals(TSDataType.INT32, maxByAccumulator.getFinalType()); + // Returns null if there's no data + ColumnBuilder[] intermediateResult = new ColumnBuilder[2]; + intermediateResult[0] = new IntColumnBuilder(null, 1); + intermediateResult[1] = new DoubleColumnBuilder(null, 1); + maxByAccumulator.outputIntermediate(intermediateResult); + Assert.assertTrue(intermediateResult[0].build().isNull(0)); + Assert.assertTrue(intermediateResult[1].build().isNull(0)); + ColumnBuilder finalResult = new IntColumnBuilder(null, 1); + maxByAccumulator.outputFinal(finalResult); + Assert.assertTrue(finalResult.build().isNull(0)); + + Column[] timeAndValueColumn = getTimeAndTwoValueColumns(1, 0); + maxByAccumulator.addInput(timeAndValueColumn, null, rawData.getPositionCount() - 1); + Assert.assertFalse(maxByAccumulator.hasFinalResult()); + intermediateResult[0] = new IntColumnBuilder(null, 1); + intermediateResult[1] = new DoubleColumnBuilder(null, 1); + maxByAccumulator.outputIntermediate(intermediateResult); + Assert.assertEquals(-99, intermediateResult[0].build().getInt(0)); + Assert.assertEquals(99d, intermediateResult[1].build().getDouble(0), 0.001); + + // add intermediate result as input + maxByAccumulator.addIntermediate( + new Column[] {intermediateResult[0].build(), intermediateResult[1].build()}); + finalResult = new IntColumnBuilder(null, 1); + maxByAccumulator.outputFinal(finalResult); + Assert.assertEquals(-99, finalResult.build().getInt(0)); + } } diff --git a/iotdb-protocol/thrift-commons/src/main/thrift/common.thrift b/iotdb-protocol/thrift-commons/src/main/thrift/common.thrift index 5a6325cd008..21a69cf689e 100644 --- a/iotdb-protocol/thrift-commons/src/main/thrift/common.thrift +++ b/iotdb-protocol/thrift-commons/src/main/thrift/common.thrift @@ -197,5 +197,7 @@ enum TAggregationType { VARIANCE, VAR_POP, VAR_SAMP, - MAX_BY + MAX_BY, + MAX_BY_X_INPUT, + MAX_BY_Y_INPUT } \ No newline at end of file
