This is an automated email from the ASF dual-hosted git repository.
jackietien pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iotdb.git
The following commit(s) were added to refs/heads/master by this push:
new 5cda97b25d5 Support aggregation functions in RPR
5cda97b25d5 is described below
commit 5cda97b25d5836391ddcb65ad5a2a13361c99334
Author: Le Yang <[email protected]>
AuthorDate: Wed Jul 2 17:01:20 2025 +0800
Support aggregation functions in RPR
---
.../it/db/it/IoTDBPatternAggregationIT.java | 492 +++++++++++++++++++++
.../process/PatternRecognitionOperator.java | 8 +
.../rowpattern/PatternAggregationTracker.java | 156 +++++++
.../process/rowpattern/PatternAggregator.java | 178 ++++++++
.../process/rowpattern/PatternAggregators.java | 74 ++++
.../rowpattern/PatternPartitionExecutor.java | 13 +
.../rowpattern/PatternVariableRecognizer.java | 8 +-
.../rowpattern/PhysicalAggregationPointer.java | 32 ++
.../process/rowpattern/expression/Computation.java | 15 +-
.../expression/PatternExpressionComputation.java | 18 +-
.../process/rowpattern/matcher/Matcher.java | 29 +-
.../plan/planner/TableOperatorGenerator.java | 141 +++++-
.../relational/analyzer/ExpressionAnalyzer.java | 155 ++++++-
.../plan/relational/planner/RelationPlanner.java | 57 ++-
.../planner/optimizations/SymbolMapper.java | 39 +-
.../planner/rowpattern/AggregationLabelSet.java | 100 +++++
.../rowpattern/AggregationValuePointer.java | 161 +++++++
.../rowpattern/ExpressionAndValuePointers.java | 16 +
.../operator/process/rowpattern/MatcherTest.java | 4 +-
19 files changed, 1654 insertions(+), 42 deletions(-)
diff --git
a/integration-test/src/test/java/org/apache/iotdb/relational/it/db/it/IoTDBPatternAggregationIT.java
b/integration-test/src/test/java/org/apache/iotdb/relational/it/db/it/IoTDBPatternAggregationIT.java
new file mode 100644
index 00000000000..772002a93d8
--- /dev/null
+++
b/integration-test/src/test/java/org/apache/iotdb/relational/it/db/it/IoTDBPatternAggregationIT.java
@@ -0,0 +1,492 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iotdb.relational.it.db.it;
+
+import org.apache.iotdb.it.env.EnvFactory;
+import org.apache.iotdb.it.framework.IoTDBTestRunner;
+import org.apache.iotdb.itbase.category.TableClusterIT;
+import org.apache.iotdb.itbase.category.TableLocalStandaloneIT;
+
+import org.junit.AfterClass;
+import org.junit.BeforeClass;
+import org.junit.Test;
+import org.junit.experimental.categories.Category;
+import org.junit.runner.RunWith;
+
+import java.sql.Connection;
+import java.sql.Statement;
+
+import static org.apache.iotdb.db.it.utils.TestUtils.tableResultSetEqualTest;
+import static org.junit.Assert.fail;
+
+@RunWith(IoTDBTestRunner.class)
+@Category({TableLocalStandaloneIT.class, TableClusterIT.class})
+public class IoTDBPatternAggregationIT {
+ private static final String DATABASE_NAME = "test";
+ private static final String[] sqls =
+ new String[] {
+ "CREATE DATABASE " + DATABASE_NAME,
+ "USE " + DATABASE_NAME,
+ // TABLE: beidou
+ "CREATE TABLE beidou(device_id STRING TAG, department STRING FIELD,
altitude DOUBLE FIELD)",
+ // d1 and DEP1s
+ "INSERT INTO beidou VALUES (2025-01-01T00:00:00, 'd1', 'DEP1', 480.5)",
+ "INSERT INTO beidou VALUES (2025-01-01T00:01:00, 'd1', 'DEP1', 510.2)",
+ "INSERT INTO beidou VALUES (2025-01-01T00:02:00, 'd1', 'DEP1', 508.7)",
+ "INSERT INTO beidou VALUES (2025-01-01T00:04:00, 'd1', 'DEP1', 495.0)",
+ "INSERT INTO beidou VALUES (2025-01-01T00:05:00, 'd1', 'DEP1', 523.0)",
+ "INSERT INTO beidou VALUES (2025-01-01T00:06:00, 'd1', 'DEP1', 517.4)",
+ // d2 and DEP1
+ "INSERT INTO beidou VALUES (2025-01-01T00:07:00, 'd2', 'DEP1', 530.1)",
+ "INSERT INTO beidou VALUES (2025-01-01T00:08:00, 'd2', 'DEP1', 540.4)",
+ "INSERT INTO beidou VALUES (2025-01-01T00:09:00, 'd2', 'DEP1', 498.2)",
+ // DEP2
+ "INSERT INTO beidou VALUES (2025-01-01T00:10:00, 'd3', 'DEP2', 470.0)",
+ "INSERT INTO beidou VALUES (2025-01-01T00:11:00, 'd3', 'DEP2', 505.0)",
+ "INSERT INTO beidou VALUES (2025-01-01T00:12:00, 'd3', 'DEP2', 480.0)",
+ // altitude lower than 500
+ "INSERT INTO beidou VALUES (2025-01-01T00:13:00, 'd4', 'DEP_1', 450)",
+ "INSERT INTO beidou VALUES (2025-01-01T00:14:00, 'd4', 'DEP_1', 470)",
+ "INSERT INTO beidou VALUES (2025-01-01T00:15:00, 'd4', 'DEP_1', 490)",
+ // outside the time range
+ "INSERT INTO beidou VALUES (2024-01-01T00:30:00, 'd1', 'DEP_1', 600)",
+ "INSERT INTO beidou VALUES (2025-01-01T02:00:00, 'd1', 'DEP_1', 570)",
+
+ // TABLE: t1
+ "CREATE TABLE t1(totalprice DOUBLE FIELD)",
+ "INSERT INTO t1 VALUES (2025-01-01T00:01:00, 10)",
+ "INSERT INTO t1 VALUES (2025-01-01T00:02:00, 20)",
+ "INSERT INTO t1 VALUES (2025-01-01T00:03:00, 30)",
+ "INSERT INTO t1 VALUES (2025-01-01T00:04:00, 40)",
+ "INSERT INTO t1 VALUES (2025-01-01T00:05:00, 10)",
+ "INSERT INTO t1 VALUES (2025-01-01T00:06:00, 20)",
+ "INSERT INTO t1 VALUES (2025-01-01T00:07:00, 30)",
+
+ // TABLE: t2
+ "CREATE TABLE t2(totalprice DOUBLE FIELD)",
+ "INSERT INTO t2 VALUES (2025-01-01T00:01:00, 4)",
+ "INSERT INTO t2 VALUES (2025-01-01T00:02:00, 6)",
+ "INSERT INTO t2 VALUES (2025-01-01T00:03:00, 5)",
+ "INSERT INTO t2 VALUES (2025-01-01T00:04:00, 13)",
+
+ // TABLE: t3
+ "CREATE TABLE t3(totalprice DOUBLE FIELD)",
+ "INSERT INTO t3 VALUES (2025-01-01T00:01:00, 4)",
+ "INSERT INTO t3 VALUES (2025-01-01T00:02:00, 6)",
+ "INSERT INTO t3 VALUES (2025-01-01T00:03:00, 7)",
+ "INSERT INTO t3 VALUES (2025-01-01T00:04:00, 7)",
+ "INSERT INTO t3 VALUES (2025-01-01T00:05:00, -8)",
+ };
+
+ private static void insertData() {
+ try (Connection connection = EnvFactory.getEnv().getTableConnection();
+ Statement statement = connection.createStatement()) {
+ for (String sql : sqls) {
+ statement.execute(sql);
+ }
+ } catch (Exception e) {
+ fail("insertData failed.");
+ }
+ }
+
+ @BeforeClass
+ public static void setUp() throws Exception {
+ EnvFactory.getEnv().initClusterEnvironment();
+ insertData();
+ }
+
+ @AfterClass
+ public static void tearDown() throws Exception {
+ EnvFactory.getEnv().cleanClusterEnvironment();
+ }
+
+ /**
+ * Search range: all devices whose department is 'DEP_1', each device's data
is grouped
+ * separately, and the time range is between 2025-01-01T00:00:00 and
2025-01-01T01:00:00.
+ *
+ * <p>Event analysis: Whenever the altitude exceeds 500 and then drops below
500, it is marked as
+ * an event.
+ */
+ @Test
+ public void testEventRecognition() {
+ String[] expectedHeader =
+ new String[] {
+ "device_id", "match", "event_start", "event_end", "max_altitude",
"sum_altitude", "count"
+ };
+ String[] retArray =
+ new String[] {
+
"d1,1,2025-01-01T00:01:00.000Z,2025-01-01T00:02:00.000Z,510.2,1018.9,2,",
+
"d1,2,2025-01-01T00:05:00.000Z,2025-01-01T00:06:00.000Z,523.0,1040.4,2,",
+
"d2,1,2025-01-01T00:07:00.000Z,2025-01-01T00:08:00.000Z,540.4,1070.5,2,",
+ };
+ tableResultSetEqualTest(
+ "SELECT * "
+ + "FROM ( "
+ + " SELECT time, device_id, altitude "
+ + " FROM beidou "
+ + " WHERE department = 'DEP1' AND time >= 2025-01-01T00:00:00
AND time < 2025-01-01T01:00:00 "
+ + ")"
+ + "MATCH_RECOGNIZE ( "
+ + " PARTITION BY device_id "
+ + " MEASURES "
+ + " MATCH_NUMBER() AS match, "
+ + " RPR_FIRST(A.time) AS event_start, "
+ + " RPR_LAST(A.time) AS event_end, "
+ + " MAX(A.altitude) AS max_altitude, "
+ + " SUM(A.altitude) AS sum_altitude, "
+ + " COUNT(A.altitude) AS count "
+ + " ONE ROW PER MATCH "
+ + " PATTERN (A+) "
+ + " DEFINE "
+ + " A AS A.altitude > 500 "
+ + ") AS m "
+ + "ORDER BY device_id, match ",
+ expectedHeader,
+ retArray,
+ DATABASE_NAME);
+ }
+
+ @Test
+ public void test1() {
+ String[] expectedHeader =
+ new String[] {
+ "time",
+ "match",
+ "count1",
+ "count2",
+ "max",
+ "min",
+ "sum1",
+ "sum2",
+ "avg1",
+ "avg2",
+ "totalprice"
+ };
+ String[] retArray =
+ new String[] {
+
"2025-01-01T00:01:00.000Z,1,1,4,10.0,10.0,10.0,100.0,10.0,25.0,10.0,",
+
"2025-01-01T00:02:00.000Z,1,2,4,20.0,10.0,30.0,100.0,15.0,25.0,20.0,",
+
"2025-01-01T00:03:00.000Z,1,3,4,30.0,10.0,60.0,100.0,20.0,25.0,30.0,",
+
"2025-01-01T00:04:00.000Z,1,4,4,40.0,10.0,100.0,100.0,25.0,25.0,40.0,",
+ "2025-01-01T00:05:00.000Z,2,1,3,10.0,10.0,10.0,60.0,10.0,20.0,10.0,",
+ "2025-01-01T00:06:00.000Z,2,2,3,20.0,10.0,30.0,60.0,15.0,20.0,20.0,",
+ "2025-01-01T00:07:00.000Z,2,3,3,30.0,10.0,60.0,60.0,20.0,20.0,30.0,",
+ };
+ tableResultSetEqualTest(
+ "SELECT m.time, m.match, m.count1, m.count2, m.max, m.min, m.sum1,
m.sum2, m.avg1, m.avg2, m.totalprice "
+ + "FROM t1 "
+ + "MATCH_RECOGNIZE ( "
+ + " MEASURES "
+ + " MATCH_NUMBER() AS match, "
+ + " COUNT(totalprice) AS count1, "
+ + " FINAL COUNT(totalprice) AS count2, "
+ + " MAX(totalprice) AS max, "
+ + " MIN(totalprice) AS min, "
+ + " SUM(totalprice) AS sum1, "
+ + " FINAL SUM(totalprice) AS sum2, "
+ + " AVG(totalprice) AS avg1, "
+ + " FINAL AVG(totalprice) AS avg2 "
+ + " ALL ROWS PER MATCH "
+ + " PATTERN (A B C D?) "
+ + " DEFINE "
+ + " A AS A.totalprice = 10, "
+ + " B AS B.totalprice = 20, "
+ + " C AS C.totalprice = 30, "
+ + " D AS D.totalprice = 40 "
+ + ") AS m ",
+ expectedHeader,
+ retArray,
+ DATABASE_NAME);
+ }
+
+ @Test
+ public void test2() {
+ String[] expectedHeader =
+ new String[] {
+ "time",
+ "match",
+ "label",
+ "count_0",
+ "count_1",
+ "count_2",
+ "count_c",
+ "final_count_c",
+ "count_u",
+ "final_count_u",
+ "totalprice"
+ };
+ String[] retArray =
+ new String[] {
+ "2025-01-01T00:01:00.000Z,1,A,1,1,1,0,1,0,2,10.0,",
+ "2025-01-01T00:02:00.000Z,1,B,2,2,2,0,1,1,2,20.0,",
+ "2025-01-01T00:03:00.000Z,1,C,3,3,3,1,1,1,2,30.0,",
+ "2025-01-01T00:04:00.000Z,1,D,4,4,4,1,1,2,2,40.0,",
+ "2025-01-01T00:05:00.000Z,2,A,1,1,1,0,1,0,1,10.0,",
+ "2025-01-01T00:06:00.000Z,2,B,2,2,2,0,1,1,1,20.0,",
+ "2025-01-01T00:07:00.000Z,2,C,3,3,3,1,1,1,1,30.0,",
+ };
+ tableResultSetEqualTest(
+ "SELECT m.time, m.match, m.label, m.count_0, m.count_1, m.count_2,
m.count_c, m.final_count_c, m.count_u, m.final_count_u, m.totalprice "
+ + "FROM t1 "
+ + "MATCH_RECOGNIZE ( "
+ + " MEASURES "
+ + " MATCH_NUMBER() AS match, "
+ + " CLASSIFIER() AS label, "
+ + " COUNT() AS count_0, "
+ + " COUNT(*) AS count_1, "
+ + " COUNT(totalprice) AS count_2, "
+ + " COUNT(C.totalprice) AS count_c, "
+ + " FINAL COUNT(C.totalprice) AS final_count_c, "
+ + " COUNT(U.totalprice) AS count_u, "
+ + " FINAL COUNT(U.totalprice) AS final_count_u "
+ + " ALL ROWS PER MATCH "
+ + " PATTERN (A B C D?) "
+ + " SUBSET U = (B, D) "
+ + " DEFINE "
+ + " A AS A.totalprice = 10, "
+ + " B AS B.totalprice = 20, "
+ + " C AS C.totalprice = 30, "
+ + " D AS D.totalprice = 40 "
+ + ") AS m ",
+ expectedHeader,
+ retArray,
+ DATABASE_NAME);
+ }
+
+ @Test
+ public void test3() {
+ String[] expectedHeader =
+ new String[] {
+ "time",
+ "match",
+ "label",
+ "sum",
+ "sum_c",
+ "final_sum_c",
+ "sum_u",
+ "final_sum_u",
+ "totalprice"
+ };
+ String[] retArray =
+ new String[] {
+ "2025-01-01T00:01:00.000Z,1,A,10.0,0.0,30.0,0.0,60.0,10.0,",
+ "2025-01-01T00:02:00.000Z,1,B,30.0,0.0,30.0,20.0,60.0,20.0,",
+ "2025-01-01T00:03:00.000Z,1,C,60.0,30.0,30.0,20.0,60.0,30.0,",
+ "2025-01-01T00:04:00.000Z,1,D,100.0,30.0,30.0,60.0,60.0,40.0,",
+ "2025-01-01T00:05:00.000Z,2,A,10.0,0.0,30.0,0.0,20.0,10.0,",
+ "2025-01-01T00:06:00.000Z,2,B,30.0,0.0,30.0,20.0,20.0,20.0,",
+ "2025-01-01T00:07:00.000Z,2,C,60.0,30.0,30.0,20.0,20.0,30.0,",
+ };
+ tableResultSetEqualTest(
+ "SELECT m.time, m.match, m.label, m.sum, m.sum_c, m.final_sum_c,
m.sum_u, m.final_sum_u, m.totalprice "
+ + "FROM t1 "
+ + "MATCH_RECOGNIZE ( "
+ + " MEASURES "
+ + " MATCH_NUMBER() AS match, "
+ + " CLASSIFIER() AS label, "
+ + " SUM(totalprice) AS sum, "
+ + " SUM(C.totalprice) AS sum_c, "
+ + " FINAL SUM(C.totalprice) AS final_sum_c, "
+ + " SUM(U.totalprice) AS sum_u, "
+ + " FINAL SUM(U.totalprice) AS final_sum_u "
+ + " ALL ROWS PER MATCH "
+ + " PATTERN (A B C D?) "
+ + " SUBSET U = (B, D) "
+ + " DEFINE "
+ + " A AS A.totalprice = 10, "
+ + " B AS B.totalprice = 20, "
+ + " C AS C.totalprice = 30, "
+ + " D AS D.totalprice = 40 "
+ + ") AS m ",
+ expectedHeader,
+ retArray,
+ DATABASE_NAME);
+ }
+
+ @Test
+ public void test4() {
+ String[] expectedHeader =
+ new String[] {
+ "time",
+ "match",
+ "label",
+ "avg",
+ "avg_c",
+ "final_avg_c",
+ "avg_u",
+ "final_avg_u",
+ "totalprice"
+ };
+ String[] retArray =
+ new String[] {
+ "2025-01-01T00:01:00.000Z,1,A,10.0,0.0,30.0,0.0,30.0,10.0,",
+ "2025-01-01T00:02:00.000Z,1,B,15.0,0.0,30.0,20.0,30.0,20.0,",
+ "2025-01-01T00:03:00.000Z,1,C,20.0,30.0,30.0,20.0,30.0,30.0,",
+ "2025-01-01T00:04:00.000Z,1,D,25.0,30.0,30.0,30.0,30.0,40.0,",
+ "2025-01-01T00:05:00.000Z,2,A,10.0,0.0,30.0,0.0,20.0,10.0,",
+ "2025-01-01T00:06:00.000Z,2,B,15.0,0.0,30.0,20.0,20.0,20.0,",
+ "2025-01-01T00:07:00.000Z,2,C,20.0,30.0,30.0,20.0,20.0,30.0,",
+ };
+ tableResultSetEqualTest(
+ "SELECT m.time, m.match, m.label, m.avg, m.avg_c, m.final_avg_c,
m.avg_u, m.final_avg_u, m.totalprice "
+ + "FROM t1 "
+ + "MATCH_RECOGNIZE ( "
+ + " MEASURES "
+ + " MATCH_NUMBER() AS match, "
+ + " CLASSIFIER() AS label, "
+ + " AVG(totalprice) AS avg, "
+ + " AVG(C.totalprice) AS avg_c, "
+ + " FINAL AVG(C.totalprice) AS final_avg_c, "
+ + " AVG(U.totalprice) AS avg_u, "
+ + " FINAL AVG(U.totalprice) AS final_avg_u "
+ + " ALL ROWS PER MATCH "
+ + " PATTERN (A B C D?) "
+ + " SUBSET U = (B, D) "
+ + " DEFINE "
+ + " A AS A.totalprice = 10, "
+ + " B AS B.totalprice = 20, "
+ + " C AS C.totalprice = 30, "
+ + " D AS D.totalprice = 40 "
+ + ") AS m ",
+ expectedHeader,
+ retArray,
+ DATABASE_NAME);
+ }
+
+ @Test
+ public void test5() {
+ String[] expectedHeader =
+ new String[] {
+ "time", "match", "firstTime", "lastTime", "maxTime", "minTime",
"firstVal", "lastVal"
+ };
+ String[] retArray =
+ new String[] {
+
"2025-01-01T00:01:00.000Z,1,2025-01-01T00:01:00.000Z,2025-01-01T00:01:00.000Z,2025-01-01T00:01:00.000Z,2025-01-01T00:01:00.000Z,10.0,10.0,",
+
"2025-01-01T00:02:00.000Z,1,2025-01-01T00:01:00.000Z,2025-01-01T00:02:00.000Z,2025-01-01T00:02:00.000Z,2025-01-01T00:01:00.000Z,10.0,20.0,",
+
"2025-01-01T00:03:00.000Z,1,2025-01-01T00:01:00.000Z,2025-01-01T00:03:00.000Z,2025-01-01T00:03:00.000Z,2025-01-01T00:01:00.000Z,10.0,30.0,",
+
"2025-01-01T00:04:00.000Z,1,2025-01-01T00:01:00.000Z,2025-01-01T00:04:00.000Z,2025-01-01T00:04:00.000Z,2025-01-01T00:01:00.000Z,10.0,40.0,",
+
"2025-01-01T00:05:00.000Z,2,2025-01-01T00:05:00.000Z,2025-01-01T00:05:00.000Z,2025-01-01T00:05:00.000Z,2025-01-01T00:05:00.000Z,10.0,10.0,",
+
"2025-01-01T00:06:00.000Z,2,2025-01-01T00:05:00.000Z,2025-01-01T00:06:00.000Z,2025-01-01T00:06:00.000Z,2025-01-01T00:05:00.000Z,10.0,20.0,",
+
"2025-01-01T00:07:00.000Z,2,2025-01-01T00:05:00.000Z,2025-01-01T00:07:00.000Z,2025-01-01T00:07:00.000Z,2025-01-01T00:05:00.000Z,10.0,30.0,"
+ };
+ tableResultSetEqualTest(
+ "SELECT m.time, m.match, m.firstTime, m.lastTime, m.maxTime,
m.minTime, m.firstVal, m.lastVal "
+ + "FROM t1 "
+ + "MATCH_RECOGNIZE ( "
+ + " MEASURES "
+ + " MATCH_NUMBER() AS match, "
+ + " FIRST_BY(time, totalprice) AS firstTime, "
+ + " LAST_BY(time, totalprice) AS lastTime, "
+ + " MAX_BY(time, totalprice) AS maxTime, "
+ + " MIN_BY(time, totalprice) AS minTime, "
+ + " FIRST(totalprice) AS firstVal, "
+ + " LAST(totalprice) AS lastVal "
+ + " ALL ROWS PER MATCH "
+ + " PATTERN (A B C D?) "
+ + " DEFINE "
+ + " A AS totalprice = 10, "
+ + " B AS totalprice = 20, "
+ + " C AS totalprice = 30, "
+ + " D AS totalprice = 40 "
+ + ") AS m ",
+ expectedHeader,
+ retArray,
+ DATABASE_NAME);
+ }
+
+ @Test
+ public void test6() {
+ String[] expectedHeader =
+ new String[] {
+ "time", "match", "mode", "extreme", "avg", "var_0", "var_1",
"var_2", "std_0", "std_1",
+ "std_2"
+ };
+ String[] retArray =
+ new String[] {
+ // [4]
+ "2025-01-01T00:01:00.000Z,1,4.0,4.0,4.0,0.0,0.0,0.0,0.0,0.0,0.0,",
+ // [4, 6]
+
"2025-01-01T00:02:00.000Z,1,4.0,6.0,5.0,2.0,2.0,1.0,1.414214,1.414214,1.0,",
+ // [4, 6, 7]
+
"2025-01-01T00:03:00.000Z,1,4.0,7.0,5.666667,2.333333,2.333333,1.555556,1.527525,1.527525,1.247219,",
+ // [4, 6, 7, 7]
+
"2025-01-01T00:04:00.000Z,1,7.0,7.0,6.0,2.0,2.0,1.5,1.414214,1.414214,1.224745,",
+ // [4, 6, 7, 7, -8]
+
"2025-01-01T00:05:00.000Z,1,7.0,-8.0,3.2,40.7,40.7,32.56,6.379655,6.379655,5.706137,"
+ };
+
+ tableResultSetEqualTest(
+ "SELECT m.time, m.match, m.mode, m.extreme, "
+ + "ROUND(m.avg, 6) AS avg, "
+ + "ROUND(m.var_0, 6) AS var_0, ROUND(m.var_1, 6) AS var_1,
ROUND(m.var_2, 6) AS var_2, "
+ + "ROUND(m.std_0, 6) AS std_0, ROUND(m.std_1, 6) AS std_1,
ROUND(m.std_2, 6) AS std_2 "
+ + "FROM t3 "
+ + "MATCH_RECOGNIZE ( "
+ + " MEASURES "
+ + " MATCH_NUMBER() AS match, "
+ + " MODE(totalprice) AS mode, "
+ + " EXTREME(totalprice) AS extreme, "
+ + " AVG(totalprice) AS avg, "
+ + " VARIANCE(totalprice) AS var_0, "
+ + " VAR_SAMP(totalprice) AS var_1, "
+ + " VAR_POP(totalprice) AS var_2, "
+ + " STDDEV(totalprice) AS std_0, "
+ + " STDDEV_SAMP(totalprice) AS std_1, "
+ + " STDDEV_POP(totalprice) AS std_2 "
+ + " ALL ROWS PER MATCH "
+ + " PATTERN (A+) "
+ + " DEFINE "
+ + " A AS true "
+ + ") AS m ",
+ expectedHeader,
+ retArray,
+ DATABASE_NAME);
+ }
+
+ @Test
+ public void testAggregationsInDefineClause() {
+ String[] expectedHeader =
+ new String[] {"time", "match", "label", "avg", "running_avg_b",
"totalprice"};
+ String[] retArray =
+ new String[] {
+ "2025-01-01T00:01:00.000Z,1,B,4.0,4.0,4.0,",
+ "2025-01-01T00:02:00.000Z,1,A,5.0,4.0,6.0,",
+ "2025-01-01T00:03:00.000Z,1,A,5.0,4.0,5.0,",
+ "2025-01-01T00:04:00.000Z,1,B,7.0,8.5,13.0,",
+ };
+ tableResultSetEqualTest(
+ "SELECT m.time, m.match, m.label, m.avg, m.running_avg_b, m.totalprice
"
+ + "FROM t2 "
+ + "MATCH_RECOGNIZE ( "
+ + " MEASURES "
+ + " MATCH_NUMBER() AS match, "
+ + " CLASSIFIER() AS label, "
+ + " RUNNING AVG(totalprice) AS avg, "
+ + " RUNNING AVG(B.totalprice) AS running_avg_b "
+ + " ALL ROWS PER MATCH "
+ + " PATTERN ((A | B)*) "
+ + " DEFINE "
+ + " A AS AVG(totalprice) = 5 "
+ + ") AS m ",
+ expectedHeader,
+ retArray,
+ DATABASE_NAME);
+ }
+}
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/PatternRecognitionOperator.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/PatternRecognitionOperator.java
index 322cc26164d..0275b44e145 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/PatternRecognitionOperator.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/PatternRecognitionOperator.java
@@ -23,6 +23,7 @@ import
org.apache.iotdb.db.queryengine.execution.MemoryEstimationHelper;
import org.apache.iotdb.db.queryengine.execution.operator.Operator;
import org.apache.iotdb.db.queryengine.execution.operator.OperatorContext;
import
org.apache.iotdb.db.queryengine.execution.operator.process.rowpattern.LogicalIndexNavigation;
+import
org.apache.iotdb.db.queryengine.execution.operator.process.rowpattern.PatternAggregator;
import
org.apache.iotdb.db.queryengine.execution.operator.process.rowpattern.PatternPartitionExecutor;
import
org.apache.iotdb.db.queryengine.execution.operator.process.rowpattern.PatternVariableRecognizer;
import
org.apache.iotdb.db.queryengine.execution.operator.process.rowpattern.expression.PatternExpressionComputation;
@@ -83,6 +84,7 @@ public class PatternRecognitionOperator implements
ProcessOperator {
private final Matcher matcher;
private final List<PatternVariableRecognizer.PatternVariableComputation>
labelPatternVariableComputations;
+ private final List<PatternAggregator> patternAggregators;
private final List<PatternExpressionComputation> measureComputations;
private final List<String> labelNames;
@@ -103,6 +105,7 @@ public class PatternRecognitionOperator implements
ProcessOperator {
Optional<LogicalIndexNavigation> skipToNavigation,
Matcher matcher,
List<PatternVariableRecognizer.PatternVariableComputation>
labelPatternVariableComputations,
+ List<PatternAggregator> patternAggregators,
List<PatternExpressionComputation> measureComputations,
List<String> labelNames) {
this.operatorContext = operatorContext;
@@ -132,6 +135,7 @@ public class PatternRecognitionOperator implements
ProcessOperator {
this.skipToNavigation = skipToNavigation;
this.matcher = matcher;
this.labelPatternVariableComputations =
ImmutableList.copyOf(labelPatternVariableComputations);
+ this.patternAggregators = ImmutableList.copyOf(patternAggregators);
this.measureComputations = ImmutableList.copyOf(measureComputations);
this.labelNames = ImmutableList.copyOf(labelNames);
@@ -206,6 +210,7 @@ public class PatternRecognitionOperator implements
ProcessOperator {
skipToNavigation,
matcher,
labelPatternVariableComputations,
+ patternAggregators,
measureComputations,
labelNames);
cachedPartitionExecutors.addLast(partitionExecutor);
@@ -267,6 +272,7 @@ public class PatternRecognitionOperator implements
ProcessOperator {
skipToNavigation,
matcher,
labelPatternVariableComputations,
+ patternAggregators,
measureComputations,
labelNames);
@@ -305,6 +311,7 @@ public class PatternRecognitionOperator implements
ProcessOperator {
skipToNavigation,
matcher,
labelPatternVariableComputations,
+ patternAggregators,
measureComputations,
labelNames);
} else {
@@ -324,6 +331,7 @@ public class PatternRecognitionOperator implements
ProcessOperator {
skipToNavigation,
matcher,
labelPatternVariableComputations,
+ patternAggregators,
measureComputations,
labelNames);
// Clear TsBlock of last partition
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/rowpattern/PatternAggregationTracker.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/rowpattern/PatternAggregationTracker.java
new file mode 100644
index 00000000000..aa0c25b816e
--- /dev/null
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/rowpattern/PatternAggregationTracker.java
@@ -0,0 +1,156 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iotdb.db.queryengine.execution.operator.process.rowpattern;
+
+import
org.apache.iotdb.db.queryengine.execution.operator.process.rowpattern.matcher.ArrayView;
+import
org.apache.iotdb.db.queryengine.execution.operator.process.rowpattern.matcher.IntList;
+
+import java.util.Set;
+
+import static com.google.common.base.Preconditions.checkArgument;
+import static com.google.common.base.Preconditions.checkState;
+import static java.util.Objects.requireNonNull;
+
+/**
+ * This class returns a set of positions to aggregate over for an aggregate
function in row pattern
+ * matching context. Aggregations in row pattern matching context have RUNNING
or FINAL semantics,
+ * and they apply only to rows matched with certain pattern variables. For
example, for a match "A B
+ * A A B", the aggregation `sum(B.x)` only applies to the second and the last
position.
+ *
+ * <p>This evaluator is stateful. It requires a reset for every new match. The
method
+ * `resolveNewPositions()` returns a portion of positions corresponding to the
new portion of the
+ * match since the last call. It is thus assumed that a sequence of calls
since the instance
+ * creation or `reset()` applies to the same match (i.e. the `matchedLabels`
passed as an argument
+ * to one call is prefix of `matchedLabels` passed in the next call).
+ *
+ * <p>Also, a full list of positions for a current match is kept, and it can
be obtained via the
+ * `getAllPositions()` method. This is for the purpose comparing pattern
matching threads in
+ * ThreadEquivalence.
+ */
+public class PatternAggregationTracker {
+ private static final int DEFAULT_CAPACITY = 10;
+
+ private final Set<Integer> labels;
+ private final boolean running;
+
+ // length of the aggregated prefix of the match
+ private int aggregated;
+
+ // length of the prefix of the match where all applicable positions were
identified,
+ // and stored in `allPositions`. During the pattern matching phase it might
exceed
+ // the `aggregated` prefix due to `getAllPositions()` calls.
+ private int evaluated;
+
+ // all identified applicable positions for the current match until the
`evaluated` position,
+ // starting from 0
+ // this list is updated:
+ // - by `resolveNewPositions()`
+ // - by `getAllPositions()` (only during the pattern matching phase)
+ private final IntList allPositions;
+
+ public PatternAggregationTracker(Set<Integer> labels, boolean running) {
+ this.labels = requireNonNull(labels, "labels is null");
+ this.running = running;
+ this.allPositions = new IntList(DEFAULT_CAPACITY);
+ }
+
+ // for copying
+ private PatternAggregationTracker(
+ Set<Integer> labels, boolean running, int aggregated, int evaluated,
IntList allPositions) {
+ this.labels = labels;
+ this.running = running;
+ this.aggregated = aggregated;
+ this.evaluated = evaluated;
+ this.allPositions = allPositions;
+ }
+
+ public void reset() {
+ aggregated = 0;
+ evaluated = 0;
+ allPositions.clear();
+ }
+
+ /**
+ * This method is used for resolving positions for aggregation: - During
pattern matching. In this
+ * case, the evaluated label has been appended to `matchedLabels` - When
computing row pattern
+ * measures after a non-empty match is found. Search is limited up to the
current row in case of
+ * RUNNING semantics and to the entire match in case of FINAL semantics.
+ *
+ * <p>If `evaluated` exceeds `aggregated`, we could reuse the pre-evaluated
positions. For that,
+ * we need to keep count of all previously returned positions from the
`aggregated` prefix.
+ *
+ * @return array of new matching positions since the last call, relative to
partition start
+ */
+ public ArrayView resolveNewPositions(
+ int currentRow, ArrayView matchedLabels, int partitionStart, int
patternStart) {
+ checkArgument(
+ currentRow >= patternStart && currentRow < patternStart +
matchedLabels.length(),
+ "current row is out of bounds of the match");
+ checkState(
+ aggregated <= evaluated && evaluated <= matchedLabels.length(),
+ "PatternAggregationTracker in inconsistent state");
+
+ IntList positions = new IntList(DEFAULT_CAPACITY);
+ int last = running ? currentRow - patternStart : matchedLabels.length() -
1;
+
+ // return positions exceeding the `aggregated` prefix
+ for (int position = aggregated; position <= last; position++) {
+ if (appliesToLabel(matchedLabels.get(position))) {
+ positions.add(position + patternStart - partitionStart);
+ if (aggregated >= evaluated) {
+ // after exceeding the `evaluated` prefix, store resolved positions
+ allPositions.add(position);
+ }
+ }
+ aggregated++;
+ }
+ evaluated = aggregated;
+
+ return positions.toArrayView();
+ }
+
+ // for ThreadEquivalence
+ // return all positions to aggregate in `labels` starting from 0
+ public ArrayView getAllPositions(ArrayView labels) {
+ checkState(evaluated <= labels.length(), "SetEvaluator in inconsistent
state");
+
+ for (int position = evaluated; position < labels.length(); position++) {
+ if (appliesToLabel(labels.get(position))) {
+ allPositions.add(position);
+ }
+ }
+ evaluated = labels.length();
+
+ return allPositions.toArrayView();
+ }
+
+ private boolean appliesToLabel(int label) {
+ return labels.isEmpty() || labels.contains(label);
+ }
+
+ public PatternAggregationTracker copy() {
+ return new PatternAggregationTracker(
+ labels, running, aggregated, evaluated, allPositions.copy());
+ }
+
+ public long getAllPositionsSizeInBytes() {
+ return allPositions.getSizeInBytes();
+ }
+}
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/rowpattern/PatternAggregator.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/rowpattern/PatternAggregator.java
new file mode 100644
index 00000000000..55fc4425041
--- /dev/null
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/rowpattern/PatternAggregator.java
@@ -0,0 +1,178 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iotdb.db.queryengine.execution.operator.process.rowpattern;
+
+import org.apache.iotdb.db.exception.sql.SemanticException;
+import
org.apache.iotdb.db.queryengine.execution.operator.process.rowpattern.matcher.ArrayView;
+import
org.apache.iotdb.db.queryengine.execution.operator.process.window.partition.Partition;
+import
org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.AggregationMask;
+import
org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.TableAccumulator;
+import org.apache.iotdb.db.queryengine.plan.relational.function.BoundSignature;
+
+import org.apache.tsfile.block.column.Column;
+import org.apache.tsfile.block.column.ColumnBuilder;
+import org.apache.tsfile.read.common.block.column.RunLengthEncodedColumn;
+
+import java.util.List;
+
+import static java.util.Objects.requireNonNull;
+import static
org.apache.iotdb.db.queryengine.execution.operator.source.relational.AbstractTableScanOperator.TIME_COLUMN_TEMPLATE;
+
+/**
+ * This class computes an aggregate function result in row pattern recognition
context.
+ *
+ * <p>Expressions in DEFINE and MEASURES clauses can contain aggregate
functions. Each of these
+ * aggregate functions is transformed into an instance of `MatchAggregation`
class.
+ *
+ * <p>Whenever the aggregate function needs to be evaluated , the method
`aggregate()` is called.
+ * The returned value is then used to evaluate the enclosing expression.
+ *
+ * <p>The aggregate function needs to be evaluated in certain cases: 1. during
the pattern matching
+ * phase, e.g. with a defining condition: `DEFINE A AS avg(B.x) > 0`, the
aggregate function `avg`
+ * needs to be evaluated over all rows matched so far to label `B` every time
the matching algorithm
+ * tries to match label `A`. 2. during row pattern measures computation, e.g.
with `MEASURES M1 AS
+ * RUNNING sum(A.x)`, the running sum must be evaluated over all rows matched
to label `A` up to
+ * every row included in the match; with `MEASURES M2 AS FINAL sum(A.x)`, the
overall sum must be
+ * computed for rows matched to label `A` in the entire match, and the result
must be propagated for
+ * every output row.
+ *
+ * <p>To avoid duplicate computations, `MatchAggregation` is stateful. The
state consists of: - the
+ * accumulator, which holds the partial result - the patternAggregator, which
determines the new
+ * positions to aggregate over since the previous call If the
`MatchAggregation` instance is going
+ * to be reused for different matches, it has to be `reset` before a new match.
+ */
+public class PatternAggregator {
+ // It stores the relevant information of the aggregation function, including
the name of the
+ // aggregation function, parameters, and return type.
+ private final BoundSignature boundSignature;
+ private final TableAccumulator accumulator;
+
+ // one expression corresponds to one instance of `PatternAggregator`, and
`argumentChannels`
+ // stores all the columns that need to be aggregated in this expression.
+ private List<Integer> argumentChannels;
+ private PatternAggregationTracker patternAggregationTracker;
+
+ public PatternAggregator(
+ BoundSignature boundSignature,
+ TableAccumulator accumulator,
+ List<Integer> argumentChannels,
+ PatternAggregationTracker patternAggregationTracker) {
+ this.boundSignature = requireNonNull(boundSignature, "boundSignature is
null");
+ this.accumulator = requireNonNull(accumulator, "accumulato is null");
+ this.argumentChannels = argumentChannels;
+ this.patternAggregationTracker =
+ requireNonNull(patternAggregationTracker, "patternAggregationTracker
is null");
+ accumulator.reset();
+ }
+
+ // reset for a new match during measure computations phase
+ public void reset() {
+ accumulator.reset();
+ patternAggregationTracker.reset();
+ }
+
+ /**
+ * Identify the new positions for aggregation since the last time this
aggregation was run, and
+ * add them to `accumulator`. Return the overall aggregation result. This
method is used for: -
+ * Evaluating labels during pattern matching. In this case, the evaluated
label has been appended
+ * to `matchedLabels`, - Computing row pattern measures after a non-empty
match is found.
+ */
+ public Object aggregate(
+ int currentRow,
+ ArrayView matchedLabels,
+ Partition partition,
+ int partitionStart,
+ int patternStart) {
+
+ // calculate the row positions that need to be newly aggregated since the
last call (relative to
+ // the starting position of the partition)
+ ArrayView positions =
+ patternAggregationTracker.resolveNewPositions(
+ currentRow, matchedLabels, partitionStart, patternStart);
+
+ AggregationMask mask =
+ AggregationMask.createSelectedPositions(
+ partition.getPositionCount(), positions.toArray(),
positions.length());
+
+ // process COUNT()/COUNT(*)
+ if (argumentChannels.isEmpty()) { // function with no arguments
+ Column[] arguments =
+ new Column[] {
+ new RunLengthEncodedColumn(TIME_COLUMN_TEMPLATE,
partition.getPositionCount())
+ };
+ accumulator.addInput(arguments, mask);
+ } else {
+ // extract the columns that need to be aggregated.
+ int argCount = argumentChannels.size();
+ Column[] argumentColumns = new Column[argCount];
+
+ for (int i = 0; i < argCount; i++) {
+ int channel = argumentChannels.get(i);
+ // Create a `ColumnBuilder` instance using the type of the i-th
parameter
+ ColumnBuilder builder =
+
boundSignature.getArgumentType(i).createColumnBuilder(partition.getPositionCount());
+
+ for (int row = 0; row < partition.getPositionCount(); row++) {
+ partition.writeTo(builder, channel, row);
+ }
+
+ argumentColumns[i] = builder.build();
+ }
+
+ accumulator.addInput(argumentColumns, mask);
+ }
+
+ // The return result of the aggregation function is one line
+ ColumnBuilder resultBuilder =
boundSignature.getReturnType().createColumnBuilder(1);
+ accumulator.evaluateFinal(resultBuilder);
+ return resultBuilder.build().getObject(0);
+ }
+
+ /**
+ * Aggregate over empty input. This method is used for computing row pattern
measures for empty
+ * matches. According to the SQL specification, in such case: - count()
aggregation should return
+ * 0, - all other aggregations should return null. In Trino, certain
aggregations do not follow
+ * this pattern (e.g. count_if). This implementation is consistent with
aggregations behavior in
+ * Trino.
+ */
+ // public Block aggregateEmpty() {
+ // if (resultOnEmpty != null) {
+ // return resultOnEmpty;
+ // }
+ // BlockBuilder blockBuilder =
boundSignature.getReturnType().createBlockBuilder(null, 1);
+ // accumulatorFactory.get().evaluateFinal(blockBuilder);
+ // resultOnEmpty = blockBuilder.build();
+ // return resultOnEmpty;
+ // }
+
+ public PatternAggregator copy() {
+ TableAccumulator accumulatorCopy;
+ try {
+ accumulatorCopy = accumulator.copy();
+ } catch (UnsupportedOperationException e) {
+ throw new SemanticException(
+ String.format(
+ "aggregate function %s does not support copying",
boundSignature.getName()));
+ }
+
+ return new PatternAggregator(
+ boundSignature, accumulatorCopy, argumentChannels,
patternAggregationTracker.copy());
+ }
+}
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/rowpattern/PatternAggregators.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/rowpattern/PatternAggregators.java
new file mode 100644
index 00000000000..d25363f2304
--- /dev/null
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/rowpattern/PatternAggregators.java
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iotdb.db.queryengine.execution.operator.process.rowpattern;
+
+import java.util.Arrays;
+import java.util.List;
+
+import static com.google.common.base.Preconditions.checkState;
+
+public class PatternAggregators {
+ private PatternAggregator[][] values;
+ // values[matchId][aggregationIndex] represents the `aggregationIndex`
aggregate function in the
+ // `matchId` match
+ private final List<PatternAggregator> patternAggregators;
+
+ public PatternAggregators(int capacity, List<PatternAggregator>
patternAggregators) {
+ this.values = new PatternAggregator[capacity][];
+ this.patternAggregators = patternAggregators;
+ }
+
+ public PatternAggregator[] get(int key) {
+ if (values[key] == null) {
+ PatternAggregator[] aggregations = new
PatternAggregator[patternAggregators.size()];
+ for (int i = 0; i < patternAggregators.size(); i++) {
+ aggregations[i] = patternAggregators.get(i);
+ // no need to reset() when creating new MatchAggregation
+ values[key] = aggregations;
+ }
+ }
+ return values[key];
+ }
+
+ public void release(int key) {
+ if (values[key] != null) {
+ values[key] = null;
+ }
+ }
+
+ public void copy(int parent, int child) {
+ ensureCapacity(child);
+ checkState(values[child] == null, "overriding aggregations for child
thread");
+
+ if (values[parent] != null) {
+ PatternAggregator[] aggregations = new
PatternAggregator[patternAggregators.size()];
+ for (int i = 0; i < patternAggregators.size(); i++) {
+ aggregations[i] = values[parent][i].copy();
+ values[child] = aggregations;
+ }
+ }
+ }
+
+ private void ensureCapacity(int key) {
+ if (key >= values.length) {
+ values = Arrays.copyOf(values, Math.max(values.length * 2, key + 1));
+ }
+ }
+}
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/rowpattern/PatternPartitionExecutor.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/rowpattern/PatternPartitionExecutor.java
index 25585b3cd11..d76fe1232c5 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/rowpattern/PatternPartitionExecutor.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/rowpattern/PatternPartitionExecutor.java
@@ -68,6 +68,11 @@ public final class PatternPartitionExecutor {
private final Optional<LogicalIndexNavigation> skipToNavigation;
private final Matcher matcher;
private final List<PatternVariableComputation> patternVariableComputations;
+ // an array of all MatchAggregations from all row pattern measures,
+ // used to reset the MatchAggregations for every new match.
+ // each of MeasureComputations also has access to the MatchAggregations,
+ // and uses them to compute the result values
+ private final PatternAggregator[] patternAggregators;
private final List<PatternExpressionComputation> measureComputations;
private final List<String> labelNames;
@@ -87,6 +92,7 @@ public final class PatternPartitionExecutor {
Optional<LogicalIndexNavigation> skipToNavigation,
Matcher matcher,
List<PatternVariableComputation> patternVariableComputations,
+ List<PatternAggregator> patternAggregators,
List<PatternExpressionComputation> measureComputations,
List<String> labelNames) {
// Partition
@@ -103,6 +109,7 @@ public final class PatternPartitionExecutor {
this.skipToNavigation = skipToNavigation;
this.matcher = matcher;
this.patternVariableComputations =
ImmutableList.copyOf(patternVariableComputations);
+ this.patternAggregators = patternAggregators.toArray(new
PatternAggregator[] {});
this.measureComputations = ImmutableList.copyOf(measureComputations);
this.labelNames = ImmutableList.copyOf(labelNames);
@@ -167,6 +174,10 @@ public final class PatternPartitionExecutor {
lastSkippedPosition = currentPosition;
matchNumber++;
} else { // non-empty match
+ for (PatternAggregator patternAggregator : patternAggregators) {
+ patternAggregator.reset();
+ }
+
if (rowsPerMatch.isOneRow()) {
outputOneRowPerMatch(builder, matchResult, patternStart,
searchStart, searchEnd);
} else {
@@ -283,6 +294,7 @@ public final class PatternPartitionExecutor {
// evaluate the MEASURES clause with the last row in the match
patternStart + labels.length() - 1,
labels,
+ patternAggregators,
partitionStart,
searchStart,
searchEnd,
@@ -358,6 +370,7 @@ public final class PatternPartitionExecutor {
measureComputation.compute(
position,
labels,
+ patternAggregators,
partitionStart,
searchStart,
searchEnd,
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/rowpattern/PatternVariableRecognizer.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/rowpattern/PatternVariableRecognizer.java
index 6cdb8c55c7d..39cd5b163e5 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/rowpattern/PatternVariableRecognizer.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/rowpattern/PatternVariableRecognizer.java
@@ -80,7 +80,7 @@ public class PatternVariableRecognizer {
// evaluate the last label in matchedLabels. It has been tentatively
appended to the match
// The `evaluateLabel` method is used to determine whether it is feasible to
identify the current
// row as a label
- public boolean evaluateLabel(ArrayView matchedLabels) {
+ public boolean evaluateLabel(ArrayView matchedLabels, PatternAggregator[]
patternAggregators) {
int label = matchedLabels.get(matchedLabels.length() - 1);
// The variable `evaluation` stores the logic for determining whether a
row can be identified as
// `label`
@@ -89,6 +89,7 @@ public class PatternVariableRecognizer {
// `label`
return patternVariableComputation.test(
matchedLabels,
+ patternAggregators,
partitionStart,
searchStart,
searchEnd,
@@ -108,8 +109,9 @@ public class PatternVariableRecognizer {
public PatternVariableComputation(
List<PhysicalValueAccessor> valueAccessors,
Computation computation,
+ List<PatternAggregator> patternAggregators,
List<String> labelNames) {
- super(valueAccessors, computation);
+ super(valueAccessors, computation, patternAggregators);
this.labelNames = requireNonNull(labelNames, "labelNames is null");
}
@@ -121,6 +123,7 @@ public class PatternVariableRecognizer {
*/
public boolean test(
ArrayView matchedLabels,
+ PatternAggregator[] patternAggregators,
int partitionStart,
int searchStart,
int searchEnd,
@@ -133,6 +136,7 @@ public class PatternVariableRecognizer {
this.compute(
currentRow,
matchedLabels,
+ patternAggregators,
partitionStart,
searchStart,
searchEnd,
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/rowpattern/PhysicalAggregationPointer.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/rowpattern/PhysicalAggregationPointer.java
new file mode 100644
index 00000000000..ed1ada2ac03
--- /dev/null
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/rowpattern/PhysicalAggregationPointer.java
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iotdb.db.queryengine.execution.operator.process.rowpattern;
+
+public class PhysicalAggregationPointer implements PhysicalValueAccessor {
+ private final int index;
+
+ public PhysicalAggregationPointer(int index) {
+ this.index = index;
+ }
+
+ public int getIndex() {
+ return index;
+ }
+}
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/rowpattern/expression/Computation.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/rowpattern/expression/Computation.java
index a442647d433..4e184b7db7b 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/rowpattern/expression/Computation.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/rowpattern/expression/Computation.java
@@ -20,6 +20,8 @@
package
org.apache.iotdb.db.queryengine.execution.operator.process.rowpattern.expression;
import org.apache.iotdb.db.exception.sql.SemanticException;
+import
org.apache.iotdb.db.queryengine.plan.relational.planner.rowpattern.ExpressionAndValuePointers;
+import
org.apache.iotdb.db.queryengine.plan.relational.planner.rowpattern.ExpressionAndValuePointers.Assignment;
import
org.apache.iotdb.db.queryengine.plan.relational.sql.ast.ArithmeticBinaryExpression;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.BooleanLiteral;
import
org.apache.iotdb.db.queryengine.plan.relational.sql.ast.ComparisonExpression;
@@ -64,9 +66,14 @@ public abstract class Computation {
*/
public static class ComputationParser {
- public static Computation parse(Expression expression) {
+ public static Computation parse(ExpressionAndValuePointers
expressionAndValuePointers) {
+ Expression expression = expressionAndValuePointers.getExpression();
+ List<Assignment> assignments =
expressionAndValuePointers.getAssignments();
AtomicInteger counter = new AtomicInteger(0);
Map<String, Integer> symbolToIndex = new HashMap<>();
+ for (int i = 0; i < assignments.size(); i++) {
+ symbolToIndex.put(assignments.get(i).getSymbol().getName(), i);
+ }
return parse(expression, counter, symbolToIndex);
}
@@ -102,15 +109,11 @@ public abstract class Computation {
NaryOperator op = mapLogicalOperator(logicalExpr.getOperator());
return new NaryComputation(computations, op);
} else if (expression instanceof SymbolReference) {
- // Upon encountering a SymbolReference type, it is converted into a
ReferenceComputation.
+ // upon encountering a SymbolReference type, it is converted into a
ReferenceComputation.
// B.value < LAST(B.value) -> b_0 < b_1
// LAST(B.value, 1) < LAST(B.value, 2) -> b_0 < b_1 + 1
SymbolReference symRef = (SymbolReference) expression;
String name = symRef.getName();
- // If an index has not been previously assigned to the symbol, a new
index is allocated.
- if (!symbolToIndex.containsKey(name)) {
- symbolToIndex.put(name, counter.getAndIncrement());
- }
int index = symbolToIndex.get(name);
return new ReferenceComputation(index);
} else if (expression instanceof LongLiteral) {
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/rowpattern/expression/PatternExpressionComputation.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/rowpattern/expression/PatternExpressionComputation.java
index fff47638ee9..0b4f0ed166a 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/rowpattern/expression/PatternExpressionComputation.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/rowpattern/expression/PatternExpressionComputation.java
@@ -20,6 +20,8 @@
package
org.apache.iotdb.db.queryengine.execution.operator.process.rowpattern.expression;
import org.apache.iotdb.db.exception.sql.SemanticException;
+import
org.apache.iotdb.db.queryengine.execution.operator.process.rowpattern.PatternAggregator;
+import
org.apache.iotdb.db.queryengine.execution.operator.process.rowpattern.PhysicalAggregationPointer;
import
org.apache.iotdb.db.queryengine.execution.operator.process.rowpattern.PhysicalValueAccessor;
import
org.apache.iotdb.db.queryengine.execution.operator.process.rowpattern.PhysicalValuePointer;
import
org.apache.iotdb.db.queryengine.execution.operator.process.rowpattern.matcher.ArrayView;
@@ -51,15 +53,22 @@ public class PatternExpressionComputation {
// depend on actual data in the TsBlock are delegated to the valueAccessor
for positioning.
private final Computation computation;
+ // It stores all the aggregation functions in the current pattern expression.
+ private final PatternAggregator[] patternAggregators;
+
public PatternExpressionComputation(
- List<PhysicalValueAccessor> valueAccessors, Computation computation) {
+ List<PhysicalValueAccessor> valueAccessors,
+ Computation computation,
+ List<PatternAggregator> patternAggregators) {
this.valueAccessors = valueAccessors;
this.computation = computation;
+ this.patternAggregators = patternAggregators.toArray(new
PatternAggregator[] {});
}
public Object compute(
int currentRow,
ArrayView matchedLabels, // If the value is i, the currentRow matches
labelNames[i]
+ PatternAggregator[] patternAggregators,
int partitionStart,
int searchStart,
int searchEnd,
@@ -103,6 +112,13 @@ public class PatternExpressionComputation {
values.add(null);
}
}
+ } else if (accessor instanceof PhysicalAggregationPointer) {
+ PatternAggregator aggregator =
+ patternAggregators[((PhysicalAggregationPointer)
accessor).getIndex()];
+
+ values.add(
+ aggregator.aggregate(
+ currentRow, matchedLabels, partition, partitionStart,
patternStart));
}
}
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/rowpattern/matcher/Matcher.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/rowpattern/matcher/Matcher.java
index 7e12ee7e5ea..ffb5f7e6fce 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/rowpattern/matcher/Matcher.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/rowpattern/matcher/Matcher.java
@@ -19,17 +19,21 @@
package
org.apache.iotdb.db.queryengine.execution.operator.process.rowpattern.matcher;
+import
org.apache.iotdb.db.queryengine.execution.operator.process.rowpattern.PatternAggregator;
+import
org.apache.iotdb.db.queryengine.execution.operator.process.rowpattern.PatternAggregators;
import
org.apache.iotdb.db.queryengine.execution.operator.process.rowpattern.PatternVariableRecognizer;
import org.apache.tsfile.utils.RamUsageEstimator;
+import java.util.List;
+
import static
org.apache.iotdb.db.queryengine.execution.operator.process.rowpattern.matcher.MatchResult.NO_MATCH;
public class Matcher {
private final Program program;
// private final ThreadEquivalence threadEquivalence;
- // private final List<MatchAggregationInstantiator> aggregations;
+ private final List<PatternAggregator> patternAggregators;
private static class Runtime {
private static final long INSTANCE_SIZE =
@@ -51,9 +55,13 @@ public class Matcher {
private final PatternCaptures patternCaptures;
// for each thread, array of MatchAggregations evaluated by this thread
- // private final MatchAggregations aggregations;
+ private final PatternAggregators aggregators;
- public Runtime(Program program, int inputLength, boolean
matchingAtPartitionStart) {
+ public Runtime(
+ Program program,
+ int inputLength,
+ boolean matchingAtPartitionStart,
+ List<PatternAggregator> patternAggregators) {
int initialCapacity = 2 * program.size();
threads = new IntList(initialCapacity);
freeThreadIds = new IntStack(initialCapacity);
@@ -62,6 +70,7 @@ public class Matcher {
initialCapacity, program.getMinSlotCount(),
program.getMinLabelCount());
this.inputLength = inputLength;
this.matchingAtPartitionStart = matchingAtPartitionStart;
+ this.aggregators = new PatternAggregators(initialCapacity,
patternAggregators);
// this.aggregations =
// new MatchAggregations(
// initialCapacity, aggregationInstantiators,
aggregationsMemoryContext);
@@ -73,7 +82,7 @@ public class Matcher {
private int forkThread(int parent) {
int child = newThread();
patternCaptures.copy(parent, child);
- // aggregations.copy(parent, child);
+ aggregators.copy(parent, child);
return child;
}
@@ -98,7 +107,7 @@ public class Matcher {
private void killThread(int threadId) {
freeThreadIds.push(threadId);
patternCaptures.release(threadId);
- // aggregations.release(threadId);
+ aggregators.release(threadId);
}
private long getSizeInBytes() {
@@ -108,12 +117,13 @@ public class Matcher {
+ threads.getSizeInBytes()
+ freeThreadIds.getSizeInBytes()
+ patternCaptures.getSizeInBytes();
- // + aggregations.getSizeInBytes();
+ // + patternAggregators.getSizeInBytes();
}
}
- public Matcher(Program program) {
+ public Matcher(Program program, List<PatternAggregator> patternAggregators) {
this.program = program;
+ this.patternAggregators = patternAggregators;
}
public MatchResult run(PatternVariableRecognizer patternVariableRecognizer) {
@@ -123,7 +133,8 @@ public class Matcher {
int inputLength = patternVariableRecognizer.getInputLength();
boolean matchingAtPartitionStart =
patternVariableRecognizer.isMatchingAtPartitionStart();
- Runtime runtime = new Runtime(program, inputLength,
matchingAtPartitionStart);
+ Runtime runtime =
+ new Runtime(program, inputLength, matchingAtPartitionStart,
patternAggregators);
advanceAndSchedule(current, runtime.newThread(), 0, 0, runtime);
@@ -159,7 +170,7 @@ public class Matcher {
// incorrectly saved label does not matter
runtime.patternCaptures.saveLabel(threadId, label);
if (patternVariableRecognizer.evaluateLabel(
- runtime.patternCaptures.getLabels(threadId))) {
+ runtime.patternCaptures.getLabels(threadId),
runtime.aggregators.get(threadId))) {
advanceAndSchedule(next, threadId, pointer + 1, index + 1,
runtime);
} else {
runtime.scheduleKill(threadId);
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/TableOperatorGenerator.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/TableOperatorGenerator.java
index 74f4cd6c027..132f5ebcc00 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/TableOperatorGenerator.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/TableOperatorGenerator.java
@@ -84,7 +84,10 @@ import
org.apache.iotdb.db.queryengine.execution.operator.process.join.merge.Sin
import
org.apache.iotdb.db.queryengine.execution.operator.process.join.merge.comparator.JoinKeyComparatorFactory;
import
org.apache.iotdb.db.queryengine.execution.operator.process.last.LastQueryUtil;
import
org.apache.iotdb.db.queryengine.execution.operator.process.rowpattern.LogicalIndexNavigation;
+import
org.apache.iotdb.db.queryengine.execution.operator.process.rowpattern.PatternAggregationTracker;
+import
org.apache.iotdb.db.queryengine.execution.operator.process.rowpattern.PatternAggregator;
import
org.apache.iotdb.db.queryengine.execution.operator.process.rowpattern.PatternVariableRecognizer;
+import
org.apache.iotdb.db.queryengine.execution.operator.process.rowpattern.PhysicalAggregationPointer;
import
org.apache.iotdb.db.queryengine.execution.operator.process.rowpattern.PhysicalValueAccessor;
import
org.apache.iotdb.db.queryengine.execution.operator.process.rowpattern.PhysicalValuePointer;
import
org.apache.iotdb.db.queryengine.execution.operator.process.rowpattern.expression.Computation;
@@ -146,6 +149,7 @@ import
org.apache.iotdb.db.queryengine.plan.planner.plan.node.sink.IdentitySinkN
import
org.apache.iotdb.db.queryengine.plan.planner.plan.parameter.InputLocation;
import
org.apache.iotdb.db.queryengine.plan.planner.plan.parameter.SeriesScanOptions;
import
org.apache.iotdb.db.queryengine.plan.relational.analyzer.predicate.ConvertPredicateToTimeFilterVisitor;
+import org.apache.iotdb.db.queryengine.plan.relational.function.BoundSignature;
import org.apache.iotdb.db.queryengine.plan.relational.function.FunctionKind;
import org.apache.iotdb.db.queryengine.plan.relational.metadata.ColumnSchema;
import org.apache.iotdb.db.queryengine.plan.relational.metadata.DeviceEntry;
@@ -206,6 +210,8 @@ import
org.apache.iotdb.db.queryengine.plan.relational.planner.node.WindowNode;
import
org.apache.iotdb.db.queryengine.plan.relational.planner.node.schema.TableDeviceFetchNode;
import
org.apache.iotdb.db.queryengine.plan.relational.planner.node.schema.TableDeviceQueryCountNode;
import
org.apache.iotdb.db.queryengine.plan.relational.planner.node.schema.TableDeviceQueryScanNode;
+import
org.apache.iotdb.db.queryengine.plan.relational.planner.rowpattern.AggregationLabelSet;
+import
org.apache.iotdb.db.queryengine.plan.relational.planner.rowpattern.AggregationValuePointer;
import
org.apache.iotdb.db.queryengine.plan.relational.planner.rowpattern.ClassifierValuePointer;
import
org.apache.iotdb.db.queryengine.plan.relational.planner.rowpattern.ExpressionAndValuePointers;
import
org.apache.iotdb.db.queryengine.plan.relational.planner.rowpattern.IrLabel;
@@ -264,6 +270,7 @@ import org.apache.tsfile.write.schema.MeasurementSchema;
import javax.validation.constraints.NotNull;
import java.io.File;
+import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
@@ -3212,6 +3219,31 @@ public class TableOperatorGenerator extends
PlanVisitor<Operator, LocalExecution
}
}
+ private PatternAggregator buildPatternAggregator(
+ ResolvedFunction resolvedFunction,
+ List<Map.Entry<Expression, Type>> arguments,
+ List<Integer> argumentChannels,
+ PatternAggregationTracker patternAggregationTracker) {
+ String functionName = resolvedFunction.getSignature().getName();
+ List<TSDataType> originalArgumentTypes =
+ resolvedFunction.getSignature().getArgumentTypes().stream()
+ .map(InternalTypeManager::getTSDataType)
+ .collect(Collectors.toList());
+
+ TableAccumulator accumulator =
+ createBuiltinAccumulator(
+ getAggregationTypeByFuncName(functionName),
+ originalArgumentTypes,
+
arguments.stream().map(Map.Entry::getKey).collect(Collectors.toList()),
+ Collections.emptyMap(),
+ true);
+
+ BoundSignature signature = resolvedFunction.getSignature();
+
+ return new PatternAggregator(
+ signature, accumulator, argumentChannels, patternAggregationTracker);
+ }
+
@Override
public Operator visitPatternRecognition(
PatternRecognitionNode node, LocalExecutionPlanContext context) {
@@ -3296,6 +3328,18 @@ public class TableOperatorGenerator extends
PlanVisitor<Operator, LocalExecution
// 3. DEFINE: prepare patternVariableComputation
(PatternVariableRecognizer is to be
// instantiated once per partition)
+
+ // during pattern matching, each thread will have a list of aggregations
necessary for label
+ // evaluations.
+ // the list of aggregations for a thread will be produced at thread
creation time from this
+ // supplier list, respecting the order.
+ // pointers in LabelEvaluator and ThreadEquivalence will access
aggregations by position in
+ // list.
+ int matchAggregationIndex = 0;
+ ImmutableList.Builder<PatternAggregator>
variableRecognizerAggregatorBuilder =
+ ImmutableList.builder();
+ List<PatternAggregator> variableRecognizerAggregators = ImmutableList.of();
+
ImmutableList.Builder<PatternVariableRecognizer.PatternVariableComputation>
evaluationsBuilder =
ImmutableList.builder();
@@ -3328,19 +3372,56 @@ public class TableOperatorGenerator extends
PlanVisitor<Operator, LocalExecution
ImmutableList.of(scalarPointer.getInputSymbol()),
childLayout)),
context.getTypeProvider().getTableModelType(scalarPointer.getInputSymbol()),
scalarPointer.getLogicalIndexPointer().toLogicalIndexNavigation(mapping)));
+ } else if (pointer instanceof AggregationValuePointer) {
+ AggregationValuePointer aggregationPointer =
(AggregationValuePointer) pointer;
+
+ ResolvedFunction resolvedFunction = aggregationPointer.getFunction();
+
+ ImmutableList.Builder<Map.Entry<Expression, Type>> builder =
ImmutableList.builder();
+ List<Type> signatureTypes =
resolvedFunction.getSignature().getArgumentTypes();
+ for (int i = 0; i < aggregationPointer.getArguments().size(); i++) {
+ builder.add(
+ new AbstractMap.SimpleEntry<>(
+ aggregationPointer.getArguments().get(i),
signatureTypes.get(i)));
+ }
+ List<Map.Entry<Expression, Type>> arguments = builder.build();
+
+ List<Integer> valueChannels = new ArrayList<>();
+
+ for (Map.Entry<Expression, Type> argumentWithType : arguments) {
+ Expression argument = argumentWithType.getKey();
+ valueChannels.add(childLayout.get(Symbol.from(argument)));
+ }
+
+ AggregationLabelSet labelSet = aggregationPointer.getSetDescriptor();
+ Set<Integer> labels =
+
labelSet.getLabels().stream().map(mapping::get).collect(Collectors.toSet());
+ PatternAggregationTracker patternAggregationTracker =
+ new PatternAggregationTracker(
+ labels, aggregationPointer.getSetDescriptor().isRunning());
+
+ PatternAggregator variableRecognizerAggregator =
+ buildPatternAggregator(
+ resolvedFunction, arguments, valueChannels,
patternAggregationTracker);
+
+
variableRecognizerAggregatorBuilder.add(variableRecognizerAggregator);
+
+ valueAccessors.add(new
PhysicalAggregationPointer(matchAggregationIndex));
+ matchAggregationIndex++;
}
}
+ variableRecognizerAggregators =
variableRecognizerAggregatorBuilder.build();
+
// transform the symbolic expression tree in the logical planning stage
into a parametric
// expression tree
- Computation computation =
-
Computation.ComputationParser.parse(expressionAndValuePointers.getExpression());
+ Computation computation =
Computation.ComputationParser.parse(expressionAndValuePointers);
// construct a `PatternVariableComputation` object, where valueAccessors
is a parameter list
// and computation is a parametric expression tree, encapsulating the
computation logic
PatternVariableRecognizer.PatternVariableComputation
patternVariableComputation =
new PatternVariableRecognizer.PatternVariableComputation(
- valueAccessors, computation, labelNames);
+ valueAccessors, computation, ImmutableList.of(), labelNames);
evaluationsBuilder.add(patternVariableComputation);
}
@@ -3349,6 +3430,11 @@ public class TableOperatorGenerator extends
PlanVisitor<Operator, LocalExecution
ImmutableList.Builder<PatternExpressionComputation>
measureComputationsBuilder =
ImmutableList.builder();
+ matchAggregationIndex = 0;
+ ImmutableList.Builder<PatternAggregator> measurePatternAggregatorBuilder =
+ ImmutableList.builder();
+ List<PatternAggregator> measurePatternAggregators = ImmutableList.of();
+
for (Measure measure : node.getMeasures().values()) {
ExpressionAndValuePointers expressionAndValuePointers =
measure.getExpressionAndValuePointers();
@@ -3377,19 +3463,55 @@ public class TableOperatorGenerator extends
PlanVisitor<Operator, LocalExecution
ImmutableList.of(scalarPointer.getInputSymbol()),
childLayout)),
context.getTypeProvider().getTableModelType(scalarPointer.getInputSymbol()),
scalarPointer.getLogicalIndexPointer().toLogicalIndexNavigation(mapping)));
+ } else if (pointer instanceof AggregationValuePointer) {
+ AggregationValuePointer aggregationPointer =
(AggregationValuePointer) pointer;
+
+ ResolvedFunction resolvedFunction = aggregationPointer.getFunction();
+
+ ImmutableList.Builder<Map.Entry<Expression, Type>> builder =
ImmutableList.builder();
+ List<Type> signatureTypes =
resolvedFunction.getSignature().getArgumentTypes();
+ for (int i = 0; i < aggregationPointer.getArguments().size(); i++) {
+ builder.add(
+ new AbstractMap.SimpleEntry<>(
+ aggregationPointer.getArguments().get(i),
signatureTypes.get(i)));
+ }
+ List<Map.Entry<Expression, Type>> arguments = builder.build();
+
+ List<Integer> valueChannels = new ArrayList<>();
+
+ for (Map.Entry<Expression, Type> argumentWithType : arguments) {
+ Expression argument = argumentWithType.getKey();
+ valueChannels.add(childLayout.get(Symbol.from(argument)));
+ }
+
+ AggregationLabelSet labelSet = aggregationPointer.getSetDescriptor();
+ Set<Integer> labels =
+
labelSet.getLabels().stream().map(mapping::get).collect(Collectors.toSet());
+ PatternAggregationTracker patternAggregationTracker =
+ new PatternAggregationTracker(
+ labels, aggregationPointer.getSetDescriptor().isRunning());
+
+ PatternAggregator measurePatternAggregator =
+ buildPatternAggregator(
+ resolvedFunction, arguments, valueChannels,
patternAggregationTracker);
+
+ measurePatternAggregatorBuilder.add(measurePatternAggregator);
+
+ valueAccessors.add(new
PhysicalAggregationPointer(matchAggregationIndex));
+ matchAggregationIndex++;
}
}
+ measurePatternAggregators = measurePatternAggregatorBuilder.build();
+
// transform the symbolic expression tree in the logical planning stage
into a parametric
// expression tree
- Computation computation =
-
Computation.ComputationParser.parse(expressionAndValuePointers.getExpression());
+ Computation computation =
Computation.ComputationParser.parse(expressionAndValuePointers);
// construct a `PatternExpressionComputation` object, where
valueAccessors is a parameter
- // list
- // and computation is a parametric expression tree, encapsulating the
computation logic
+ // list and computation is a parametric expression tree, encapsulating
the computation logic.
PatternExpressionComputation measureComputation =
- new PatternExpressionComputation(valueAccessors, computation);
+ new PatternExpressionComputation(valueAccessors, computation,
measurePatternAggregators);
measureComputationsBuilder.add(measureComputation);
}
@@ -3415,8 +3537,9 @@ public class TableOperatorGenerator extends
PlanVisitor<Operator, LocalExecution
node.getRowsPerMatch(),
node.getSkipToPosition(),
skipToNavigation,
- new Matcher(program),
+ new Matcher(program, variableRecognizerAggregators),
evaluationsBuilder.build(),
+ measurePatternAggregators,
measureComputationsBuilder.build(),
labelNames);
}
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/ExpressionAnalyzer.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/ExpressionAnalyzer.java
index fad9bc32d2f..13923444e0c 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/ExpressionAnalyzer.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/ExpressionAnalyzer.java
@@ -25,6 +25,7 @@ import org.apache.iotdb.db.queryengine.common.SessionInfo;
import org.apache.iotdb.db.queryengine.execution.warnings.WarningCollector;
import org.apache.iotdb.db.queryengine.plan.analyze.TypeProvider;
import org.apache.iotdb.db.queryengine.plan.relational.analyzer.Analysis.Range;
+import
org.apache.iotdb.db.queryengine.plan.relational.analyzer.PatternRecognitionAnalysis.AggregationDescriptor;
import
org.apache.iotdb.db.queryengine.plan.relational.analyzer.PatternRecognitionAnalysis.ClassifierDescriptor;
import
org.apache.iotdb.db.queryengine.plan.relational.analyzer.PatternRecognitionAnalysis.MatchNumberDescriptor;
import
org.apache.iotdb.db.queryengine.plan.relational.analyzer.PatternRecognitionAnalysis.Navigation;
@@ -100,6 +101,7 @@ import
org.apache.iotdb.db.queryengine.plan.relational.type.TypeNotFoundExceptio
import com.google.common.collect.HashMultimap;
import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.google.common.collect.LinkedHashMultimap;
import com.google.common.collect.Multimap;
@@ -111,20 +113,24 @@ import javax.annotation.Nullable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
+import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.NoSuchElementException;
+import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.BiFunction;
import java.util.function.Function;
+import java.util.stream.Collectors;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.collect.ImmutableList.toImmutableList;
+import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static com.google.common.collect.Iterables.getOnlyElement;
import static com.google.common.collect.Iterators.getOnlyElement;
import static java.util.Collections.unmodifiableMap;
@@ -968,13 +974,17 @@ public class ExpressionAnalyzer {
FunctionCall node, StackableAstVisitorContext<Context> context) {
String functionName = node.getName().getSuffix();
boolean isAggregation = metadata.isAggregationFunction(session,
functionName, accessControl);
+ boolean isRowPatternCount =
+ context.getContext().isPatternRecognition()
+ && isAggregation
+ && node.getName().getSuffix().equalsIgnoreCase("count");
// argument of the form `label.*` is only allowed for row pattern count
function
node.getArguments().stream()
.filter(DereferenceExpression::isQualifiedAllFieldsReference)
.findAny()
.ifPresent(
allRowsReference -> {
- if (node.getArguments().size() > 1) {
+ if (!isRowPatternCount || node.getArguments().size() > 1) {
throw new SemanticException(
"label.* syntax is only supported as the only argument
of row pattern count function");
}
@@ -993,7 +1003,12 @@ public class ExpressionAnalyzer {
}
if (context.getContext().isPatternRecognition()) {
- if (isPatternRecognitionFunction(node)) {
+ if (isAggregation) {
+ if (node.isDistinct()) {
+ throw new SemanticException(
+ "Cannot use DISTINCT with aggregate function in pattern
recognition context");
+ }
+ } else if (isPatternRecognitionFunction(node)) {
validatePatternRecognitionFunction(node);
String name = node.getName().getSuffix().toUpperCase(ENGLISH);
@@ -1011,12 +1026,6 @@ public class ExpressionAnalyzer {
default:
throw new SemanticException("unexpected pattern recognition
function " + name);
}
-
- } else if (isAggregation) {
- if (node.isDistinct()) {
- throw new SemanticException(
- "Cannot use DISTINCT with aggregate function in pattern
recognition context");
- }
}
}
@@ -1090,6 +1099,12 @@ public class ExpressionAnalyzer {
true,
functionNullability);
resolvedFunctions.put(NodeRef.of(node), resolvedFunction);
+
+ // must run after arguments are processed and labels are recorded
+ if (context.getContext().isPatternRecognition() && isAggregation) {
+ analyzePatternAggregation(node, resolvedFunction);
+ }
+
return setExpressionType(node, type);
}
@@ -1375,6 +1390,95 @@ public class ExpressionAnalyzer {
return identifier.getCanonicalValue();
}
+ private ArgumentLabel validateLabelConsistency(FunctionCall node, int
argumentIndex) {
+ Set<Optional<String>> referenceLabels =
+ extractExpressions(node.getArguments(), Expression.class).stream()
+ .map(child -> labels.get(NodeRef.of(child)))
+ .filter(Objects::nonNull)
+ .collect(toImmutableSet());
+
+ Set<Optional<String>> classifierLabels =
+ extractExpressions(
+ ImmutableList.of(node.getArguments().get(argumentIndex)),
FunctionCall.class)
+ .stream()
+ .filter(this::isClassifierFunction)
+ .map(
+ functionCall ->
+ functionCall.getArguments().stream()
+ .findFirst()
+ .map(argument -> label((Identifier) argument)))
+ .collect(toImmutableSet());
+
+ Set<Optional<String>> allLabels =
+ ImmutableSet.<Optional<String>>builder()
+ .addAll(referenceLabels)
+ .addAll(classifierLabels)
+ .build();
+
+ if (allLabels.isEmpty()) {
+ return ArgumentLabel.noLabel();
+ }
+
+ if (allLabels.size() > 1) {
+ String name = node.getName().getSuffix();
+ throw new SemanticException(
+ String.format("All labels and classifiers inside the call to '%s'
must match", name));
+ }
+
+ Optional<String> label = Iterables.getOnlyElement(allLabels);
+ return
label.map(ArgumentLabel::explicitLabel).orElseGet(ArgumentLabel::universalLabel);
+ }
+
+ private Set<String> analyzeAggregationLabels(FunctionCall node) {
+ if (node.getArguments().isEmpty()) {
+ return ImmutableSet.of();
+ }
+
+ Set<Optional<String>> argumentLabels = new HashSet<>();
+ for (int i = 0; i < node.getArguments().size(); i++) {
+ ArgumentLabel argumentLabel = validateLabelConsistency(node, i);
+ if (argumentLabel.hasLabel()) {
+ argumentLabels.add(argumentLabel.getLabel());
+ }
+ }
+ if (argumentLabels.size() > 1) {
+ throw new SemanticException(
+ "All aggregate function arguments must apply to rows matched with
the same label");
+ }
+
+ return argumentLabels.stream()
+ .filter(Optional::isPresent)
+ .map(Optional::get)
+ .collect(Collectors.toSet());
+ }
+
+ private void analyzePatternAggregation(FunctionCall node, ResolvedFunction
function) {
+ checkNoNestedAggregations(node);
+ checkNoNestedNavigations(node);
+ Set<String> labels = analyzeAggregationLabels(node);
+
+ List<FunctionCall> matchNumberCalls =
+ extractExpressions(node.getArguments(), FunctionCall.class).stream()
+ .filter(this::isMatchNumberFunction)
+ .collect(toImmutableList());
+
+ List<FunctionCall> classifierCalls =
+ extractExpressions(node.getArguments(), FunctionCall.class).stream()
+ .filter(this::isClassifierFunction)
+ .collect(toImmutableList());
+
+ patternRecognitionInputs.add(
+ new PatternFunctionAnalysis(
+ node,
+ new AggregationDescriptor(
+ function,
+ node.getArguments(),
+ mapProcessingMode(node.getProcessingMode()),
+ labels,
+ matchNumberCalls,
+ classifierCalls)));
+ }
+
private void checkNoNestedAggregations(FunctionCall node) {
extractExpressions(node.getArguments(), FunctionCall.class).stream()
.filter(
@@ -2191,6 +2295,10 @@ public class ExpressionAnalyzer {
}
}
+ /**
+ * Checks if the given function call is a specific function for pattern
recognition, excluding
+ * aggregation functions.
+ */
public static boolean isPatternRecognitionFunction(FunctionCall node) {
QualifiedName qualifiedName = node.getName();
if (qualifiedName.getParts().size() > 1) {
@@ -2476,4 +2584,35 @@ public class ExpressionAnalyzer {
return column;
}
}
+
+ private static class ArgumentLabel {
+ private final boolean hasLabel; // whether the parameter is bound with a
label
+ private final Optional<String> label;
+
+ private ArgumentLabel(boolean hasLabel, Optional<String> label) {
+ this.hasLabel = hasLabel;
+ this.label = label;
+ }
+
+ public static ArgumentLabel noLabel() {
+ return new ArgumentLabel(false, Optional.empty());
+ }
+
+ public static ArgumentLabel universalLabel() {
+ return new ArgumentLabel(true, Optional.empty());
+ }
+
+ public static ArgumentLabel explicitLabel(String label) {
+ return new ArgumentLabel(true, Optional.of(label));
+ }
+
+ public boolean hasLabel() {
+ return hasLabel;
+ }
+
+ public Optional<String> getLabel() {
+ checkState(hasLabel, "no label available");
+ return label;
+ }
+ }
}
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/RelationPlanner.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/RelationPlanner.java
index 59f0c40f8d5..fed7e954d40 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/RelationPlanner.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/RelationPlanner.java
@@ -40,6 +40,7 @@ import
org.apache.iotdb.db.queryengine.plan.relational.analyzer.Analysis;
import org.apache.iotdb.db.queryengine.plan.relational.analyzer.Field;
import org.apache.iotdb.db.queryengine.plan.relational.analyzer.NodeRef;
import
org.apache.iotdb.db.queryengine.plan.relational.analyzer.PatternRecognitionAnalysis;
+import
org.apache.iotdb.db.queryengine.plan.relational.analyzer.PatternRecognitionAnalysis.AggregationDescriptor;
import
org.apache.iotdb.db.queryengine.plan.relational.analyzer.PatternRecognitionAnalysis.ClassifierDescriptor;
import
org.apache.iotdb.db.queryengine.plan.relational.analyzer.PatternRecognitionAnalysis.MatchNumberDescriptor;
import
org.apache.iotdb.db.queryengine.plan.relational.analyzer.PatternRecognitionAnalysis.Navigation;
@@ -65,6 +66,8 @@ import
org.apache.iotdb.db.queryengine.plan.relational.planner.node.SkipToPositi
import
org.apache.iotdb.db.queryengine.plan.relational.planner.node.TableFunctionNode;
import
org.apache.iotdb.db.queryengine.plan.relational.planner.node.TableScanNode;
import
org.apache.iotdb.db.queryengine.plan.relational.planner.node.TreeDeviceViewScanNode;
+import
org.apache.iotdb.db.queryengine.plan.relational.planner.rowpattern.AggregationLabelSet;
+import
org.apache.iotdb.db.queryengine.plan.relational.planner.rowpattern.AggregationValuePointer;
import
org.apache.iotdb.db.queryengine.plan.relational.planner.rowpattern.ClassifierValuePointer;
import
org.apache.iotdb.db.queryengine.plan.relational.planner.rowpattern.ExpressionAndValuePointers;
import
org.apache.iotdb.db.queryengine.plan.relational.planner.rowpattern.ExpressionAndValuePointers.Assignment;
@@ -81,6 +84,7 @@ import
org.apache.iotdb.db.queryengine.plan.relational.sql.ast.AstVisitor;
import
org.apache.iotdb.db.queryengine.plan.relational.sql.ast.CoalesceExpression;
import
org.apache.iotdb.db.queryengine.plan.relational.sql.ast.ComparisonExpression;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Delete;
+import
org.apache.iotdb.db.queryengine.plan.relational.sql.ast.DereferenceExpression;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Except;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Expression;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Identifier;
@@ -158,6 +162,8 @@ import static
org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Join.Type.
import static
org.apache.iotdb.db.queryengine.plan.relational.sql.ast.PatternRecognitionRelation.RowsPerMatch.ONE;
import static
org.apache.iotdb.db.queryengine.plan.relational.sql.ast.SkipTo.Position.PAST_LAST;
import static
org.apache.iotdb.db.queryengine.plan.relational.utils.NodeUtils.getSortItemsFromOrderBy;
+import static org.apache.tsfile.read.common.type.LongType.INT64;
+import static org.apache.tsfile.read.common.type.StringType.STRING;
public class RelationPlanner extends AstVisitor<RelationPlan, Void> {
@@ -875,7 +881,7 @@ public class RelationPlanner extends
AstVisitor<RelationPlan, Void> {
case ALL_WITH_UNMATCHED:
return RowsPerMatch.ALL_WITH_UNMATCHED;
default:
- throw new IllegalArgumentException("Unexpected value: " +
rowsPerMatch);
+ throw new SemanticException("Unexpected rows per match: " +
rowsPerMatch);
}
}
@@ -992,8 +998,53 @@ public class RelationPlanner extends
AstVisitor<RelationPlan, Void> {
new ScalarValuePointer(
planValuePointer(descriptor.getLabel(),
descriptor.getNavigation(), subsets),
Symbol.from(translations.rewrite(accessor.getExpression())));
+ } else if (accessor.getDescriptor() instanceof AggregationDescriptor) {
+ AggregationDescriptor descriptor = (AggregationDescriptor)
accessor.getDescriptor();
+
+ Map<NodeRef<Expression>, Symbol> mappings = new HashMap<>();
+
+ Optional<Symbol> matchNumberSymbol = Optional.empty();
+ if (!descriptor.getMatchNumberCalls().isEmpty()) {
+ Symbol symbol = symbolAllocator.newSymbol("match_number", INT64);
+ for (Expression call : descriptor.getMatchNumberCalls()) {
+ mappings.put(NodeRef.of(call), symbol);
+ }
+ matchNumberSymbol = Optional.of(symbol);
+ }
+
+ Optional<Symbol> classifierSymbol = Optional.empty();
+ if (!descriptor.getClassifierCalls().isEmpty()) {
+ Symbol symbol = symbolAllocator.newSymbol("classifier", STRING);
+
+ for (Expression call : descriptor.getClassifierCalls()) {
+ mappings.put(NodeRef.of(call), symbol);
+ }
+ classifierSymbol = Optional.of(symbol);
+ }
+
+ TranslationMap argumentTranslation =
translations.withAdditionalIdentityMappings(mappings);
+
+ Set<IrLabel> labels =
+ descriptor.getLabels().stream()
+ .flatMap(label -> planLabels(Optional.of(label),
subsets).stream())
+ .collect(Collectors.toSet());
+
+ pointer =
+ new AggregationValuePointer(
+ descriptor.getFunction(),
+ new AggregationLabelSet(labels, descriptor.getMode() ==
RUNNING),
+ descriptor.getArguments().stream()
+ .filter(
+ argument ->
!DereferenceExpression.isQualifiedAllFieldsReference(argument))
+ .map(
+ argument ->
+ coerceIfNecessary(
+ analysis, argument,
argumentTranslation.rewrite(argument)))
+ .collect(Collectors.toList()),
+ classifierSymbol,
+ matchNumberSymbol);
} else {
- throw new IllegalArgumentException(
+ throw new SemanticException(
"Unexpected descriptor type: " +
accessor.getDescriptor().getClass().getName());
}
@@ -1039,7 +1090,7 @@ public class RelationPlanner extends
AstVisitor<RelationPlan, Void> {
case LAST:
return SkipToPosition.LAST;
default:
- throw new IllegalArgumentException("Unexpected value: " + position);
+ throw new SemanticException("Unexpected skip to position: " +
position);
}
}
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/SymbolMapper.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/SymbolMapper.java
index db198c8ce70..08f08deef5a 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/SymbolMapper.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/SymbolMapper.java
@@ -35,6 +35,7 @@ import
org.apache.iotdb.db.queryengine.plan.relational.planner.node.Measure;
import
org.apache.iotdb.db.queryengine.plan.relational.planner.node.PatternRecognitionNode;
import org.apache.iotdb.db.queryengine.plan.relational.planner.node.TopKNode;
import org.apache.iotdb.db.queryengine.plan.relational.planner.node.WindowNode;
+import
org.apache.iotdb.db.queryengine.plan.relational.planner.rowpattern.AggregationValuePointer;
import
org.apache.iotdb.db.queryengine.plan.relational.planner.rowpattern.ClassifierValuePointer;
import
org.apache.iotdb.db.queryengine.plan.relational.planner.rowpattern.ExpressionAndValuePointers;
import
org.apache.iotdb.db.queryengine.plan.relational.planner.rowpattern.IrLabel;
@@ -349,13 +350,47 @@ public class SymbolMapper {
expressionAndValuePointers.getAssignments()) {
ValuePointer newPointer;
if (assignment.getValuePointer() instanceof ClassifierValuePointer) {
- newPointer = (ClassifierValuePointer) assignment.getValuePointer();
+ newPointer = assignment.getValuePointer();
} else if (assignment.getValuePointer() instanceof
MatchNumberValuePointer) {
- newPointer = (MatchNumberValuePointer) assignment.getValuePointer();
+ newPointer = assignment.getValuePointer();
} else if (assignment.getValuePointer() instanceof ScalarValuePointer) {
ScalarValuePointer pointer = (ScalarValuePointer)
assignment.getValuePointer();
newPointer =
new ScalarValuePointer(pointer.getLogicalIndexPointer(),
map(pointer.getInputSymbol()));
+ } else if (assignment.getValuePointer() instanceof
AggregationValuePointer) {
+ AggregationValuePointer pointer = (AggregationValuePointer)
assignment.getValuePointer();
+ List<Expression> newArguments =
+ pointer.getArguments().stream()
+ .map(
+ expression ->
+ ExpressionTreeRewriter.rewriteWith(
+ new ExpressionRewriter<Void>() {
+ @Override
+ public Expression rewriteSymbolReference(
+ SymbolReference node,
+ Void context,
+ ExpressionTreeRewriter<Void> treeRewriter) {
+ if (pointer.getClassifierSymbol().isPresent()
+ && Symbol.from(node)
+
.equals(pointer.getClassifierSymbol().get())
+ ||
pointer.getMatchNumberSymbol().isPresent()
+ && Symbol.from(node)
+
.equals(pointer.getMatchNumberSymbol().get())) {
+ return node;
+ }
+ return map(node);
+ }
+ },
+ expression))
+ .collect(toImmutableList());
+
+ newPointer =
+ new AggregationValuePointer(
+ pointer.getFunction(),
+ pointer.getSetDescriptor(),
+ newArguments,
+ pointer.getClassifierSymbol(),
+ pointer.getMatchNumberSymbol());
} else {
throw new IllegalArgumentException(
"Unsupported ValuePointer type: " +
assignment.getValuePointer().getClass().getName());
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/rowpattern/AggregationLabelSet.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/rowpattern/AggregationLabelSet.java
new file mode 100644
index 00000000000..f002f3d0c1a
--- /dev/null
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/rowpattern/AggregationLabelSet.java
@@ -0,0 +1,100 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iotdb.db.queryengine.plan.relational.planner.rowpattern;
+
+import java.io.DataOutputStream;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.HashSet;
+import java.util.Objects;
+import java.util.Set;
+
+import static java.util.Objects.requireNonNull;
+
+public class AggregationLabelSet {
+ // A set of labels to identify all rows to aggregate over:
+ // avg(A.price) => this is an aggregation over rows with label A, so labels
= {A}
+ // avg(Union.price) => this is an aggregation over rows matching a union
variable Union, so for
+ // SUBSET Union = (A, B, C), labels = {A, B, C}
+ // avg(price) => this is an aggregation over "universal pattern variable",
which is effectively
+ // over all rows, no matter the assigned labels. In such case labels = {}
+ private final Set<IrLabel> labels;
+
+ private final boolean running;
+
+ public AggregationLabelSet(Set<IrLabel> labels, boolean running) {
+ this.labels = requireNonNull(labels, "labels is null");
+ this.running = running;
+ }
+
+ public Set<IrLabel> getLabels() {
+ return labels;
+ }
+
+ public boolean isRunning() {
+ return running;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+
+ AggregationLabelSet that = (AggregationLabelSet) o;
+ return labels.equals(that.labels) && running == that.running;
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(labels, running);
+ }
+
+ public static void serialize(AggregationLabelSet set, ByteBuffer byteBuffer)
{
+ byteBuffer.putInt(set.labels.size());
+ for (IrLabel label : set.labels) {
+ IrLabel.serialize(label, byteBuffer);
+ }
+ byteBuffer.put((byte) (set.running ? 1 : 0));
+ }
+
+ public static void serialize(AggregationLabelSet set, DataOutputStream
stream)
+ throws IOException {
+ stream.writeInt(set.labels.size());
+ for (IrLabel label : set.labels) {
+ IrLabel.serialize(label, stream);
+ }
+ stream.writeBoolean(set.running);
+ }
+
+ public static AggregationLabelSet deserialize(ByteBuffer byteBuffer) {
+ int size = byteBuffer.getInt();
+ Set<IrLabel> labels = new HashSet<>(size);
+ for (int i = 0; i < size; i++) {
+ labels.add(IrLabel.deserialize(byteBuffer));
+ }
+ boolean running = byteBuffer.get() == 1;
+ return new AggregationLabelSet(labels, running);
+ }
+}
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/rowpattern/AggregationValuePointer.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/rowpattern/AggregationValuePointer.java
new file mode 100644
index 00000000000..3636e2673db
--- /dev/null
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/rowpattern/AggregationValuePointer.java
@@ -0,0 +1,161 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iotdb.db.queryengine.plan.relational.planner.rowpattern;
+
+import
org.apache.iotdb.db.queryengine.plan.relational.metadata.ResolvedFunction;
+import org.apache.iotdb.db.queryengine.plan.relational.planner.Symbol;
+import
org.apache.iotdb.db.queryengine.plan.relational.planner.SymbolsExtractor;
+import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Expression;
+
+import java.io.DataOutputStream;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.List;
+import java.util.Objects;
+import java.util.Optional;
+
+import static com.google.common.collect.ImmutableList.toImmutableList;
+import static java.util.Objects.requireNonNull;
+
+public final class AggregationValuePointer implements ValuePointer {
+ private final ResolvedFunction function;
+ private final AggregationLabelSet setDescriptor;
+ private final List<Expression> arguments;
+ private final Optional<Symbol> classifierSymbol;
+ private final Optional<Symbol> matchNumberSymbol;
+
+ public AggregationValuePointer(
+ ResolvedFunction function,
+ AggregationLabelSet setDescriptor,
+ List<Expression> arguments,
+ Optional<Symbol> classifierSymbol,
+ Optional<Symbol> matchNumberSymbol) {
+ this.function = requireNonNull(function, "function is null");
+ this.setDescriptor = requireNonNull(setDescriptor, "setDescriptor is
null");
+ this.arguments = requireNonNull(arguments, "arguments is null");
+ this.classifierSymbol = requireNonNull(classifierSymbol, "classifierSymbol
is null");
+ this.matchNumberSymbol = requireNonNull(matchNumberSymbol,
"matchNumberSymbol is null");
+ }
+
+ public ResolvedFunction getFunction() {
+ return function;
+ }
+
+ public AggregationLabelSet getSetDescriptor() {
+ return setDescriptor;
+ }
+
+ public List<Expression> getArguments() {
+ return arguments;
+ }
+
+ public Optional<Symbol> getClassifierSymbol() {
+ return classifierSymbol;
+ }
+
+ public Optional<Symbol> getMatchNumberSymbol() {
+ return matchNumberSymbol;
+ }
+
+ public List<Symbol> getInputSymbols() {
+ return arguments.stream()
+ .map(SymbolsExtractor::extractAll)
+ .flatMap(Collection::stream)
+ .filter(
+ symbol ->
+ (!classifierSymbol.isPresent() ||
!classifierSymbol.get().equals(symbol))
+ && (!matchNumberSymbol.isPresent() ||
!matchNumberSymbol.get().equals(symbol)))
+ .collect(toImmutableList());
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ if (this == obj) {
+ return true;
+ }
+ if ((obj == null) || (getClass() != obj.getClass())) {
+ return false;
+ }
+ AggregationValuePointer o = (AggregationValuePointer) obj;
+ return Objects.equals(function, o.function)
+ && Objects.equals(setDescriptor, o.setDescriptor)
+ && Objects.equals(arguments, o.arguments)
+ && Objects.equals(classifierSymbol, o.classifierSymbol)
+ && Objects.equals(matchNumberSymbol, o.matchNumberSymbol);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(function, setDescriptor, arguments, classifierSymbol,
matchNumberSymbol);
+ }
+
+ public static void serialize(AggregationValuePointer pointer, ByteBuffer
byteBuffer) {
+ pointer.function.serialize(byteBuffer);
+ AggregationLabelSet.serialize(pointer.setDescriptor, byteBuffer);
+ byteBuffer.putInt(pointer.arguments.size());
+ for (Expression arg : pointer.arguments) {
+ Expression.serialize(arg, byteBuffer);
+ }
+ byteBuffer.put(pointer.classifierSymbol.isPresent() ? (byte) 1 : (byte) 0);
+ if (pointer.classifierSymbol.isPresent()) {
+ Symbol.serialize(pointer.classifierSymbol.get(), byteBuffer);
+ }
+ byteBuffer.put(pointer.matchNumberSymbol.isPresent() ? (byte) 1 : (byte)
0);
+ if (pointer.matchNumberSymbol.isPresent()) {
+ Symbol.serialize(pointer.matchNumberSymbol.get(), byteBuffer);
+ }
+ }
+
+ public static void serialize(AggregationValuePointer pointer,
DataOutputStream stream)
+ throws IOException {
+ pointer.function.serialize(stream);
+ AggregationLabelSet.serialize(pointer.setDescriptor, stream);
+ stream.writeInt(pointer.arguments.size());
+ for (Expression arg : pointer.arguments) {
+ Expression.serialize(arg, stream);
+ }
+ stream.writeBoolean(pointer.classifierSymbol.isPresent());
+ if (pointer.classifierSymbol.isPresent()) {
+ Symbol.serialize(pointer.classifierSymbol.get(), stream);
+ }
+ stream.writeBoolean(pointer.matchNumberSymbol.isPresent());
+ if (pointer.matchNumberSymbol.isPresent()) {
+ Symbol.serialize(pointer.matchNumberSymbol.get(), stream);
+ }
+ }
+
+ public static AggregationValuePointer deserialize(ByteBuffer byteBuffer) {
+ ResolvedFunction function = ResolvedFunction.deserialize(byteBuffer);
+ AggregationLabelSet setDescriptor =
AggregationLabelSet.deserialize(byteBuffer);
+ int argCount = byteBuffer.getInt();
+ List<Expression> arguments = new ArrayList<>(argCount);
+ for (int i = 0; i < argCount; i++) {
+ arguments.add(Expression.deserialize(byteBuffer));
+ }
+ Optional<Symbol> classifierSymbol =
+ byteBuffer.get() == 1 ? Optional.of(Symbol.deserialize(byteBuffer)) :
Optional.empty();
+ Optional<Symbol> matchNumberSymbol =
+ byteBuffer.get() == 1 ? Optional.of(Symbol.deserialize(byteBuffer)) :
Optional.empty();
+ return new AggregationValuePointer(
+ function, setDescriptor, arguments, classifierSymbol,
matchNumberSymbol);
+ }
+}
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/rowpattern/ExpressionAndValuePointers.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/rowpattern/ExpressionAndValuePointers.java
index 37644973dbf..288fae8fca3 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/rowpattern/ExpressionAndValuePointers.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/rowpattern/ExpressionAndValuePointers.java
@@ -173,6 +173,8 @@ public class ExpressionAndValuePointers {
ReadWriteIOUtils.write(1, byteBuffer);
} else if (assignment.valuePointer instanceof ScalarValuePointer) {
ReadWriteIOUtils.write(2, byteBuffer);
+ } else if (assignment.valuePointer instanceof AggregationValuePointer) {
+ ReadWriteIOUtils.write(3, byteBuffer);
} else {
throw new IllegalArgumentException("Unknown ValuePointer type");
}
@@ -185,6 +187,11 @@ public class ExpressionAndValuePointers {
(ClassifierValuePointer) assignment.valuePointer, byteBuffer);
} else if (assignment.valuePointer instanceof ScalarValuePointer) {
ScalarValuePointer.serialize((ScalarValuePointer)
assignment.valuePointer, byteBuffer);
+ } else if (assignment.valuePointer instanceof AggregationValuePointer) {
+ AggregationValuePointer.serialize(
+ (AggregationValuePointer) assignment.valuePointer, byteBuffer);
+ } else {
+ throw new IllegalArgumentException("Unknown ValuePointer type");
}
}
@@ -198,6 +205,8 @@ public class ExpressionAndValuePointers {
ReadWriteIOUtils.write(1, stream);
} else if (assignment.valuePointer instanceof ScalarValuePointer) {
ReadWriteIOUtils.write(2, stream);
+ } else if (assignment.valuePointer instanceof AggregationValuePointer) {
+ ReadWriteIOUtils.write(3, stream);
} else {
throw new IllegalArgumentException("Unknown ValuePointer type");
}
@@ -209,6 +218,11 @@ public class ExpressionAndValuePointers {
ClassifierValuePointer.serialize((ClassifierValuePointer)
assignment.valuePointer, stream);
} else if (assignment.valuePointer instanceof ScalarValuePointer) {
ScalarValuePointer.serialize((ScalarValuePointer)
assignment.valuePointer, stream);
+ } else if (assignment.valuePointer instanceof AggregationValuePointer) {
+ AggregationValuePointer.serialize(
+ (AggregationValuePointer) assignment.valuePointer, stream);
+ } else {
+ throw new IllegalArgumentException("Unknown ValuePointer type");
}
}
@@ -224,6 +238,8 @@ public class ExpressionAndValuePointers {
valuePointer = ClassifierValuePointer.deserialize(byteBuffer);
} else if (type == 2) {
valuePointer = ScalarValuePointer.deserialize(byteBuffer);
+ } else if (type == 3) {
+ valuePointer = AggregationValuePointer.deserialize(byteBuffer);
} else {
throw new IllegalArgumentException("Unknown ValuePointer type");
}
diff --git
a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/execution/operator/process/rowpattern/MatcherTest.java
b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/execution/operator/process/rowpattern/MatcherTest.java
index 3b468eb2816..5c6c6da62cd 100644
---
a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/execution/operator/process/rowpattern/MatcherTest.java
+++
b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/execution/operator/process/rowpattern/MatcherTest.java
@@ -157,7 +157,7 @@ public class MatcherTest {
private static MatchResult match(IrRowPattern pattern, String input) {
Program program = IrRowPatternToProgramRewriter.rewrite(pattern,
LABEL_MAPPING);
- Matcher matcher = new Matcher(program);
+ Matcher matcher = new Matcher(program, ImmutableList.of());
int[] mappedInput = new int[input.length()];
for (int i = 0; i < input.length(); i++) {
@@ -195,7 +195,7 @@ public class MatcherTest {
* the label based on the definition of the label in the DEFINE clause.
*/
@Override
- public boolean evaluateLabel(ArrayView matchedLabels) {
+ public boolean evaluateLabel(ArrayView matchedLabels, PatternAggregator[]
patternAggregators) {
int position = matchedLabels.length() - 1;
return input[position] == matchedLabels.get(position);
}