This is an automated email from the ASF dual-hosted git repository. caogaofei pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/iotdb.git
The following commit(s) were added to refs/heads/master by this push: new e7ce326ca07 [opt](query) Optimize count(1), count(constant expression) to count(*) e7ce326ca07 is described below commit e7ce326ca07726f570efd7cb37deb51220841d36 Author: Beyyes <cgf1...@foxmail.com> AuthorDate: Wed Apr 2 17:31:41 2025 +0800 [opt](query) Optimize count(1), count(constant expression) to count(*) --- .../db/it/IoTDBMultiTAGsWithAttributesTableIT.java | 122 ++++++++++++++++- .../it/query/recent/IoTDBTableAggregationIT.java | 56 ++++++++ .../relational/analyzer/ExpressionAnalyzer.java | 2 +- .../plan/relational/function/FunctionId.java | 2 + .../plan/relational/metadata/ResolvedFunction.java | 1 - .../plan/relational/planner/ir/IrUtils.java | 7 + .../iterative/rule/SimplifyCountOverConstant.java | 145 +++++++++++++++++++++ .../optimizations/LogicalOptimizeFactory.java | 5 +- ...mQuantifiedComparisonApplyToCorrelatedJoin.java | 13 +- .../relational/planner/optimizations/Util.java | 2 +- .../plan/relational/analyzer/AggregationTest.java | 45 +++++++ 11 files changed, 381 insertions(+), 19 deletions(-) diff --git a/integration-test/src/test/java/org/apache/iotdb/relational/it/db/it/IoTDBMultiTAGsWithAttributesTableIT.java b/integration-test/src/test/java/org/apache/iotdb/relational/it/db/it/IoTDBMultiTAGsWithAttributesTableIT.java index 17e58550b02..caeba55e32e 100644 --- a/integration-test/src/test/java/org/apache/iotdb/relational/it/db/it/IoTDBMultiTAGsWithAttributesTableIT.java +++ b/integration-test/src/test/java/org/apache/iotdb/relational/it/db/it/IoTDBMultiTAGsWithAttributesTableIT.java @@ -537,6 +537,16 @@ public class IoTDBMultiTAGsWithAttributesTableIT { expectedHeader, retArray, DATABASE_NAME); + tableResultSetEqualTest( + "select count(num) as count_num, count(1) as count_star, avg(num) as avg_num, count(num) as count_num,\n" + + "count(attr2) as count_attr2, avg(num) as avg_num, count(device) as count_device,\n" + + "count(attr1) as count_attr1, count(device) as count_device, \n" + + "round(avg(floatnum)) as avg_floatnum, count(date) as count_date, " + + "count(time) as count_time, count(1) as count_star " + + "from table0", + expectedHeader, + retArray, + DATABASE_NAME); retArray = new String[] { @@ -553,6 +563,17 @@ public class IoTDBMultiTAGsWithAttributesTableIT { expectedHeader, retArray, DATABASE_NAME); + tableResultSetEqualTest( + "select count(num) as count_num, count(1) as count_star, avg(num) as avg_num, count(num) as count_num,\n" + + "count(attr2) as count_attr2, avg(num) as avg_num, count(device) as count_device,\n" + + "count(attr1) as count_attr1, count(device) as count_device, \n" + + "round(avg(floatnum)) as avg_floatnum, count(date) as count_date, " + + "count(time) as count_time, count(1) as count_star " + + "from table0 " + + "where time<200 and num>1 ", + expectedHeader, + retArray, + DATABASE_NAME); // TAG has null value expectedHeader = @@ -595,7 +616,7 @@ public class IoTDBMultiTAGsWithAttributesTableIT { String[] expectedHeader = new String[] {"_col0"}; String[] retArray = new String[] {"30,"}; - String sql = "SELECT count(num+1) from table0"; + sql = "SELECT count(num+1) from table0"; tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME); } @@ -603,7 +624,9 @@ public class IoTDBMultiTAGsWithAttributesTableIT { public void countStarTest() { expectedHeader = new String[] {"_col0", "_col1"}; retArray = new String[] {"1,1,"}; - String sql = "select count(*),count(t1) from (select avg(num+1) as t1 from table0)"; + sql = "select count(*),count(t1) from (select avg(num+1) as t1 from table0)"; + tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME); + sql = "select count(1),count(t1) from (select avg(num+1) as t1 from table0)"; tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME); expectedHeader = new String[] {"count_star"}; @@ -613,6 +636,8 @@ public class IoTDBMultiTAGsWithAttributesTableIT { }; tableResultSetEqualTest( "select count(*) as count_star from table0", expectedHeader, retArray, DATABASE_NAME); + tableResultSetEqualTest( + "select count(1) as count_star from table0", expectedHeader, retArray, DATABASE_NAME); retArray = new String[] { @@ -623,6 +648,11 @@ public class IoTDBMultiTAGsWithAttributesTableIT { expectedHeader, retArray, DATABASE_NAME); + tableResultSetEqualTest( + "select count(1) as count_star from (select count(1) from table0)", + expectedHeader, + retArray, + DATABASE_NAME); retArray = new String[] { @@ -633,6 +663,11 @@ public class IoTDBMultiTAGsWithAttributesTableIT { expectedHeader, retArray, DATABASE_NAME); + tableResultSetEqualTest( + "select count(1) as count_star from (select count(1), avg(num) from table0)", + expectedHeader, + retArray, + DATABASE_NAME); expectedHeader = new String[] {"sum"}; retArray = @@ -644,6 +679,11 @@ public class IoTDBMultiTAGsWithAttributesTableIT { expectedHeader, retArray, DATABASE_NAME); + tableResultSetEqualTest( + "select count_star + avg_num as sum from (select count(1) as count_star, avg(num) as avg_num from table0)", + expectedHeader, + retArray, + DATABASE_NAME); // TODO select count(*),count(t1) from (select avg(num+1) as t1 from table0) where time < 0 @@ -657,6 +697,11 @@ public class IoTDBMultiTAGsWithAttributesTableIT { expectedHeader, retArray, DATABASE_NAME); + tableResultSetEqualTest( + "select count(1) from (select device from table0 group by device, level)", + expectedHeader, + retArray, + DATABASE_NAME); } @Test @@ -693,6 +738,12 @@ public class IoTDBMultiTAGsWithAttributesTableIT { + "count(attr1) as count_attr1, count(attr2) as count_attr2, count(time) as count_time, sum(num) as sum_num " + "from table0 group by device,level order by device, level"; tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME); + sql = + "select device, level, " + + "count(num) as count_num, count(1) as count_star, count(device) as count_device, count(date) as count_date, " + + "count(attr1) as count_attr1, count(attr2) as count_attr2, count(time) as count_time, sum(num) as sum_num " + + "from table0 group by device,level order by device, level"; + tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME); expectedHeader = new String[] { @@ -716,6 +767,12 @@ public class IoTDBMultiTAGsWithAttributesTableIT { + "count(attr1) as count_attr1, count(attr2) as count_attr2, count(time) as count_time, sum(num) as sum_num " + "from table0 where device='d1' and level='l1' group by device order by device"; tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME); + sql = + "select device, " + + "count(num) as count_num, count(1) as count_star, count(device) as count_device, count(date) as count_date, " + + "count(attr1) as count_attr1, count(attr2) as count_attr2, count(time) as count_time, sum(num) as sum_num " + + "from table0 where device='d1' and level='l1' group by device order by device"; + tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME); expectedHeader = new String[] {"device", "level"}; retArray = @@ -816,6 +873,12 @@ public class IoTDBMultiTAGsWithAttributesTableIT { + "count(attr1) as count_attr1, count(attr2) as count_attr2, count(time) as count_time, avg(num) as avg_num " + "from table0 group by 3, device, level order by device, level, bin"; tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME); + sql = + "select device, level, date_bin(1y, time) as bin," + + "count(num) as count_num, count(1) as count_star, count(device) as count_device, count(date) as count_date, " + + "count(attr1) as count_attr1, count(attr2) as count_attr2, count(time) as count_time, avg(num) as avg_num " + + "from table0 group by 3, device, level order by device, level, bin"; + tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME); retArray = new String[] { @@ -854,6 +917,12 @@ public class IoTDBMultiTAGsWithAttributesTableIT { + "count(attr1) as count_attr1, count(attr2) as count_attr2, count(time) as count_time, avg(num) as avg_num " + "from table0 group by 3, device, level order by device, level, bin"; tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME); + sql = + "select device, level, date_bin(1d, time) as bin," + + "count(num) as count_num, count(1) as count_star, count(device) as count_device, count(date) as count_date, " + + "count(attr1) as count_attr1, count(attr2) as count_attr2, count(time) as count_time, avg(num) as avg_num " + + "from table0 group by 3, device, level order by device, level, bin"; + tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME); retArray = new String[] { @@ -894,6 +963,12 @@ public class IoTDBMultiTAGsWithAttributesTableIT { + "count(attr1) as count_attr1, count(attr2) as count_attr2, count(time) as count_time, avg(num) as avg_num " + "from table0 group by 3, device, level order by device, level, bin"; tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME); + sql = + "select device, level, date_bin(1s, time) as bin," + + "count(num) as count_num, count(1) as count_star, count(device) as count_device, count(date) as count_date, " + + "count(attr1) as count_attr1, count(attr2) as count_attr2, count(time) as count_time, avg(num) as avg_num " + + "from table0 group by 3, device, level order by device, level, bin"; + tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME); // only group by date_bin expectedHeader = new String[] {"bin"}; @@ -929,6 +1004,12 @@ public class IoTDBMultiTAGsWithAttributesTableIT { "1971-04-26T17:46:40.000Z,1,1,1,0,1,1,1,12.0," }; tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME); + sql = + "select date_bin(1s, time) as bin," + + "count(num) as count_num, count(1) as count_star, count(device) as count_device, count(date) as count_date, " + + "count(attr1) as count_attr1, count(attr2) as count_attr2, count(time) as count_time, avg(num) as avg_num " + + "from table0 where device='d1' and level='l2' group by 1"; + tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME); // flush multi times, generated multi tsfile expectedHeader = buildHeaders(1); @@ -975,6 +1056,11 @@ public class IoTDBMultiTAGsWithAttributesTableIT { + "count(attr1) as count_attr1, count(attr2) as count_attr2, count(time) as count_time, sum(num) as sum_num," + "avg(num) as avg_num from table0 where time=32"; tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME); + sql = + "select count(num) as count_num, count(1) as count_star, count(device) as count_device, count(date) as count_date, " + + "count(attr1) as count_attr1, count(attr2) as count_attr2, count(time) as count_time, sum(num) as sum_num," + + "avg(num) as avg_num from table0 where time=32"; + tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME); retArray = new String[] { @@ -985,6 +1071,11 @@ public class IoTDBMultiTAGsWithAttributesTableIT { + "count(attr1) as count_attr1, count(attr2) as count_attr2, count(time) as count_time, sum(num) as sum_num," + "avg(num) as avg_num from table0 where time=32 or time=1971-04-27T01:46:40.000+08:00"; tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME); + sql = + "select count(num) as count_num, count(1) as count_star, count(device) as count_device, count(date) as count_date, " + + "count(attr1) as count_attr1, count(attr2) as count_attr2, count(time) as count_time, sum(num) as sum_num," + + "avg(num) as avg_num from table0 where time=32 or time=1971-04-27T01:46:40.000+08:00"; + tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME); expectedHeader = new String[] { @@ -1006,12 +1097,22 @@ public class IoTDBMultiTAGsWithAttributesTableIT { + "count(attr1) as count_attr1, count(attr2) as count_attr2, count(time) as count_time, sum(num) as sum_num," + "avg(num) as avg_num from table0 where time=32 group by device, level order by device, level"; tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME); + sql = + "select device, level, count(num) as count_num, count(1) as count_star, count(device) as count_device, count(date) as count_date, " + + "count(attr1) as count_attr1, count(attr2) as count_attr2, count(time) as count_time, sum(num) as sum_num," + + "avg(num) as avg_num from table0 where time=32 group by device, level order by device, level"; + tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME); retArray = new String[] {"d1,l2,1,1,1,0,1,1,1,12.0,12.0,", "d2,l2,1,1,1,0,1,0,1,12.0,12.0,"}; sql = "select device, level, count(num) as count_num, count(*) as count_star, count(device) as count_device, count(date) as count_date, " + "count(attr1) as count_attr1, count(attr2) as count_attr2, count(time) as count_time, sum(num) as sum_num," + "avg(num) as avg_num from table0 where time=32 or time=1971-04-27T01:46:40.000+08:00 group by device, level order by device, level"; tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME); + sql = + "select device, level, count(num) as count_num, count(1) as count_star, count(device) as count_device, count(date) as count_date, " + + "count(attr1) as count_attr1, count(attr2) as count_attr2, count(time) as count_time, sum(num) as sum_num," + + "avg(num) as avg_num from table0 where time=32 or time=1971-04-27T01:46:40.000+08:00 group by device, level order by device, level"; + tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME); expectedHeader = new String[] { @@ -1035,6 +1136,12 @@ public class IoTDBMultiTAGsWithAttributesTableIT { + "count(attr1) as count_attr1, count(attr2) as count_attr2, count(time) as count_time, sum(num) as sum_num," + "avg(num) as avg_num from table0 where time=32 group by 3, device, level order by device, level"; tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME); + sql = + "select device, level, date_bin(1d, time) as bin, count(num) as count_num, count(1) as count_star, " + + "count(device) as count_device, count(date) as count_date, " + + "count(attr1) as count_attr1, count(attr2) as count_attr2, count(time) as count_time, sum(num) as sum_num," + + "avg(num) as avg_num from table0 where time=32 group by 3, device, level order by device, level"; + tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME); retArray = new String[] { "d1,l2,1971-04-26T00:00:00.000Z,1,1,1,0,1,1,1,12.0,12.0,", @@ -1046,16 +1153,27 @@ public class IoTDBMultiTAGsWithAttributesTableIT { + "count(attr1) as count_attr1, count(attr2) as count_attr2, count(time) as count_time, sum(num) as sum_num," + "avg(num) as avg_num from table0 where time=32 or time=1971-04-27T01:46:40.000+08:00 group by 3, device, level order by device, level"; tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME); + sql = + "select device, level, date_bin(1d, time) as bin, count(num) as count_num, count(1) as count_star, " + + "count(device) as count_device, count(date) as count_date, " + + "count(attr1) as count_attr1, count(attr2) as count_attr2, count(time) as count_time, sum(num) as sum_num," + + "avg(num) as avg_num from table0 where time=32 or time=1971-04-27T01:46:40.000+08:00 group by 3, device, level order by device, level"; + tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME); // queried device is not exist expectedHeader = buildHeaders(3); sql = "select count(*), count(num), sum(num) from table0 where device='d_not_exist'"; retArray = new String[] {"0,0,null,"}; tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME); + sql = "select count(1), count(num), sum(num) from table0 where device='d_not_exist'"; + tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME); sql = "select count(*), count(num), sum(num) from table0 where device='d_not_exist1' or device='d_not_exist2'"; retArray = new String[] {"0,0,null,"}; tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME); + sql = + "select count(1), count(num), sum(num) from table0 where device='d_not_exist1' or device='d_not_exist2'"; + tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME); // no data in given time range (push-down) sql = "select count(*), count(num), sum(num) from table0 where time>2100-04-26T18:01:40.000"; diff --git a/integration-test/src/test/java/org/apache/iotdb/relational/it/query/recent/IoTDBTableAggregationIT.java b/integration-test/src/test/java/org/apache/iotdb/relational/it/query/recent/IoTDBTableAggregationIT.java index dd47faebe97..70e5734e649 100644 --- a/integration-test/src/test/java/org/apache/iotdb/relational/it/query/recent/IoTDBTableAggregationIT.java +++ b/integration-test/src/test/java/org/apache/iotdb/relational/it/query/recent/IoTDBTableAggregationIT.java @@ -141,6 +141,11 @@ public class IoTDBTableAggregationIT { expectedHeader, retArray, DATABASE_NAME); + tableResultSetEqualTest( + "select count('a') from table1 where device_id = 'd01'", + expectedHeader, + retArray, + DATABASE_NAME); expectedHeader = new String[] {"_col0", "end_time", "device_id", "_col3"}; retArray = @@ -156,6 +161,11 @@ public class IoTDBTableAggregationIT { expectedHeader, retArray, DATABASE_NAME); + tableResultSetEqualTest( + "select date_bin(5s, time), (date_bin(5s, time) + 5000) as end_time, device_id, count(1) from table1 where device_id = 'd01' group by 1,device_id", + expectedHeader, + retArray, + DATABASE_NAME); expectedHeader = new String[] {"_col0", "province", "city", "region", "device_id", "_col5"}; retArray = @@ -230,6 +240,11 @@ public class IoTDBTableAggregationIT { expectedHeader, retArray, DATABASE_NAME); + tableResultSetEqualTest( + "select date_bin(5s, time),province,city,region,device_id, count(1) from table1 group by 1,2,3,4,5 order by 2,3,4,5,1", + expectedHeader, + retArray, + DATABASE_NAME); expectedHeader = new String[] { @@ -389,6 +404,11 @@ public class IoTDBTableAggregationIT { expectedHeader, retArray, DATABASE_NAME); + tableResultSetEqualTest( + "select province,city,region,device_id,count(1) from table1 group by 1,2,3,4 order by 1,2,3,4", + expectedHeader, + retArray, + DATABASE_NAME); expectedHeader = new String[] {"province", "city", "region", "_col3"}; retArray = @@ -403,6 +423,11 @@ public class IoTDBTableAggregationIT { expectedHeader, retArray, DATABASE_NAME); + tableResultSetEqualTest( + "select province,city,region,count(1) from table1 group by 1,2,3 order by 1,2,3", + expectedHeader, + retArray, + DATABASE_NAME); expectedHeader = new String[] {"province", "city", "_col2"}; retArray = @@ -414,6 +439,11 @@ public class IoTDBTableAggregationIT { expectedHeader, retArray, DATABASE_NAME); + tableResultSetEqualTest( + "select province,city,count(1) from table1 group by 1,2 order by 1,2", + expectedHeader, + retArray, + DATABASE_NAME); expectedHeader = new String[] {"province", "_col1"}; retArray = @@ -425,6 +455,11 @@ public class IoTDBTableAggregationIT { expectedHeader, retArray, DATABASE_NAME); + tableResultSetEqualTest( + "select province,count(1) from table1 group by 1 order by 1", + expectedHeader, + retArray, + DATABASE_NAME); expectedHeader = new String[] {"_col0"}; retArray = @@ -432,6 +467,7 @@ public class IoTDBTableAggregationIT { "64,", }; tableResultSetEqualTest("select count(*) from table1", expectedHeader, retArray, DATABASE_NAME); + tableResultSetEqualTest("select count(1) from table1", expectedHeader, retArray, DATABASE_NAME); } @Test @@ -3116,6 +3152,11 @@ public class IoTDBTableAggregationIT { expectedHeader, retArray, DATABASE_NAME); + tableResultSetEqualTest( + "select color,type, date_bin(5s, time), count(1) from table1 group by 1,2,3 order by 1,2,3", + expectedHeader, + retArray, + DATABASE_NAME); } @Test @@ -3489,6 +3530,11 @@ public class IoTDBTableAggregationIT { expectedHeader, retArray, DATABASE_NAME); + tableResultSetEqualTest( + "select province,city,region,device_id,s7,count(1) from table1 group by 1,2,3,4,5 order by 1,2,3,4,5", + expectedHeader, + retArray, + DATABASE_NAME); expectedHeader = new String[] {"province", "city", "region", "device_id", "s8", "_col5"}; retArray = @@ -3973,6 +4019,11 @@ public class IoTDBTableAggregationIT { expectedHeader, retArray, DATABASE_NAME); + tableResultSetEqualTest( + "SELECT color, device_id FROM (SELECT date_bin(5s, time), color, device_id, avg(s4) as avg_s4 FROM table1 WHERE type='A' AND (time >= 2024-09-24T06:15:30.000+00:00 AND time <= 2024-09-24T06:15:59.999+00:00) GROUP BY 1,2,3) WHERE avg_s4 > 1.0 GROUP BY color, device_id HAVING count(1) >= 2 ORDER BY color, device_id", + expectedHeader, + retArray, + DATABASE_NAME); expectedHeader = new String[] {"_col0", "city", "type", "_col3"}; retArray = @@ -3996,6 +4047,11 @@ public class IoTDBTableAggregationIT { expectedHeader, retArray, DATABASE_NAME); + tableResultSetEqualTest( + "SELECT date_bin(10s, five_seconds), city, type, sum(five_seconds_count) / 2 FROM (SELECT date_bin(5s, time) AS five_seconds, city, type, count(1) AS five_seconds_count FROM table1 WHERE (time >= 2024-09-24T06:15:30.000+00:00 AND time <= 2024-09-24T06:15:59.999+00:00) AND device_id IS NOT NULL GROUP BY 1, city, type, device_id HAVING avg(s1) > 1) GROUP BY 1, city, type order by 2,3,1", + expectedHeader, + retArray, + DATABASE_NAME); } @Test diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/ExpressionAnalyzer.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/ExpressionAnalyzer.java index 8bf9fcc2940..14a60f7cf3e 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/ExpressionAnalyzer.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/ExpressionAnalyzer.java @@ -802,7 +802,7 @@ public class ExpressionAnalyzer { ResolvedFunction resolvedFunction = new ResolvedFunction( new BoundSignature(functionName.toLowerCase(Locale.ENGLISH), type, argumentTypes), - new FunctionId("noop"), + FunctionId.NOOP_FUNCTION_ID, isAggregation ? FunctionKind.AGGREGATE : FunctionKind.SCALAR, true, isAggregation diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/FunctionId.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/FunctionId.java index cfc31ab471c..d814d8b4db6 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/FunctionId.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/FunctionId.java @@ -24,6 +24,8 @@ import java.util.Locale; import static java.util.Objects.requireNonNull; public class FunctionId { + public static final FunctionId NOOP_FUNCTION_ID = new FunctionId("noop"); + private final String id; public FunctionId(String id) { diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/ResolvedFunction.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/ResolvedFunction.java index 432488612f2..e783a05fd08 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/ResolvedFunction.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/ResolvedFunction.java @@ -38,7 +38,6 @@ public class ResolvedFunction { private final FunctionId functionId; private final FunctionKind functionKind; private final boolean deterministic; - private final FunctionNullability functionNullability; public ResolvedFunction( diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/ir/IrUtils.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/ir/IrUtils.java index d467431e474..1d175c51417 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/ir/IrUtils.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/ir/IrUtils.java @@ -261,6 +261,12 @@ public final class IrUtils { return combineConjuncts(conjuncts); } + /** + * Returns whether expression is effectively literal. An effectively literal expression is a + * simple constant value, or null, in either {@link Literal} form, or other form returned by + * LiteralEncoder. In particular, other constant expressions like a deterministic function call + * with constant arguments are not considered effectively literal. + */ public static boolean isEffectivelyLiteral( Expression expression, PlannerContext plannerContext, SessionInfo session) { if (expression instanceof Literal) { @@ -271,6 +277,7 @@ public final class IrUtils { // a Cast(Literal(...)) can fail, so this requires verification && constantExpressionEvaluatesSuccessfully(plannerContext, session, expression); } + if (expression instanceof FunctionCall) { String functionName = ((FunctionCall) expression).getName().getSuffix(); if (functionName.equals("pi") || functionName.equals("e")) { diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/iterative/rule/SimplifyCountOverConstant.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/iterative/rule/SimplifyCountOverConstant.java new file mode 100644 index 00000000000..e0ea8ffc7ec --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/iterative/rule/SimplifyCountOverConstant.java @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule; + +import org.apache.iotdb.db.queryengine.plan.relational.function.BoundSignature; +import org.apache.iotdb.db.queryengine.plan.relational.function.FunctionId; +import org.apache.iotdb.db.queryengine.plan.relational.metadata.ResolvedFunction; +import org.apache.iotdb.db.queryengine.plan.relational.planner.Assignments; +import org.apache.iotdb.db.queryengine.plan.relational.planner.PlannerContext; +import org.apache.iotdb.db.queryengine.plan.relational.planner.Symbol; +import org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.Rule; +import org.apache.iotdb.db.queryengine.plan.relational.planner.node.AggregationNode; +import org.apache.iotdb.db.queryengine.plan.relational.planner.node.ProjectNode; +import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Expression; +import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.SymbolReference; +import org.apache.iotdb.db.queryengine.plan.relational.utils.matching.Capture; +import org.apache.iotdb.db.queryengine.plan.relational.utils.matching.Captures; +import org.apache.iotdb.db.queryengine.plan.relational.utils.matching.Pattern; + +import com.google.common.collect.ImmutableList; +import org.apache.tsfile.read.common.type.LongType; + +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.Optional; + +import static com.google.common.base.Verify.verify; +import static java.util.Objects.requireNonNull; +import static org.apache.iotdb.db.queryengine.plan.relational.function.FunctionKind.AGGREGATE; +import static org.apache.iotdb.db.queryengine.plan.relational.metadata.FunctionNullability.getAggregationFunctionNullability; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.IrExpressionInterpreter.evaluateConstantExpression; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.ir.IrUtils.isEffectivelyLiteral; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.node.Patterns.aggregation; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.node.Patterns.project; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.node.Patterns.source; +import static org.apache.iotdb.db.queryengine.plan.relational.utils.matching.Capture.newCapture; +import static org.apache.iotdb.db.utils.constant.SqlConstant.COUNT; + +public class SimplifyCountOverConstant implements Rule<AggregationNode> { + private static final Capture<ProjectNode> CHILD = newCapture(); + + private static final Pattern<AggregationNode> PATTERN = + aggregation().with(source().matching(project().capturedAs(CHILD))); + + private final PlannerContext plannerContext; + + public SimplifyCountOverConstant(PlannerContext plannerContext) { + this.plannerContext = requireNonNull(plannerContext, "plannerContext is null"); + } + + @Override + public Pattern<AggregationNode> getPattern() { + return PATTERN; + } + + @Override + public Result apply(AggregationNode parent, Captures captures, Context context) { + ProjectNode child = captures.get(CHILD); + + boolean changed = false; + Map<Symbol, AggregationNode.Aggregation> aggregations = null; + ResolvedFunction countWildcardFunction = null; + + for (Map.Entry<Symbol, AggregationNode.Aggregation> entry : + parent.getAggregations().entrySet()) { + Symbol symbol = entry.getKey(); + AggregationNode.Aggregation aggregation = entry.getValue(); + + if (isCountOverConstant(context, aggregation, child.getAssignments())) { + changed = true; + + if (countWildcardFunction == null) { + aggregations = new LinkedHashMap<>(parent.getAggregations()); + countWildcardFunction = + new ResolvedFunction( + new BoundSignature(COUNT, LongType.INT64, Collections.emptyList()), + FunctionId.NOOP_FUNCTION_ID, + AGGREGATE, + true, + getAggregationFunctionNullability(1)); + } + + aggregations.put( + symbol, + new AggregationNode.Aggregation( + countWildcardFunction, + ImmutableList.of(), + false, + Optional.empty(), + Optional.empty(), + aggregation.getMask())); + } + } + + if (!changed) { + return Result.empty(); + } + + return Result.ofPlanNode( + AggregationNode.builderFrom(parent) + .setSource(child) + .setAggregations(aggregations) + .setPreGroupedSymbols(ImmutableList.of()) + .build()); + } + + private boolean isCountOverConstant( + Context context, AggregationNode.Aggregation aggregation, Assignments inputs) { + BoundSignature signature = aggregation.getResolvedFunction().getSignature(); + if (!signature.getName().equals(COUNT) || signature.getArgumentTypes().size() != 1) { + return false; + } + + Expression argument = aggregation.getArguments().get(0); + if (argument instanceof SymbolReference) { + argument = inputs.get(Symbol.from(argument)); + } + + if (isEffectivelyLiteral(argument, plannerContext, context.getSessionInfo())) { + Object value = evaluateConstantExpression(argument, plannerContext, context.getSessionInfo()); + verify(!(value instanceof Expression)); + return value != null; + } + + return false; + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/LogicalOptimizeFactory.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/LogicalOptimizeFactory.java index f3010fb3539..59df50bd861 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/LogicalOptimizeFactory.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/LogicalOptimizeFactory.java @@ -68,6 +68,7 @@ import org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.Re import org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.RemoveTrivialFilters; import org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.RemoveUnreferencedScalarApplyNodes; import org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.RemoveUnreferencedScalarSubqueries; +import org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.SimplifyCountOverConstant; import org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.SimplifyExpressions; import org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.SingleDistinctAggregationToGroupBy; import org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.TransformCorrelatedDistinctAggregationWithProjection; @@ -205,10 +206,10 @@ public class LogicalOptimizeFactory { // Our AggregationPushDown does not support AggregationNode with distinct, // so there is no need to put it after AggregationPushDown, // put it here to avoid extra ColumnPruning. - new MultipleDistinctAggregationToMarkDistinct() + new MultipleDistinctAggregationToMarkDistinct(), // new MergeLimitWithDistinct(), // new PruneCountAggregationOverScalar(metadata), - // new SimplifyCountOverConstant(plannerContext), + new SimplifyCountOverConstant(plannerContext) // new // PreAggregateCaseAggregations(plannerContext, typeAnalyzer))) )) diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/TransformQuantifiedComparisonApplyToCorrelatedJoin.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/TransformQuantifiedComparisonApplyToCorrelatedJoin.java index 00fecc1a69a..d7aaeda4f85 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/TransformQuantifiedComparisonApplyToCorrelatedJoin.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/TransformQuantifiedComparisonApplyToCorrelatedJoin.java @@ -21,10 +21,6 @@ package org.apache.iotdb.db.queryengine.plan.relational.planner.optimizations; import org.apache.iotdb.db.queryengine.common.QueryId; import org.apache.iotdb.db.queryengine.plan.planner.plan.node.PlanNode; -import org.apache.iotdb.db.queryengine.plan.relational.function.BoundSignature; -import org.apache.iotdb.db.queryengine.plan.relational.function.FunctionId; -import org.apache.iotdb.db.queryengine.plan.relational.function.FunctionKind; -import org.apache.iotdb.db.queryengine.plan.relational.metadata.FunctionNullability; import org.apache.iotdb.db.queryengine.plan.relational.metadata.Metadata; import org.apache.iotdb.db.queryengine.plan.relational.metadata.ResolvedFunction; import org.apache.iotdb.db.queryengine.plan.relational.planner.Assignments; @@ -54,7 +50,6 @@ import org.apache.tsfile.read.common.type.Type; import java.util.EnumSet; import java.util.List; -import java.util.Locale; import java.util.Optional; import java.util.function.Function; @@ -200,13 +195,7 @@ public class TransformQuantifiedComparisonApplyToCorrelatedJoin implements PlanO private ResolvedFunction getResolvedBuiltInAggregateFunction( String functionName, List<Type> argumentTypes) { // The same as the code in ExpressionAnalyzer - Type type = metadata.getFunctionReturnType(functionName, argumentTypes); - return new ResolvedFunction( - new BoundSignature(functionName.toLowerCase(Locale.ENGLISH), type, argumentTypes), - new FunctionId("noop"), - FunctionKind.AGGREGATE, - true, - FunctionNullability.getAggregationFunctionNullability(argumentTypes.size())); + return Util.getResolvedBuiltInAggregateFunction(metadata, functionName, argumentTypes); } public Expression rewriteUsingBounds( diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/Util.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/Util.java index 523dd09f9f9..3423168f5c8 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/Util.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/Util.java @@ -231,7 +231,7 @@ public class Util { Type type = metadata.getFunctionReturnType(functionName, argumentTypes); return new ResolvedFunction( new BoundSignature(functionName.toLowerCase(Locale.ENGLISH), type, argumentTypes), - new FunctionId("noop"), + FunctionId.NOOP_FUNCTION_ID, FunctionKind.AGGREGATE, true, FunctionNullability.getAggregationFunctionNullability(argumentTypes.size())); diff --git a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/AggregationTest.java b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/AggregationTest.java index eba94b76ca5..2227a26f93e 100644 --- a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/AggregationTest.java +++ b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/AggregationTest.java @@ -659,4 +659,49 @@ public class AggregationTest { ImmutableSet.of("tag1", "tag2", "tag3", "s1"))); } } + + @Test + public void countConstantTest() { + PlanTester planTester = new PlanTester(); + + LogicalQueryPlan ret = + planTester.createPlan("SELECT count(*) FROM table1 where tag1='beijing' and tag2='A1'"); + assertPlan( + ret, + output( + aggregationTableScan( + singleGroupingSet(), + ImmutableList.of(), // UnStreamable + Optional.empty(), + SINGLE, + "testdb.table1", + ImmutableList.of("count"), + ImmutableSet.of("time")))); + + ret = planTester.createPlan("SELECT count(1) FROM table1 where tag1='beijing' and tag2='A1'"); + assertPlan( + ret, + output( + aggregationTableScan( + singleGroupingSet(), + ImmutableList.of(), // UnStreamable + Optional.empty(), + SINGLE, + "testdb.table1", + ImmutableList.of("count"), + ImmutableSet.of("time")))); + + ret = planTester.createPlan("SELECT count('a') FROM table1 where tag1='beijing' and tag2='A1'"); + assertPlan( + ret, + output( + aggregationTableScan( + singleGroupingSet(), + ImmutableList.of(), // UnStreamable + Optional.empty(), + SINGLE, + "testdb.table1", + ImmutableList.of("count"), + ImmutableSet.of("time")))); + } }