This is an automated email from the ASF dual-hosted git repository. sunlan pushed a commit to branch GROOVY-11720 in repository https://gitbox.apache.org/repos/asf/groovy.git
commit fee52236ac3ff8e97a56256d761a324d3c663bce Author: Daniel Sun <sun...@apache.org> AuthorDate: Sat Jul 26 09:03:45 2025 +0900 GROOVY-11720: [GINQ] Failed to recognize sub-query in where clause --- .../org/apache/groovy/ginq/dsl/GinqAstBuilder.java | 31 ++- .../ginq/provider/collection/GinqAstWalker.groovy | 23 +- .../test/org/apache/groovy/ginq/GinqTest.groovy | 253 +++++++++++++++++++++ 3 files changed, 295 insertions(+), 12 deletions(-) diff --git a/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/dsl/GinqAstBuilder.java b/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/dsl/GinqAstBuilder.java index f88b690c81..ba60c2a6b7 100644 --- a/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/dsl/GinqAstBuilder.java +++ b/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/dsl/GinqAstBuilder.java @@ -53,11 +53,13 @@ import org.codehaus.groovy.syntax.Types; import java.util.ArrayDeque; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.Deque; import java.util.HashSet; import java.util.List; import java.util.Set; + /** * Build the AST for GINQ * @@ -376,12 +378,17 @@ public class GinqAstBuilder extends CodeVisitorSupport implements SyntaxErrorRep public void visitBinaryExpression(BinaryExpression expression) { super.visitBinaryExpression(expression); - final int opType = expression.getOperation().getType(); - if (opType == Types.KEYWORD_IN || opType == Types.COMPARE_NOT_IN) { - if (null != latestGinqExpression && isSelectMethodCallExpression(expression.getRightExpression())) { + final Integer opType = expression.getOperation().getType(); + if (FILTER_BINARY_OP_SET.contains(opType)) { + if (null != latestGinqExpression) { // use the nested ginq and clear it - expression.setRightExpression(latestGinqExpression); - latestGinqExpression = null; + if (isSelectMethodCallExpression(expression.getRightExpression())) { + expression.setRightExpression(latestGinqExpression); + latestGinqExpression = null; + } else if (isSelectMethodCallExpression(expression.getLeftExpression())) { + expression.setLeftExpression(latestGinqExpression); + latestGinqExpression = null; + } } } } @@ -464,8 +471,12 @@ public class GinqAstBuilder extends CodeVisitorSupport implements SyntaxErrorRep return sourceUnit; } - private static final String __LATEST_GINQ_EXPRESSION_CLAUSE = "__latestGinqExpressionClause"; + public static final Set<Integer> FILTER_BINARY_OP_SET = Collections.unmodifiableSet(new HashSet<>(Arrays.asList( + Types.KEYWORD_IN, Types.COMPARE_NOT_IN, Types.COMPARE_IDENTICAL, Types.COMPARE_NOT_IDENTICAL, + Types.COMPARE_EQUAL, Types.COMPARE_NOT_EQUAL, Types.COMPARE_LESS_THAN, Types.COMPARE_LESS_THAN_EQUAL, + Types.COMPARE_GREATER_THAN, Types.COMPARE_GREATER_THAN_EQUAL, Types.MATCH_REGEX))); + private static final String __LATEST_GINQ_EXPRESSION_CLAUSE = "__latestGinqExpressionClause"; private static final String KW_WITH = "with"; // reserved keyword private static final String KW_FROM = "from"; private static final String KW_IN = "in"; @@ -483,10 +494,12 @@ public class GinqAstBuilder extends CodeVisitorSupport implements SyntaxErrorRep private static final String KW_OVER = "over"; private static final String KW_AS = "as"; private static final String KW_SHUTDOWN = "shutdown"; - private static final Set<String> KEYWORD_SET = new HashSet<>(); + private static final Set<String> KEYWORD_SET; static { - KEYWORD_SET.addAll(Arrays.asList(KW_WITH, KW_FROM, KW_IN, KW_ON, KW_WHERE, KW_EXISTS, KW_GROUPBY, KW_HAVING, KW_ORDERBY, + Set<String> keywordSet = new HashSet<>(); + keywordSet.addAll(Arrays.asList(KW_WITH, KW_FROM, KW_IN, KW_ON, KW_WHERE, KW_EXISTS, KW_GROUPBY, KW_HAVING, KW_ORDERBY, KW_LIMIT, KW_OFFSET, KW_SELECT, KW_DISTINCT, KW_WITHINGROUP, KW_OVER, KW_AS, KW_SHUTDOWN)); - KEYWORD_SET.addAll(JoinExpression.JOIN_NAME_LIST); + keywordSet.addAll(JoinExpression.JOIN_NAME_LIST); + KEYWORD_SET = Collections.unmodifiableSet(keywordSet); } } diff --git a/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/GinqAstWalker.groovy b/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/GinqAstWalker.groovy index a32910a6c4..8c5301a7b1 100644 --- a/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/GinqAstWalker.groovy +++ b/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/GinqAstWalker.groovy @@ -542,11 +542,20 @@ class GinqAstWalker implements GinqAstVisitor<Expression>, SyntaxErrorReportable } if (expression instanceof BinaryExpression) { - if (expression.operation.type in [Types.KEYWORD_IN, Types.COMPARE_NOT_IN]) { + if (expression.operation.type in GinqAstBuilder.FILTER_BINARY_OP_SET) { + boolean containsGinqExpression = false + if (expression.leftExpression instanceof AbstractGinqExpression) { + expression.leftExpression = callSingleValue((AbstractGinqExpression) expression.leftExpression) + containsGinqExpression = true + } if (expression.rightExpression instanceof AbstractGinqExpression) { - expression.rightExpression = callX(visit((AbstractGinqExpression) expression.rightExpression), "toList") - return expression + expression.rightExpression = expression.operation.type in [Types.KEYWORD_IN, Types.COMPARE_NOT_IN] + ? callToList((AbstractGinqExpression) expression.rightExpression) + : callSingleValue((AbstractGinqExpression) expression.rightExpression) + containsGinqExpression = true } + + if (containsGinqExpression) return expression } } @@ -560,6 +569,14 @@ class GinqAstWalker implements GinqAstVisitor<Expression>, SyntaxErrorReportable return whereMethodCallExpression } + private MethodCallExpression callSingleValue(AbstractGinqExpression expression) { + return callX(classX(QUERYABLE_HELPER_TYPE), "singleValue", visit(expression)) + } + + private MethodCallExpression callToList(AbstractGinqExpression expression) { + return callX(visit(expression), "toList") + } + @Override MethodCallExpression visitGroupExpression(GroupExpression groupExpression) { DataSourceExpression dataSourceExpression = groupExpression.dataSourceExpression diff --git a/subprojects/groovy-ginq/src/spec/test/org/apache/groovy/ginq/GinqTest.groovy b/subprojects/groovy-ginq/src/spec/test/org/apache/groovy/ginq/GinqTest.groovy index ea382d959b..662d4ae042 100644 --- a/subprojects/groovy-ginq/src/spec/test/org/apache/groovy/ginq/GinqTest.groovy +++ b/subprojects/groovy-ginq/src/spec/test/org/apache/groovy/ginq/GinqTest.groovy @@ -769,6 +769,259 @@ class GinqTest { ''' } + @Test + void "testGinq - nested from where select - 0"() { + assertGinqScript ''' + assert [2] == GQ { + from n in [1, 2, 3] + where n == (from m in [2] select m) + select n + }.toList() + ''' + } + + @Test + void "testGinq - nested from where select - 1"() { + assertGinqScript ''' + assert [2] == GQ { + from n in [1, 2, 3] + where n == (from m in [1, 2] select max(m)) + select n + }.toList() + ''' + } + + @Test + void "testGinq - nested from where select - 1 - swap operand"() { + assertGinqScript ''' + assert [2] == GQ { + from n in [1, 2, 3] + where ((from m in [1, 2] select max(m)) == n) + select n + }.toList() + ''' + } + + @Test + void "testGinq - nested from where select - 2"() { + assertGinqScript ''' + assert [1, 3] == GQ { + from n in [1, 2, 3] + where n != (from m in [1, 2] select max(m)) + select n + }.toList() + ''' + } + + @Test + void "testGinq - nested from where select - 2 - swap operand"() { + assertGinqScript ''' + assert [1, 3] == GQ { + from n in [1, 2, 3] + where ((from m in [1, 2] select max(m)) != n) + select n + }.toList() + ''' + } + + @Test + void "testGinq - nested from where select - 3"() { + assertGinqScript ''' + assert [3] == GQ { + from n in [1, 2, 3] + where n > (from m in [1, 2] select max(m)) + select n + }.toList() + ''' + } + + @Test + void "testGinq - nested from where select - 3 - swap operand"() { + assertGinqScript ''' + assert [3] == GQ { + from n in [1, 2, 3] + where ((from m in [1, 2] select max(m)) < n) + select n + }.toList() + ''' + } + + @Test + void "testGinq - nested from where select - 4"() { + assertGinqScript ''' + assert [2, 3] == GQ { + from n in [1, 2, 3] + where n >= (from m in [1, 2] select max(m)) + select n + }.toList() + ''' + } + + @Test + void "testGinq - nested from where select - 4 - swap operand"() { + assertGinqScript ''' + assert [2, 3] == GQ { + from n in [1, 2, 3] + where ((from m in [1, 2] select max(m)) <= n) + select n + }.toList() + ''' + } + + @Test + void "testGinq - nested from where select - 5"() { + assertGinqScript ''' + assert [1] == GQ { + from n in [1, 2, 3] + where n < (from m in [1, 2] select max(m)) + select n + }.toList() + ''' + } + + @Test + void "testGinq - nested from where select - 5 - swap operand"() { + assertGinqScript ''' + assert [1] == GQ { + from n in [1, 2, 3] + where ((from m in [1, 2] select max(m)) > n) + select n + }.toList() + ''' + } + + @Test + void "testGinq - nested from where select - 6"() { + assertGinqScript ''' + assert [1, 2] == GQ { + from n in [1, 2, 3] + where n <= (from m in [1, 2] select max(m)) + select n + }.toList() + ''' + } + + @Test + void "testGinq - nested from where select - 6 - swap operand"() { + assertGinqScript ''' + assert [1, 2] == GQ { + from n in [1, 2, 3] + where ((from m in [1, 2] select max(m)) >= n) + select n + }.toList() + ''' + } + + @Test + void "testGinq - nested from where select - 7"() { + assertGinqScript ''' + assert [2] == GQ { + from n in [1, 2, 3] + where n === (from m in [1, 2] select max(m)) + select n + }.toList() + ''' + } + + @Test + void "testGinq - nested from where select - 7 - swap operand"() { + assertGinqScript ''' + assert [2] == GQ { + from n in [1, 2, 3] + where ((from m in [1, 2] select max(m)) === n) + select n + }.toList() + ''' + } + + @Test + void "testGinq - nested from where select - 8"() { + assertGinqScript ''' + assert [1, 3] == GQ { + from n in [1, 2, 3] + where n !== (from m in [1, 2] select max(m)) + select n + }.toList() + ''' + } + + @Test + void "testGinq - nested from where select - 8 - swap operand"() { + assertGinqScript ''' + assert [1, 3] == GQ { + from n in [1, 2, 3] + where ((from m in [1, 2] select max(m)) !== n) + select n + }.toList() + ''' + } + + @Test + void "testGinq - nested from where select - 9"() { + assertGinqScript ''' + assert ['123'] == GQ { + from n in ['abc', '123', 'a1b2c3'] + where n ==~ (from m in [/[0-9]+/] select m) + select n + }.toList() + ''' + } + + @Test + void "testGinq - nested from where select - 10"() { + assertGinqScript ''' + assert [/[a-z]+/] == GQ { + from n in [/[0-9]+/, /[a-z]+/] + where (from m in ['abc'] select m) ==~ n + select n + }.toList() + ''' + } + + @Test + void "testGinq - nested from where select - 11"() { + assertGinqScript ''' + assert [2] == GQ { + from n in [2, 3, 4] + where 2 > (from m in [1, 2, 3] where m < n select max(m)) + select n + }.toList() + ''' + } + + @Test + void "testGinq - nested from where select - 11 - swap operand"() { + assertGinqScript ''' + assert [2] == GQ { + from n in [2, 3, 4] + where ((from m in [1, 2, 3] where m < n select max(m)) < 2) + select n + }.toList() + ''' + } + + @Test + void "testGinq - nested from where select - 12"() { + assertGinqScript ''' + assert [2, 3] == GQ { + from n in [2, 3, 4] + where n == (from m in [1, 2, 3] where m >= n select min(m)) + select n + }.toList() + ''' + } + + @Test + void "testGinq - nested from where select - 12 - swap operand"() { + assertGinqScript ''' + assert [2, 3] == GQ { + from n in [2, 3, 4] + where ((from m in [1, 2, 3] where m >= n select min(m)) == n) + select n + }.toList() + ''' + } + @Test void "testGinq - nested from select - 0"() { assertGinqScript '''