This is an automated email from the ASF dual-hosted git repository. sunlan pushed a commit to branch GROOVY-11720 in repository https://gitbox.apache.org/repos/asf/groovy.git
commit 59526b9196fe77a8fe23dd1a104d0bd2d7eb9f73 Author: Daniel Sun <sun...@apache.org> AuthorDate: Sat Jul 26 09:03:45 2025 +0900 GROOVY-11720: [GINQ] Failed to recognize sub-query in where clause --- .../org/apache/groovy/ginq/dsl/GinqAstBuilder.java | 37 ++- .../ginq/provider/collection/GinqAstWalker.groovy | 17 +- .../test/org/apache/groovy/ginq/GinqTest.groovy | 253 +++++++++++++++++++++ 3 files changed, 296 insertions(+), 11 deletions(-) diff --git a/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/dsl/GinqAstBuilder.java b/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/dsl/GinqAstBuilder.java index f88b690c81..41e8240119 100644 --- a/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/dsl/GinqAstBuilder.java +++ b/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/dsl/GinqAstBuilder.java @@ -48,16 +48,28 @@ import org.codehaus.groovy.ast.stmt.BlockStatement; import org.codehaus.groovy.ast.stmt.ExpressionStatement; import org.codehaus.groovy.ast.stmt.Statement; import org.codehaus.groovy.control.SourceUnit; -import org.codehaus.groovy.syntax.Types; import java.util.ArrayDeque; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.Deque; import java.util.HashSet; import java.util.List; import java.util.Set; +import static org.codehaus.groovy.syntax.Types.COMPARE_EQUAL; +import static org.codehaus.groovy.syntax.Types.COMPARE_GREATER_THAN; +import static org.codehaus.groovy.syntax.Types.COMPARE_GREATER_THAN_EQUAL; +import static org.codehaus.groovy.syntax.Types.COMPARE_IDENTICAL; +import static org.codehaus.groovy.syntax.Types.COMPARE_LESS_THAN; +import static org.codehaus.groovy.syntax.Types.COMPARE_LESS_THAN_EQUAL; +import static org.codehaus.groovy.syntax.Types.COMPARE_NOT_EQUAL; +import static org.codehaus.groovy.syntax.Types.COMPARE_NOT_IDENTICAL; +import static org.codehaus.groovy.syntax.Types.COMPARE_NOT_IN; +import static org.codehaus.groovy.syntax.Types.KEYWORD_IN; +import static org.codehaus.groovy.syntax.Types.MATCH_REGEX; + /** * Build the AST for GINQ * @@ -187,7 +199,7 @@ public class GinqAstBuilder extends CodeVisitorSupport implements SyntaxErrorRep } final Expression expression = arguments.getExpression(0); if (!(expression instanceof BinaryExpression - && ((BinaryExpression) expression).getOperation().getType() == Types.KEYWORD_IN)) { + && ((BinaryExpression) expression).getOperation().getType() == KEYWORD_IN)) { this.collectSyntaxError( new GinqSyntaxError( "`in` is expected for `" + methodName + "`, e.g. `" + methodName + " n in nums`", @@ -376,12 +388,17 @@ public class GinqAstBuilder extends CodeVisitorSupport implements SyntaxErrorRep public void visitBinaryExpression(BinaryExpression expression) { super.visitBinaryExpression(expression); - final int opType = expression.getOperation().getType(); - if (opType == Types.KEYWORD_IN || opType == Types.COMPARE_NOT_IN) { - if (null != latestGinqExpression && isSelectMethodCallExpression(expression.getRightExpression())) { + final Integer opType = expression.getOperation().getType(); + if (FILTER_BINARY_OP_SET.contains(opType)) { + if (null != latestGinqExpression) { // use the nested ginq and clear it - expression.setRightExpression(latestGinqExpression); - latestGinqExpression = null; + if (isSelectMethodCallExpression(expression.getRightExpression())) { + expression.setRightExpression(latestGinqExpression); + latestGinqExpression = null; + } else if (isSelectMethodCallExpression(expression.getLeftExpression())) { + expression.setLeftExpression(latestGinqExpression); + latestGinqExpression = null; + } } } } @@ -464,8 +481,12 @@ public class GinqAstBuilder extends CodeVisitorSupport implements SyntaxErrorRep return sourceUnit; } - private static final String __LATEST_GINQ_EXPRESSION_CLAUSE = "__latestGinqExpressionClause"; + public static final Set<Integer> FILTER_BINARY_OP_SET = Collections.unmodifiableSet(new HashSet<>(Arrays.asList( + KEYWORD_IN, COMPARE_NOT_IN, COMPARE_IDENTICAL, COMPARE_NOT_IDENTICAL, + COMPARE_EQUAL, COMPARE_NOT_EQUAL, COMPARE_LESS_THAN, COMPARE_LESS_THAN_EQUAL, + COMPARE_GREATER_THAN, COMPARE_GREATER_THAN_EQUAL, MATCH_REGEX))); + private static final String __LATEST_GINQ_EXPRESSION_CLAUSE = "__latestGinqExpressionClause"; private static final String KW_WITH = "with"; // reserved keyword private static final String KW_FROM = "from"; private static final String KW_IN = "in"; diff --git a/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/GinqAstWalker.groovy b/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/GinqAstWalker.groovy index a32910a6c4..28a1fe60a4 100644 --- a/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/GinqAstWalker.groovy +++ b/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/GinqAstWalker.groovy @@ -542,11 +542,22 @@ class GinqAstWalker implements GinqAstVisitor<Expression>, SyntaxErrorReportable } if (expression instanceof BinaryExpression) { - if (expression.operation.type in [Types.KEYWORD_IN, Types.COMPARE_NOT_IN]) { + if (expression.operation.type in GinqAstBuilder.FILTER_BINARY_OP_SET) { + boolean containsGinqExpression = false + if (expression.leftExpression instanceof AbstractGinqExpression) { + expression.leftExpression = callX(classX(QUERYABLE_HELPER_TYPE), "singleValue", visit((AbstractGinqExpression) expression.leftExpression)) + containsGinqExpression = true + } if (expression.rightExpression instanceof AbstractGinqExpression) { - expression.rightExpression = callX(visit((AbstractGinqExpression) expression.rightExpression), "toList") - return expression + if (expression.operation.type in [Types.KEYWORD_IN, Types.COMPARE_NOT_IN]) { + expression.rightExpression = callX(visit((AbstractGinqExpression) expression.rightExpression), "toList") + } else { + expression.rightExpression = callX(classX(QUERYABLE_HELPER_TYPE), "singleValue", visit((AbstractGinqExpression) expression.rightExpression)) + } + containsGinqExpression = true } + + if (containsGinqExpression) return expression } } diff --git a/subprojects/groovy-ginq/src/spec/test/org/apache/groovy/ginq/GinqTest.groovy b/subprojects/groovy-ginq/src/spec/test/org/apache/groovy/ginq/GinqTest.groovy index ea382d959b..662d4ae042 100644 --- a/subprojects/groovy-ginq/src/spec/test/org/apache/groovy/ginq/GinqTest.groovy +++ b/subprojects/groovy-ginq/src/spec/test/org/apache/groovy/ginq/GinqTest.groovy @@ -769,6 +769,259 @@ class GinqTest { ''' } + @Test + void "testGinq - nested from where select - 0"() { + assertGinqScript ''' + assert [2] == GQ { + from n in [1, 2, 3] + where n == (from m in [2] select m) + select n + }.toList() + ''' + } + + @Test + void "testGinq - nested from where select - 1"() { + assertGinqScript ''' + assert [2] == GQ { + from n in [1, 2, 3] + where n == (from m in [1, 2] select max(m)) + select n + }.toList() + ''' + } + + @Test + void "testGinq - nested from where select - 1 - swap operand"() { + assertGinqScript ''' + assert [2] == GQ { + from n in [1, 2, 3] + where ((from m in [1, 2] select max(m)) == n) + select n + }.toList() + ''' + } + + @Test + void "testGinq - nested from where select - 2"() { + assertGinqScript ''' + assert [1, 3] == GQ { + from n in [1, 2, 3] + where n != (from m in [1, 2] select max(m)) + select n + }.toList() + ''' + } + + @Test + void "testGinq - nested from where select - 2 - swap operand"() { + assertGinqScript ''' + assert [1, 3] == GQ { + from n in [1, 2, 3] + where ((from m in [1, 2] select max(m)) != n) + select n + }.toList() + ''' + } + + @Test + void "testGinq - nested from where select - 3"() { + assertGinqScript ''' + assert [3] == GQ { + from n in [1, 2, 3] + where n > (from m in [1, 2] select max(m)) + select n + }.toList() + ''' + } + + @Test + void "testGinq - nested from where select - 3 - swap operand"() { + assertGinqScript ''' + assert [3] == GQ { + from n in [1, 2, 3] + where ((from m in [1, 2] select max(m)) < n) + select n + }.toList() + ''' + } + + @Test + void "testGinq - nested from where select - 4"() { + assertGinqScript ''' + assert [2, 3] == GQ { + from n in [1, 2, 3] + where n >= (from m in [1, 2] select max(m)) + select n + }.toList() + ''' + } + + @Test + void "testGinq - nested from where select - 4 - swap operand"() { + assertGinqScript ''' + assert [2, 3] == GQ { + from n in [1, 2, 3] + where ((from m in [1, 2] select max(m)) <= n) + select n + }.toList() + ''' + } + + @Test + void "testGinq - nested from where select - 5"() { + assertGinqScript ''' + assert [1] == GQ { + from n in [1, 2, 3] + where n < (from m in [1, 2] select max(m)) + select n + }.toList() + ''' + } + + @Test + void "testGinq - nested from where select - 5 - swap operand"() { + assertGinqScript ''' + assert [1] == GQ { + from n in [1, 2, 3] + where ((from m in [1, 2] select max(m)) > n) + select n + }.toList() + ''' + } + + @Test + void "testGinq - nested from where select - 6"() { + assertGinqScript ''' + assert [1, 2] == GQ { + from n in [1, 2, 3] + where n <= (from m in [1, 2] select max(m)) + select n + }.toList() + ''' + } + + @Test + void "testGinq - nested from where select - 6 - swap operand"() { + assertGinqScript ''' + assert [1, 2] == GQ { + from n in [1, 2, 3] + where ((from m in [1, 2] select max(m)) >= n) + select n + }.toList() + ''' + } + + @Test + void "testGinq - nested from where select - 7"() { + assertGinqScript ''' + assert [2] == GQ { + from n in [1, 2, 3] + where n === (from m in [1, 2] select max(m)) + select n + }.toList() + ''' + } + + @Test + void "testGinq - nested from where select - 7 - swap operand"() { + assertGinqScript ''' + assert [2] == GQ { + from n in [1, 2, 3] + where ((from m in [1, 2] select max(m)) === n) + select n + }.toList() + ''' + } + + @Test + void "testGinq - nested from where select - 8"() { + assertGinqScript ''' + assert [1, 3] == GQ { + from n in [1, 2, 3] + where n !== (from m in [1, 2] select max(m)) + select n + }.toList() + ''' + } + + @Test + void "testGinq - nested from where select - 8 - swap operand"() { + assertGinqScript ''' + assert [1, 3] == GQ { + from n in [1, 2, 3] + where ((from m in [1, 2] select max(m)) !== n) + select n + }.toList() + ''' + } + + @Test + void "testGinq - nested from where select - 9"() { + assertGinqScript ''' + assert ['123'] == GQ { + from n in ['abc', '123', 'a1b2c3'] + where n ==~ (from m in [/[0-9]+/] select m) + select n + }.toList() + ''' + } + + @Test + void "testGinq - nested from where select - 10"() { + assertGinqScript ''' + assert [/[a-z]+/] == GQ { + from n in [/[0-9]+/, /[a-z]+/] + where (from m in ['abc'] select m) ==~ n + select n + }.toList() + ''' + } + + @Test + void "testGinq - nested from where select - 11"() { + assertGinqScript ''' + assert [2] == GQ { + from n in [2, 3, 4] + where 2 > (from m in [1, 2, 3] where m < n select max(m)) + select n + }.toList() + ''' + } + + @Test + void "testGinq - nested from where select - 11 - swap operand"() { + assertGinqScript ''' + assert [2] == GQ { + from n in [2, 3, 4] + where ((from m in [1, 2, 3] where m < n select max(m)) < 2) + select n + }.toList() + ''' + } + + @Test + void "testGinq - nested from where select - 12"() { + assertGinqScript ''' + assert [2, 3] == GQ { + from n in [2, 3, 4] + where n == (from m in [1, 2, 3] where m >= n select min(m)) + select n + }.toList() + ''' + } + + @Test + void "testGinq - nested from where select - 12 - swap operand"() { + assertGinqScript ''' + assert [2, 3] == GQ { + from n in [2, 3, 4] + where ((from m in [1, 2, 3] where m >= n select min(m)) == n) + select n + }.toList() + ''' + } + @Test void "testGinq - nested from select - 0"() { assertGinqScript '''