This is an automated email from the ASF dual-hosted git repository. chenglei pushed a commit to branch 4.x in repository https://gitbox.apache.org/repos/asf/phoenix.git
The following commit(s) were added to refs/heads/4.x by this push: new a99f692 PHOENIX-6498 Fix incorrect Correlated Exists Subquery rewrite when Subquery is aggregate a99f692 is described below commit a99f692be6c6584c3ae07c21d824c47efc845b18 Author: chenglei <cheng...@apache.org> AuthorDate: Wed Jul 28 11:57:12 2021 +0800 PHOENIX-6498 Fix incorrect Correlated Exists Subquery rewrite when Subquery is aggregate --- .../apache/phoenix/end2end/join/SubqueryIT.java | 102 +++++++++ .../end2end/join/SubqueryUsingSortMergeJoinIT.java | 104 ++++++++- .../apache/phoenix/compile/SubqueryRewriter.java | 248 ++++++++++++++++----- .../apache/phoenix/compile/QueryCompilerTest.java | 118 +++++++++- 4 files changed, 510 insertions(+), 62 deletions(-) diff --git a/phoenix-core/src/it/java/org/apache/phoenix/end2end/join/SubqueryIT.java b/phoenix-core/src/it/java/org/apache/phoenix/end2end/join/SubqueryIT.java index e563ca6..010e569 100644 --- a/phoenix-core/src/it/java/org/apache/phoenix/end2end/join/SubqueryIT.java +++ b/phoenix-core/src/it/java/org/apache/phoenix/end2end/join/SubqueryIT.java @@ -809,6 +809,108 @@ public class SubqueryIT extends BaseJoinIT { } @Test + public void testCorrelatedExistsSubqueryBug6498() throws Exception { + Properties props = PropertiesUtil.deepCopy(TEST_PROPERTIES); + final Connection conn = DriverManager.getConnection(getUrl(), props); + String tableName1 = getTableName(conn, JOIN_ITEM_TABLE_FULL_NAME); + String tableName4 = getTableName(conn, JOIN_ORDER_TABLE_FULL_NAME); + try { + String query = "SELECT \"order_id\", name FROM " + tableName4 + + " o JOIN " + tableName1 + + " i ON o.\"item_id\" = i.\"item_id\" WHERE exists " + + "(SELECT 1 FROM " + tableName4 + " q WHERE o.\"item_id\" = q.\"item_id\"" + + " group by q.\"customer_id\" having count(\"order_id\") > 1)"; + PreparedStatement statement = conn.prepareStatement(query); + ResultSet rs = statement.executeQuery(); + assertFalse(rs.next()); + + query = "SELECT \"order_id\", name FROM " + tableName4 + + " o JOIN " + tableName1 + + " i ON o.\"item_id\" = i.\"item_id\" WHERE exists " + + "(SELECT 1 FROM " + tableName4 + " q WHERE o.\"item_id\" = q.\"item_id\"" + + " group by q.\"customer_id\" having count(\"order_id\") >= 1) order by \"order_id\""; + statement = conn.prepareStatement(query); + rs = statement.executeQuery(); + assertTrue (rs.next()); + assertEquals(rs.getString(1), "000000000000001"); + assertEquals(rs.getString(2), "T1"); + assertTrue (rs.next()); + assertEquals(rs.getString(1), "000000000000002"); + assertEquals(rs.getString(2), "T6"); + assertTrue (rs.next()); + assertEquals(rs.getString(1), "000000000000003"); + assertEquals(rs.getString(2), "T2"); + assertTrue (rs.next()); + assertEquals(rs.getString(1), "000000000000004"); + assertEquals(rs.getString(2), "T6"); + assertTrue (rs.next()); + assertEquals(rs.getString(1), "000000000000005"); + assertEquals(rs.getString(2), "T3"); + assertFalse(rs.next()); + + query = "SELECT \"order_id\", name FROM " + tableName4 + + " o JOIN " + tableName1 + + " i ON o.\"item_id\" = i.\"item_id\" WHERE exists " + + "(SELECT 1 FROM " + tableName4 + " q WHERE o.\"item_id\" = q.\"item_id\"" + + " and q.price <= 150 group by q.\"customer_id\" having count(\"order_id\") >= 1)"+ + " or o.quantity = 5000 order by \"order_id\""; + statement = conn.prepareStatement(query); + rs = statement.executeQuery(); + assertTrue (rs.next()); + assertEquals(rs.getString(1), "000000000000001"); + assertEquals(rs.getString(2), "T1"); + assertTrue (rs.next()); + assertEquals(rs.getString(1), "000000000000005"); + assertEquals(rs.getString(2), "T3"); + assertFalse(rs.next()); + + query = "SELECT \"order_id\" FROM " + tableName4 + + " o WHERE exists (SELECT 1 FROM " + tableName4 + + " WHERE o.\"item_id\" = \"item_id\" AND \"order_id\" != '000000000000004' GROUP BY \"order_id\"" + + " having count(\"customer_id\") >= 1) order by \"order_id\""; + statement = conn.prepareStatement(query); + rs = statement.executeQuery(); + assertTrue (rs.next()); + assertEquals(rs.getString(1), "000000000000001"); + assertTrue (rs.next()); + assertEquals(rs.getString(1), "000000000000002"); + assertTrue (rs.next()); + assertEquals(rs.getString(1), "000000000000003"); + assertTrue (rs.next()); + assertEquals(rs.getString(1), "000000000000004"); + assertTrue (rs.next()); + assertEquals(rs.getString(1), "000000000000005"); + assertFalse(rs.next()); + + query = "SELECT \"order_id\" FROM " + tableName4 + + " o WHERE exists (SELECT 1 FROM " + tableName4 + + " WHERE o.\"item_id\" = \"item_id\" AND \"order_id\" != '000000000000003' GROUP BY \"order_id\"" + + " having count(\"customer_id\") >= 1) order by \"order_id\""; + statement = conn.prepareStatement(query); + rs = statement.executeQuery(); + assertTrue (rs.next()); + assertEquals(rs.getString(1), "000000000000001"); + assertTrue (rs.next()); + assertEquals(rs.getString(1), "000000000000002"); + assertTrue (rs.next()); + assertEquals(rs.getString(1), "000000000000004"); + assertTrue (rs.next()); + assertEquals(rs.getString(1), "000000000000005"); + assertFalse(rs.next()); + + query = "SELECT \"order_id\" FROM " + tableName4 + + " o WHERE exists (SELECT 1 FROM " + tableName4 + + " WHERE o.\"item_id\" = \"item_id\" AND \"order_id\" != '000000000000003' GROUP BY \"order_id\"" + + " having count(\"customer_id\") > 1) order by \"order_id\""; + statement = conn.prepareStatement(query); + rs = statement.executeQuery(); + assertFalse(rs.next()); + } finally { + conn.close(); + } + } + + @Test public void testAnyAllComparisonSubquery() throws Exception { Properties props = PropertiesUtil.deepCopy(TEST_PROPERTIES); Connection conn = DriverManager.getConnection(getUrl(), props); diff --git a/phoenix-core/src/it/java/org/apache/phoenix/end2end/join/SubqueryUsingSortMergeJoinIT.java b/phoenix-core/src/it/java/org/apache/phoenix/end2end/join/SubqueryUsingSortMergeJoinIT.java index 9d98a04..67270f2 100644 --- a/phoenix-core/src/it/java/org/apache/phoenix/end2end/join/SubqueryUsingSortMergeJoinIT.java +++ b/phoenix-core/src/it/java/org/apache/phoenix/end2end/join/SubqueryUsingSortMergeJoinIT.java @@ -353,7 +353,7 @@ public class SubqueryUsingSortMergeJoinIT extends BaseJoinIT { conn.close(); } } - + @Test public void testExistsSubquery() throws Exception { Properties props = PropertiesUtil.deepCopy(TEST_PROPERTIES); @@ -599,6 +599,108 @@ public class SubqueryUsingSortMergeJoinIT extends BaseJoinIT { } @Test + public void testCorrelatedExistsSubqueryBug6498() throws Exception { + Properties props = PropertiesUtil.deepCopy(TEST_PROPERTIES); + final Connection conn = DriverManager.getConnection(getUrl(), props); + String tableName1 = getTableName(conn, JOIN_ITEM_TABLE_FULL_NAME); + String tableName4 = getTableName(conn, JOIN_ORDER_TABLE_FULL_NAME); + try { + String query = "SELECT /*+ USE_SORT_MERGE_JOIN*/ \"order_id\", name FROM " + tableName4 + + " o JOIN " + tableName1 + + " i ON o.\"item_id\" = i.\"item_id\" WHERE exists " + + "(SELECT 1 FROM " + tableName4 + " q WHERE o.\"item_id\" = q.\"item_id\"" + + " group by q.\"customer_id\" having count(\"order_id\") > 1)"; + PreparedStatement statement = conn.prepareStatement(query); + ResultSet rs = statement.executeQuery(); + assertFalse(rs.next()); + + query = "SELECT /*+ USE_SORT_MERGE_JOIN*/ \"order_id\", name FROM " + tableName4 + + " o JOIN " + tableName1 + + " i ON o.\"item_id\" = i.\"item_id\" WHERE exists " + + "(SELECT 1 FROM " + tableName4 + " q WHERE o.\"item_id\" = q.\"item_id\"" + + " group by q.\"customer_id\" having count(\"order_id\") >= 1) order by \"order_id\""; + statement = conn.prepareStatement(query); + rs = statement.executeQuery(); + assertTrue (rs.next()); + assertEquals(rs.getString(1), "000000000000001"); + assertEquals(rs.getString(2), "T1"); + assertTrue (rs.next()); + assertEquals(rs.getString(1), "000000000000002"); + assertEquals(rs.getString(2), "T6"); + assertTrue (rs.next()); + assertEquals(rs.getString(1), "000000000000003"); + assertEquals(rs.getString(2), "T2"); + assertTrue (rs.next()); + assertEquals(rs.getString(1), "000000000000004"); + assertEquals(rs.getString(2), "T6"); + assertTrue (rs.next()); + assertEquals(rs.getString(1), "000000000000005"); + assertEquals(rs.getString(2), "T3"); + assertFalse(rs.next()); + + query = "SELECT /*+ USE_SORT_MERGE_JOIN*/ \"order_id\", name FROM " + tableName4 + + " o JOIN " + tableName1 + + " i ON o.\"item_id\" = i.\"item_id\" WHERE exists " + + "(SELECT 1 FROM " + tableName4 + " q WHERE o.\"item_id\" = q.\"item_id\"" + + " and q.price <= 150 group by q.\"customer_id\" having count(\"order_id\") >= 1)"+ + " or o.quantity = 5000 order by \"order_id\""; + statement = conn.prepareStatement(query); + rs = statement.executeQuery(); + assertTrue (rs.next()); + assertEquals(rs.getString(1), "000000000000001"); + assertEquals(rs.getString(2), "T1"); + assertTrue (rs.next()); + assertEquals(rs.getString(1), "000000000000005"); + assertEquals(rs.getString(2), "T3"); + assertFalse(rs.next()); + + query = "SELECT /*+ USE_SORT_MERGE_JOIN*/ \"order_id\" FROM " + tableName4 + + " o WHERE exists (SELECT 1 FROM " + tableName4 + + " WHERE o.\"item_id\" = \"item_id\" AND \"order_id\" != '000000000000004' GROUP BY \"order_id\"" + + " having count(\"customer_id\") >= 1) order by \"order_id\""; + statement = conn.prepareStatement(query); + rs = statement.executeQuery(); + assertTrue (rs.next()); + assertEquals(rs.getString(1), "000000000000001"); + assertTrue (rs.next()); + assertEquals(rs.getString(1), "000000000000002"); + assertTrue (rs.next()); + assertEquals(rs.getString(1), "000000000000003"); + assertTrue (rs.next()); + assertEquals(rs.getString(1), "000000000000004"); + assertTrue (rs.next()); + assertEquals(rs.getString(1), "000000000000005"); + assertFalse(rs.next()); + + query = "SELECT /*+ USE_SORT_MERGE_JOIN*/ \"order_id\" FROM " + tableName4 + + " o WHERE exists (SELECT 1 FROM " + tableName4 + + " WHERE o.\"item_id\" = \"item_id\" AND \"order_id\" != '000000000000003' GROUP BY \"order_id\"" + + " having count(\"customer_id\") >= 1) order by \"order_id\""; + statement = conn.prepareStatement(query); + rs = statement.executeQuery(); + assertTrue (rs.next()); + assertEquals(rs.getString(1), "000000000000001"); + assertTrue (rs.next()); + assertEquals(rs.getString(1), "000000000000002"); + assertTrue (rs.next()); + assertEquals(rs.getString(1), "000000000000004"); + assertTrue (rs.next()); + assertEquals(rs.getString(1), "000000000000005"); + assertFalse(rs.next()); + + query = "SELECT /*+ USE_SORT_MERGE_JOIN*/ \"order_id\" FROM " + tableName4 + + " o WHERE exists (SELECT 1 FROM " + tableName4 + + " WHERE o.\"item_id\" = \"item_id\" AND \"order_id\" != '000000000000003' GROUP BY \"order_id\"" + + " having count(\"customer_id\") > 1) order by \"order_id\""; + statement = conn.prepareStatement(query); + rs = statement.executeQuery(); + assertFalse(rs.next()); + } finally { + conn.close(); + } + } + + @Test public void testAnyAllComparisonSubquery() throws Exception { Properties props = PropertiesUtil.deepCopy(TEST_PROPERTIES); Connection conn = DriverManager.getConnection(getUrl(), props); diff --git a/phoenix-core/src/main/java/org/apache/phoenix/compile/SubqueryRewriter.java b/phoenix-core/src/main/java/org/apache/phoenix/compile/SubqueryRewriter.java index 9710108..59ce92c 100644 --- a/phoenix-core/src/main/java/org/apache/phoenix/compile/SubqueryRewriter.java +++ b/phoenix-core/src/main/java/org/apache/phoenix/compile/SubqueryRewriter.java @@ -35,6 +35,7 @@ import org.apache.phoenix.parse.ArrayAnyComparisonNode; import org.apache.phoenix.parse.ColumnParseNode; import org.apache.phoenix.parse.ComparisonParseNode; import org.apache.phoenix.parse.CompoundParseNode; +import org.apache.phoenix.parse.DerivedTableNode; import org.apache.phoenix.parse.ExistsParseNode; import org.apache.phoenix.parse.HintNode; import org.apache.phoenix.parse.InParseNode; @@ -56,7 +57,7 @@ import org.apache.phoenix.schema.TableNotFoundException; import com.google.common.collect.Lists; -/* +/** * Class for rewriting where-clause sub-queries into join queries. * * If the where-clause sub-query is one of those top-node conditions (being @@ -70,7 +71,7 @@ import com.google.common.collect.Lists; public class SubqueryRewriter extends ParseNodeRewriter { private static final ParseNodeFactory NODE_FACTORY = new ParseNodeFactory(); - private final ColumnResolver resolver; + private final ColumnResolver columnResolver; private final PhoenixConnection connection; private TableNode tableNode; private ParseNode topNode; @@ -89,7 +90,7 @@ public class SubqueryRewriter extends ParseNodeRewriter { } protected SubqueryRewriter(SelectStatement select, ColumnResolver resolver, PhoenixConnection connection) { - this.resolver = resolver; + this.columnResolver = resolver; this.connection = connection; this.tableNode = select.getFrom(); this.topNode = null; @@ -194,7 +195,7 @@ public class SubqueryRewriter extends ParseNodeRewriter { JoinConditionExtractor joinConditionExtractor = new JoinConditionExtractor( subquerySelectStatementToUse, - resolver, + columnResolver, connection, subqueryTableTempAlias); @@ -228,7 +229,7 @@ public class SubqueryRewriter extends ParseNodeRewriter { * It is an Correlated subquery. */ List<AliasedNode> extractedAdditionalSelectAliasNodes = - joinConditionExtractor.getAdditionalSelectNodes(); + joinConditionExtractor.getAdditionalSubselectSelectAliasedNodes(); extractedSelectAliasNodeCount = extractedAdditionalSelectAliasNodes.size(); newSubquerySelectAliasedNodes = Lists.<AliasedNode> newArrayListWithExpectedSize( oldSubqueryAliasedNodes.size() + 1 + @@ -239,10 +240,11 @@ public class SubqueryRewriter extends ParseNodeRewriter { LiteralParseNode.ONE)); this.addNewAliasedNodes(newSubquerySelectAliasedNodes, oldSubqueryAliasedNodes); newSubquerySelectAliasedNodes.addAll(extractedAdditionalSelectAliasNodes); - extractedJoinConditionParseNode = joinConditionExtractor.getJoinCondition(); + extractedJoinConditionParseNode = + joinConditionExtractor.getJoinConditionParseNode(); boolean isAggregate = subquerySelectStatementToUse.isAggregate(); - if(!isAggregate) { + if (!isAggregate) { subquerySelectStatementToUse = NODE_FACTORY.select( subquerySelectStatementToUse, @@ -274,7 +276,7 @@ public class SubqueryRewriter extends ParseNodeRewriter { subqueryTableTempAlias, extractedJoinConditionParseNode, extractedSelectAliasNodeCount); - TableNode rhsTableNode = NODE_FACTORY.derivedTable( + DerivedTableNode subqueryDerivedTableNode = NODE_FACTORY.derivedTable( subqueryTableTempAlias, subquerySelectStatementToUse); JoinType joinType = isTopNode ? @@ -291,45 +293,167 @@ public class SubqueryRewriter extends ParseNodeRewriter { tableNode = NODE_FACTORY.join( joinType, tableNode, - rhsTableNode, + subqueryDerivedTableNode, joinOnConditionParseNode, false); return resultWhereParseNode; } + /** + * <pre> + * {@code + * Rewrite the Exists Subquery to semi/anti/left join for both NonCorrelated and Correlated subquery. + * + * 1.If the {@link ExistsParseNode} is NonCorrelated subquery,the just add LIMIT 1. + * an example is: + * SELECT item_id, name FROM item i WHERE exists + * (SELECT 1 FROM order o where o.price > 8) + * + * The above sql would be rewritten as: + * SELECT ITEM_ID,NAME FROM item I WHERE EXISTS + * (SELECT 1 FROM ORDER_TABLE O WHERE O.PRICE > 8 LIMIT 1) + * + * another example is: + * SELECT item_id, name FROM item i WHERE exists + * (SELECT 1 FROM order o where o.price > 8 group by o.customer_id,o.item_id having count(order_id) > 1) + * or i.discount1 > 10 + * + * The above sql would be rewritten as: + * SELECT ITEM_ID,NAME FROM item I WHERE + * ( EXISTS (SELECT 1 FROM ORDER_TABLE O WHERE O.PRICE > 8 GROUP BY O.CUSTOMER_ID,O.ITEM_ID HAVING COUNT(ORDER_ID) > 1 LIMIT 1) + * OR I.DISCOUNT1 > 10) + * + * 2.If the {@link ExistsParseNode} is Correlated subquery and is the only node in where clause or + * is the ANDed part of the where clause, then we would rewrite the Exists Subquery to semi/anti join: + * an example is: + * SELECT item_id, name FROM item i WHERE exists + * (SELECT 1 FROM order o where o.price = i.price and o.quantity = 5 ) + * + * The above sql would be rewritten as: + * SELECT ITEM_ID,NAME FROM item I Semi JOIN + * (SELECT DISTINCT 1 $3,O.PRICE $2 FROM ORDER_TABLE O WHERE O.QUANTITY = 5) $1 + * ON ($1.$2 = I.PRICE) + * + * another example with AggregateFunction and groupBy is + * SELECT item_id, name FROM item i WHERE exists + * (SELECT 1 FROM order o where o.item_id = i.item_id group by customer_id having count(order_id) > 1) + * + * The above sql would be rewritten as: + * SELECT ITEM_ID,NAME FROM item I Semi JOIN + * (SELECT DISTINCT 1 $3,O.ITEM_ID $2 FROM order O GROUP BY O.ITEM_ID,CUSTOMER_ID HAVING COUNT(ORDER_ID) > 1) $1 + * ON ($1.$2 = I.ITEM_ID) + * + * 3.If the {@link ExistsParseNode} is Correlated subquery and is the ORed part of the where clause, + * then we would rewrite the Exists Subquery to Left Join. + * an example is: + * SELECT item_id, name FROM item i WHERE exists + * (SELECT 1 FROM order o where o.item_id = i.item_id group by customer_id having count(order_id) > 1) + * or i.discount1 > 10 + * + * The above sql would be rewritten as: + * SELECT ITEM_ID,NAME FROM item I Left JOIN + * (SELECT DISTINCT 1 $3,O.ITEM_ID $2 FROM order O GROUP BY O.ITEM_ID,CUSTOMER_ID HAVING COUNT(ORDER_ID) > 1) $1 + * ON ($1.$2 = I.ITEM_ID) WHERE ($1.$3 IS NOT NULL OR I.DISCOUNT1 > 10) + * } + * </pre> + */ @Override - public ParseNode visitLeave(ExistsParseNode node, List<ParseNode> l) throws SQLException { - boolean isTopNode = topNode == node; + public ParseNode visitLeave( + ExistsParseNode existsParseNode, + List<ParseNode> childParseNodes) throws SQLException { + + boolean isTopNode = topNode == existsParseNode; if (isTopNode) { topNode = null; } - SubqueryParseNode subqueryNode = (SubqueryParseNode) l.get(0); - SelectStatement subquery = fixSubqueryStatement(subqueryNode.getSelectNode()); - String rhsTableAlias = ParseNodeFactory.createTempAlias(); - JoinConditionExtractor conditionExtractor = new JoinConditionExtractor(subquery, resolver, connection, rhsTableAlias); - ParseNode where = subquery.getWhere() == null ? null : subquery.getWhere().accept(conditionExtractor); - if (where == subquery.getWhere()) { // non-correlated EXISTS subquery, add LIMIT 1 - subquery = NODE_FACTORY.select(subquery, NODE_FACTORY.limit(NODE_FACTORY.literal(1))); - subqueryNode = NODE_FACTORY.subquery(subquery, false); - node = NODE_FACTORY.exists(subqueryNode, node.isNegate()); - return super.visitLeave(node, Collections.<ParseNode> singletonList(subqueryNode)); - } - - List<AliasedNode> additionalSelectNodes = conditionExtractor.getAdditionalSelectNodes(); - List<AliasedNode> selectNodes = Lists.newArrayListWithExpectedSize(additionalSelectNodes.size() + 1); - selectNodes.add(NODE_FACTORY.aliasedNode(ParseNodeFactory.createTempAlias(), LiteralParseNode.ONE)); - selectNodes.addAll(additionalSelectNodes); - - subquery = NODE_FACTORY.select(subquery, true, selectNodes, where); - ParseNode onNode = conditionExtractor.getJoinCondition(); - TableNode rhsTable = NODE_FACTORY.derivedTable(rhsTableAlias, subquery); - JoinType joinType = isTopNode ? (node.isNegate() ? JoinType.Anti : JoinType.Semi) : JoinType.Left; - ParseNode ret = isTopNode ? null : NODE_FACTORY.isNull(NODE_FACTORY.column(NODE_FACTORY.table(null, rhsTableAlias), selectNodes.get(0).getAlias(), null), !node.isNegate()); - tableNode = NODE_FACTORY.join(joinType, tableNode, rhsTable, onNode, false); + SubqueryParseNode subqueryParseNode = (SubqueryParseNode) childParseNodes.get(0); + SelectStatement subquerySelectStatementToUse = + fixSubqueryStatement(subqueryParseNode.getSelectNode()); + String subqueryTableTempAlias = ParseNodeFactory.createTempAlias(); + JoinConditionExtractor joinConditionExtractor = + new JoinConditionExtractor( + subquerySelectStatementToUse, + columnResolver, + connection, + subqueryTableTempAlias); + ParseNode whereParseNodeAfterExtract = + subquerySelectStatementToUse.getWhere() == null ? + null : + subquerySelectStatementToUse.getWhere().accept(joinConditionExtractor); + if (whereParseNodeAfterExtract == subquerySelectStatementToUse.getWhere()) { + /** + * It is non-correlated EXISTS subquery, add LIMIT 1 + */ + subquerySelectStatementToUse = + NODE_FACTORY.select( + subquerySelectStatementToUse, + NODE_FACTORY.limit(NODE_FACTORY.literal(1))); + subqueryParseNode = NODE_FACTORY.subquery(subquerySelectStatementToUse, false); + existsParseNode = NODE_FACTORY.exists(subqueryParseNode, existsParseNode.isNegate()); + return super.visitLeave( + existsParseNode, + Collections.<ParseNode>singletonList(subqueryParseNode)); + } + + List<AliasedNode> extractedAdditionalSelectAliasNodes = + joinConditionExtractor.getAdditionalSubselectSelectAliasedNodes(); + List<AliasedNode> newSubquerySelectAliasedNodes = Lists.newArrayListWithExpectedSize( + extractedAdditionalSelectAliasNodes.size() + 1); + /** + * Just overwrite original subquery selectAliasNodes. + */ + newSubquerySelectAliasedNodes.add( + NODE_FACTORY.aliasedNode(ParseNodeFactory.createTempAlias(), LiteralParseNode.ONE)); + newSubquerySelectAliasedNodes.addAll(extractedAdditionalSelectAliasNodes); - return ret; + boolean isAggregate = subquerySelectStatementToUse.isAggregate(); + if (!isAggregate) { + subquerySelectStatementToUse = NODE_FACTORY.select( + subquerySelectStatementToUse, + true, + newSubquerySelectAliasedNodes, + whereParseNodeAfterExtract); + } else { + /** + * If exists AggregateFunction,we must add the correlated join condition to both the + * groupBy clause and select lists of the subquery. + */ + List<ParseNode> newGroupByParseNodes = this.createNewGroupByParseNodes( + extractedAdditionalSelectAliasNodes, + subquerySelectStatementToUse); + + subquerySelectStatementToUse = NODE_FACTORY.select( + subquerySelectStatementToUse, + true, + newSubquerySelectAliasedNodes, + whereParseNodeAfterExtract, + newGroupByParseNodes, + true); + } + ParseNode joinOnConditionParseNode = joinConditionExtractor.getJoinConditionParseNode(); + DerivedTableNode subqueryDerivedTableNode = NODE_FACTORY.derivedTable( + subqueryTableTempAlias, + subquerySelectStatementToUse); + JoinType joinType = isTopNode ? + (existsParseNode.isNegate() ? JoinType.Anti : JoinType.Semi) : + JoinType.Left; + ParseNode resultWhereParseNode = isTopNode ? + null : + NODE_FACTORY.isNull( + NODE_FACTORY.column( + NODE_FACTORY.table(null, subqueryTableTempAlias), + newSubquerySelectAliasedNodes.get(0).getAlias(), + null), + !existsParseNode.isNegate()); + tableNode = NODE_FACTORY.join( + joinType, + tableNode, + subqueryDerivedTableNode, + joinOnConditionParseNode, + false); + return resultWhereParseNode; } @Override @@ -347,7 +471,7 @@ public class SubqueryRewriter extends ParseNodeRewriter { SubqueryParseNode subqueryNode = (SubqueryParseNode) secondChild; SelectStatement subquery = fixSubqueryStatement(subqueryNode.getSelectNode()); String rhsTableAlias = ParseNodeFactory.createTempAlias(); - JoinConditionExtractor conditionExtractor = new JoinConditionExtractor(subquery, resolver, connection, rhsTableAlias); + JoinConditionExtractor conditionExtractor = new JoinConditionExtractor(subquery, columnResolver, connection, rhsTableAlias); ParseNode where = subquery.getWhere() == null ? null : subquery.getWhere().accept(conditionExtractor); if (where == subquery.getWhere()) { // non-correlated comparison subquery, add LIMIT 2, expectSingleRow = true subquery = NODE_FACTORY.select(subquery, NODE_FACTORY.limit(NODE_FACTORY.literal(2))); @@ -371,8 +495,10 @@ public class SubqueryRewriter extends ParseNodeRewriter { rhsNode = NODE_FACTORY.rowValueConstructor(nodes); } - List<AliasedNode> additionalSelectNodes = conditionExtractor.getAdditionalSelectNodes(); - List<AliasedNode> selectNodes = Lists.newArrayListWithExpectedSize(additionalSelectNodes.size() + 1); + List<AliasedNode> additionalSelectNodes = + conditionExtractor.getAdditionalSubselectSelectAliasedNodes(); + List<AliasedNode> selectNodes = + Lists.newArrayListWithExpectedSize(additionalSelectNodes.size() + 1); selectNodes.add(NODE_FACTORY.aliasedNode(ParseNodeFactory.createTempAlias(), rhsNode)); selectNodes.addAll(additionalSelectNodes); @@ -385,7 +511,7 @@ public class SubqueryRewriter extends ParseNodeRewriter { subquery = NODE_FACTORY.select(subquery, subquery.isDistinct(), selectNodes, where, groupbyNodes, true); } - ParseNode onNode = conditionExtractor.getJoinCondition(); + ParseNode onNode = conditionExtractor.getJoinConditionParseNode(); TableNode rhsTable = NODE_FACTORY.derivedTable(rhsTableAlias, subquery); JoinType joinType = isTopNode ? JoinType.Inner : JoinType.Left; ParseNode ret = NODE_FACTORY.comparison(node.getFilterOp(), l.get(0), NODE_FACTORY.column(NODE_FACTORY.table(null, rhsTableAlias), selectNodes.get(0).getAlias(), null)); @@ -428,7 +554,7 @@ public class SubqueryRewriter extends ParseNodeRewriter { SubqueryParseNode subqueryNode = (SubqueryParseNode) firstChild; SelectStatement subquery = fixSubqueryStatement(subqueryNode.getSelectNode()); String rhsTableAlias = ParseNodeFactory.createTempAlias(); - JoinConditionExtractor conditionExtractor = new JoinConditionExtractor(subquery, resolver, connection, rhsTableAlias); + JoinConditionExtractor conditionExtractor = new JoinConditionExtractor(subquery, columnResolver, connection, rhsTableAlias); ParseNode where = subquery.getWhere() == null ? null : subquery.getWhere().accept(conditionExtractor); if (where == subquery.getWhere()) { // non-correlated any/all comparison subquery return l; @@ -457,7 +583,7 @@ public class SubqueryRewriter extends ParseNodeRewriter { rhsNode = NODE_FACTORY.function(DistinctValueAggregateFunction.NAME, Collections.singletonList(rhsNode)); } - List<AliasedNode> additionalSelectNodes = conditionExtractor.getAdditionalSelectNodes(); + List<AliasedNode> additionalSelectNodes = conditionExtractor.getAdditionalSubselectSelectAliasedNodes(); List<AliasedNode> selectNodes = Lists.newArrayListWithExpectedSize(additionalSelectNodes.size() + 1); selectNodes.add(NODE_FACTORY.aliasedNode(ParseNodeFactory.createTempAlias(), rhsNode)); selectNodes.addAll(additionalSelectNodes); @@ -489,7 +615,7 @@ public class SubqueryRewriter extends ParseNodeRewriter { Collections.<SelectStatement> emptyList(), subquery.getUdfParseNodes()); } - ParseNode onNode = conditionExtractor.getJoinCondition(); + ParseNode onNode = conditionExtractor.getJoinConditionParseNode(); TableNode rhsTable = NODE_FACTORY.derivedTable(rhsTableAlias, subquery); JoinType joinType = isTopNode ? JoinType.Inner : JoinType.Left; tableNode = NODE_FACTORY.join(joinType, tableNode, rhsTable, onNode, false); @@ -623,8 +749,8 @@ public class SubqueryRewriter extends ParseNodeRewriter { private static class JoinConditionExtractor extends AndRewriterBooleanParseNodeVisitor { private final TableName tableName; private ColumnResolveVisitor columnResolveVisitor; - private List<AliasedNode> additionalSelectNodes; - private List<ParseNode> joinConditions; + private List<AliasedNode> additionalSubselectSelectAliasedNodes; + private List<ParseNode> joinConditionParseNodes; public JoinConditionExtractor(SelectStatement subquery, ColumnResolver outerResolver, PhoenixConnection connection, String tableAlias) throws SQLException { @@ -632,22 +758,24 @@ public class SubqueryRewriter extends ParseNodeRewriter { this.tableName = NODE_FACTORY.table(null, tableAlias); ColumnResolver localResolver = FromCompiler.getResolverForQuery(subquery, connection); this.columnResolveVisitor = new ColumnResolveVisitor(localResolver, outerResolver); - this.additionalSelectNodes = Lists.<AliasedNode> newArrayList(); - this.joinConditions = Lists.<ParseNode> newArrayList(); + this.additionalSubselectSelectAliasedNodes = Lists.<AliasedNode>newArrayList(); + this.joinConditionParseNodes = Lists.<ParseNode>newArrayList(); } - public List<AliasedNode> getAdditionalSelectNodes() { - return this.additionalSelectNodes; + public List<AliasedNode> getAdditionalSubselectSelectAliasedNodes() { + return this.additionalSubselectSelectAliasedNodes; } - public ParseNode getJoinCondition() { - if (this.joinConditions.isEmpty()) + public ParseNode getJoinConditionParseNode() { + if (this.joinConditionParseNodes.isEmpty()) { return null; - - if (this.joinConditions.size() == 1) - return this.joinConditions.get(0); - - return NODE_FACTORY.and(this.joinConditions); + } + + if (this.joinConditionParseNodes.size() == 1) { + return this.joinConditionParseNodes.get(0); + } + + return NODE_FACTORY.and(this.joinConditionParseNodes); } @Override @@ -680,16 +808,18 @@ public class SubqueryRewriter extends ParseNodeRewriter { } if (lhsType == ColumnResolveVisitor.ColumnResolveType.LOCAL && rhsType == ColumnResolveVisitor.ColumnResolveType.OUTER) { String alias = ParseNodeFactory.createTempAlias(); - this.additionalSelectNodes.add(NODE_FACTORY.aliasedNode(alias, node.getLHS())); + this.additionalSubselectSelectAliasedNodes.add( + NODE_FACTORY.aliasedNode(alias, node.getLHS())); ParseNode lhsNode = NODE_FACTORY.column(tableName, alias, null); - this.joinConditions.add(NODE_FACTORY.equal(lhsNode, node.getRHS())); + this.joinConditionParseNodes.add(NODE_FACTORY.equal(lhsNode, node.getRHS())); return null; } if (lhsType == ColumnResolveVisitor.ColumnResolveType.OUTER && rhsType == ColumnResolveVisitor.ColumnResolveType.LOCAL) { String alias = ParseNodeFactory.createTempAlias(); - this.additionalSelectNodes.add(NODE_FACTORY.aliasedNode(alias, node.getRHS())); + this.additionalSubselectSelectAliasedNodes.add( + NODE_FACTORY.aliasedNode(alias, node.getRHS())); ParseNode rhsNode = NODE_FACTORY.column(tableName, alias, null); - this.joinConditions.add(NODE_FACTORY.equal(node.getLHS(), rhsNode)); + this.joinConditionParseNodes.add(NODE_FACTORY.equal(node.getLHS(), rhsNode)); return null; } diff --git a/phoenix-core/src/test/java/org/apache/phoenix/compile/QueryCompilerTest.java b/phoenix-core/src/test/java/org/apache/phoenix/compile/QueryCompilerTest.java index 4500a18..0af5ce7 100644 --- a/phoenix-core/src/test/java/org/apache/phoenix/compile/QueryCompilerTest.java +++ b/phoenix-core/src/test/java/org/apache/phoenix/compile/QueryCompilerTest.java @@ -6527,7 +6527,7 @@ public class QueryCompilerTest extends BaseConnectionlessQueryTest { } - //test Correlated subquery with AggregateFunction with groupBy and is ORed part of the where clause. + //test Correlated subquery with AggregateFunction with groupBy and is ORed part of the where clause. sql= "SELECT item_id, name FROM " + itemTableName + " i WHERE i.item_id IN "+ "(SELECT max(item_id) FROM " + orderTableName + " o where o.price = i.price group by o.customer_id) or i.discount1 > 10 ORDER BY name"; queryPlan= TestUtil.getOptimizeQueryPlanNoIterator(conn, sql); @@ -6737,7 +6737,8 @@ public class QueryCompilerTest extends BaseConnectionlessQueryTest { scanPlan=(ScanPlan)(hashJoinPlan.getDelegate()); TestUtil.assertSelectStatement( scanPlan.getStatement(), - "SELECT A.AID FROM " + tableName1 + " WHERE (AGE > (SELECT MAX(CODE) FROM " + tableName2 + " C WHERE C.BID >= 1 LIMIT 2) AND (AGE >= 11 AND AGE <= 33)) ORDER BY A.AID"); + "SELECT A.AID FROM " + tableName1 + + " WHERE (AGE > (SELECT MAX(CODE) FROM " + tableName2 + " C WHERE C.BID >= 1 LIMIT 2) AND (AGE >= 11 AND AGE <= 33)) ORDER BY A.AID"); subPlans = hashJoinPlan.getSubPlans(); assertTrue(subPlans.length == 2); assertTrue(subPlans[0] instanceof WhereClauseSubPlan); @@ -6766,4 +6767,117 @@ public class QueryCompilerTest extends BaseConnectionlessQueryTest { conn.close(); } } + + @Test + public void testExistsSubqueryBug6498() throws Exception { + Connection conn = null; + try { + conn = DriverManager.getConnection(getUrl()); + String itemTableName = "item_table"; + String sql ="create table " + itemTableName + + " (item_id varchar not null primary key, " + + " name varchar, " + + " price integer, " + + " discount1 integer, " + + " discount2 integer, " + + " supplier_id varchar, " + + " description varchar)"; + conn.createStatement().execute(sql); + + String orderTableName = "order_table"; + sql = "create table " + orderTableName + + " (order_id varchar not null primary key, " + + " customer_id varchar, " + + " item_id varchar, " + + " price integer, " + + " quantity integer, " + + " date timestamp)"; + conn.createStatement().execute(sql); + + //test simple Correlated subquery + ParseNodeFactory.setTempAliasCounterValue(0); + sql= "SELECT item_id, name FROM " + itemTableName + " i WHERE exists "+ + "(SELECT 1 FROM " + orderTableName + " o where o.price = i.price and o.quantity = 5 ) ORDER BY name"; + QueryPlan queryPlan = TestUtil.getOptimizeQueryPlanNoIterator(conn, sql); + assertTrue(queryPlan instanceof HashJoinPlan); + System.out.println(queryPlan.getStatement()); + TestUtil.assertSelectStatement( + queryPlan.getStatement(), + "SELECT ITEM_ID,NAME FROM ITEM_TABLE I Semi JOIN " + + "(SELECT DISTINCT 1 $3,O.PRICE $2 FROM ORDER_TABLE O WHERE O.QUANTITY = 5) $1 "+ + "ON ($1.$2 = I.PRICE) ORDER BY NAME"); + + //test Correlated subquery with AggregateFunction and groupBy + ParseNodeFactory.setTempAliasCounterValue(0); + sql= "SELECT item_id, name FROM " + itemTableName + " i WHERE exists "+ + "(SELECT 1 FROM " + orderTableName + " o where o.item_id = i.item_id group by customer_id having count(order_id) > 1) " + + "ORDER BY name"; + queryPlan = TestUtil.getOptimizeQueryPlanNoIterator(conn, sql); + assertTrue(queryPlan instanceof HashJoinPlan); + TestUtil.assertSelectStatement( + queryPlan.getStatement(), + "SELECT ITEM_ID,NAME FROM ITEM_TABLE I Semi JOIN " + + "(SELECT DISTINCT 1 $3,O.ITEM_ID $2 FROM ORDER_TABLE O GROUP BY O.ITEM_ID,CUSTOMER_ID HAVING COUNT(ORDER_ID) > 1) $1 " + + "ON ($1.$2 = I.ITEM_ID) ORDER BY NAME"); + + //for Correlated subquery, the extracted join condition must be equal expression. + sql= "SELECT item_id, name FROM " + itemTableName + " i WHERE exists "+ + "(SELECT 1 FROM " + orderTableName + " o where o.price = i.price or o.quantity > 1 group by o.customer_id) ORDER BY name"; + try { + queryPlan= TestUtil.getOptimizeQueryPlanNoIterator(conn, sql); + fail(); + } catch(SQLFeatureNotSupportedException exception) { + + } + + //test Correlated subquery with AggregateFunction with groupBy and is ORed part of the where clause. + ParseNodeFactory.setTempAliasCounterValue(0); + sql= "SELECT item_id, name FROM " + itemTableName + " i WHERE exists "+ + "(SELECT 1 FROM " + orderTableName + " o where o.item_id = i.item_id group by customer_id having count(order_id) > 1) "+ + " or i.discount1 > 10 ORDER BY name"; + queryPlan= TestUtil.getOptimizeQueryPlanNoIterator(conn, sql); + assertTrue(queryPlan instanceof HashJoinPlan); + TestUtil.assertSelectStatement( + queryPlan.getStatement(), + "SELECT ITEM_ID,NAME FROM ITEM_TABLE I Left JOIN " + + "(SELECT DISTINCT 1 $3,O.ITEM_ID $2 FROM ORDER_TABLE O GROUP BY O.ITEM_ID,CUSTOMER_ID HAVING COUNT(ORDER_ID) > 1) $1 " + + "ON ($1.$2 = I.ITEM_ID) WHERE ($1.$3 IS NOT NULL OR I.DISCOUNT1 > 10) ORDER BY NAME"); + + // test NonCorrelated subquery + ParseNodeFactory.setTempAliasCounterValue(0); + sql= "SELECT item_id, name FROM " + itemTableName + " i WHERE exists "+ + "(SELECT 1 FROM " + orderTableName + " o where o.price > 8) ORDER BY name"; + queryPlan= TestUtil.getOptimizeQueryPlanNoIterator(conn, sql); + assertTrue(queryPlan instanceof HashJoinPlan); + System.out.println(queryPlan.getStatement()); + TestUtil.assertSelectStatement( + queryPlan.getStatement(), + "SELECT ITEM_ID,NAME FROM ITEM_TABLE I WHERE EXISTS (SELECT 1 FROM ORDER_TABLE O WHERE O.PRICE > 8 LIMIT 1) ORDER BY NAME"); + + sql= "SELECT item_id, name FROM " + itemTableName + " i WHERE exists "+ + "(SELECT 1 FROM " + orderTableName + " o where o.price > 8 group by o.customer_id,o.item_id having count(order_id) > 1)" + + " ORDER BY name"; + queryPlan= TestUtil.getOptimizeQueryPlanNoIterator(conn, sql); + assertTrue(queryPlan instanceof HashJoinPlan); + TestUtil.assertSelectStatement( + queryPlan.getStatement(), + "SELECT ITEM_ID,NAME FROM ITEM_TABLE I WHERE EXISTS "+ + "(SELECT 1 FROM ORDER_TABLE O WHERE O.PRICE > 8 GROUP BY O.CUSTOMER_ID,O.ITEM_ID HAVING COUNT(ORDER_ID) > 1 LIMIT 1)" + + " ORDER BY NAME"); + + sql= "SELECT item_id, name FROM " + itemTableName + " i WHERE exists "+ + "(SELECT 1 FROM " + orderTableName + " o where o.price > 8 group by o.customer_id,o.item_id having count(order_id) > 1)" + + " or i.discount1 > 10 ORDER BY name"; + queryPlan= TestUtil.getOptimizeQueryPlanNoIterator(conn, sql); + assertTrue(queryPlan instanceof HashJoinPlan); + TestUtil.assertSelectStatement( + queryPlan.getStatement(), + "SELECT ITEM_ID,NAME FROM ITEM_TABLE I WHERE " + + "( EXISTS (SELECT 1 FROM ORDER_TABLE O WHERE O.PRICE > 8 GROUP BY O.CUSTOMER_ID,O.ITEM_ID HAVING COUNT(ORDER_ID) > 1 LIMIT 1)" + + " OR I.DISCOUNT1 > 10) ORDER BY NAME"); + } finally { + conn.close(); + } + } + }