This is an automated email from the ASF dual-hosted git repository.
jackietien pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iotdb.git
The following commit(s) were added to refs/heads/master by this push:
new fea2e4fc8f8 [Table Model Subquery] Support uncorrelated quantified
comparison
fea2e4fc8f8 is described below
commit fea2e4fc8f8b712a5eb471378697988c86f57e69
Author: Liao Lanyu <[email protected]>
AuthorDate: Mon Jan 20 08:57:56 2025 +0800
[Table Model Subquery] Support uncorrelated quantified comparison
---
.../IoTDBUncorrelatedQuantifiedComparisonIT.java | 674 +++++++++++++++++++++
.../relational/aggregation/AccumulatorFactory.java | 2 +
.../aggregation/CountAllAccumulator.java | 97 +++
.../relational/ColumnTransformerBuilder.java | 6 +-
.../plan/planner/TableOperatorGenerator.java | 2 +
.../relational/metadata/TableMetadataImpl.java | 1 +
.../plan/relational/planner/IrTypeAnalyzer.java | 4 +-
.../relational/planner/SimplePlanRewriter.java | 88 +++
.../optimizations/LogicalOptimizeFactory.java | 8 +
.../optimizations/PushPredicateIntoTableScan.java | 31 +-
...mQuantifiedComparisonApplyToCorrelatedJoin.java | 341 +++++++++++
.../iotdb/db/utils/constant/SqlConstant.java | 1 +
.../plan/relational/planner/SubqueryTest.java | 212 +++++++
.../TableBuiltinAggregationFunction.java | 2 +
.../thrift-commons/src/main/thrift/common.thrift | 3 +-
15 files changed, 1468 insertions(+), 4 deletions(-)
diff --git
a/integration-test/src/test/java/org/apache/iotdb/relational/it/query/recent/subquery/uncorrelated/IoTDBUncorrelatedQuantifiedComparisonIT.java
b/integration-test/src/test/java/org/apache/iotdb/relational/it/query/recent/subquery/uncorrelated/IoTDBUncorrelatedQuantifiedComparisonIT.java
new file mode 100644
index 00000000000..11eb1feab90
--- /dev/null
+++
b/integration-test/src/test/java/org/apache/iotdb/relational/it/query/recent/subquery/uncorrelated/IoTDBUncorrelatedQuantifiedComparisonIT.java
@@ -0,0 +1,674 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iotdb.relational.it.query.recent.subquery.uncorrelated;
+
+import org.apache.iotdb.it.env.EnvFactory;
+import org.apache.iotdb.it.framework.IoTDBTestRunner;
+import org.apache.iotdb.itbase.category.TableClusterIT;
+import org.apache.iotdb.itbase.category.TableLocalStandaloneIT;
+
+import org.junit.AfterClass;
+import org.junit.BeforeClass;
+import org.junit.Test;
+import org.junit.experimental.categories.Category;
+import org.junit.runner.RunWith;
+
+import static org.apache.iotdb.db.it.utils.TestUtils.prepareTableData;
+import static org.apache.iotdb.db.it.utils.TestUtils.tableAssertTestFail;
+import static org.apache.iotdb.db.it.utils.TestUtils.tableResultSetEqualTest;
+import static
org.apache.iotdb.relational.it.query.recent.subquery.SubqueryDataSetUtils.CREATE_SQLS;
+import static
org.apache.iotdb.relational.it.query.recent.subquery.SubqueryDataSetUtils.DATABASE_NAME;
+import static
org.apache.iotdb.relational.it.query.recent.subquery.SubqueryDataSetUtils.NUMERIC_MEASUREMENTS;
+
+@RunWith(IoTDBTestRunner.class)
+@Category({TableLocalStandaloneIT.class, TableClusterIT.class})
+public class IoTDBUncorrelatedQuantifiedComparisonIT {
+
+ @BeforeClass
+ public static void setUp() throws Exception {
+ EnvFactory.getEnv().getConfig().getCommonConfig().setSortBufferSize(128 *
1024);
+
EnvFactory.getEnv().getConfig().getCommonConfig().setMaxTsBlockSizeInByte(4 *
1024);
+ EnvFactory.getEnv().initClusterEnvironment();
+ prepareTableData(CREATE_SQLS);
+ }
+
+ @AfterClass
+ public static void tearDown() throws Exception {
+ EnvFactory.getEnv().cleanClusterEnvironment();
+ }
+
+ @Test
+ public void testAnyAndSomeComparisonInWhereClauseWithoutNull() {
+ String sql;
+ String[] expectedHeader;
+ String[] retArray;
+
+ // Test case: where s > any (subquery), s does not contain null value
+ sql =
+ "SELECT cast(%s AS INT32) as %s FROM table1 WHERE device_id = 'd01'
and %s > any (SELECT %s FROM table1 WHERE device_id = 'd01')";
+ retArray = new String[] {"40,", "50,", "60,", "70,"};
+ for (String measurement : NUMERIC_MEASUREMENTS) {
+ expectedHeader = new String[] {measurement};
+ tableResultSetEqualTest(
+ String.format(sql, measurement, measurement, measurement,
measurement),
+ expectedHeader,
+ retArray,
+ DATABASE_NAME);
+ }
+
+ // Test case: where s > some (subquery), s does not contain null value
+ sql =
+ "SELECT cast(%s AS INT32) as %s FROM table1 WHERE device_id = 'd01'
and %s > some (SELECT %s FROM table1 WHERE device_id = 'd01')";
+ retArray = new String[] {"40,", "50,", "60,", "70,"};
+ for (String measurement : NUMERIC_MEASUREMENTS) {
+ expectedHeader = new String[] {measurement};
+ tableResultSetEqualTest(
+ String.format(sql, measurement, measurement, measurement,
measurement),
+ expectedHeader,
+ retArray,
+ DATABASE_NAME);
+ }
+
+ // Test case: where s >= any (subquery), s does not contain null value
+ sql =
+ "SELECT cast(%s AS INT32) as %s FROM table1 WHERE device_id = 'd01'
and %s >= any (SELECT %s FROM table1 WHERE device_id = 'd01')";
+ retArray = new String[] {"30,", "40,", "50,", "60,", "70,"};
+ for (String measurement : NUMERIC_MEASUREMENTS) {
+ expectedHeader = new String[] {measurement};
+ tableResultSetEqualTest(
+ String.format(sql, measurement, measurement, measurement,
measurement),
+ expectedHeader,
+ retArray,
+ DATABASE_NAME);
+ }
+
+ // Test case: where s >= some (subquery), s does not contain null value
+ sql =
+ "SELECT cast(%s AS INT32) as %s FROM table1 WHERE device_id = 'd01'
and %s >= some (SELECT %s FROM table1 WHERE device_id = 'd01')";
+ retArray = new String[] {"30,", "40,", "50,", "60,", "70,"};
+ for (String measurement : NUMERIC_MEASUREMENTS) {
+ expectedHeader = new String[] {measurement};
+ tableResultSetEqualTest(
+ String.format(sql, measurement, measurement, measurement,
measurement),
+ expectedHeader,
+ retArray,
+ DATABASE_NAME);
+ }
+
+ // Test case: where s < any (subquery), s does not contain null value
+ sql =
+ "SELECT cast(%s AS INT32) as %s FROM table1 WHERE device_id = 'd01'
and %s < any (SELECT %s FROM table1 WHERE device_id = 'd01')";
+ retArray = new String[] {"30,", "40,", "50,", "60,"};
+ for (String measurement : NUMERIC_MEASUREMENTS) {
+ expectedHeader = new String[] {measurement};
+ tableResultSetEqualTest(
+ String.format(sql, measurement, measurement, measurement,
measurement),
+ expectedHeader,
+ retArray,
+ DATABASE_NAME);
+ }
+
+ // Test case: where < some (subquery), s does not contain null value
+ sql =
+ "SELECT cast(%s AS INT32) as %s FROM table1 WHERE device_id = 'd01'
and %s < some (SELECT %s FROM table1 WHERE device_id = 'd01')";
+ retArray = new String[] {"30,", "40,", "50,", "60,"};
+ for (String measurement : NUMERIC_MEASUREMENTS) {
+ expectedHeader = new String[] {measurement};
+ tableResultSetEqualTest(
+ String.format(sql, measurement, measurement, measurement,
measurement),
+ expectedHeader,
+ retArray,
+ DATABASE_NAME);
+ }
+
+ // Test case: where s <= any (subquery), s does not contain null value
+ sql =
+ "SELECT cast(%s AS INT32) as %s FROM table1 WHERE device_id = 'd01'
and %s <= any (SELECT %s FROM table1 WHERE device_id = 'd01')";
+ retArray = new String[] {"30,", "40,", "50,", "60,", "70,"};
+ for (String measurement : NUMERIC_MEASUREMENTS) {
+ expectedHeader = new String[] {measurement};
+ tableResultSetEqualTest(
+ String.format(sql, measurement, measurement, measurement,
measurement),
+ expectedHeader,
+ retArray,
+ DATABASE_NAME);
+ }
+
+ // Test case: where s <= some (subquery), s does not contain null value
+ sql =
+ "SELECT cast(%s AS INT32) as %s FROM table1 WHERE device_id = 'd01'
and %s <= some (SELECT %s FROM table1 WHERE device_id = 'd01')";
+ retArray = new String[] {"30,", "40,", "50,", "60,", "70,"};
+ for (String measurement : NUMERIC_MEASUREMENTS) {
+ expectedHeader = new String[] {measurement};
+ tableResultSetEqualTest(
+ String.format(sql, measurement, measurement, measurement,
measurement),
+ expectedHeader,
+ retArray,
+ DATABASE_NAME);
+ }
+
+ // Test case: where s = any (subquery), s does not contain null value
+ sql =
+ "SELECT cast(%s AS INT32) as %s FROM table1 WHERE device_id = 'd01'
and %s = any (SELECT %s FROM table1 WHERE device_id = 'd01')";
+ retArray = new String[] {"30,", "40,", "50,", "60,", "70,"};
+ for (String measurement : NUMERIC_MEASUREMENTS) {
+ expectedHeader = new String[] {measurement};
+ tableResultSetEqualTest(
+ String.format(sql, measurement, measurement, measurement,
measurement),
+ expectedHeader,
+ retArray,
+ DATABASE_NAME);
+ }
+
+ // Test case: where s = some (subquery), s does not contain null value
+ sql =
+ "SELECT cast(%s AS INT32) as %s FROM table1 WHERE device_id = 'd01'
and %s = some (SELECT %s FROM table1 WHERE device_id = 'd01')";
+ retArray = new String[] {"30,", "40,", "50,", "60,", "70,"};
+ for (String measurement : NUMERIC_MEASUREMENTS) {
+ expectedHeader = new String[] {measurement};
+ tableResultSetEqualTest(
+ String.format(sql, measurement, measurement, measurement,
measurement),
+ expectedHeader,
+ retArray,
+ DATABASE_NAME);
+ }
+
+ // Test case: where s != any (subquery), s does not contain null value
+ sql =
+ "SELECT cast(%s AS INT32) as %s FROM table1 WHERE device_id = 'd01'
and %s != any (SELECT %s FROM table1 WHERE device_id = 'd01')";
+ retArray = new String[] {"30,", "40,", "50,", "60,", "70,"};
+ for (String measurement : NUMERIC_MEASUREMENTS) {
+ expectedHeader = new String[] {measurement};
+ tableResultSetEqualTest(
+ String.format(sql, measurement, measurement, measurement,
measurement),
+ expectedHeader,
+ retArray,
+ DATABASE_NAME);
+ }
+
+ // Test case: where s != some (subquery), s does not contain null value
+ sql =
+ "SELECT cast(%s AS INT32) as %s FROM table1 WHERE device_id = 'd01'
and %s != some (SELECT %s FROM table1 WHERE device_id = 'd01')";
+ retArray = new String[] {"30,", "40,", "50,", "60,", "70,"};
+ for (String measurement : NUMERIC_MEASUREMENTS) {
+ expectedHeader = new String[] {measurement};
+ tableResultSetEqualTest(
+ String.format(sql, measurement, measurement, measurement,
measurement),
+ expectedHeader,
+ retArray,
+ DATABASE_NAME);
+ }
+ }
+
+ @Test
+ public void testAllComparisonInWhereClauseWithoutNull() {
+ String sql;
+ String[] expectedHeader;
+ String[] retArray;
+
+ // Test case: where s > all (subquery), s does not contain null value
+ sql =
+ "SELECT cast(%s AS INT32) as %s FROM table1 WHERE device_id = 'd01'
and %s > all (SELECT %s FROM table3 WHERE device_id = 'd01')";
+ retArray = new String[] {"50,", "60,", "70,"};
+ for (String measurement : NUMERIC_MEASUREMENTS) {
+ expectedHeader = new String[] {measurement};
+ tableResultSetEqualTest(
+ String.format(sql, measurement, measurement, measurement,
measurement),
+ expectedHeader,
+ retArray,
+ DATABASE_NAME);
+ }
+
+ // Test case: where s >= all (subquery), s does not contain null value
+ sql =
+ "SELECT cast(%s AS INT32) as %s FROM table1 WHERE device_id = 'd01'
and %s >= all (SELECT %s FROM table3 WHERE device_id = 'd01')";
+ retArray = new String[] {"40,", "50,", "60,", "70,"};
+ for (String measurement : NUMERIC_MEASUREMENTS) {
+ expectedHeader = new String[] {measurement};
+ tableResultSetEqualTest(
+ String.format(sql, measurement, measurement, measurement,
measurement),
+ expectedHeader,
+ retArray,
+ DATABASE_NAME);
+ }
+
+ // Test case: where s < all (subquery), s does not contain null value
+ sql =
+ "SELECT cast(%s AS INT32) as %s FROM table1 WHERE device_id = 'd01'
and %s < all (SELECT %s FROM table3 WHERE device_id = 'd01')";
+ retArray = new String[] {};
+ for (String measurement : NUMERIC_MEASUREMENTS) {
+ expectedHeader = new String[] {measurement};
+ tableResultSetEqualTest(
+ String.format(sql, measurement, measurement, measurement,
measurement),
+ expectedHeader,
+ retArray,
+ DATABASE_NAME);
+ }
+
+ // Test case: where s <= all (subquery), s does not contain null value
+ sql =
+ "SELECT cast(%s AS INT32) as %s FROM table1 WHERE device_id = 'd01'
and %s <= all (SELECT %s FROM table3 WHERE device_id = 'd01')";
+ retArray = new String[] {"30,"};
+ for (String measurement : NUMERIC_MEASUREMENTS) {
+ expectedHeader = new String[] {measurement};
+ tableResultSetEqualTest(
+ String.format(sql, measurement, measurement, measurement,
measurement),
+ expectedHeader,
+ retArray,
+ DATABASE_NAME);
+ }
+
+ // Test case: where s = all (subquery), s does not contain null value
+ sql =
+ "SELECT cast(%s AS INT32) as %s FROM table1 WHERE device_id = 'd01'
and %s = all (SELECT %s FROM table3 WHERE device_id = 'd01')";
+ retArray = new String[] {};
+ for (String measurement : NUMERIC_MEASUREMENTS) {
+ expectedHeader = new String[] {measurement};
+ tableResultSetEqualTest(
+ String.format(sql, measurement, measurement, measurement,
measurement),
+ expectedHeader,
+ retArray,
+ DATABASE_NAME);
+ }
+
+ // Test case: where s != all (subquery), s does not contain null value
+ sql =
+ "SELECT cast(%s AS INT32) as %s FROM table1 WHERE device_id = 'd01'
and %s != all (SELECT %s FROM table3 WHERE device_id = 'd01')";
+ retArray = new String[] {"50,", "60,", "70,"};
+ for (String measurement : NUMERIC_MEASUREMENTS) {
+ expectedHeader = new String[] {measurement};
+ tableResultSetEqualTest(
+ String.format(sql, measurement, measurement, measurement,
measurement),
+ expectedHeader,
+ retArray,
+ DATABASE_NAME);
+ }
+ }
+
+ @Test
+ public void testAnyAndSomeComparisonInWhereClauseWithNull() {
+ String sql;
+ String[] expectedHeader;
+ String[] retArray;
+
+ // Test case: where s1 > any (subquery), s1 in table3 contains null value
+ sql =
+ "SELECT cast(%s AS INT32) as %s FROM table1 WHERE device_id = 'd01'
and %s > any (SELECT s1 FROM table3)";
+ retArray = new String[] {"40,", "50,", "60,", "70,"};
+ for (String measurement : NUMERIC_MEASUREMENTS) {
+ expectedHeader = new String[] {measurement};
+ tableResultSetEqualTest(
+ String.format(sql, measurement, measurement, measurement,
measurement),
+ expectedHeader,
+ retArray,
+ DATABASE_NAME);
+ }
+
+ // Test case: where s1 > some (subquery), s1 in table3 contains null value
+ sql =
+ "SELECT cast(%s AS INT32) as %s FROM table1 WHERE device_id = 'd01'
and %s > some (SELECT s1 FROM table3)";
+ retArray = new String[] {"40,", "50,", "60,", "70,"};
+ for (String measurement : NUMERIC_MEASUREMENTS) {
+ expectedHeader = new String[] {measurement};
+ tableResultSetEqualTest(
+ String.format(sql, measurement, measurement, measurement,
measurement),
+ expectedHeader,
+ retArray,
+ DATABASE_NAME);
+ }
+
+ // Test case: where s1 >= any (subquery), s1 in table3 contains null value
+ sql =
+ "SELECT cast(%s AS INT32) as %s FROM table1 WHERE device_id = 'd01'
and %s >= any (SELECT s1 FROM table3)";
+ retArray = new String[] {"30,", "40,", "50,", "60,", "70,"};
+ for (String measurement : NUMERIC_MEASUREMENTS) {
+ expectedHeader = new String[] {measurement};
+ tableResultSetEqualTest(
+ String.format(sql, measurement, measurement, measurement,
measurement),
+ expectedHeader,
+ retArray,
+ DATABASE_NAME);
+ }
+
+ // Test case: where s1 >= some (subquery), s1 in table3 contains null value
+ sql =
+ "SELECT cast(%s AS INT32) as %s FROM table1 WHERE device_id = 'd01'
and %s >= some (SELECT s1 FROM table3)";
+ retArray = new String[] {"30,", "40,", "50,", "60,", "70,"};
+ for (String measurement : NUMERIC_MEASUREMENTS) {
+ expectedHeader = new String[] {measurement};
+ tableResultSetEqualTest(
+ String.format(sql, measurement, measurement, measurement,
measurement),
+ expectedHeader,
+ retArray,
+ DATABASE_NAME);
+ }
+
+ // Test case: where s1 < any (subquery), s1 in table3 contains null value
+ sql =
+ "SELECT cast(%s AS INT32) as %s FROM table1 WHERE device_id = 'd01'
and %s < any (SELECT s1 FROM table3)";
+ retArray = new String[] {"30,"};
+ for (String measurement : NUMERIC_MEASUREMENTS) {
+ expectedHeader = new String[] {measurement};
+ tableResultSetEqualTest(
+ String.format(sql, measurement, measurement, measurement,
measurement),
+ expectedHeader,
+ retArray,
+ DATABASE_NAME);
+ }
+
+ // Test case: where s1 < some (subquery), s1 in table3 contains null value
+ sql =
+ "SELECT cast(%s AS INT32) as %s FROM table1 WHERE device_id = 'd01'
and %s < some (SELECT s1 FROM table3)";
+ retArray = new String[] {"30,"};
+ for (String measurement : NUMERIC_MEASUREMENTS) {
+ expectedHeader = new String[] {measurement};
+ tableResultSetEqualTest(
+ String.format(sql, measurement, measurement, measurement,
measurement),
+ expectedHeader,
+ retArray,
+ DATABASE_NAME);
+ }
+ }
+
+ @Test
+ public void testAllComparisonInWhereClauseWithNull() {
+ String sql;
+ String[] expectedHeader;
+ String[] retArray;
+
+ // Test case: where s1 > all (subquery), s1 in table3 contains null value.
+ sql =
+ "SELECT cast(%s AS INT32) as %s FROM table1 WHERE device_id = 'd01'
and %s > all (SELECT s1 FROM table3)";
+ retArray = new String[] {};
+ for (String measurement : NUMERIC_MEASUREMENTS) {
+ expectedHeader = new String[] {measurement};
+ tableResultSetEqualTest(
+ String.format(sql, measurement, measurement, measurement,
measurement),
+ expectedHeader,
+ retArray,
+ DATABASE_NAME);
+ }
+
+ // Test case: where s1 >= all (subquery), s1 in table3 contains null value.
+ sql =
+ "SELECT cast(%s AS INT32) as %s FROM table1 WHERE device_id = 'd01'
and %s >= all (SELECT s1 FROM table3)";
+ retArray = new String[] {};
+ for (String measurement : NUMERIC_MEASUREMENTS) {
+ expectedHeader = new String[] {measurement};
+ tableResultSetEqualTest(
+ String.format(sql, measurement, measurement, measurement,
measurement),
+ expectedHeader,
+ retArray,
+ DATABASE_NAME);
+ }
+
+ // Test case: where s1 < all (subquery), s1 in table3 contains null value.
+ sql =
+ "SELECT cast(%s AS INT32) as %s FROM table1 WHERE device_id = 'd01'
and %s < all (SELECT s1 FROM table3)";
+ retArray = new String[] {};
+ for (String measurement : NUMERIC_MEASUREMENTS) {
+ expectedHeader = new String[] {measurement};
+ tableResultSetEqualTest(
+ String.format(sql, measurement, measurement, measurement,
measurement),
+ expectedHeader,
+ retArray,
+ DATABASE_NAME);
+ }
+
+ // Test case: where s1 <= all (subquery), s1 in table3 contains null value.
+ sql =
+ "SELECT cast(%s AS INT32) as %s FROM table1 WHERE device_id = 'd01'
and %s <= all (SELECT s1 FROM table3)";
+ retArray = new String[] {};
+ for (String measurement : NUMERIC_MEASUREMENTS) {
+ expectedHeader = new String[] {measurement};
+ tableResultSetEqualTest(
+ String.format(sql, measurement, measurement, measurement,
measurement),
+ expectedHeader,
+ retArray,
+ DATABASE_NAME);
+ }
+
+ // Test case: where s1 = all (subquery), s1 in table3 contains null value.
+ sql =
+ "SELECT cast(%s AS INT32) as %s FROM table1 WHERE device_id = 'd01'
and %s = all (SELECT s1 FROM table3)";
+ retArray = new String[] {};
+ for (String measurement : NUMERIC_MEASUREMENTS) {
+ expectedHeader = new String[] {measurement};
+ tableResultSetEqualTest(
+ String.format(sql, measurement, measurement, measurement,
measurement),
+ expectedHeader,
+ retArray,
+ DATABASE_NAME);
+ }
+
+ // Test case: where s1 != all (subquery), s1 in table3 contains null value.
+ sql =
+ "SELECT cast(%s AS INT32) as %s FROM table1 WHERE device_id = 'd01'
and cast(%s as INT32) != all (SELECT s1 FROM table3)";
+ retArray = new String[] {};
+ for (String measurement : NUMERIC_MEASUREMENTS) {
+ expectedHeader = new String[] {measurement};
+ tableResultSetEqualTest(
+ String.format(sql, measurement, measurement, measurement,
measurement),
+ expectedHeader,
+ retArray,
+ DATABASE_NAME);
+ }
+ }
+
+ @Test
+ public void testQuantifiedComparisonInWhereWithExpression() {
+ String sql;
+ String[] expectedHeader;
+ String[] retArray;
+
+ sql =
+ "SELECT cast(%s + 10 AS INT32) as %s FROM table1 WHERE device_id =
'd01' and %s + 10 > any (SELECT %s + 10 FROM table1 WHERE device_id = 'd01')";
+ retArray = new String[] {"50,", "60,", "70,", "80,"};
+ for (String measurement : NUMERIC_MEASUREMENTS) {
+ expectedHeader = new String[] {measurement};
+ tableResultSetEqualTest(
+ String.format(sql, measurement, measurement, measurement,
measurement),
+ expectedHeader,
+ retArray,
+ DATABASE_NAME);
+ }
+
+ sql =
+ "SELECT cast(%s + 10 AS INT32) as %s FROM table1 WHERE device_id =
'd01' and %s + 10 > some (SELECT %s + 10 FROM table1 WHERE device_id = 'd01')";
+ retArray = new String[] {"50,", "60,", "70,", "80,"};
+ for (String measurement : NUMERIC_MEASUREMENTS) {
+ expectedHeader = new String[] {measurement};
+ tableResultSetEqualTest(
+ String.format(sql, measurement, measurement, measurement,
measurement),
+ expectedHeader,
+ retArray,
+ DATABASE_NAME);
+ }
+
+ sql =
+ "SELECT cast(%s + 10 AS INT32) as %s FROM table1 WHERE device_id =
'd01' and %s + 10 >= all (SELECT %s + 10 FROM table1 WHERE device_id = 'd01')";
+ retArray = new String[] {"80,"};
+ for (String measurement : NUMERIC_MEASUREMENTS) {
+ expectedHeader = new String[] {measurement};
+ tableResultSetEqualTest(
+ String.format(sql, measurement, measurement, measurement,
measurement),
+ expectedHeader,
+ retArray,
+ DATABASE_NAME);
+ }
+ }
+
+ @Test
+ public void testQuantifiedComparisonInHavingClause() {
+ String sql;
+ String[] expectedHeader;
+ String[] retArray;
+
+ // Test case: having s >= any(subquery)
+ sql =
+ "SELECT device_id, count(*) from table1 group by device_id having
count(*) + 25 >= any(SELECT cast(s1 as INT64) from table3 where device_id =
'd01')";
+ expectedHeader = new String[] {"device_id", "_col1"};
+ retArray =
+ new String[] {
+ "d01,5,", "d03,5,", "d05,5,", "d07,5,", "d09,5,", "d11,5,",
"d13,5,", "d15,5,"
+ };
+ for (String measurement : NUMERIC_MEASUREMENTS) {
+ tableResultSetEqualTest(
+ String.format(sql, measurement, measurement, measurement,
measurement),
+ expectedHeader,
+ retArray,
+ DATABASE_NAME);
+ }
+
+ // Test case: having s >= some(subquery)
+ sql =
+ "SELECT device_id, count(*) from table1 group by device_id having
count(*) + 25 >= some(SELECT cast(s1 as INT64) from table3 where device_id =
'd01')";
+ expectedHeader = new String[] {"device_id", "_col1"};
+ retArray =
+ new String[] {
+ "d01,5,", "d03,5,", "d05,5,", "d07,5,", "d09,5,", "d11,5,",
"d13,5,", "d15,5,"
+ };
+ for (String measurement : NUMERIC_MEASUREMENTS) {
+ tableResultSetEqualTest(
+ String.format(sql, measurement, measurement, measurement,
measurement),
+ expectedHeader,
+ retArray,
+ DATABASE_NAME);
+ }
+
+ // Test case: having s >= all(subquery)
+ sql =
+ "SELECT device_id, count(*) from table1 group by device_id having
count(*) + 35 >= all(SELECT cast(s1 as INT64) from table3 where device_id =
'd01')";
+ expectedHeader = new String[] {"device_id", "_col1"};
+ retArray =
+ new String[] {
+ "d01,5,", "d03,5,", "d05,5,", "d07,5,", "d09,5,", "d11,5,",
"d13,5,", "d15,5,"
+ };
+ for (String measurement : NUMERIC_MEASUREMENTS) {
+ tableResultSetEqualTest(
+ String.format(sql, measurement, measurement, measurement,
measurement),
+ expectedHeader,
+ retArray,
+ DATABASE_NAME);
+ }
+ }
+
+ public void testQuantifiedComparisonInSelectClause() {
+ String sql;
+ String[] expectedHeader;
+ String[] retArray;
+
+ // Test case: select s > any(subquery)
+ sql =
+ "SELECT %s > any(SELECT (%s) from table3 WHERE device_id = 'd01') from
table1 where device_id = 'd01'";
+ expectedHeader = new String[] {"_col0"};
+ retArray = new String[] {"false,", "true,", "true,", "true,", "true,"};
+ for (String measurement : NUMERIC_MEASUREMENTS) {
+ tableResultSetEqualTest(
+ String.format(sql, measurement, measurement), expectedHeader,
retArray, DATABASE_NAME);
+ }
+
+ // Test case: select s > some(subquery)
+ sql =
+ "SELECT %s > some(SELECT (%s) from table3 WHERE device_id = 'd01')
from table1 where device_id = 'd01'";
+ expectedHeader = new String[] {"_col0"};
+ retArray = new String[] {"false,", "true,", "true,", "true,", "true,"};
+ for (String measurement : NUMERIC_MEASUREMENTS) {
+ tableResultSetEqualTest(
+ String.format(sql, measurement, measurement), expectedHeader,
retArray, DATABASE_NAME);
+ }
+
+ // Test case: select s > all(subquery)
+ sql =
+ "SELECT %s > all(SELECT (%s) from table3 WHERE device_id = 'd01') from
table1 where device_id = 'd01'";
+ expectedHeader = new String[] {"_col0"};
+ retArray = new String[] {"false,", "false,", "false,", "false,", "false,"};
+ for (String measurement : NUMERIC_MEASUREMENTS) {
+ tableResultSetEqualTest(
+ String.format(sql, measurement, measurement), expectedHeader,
retArray, DATABASE_NAME);
+ }
+
+ // Test case: select s < any(subquery), subquery contains null value
+ sql = "SELECT %s < any(SELECT (%s) from table3) from table1 where
device_id = 'd01'";
+ expectedHeader = new String[] {"_col0"};
+ retArray = new String[] {"null,", "null,", "null,", "null,", "null,"};
+ for (String measurement : NUMERIC_MEASUREMENTS) {
+ tableResultSetEqualTest(
+ String.format(sql, measurement, measurement), expectedHeader,
retArray, DATABASE_NAME);
+ }
+
+ // Test case: select s < some(subquery), subquery contains null value
+ sql = "SELECT %s < some(SELECT (%s) from table3) from table1 where
device_id = 'd01'";
+ expectedHeader = new String[] {"_col0"};
+ retArray = new String[] {"null,", "null,", "null,", "null,", "null,"};
+ for (String measurement : NUMERIC_MEASUREMENTS) {
+ tableResultSetEqualTest(
+ String.format(sql, measurement, measurement), expectedHeader,
retArray, DATABASE_NAME);
+ }
+
+ // Test case: select s <= any(subquery), subquery contains null value
+ sql = "SELECT %s <= any(SELECT (%s) from table3) from table1 where
device_id = 'd01'";
+ expectedHeader = new String[] {"_col0"};
+ retArray = new String[] {"true,", "null,", "null,", "null,", "null,"};
+ for (String measurement : NUMERIC_MEASUREMENTS) {
+ tableResultSetEqualTest(
+ String.format(sql, measurement, measurement), expectedHeader,
retArray, DATABASE_NAME);
+ }
+
+ // Test case: select s <= some(subquery), subquery contains null value
+ sql = "SELECT %s <= some(SELECT (%s) from table3) from table1 where
device_id = 'd01'";
+ expectedHeader = new String[] {"_col0"};
+ retArray = new String[] {"true,", "null,", "null,", "null,", "null,"};
+ for (String measurement : NUMERIC_MEASUREMENTS) {
+ tableResultSetEqualTest(
+ String.format(sql, measurement, measurement), expectedHeader,
retArray, DATABASE_NAME);
+ }
+
+ // Test case: select s != all(subquery), subquery contains null value
+ sql = "SELECT %s != all(SELECT (%s) from table3) from table1 where
device_id = 'd01'";
+ expectedHeader = new String[] {"_col0"};
+ retArray = new String[] {"false,", "false,", "null,", "null,", "null,"};
+ for (String measurement : NUMERIC_MEASUREMENTS) {
+ tableResultSetEqualTest(
+ String.format(sql, measurement, measurement), expectedHeader,
retArray, DATABASE_NAME);
+ }
+
+ // Test case: select s != all(subquery), subquery result contains null
value and s not in
+ // non-null
+ // value result set
+ sql =
+ "SELECT %s != all(SELECT (%s) from table3 where device_id = 'd_null')
from table1 where device_id = 'd02' and %s != 30";
+ expectedHeader = new String[] {"_col0"};
+ retArray = new String[] {"null,", "null,"};
+ for (String measurement : NUMERIC_MEASUREMENTS) {
+ tableResultSetEqualTest(
+ String.format(sql, measurement, measurement), expectedHeader,
retArray, DATABASE_NAME);
+ }
+ }
+
+ @Test
+ public void testQuantifiedComparisonLegalityCheck() {
+ // Legality check: only support any/some/all quantifier
+ tableAssertTestFail(
+ "select s1 from table1 where s1 > any_value (select s1 from table3)",
+ "mismatched input",
+ DATABASE_NAME);
+ }
+}
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/AccumulatorFactory.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/AccumulatorFactory.java
index 403627ed461..a9b4122375f 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/AccumulatorFactory.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/AccumulatorFactory.java
@@ -254,6 +254,8 @@ public class AccumulatorFactory {
switch (aggregationType) {
case COUNT:
return new CountAccumulator();
+ case COUNT_ALL:
+ return new CountAllAccumulator();
case COUNT_IF:
return new CountIfAccumulator();
case AVG:
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/CountAllAccumulator.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/CountAllAccumulator.java
new file mode 100644
index 00000000000..27867e38730
--- /dev/null
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/CountAllAccumulator.java
@@ -0,0 +1,97 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package
org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation;
+
+import org.apache.tsfile.block.column.Column;
+import org.apache.tsfile.block.column.ColumnBuilder;
+import org.apache.tsfile.file.metadata.statistics.Statistics;
+import org.apache.tsfile.utils.RamUsageEstimator;
+
+import static com.google.common.base.Preconditions.checkArgument;
+
+public class CountAllAccumulator implements TableAccumulator {
+ private static final long INSTANCE_SIZE =
+ RamUsageEstimator.shallowSizeOfInstance(CountAllAccumulator.class);
+ private long countState = 0;
+
+ @Override
+ public long getEstimatedSize() {
+ return INSTANCE_SIZE;
+ }
+
+ @Override
+ public TableAccumulator copy() {
+ return new CountAllAccumulator();
+ }
+
+ @Override
+ public void addInput(Column[] arguments, AggregationMask mask) {
+ checkArgument(arguments.length == 1, "argument of CountAll should be one
column");
+ int count = mask.getSelectedPositionCount();
+ countState += count;
+ }
+
+ @Override
+ public void removeInput(Column[] arguments) {
+ checkArgument(arguments.length == 1, "argument of Count should be one
column");
+ int count = arguments[0].getPositionCount();
+ countState -= count;
+ }
+
+ @Override
+ public void addIntermediate(Column argument) {
+ for (int i = 0; i < argument.getPositionCount(); i++) {
+ if (argument.isNull(i)) {
+ continue;
+ }
+ countState += argument.getLong(i);
+ }
+ }
+
+ @Override
+ public void evaluateIntermediate(ColumnBuilder columnBuilder) {
+ columnBuilder.writeLong(countState);
+ }
+
+ @Override
+ public void evaluateFinal(ColumnBuilder columnBuilder) {
+ columnBuilder.writeLong(countState);
+ }
+
+ @Override
+ public boolean hasFinalResult() {
+ return false;
+ }
+
+ @Override
+ public void addStatistics(Statistics[] statistics) {
+ throw new UnsupportedOperationException("CountAllAccumulator does not
support statistics.");
+ }
+
+ @Override
+ public void reset() {
+ countState = 0;
+ }
+
+ @Override
+ public boolean removable() {
+ return true;
+ }
+}
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/relational/ColumnTransformerBuilder.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/relational/ColumnTransformerBuilder.java
index e2acc16987a..55de5391938 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/relational/ColumnTransformerBuilder.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/relational/ColumnTransformerBuilder.java
@@ -493,7 +493,7 @@ public class ColumnTransformerBuilder
return res;
}
- // currently, we only support Date and Timestamp
+ // currently, we only support Date/Timestamp/INT64
// for Date, GenericLiteral.value is an int value
// for Timestamp, GenericLiteral.value is a long value
private static ConstantColumnTransformer
getColumnTransformerForGenericLiteral(
@@ -506,6 +506,10 @@ public class ColumnTransformerBuilder
return new ConstantColumnTransformer(
TimestampType.TIMESTAMP,
new LongColumn(1, Optional.empty(), new long[]
{Long.parseLong(literal.getValue())}));
+ } else if (INT64.getTypeEnum().name().equals(literal.getType())) {
+ return new ConstantColumnTransformer(
+ INT64,
+ new LongColumn(1, Optional.empty(), new long[]
{Long.parseLong(literal.getValue())}));
} else {
throw new SemanticException("Unsupported type in GenericLiteral: " +
literal.getType());
}
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/TableOperatorGenerator.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/TableOperatorGenerator.java
index ffd248cc845..e94136cf74c 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/TableOperatorGenerator.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/TableOperatorGenerator.java
@@ -1533,9 +1533,11 @@ public class TableOperatorGenerator extends
PlanVisitor<Operator, LocalExecution
Type sourceJoinKeyType =
context.getTypeProvider().getTableModelType(node.getSourceJoinSymbol());
+
checkIfJoinKeyTypeMatches(
sourceJoinKeyType,
context.getTypeProvider().getTableModelType(node.getFilteringSourceJoinSymbol()));
+
OperatorContext operatorContext =
context
.getDriverContext()
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/TableMetadataImpl.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/TableMetadataImpl.java
index cfac519f769..607b1c37135 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/TableMetadataImpl.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/TableMetadataImpl.java
@@ -628,6 +628,7 @@ public class TableMetadataImpl implements Metadata {
// get return type
switch (functionName.toLowerCase(Locale.ENGLISH)) {
case SqlConstant.COUNT:
+ case SqlConstant.COUNT_ALL:
case SqlConstant.COUNT_IF:
return INT64;
case SqlConstant.FIRST_AGGREGATION:
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/IrTypeAnalyzer.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/IrTypeAnalyzer.java
index 1a19f40132f..1ecf2dad77d 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/IrTypeAnalyzer.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/IrTypeAnalyzer.java
@@ -340,7 +340,7 @@ public class IrTypeAnalyzer {
&& node.getParsedValue() <= Integer.MAX_VALUE) {
return setExpressionType(node, INT32);
}
-
+ // keep the original type
return setExpressionType(node, INT64);
}
@@ -361,6 +361,8 @@ public class IrTypeAnalyzer {
type = DateType.DATE;
} else if
(TimestampType.TIMESTAMP.getTypeEnum().name().equals(node.getType())) {
type = TimestampType.TIMESTAMP;
+ } else if (INT64.getTypeEnum().name().equals(node.getType())) {
+ type = INT64;
} else {
throw new SemanticException("Unsupported type in GenericLiteral: " +
node.getType());
}
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/SimplePlanRewriter.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/SimplePlanRewriter.java
new file mode 100644
index 00000000000..d2a8ec781a1
--- /dev/null
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/SimplePlanRewriter.java
@@ -0,0 +1,88 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iotdb.db.queryengine.plan.relational.planner;
+
+import org.apache.iotdb.db.queryengine.plan.planner.plan.node.PlanNode;
+import org.apache.iotdb.db.queryengine.plan.planner.plan.node.PlanVisitor;
+
+import com.google.common.collect.ImmutableList;
+
+import static com.google.common.base.Verify.verifyNotNull;
+import static
org.apache.iotdb.db.queryengine.plan.relational.planner.node.ChildReplacer.replaceChildren;
+
+public abstract class SimplePlanRewriter<C>
+ extends PlanVisitor<PlanNode, SimplePlanRewriter.RewriteContext<C>> {
+ public static <C> PlanNode rewriteWith(SimplePlanRewriter<C> rewriter,
PlanNode node) {
+ return node.accept(rewriter, new RewriteContext<>(rewriter, null));
+ }
+
+ public static <C> PlanNode rewriteWith(SimplePlanRewriter<C> rewriter,
PlanNode node, C context) {
+ return node.accept(rewriter, new RewriteContext<>(rewriter, context));
+ }
+
+ @Override
+ public PlanNode visitPlan(PlanNode node, RewriteContext<C> context) {
+ return context.defaultRewrite(node, context.get());
+ }
+
+ public static class RewriteContext<C> {
+ private final C userContext;
+ private final SimplePlanRewriter<C> nodeRewriter;
+
+ private RewriteContext(SimplePlanRewriter<C> nodeRewriter, C userContext) {
+ this.nodeRewriter = nodeRewriter;
+ this.userContext = userContext;
+ }
+
+ public C get() {
+ return userContext;
+ }
+
+ /**
+ * Invoke the rewrite logic recursively on children of the given node and
swap it out with an
+ * identical copy with the rewritten children
+ */
+ public PlanNode defaultRewrite(PlanNode node) {
+ return defaultRewrite(node, null);
+ }
+
+ /**
+ * Invoke the rewrite logic recursively on children of the given node and
swap it out with an
+ * identical copy with the rewritten children
+ */
+ public PlanNode defaultRewrite(PlanNode node, C context) {
+ ImmutableList.Builder<PlanNode> children =
+ ImmutableList.builderWithExpectedSize(node.getChildren().size());
+ node.getChildren().forEach(source -> children.add(rewrite(source,
context)));
+ return replaceChildren(node, children.build());
+ }
+
+ /** This method is meant for invoking the rewrite logic on children while
processing a node */
+ public PlanNode rewrite(PlanNode node, C userContext) {
+ PlanNode result = node.accept(nodeRewriter, new
RewriteContext<>(nodeRewriter, userContext));
+ return verifyNotNull(result, "nodeRewriter returned null for %s",
node.getClass().getName());
+ }
+
+ /** This method is meant for invoking the rewrite logic on children while
processing a node */
+ public PlanNode rewrite(PlanNode node) {
+ return rewrite(node, null);
+ }
+ }
+}
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/LogicalOptimizeFactory.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/LogicalOptimizeFactory.java
index 62bdd1bf1b9..b3766383c18 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/LogicalOptimizeFactory.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/LogicalOptimizeFactory.java
@@ -208,6 +208,7 @@ public class LogicalOptimizeFactory {
new UnaliasSymbolReferences(plannerContext.getMetadata()),
columnPruningOptimizer,
inlineProjectionLimitFiltersOptimizer,
+ new TransformQuantifiedComparisonApplyToCorrelatedJoin(metadata),
new IterativeOptimizer(
plannerContext,
ruleStats,
@@ -215,6 +216,13 @@ public class LogicalOptimizeFactory {
new RemoveRedundantEnforceSingleRowNode(), new
RemoveUnreferencedScalarSubqueries(),
new TransformUncorrelatedSubqueryToJoin(),
new TransformUncorrelatedInPredicateSubqueryToSemiJoin())),
+ new IterativeOptimizer(
+ plannerContext,
+ ruleStats,
+ ImmutableSet.of(
+ new InlineProjections(plannerContext), new
RemoveRedundantIdentityProjections()
+ /*new TransformCorrelatedSingleRowSubqueryToProject(),
+ new RemoveAggregationInSemiJoin())*/ )),
new CheckSubqueryNodesAreRewritten(),
new IterativeOptimizer(
plannerContext, ruleStats, ImmutableSet.of(new
PruneDistinctAggregation())),
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/PushPredicateIntoTableScan.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/PushPredicateIntoTableScan.java
index 7cf45ff3cbd..f039c09a701 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/PushPredicateIntoTableScan.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/PushPredicateIntoTableScan.java
@@ -212,7 +212,7 @@ public class PushPredicateIntoTableScan implements
PlanOptimizer {
Expression predicate = combineConjuncts(node.getPredicate(),
context.inheritedPredicate);
// when exist diff function, predicate can not be pushed down into
DeviceTableScanNode
- if (containsDiffFunction(predicate)) {
+ if (containsDiffFunction(predicate) ||
canNotPushDownBelowProjectNode(node, predicate)) {
node.setChild(node.getChild().accept(this, new RewriteContext()));
node.setPredicate(predicate);
context.inheritedPredicate = TRUE_LITERAL;
@@ -234,6 +234,35 @@ public class PushPredicateIntoTableScan implements
PlanOptimizer {
return node;
}
+ private boolean canNotPushDownBelowProjectNode(FilterNode node, Expression
predicate) {
+ PlanNode child = node.getChild();
+ if (child instanceof ProjectNode) {
+ // if the inheritedPredicate is not in the output of the child of
ProjectNode, we can not
+ // push it down for now.
+ // (predicate will be computed by the ProjectNode, Trino will rewrite
the predicate in
+ // visitProject, but we have not implemented this for now.)
+ Set<Symbol> outputSymbolsOfProjectChild =
+ new HashSet<>(((ProjectNode) child).getChild().getOutputSymbols());
+ return missingTermsInOutputSymbols(predicate,
outputSymbolsOfProjectChild);
+ }
+ return false;
+ }
+
+ private boolean missingTermsInOutputSymbols(Expression expression,
Set<Symbol> outputSymbols) {
+ if (expression instanceof SymbolReference) {
+ return !outputSymbols.contains(Symbol.from(expression));
+ }
+ if (!expression.getChildren().isEmpty()) {
+ for (Node node : expression.getChildren()) {
+ if (missingTermsInOutputSymbols((Expression) node, outputSymbols)) {
+ return true;
+ }
+ }
+ }
+
+ return false;
+ }
+
// private boolean areExpressionsEquivalent(
// Expression leftExpression, Expression rightExpression) {
// return false;
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/TransformQuantifiedComparisonApplyToCorrelatedJoin.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/TransformQuantifiedComparisonApplyToCorrelatedJoin.java
new file mode 100644
index 00000000000..00fecc1a69a
--- /dev/null
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/TransformQuantifiedComparisonApplyToCorrelatedJoin.java
@@ -0,0 +1,341 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iotdb.db.queryengine.plan.relational.planner.optimizations;
+
+import org.apache.iotdb.db.queryengine.common.QueryId;
+import org.apache.iotdb.db.queryengine.plan.planner.plan.node.PlanNode;
+import org.apache.iotdb.db.queryengine.plan.relational.function.BoundSignature;
+import org.apache.iotdb.db.queryengine.plan.relational.function.FunctionId;
+import org.apache.iotdb.db.queryengine.plan.relational.function.FunctionKind;
+import
org.apache.iotdb.db.queryengine.plan.relational.metadata.FunctionNullability;
+import org.apache.iotdb.db.queryengine.plan.relational.metadata.Metadata;
+import
org.apache.iotdb.db.queryengine.plan.relational.metadata.ResolvedFunction;
+import org.apache.iotdb.db.queryengine.plan.relational.planner.Assignments;
+import
org.apache.iotdb.db.queryengine.plan.relational.planner.SimplePlanRewriter;
+import org.apache.iotdb.db.queryengine.plan.relational.planner.Symbol;
+import org.apache.iotdb.db.queryengine.plan.relational.planner.SymbolAllocator;
+import org.apache.iotdb.db.queryengine.plan.relational.planner.ir.IrUtils;
+import
org.apache.iotdb.db.queryengine.plan.relational.planner.node.AggregationNode;
+import org.apache.iotdb.db.queryengine.plan.relational.planner.node.ApplyNode;
+import
org.apache.iotdb.db.queryengine.plan.relational.planner.node.CorrelatedJoinNode;
+import org.apache.iotdb.db.queryengine.plan.relational.planner.node.JoinNode;
+import
org.apache.iotdb.db.queryengine.plan.relational.planner.node.ProjectNode;
+import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.BooleanLiteral;
+import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Cast;
+import
org.apache.iotdb.db.queryengine.plan.relational.sql.ast.ComparisonExpression;
+import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Expression;
+import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.GenericLiteral;
+import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.NullLiteral;
+import
org.apache.iotdb.db.queryengine.plan.relational.sql.ast.SearchedCaseExpression;
+import
org.apache.iotdb.db.queryengine.plan.relational.sql.ast.SimpleCaseExpression;
+import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.WhenClause;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+import org.apache.tsfile.read.common.type.LongType;
+import org.apache.tsfile.read.common.type.Type;
+
+import java.util.EnumSet;
+import java.util.List;
+import java.util.Locale;
+import java.util.Optional;
+import java.util.function.Function;
+
+import static com.google.common.base.Preconditions.checkState;
+import static com.google.common.collect.Iterables.getOnlyElement;
+import static java.util.Objects.requireNonNull;
+import static
org.apache.iotdb.db.queryengine.plan.relational.planner.SimplePlanRewriter.rewriteWith;
+import static
org.apache.iotdb.db.queryengine.plan.relational.planner.ir.IrUtils.combineConjuncts;
+import static
org.apache.iotdb.db.queryengine.plan.relational.planner.node.AggregationNode.globalAggregation;
+import static
org.apache.iotdb.db.queryengine.plan.relational.planner.node.AggregationNode.singleAggregation;
+import static
org.apache.iotdb.db.queryengine.plan.relational.planner.node.ApplyNode.Quantifier.ALL;
+import static
org.apache.iotdb.db.queryengine.plan.relational.sql.ast.BooleanLiteral.FALSE_LITERAL;
+import static
org.apache.iotdb.db.queryengine.plan.relational.sql.ast.BooleanLiteral.TRUE_LITERAL;
+import static
org.apache.iotdb.db.queryengine.plan.relational.sql.ast.ComparisonExpression.Operator.EQUAL;
+import static
org.apache.iotdb.db.queryengine.plan.relational.sql.ast.ComparisonExpression.Operator.GREATER_THAN;
+import static
org.apache.iotdb.db.queryengine.plan.relational.sql.ast.ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL;
+import static
org.apache.iotdb.db.queryengine.plan.relational.sql.ast.ComparisonExpression.Operator.LESS_THAN;
+import static
org.apache.iotdb.db.queryengine.plan.relational.sql.ast.ComparisonExpression.Operator.LESS_THAN_OR_EQUAL;
+import static
org.apache.iotdb.db.queryengine.plan.relational.sql.ast.ComparisonExpression.Operator.NOT_EQUAL;
+import static
org.apache.iotdb.db.queryengine.plan.relational.type.TypeSignatureTranslator.toSqlType;
+import static org.apache.tsfile.read.common.type.BooleanType.BOOLEAN;
+
+public class TransformQuantifiedComparisonApplyToCorrelatedJoin implements
PlanOptimizer {
+ private final Metadata metadata;
+
+ public TransformQuantifiedComparisonApplyToCorrelatedJoin(Metadata metadata)
{
+ this.metadata = requireNonNull(metadata, "metadata is null");
+ }
+
+ @Override
+ public PlanNode optimize(PlanNode plan, Context context) {
+ return rewriteWith(
+ new Rewriter(context.idAllocator(), context.getSymbolAllocator(),
metadata), plan, null);
+ }
+
+ private static class Rewriter extends SimplePlanRewriter<PlanNode> {
+ private final QueryId idAllocator;
+ private final SymbolAllocator symbolAllocator;
+ private final Metadata metadata;
+
+ public Rewriter(QueryId idAllocator, SymbolAllocator symbolAllocator,
Metadata metadata) {
+ this.idAllocator = requireNonNull(idAllocator, "idAllocator is null");
+ this.symbolAllocator = requireNonNull(symbolAllocator, "symbolAllocator
is null");
+ this.metadata = requireNonNull(metadata, "metadata is null");
+ }
+
+ @Override
+ public PlanNode visitApply(ApplyNode node, RewriteContext<PlanNode>
context) {
+ if (node.getSubqueryAssignments().size() != 1) {
+ return context.defaultRewrite(node);
+ }
+
+ ApplyNode.SetExpression expression =
getOnlyElement(node.getSubqueryAssignments().values());
+ if (expression instanceof ApplyNode.QuantifiedComparison) {
+ return rewriteQuantifiedApplyNode(
+ node, (ApplyNode.QuantifiedComparison) expression, context);
+ }
+
+ return context.defaultRewrite(node);
+ }
+
+ private PlanNode rewriteQuantifiedApplyNode(
+ ApplyNode node,
+ ApplyNode.QuantifiedComparison quantifiedComparison,
+ RewriteContext<PlanNode> context) {
+ PlanNode subqueryPlan = context.rewrite(node.getSubquery());
+
+ Symbol outputColumn = getOnlyElement(subqueryPlan.getOutputSymbols());
+ Type outputColumnType =
symbolAllocator.getTypes().getTableModelType(outputColumn);
+ checkState(outputColumnType.isOrderable(), "Subquery result type must be
orderable");
+
+ Symbol minValue = symbolAllocator.newSymbol("min", outputColumnType);
+ Symbol maxValue = symbolAllocator.newSymbol("max", outputColumnType);
+ Symbol countAllValue = symbolAllocator.newSymbol("count_all",
LongType.getInstance());
+ Symbol countNonNullValue =
+ symbolAllocator.newSymbol("count_non_null", LongType.getInstance());
+
+ List<Expression> outputColumnReferences =
ImmutableList.of(outputColumn.toSymbolReference());
+
+ subqueryPlan =
+ singleAggregation(
+ idAllocator.genPlanNodeId(),
+ subqueryPlan,
+ ImmutableMap.of(
+ minValue,
+ new AggregationNode.Aggregation(
+ getResolvedBuiltInAggregateFunction(
+ "min", ImmutableList.of(outputColumnType)),
+ outputColumnReferences,
+ false,
+ Optional.empty(),
+ Optional.empty(),
+ Optional.empty()),
+ maxValue,
+ new AggregationNode.Aggregation(
+ getResolvedBuiltInAggregateFunction(
+ "max", ImmutableList.of(outputColumnType)),
+ outputColumnReferences,
+ false,
+ Optional.empty(),
+ Optional.empty(),
+ Optional.empty()),
+ countAllValue,
+ new AggregationNode.Aggregation(
+ getResolvedBuiltInAggregateFunction(
+ "count_all", ImmutableList.of(outputColumnType)),
+ outputColumnReferences,
+ false,
+ Optional.empty(),
+ Optional.empty(),
+ Optional.empty()),
+ countNonNullValue,
+ new AggregationNode.Aggregation(
+ getResolvedBuiltInAggregateFunction(
+ "count", ImmutableList.of(outputColumnType)),
+ outputColumnReferences,
+ false,
+ Optional.empty(),
+ Optional.empty(),
+ Optional.empty())),
+ globalAggregation());
+
+ PlanNode join =
+ new CorrelatedJoinNode(
+ node.getPlanNodeId(),
+ context.rewrite(node.getInput()),
+ subqueryPlan,
+ node.getCorrelation(),
+ JoinNode.JoinType.INNER,
+ TRUE_LITERAL,
+ node.getOriginSubquery());
+
+ Expression valueComparedToSubquery =
+ rewriteUsingBounds(
+ quantifiedComparison, minValue, maxValue, countAllValue,
countNonNullValue);
+
+ Symbol quantifiedComparisonSymbol =
getOnlyElement(node.getSubqueryAssignments().keySet());
+
+ return projectExpressions(
+ join, Assignments.of(quantifiedComparisonSymbol,
valueComparedToSubquery));
+ }
+
+ private ResolvedFunction getResolvedBuiltInAggregateFunction(
+ String functionName, List<Type> argumentTypes) {
+ // The same as the code in ExpressionAnalyzer
+ Type type = metadata.getFunctionReturnType(functionName, argumentTypes);
+ return new ResolvedFunction(
+ new BoundSignature(functionName.toLowerCase(Locale.ENGLISH), type,
argumentTypes),
+ new FunctionId("noop"),
+ FunctionKind.AGGREGATE,
+ true,
+
FunctionNullability.getAggregationFunctionNullability(argumentTypes.size()));
+ }
+
+ public Expression rewriteUsingBounds(
+ ApplyNode.QuantifiedComparison quantifiedComparison,
+ Symbol minValue,
+ Symbol maxValue,
+ Symbol countAllValue,
+ Symbol countNonNullValue) {
+ BooleanLiteral emptySetResult;
+ Function<List<Expression>, Expression> quantifier;
+ if (quantifiedComparison.getQuantifier() == ALL) {
+ emptySetResult = TRUE_LITERAL;
+ quantifier = IrUtils::combineConjuncts;
+ } else {
+ emptySetResult = FALSE_LITERAL;
+ quantifier = IrUtils::combineDisjuncts;
+ }
+ Expression comparisonWithExtremeValue =
+ getBoundComparisons(quantifiedComparison, minValue, maxValue);
+
+ return new SimpleCaseExpression(
+ countAllValue.toSymbolReference(),
+ ImmutableList.of(new WhenClause(new GenericLiteral("INT64", "0"),
emptySetResult)),
+ quantifier.apply(
+ ImmutableList.of(
+ comparisonWithExtremeValue,
+ new SearchedCaseExpression(
+ ImmutableList.of(
+ new WhenClause(
+ new ComparisonExpression(
+ NOT_EQUAL,
+ countAllValue.toSymbolReference(),
+ countNonNullValue.toSymbolReference()),
+ new Cast(new NullLiteral(),
toSqlType(BOOLEAN)))),
+ emptySetResult))));
+ }
+
+ private Expression getBoundComparisons(
+ ApplyNode.QuantifiedComparison quantifiedComparison, Symbol minValue,
Symbol maxValue) {
+ if (mapOperator(quantifiedComparison) == EQUAL
+ && quantifiedComparison.getQuantifier() == ALL) {
+ // A = ALL B <=> min B = max B && A = min B
+ return combineConjuncts(
+ new ComparisonExpression(
+ EQUAL, minValue.toSymbolReference(),
maxValue.toSymbolReference()),
+ new ComparisonExpression(
+ EQUAL,
+ quantifiedComparison.getValue().toSymbolReference(),
+ maxValue.toSymbolReference()));
+ }
+
+ if (EnumSet.of(LESS_THAN, LESS_THAN_OR_EQUAL, GREATER_THAN,
GREATER_THAN_OR_EQUAL)
+ .contains(mapOperator(quantifiedComparison))) {
+ // A < ALL B <=> A < min B
+ // A > ALL B <=> A > max B
+ // A < ANY B <=> A < max B
+ // A > ANY B <=> A > min B
+ Symbol boundValue =
+ shouldCompareValueWithLowerBound(quantifiedComparison) ? minValue
: maxValue;
+ return new ComparisonExpression(
+ mapOperator(quantifiedComparison),
+ quantifiedComparison.getValue().toSymbolReference(),
+ boundValue.toSymbolReference());
+ }
+ throw new IllegalArgumentException(
+ "Unsupported quantified comparison: " + quantifiedComparison);
+ }
+
+ private static ComparisonExpression.Operator mapOperator(
+ ApplyNode.QuantifiedComparison quantifiedComparison) {
+ switch (quantifiedComparison.getOperator()) {
+ case EQUAL:
+ return EQUAL;
+ case NOT_EQUAL:
+ return NOT_EQUAL;
+ case LESS_THAN:
+ return LESS_THAN;
+ case LESS_THAN_OR_EQUAL:
+ return LESS_THAN_OR_EQUAL;
+ case GREATER_THAN:
+ return GREATER_THAN;
+ case GREATER_THAN_OR_EQUAL:
+ return GREATER_THAN_OR_EQUAL;
+ default:
+ throw new IllegalArgumentException(
+ "Unexpected quantifiedComparison: " +
quantifiedComparison.getOperator());
+ }
+ }
+
+ private static boolean shouldCompareValueWithLowerBound(
+ ApplyNode.QuantifiedComparison quantifiedComparison) {
+ ComparisonExpression.Operator operator =
mapOperator(quantifiedComparison);
+ switch (quantifiedComparison.getQuantifier()) {
+ case ALL:
+ switch (operator) {
+ case LESS_THAN:
+ case LESS_THAN_OR_EQUAL:
+ return true;
+ case GREATER_THAN:
+ case GREATER_THAN_OR_EQUAL:
+ return false;
+ default:
+ throw new IllegalArgumentException("Unexpected value: " +
operator);
+ }
+ case ANY:
+ case SOME:
+ switch (operator) {
+ case LESS_THAN:
+ case LESS_THAN_OR_EQUAL:
+ return false;
+ case GREATER_THAN:
+ case GREATER_THAN_OR_EQUAL:
+ return true;
+ default:
+ throw new IllegalArgumentException("Unexpected value: " +
operator);
+ }
+ default:
+ throw new IllegalArgumentException(
+ "Unexpected Quantifier: " +
quantifiedComparison.getQuantifier());
+ }
+ }
+
+ private ProjectNode projectExpressions(PlanNode input, Assignments
subqueryAssignments) {
+ Assignments assignments =
+ Assignments.builder()
+ .putIdentities(input.getOutputSymbols())
+ .putAll(subqueryAssignments)
+ .build();
+ return new ProjectNode(idAllocator.genPlanNodeId(), input, assignments);
+ }
+ }
+}
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/utils/constant/SqlConstant.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/utils/constant/SqlConstant.java
index 43267f9e48e..cad8f229c4e 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/utils/constant/SqlConstant.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/utils/constant/SqlConstant.java
@@ -62,6 +62,7 @@ public class SqlConstant {
public static final String LAST_AGGREGATION = "last";
public static final String FIRST_AGGREGATION = "first";
public static final String COUNT = "count";
+ public static final String COUNT_ALL = "count_all";
public static final String AVG = "avg";
public static final String SUM = "sum";
public static final String COUNT_IF = "count_if";
diff --git
a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/planner/SubqueryTest.java
b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/planner/SubqueryTest.java
index 52321265280..96dbe9d2b14 100644
---
a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/planner/SubqueryTest.java
+++
b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/planner/SubqueryTest.java
@@ -56,6 +56,7 @@ import static
org.apache.iotdb.db.queryengine.plan.relational.planner.assertions
import static
org.apache.iotdb.db.queryengine.plan.relational.planner.node.AggregationNode.Step.FINAL;
import static
org.apache.iotdb.db.queryengine.plan.relational.planner.node.AggregationNode.Step.INTERMEDIATE;
import static
org.apache.iotdb.db.queryengine.plan.relational.planner.node.AggregationNode.Step.PARTIAL;
+import static
org.apache.iotdb.db.queryengine.plan.relational.planner.node.AggregationNode.Step.SINGLE;
import static
org.apache.iotdb.db.queryengine.plan.relational.sql.ast.ComparisonExpression.Operator.EQUAL;
public class SubqueryTest {
@@ -322,4 +323,215 @@ public class SubqueryTest {
filterPredicate,
semiJoin("s1", "s1_6", "expr", sort(tableScan1),
sort(tableScan2))))));
}
+
+ @Test
+ public void testUncorrelatedAnyComparisonSubquery() {
+ PlanTester planTester = new PlanTester();
+
+ String sql = "SELECT s1 FROM table1 where s1 > any (select s1 from
table1)";
+
+ LogicalQueryPlan logicalQueryPlan = planTester.createPlan(sql);
+
+ PlanMatchPattern tableScan1 =
+ tableScan("testdb.table1", ImmutableList.of("s1"),
ImmutableSet.of("s1"));
+
+ PlanMatchPattern tableScan2 = tableScan("testdb.table1",
ImmutableMap.of("s1_7", "s1"));
+
+ PlanMatchPattern tableScan3 = tableScan("testdb.table1",
ImmutableMap.of("s1_6", "s1"));
+
+ // Verify full LogicalPlan
+ /*
+ * └──OutputNode
+ * └──ProjectNode
+ * └──FilterNode
+ * └──ProjectNode
+ * └──JoinNode
+ * |──TableScanNode
+ * ├──AggregationNode
+ * │ └──TableScanNode
+
+ */
+ assertPlan(
+ logicalQueryPlan,
+ output(
+ project(
+ anyTree(
+ project(
+ join(
+ JoinNode.JoinType.INNER,
+ builder ->
+ builder
+ .left(tableScan1)
+ .right(
+ aggregation(
+ singleGroupingSet(),
+ ImmutableMap.of(
+ Optional.of("min"),
+ aggregationFunction(
+ "min",
ImmutableList.of("s1_7")),
+ Optional.of("count_all"),
+ aggregationFunction(
+ "count_all",
ImmutableList.of("s1_7")),
+ Optional.of("count_non_null"),
+ aggregationFunction(
+ "count",
ImmutableList.of("s1_7"))),
+ Collections.emptyList(),
+ Optional.empty(),
+ SINGLE,
+ tableScan2))))))));
+
+ // Verify DistributionPlan
+ assertPlan(
+ planTester.getFragmentPlan(0),
+ output(
+ project(
+ anyTree(
+ project(
+ join(
+ JoinNode.JoinType.INNER,
+ builder ->
+ builder
+ .left(collect(exchange(), tableScan1,
exchange()))
+ .right(
+ aggregation(
+ singleGroupingSet(),
+ ImmutableMap.of(
+ Optional.of("min"),
+ aggregationFunction(
+ "min",
ImmutableList.of("min_9")),
+ Optional.of("count_all"),
+ aggregationFunction(
+ "count_all",
ImmutableList.of("count_all_10")),
+ Optional.of("count_non_null"),
+ aggregationFunction(
+ "count",
ImmutableList.of("count"))),
+ Collections.emptyList(),
+ Optional.empty(),
+ FINAL,
+ collect(
+ exchange(),
+ aggregation(
+ singleGroupingSet(),
+ ImmutableMap.of(
+ Optional.of("min_9"),
+ aggregationFunction(
+ "min",
ImmutableList.of("s1_6")),
+
Optional.of("count_all_10"),
+ aggregationFunction(
+ "count_all",
ImmutableList.of("s1_6")),
+ Optional.of("count"),
+ aggregationFunction(
+ "count",
ImmutableList.of("s1_6"))),
+ Collections.emptyList(),
+ Optional.empty(),
+ PARTIAL,
+ tableScan3),
+ exchange())))))))));
+
+ assertPlan(planTester.getFragmentPlan(1), tableScan1);
+
+ assertPlan(planTester.getFragmentPlan(2), tableScan1);
+
+ assertPlan(
+ planTester.getFragmentPlan(3),
+ aggregation(
+ singleGroupingSet(),
+ ImmutableMap.of(
+ Optional.of("min_9"),
+ aggregationFunction("min", ImmutableList.of("s1_6")),
+ Optional.of("count_all_10"),
+ aggregationFunction("count_all", ImmutableList.of("s1_6")),
+ Optional.of("count"),
+ aggregationFunction("count", ImmutableList.of("s1_6"))),
+ Collections.emptyList(),
+ Optional.empty(),
+ PARTIAL,
+ tableScan3));
+
+ assertPlan(
+ planTester.getFragmentPlan(4),
+ aggregation(
+ singleGroupingSet(),
+ ImmutableMap.of(
+ Optional.of("min_9"),
+ aggregationFunction("min", ImmutableList.of("s1_6")),
+ Optional.of("count_all_10"),
+ aggregationFunction("count_all", ImmutableList.of("s1_6")),
+ Optional.of("count"),
+ aggregationFunction("count", ImmutableList.of("s1_6"))),
+ Collections.emptyList(),
+ Optional.empty(),
+ PARTIAL,
+ tableScan3));
+ }
+
+ @Test
+ public void testUncorrelatedEqualsSomeComparisonSubquery() {
+ PlanTester planTester = new PlanTester();
+
+ String sql = "SELECT s1 FROM table1 where s1 = some (select s1 from
table1)";
+
+ LogicalQueryPlan logicalQueryPlan = planTester.createPlan(sql);
+
+ Expression filterPredicate = new SymbolReference("expr");
+
+ PlanMatchPattern tableScan1 =
+ tableScan("testdb.table1", ImmutableList.of("s1"),
ImmutableSet.of("s1"));
+
+ PlanMatchPattern tableScan2 = tableScan("testdb.table1",
ImmutableMap.of("s1_6", "s1"));
+
+ // Verify full LogicalPlan
+ /*
+ * └──OutputNode
+ * └──ProjectNode
+ * └──FilterNode
+ * └──SemiJoinNode
+ * |──SortNode
+ * | └──TableScanNode
+ * ├──SortNode
+ * │ └──TableScanNode
+
+ */
+ assertPlan(
+ logicalQueryPlan,
+ output(
+ project(
+ filter(
+ filterPredicate,
+ semiJoin("s1", "s1_6", "expr", sort(tableScan1),
sort(tableScan2))))));
+ }
+
+ @Test
+ public void testUncorrelatedAllComparisonSubquery() {
+ PlanTester planTester = new PlanTester();
+
+ String sql = "SELECT s1 FROM table1 where s1 != all (select s1 from
table1)";
+
+ LogicalQueryPlan logicalQueryPlan = planTester.createPlan(sql);
+
+ PlanMatchPattern tableScan1 =
+ tableScan("testdb.table1", ImmutableList.of("s1"),
ImmutableSet.of("s1"));
+
+ PlanMatchPattern tableScan2 = tableScan("testdb.table1",
ImmutableMap.of("s1_6", "s1"));
+
+ // Verify full LogicalPlan
+ /*
+ * └──OutputNode
+ * └──ProjectNode
+ * └──FilterNode
+ * └──ProjectNode
+ * └──SemiJoinNode
+ * |──SortNode
+ * | └──TableScanNode
+ * ├──SortNode
+ * │ └──TableScanNode
+
+ */
+ assertPlan(
+ logicalQueryPlan,
+ output(
+ project(
+ anyTree(
+ project(semiJoin("s1", "s1_6", "expr", sort(tableScan1),
sort(tableScan2)))))));
+ }
}
diff --git
a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/relational/TableBuiltinAggregationFunction.java
b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/relational/TableBuiltinAggregationFunction.java
index 10aa13ed4ad..3d0510957d1 100644
---
a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/relational/TableBuiltinAggregationFunction.java
+++
b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/relational/TableBuiltinAggregationFunction.java
@@ -38,6 +38,7 @@ import static
org.apache.tsfile.read.common.type.LongType.INT64;
public enum TableBuiltinAggregationFunction {
SUM("sum"),
COUNT("count"),
+ COUNT_ALL("count_all"),
COUNT_IF("count_if"),
AVG("avg"),
EXTREME("extreme"),
@@ -82,6 +83,7 @@ public enum TableBuiltinAggregationFunction {
final String functionName = name.toLowerCase();
switch (functionName) {
case "count":
+ case "count_all":
case "count_if":
return INT64;
case "sum":
diff --git a/iotdb-protocol/thrift-commons/src/main/thrift/common.thrift
b/iotdb-protocol/thrift-commons/src/main/thrift/common.thrift
index c46c4c0a65a..93eafedc118 100644
--- a/iotdb-protocol/thrift-commons/src/main/thrift/common.thrift
+++ b/iotdb-protocol/thrift-commons/src/main/thrift/common.thrift
@@ -281,7 +281,8 @@ enum TAggregationType {
FIRST_BY,
LAST_BY,
MIN,
- MAX
+ MAX,
+ COUNT_ALL
}
struct TShowConfigurationTemplateResp {