This is an automated email from the ASF dual-hosted git repository. sunlan pushed a commit to branch GROOVY-11720 in repository https://gitbox.apache.org/repos/asf/groovy.git
commit 38a2680aa282456264dd88af941c4db910cd1d64 Author: Daniel Sun <sun...@apache.org> AuthorDate: Sat Jul 26 09:03:45 2025 +0900 GROOVY-11720: [GINQ] Failed to recognize sub-query in where clause --- .../org/apache/groovy/ginq/dsl/GinqAstBuilder.java | 31 ++- .../ginq/provider/collection/GinqAstWalker.groovy | 36 ++- .../test/org/apache/groovy/ginq/GinqTest.groovy | 253 +++++++++++++++++++++ 3 files changed, 302 insertions(+), 18 deletions(-) diff --git a/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/dsl/GinqAstBuilder.java b/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/dsl/GinqAstBuilder.java index f88b690c81..ba60c2a6b7 100644 --- a/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/dsl/GinqAstBuilder.java +++ b/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/dsl/GinqAstBuilder.java @@ -53,11 +53,13 @@ import org.codehaus.groovy.syntax.Types; import java.util.ArrayDeque; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.Deque; import java.util.HashSet; import java.util.List; import java.util.Set; + /** * Build the AST for GINQ * @@ -376,12 +378,17 @@ public class GinqAstBuilder extends CodeVisitorSupport implements SyntaxErrorRep public void visitBinaryExpression(BinaryExpression expression) { super.visitBinaryExpression(expression); - final int opType = expression.getOperation().getType(); - if (opType == Types.KEYWORD_IN || opType == Types.COMPARE_NOT_IN) { - if (null != latestGinqExpression && isSelectMethodCallExpression(expression.getRightExpression())) { + final Integer opType = expression.getOperation().getType(); + if (FILTER_BINARY_OP_SET.contains(opType)) { + if (null != latestGinqExpression) { // use the nested ginq and clear it - expression.setRightExpression(latestGinqExpression); - latestGinqExpression = null; + if (isSelectMethodCallExpression(expression.getRightExpression())) { + expression.setRightExpression(latestGinqExpression); + latestGinqExpression = null; + } else if (isSelectMethodCallExpression(expression.getLeftExpression())) { + expression.setLeftExpression(latestGinqExpression); + latestGinqExpression = null; + } } } } @@ -464,8 +471,12 @@ public class GinqAstBuilder extends CodeVisitorSupport implements SyntaxErrorRep return sourceUnit; } - private static final String __LATEST_GINQ_EXPRESSION_CLAUSE = "__latestGinqExpressionClause"; + public static final Set<Integer> FILTER_BINARY_OP_SET = Collections.unmodifiableSet(new HashSet<>(Arrays.asList( + Types.KEYWORD_IN, Types.COMPARE_NOT_IN, Types.COMPARE_IDENTICAL, Types.COMPARE_NOT_IDENTICAL, + Types.COMPARE_EQUAL, Types.COMPARE_NOT_EQUAL, Types.COMPARE_LESS_THAN, Types.COMPARE_LESS_THAN_EQUAL, + Types.COMPARE_GREATER_THAN, Types.COMPARE_GREATER_THAN_EQUAL, Types.MATCH_REGEX))); + private static final String __LATEST_GINQ_EXPRESSION_CLAUSE = "__latestGinqExpressionClause"; private static final String KW_WITH = "with"; // reserved keyword private static final String KW_FROM = "from"; private static final String KW_IN = "in"; @@ -483,10 +494,12 @@ public class GinqAstBuilder extends CodeVisitorSupport implements SyntaxErrorRep private static final String KW_OVER = "over"; private static final String KW_AS = "as"; private static final String KW_SHUTDOWN = "shutdown"; - private static final Set<String> KEYWORD_SET = new HashSet<>(); + private static final Set<String> KEYWORD_SET; static { - KEYWORD_SET.addAll(Arrays.asList(KW_WITH, KW_FROM, KW_IN, KW_ON, KW_WHERE, KW_EXISTS, KW_GROUPBY, KW_HAVING, KW_ORDERBY, + Set<String> keywordSet = new HashSet<>(); + keywordSet.addAll(Arrays.asList(KW_WITH, KW_FROM, KW_IN, KW_ON, KW_WHERE, KW_EXISTS, KW_GROUPBY, KW_HAVING, KW_ORDERBY, KW_LIMIT, KW_OFFSET, KW_SELECT, KW_DISTINCT, KW_WITHINGROUP, KW_OVER, KW_AS, KW_SHUTDOWN)); - KEYWORD_SET.addAll(JoinExpression.JOIN_NAME_LIST); + keywordSet.addAll(JoinExpression.JOIN_NAME_LIST); + KEYWORD_SET = Collections.unmodifiableSet(keywordSet); } } diff --git a/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/GinqAstWalker.groovy b/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/GinqAstWalker.groovy index a32910a6c4..1c74eb907e 100644 --- a/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/GinqAstWalker.groovy +++ b/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/GinqAstWalker.groovy @@ -263,7 +263,7 @@ class GinqAstWalker implements GinqAstVisitor<Expression>, SyntaxErrorReportable if (expr instanceof MethodCallExpression) { MethodCallExpression call = (MethodCallExpression) expr - if (call.implicitThis && AGG_FUNCTION_NAME_LIST.contains(call.methodAsString)) { + if (call.implicitThis && AGG_FUNCTION_NAME_SET.contains(call.methodAsString)) { def argumentCnt = ((ArgumentListExpression) call.getArguments()).getExpressions().size() if (1 == argumentCnt || (FUNCTION_COUNT == call.methodAsString && 0 == argumentCnt)) { return true @@ -542,11 +542,20 @@ class GinqAstWalker implements GinqAstVisitor<Expression>, SyntaxErrorReportable } if (expression instanceof BinaryExpression) { - if (expression.operation.type in [Types.KEYWORD_IN, Types.COMPARE_NOT_IN]) { + if (expression.operation.type in GinqAstBuilder.FILTER_BINARY_OP_SET) { + boolean containsGinqExpression = false + if (expression.leftExpression instanceof AbstractGinqExpression) { + expression.leftExpression = callSingleValue((AbstractGinqExpression) expression.leftExpression) + containsGinqExpression = true + } if (expression.rightExpression instanceof AbstractGinqExpression) { - expression.rightExpression = callX(visit((AbstractGinqExpression) expression.rightExpression), "toList") - return expression + expression.rightExpression = expression.operation.type in IN_OP_SET + ? callToList((AbstractGinqExpression) expression.rightExpression) + : callSingleValue((AbstractGinqExpression) expression.rightExpression) + containsGinqExpression = true } + + if (containsGinqExpression) return expression } } @@ -560,6 +569,14 @@ class GinqAstWalker implements GinqAstVisitor<Expression>, SyntaxErrorReportable return whereMethodCallExpression } + private MethodCallExpression callSingleValue(AbstractGinqExpression expression) { + return callX(classX(QUERYABLE_HELPER_TYPE), "singleValue", visit(expression)) + } + + private MethodCallExpression callToList(AbstractGinqExpression expression) { + return callX(visit(expression), "toList") + } + @Override MethodCallExpression visitGroupExpression(GroupExpression groupExpression) { DataSourceExpression dataSourceExpression = groupExpression.dataSourceExpression @@ -743,7 +760,7 @@ class GinqAstWalker implements GinqAstVisitor<Expression>, SyntaxErrorReportable def windowFunctionMethodCallExpression = (MethodCallExpression) expression.objectExpression Expression result = null - if (windowFunctionMethodCallExpression.methodAsString in WINDOW_FUNCTION_LIST) { + if (windowFunctionMethodCallExpression.methodAsString in WINDOW_FUNCTION_SET) { def argumentListExpression = (ArgumentListExpression) windowFunctionMethodCallExpression.arguments List<Expression> argumentExpressionList = [] if (windowFunctionMethodCallExpression.methodAsString !in [FUNCTION_ROW_NUMBER, FUNCTION_RANK, FUNCTION_DENSE_RANK, FUNCTION_PERCENT_RANK, FUNCTION_CUME_DIST] && argumentListExpression.expressions) { @@ -1304,7 +1321,7 @@ class GinqAstWalker implements GinqAstVisitor<Expression>, SyntaxErrorReportable if (FUNCTION_COUNT == methodName && ((TupleExpression) expression.arguments).getExpressions().isEmpty()) { // Similar to count(*) in SQL expression.objectExpression = varX(__GROUP) transformedExpression = expression - } else if (methodName in AGG_FUNCTION_NAME_LIST) { + } else if (methodName in AGG_FUNCTION_NAME_SET) { Expression lambdaCode = ((TupleExpression) expression.arguments).getExpression(0) lambdaCode.putNodeMetaData(__LAMBDA_PARAM_NAME, findRootObjectExpression(lambdaCode).text) transformedExpression = @@ -1563,7 +1580,7 @@ class GinqAstWalker implements GinqAstVisitor<Expression>, SyntaxErrorReportable private static final String FUNCTION_VAR = 'var' private static final String FUNCTION_VARP = 'varp' private static final String FUNCTION_AGG = 'agg' - private static final List<String> AGG_FUNCTION_NAME_LIST = [FUNCTION_COUNT, FUNCTION_MIN, FUNCTION_MAX, FUNCTION_SUM, FUNCTION_AVG, FUNCTION_MEDIAN, FUNCTION_STDEV, FUNCTION_STDEVP, FUNCTION_VAR, FUNCTION_VARP, FUNCTION_LIST, FUNCTION_AGG] + private static final Set<String> AGG_FUNCTION_NAME_SET = [FUNCTION_COUNT, FUNCTION_MIN, FUNCTION_MAX, FUNCTION_SUM, FUNCTION_AVG, FUNCTION_MEDIAN, FUNCTION_STDEV, FUNCTION_STDEVP, FUNCTION_VAR, FUNCTION_VARP, FUNCTION_LIST, FUNCTION_AGG] as HashSet private static final String FUNCTION_ROW_NUMBER = 'rowNumber' private static final String FUNCTION_LEAD = 'lead' @@ -1576,9 +1593,10 @@ class GinqAstWalker implements GinqAstVisitor<Expression>, SyntaxErrorReportable private static final String FUNCTION_PERCENT_RANK = 'percentRank' private static final String FUNCTION_CUME_DIST = 'cumeDist' private static final String FUNCTION_NTILE = 'ntile' - private static final List<String> WINDOW_FUNCTION_LIST = [FUNCTION_COUNT, FUNCTION_MIN, FUNCTION_MAX, FUNCTION_SUM, FUNCTION_AVG, FUNCTION_MEDIAN, FUNCTION_STDEV, FUNCTION_STDEVP, FUNCTION_VAR, FUNCTION_VARP, FUNCTION_AGG, - FUNCTION_ROW_NUMBER, FUNCTION_LEAD, FUNCTION_LAG, FUNCTION_FIRST_VALUE, FUNCTION_LAST_VALUE, FUNCTION_NTH_VALUE, FUNCTION_RANK, FUNCTION_DENSE_RANK, FUNCTION_PERCENT_RANK, FUNCTION_CUME_DIST, FUNCTION_NTILE] + private static final Set<String> WINDOW_FUNCTION_SET = [FUNCTION_COUNT, FUNCTION_MIN, FUNCTION_MAX, FUNCTION_SUM, FUNCTION_AVG, FUNCTION_MEDIAN, FUNCTION_STDEV, FUNCTION_STDEVP, FUNCTION_VAR, FUNCTION_VARP, FUNCTION_AGG, + FUNCTION_ROW_NUMBER, FUNCTION_LEAD, FUNCTION_LAG, FUNCTION_FIRST_VALUE, FUNCTION_LAST_VALUE, FUNCTION_NTH_VALUE, FUNCTION_RANK, FUNCTION_DENSE_RANK, FUNCTION_PERCENT_RANK, FUNCTION_CUME_DIST, FUNCTION_NTILE] as HashSet + private static final Set<Integer> IN_OP_SET = [Types.KEYWORD_IN, Types.COMPARE_NOT_IN] as HashSet private static final String NAMEDRECORD_CLASS_NAME = NamedRecord.class.name private static final String USE_WINDOW_FUNCTION = 'useWindowFunction' diff --git a/subprojects/groovy-ginq/src/spec/test/org/apache/groovy/ginq/GinqTest.groovy b/subprojects/groovy-ginq/src/spec/test/org/apache/groovy/ginq/GinqTest.groovy index ea382d959b..662d4ae042 100644 --- a/subprojects/groovy-ginq/src/spec/test/org/apache/groovy/ginq/GinqTest.groovy +++ b/subprojects/groovy-ginq/src/spec/test/org/apache/groovy/ginq/GinqTest.groovy @@ -769,6 +769,259 @@ class GinqTest { ''' } + @Test + void "testGinq - nested from where select - 0"() { + assertGinqScript ''' + assert [2] == GQ { + from n in [1, 2, 3] + where n == (from m in [2] select m) + select n + }.toList() + ''' + } + + @Test + void "testGinq - nested from where select - 1"() { + assertGinqScript ''' + assert [2] == GQ { + from n in [1, 2, 3] + where n == (from m in [1, 2] select max(m)) + select n + }.toList() + ''' + } + + @Test + void "testGinq - nested from where select - 1 - swap operand"() { + assertGinqScript ''' + assert [2] == GQ { + from n in [1, 2, 3] + where ((from m in [1, 2] select max(m)) == n) + select n + }.toList() + ''' + } + + @Test + void "testGinq - nested from where select - 2"() { + assertGinqScript ''' + assert [1, 3] == GQ { + from n in [1, 2, 3] + where n != (from m in [1, 2] select max(m)) + select n + }.toList() + ''' + } + + @Test + void "testGinq - nested from where select - 2 - swap operand"() { + assertGinqScript ''' + assert [1, 3] == GQ { + from n in [1, 2, 3] + where ((from m in [1, 2] select max(m)) != n) + select n + }.toList() + ''' + } + + @Test + void "testGinq - nested from where select - 3"() { + assertGinqScript ''' + assert [3] == GQ { + from n in [1, 2, 3] + where n > (from m in [1, 2] select max(m)) + select n + }.toList() + ''' + } + + @Test + void "testGinq - nested from where select - 3 - swap operand"() { + assertGinqScript ''' + assert [3] == GQ { + from n in [1, 2, 3] + where ((from m in [1, 2] select max(m)) < n) + select n + }.toList() + ''' + } + + @Test + void "testGinq - nested from where select - 4"() { + assertGinqScript ''' + assert [2, 3] == GQ { + from n in [1, 2, 3] + where n >= (from m in [1, 2] select max(m)) + select n + }.toList() + ''' + } + + @Test + void "testGinq - nested from where select - 4 - swap operand"() { + assertGinqScript ''' + assert [2, 3] == GQ { + from n in [1, 2, 3] + where ((from m in [1, 2] select max(m)) <= n) + select n + }.toList() + ''' + } + + @Test + void "testGinq - nested from where select - 5"() { + assertGinqScript ''' + assert [1] == GQ { + from n in [1, 2, 3] + where n < (from m in [1, 2] select max(m)) + select n + }.toList() + ''' + } + + @Test + void "testGinq - nested from where select - 5 - swap operand"() { + assertGinqScript ''' + assert [1] == GQ { + from n in [1, 2, 3] + where ((from m in [1, 2] select max(m)) > n) + select n + }.toList() + ''' + } + + @Test + void "testGinq - nested from where select - 6"() { + assertGinqScript ''' + assert [1, 2] == GQ { + from n in [1, 2, 3] + where n <= (from m in [1, 2] select max(m)) + select n + }.toList() + ''' + } + + @Test + void "testGinq - nested from where select - 6 - swap operand"() { + assertGinqScript ''' + assert [1, 2] == GQ { + from n in [1, 2, 3] + where ((from m in [1, 2] select max(m)) >= n) + select n + }.toList() + ''' + } + + @Test + void "testGinq - nested from where select - 7"() { + assertGinqScript ''' + assert [2] == GQ { + from n in [1, 2, 3] + where n === (from m in [1, 2] select max(m)) + select n + }.toList() + ''' + } + + @Test + void "testGinq - nested from where select - 7 - swap operand"() { + assertGinqScript ''' + assert [2] == GQ { + from n in [1, 2, 3] + where ((from m in [1, 2] select max(m)) === n) + select n + }.toList() + ''' + } + + @Test + void "testGinq - nested from where select - 8"() { + assertGinqScript ''' + assert [1, 3] == GQ { + from n in [1, 2, 3] + where n !== (from m in [1, 2] select max(m)) + select n + }.toList() + ''' + } + + @Test + void "testGinq - nested from where select - 8 - swap operand"() { + assertGinqScript ''' + assert [1, 3] == GQ { + from n in [1, 2, 3] + where ((from m in [1, 2] select max(m)) !== n) + select n + }.toList() + ''' + } + + @Test + void "testGinq - nested from where select - 9"() { + assertGinqScript ''' + assert ['123'] == GQ { + from n in ['abc', '123', 'a1b2c3'] + where n ==~ (from m in [/[0-9]+/] select m) + select n + }.toList() + ''' + } + + @Test + void "testGinq - nested from where select - 10"() { + assertGinqScript ''' + assert [/[a-z]+/] == GQ { + from n in [/[0-9]+/, /[a-z]+/] + where (from m in ['abc'] select m) ==~ n + select n + }.toList() + ''' + } + + @Test + void "testGinq - nested from where select - 11"() { + assertGinqScript ''' + assert [2] == GQ { + from n in [2, 3, 4] + where 2 > (from m in [1, 2, 3] where m < n select max(m)) + select n + }.toList() + ''' + } + + @Test + void "testGinq - nested from where select - 11 - swap operand"() { + assertGinqScript ''' + assert [2] == GQ { + from n in [2, 3, 4] + where ((from m in [1, 2, 3] where m < n select max(m)) < 2) + select n + }.toList() + ''' + } + + @Test + void "testGinq - nested from where select - 12"() { + assertGinqScript ''' + assert [2, 3] == GQ { + from n in [2, 3, 4] + where n == (from m in [1, 2, 3] where m >= n select min(m)) + select n + }.toList() + ''' + } + + @Test + void "testGinq - nested from where select - 12 - swap operand"() { + assertGinqScript ''' + assert [2, 3] == GQ { + from n in [2, 3, 4] + where ((from m in [1, 2, 3] where m >= n select min(m)) == n) + select n + }.toList() + ''' + } + @Test void "testGinq - nested from select - 0"() { assertGinqScript '''