This is an automated email from the ASF dual-hosted git repository.

caogaofei pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iotdb.git


The following commit(s) were added to refs/heads/master by this push:
     new e7ce326ca07 [opt](query) Optimize count(1), count(constant expression) 
to count(*)
e7ce326ca07 is described below

commit e7ce326ca07726f570efd7cb37deb51220841d36
Author: Beyyes <cgf1...@foxmail.com>
AuthorDate: Wed Apr 2 17:31:41 2025 +0800

    [opt](query) Optimize count(1), count(constant expression) to count(*)
---
 .../db/it/IoTDBMultiTAGsWithAttributesTableIT.java | 122 ++++++++++++++++-
 .../it/query/recent/IoTDBTableAggregationIT.java   |  56 ++++++++
 .../relational/analyzer/ExpressionAnalyzer.java    |   2 +-
 .../plan/relational/function/FunctionId.java       |   2 +
 .../plan/relational/metadata/ResolvedFunction.java |   1 -
 .../plan/relational/planner/ir/IrUtils.java        |   7 +
 .../iterative/rule/SimplifyCountOverConstant.java  | 145 +++++++++++++++++++++
 .../optimizations/LogicalOptimizeFactory.java      |   5 +-
 ...mQuantifiedComparisonApplyToCorrelatedJoin.java |  13 +-
 .../relational/planner/optimizations/Util.java     |   2 +-
 .../plan/relational/analyzer/AggregationTest.java  |  45 +++++++
 11 files changed, 381 insertions(+), 19 deletions(-)

diff --git 
a/integration-test/src/test/java/org/apache/iotdb/relational/it/db/it/IoTDBMultiTAGsWithAttributesTableIT.java
 
b/integration-test/src/test/java/org/apache/iotdb/relational/it/db/it/IoTDBMultiTAGsWithAttributesTableIT.java
index 17e58550b02..caeba55e32e 100644
--- 
a/integration-test/src/test/java/org/apache/iotdb/relational/it/db/it/IoTDBMultiTAGsWithAttributesTableIT.java
+++ 
b/integration-test/src/test/java/org/apache/iotdb/relational/it/db/it/IoTDBMultiTAGsWithAttributesTableIT.java
@@ -537,6 +537,16 @@ public class IoTDBMultiTAGsWithAttributesTableIT {
         expectedHeader,
         retArray,
         DATABASE_NAME);
+    tableResultSetEqualTest(
+        "select count(num) as count_num, count(1) as count_star, avg(num) as 
avg_num, count(num) as count_num,\n"
+            + "count(attr2) as count_attr2, avg(num) as avg_num, count(device) 
as count_device,\n"
+            + "count(attr1) as count_attr1, count(device) as count_device, \n"
+            + "round(avg(floatnum)) as avg_floatnum, count(date) as 
count_date, "
+            + "count(time) as count_time, count(1) as count_star "
+            + "from table0",
+        expectedHeader,
+        retArray,
+        DATABASE_NAME);
 
     retArray =
         new String[] {
@@ -553,6 +563,17 @@ public class IoTDBMultiTAGsWithAttributesTableIT {
         expectedHeader,
         retArray,
         DATABASE_NAME);
+    tableResultSetEqualTest(
+        "select count(num) as count_num, count(1) as count_star, avg(num) as 
avg_num, count(num) as count_num,\n"
+            + "count(attr2) as count_attr2, avg(num) as avg_num, count(device) 
as count_device,\n"
+            + "count(attr1) as count_attr1, count(device) as count_device, \n"
+            + "round(avg(floatnum)) as avg_floatnum, count(date) as 
count_date, "
+            + "count(time) as count_time, count(1) as count_star "
+            + "from table0 "
+            + "where time<200 and num>1 ",
+        expectedHeader,
+        retArray,
+        DATABASE_NAME);
 
     // TAG has null value
     expectedHeader =
@@ -595,7 +616,7 @@ public class IoTDBMultiTAGsWithAttributesTableIT {
     String[] expectedHeader = new String[] {"_col0"};
     String[] retArray = new String[] {"30,"};
 
-    String sql = "SELECT count(num+1) from table0";
+    sql = "SELECT count(num+1) from table0";
     tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME);
   }
 
@@ -603,7 +624,9 @@ public class IoTDBMultiTAGsWithAttributesTableIT {
   public void countStarTest() {
     expectedHeader = new String[] {"_col0", "_col1"};
     retArray = new String[] {"1,1,"};
-    String sql = "select count(*),count(t1) from (select avg(num+1) as t1 from 
table0)";
+    sql = "select count(*),count(t1) from (select avg(num+1) as t1 from 
table0)";
+    tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME);
+    sql = "select count(1),count(t1) from (select avg(num+1) as t1 from 
table0)";
     tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME);
 
     expectedHeader = new String[] {"count_star"};
@@ -613,6 +636,8 @@ public class IoTDBMultiTAGsWithAttributesTableIT {
         };
     tableResultSetEqualTest(
         "select count(*) as count_star from table0", expectedHeader, retArray, 
DATABASE_NAME);
+    tableResultSetEqualTest(
+        "select count(1) as count_star from table0", expectedHeader, retArray, 
DATABASE_NAME);
 
     retArray =
         new String[] {
@@ -623,6 +648,11 @@ public class IoTDBMultiTAGsWithAttributesTableIT {
         expectedHeader,
         retArray,
         DATABASE_NAME);
+    tableResultSetEqualTest(
+        "select count(1) as count_star from (select count(1) from table0)",
+        expectedHeader,
+        retArray,
+        DATABASE_NAME);
 
     retArray =
         new String[] {
@@ -633,6 +663,11 @@ public class IoTDBMultiTAGsWithAttributesTableIT {
         expectedHeader,
         retArray,
         DATABASE_NAME);
+    tableResultSetEqualTest(
+        "select count(1) as count_star from (select count(1), avg(num) from 
table0)",
+        expectedHeader,
+        retArray,
+        DATABASE_NAME);
 
     expectedHeader = new String[] {"sum"};
     retArray =
@@ -644,6 +679,11 @@ public class IoTDBMultiTAGsWithAttributesTableIT {
         expectedHeader,
         retArray,
         DATABASE_NAME);
+    tableResultSetEqualTest(
+        "select count_star + avg_num as sum from (select count(1) as 
count_star, avg(num) as avg_num from table0)",
+        expectedHeader,
+        retArray,
+        DATABASE_NAME);
 
     // TODO select count(*),count(t1) from (select avg(num+1) as t1 from 
table0) where time < 0
 
@@ -657,6 +697,11 @@ public class IoTDBMultiTAGsWithAttributesTableIT {
         expectedHeader,
         retArray,
         DATABASE_NAME);
+    tableResultSetEqualTest(
+        "select count(1) from (select device from table0 group by device, 
level)",
+        expectedHeader,
+        retArray,
+        DATABASE_NAME);
   }
 
   @Test
@@ -693,6 +738,12 @@ public class IoTDBMultiTAGsWithAttributesTableIT {
             + "count(attr1) as count_attr1, count(attr2) as count_attr2, 
count(time) as count_time, sum(num) as sum_num "
             + "from table0 group by device,level order by device, level";
     tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME);
+    sql =
+        "select device, level, "
+            + "count(num) as count_num, count(1) as count_star, count(device) 
as count_device, count(date) as count_date, "
+            + "count(attr1) as count_attr1, count(attr2) as count_attr2, 
count(time) as count_time, sum(num) as sum_num "
+            + "from table0 group by device,level order by device, level";
+    tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME);
 
     expectedHeader =
         new String[] {
@@ -716,6 +767,12 @@ public class IoTDBMultiTAGsWithAttributesTableIT {
             + "count(attr1) as count_attr1, count(attr2) as count_attr2, 
count(time) as count_time, sum(num) as sum_num "
             + "from table0 where device='d1' and level='l1' group by device 
order by device";
     tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME);
+    sql =
+        "select device, "
+            + "count(num) as count_num, count(1) as count_star, count(device) 
as count_device, count(date) as count_date, "
+            + "count(attr1) as count_attr1, count(attr2) as count_attr2, 
count(time) as count_time, sum(num) as sum_num "
+            + "from table0 where device='d1' and level='l1' group by device 
order by device";
+    tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME);
 
     expectedHeader = new String[] {"device", "level"};
     retArray =
@@ -816,6 +873,12 @@ public class IoTDBMultiTAGsWithAttributesTableIT {
             + "count(attr1) as count_attr1, count(attr2) as count_attr2, 
count(time) as count_time, avg(num) as avg_num "
             + "from table0 group by 3, device, level order by device, level, 
bin";
     tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME);
+    sql =
+        "select device, level, date_bin(1y, time) as bin,"
+            + "count(num) as count_num, count(1) as count_star, count(device) 
as count_device, count(date) as count_date, "
+            + "count(attr1) as count_attr1, count(attr2) as count_attr2, 
count(time) as count_time, avg(num) as avg_num "
+            + "from table0 group by 3, device, level order by device, level, 
bin";
+    tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME);
 
     retArray =
         new String[] {
@@ -854,6 +917,12 @@ public class IoTDBMultiTAGsWithAttributesTableIT {
             + "count(attr1) as count_attr1, count(attr2) as count_attr2, 
count(time) as count_time, avg(num) as avg_num "
             + "from table0 group by 3, device, level order by device, level, 
bin";
     tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME);
+    sql =
+        "select device, level, date_bin(1d, time) as bin,"
+            + "count(num) as count_num, count(1) as count_star, count(device) 
as count_device, count(date) as count_date, "
+            + "count(attr1) as count_attr1, count(attr2) as count_attr2, 
count(time) as count_time, avg(num) as avg_num "
+            + "from table0 group by 3, device, level order by device, level, 
bin";
+    tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME);
 
     retArray =
         new String[] {
@@ -894,6 +963,12 @@ public class IoTDBMultiTAGsWithAttributesTableIT {
             + "count(attr1) as count_attr1, count(attr2) as count_attr2, 
count(time) as count_time, avg(num) as avg_num "
             + "from table0 group by 3, device, level order by device, level, 
bin";
     tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME);
+    sql =
+        "select device, level, date_bin(1s, time) as bin,"
+            + "count(num) as count_num, count(1) as count_star, count(device) 
as count_device, count(date) as count_date, "
+            + "count(attr1) as count_attr1, count(attr2) as count_attr2, 
count(time) as count_time, avg(num) as avg_num "
+            + "from table0 group by 3, device, level order by device, level, 
bin";
+    tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME);
 
     // only group by date_bin
     expectedHeader = new String[] {"bin"};
@@ -929,6 +1004,12 @@ public class IoTDBMultiTAGsWithAttributesTableIT {
           "1971-04-26T17:46:40.000Z,1,1,1,0,1,1,1,12.0,"
         };
     tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME);
+    sql =
+        "select date_bin(1s, time) as bin,"
+            + "count(num) as count_num, count(1) as count_star, count(device) 
as count_device, count(date) as count_date, "
+            + "count(attr1) as count_attr1, count(attr2) as count_attr2, 
count(time) as count_time, avg(num) as avg_num "
+            + "from table0 where device='d1' and level='l2' group by 1";
+    tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME);
 
     // flush multi times, generated multi tsfile
     expectedHeader = buildHeaders(1);
@@ -975,6 +1056,11 @@ public class IoTDBMultiTAGsWithAttributesTableIT {
             + "count(attr1) as count_attr1, count(attr2) as count_attr2, 
count(time) as count_time, sum(num) as sum_num,"
             + "avg(num) as avg_num from table0 where time=32";
     tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME);
+    sql =
+        "select count(num) as count_num, count(1) as count_star, count(device) 
as count_device, count(date) as count_date, "
+            + "count(attr1) as count_attr1, count(attr2) as count_attr2, 
count(time) as count_time, sum(num) as sum_num,"
+            + "avg(num) as avg_num from table0 where time=32";
+    tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME);
 
     retArray =
         new String[] {
@@ -985,6 +1071,11 @@ public class IoTDBMultiTAGsWithAttributesTableIT {
             + "count(attr1) as count_attr1, count(attr2) as count_attr2, 
count(time) as count_time, sum(num) as sum_num,"
             + "avg(num) as avg_num from table0 where time=32 or 
time=1971-04-27T01:46:40.000+08:00";
     tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME);
+    sql =
+        "select count(num) as count_num, count(1) as count_star, count(device) 
as count_device, count(date) as count_date, "
+            + "count(attr1) as count_attr1, count(attr2) as count_attr2, 
count(time) as count_time, sum(num) as sum_num,"
+            + "avg(num) as avg_num from table0 where time=32 or 
time=1971-04-27T01:46:40.000+08:00";
+    tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME);
 
     expectedHeader =
         new String[] {
@@ -1006,12 +1097,22 @@ public class IoTDBMultiTAGsWithAttributesTableIT {
             + "count(attr1) as count_attr1, count(attr2) as count_attr2, 
count(time) as count_time, sum(num) as sum_num,"
             + "avg(num) as avg_num from table0 where time=32 group by device, 
level order by device, level";
     tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME);
+    sql =
+        "select device, level, count(num) as count_num, count(1) as 
count_star, count(device) as count_device, count(date) as count_date, "
+            + "count(attr1) as count_attr1, count(attr2) as count_attr2, 
count(time) as count_time, sum(num) as sum_num,"
+            + "avg(num) as avg_num from table0 where time=32 group by device, 
level order by device, level";
+    tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME);
     retArray = new String[] {"d1,l2,1,1,1,0,1,1,1,12.0,12.0,", 
"d2,l2,1,1,1,0,1,0,1,12.0,12.0,"};
     sql =
         "select device, level, count(num) as count_num, count(*) as 
count_star, count(device) as count_device, count(date) as count_date, "
             + "count(attr1) as count_attr1, count(attr2) as count_attr2, 
count(time) as count_time, sum(num) as sum_num,"
             + "avg(num) as avg_num from table0 where time=32 or 
time=1971-04-27T01:46:40.000+08:00 group by device, level order by device, 
level";
     tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME);
+    sql =
+        "select device, level, count(num) as count_num, count(1) as 
count_star, count(device) as count_device, count(date) as count_date, "
+            + "count(attr1) as count_attr1, count(attr2) as count_attr2, 
count(time) as count_time, sum(num) as sum_num,"
+            + "avg(num) as avg_num from table0 where time=32 or 
time=1971-04-27T01:46:40.000+08:00 group by device, level order by device, 
level";
+    tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME);
 
     expectedHeader =
         new String[] {
@@ -1035,6 +1136,12 @@ public class IoTDBMultiTAGsWithAttributesTableIT {
             + "count(attr1) as count_attr1, count(attr2) as count_attr2, 
count(time) as count_time, sum(num) as sum_num,"
             + "avg(num) as avg_num from table0 where time=32 group by 3, 
device, level order by device, level";
     tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME);
+    sql =
+        "select device, level, date_bin(1d, time) as bin, count(num) as 
count_num, count(1) as count_star, "
+            + "count(device) as count_device, count(date) as count_date, "
+            + "count(attr1) as count_attr1, count(attr2) as count_attr2, 
count(time) as count_time, sum(num) as sum_num,"
+            + "avg(num) as avg_num from table0 where time=32 group by 3, 
device, level order by device, level";
+    tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME);
     retArray =
         new String[] {
           "d1,l2,1971-04-26T00:00:00.000Z,1,1,1,0,1,1,1,12.0,12.0,",
@@ -1046,16 +1153,27 @@ public class IoTDBMultiTAGsWithAttributesTableIT {
             + "count(attr1) as count_attr1, count(attr2) as count_attr2, 
count(time) as count_time, sum(num) as sum_num,"
             + "avg(num) as avg_num from table0 where time=32 or 
time=1971-04-27T01:46:40.000+08:00 group by 3, device, level order by device, 
level";
     tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME);
+    sql =
+        "select device, level, date_bin(1d, time) as bin, count(num) as 
count_num, count(1) as count_star, "
+            + "count(device) as count_device, count(date) as count_date, "
+            + "count(attr1) as count_attr1, count(attr2) as count_attr2, 
count(time) as count_time, sum(num) as sum_num,"
+            + "avg(num) as avg_num from table0 where time=32 or 
time=1971-04-27T01:46:40.000+08:00 group by 3, device, level order by device, 
level";
+    tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME);
 
     // queried device is not exist
     expectedHeader = buildHeaders(3);
     sql = "select count(*), count(num), sum(num) from table0 where 
device='d_not_exist'";
     retArray = new String[] {"0,0,null,"};
     tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME);
+    sql = "select count(1), count(num), sum(num) from table0 where 
device='d_not_exist'";
+    tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME);
     sql =
         "select count(*), count(num), sum(num) from table0 where 
device='d_not_exist1' or device='d_not_exist2'";
     retArray = new String[] {"0,0,null,"};
     tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME);
+    sql =
+        "select count(1), count(num), sum(num) from table0 where 
device='d_not_exist1' or device='d_not_exist2'";
+    tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME);
 
     // no data in given time range (push-down)
     sql = "select count(*), count(num), sum(num) from table0 where 
time>2100-04-26T18:01:40.000";
diff --git 
a/integration-test/src/test/java/org/apache/iotdb/relational/it/query/recent/IoTDBTableAggregationIT.java
 
b/integration-test/src/test/java/org/apache/iotdb/relational/it/query/recent/IoTDBTableAggregationIT.java
index dd47faebe97..70e5734e649 100644
--- 
a/integration-test/src/test/java/org/apache/iotdb/relational/it/query/recent/IoTDBTableAggregationIT.java
+++ 
b/integration-test/src/test/java/org/apache/iotdb/relational/it/query/recent/IoTDBTableAggregationIT.java
@@ -141,6 +141,11 @@ public class IoTDBTableAggregationIT {
         expectedHeader,
         retArray,
         DATABASE_NAME);
+    tableResultSetEqualTest(
+        "select count('a') from table1 where device_id = 'd01'",
+        expectedHeader,
+        retArray,
+        DATABASE_NAME);
 
     expectedHeader = new String[] {"_col0", "end_time", "device_id", "_col3"};
     retArray =
@@ -156,6 +161,11 @@ public class IoTDBTableAggregationIT {
         expectedHeader,
         retArray,
         DATABASE_NAME);
+    tableResultSetEqualTest(
+        "select date_bin(5s, time), (date_bin(5s, time) + 5000) as end_time, 
device_id, count(1) from table1 where device_id = 'd01' group by 1,device_id",
+        expectedHeader,
+        retArray,
+        DATABASE_NAME);
 
     expectedHeader = new String[] {"_col0", "province", "city", "region", 
"device_id", "_col5"};
     retArray =
@@ -230,6 +240,11 @@ public class IoTDBTableAggregationIT {
         expectedHeader,
         retArray,
         DATABASE_NAME);
+    tableResultSetEqualTest(
+        "select date_bin(5s, time),province,city,region,device_id, count(1) 
from table1 group by 1,2,3,4,5 order by 2,3,4,5,1",
+        expectedHeader,
+        retArray,
+        DATABASE_NAME);
 
     expectedHeader =
         new String[] {
@@ -389,6 +404,11 @@ public class IoTDBTableAggregationIT {
         expectedHeader,
         retArray,
         DATABASE_NAME);
+    tableResultSetEqualTest(
+        "select province,city,region,device_id,count(1) from table1 group by 
1,2,3,4 order by 1,2,3,4",
+        expectedHeader,
+        retArray,
+        DATABASE_NAME);
 
     expectedHeader = new String[] {"province", "city", "region", "_col3"};
     retArray =
@@ -403,6 +423,11 @@ public class IoTDBTableAggregationIT {
         expectedHeader,
         retArray,
         DATABASE_NAME);
+    tableResultSetEqualTest(
+        "select province,city,region,count(1) from table1 group by 1,2,3 order 
by 1,2,3",
+        expectedHeader,
+        retArray,
+        DATABASE_NAME);
 
     expectedHeader = new String[] {"province", "city", "_col2"};
     retArray =
@@ -414,6 +439,11 @@ public class IoTDBTableAggregationIT {
         expectedHeader,
         retArray,
         DATABASE_NAME);
+    tableResultSetEqualTest(
+        "select province,city,count(1) from table1 group by 1,2 order by 1,2",
+        expectedHeader,
+        retArray,
+        DATABASE_NAME);
 
     expectedHeader = new String[] {"province", "_col1"};
     retArray =
@@ -425,6 +455,11 @@ public class IoTDBTableAggregationIT {
         expectedHeader,
         retArray,
         DATABASE_NAME);
+    tableResultSetEqualTest(
+        "select province,count(1) from table1 group by 1 order by 1",
+        expectedHeader,
+        retArray,
+        DATABASE_NAME);
 
     expectedHeader = new String[] {"_col0"};
     retArray =
@@ -432,6 +467,7 @@ public class IoTDBTableAggregationIT {
           "64,",
         };
     tableResultSetEqualTest("select count(*) from table1", expectedHeader, 
retArray, DATABASE_NAME);
+    tableResultSetEqualTest("select count(1) from table1", expectedHeader, 
retArray, DATABASE_NAME);
   }
 
   @Test
@@ -3116,6 +3152,11 @@ public class IoTDBTableAggregationIT {
         expectedHeader,
         retArray,
         DATABASE_NAME);
+    tableResultSetEqualTest(
+        "select color,type, date_bin(5s, time), count(1) from table1 group by 
1,2,3 order by 1,2,3",
+        expectedHeader,
+        retArray,
+        DATABASE_NAME);
   }
 
   @Test
@@ -3489,6 +3530,11 @@ public class IoTDBTableAggregationIT {
         expectedHeader,
         retArray,
         DATABASE_NAME);
+    tableResultSetEqualTest(
+        "select province,city,region,device_id,s7,count(1) from table1 group 
by 1,2,3,4,5 order by 1,2,3,4,5",
+        expectedHeader,
+        retArray,
+        DATABASE_NAME);
 
     expectedHeader = new String[] {"province", "city", "region", "device_id", 
"s8", "_col5"};
     retArray =
@@ -3973,6 +4019,11 @@ public class IoTDBTableAggregationIT {
         expectedHeader,
         retArray,
         DATABASE_NAME);
+    tableResultSetEqualTest(
+        "SELECT color, device_id FROM (SELECT date_bin(5s, time), color, 
device_id, avg(s4) as avg_s4 FROM table1 WHERE type='A' AND (time >= 
2024-09-24T06:15:30.000+00:00 AND time <= 2024-09-24T06:15:59.999+00:00) GROUP 
BY 1,2,3) WHERE avg_s4 > 1.0 GROUP BY color, device_id HAVING count(1) >= 2 
ORDER BY color, device_id",
+        expectedHeader,
+        retArray,
+        DATABASE_NAME);
 
     expectedHeader = new String[] {"_col0", "city", "type", "_col3"};
     retArray =
@@ -3996,6 +4047,11 @@ public class IoTDBTableAggregationIT {
         expectedHeader,
         retArray,
         DATABASE_NAME);
+    tableResultSetEqualTest(
+        "SELECT date_bin(10s, five_seconds), city, type, 
sum(five_seconds_count) / 2 FROM (SELECT date_bin(5s, time) AS five_seconds, 
city, type, count(1) AS five_seconds_count FROM table1 WHERE (time >= 
2024-09-24T06:15:30.000+00:00 AND time <= 2024-09-24T06:15:59.999+00:00) AND 
device_id IS NOT NULL GROUP BY 1, city, type, device_id HAVING avg(s1) > 1) 
GROUP BY 1, city, type order by 2,3,1",
+        expectedHeader,
+        retArray,
+        DATABASE_NAME);
   }
 
   @Test
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/ExpressionAnalyzer.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/ExpressionAnalyzer.java
index 8bf9fcc2940..14a60f7cf3e 100644
--- 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/ExpressionAnalyzer.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/ExpressionAnalyzer.java
@@ -802,7 +802,7 @@ public class ExpressionAnalyzer {
       ResolvedFunction resolvedFunction =
           new ResolvedFunction(
               new BoundSignature(functionName.toLowerCase(Locale.ENGLISH), 
type, argumentTypes),
-              new FunctionId("noop"),
+              FunctionId.NOOP_FUNCTION_ID,
               isAggregation ? FunctionKind.AGGREGATE : FunctionKind.SCALAR,
               true,
               isAggregation
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/FunctionId.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/FunctionId.java
index cfc31ab471c..d814d8b4db6 100644
--- 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/FunctionId.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/FunctionId.java
@@ -24,6 +24,8 @@ import java.util.Locale;
 import static java.util.Objects.requireNonNull;
 
 public class FunctionId {
+  public static final FunctionId NOOP_FUNCTION_ID = new FunctionId("noop");
+
   private final String id;
 
   public FunctionId(String id) {
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/ResolvedFunction.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/ResolvedFunction.java
index 432488612f2..e783a05fd08 100644
--- 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/ResolvedFunction.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/ResolvedFunction.java
@@ -38,7 +38,6 @@ public class ResolvedFunction {
   private final FunctionId functionId;
   private final FunctionKind functionKind;
   private final boolean deterministic;
-
   private final FunctionNullability functionNullability;
 
   public ResolvedFunction(
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/ir/IrUtils.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/ir/IrUtils.java
index d467431e474..1d175c51417 100644
--- 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/ir/IrUtils.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/ir/IrUtils.java
@@ -261,6 +261,12 @@ public final class IrUtils {
     return combineConjuncts(conjuncts);
   }
 
+  /**
+   * Returns whether expression is effectively literal. An effectively literal 
expression is a
+   * simple constant value, or null, in either {@link Literal} form, or other 
form returned by
+   * LiteralEncoder. In particular, other constant expressions like a 
deterministic function call
+   * with constant arguments are not considered effectively literal.
+   */
   public static boolean isEffectivelyLiteral(
       Expression expression, PlannerContext plannerContext, SessionInfo 
session) {
     if (expression instanceof Literal) {
@@ -271,6 +277,7 @@ public final class IrUtils {
           // a Cast(Literal(...)) can fail, so this requires verification
           && constantExpressionEvaluatesSuccessfully(plannerContext, session, 
expression);
     }
+
     if (expression instanceof FunctionCall) {
       String functionName = ((FunctionCall) expression).getName().getSuffix();
       if (functionName.equals("pi") || functionName.equals("e")) {
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/iterative/rule/SimplifyCountOverConstant.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/iterative/rule/SimplifyCountOverConstant.java
new file mode 100644
index 00000000000..e0ea8ffc7ec
--- /dev/null
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/iterative/rule/SimplifyCountOverConstant.java
@@ -0,0 +1,145 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule;
+
+import org.apache.iotdb.db.queryengine.plan.relational.function.BoundSignature;
+import org.apache.iotdb.db.queryengine.plan.relational.function.FunctionId;
+import 
org.apache.iotdb.db.queryengine.plan.relational.metadata.ResolvedFunction;
+import org.apache.iotdb.db.queryengine.plan.relational.planner.Assignments;
+import org.apache.iotdb.db.queryengine.plan.relational.planner.PlannerContext;
+import org.apache.iotdb.db.queryengine.plan.relational.planner.Symbol;
+import org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.Rule;
+import 
org.apache.iotdb.db.queryengine.plan.relational.planner.node.AggregationNode;
+import 
org.apache.iotdb.db.queryengine.plan.relational.planner.node.ProjectNode;
+import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Expression;
+import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.SymbolReference;
+import org.apache.iotdb.db.queryengine.plan.relational.utils.matching.Capture;
+import org.apache.iotdb.db.queryengine.plan.relational.utils.matching.Captures;
+import org.apache.iotdb.db.queryengine.plan.relational.utils.matching.Pattern;
+
+import com.google.common.collect.ImmutableList;
+import org.apache.tsfile.read.common.type.LongType;
+
+import java.util.Collections;
+import java.util.LinkedHashMap;
+import java.util.Map;
+import java.util.Optional;
+
+import static com.google.common.base.Verify.verify;
+import static java.util.Objects.requireNonNull;
+import static 
org.apache.iotdb.db.queryengine.plan.relational.function.FunctionKind.AGGREGATE;
+import static 
org.apache.iotdb.db.queryengine.plan.relational.metadata.FunctionNullability.getAggregationFunctionNullability;
+import static 
org.apache.iotdb.db.queryengine.plan.relational.planner.IrExpressionInterpreter.evaluateConstantExpression;
+import static 
org.apache.iotdb.db.queryengine.plan.relational.planner.ir.IrUtils.isEffectivelyLiteral;
+import static 
org.apache.iotdb.db.queryengine.plan.relational.planner.node.Patterns.aggregation;
+import static 
org.apache.iotdb.db.queryengine.plan.relational.planner.node.Patterns.project;
+import static 
org.apache.iotdb.db.queryengine.plan.relational.planner.node.Patterns.source;
+import static 
org.apache.iotdb.db.queryengine.plan.relational.utils.matching.Capture.newCapture;
+import static org.apache.iotdb.db.utils.constant.SqlConstant.COUNT;
+
+public class SimplifyCountOverConstant implements Rule<AggregationNode> {
+  private static final Capture<ProjectNode> CHILD = newCapture();
+
+  private static final Pattern<AggregationNode> PATTERN =
+      aggregation().with(source().matching(project().capturedAs(CHILD)));
+
+  private final PlannerContext plannerContext;
+
+  public SimplifyCountOverConstant(PlannerContext plannerContext) {
+    this.plannerContext = requireNonNull(plannerContext, "plannerContext is 
null");
+  }
+
+  @Override
+  public Pattern<AggregationNode> getPattern() {
+    return PATTERN;
+  }
+
+  @Override
+  public Result apply(AggregationNode parent, Captures captures, Context 
context) {
+    ProjectNode child = captures.get(CHILD);
+
+    boolean changed = false;
+    Map<Symbol, AggregationNode.Aggregation> aggregations = null;
+    ResolvedFunction countWildcardFunction = null;
+
+    for (Map.Entry<Symbol, AggregationNode.Aggregation> entry :
+        parent.getAggregations().entrySet()) {
+      Symbol symbol = entry.getKey();
+      AggregationNode.Aggregation aggregation = entry.getValue();
+
+      if (isCountOverConstant(context, aggregation, child.getAssignments())) {
+        changed = true;
+
+        if (countWildcardFunction == null) {
+          aggregations = new LinkedHashMap<>(parent.getAggregations());
+          countWildcardFunction =
+              new ResolvedFunction(
+                  new BoundSignature(COUNT, LongType.INT64, 
Collections.emptyList()),
+                  FunctionId.NOOP_FUNCTION_ID,
+                  AGGREGATE,
+                  true,
+                  getAggregationFunctionNullability(1));
+        }
+
+        aggregations.put(
+            symbol,
+            new AggregationNode.Aggregation(
+                countWildcardFunction,
+                ImmutableList.of(),
+                false,
+                Optional.empty(),
+                Optional.empty(),
+                aggregation.getMask()));
+      }
+    }
+
+    if (!changed) {
+      return Result.empty();
+    }
+
+    return Result.ofPlanNode(
+        AggregationNode.builderFrom(parent)
+            .setSource(child)
+            .setAggregations(aggregations)
+            .setPreGroupedSymbols(ImmutableList.of())
+            .build());
+  }
+
+  private boolean isCountOverConstant(
+      Context context, AggregationNode.Aggregation aggregation, Assignments 
inputs) {
+    BoundSignature signature = 
aggregation.getResolvedFunction().getSignature();
+    if (!signature.getName().equals(COUNT) || 
signature.getArgumentTypes().size() != 1) {
+      return false;
+    }
+
+    Expression argument = aggregation.getArguments().get(0);
+    if (argument instanceof SymbolReference) {
+      argument = inputs.get(Symbol.from(argument));
+    }
+
+    if (isEffectivelyLiteral(argument, plannerContext, 
context.getSessionInfo())) {
+      Object value = evaluateConstantExpression(argument, plannerContext, 
context.getSessionInfo());
+      verify(!(value instanceof Expression));
+      return value != null;
+    }
+
+    return false;
+  }
+}
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/LogicalOptimizeFactory.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/LogicalOptimizeFactory.java
index f3010fb3539..59df50bd861 100644
--- 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/LogicalOptimizeFactory.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/LogicalOptimizeFactory.java
@@ -68,6 +68,7 @@ import 
org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.Re
 import 
org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.RemoveTrivialFilters;
 import 
org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.RemoveUnreferencedScalarApplyNodes;
 import 
org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.RemoveUnreferencedScalarSubqueries;
+import 
org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.SimplifyCountOverConstant;
 import 
org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.SimplifyExpressions;
 import 
org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.SingleDistinctAggregationToGroupBy;
 import 
org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.TransformCorrelatedDistinctAggregationWithProjection;
@@ -205,10 +206,10 @@ public class LogicalOptimizeFactory {
                         // Our AggregationPushDown does not support 
AggregationNode with distinct,
                         // so there is no need to put it after 
AggregationPushDown,
                         // put it here to avoid extra ColumnPruning.
-                        new MultipleDistinctAggregationToMarkDistinct()
+                        new MultipleDistinctAggregationToMarkDistinct(),
                         //                        new MergeLimitWithDistinct(),
                         //                        new 
PruneCountAggregationOverScalar(metadata),
-                        //                        new 
SimplifyCountOverConstant(plannerContext),
+                        new SimplifyCountOverConstant(plannerContext)
                         //                        new
                         // PreAggregateCaseAggregations(plannerContext, 
typeAnalyzer)))
                         ))
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/TransformQuantifiedComparisonApplyToCorrelatedJoin.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/TransformQuantifiedComparisonApplyToCorrelatedJoin.java
index 00fecc1a69a..d7aaeda4f85 100644
--- 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/TransformQuantifiedComparisonApplyToCorrelatedJoin.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/TransformQuantifiedComparisonApplyToCorrelatedJoin.java
@@ -21,10 +21,6 @@ package 
org.apache.iotdb.db.queryengine.plan.relational.planner.optimizations;
 
 import org.apache.iotdb.db.queryengine.common.QueryId;
 import org.apache.iotdb.db.queryengine.plan.planner.plan.node.PlanNode;
-import org.apache.iotdb.db.queryengine.plan.relational.function.BoundSignature;
-import org.apache.iotdb.db.queryengine.plan.relational.function.FunctionId;
-import org.apache.iotdb.db.queryengine.plan.relational.function.FunctionKind;
-import 
org.apache.iotdb.db.queryengine.plan.relational.metadata.FunctionNullability;
 import org.apache.iotdb.db.queryengine.plan.relational.metadata.Metadata;
 import 
org.apache.iotdb.db.queryengine.plan.relational.metadata.ResolvedFunction;
 import org.apache.iotdb.db.queryengine.plan.relational.planner.Assignments;
@@ -54,7 +50,6 @@ import org.apache.tsfile.read.common.type.Type;
 
 import java.util.EnumSet;
 import java.util.List;
-import java.util.Locale;
 import java.util.Optional;
 import java.util.function.Function;
 
@@ -200,13 +195,7 @@ public class 
TransformQuantifiedComparisonApplyToCorrelatedJoin implements PlanO
     private ResolvedFunction getResolvedBuiltInAggregateFunction(
         String functionName, List<Type> argumentTypes) {
       // The same as the code in ExpressionAnalyzer
-      Type type = metadata.getFunctionReturnType(functionName, argumentTypes);
-      return new ResolvedFunction(
-          new BoundSignature(functionName.toLowerCase(Locale.ENGLISH), type, 
argumentTypes),
-          new FunctionId("noop"),
-          FunctionKind.AGGREGATE,
-          true,
-          
FunctionNullability.getAggregationFunctionNullability(argumentTypes.size()));
+      return Util.getResolvedBuiltInAggregateFunction(metadata, functionName, 
argumentTypes);
     }
 
     public Expression rewriteUsingBounds(
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/Util.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/Util.java
index 523dd09f9f9..3423168f5c8 100644
--- 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/Util.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/Util.java
@@ -231,7 +231,7 @@ public class Util {
     Type type = metadata.getFunctionReturnType(functionName, argumentTypes);
     return new ResolvedFunction(
         new BoundSignature(functionName.toLowerCase(Locale.ENGLISH), type, 
argumentTypes),
-        new FunctionId("noop"),
+        FunctionId.NOOP_FUNCTION_ID,
         FunctionKind.AGGREGATE,
         true,
         
FunctionNullability.getAggregationFunctionNullability(argumentTypes.size()));
diff --git 
a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/AggregationTest.java
 
b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/AggregationTest.java
index eba94b76ca5..2227a26f93e 100644
--- 
a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/AggregationTest.java
+++ 
b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/AggregationTest.java
@@ -659,4 +659,49 @@ public class AggregationTest {
               ImmutableSet.of("tag1", "tag2", "tag3", "s1")));
     }
   }
+
+  @Test
+  public void countConstantTest() {
+    PlanTester planTester = new PlanTester();
+
+    LogicalQueryPlan ret =
+        planTester.createPlan("SELECT count(*) FROM table1 where 
tag1='beijing' and tag2='A1'");
+    assertPlan(
+        ret,
+        output(
+            aggregationTableScan(
+                singleGroupingSet(),
+                ImmutableList.of(), // UnStreamable
+                Optional.empty(),
+                SINGLE,
+                "testdb.table1",
+                ImmutableList.of("count"),
+                ImmutableSet.of("time"))));
+
+    ret = planTester.createPlan("SELECT count(1) FROM table1 where 
tag1='beijing' and tag2='A1'");
+    assertPlan(
+        ret,
+        output(
+            aggregationTableScan(
+                singleGroupingSet(),
+                ImmutableList.of(), // UnStreamable
+                Optional.empty(),
+                SINGLE,
+                "testdb.table1",
+                ImmutableList.of("count"),
+                ImmutableSet.of("time"))));
+
+    ret = planTester.createPlan("SELECT count('a') FROM table1 where 
tag1='beijing' and tag2='A1'");
+    assertPlan(
+        ret,
+        output(
+            aggregationTableScan(
+                singleGroupingSet(),
+                ImmutableList.of(), // UnStreamable
+                Optional.empty(),
+                SINGLE,
+                "testdb.table1",
+                ImmutableList.of("count"),
+                ImmutableSet.of("time"))));
+  }
 }


Reply via email to