This is an automated email from the ASF dual-hosted git repository.

caogaofei pushed a commit to branch fix_count_1
in repository https://gitbox.apache.org/repos/asf/iotdb.git

commit 380fe29d60c9c4af7e138c82e0b77d2e7a868c4a
Author: Beyyes <[email protected]>
AuthorDate: Fri Mar 14 17:59:32 2025 +0800

    optimize count(constant) to count(*)
---
 .../db/it/IoTDBMultiTAGsWithAttributesTableIT.java | 122 +++++++++++++++++++-
 .../it/query/recent/IoTDBTableAggregationIT.java   |  56 +++++++++
 .../relational/analyzer/ExpressionAnalyzer.java    |   2 +-
 .../plan/relational/function/FunctionId.java       |   2 +
 .../plan/relational/metadata/ResolvedFunction.java |   7 +-
 .../plan/relational/planner/ir/IrUtils.java        |   7 ++
 .../iterative/rule/SimplifyCountOverConstant.java  | 126 +++++++++++++++++++++
 .../optimizations/LogicalOptimizeFactory.java      |   5 +-
 ...mQuantifiedComparisonApplyToCorrelatedJoin.java |  13 +--
 .../relational/planner/optimizations/Util.java     |   2 +-
 .../plan/relational/analyzer/AggregationTest.java  |  14 ++-
 11 files changed, 333 insertions(+), 23 deletions(-)

diff --git 
a/integration-test/src/test/java/org/apache/iotdb/relational/it/db/it/IoTDBMultiTAGsWithAttributesTableIT.java
 
b/integration-test/src/test/java/org/apache/iotdb/relational/it/db/it/IoTDBMultiTAGsWithAttributesTableIT.java
index 667ec3e06d4..20e50c0a089 100644
--- 
a/integration-test/src/test/java/org/apache/iotdb/relational/it/db/it/IoTDBMultiTAGsWithAttributesTableIT.java
+++ 
b/integration-test/src/test/java/org/apache/iotdb/relational/it/db/it/IoTDBMultiTAGsWithAttributesTableIT.java
@@ -536,6 +536,16 @@ public class IoTDBMultiTAGsWithAttributesTableIT {
         expectedHeader,
         retArray,
         DATABASE_NAME);
+    tableResultSetEqualTest(
+        "select count(num) as count_num, count(1) as count_star, avg(num) as 
avg_num, count(num) as count_num,\n"
+            + "count(attr2) as count_attr2, avg(num) as avg_num, count(device) 
as count_device,\n"
+            + "count(attr1) as count_attr1, count(device) as count_device, \n"
+            + "round(avg(floatnum)) as avg_floatnum, count(date) as 
count_date, "
+            + "count(time) as count_time, count(1) as count_star "
+            + "from table0",
+        expectedHeader,
+        retArray,
+        DATABASE_NAME);
 
     retArray =
         new String[] {
@@ -552,6 +562,17 @@ public class IoTDBMultiTAGsWithAttributesTableIT {
         expectedHeader,
         retArray,
         DATABASE_NAME);
+    tableResultSetEqualTest(
+        "select count(num) as count_num, count(1) as count_star, avg(num) as 
avg_num, count(num) as count_num,\n"
+            + "count(attr2) as count_attr2, avg(num) as avg_num, count(device) 
as count_device,\n"
+            + "count(attr1) as count_attr1, count(device) as count_device, \n"
+            + "round(avg(floatnum)) as avg_floatnum, count(date) as 
count_date, "
+            + "count(time) as count_time, count(1) as count_star "
+            + "from table0 "
+            + "where time<200 and num>1 ",
+        expectedHeader,
+        retArray,
+        DATABASE_NAME);
 
     // TAG has null value
     expectedHeader =
@@ -594,7 +615,7 @@ public class IoTDBMultiTAGsWithAttributesTableIT {
     String[] expectedHeader = new String[] {"_col0"};
     String[] retArray = new String[] {"30,"};
 
-    String sql = "SELECT count(num+1) from table0";
+    sql = "SELECT count(num+1) from table0";
     tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME);
   }
 
@@ -602,7 +623,9 @@ public class IoTDBMultiTAGsWithAttributesTableIT {
   public void countStarTest() {
     expectedHeader = new String[] {"_col0", "_col1"};
     retArray = new String[] {"1,1,"};
-    String sql = "select count(*),count(t1) from (select avg(num+1) as t1 from 
table0)";
+    sql = "select count(*),count(t1) from (select avg(num+1) as t1 from 
table0)";
+    tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME);
+    sql = "select count(1),count(t1) from (select avg(num+1) as t1 from 
table0)";
     tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME);
 
     expectedHeader = new String[] {"count_star"};
@@ -612,6 +635,8 @@ public class IoTDBMultiTAGsWithAttributesTableIT {
         };
     tableResultSetEqualTest(
         "select count(*) as count_star from table0", expectedHeader, retArray, 
DATABASE_NAME);
+    tableResultSetEqualTest(
+        "select count(1) as count_star from table0", expectedHeader, retArray, 
DATABASE_NAME);
 
     retArray =
         new String[] {
@@ -622,6 +647,11 @@ public class IoTDBMultiTAGsWithAttributesTableIT {
         expectedHeader,
         retArray,
         DATABASE_NAME);
+    tableResultSetEqualTest(
+        "select count(1) as count_star from (select count(1) from table0)",
+        expectedHeader,
+        retArray,
+        DATABASE_NAME);
 
     retArray =
         new String[] {
@@ -632,6 +662,11 @@ public class IoTDBMultiTAGsWithAttributesTableIT {
         expectedHeader,
         retArray,
         DATABASE_NAME);
+    tableResultSetEqualTest(
+        "select count(1) as count_star from (select count(1), avg(num) from 
table0)",
+        expectedHeader,
+        retArray,
+        DATABASE_NAME);
 
     expectedHeader = new String[] {"sum"};
     retArray =
@@ -643,6 +678,11 @@ public class IoTDBMultiTAGsWithAttributesTableIT {
         expectedHeader,
         retArray,
         DATABASE_NAME);
+    tableResultSetEqualTest(
+        "select count_star + avg_num as sum from (select count(1) as 
count_star, avg(num) as avg_num from table0)",
+        expectedHeader,
+        retArray,
+        DATABASE_NAME);
 
     // TODO select count(*),count(t1) from (select avg(num+1) as t1 from 
table0) where time < 0
 
@@ -656,6 +696,11 @@ public class IoTDBMultiTAGsWithAttributesTableIT {
         expectedHeader,
         retArray,
         DATABASE_NAME);
+    tableResultSetEqualTest(
+        "select count(1) from (select device from table0 group by device, 
level)",
+        expectedHeader,
+        retArray,
+        DATABASE_NAME);
   }
 
   @Test
@@ -692,6 +737,12 @@ public class IoTDBMultiTAGsWithAttributesTableIT {
             + "count(attr1) as count_attr1, count(attr2) as count_attr2, 
count(time) as count_time, sum(num) as sum_num "
             + "from table0 group by device,level order by device, level";
     tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME);
+    sql =
+        "select device, level, "
+            + "count(num) as count_num, count(1) as count_star, count(device) 
as count_device, count(date) as count_date, "
+            + "count(attr1) as count_attr1, count(attr2) as count_attr2, 
count(time) as count_time, sum(num) as sum_num "
+            + "from table0 group by device,level order by device, level";
+    tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME);
 
     expectedHeader =
         new String[] {
@@ -715,6 +766,12 @@ public class IoTDBMultiTAGsWithAttributesTableIT {
             + "count(attr1) as count_attr1, count(attr2) as count_attr2, 
count(time) as count_time, sum(num) as sum_num "
             + "from table0 where device='d1' and level='l1' group by device 
order by device";
     tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME);
+    sql =
+        "select device, "
+            + "count(num) as count_num, count(1) as count_star, count(device) 
as count_device, count(date) as count_date, "
+            + "count(attr1) as count_attr1, count(attr2) as count_attr2, 
count(time) as count_time, sum(num) as sum_num "
+            + "from table0 where device='d1' and level='l1' group by device 
order by device";
+    tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME);
 
     expectedHeader = new String[] {"device", "level"};
     retArray =
@@ -815,6 +872,12 @@ public class IoTDBMultiTAGsWithAttributesTableIT {
             + "count(attr1) as count_attr1, count(attr2) as count_attr2, 
count(time) as count_time, avg(num) as avg_num "
             + "from table0 group by 3, device, level order by device, level, 
bin";
     tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME);
+    sql =
+        "select device, level, date_bin(1y, time) as bin,"
+            + "count(num) as count_num, count(1) as count_star, count(device) 
as count_device, count(date) as count_date, "
+            + "count(attr1) as count_attr1, count(attr2) as count_attr2, 
count(time) as count_time, avg(num) as avg_num "
+            + "from table0 group by 3, device, level order by device, level, 
bin";
+    tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME);
 
     retArray =
         new String[] {
@@ -853,6 +916,12 @@ public class IoTDBMultiTAGsWithAttributesTableIT {
             + "count(attr1) as count_attr1, count(attr2) as count_attr2, 
count(time) as count_time, avg(num) as avg_num "
             + "from table0 group by 3, device, level order by device, level, 
bin";
     tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME);
+    sql =
+        "select device, level, date_bin(1d, time) as bin,"
+            + "count(num) as count_num, count(1) as count_star, count(device) 
as count_device, count(date) as count_date, "
+            + "count(attr1) as count_attr1, count(attr2) as count_attr2, 
count(time) as count_time, avg(num) as avg_num "
+            + "from table0 group by 3, device, level order by device, level, 
bin";
+    tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME);
 
     retArray =
         new String[] {
@@ -893,6 +962,12 @@ public class IoTDBMultiTAGsWithAttributesTableIT {
             + "count(attr1) as count_attr1, count(attr2) as count_attr2, 
count(time) as count_time, avg(num) as avg_num "
             + "from table0 group by 3, device, level order by device, level, 
bin";
     tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME);
+    sql =
+        "select device, level, date_bin(1s, time) as bin,"
+            + "count(num) as count_num, count(1) as count_star, count(device) 
as count_device, count(date) as count_date, "
+            + "count(attr1) as count_attr1, count(attr2) as count_attr2, 
count(time) as count_time, avg(num) as avg_num "
+            + "from table0 group by 3, device, level order by device, level, 
bin";
+    tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME);
 
     // only group by date_bin
     expectedHeader = new String[] {"bin"};
@@ -928,6 +1003,12 @@ public class IoTDBMultiTAGsWithAttributesTableIT {
           "1971-04-26T17:46:40.000Z,1,1,1,0,1,1,1,12.0,"
         };
     tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME);
+    sql =
+        "select date_bin(1s, time) as bin,"
+            + "count(num) as count_num, count(1) as count_star, count(device) 
as count_device, count(date) as count_date, "
+            + "count(attr1) as count_attr1, count(attr2) as count_attr2, 
count(time) as count_time, avg(num) as avg_num "
+            + "from table0 where device='d1' and level='l2' group by 1";
+    tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME);
 
     // flush multi times, generated multi tsfile
     expectedHeader = buildHeaders(1);
@@ -974,6 +1055,11 @@ public class IoTDBMultiTAGsWithAttributesTableIT {
             + "count(attr1) as count_attr1, count(attr2) as count_attr2, 
count(time) as count_time, sum(num) as sum_num,"
             + "avg(num) as avg_num from table0 where time=32";
     tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME);
+    sql =
+        "select count(num) as count_num, count(1) as count_star, count(device) 
as count_device, count(date) as count_date, "
+            + "count(attr1) as count_attr1, count(attr2) as count_attr2, 
count(time) as count_time, sum(num) as sum_num,"
+            + "avg(num) as avg_num from table0 where time=32";
+    tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME);
 
     retArray =
         new String[] {
@@ -984,6 +1070,11 @@ public class IoTDBMultiTAGsWithAttributesTableIT {
             + "count(attr1) as count_attr1, count(attr2) as count_attr2, 
count(time) as count_time, sum(num) as sum_num,"
             + "avg(num) as avg_num from table0 where time=32 or 
time=1971-04-27T01:46:40.000+08:00";
     tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME);
+    sql =
+        "select count(num) as count_num, count(1) as count_star, count(device) 
as count_device, count(date) as count_date, "
+            + "count(attr1) as count_attr1, count(attr2) as count_attr2, 
count(time) as count_time, sum(num) as sum_num,"
+            + "avg(num) as avg_num from table0 where time=32 or 
time=1971-04-27T01:46:40.000+08:00";
+    tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME);
 
     expectedHeader =
         new String[] {
@@ -1005,12 +1096,22 @@ public class IoTDBMultiTAGsWithAttributesTableIT {
             + "count(attr1) as count_attr1, count(attr2) as count_attr2, 
count(time) as count_time, sum(num) as sum_num,"
             + "avg(num) as avg_num from table0 where time=32 group by device, 
level order by device, level";
     tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME);
+    sql =
+        "select device, level, count(num) as count_num, count(1) as 
count_star, count(device) as count_device, count(date) as count_date, "
+            + "count(attr1) as count_attr1, count(attr2) as count_attr2, 
count(time) as count_time, sum(num) as sum_num,"
+            + "avg(num) as avg_num from table0 where time=32 group by device, 
level order by device, level";
+    tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME);
     retArray = new String[] {"d1,l2,1,1,1,0,1,1,1,12.0,12.0,", 
"d2,l2,1,1,1,0,1,0,1,12.0,12.0,"};
     sql =
         "select device, level, count(num) as count_num, count(*) as 
count_star, count(device) as count_device, count(date) as count_date, "
             + "count(attr1) as count_attr1, count(attr2) as count_attr2, 
count(time) as count_time, sum(num) as sum_num,"
             + "avg(num) as avg_num from table0 where time=32 or 
time=1971-04-27T01:46:40.000+08:00 group by device, level order by device, 
level";
     tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME);
+    sql =
+        "select device, level, count(num) as count_num, count(1) as 
count_star, count(device) as count_device, count(date) as count_date, "
+            + "count(attr1) as count_attr1, count(attr2) as count_attr2, 
count(time) as count_time, sum(num) as sum_num,"
+            + "avg(num) as avg_num from table0 where time=32 or 
time=1971-04-27T01:46:40.000+08:00 group by device, level order by device, 
level";
+    tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME);
 
     expectedHeader =
         new String[] {
@@ -1034,6 +1135,12 @@ public class IoTDBMultiTAGsWithAttributesTableIT {
             + "count(attr1) as count_attr1, count(attr2) as count_attr2, 
count(time) as count_time, sum(num) as sum_num,"
             + "avg(num) as avg_num from table0 where time=32 group by 3, 
device, level order by device, level";
     tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME);
+    sql =
+        "select device, level, date_bin(1d, time) as bin, count(num) as 
count_num, count(1) as count_star, "
+            + "count(device) as count_device, count(date) as count_date, "
+            + "count(attr1) as count_attr1, count(attr2) as count_attr2, 
count(time) as count_time, sum(num) as sum_num,"
+            + "avg(num) as avg_num from table0 where time=32 group by 3, 
device, level order by device, level";
+    tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME);
     retArray =
         new String[] {
           "d1,l2,1971-04-26T00:00:00.000Z,1,1,1,0,1,1,1,12.0,12.0,",
@@ -1045,16 +1152,27 @@ public class IoTDBMultiTAGsWithAttributesTableIT {
             + "count(attr1) as count_attr1, count(attr2) as count_attr2, 
count(time) as count_time, sum(num) as sum_num,"
             + "avg(num) as avg_num from table0 where time=32 or 
time=1971-04-27T01:46:40.000+08:00 group by 3, device, level order by device, 
level";
     tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME);
+    sql =
+        "select device, level, date_bin(1d, time) as bin, count(num) as 
count_num, count(1) as count_star, "
+            + "count(device) as count_device, count(date) as count_date, "
+            + "count(attr1) as count_attr1, count(attr2) as count_attr2, 
count(time) as count_time, sum(num) as sum_num,"
+            + "avg(num) as avg_num from table0 where time=32 or 
time=1971-04-27T01:46:40.000+08:00 group by 3, device, level order by device, 
level";
+    tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME);
 
     // queried device is not exist
     expectedHeader = buildHeaders(3);
     sql = "select count(*), count(num), sum(num) from table0 where 
device='d_not_exist'";
     retArray = new String[] {"0,0,null,"};
     tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME);
+    sql = "select count(1), count(num), sum(num) from table0 where 
device='d_not_exist'";
+    tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME);
     sql =
         "select count(*), count(num), sum(num) from table0 where 
device='d_not_exist1' or device='d_not_exist2'";
     retArray = new String[] {"0,0,null,"};
     tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME);
+    sql =
+        "select count(1), count(num), sum(num) from table0 where 
device='d_not_exist1' or device='d_not_exist2'";
+    tableResultSetEqualTest(sql, expectedHeader, retArray, DATABASE_NAME);
 
     // no data in given time range (push-down)
     sql = "select count(*), count(num), sum(num) from table0 where 
time>2100-04-26T18:01:40.000";
diff --git 
a/integration-test/src/test/java/org/apache/iotdb/relational/it/query/recent/IoTDBTableAggregationIT.java
 
b/integration-test/src/test/java/org/apache/iotdb/relational/it/query/recent/IoTDBTableAggregationIT.java
index dd47faebe97..2225317d2f4 100644
--- 
a/integration-test/src/test/java/org/apache/iotdb/relational/it/query/recent/IoTDBTableAggregationIT.java
+++ 
b/integration-test/src/test/java/org/apache/iotdb/relational/it/query/recent/IoTDBTableAggregationIT.java
@@ -141,6 +141,11 @@ public class IoTDBTableAggregationIT {
         expectedHeader,
         retArray,
         DATABASE_NAME);
+    tableResultSetEqualTest(
+        "select count(1) from table1 where device_id = 'd01'",
+        expectedHeader,
+        retArray,
+        DATABASE_NAME);
 
     expectedHeader = new String[] {"_col0", "end_time", "device_id", "_col3"};
     retArray =
@@ -156,6 +161,11 @@ public class IoTDBTableAggregationIT {
         expectedHeader,
         retArray,
         DATABASE_NAME);
+    tableResultSetEqualTest(
+        "select date_bin(5s, time), (date_bin(5s, time) + 5000) as end_time, 
device_id, count(1) from table1 where device_id = 'd01' group by 1,device_id",
+        expectedHeader,
+        retArray,
+        DATABASE_NAME);
 
     expectedHeader = new String[] {"_col0", "province", "city", "region", 
"device_id", "_col5"};
     retArray =
@@ -230,6 +240,11 @@ public class IoTDBTableAggregationIT {
         expectedHeader,
         retArray,
         DATABASE_NAME);
+    tableResultSetEqualTest(
+        "select date_bin(5s, time),province,city,region,device_id, count(1) 
from table1 group by 1,2,3,4,5 order by 2,3,4,5,1",
+        expectedHeader,
+        retArray,
+        DATABASE_NAME);
 
     expectedHeader =
         new String[] {
@@ -389,6 +404,11 @@ public class IoTDBTableAggregationIT {
         expectedHeader,
         retArray,
         DATABASE_NAME);
+    tableResultSetEqualTest(
+        "select province,city,region,device_id,count(1) from table1 group by 
1,2,3,4 order by 1,2,3,4",
+        expectedHeader,
+        retArray,
+        DATABASE_NAME);
 
     expectedHeader = new String[] {"province", "city", "region", "_col3"};
     retArray =
@@ -403,6 +423,11 @@ public class IoTDBTableAggregationIT {
         expectedHeader,
         retArray,
         DATABASE_NAME);
+    tableResultSetEqualTest(
+        "select province,city,region,count(1) from table1 group by 1,2,3 order 
by 1,2,3",
+        expectedHeader,
+        retArray,
+        DATABASE_NAME);
 
     expectedHeader = new String[] {"province", "city", "_col2"};
     retArray =
@@ -414,6 +439,11 @@ public class IoTDBTableAggregationIT {
         expectedHeader,
         retArray,
         DATABASE_NAME);
+    tableResultSetEqualTest(
+        "select province,city,count(1) from table1 group by 1,2 order by 1,2",
+        expectedHeader,
+        retArray,
+        DATABASE_NAME);
 
     expectedHeader = new String[] {"province", "_col1"};
     retArray =
@@ -425,6 +455,11 @@ public class IoTDBTableAggregationIT {
         expectedHeader,
         retArray,
         DATABASE_NAME);
+    tableResultSetEqualTest(
+        "select province,count(1) from table1 group by 1 order by 1",
+        expectedHeader,
+        retArray,
+        DATABASE_NAME);
 
     expectedHeader = new String[] {"_col0"};
     retArray =
@@ -432,6 +467,7 @@ public class IoTDBTableAggregationIT {
           "64,",
         };
     tableResultSetEqualTest("select count(*) from table1", expectedHeader, 
retArray, DATABASE_NAME);
+    tableResultSetEqualTest("select count(1) from table1", expectedHeader, 
retArray, DATABASE_NAME);
   }
 
   @Test
@@ -3116,6 +3152,11 @@ public class IoTDBTableAggregationIT {
         expectedHeader,
         retArray,
         DATABASE_NAME);
+    tableResultSetEqualTest(
+        "select color,type, date_bin(5s, time), count(1) from table1 group by 
1,2,3 order by 1,2,3",
+        expectedHeader,
+        retArray,
+        DATABASE_NAME);
   }
 
   @Test
@@ -3489,6 +3530,11 @@ public class IoTDBTableAggregationIT {
         expectedHeader,
         retArray,
         DATABASE_NAME);
+    tableResultSetEqualTest(
+        "select province,city,region,device_id,s7,count(1) from table1 group 
by 1,2,3,4,5 order by 1,2,3,4,5",
+        expectedHeader,
+        retArray,
+        DATABASE_NAME);
 
     expectedHeader = new String[] {"province", "city", "region", "device_id", 
"s8", "_col5"};
     retArray =
@@ -3973,6 +4019,11 @@ public class IoTDBTableAggregationIT {
         expectedHeader,
         retArray,
         DATABASE_NAME);
+    tableResultSetEqualTest(
+        "SELECT color, device_id FROM (SELECT date_bin(5s, time), color, 
device_id, avg(s4) as avg_s4 FROM table1 WHERE type='A' AND (time >= 
2024-09-24T06:15:30.000+00:00 AND time <= 2024-09-24T06:15:59.999+00:00) GROUP 
BY 1,2,3) WHERE avg_s4 > 1.0 GROUP BY color, device_id HAVING count(1) >= 2 
ORDER BY color, device_id",
+        expectedHeader,
+        retArray,
+        DATABASE_NAME);
 
     expectedHeader = new String[] {"_col0", "city", "type", "_col3"};
     retArray =
@@ -3996,6 +4047,11 @@ public class IoTDBTableAggregationIT {
         expectedHeader,
         retArray,
         DATABASE_NAME);
+    tableResultSetEqualTest(
+        "SELECT date_bin(10s, five_seconds), city, type, 
sum(five_seconds_count) / 2 FROM (SELECT date_bin(5s, time) AS five_seconds, 
city, type, count(1) AS five_seconds_count FROM table1 WHERE (time >= 
2024-09-24T06:15:30.000+00:00 AND time <= 2024-09-24T06:15:59.999+00:00) AND 
device_id IS NOT NULL GROUP BY 1, city, type, device_id HAVING avg(s1) > 1) 
GROUP BY 1, city, type order by 2,3,1",
+        expectedHeader,
+        retArray,
+        DATABASE_NAME);
   }
 
   @Test
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/ExpressionAnalyzer.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/ExpressionAnalyzer.java
index 1798a0ed0b7..cddba6f1e74 100644
--- 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/ExpressionAnalyzer.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/ExpressionAnalyzer.java
@@ -801,7 +801,7 @@ public class ExpressionAnalyzer {
       ResolvedFunction resolvedFunction =
           new ResolvedFunction(
               new BoundSignature(functionName.toLowerCase(Locale.ENGLISH), 
type, argumentTypes),
-              new FunctionId("noop"),
+              FunctionId.NOOP_FUNCTION_ID,
               isAggregation ? FunctionKind.AGGREGATE : FunctionKind.SCALAR,
               true,
               isAggregation
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/FunctionId.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/FunctionId.java
index cfc31ab471c..d814d8b4db6 100644
--- 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/FunctionId.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/FunctionId.java
@@ -24,6 +24,8 @@ import java.util.Locale;
 import static java.util.Objects.requireNonNull;
 
 public class FunctionId {
+  public static final FunctionId NOOP_FUNCTION_ID = new FunctionId("noop");
+
   private final String id;
 
   public FunctionId(String id) {
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/ResolvedFunction.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/ResolvedFunction.java
index 432488612f2..94298ea1bc8 100644
--- 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/ResolvedFunction.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/ResolvedFunction.java
@@ -38,7 +38,6 @@ public class ResolvedFunction {
   private final FunctionId functionId;
   private final FunctionKind functionKind;
   private final boolean deterministic;
-
   private final FunctionNullability functionNullability;
 
   public ResolvedFunction(
@@ -48,10 +47,10 @@ public class ResolvedFunction {
       boolean deterministic,
       FunctionNullability functionNullability) {
     this.signature = requireNonNull(signature, "signature is null");
-    this.functionId = requireNonNull(functionId, "functionId is null");
-    this.functionKind = requireNonNull(functionKind, "functionKind is null");
+    this.functionId = functionId;
+    this.functionKind = functionKind;
     this.deterministic = deterministic;
-    this.functionNullability = requireNonNull(functionNullability, 
"functionNullability is null");
+    this.functionNullability = functionNullability;
   }
 
   public BoundSignature getSignature() {
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/ir/IrUtils.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/ir/IrUtils.java
index 037360104ba..fb395f0f12a 100644
--- 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/ir/IrUtils.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/ir/IrUtils.java
@@ -256,6 +256,12 @@ public final class IrUtils {
     return combineConjuncts(conjuncts);
   }
 
+  /**
+   * Returns whether expression is effectively literal. An effectively literal 
expression is a
+   * simple constant value, or null, in either {@link Literal} form, or other 
form returned by
+   * LiteralEncoder. In particular, other constant expressions like a 
deterministic function call
+   * with constant arguments are not considered effectively literal.
+   */
   public static boolean isEffectivelyLiteral(
       Expression expression, PlannerContext plannerContext, SessionInfo 
session) {
     if (expression instanceof Literal) {
@@ -266,6 +272,7 @@ public final class IrUtils {
           // a Cast(Literal(...)) can fail, so this requires verification
           && constantExpressionEvaluatesSuccessfully(plannerContext, session, 
expression);
     }
+
     if (expression instanceof FunctionCall) {
       String functionName = ((FunctionCall) expression).getName().getSuffix();
       if (functionName.equals("pi") || functionName.equals("e")) {
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/iterative/rule/SimplifyCountOverConstant.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/iterative/rule/SimplifyCountOverConstant.java
new file mode 100644
index 00000000000..b42ec8ffbdb
--- /dev/null
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/iterative/rule/SimplifyCountOverConstant.java
@@ -0,0 +1,126 @@
+package org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule;
+
+import org.apache.iotdb.db.queryengine.plan.relational.function.BoundSignature;
+import org.apache.iotdb.db.queryengine.plan.relational.function.FunctionId;
+import 
org.apache.iotdb.db.queryengine.plan.relational.metadata.ResolvedFunction;
+import org.apache.iotdb.db.queryengine.plan.relational.planner.Assignments;
+import org.apache.iotdb.db.queryengine.plan.relational.planner.PlannerContext;
+import org.apache.iotdb.db.queryengine.plan.relational.planner.Symbol;
+import org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.Rule;
+import 
org.apache.iotdb.db.queryengine.plan.relational.planner.node.AggregationNode;
+import 
org.apache.iotdb.db.queryengine.plan.relational.planner.node.ProjectNode;
+import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Expression;
+import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.SymbolReference;
+import org.apache.iotdb.db.queryengine.plan.relational.utils.matching.Capture;
+import org.apache.iotdb.db.queryengine.plan.relational.utils.matching.Captures;
+import org.apache.iotdb.db.queryengine.plan.relational.utils.matching.Pattern;
+
+import com.google.common.collect.ImmutableList;
+import org.apache.tsfile.read.common.type.LongType;
+
+import java.util.Collections;
+import java.util.LinkedHashMap;
+import java.util.Map;
+import java.util.Optional;
+
+import static com.google.common.base.Verify.verify;
+import static java.util.Objects.requireNonNull;
+import static 
org.apache.iotdb.db.queryengine.plan.relational.function.FunctionKind.AGGREGATE;
+import static 
org.apache.iotdb.db.queryengine.plan.relational.metadata.FunctionNullability.getAggregationFunctionNullability;
+import static 
org.apache.iotdb.db.queryengine.plan.relational.planner.IrExpressionInterpreter.evaluateConstantExpression;
+import static 
org.apache.iotdb.db.queryengine.plan.relational.planner.ir.IrUtils.isEffectivelyLiteral;
+import static 
org.apache.iotdb.db.queryengine.plan.relational.planner.node.Patterns.aggregation;
+import static 
org.apache.iotdb.db.queryengine.plan.relational.planner.node.Patterns.project;
+import static 
org.apache.iotdb.db.queryengine.plan.relational.planner.node.Patterns.source;
+import static 
org.apache.iotdb.db.queryengine.plan.relational.utils.matching.Capture.newCapture;
+import static org.apache.iotdb.db.utils.constant.SqlConstant.COUNT;
+
+public class SimplifyCountOverConstant implements Rule<AggregationNode> {
+  private static final Capture<ProjectNode> CHILD = newCapture();
+
+  private static final Pattern<AggregationNode> PATTERN =
+      aggregation().with(source().matching(project().capturedAs(CHILD)));
+
+  private final PlannerContext plannerContext;
+
+  public SimplifyCountOverConstant(PlannerContext plannerContext) {
+    this.plannerContext = requireNonNull(plannerContext, "plannerContext is 
null");
+  }
+
+  @Override
+  public Pattern<AggregationNode> getPattern() {
+    return PATTERN;
+  }
+
+  @Override
+  public Result apply(AggregationNode parent, Captures captures, Context 
context) {
+    ProjectNode child = captures.get(CHILD);
+
+    boolean changed = false;
+    Map<Symbol, AggregationNode.Aggregation> aggregations = null;
+    ResolvedFunction countWildcardFunction = null;
+
+    for (Map.Entry<Symbol, AggregationNode.Aggregation> entry :
+        parent.getAggregations().entrySet()) {
+      Symbol symbol = entry.getKey();
+      AggregationNode.Aggregation aggregation = entry.getValue();
+
+      if (isCountOverConstant(context, aggregation, child.getAssignments())) {
+        changed = true;
+
+        if (countWildcardFunction == null) {
+          aggregations = new LinkedHashMap<>(parent.getAggregations());
+          countWildcardFunction =
+              new ResolvedFunction(
+                  new BoundSignature(COUNT, LongType.INT64, 
Collections.emptyList()),
+                  FunctionId.NOOP_FUNCTION_ID,
+                  AGGREGATE,
+                  true,
+                  getAggregationFunctionNullability(1));
+        }
+
+        aggregations.put(
+            symbol,
+            new AggregationNode.Aggregation(
+                countWildcardFunction,
+                ImmutableList.of(),
+                false,
+                Optional.empty(),
+                Optional.empty(),
+                aggregation.getMask()));
+      }
+    }
+
+    if (!changed) {
+      return Result.empty();
+    }
+
+    return Result.ofPlanNode(
+        AggregationNode.builderFrom(parent)
+            .setSource(child)
+            .setAggregations(aggregations)
+            .setPreGroupedSymbols(ImmutableList.of())
+            .build());
+  }
+
+  private boolean isCountOverConstant(
+      Context context, AggregationNode.Aggregation aggregation, Assignments 
inputs) {
+    BoundSignature signature = 
aggregation.getResolvedFunction().getSignature();
+    if (!signature.getName().equals(COUNT) || 
signature.getArgumentTypes().size() != 1) {
+      return false;
+    }
+
+    Expression argument = aggregation.getArguments().get(0);
+    if (argument instanceof SymbolReference) {
+      argument = inputs.get(Symbol.from(argument));
+    }
+
+    if (isEffectivelyLiteral(argument, plannerContext, 
context.getSessionInfo())) {
+      Object value = evaluateConstantExpression(argument, plannerContext, 
context.getSessionInfo());
+      verify(!(value instanceof Expression));
+      return value != null;
+    }
+
+    return false;
+  }
+}
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/LogicalOptimizeFactory.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/LogicalOptimizeFactory.java
index 36f0213da1c..748b9f829d7 100644
--- 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/LogicalOptimizeFactory.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/LogicalOptimizeFactory.java
@@ -67,6 +67,7 @@ import 
org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.Re
 import 
org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.RemoveTrivialFilters;
 import 
org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.RemoveUnreferencedScalarApplyNodes;
 import 
org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.RemoveUnreferencedScalarSubqueries;
+import 
org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.SimplifyCountOverConstant;
 import 
org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.SimplifyExpressions;
 import 
org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.SingleDistinctAggregationToGroupBy;
 import 
org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.TransformCorrelatedDistinctAggregationWithProjection;
@@ -203,10 +204,10 @@ public class LogicalOptimizeFactory {
                         // Our AggregationPushDown does not support 
AggregationNode with distinct,
                         // so there is no need to put it after 
AggregationPushDown,
                         // put it here to avoid extra ColumnPruning.
-                        new MultipleDistinctAggregationToMarkDistinct()
+                        new MultipleDistinctAggregationToMarkDistinct(),
                         //                        new MergeLimitWithDistinct(),
                         //                        new 
PruneCountAggregationOverScalar(metadata),
-                        //                        new 
SimplifyCountOverConstant(plannerContext),
+                        new SimplifyCountOverConstant(plannerContext)
                         //                        new
                         // PreAggregateCaseAggregations(plannerContext, 
typeAnalyzer)))
                         ))
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/TransformQuantifiedComparisonApplyToCorrelatedJoin.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/TransformQuantifiedComparisonApplyToCorrelatedJoin.java
index 00fecc1a69a..d7aaeda4f85 100644
--- 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/TransformQuantifiedComparisonApplyToCorrelatedJoin.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/TransformQuantifiedComparisonApplyToCorrelatedJoin.java
@@ -21,10 +21,6 @@ package 
org.apache.iotdb.db.queryengine.plan.relational.planner.optimizations;
 
 import org.apache.iotdb.db.queryengine.common.QueryId;
 import org.apache.iotdb.db.queryengine.plan.planner.plan.node.PlanNode;
-import org.apache.iotdb.db.queryengine.plan.relational.function.BoundSignature;
-import org.apache.iotdb.db.queryengine.plan.relational.function.FunctionId;
-import org.apache.iotdb.db.queryengine.plan.relational.function.FunctionKind;
-import 
org.apache.iotdb.db.queryengine.plan.relational.metadata.FunctionNullability;
 import org.apache.iotdb.db.queryengine.plan.relational.metadata.Metadata;
 import 
org.apache.iotdb.db.queryengine.plan.relational.metadata.ResolvedFunction;
 import org.apache.iotdb.db.queryengine.plan.relational.planner.Assignments;
@@ -54,7 +50,6 @@ import org.apache.tsfile.read.common.type.Type;
 
 import java.util.EnumSet;
 import java.util.List;
-import java.util.Locale;
 import java.util.Optional;
 import java.util.function.Function;
 
@@ -200,13 +195,7 @@ public class 
TransformQuantifiedComparisonApplyToCorrelatedJoin implements PlanO
     private ResolvedFunction getResolvedBuiltInAggregateFunction(
         String functionName, List<Type> argumentTypes) {
       // The same as the code in ExpressionAnalyzer
-      Type type = metadata.getFunctionReturnType(functionName, argumentTypes);
-      return new ResolvedFunction(
-          new BoundSignature(functionName.toLowerCase(Locale.ENGLISH), type, 
argumentTypes),
-          new FunctionId("noop"),
-          FunctionKind.AGGREGATE,
-          true,
-          
FunctionNullability.getAggregationFunctionNullability(argumentTypes.size()));
+      return Util.getResolvedBuiltInAggregateFunction(metadata, functionName, 
argumentTypes);
     }
 
     public Expression rewriteUsingBounds(
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/Util.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/Util.java
index 523dd09f9f9..3423168f5c8 100644
--- 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/Util.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/Util.java
@@ -231,7 +231,7 @@ public class Util {
     Type type = metadata.getFunctionReturnType(functionName, argumentTypes);
     return new ResolvedFunction(
         new BoundSignature(functionName.toLowerCase(Locale.ENGLISH), type, 
argumentTypes),
-        new FunctionId("noop"),
+        FunctionId.NOOP_FUNCTION_ID,
         FunctionKind.AGGREGATE,
         true,
         
FunctionNullability.getAggregationFunctionNullability(argumentTypes.size()));
diff --git 
a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/AggregationTest.java
 
b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/AggregationTest.java
index 6d99582aac9..96730ed0f5f 100644
--- 
a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/AggregationTest.java
+++ 
b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/AggregationTest.java
@@ -677,7 +677,7 @@ public class AggregationTest {
   public void deviceWithNumerousRegionTest() {
     PlanTester planTester = new PlanTester();
     LogicalQueryPlan logicalQueryPlan =
-        planTester.createPlan("SELECT count(s1) FROM table1 where tag2='B2'");
+        planTester.createPlan("SELECT count(s1+1) FROM table1 where 
tag2='B2'");
     // complete push-down when do logical optimize
     assertPlan(
         logicalQueryPlan,
@@ -754,4 +754,16 @@ public class AggregationTest {
                             ImmutableSet.of("tag1", "tag2", "tag3", "s1")),
                         exchange())))));
   }
+
+  @Test
+  public void countConstantTest() {
+    // select count(1) from table
+    // select count(2) from table
+    // select count('a') from table
+
+    // select count(pi()) from table
+    // select count(e()) from table
+
+    // select count(cast('a' as 'a')) from table
+  }
 }


Reply via email to