This is an automated email from the ASF dual-hosted git repository.

sunlan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/groovy.git


The following commit(s) were added to refs/heads/master by this push:
     new e58a9e1712 GROOVY-11720: [GINQ] Failed to recognize sub-query in where 
clause (#2273)
e58a9e1712 is described below

commit e58a9e1712725016dd2035bc9916ddeedef8d8e6
Author: Daniel Sun <sun...@apache.org>
AuthorDate: Mon Jul 28 04:29:58 2025 +0900

    GROOVY-11720: [GINQ] Failed to recognize sub-query in where clause (#2273)
---
 .../org/apache/groovy/ginq/dsl/GinqAstBuilder.java |  31 ++-
 .../ginq/provider/collection/GinqAstWalker.groovy  |  38 +++-
 .../test/org/apache/groovy/ginq/GinqTest.groovy    | 253 +++++++++++++++++++++
 3 files changed, 302 insertions(+), 20 deletions(-)

diff --git 
a/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/dsl/GinqAstBuilder.java
 
b/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/dsl/GinqAstBuilder.java
index f88b690c81..beb5e39220 100644
--- 
a/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/dsl/GinqAstBuilder.java
+++ 
b/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/dsl/GinqAstBuilder.java
@@ -53,11 +53,13 @@ import org.codehaus.groovy.syntax.Types;
 import java.util.ArrayDeque;
 import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.Collections;
 import java.util.Deque;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Set;
 
+
 /**
  * Build the AST for GINQ
  *
@@ -376,12 +378,17 @@ public class GinqAstBuilder extends CodeVisitorSupport 
implements SyntaxErrorRep
     public void visitBinaryExpression(BinaryExpression expression) {
         super.visitBinaryExpression(expression);
 
-        final int opType = expression.getOperation().getType();
-        if (opType == Types.KEYWORD_IN || opType == Types.COMPARE_NOT_IN) {
-            if (null != latestGinqExpression && 
isSelectMethodCallExpression(expression.getRightExpression())) {
+        final Integer opType = expression.getOperation().getType();
+        if (FILTER_BINARY_OP_SET.contains(opType)) {
+            if (null != latestGinqExpression) {
                 // use the nested ginq and clear it
-                expression.setRightExpression(latestGinqExpression);
-                latestGinqExpression = null;
+                if 
(isSelectMethodCallExpression(expression.getRightExpression())) {
+                    expression.setRightExpression(latestGinqExpression);
+                    latestGinqExpression = null;
+                } else if 
(isSelectMethodCallExpression(expression.getLeftExpression())) {
+                    expression.setLeftExpression(latestGinqExpression);
+                    latestGinqExpression = null;
+                }
             }
         }
     }
@@ -464,8 +471,12 @@ public class GinqAstBuilder extends CodeVisitorSupport 
implements SyntaxErrorRep
         return sourceUnit;
     }
 
-    private static final String __LATEST_GINQ_EXPRESSION_CLAUSE = 
"__latestGinqExpressionClause";
+    private static final Set<Integer> FILTER_BINARY_OP_SET = 
Collections.unmodifiableSet(new HashSet<>(Arrays.asList(
+        Types.KEYWORD_IN, Types.COMPARE_NOT_IN, Types.COMPARE_IDENTICAL, 
Types.COMPARE_NOT_IDENTICAL,
+        Types.COMPARE_EQUAL, Types.COMPARE_NOT_EQUAL, Types.COMPARE_LESS_THAN, 
Types.COMPARE_LESS_THAN_EQUAL,
+        Types.COMPARE_GREATER_THAN, Types.COMPARE_GREATER_THAN_EQUAL, 
Types.MATCH_REGEX)));
 
+    private static final String __LATEST_GINQ_EXPRESSION_CLAUSE = 
"__latestGinqExpressionClause";
     private static final String KW_WITH = "with"; // reserved keyword
     private static final String KW_FROM = "from";
     private static final String KW_IN = "in";
@@ -483,10 +494,12 @@ public class GinqAstBuilder extends CodeVisitorSupport 
implements SyntaxErrorRep
     private static final String KW_OVER = "over";
     private static final String KW_AS = "as";
     private static final String KW_SHUTDOWN = "shutdown";
-    private static final Set<String> KEYWORD_SET = new HashSet<>();
+    private static final Set<String> KEYWORD_SET;
     static {
-        KEYWORD_SET.addAll(Arrays.asList(KW_WITH, KW_FROM, KW_IN, KW_ON, 
KW_WHERE, KW_EXISTS, KW_GROUPBY, KW_HAVING, KW_ORDERBY,
+        Set<String> keywordSet = new HashSet<>();
+        keywordSet.addAll(Arrays.asList(KW_WITH, KW_FROM, KW_IN, KW_ON, 
KW_WHERE, KW_EXISTS, KW_GROUPBY, KW_HAVING, KW_ORDERBY,
                                          KW_LIMIT, KW_OFFSET, KW_SELECT, 
KW_DISTINCT, KW_WITHINGROUP, KW_OVER, KW_AS, KW_SHUTDOWN));
-        KEYWORD_SET.addAll(JoinExpression.JOIN_NAME_LIST);
+        keywordSet.addAll(JoinExpression.JOIN_NAME_LIST);
+        KEYWORD_SET = Collections.unmodifiableSet(keywordSet);
     }
 }
diff --git 
a/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/GinqAstWalker.groovy
 
b/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/GinqAstWalker.groovy
index a32910a6c4..a0690e13f2 100644
--- 
a/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/GinqAstWalker.groovy
+++ 
b/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/GinqAstWalker.groovy
@@ -263,7 +263,7 @@ class GinqAstWalker implements GinqAstVisitor<Expression>, 
SyntaxErrorReportable
 
         if (expr instanceof MethodCallExpression) {
             MethodCallExpression call = (MethodCallExpression) expr
-            if (call.implicitThis && 
AGG_FUNCTION_NAME_LIST.contains(call.methodAsString)) {
+            if (call.implicitThis && 
AGG_FUNCTION_NAME_SET.contains(call.methodAsString)) {
                 def argumentCnt = ((ArgumentListExpression) 
call.getArguments()).getExpressions().size()
                 if (1 == argumentCnt || (FUNCTION_COUNT == call.methodAsString 
&& 0 == argumentCnt)) {
                     return true
@@ -542,12 +542,19 @@ class GinqAstWalker implements 
GinqAstVisitor<Expression>, SyntaxErrorReportable
                 }
 
                 if (expression instanceof BinaryExpression) {
-                    if (expression.operation.type in [Types.KEYWORD_IN, 
Types.COMPARE_NOT_IN]) {
-                        if (expression.rightExpression instanceof 
AbstractGinqExpression) {
-                            expression.rightExpression = 
callX(visit((AbstractGinqExpression) expression.rightExpression), "toList")
-                            return expression
-                        }
+                    boolean containsGinqExpression = false
+                    if (expression.leftExpression instanceof 
AbstractGinqExpression) {
+                        expression.leftExpression = 
callSingleValue((AbstractGinqExpression) expression.leftExpression)
+                        containsGinqExpression = true
+                    }
+                    if (expression.rightExpression instanceof 
AbstractGinqExpression) {
+                        expression.rightExpression = expression.operation.type 
in IN_OP_SET
+                            ? callToList((AbstractGinqExpression) 
expression.rightExpression)
+                            : callSingleValue((AbstractGinqExpression) 
expression.rightExpression)
+                        containsGinqExpression = true
                     }
+
+                    if (containsGinqExpression) return expression
                 }
 
                 return expression.transformExpression(this)
@@ -560,6 +567,14 @@ class GinqAstWalker implements GinqAstVisitor<Expression>, 
SyntaxErrorReportable
         return whereMethodCallExpression
     }
 
+    private MethodCallExpression callSingleValue(AbstractGinqExpression 
expression) {
+        return callX(classX(QUERYABLE_HELPER_TYPE), "singleValue", 
visit(expression))
+    }
+
+    private MethodCallExpression callToList(AbstractGinqExpression expression) 
{
+        return callX(visit(expression), "toList")
+    }
+
     @Override
     MethodCallExpression visitGroupExpression(GroupExpression groupExpression) 
{
         DataSourceExpression dataSourceExpression = 
groupExpression.dataSourceExpression
@@ -743,7 +758,7 @@ class GinqAstWalker implements GinqAstVisitor<Expression>, 
SyntaxErrorReportable
                             def windowFunctionMethodCallExpression = 
(MethodCallExpression) expression.objectExpression
 
                             Expression result = null
-                            if 
(windowFunctionMethodCallExpression.methodAsString in WINDOW_FUNCTION_LIST) {
+                            if 
(windowFunctionMethodCallExpression.methodAsString in WINDOW_FUNCTION_SET) {
                                 def argumentListExpression = 
(ArgumentListExpression) windowFunctionMethodCallExpression.arguments
                                 List<Expression> argumentExpressionList = []
                                 if 
(windowFunctionMethodCallExpression.methodAsString !in [FUNCTION_ROW_NUMBER, 
FUNCTION_RANK, FUNCTION_DENSE_RANK, FUNCTION_PERCENT_RANK, FUNCTION_CUME_DIST] 
&& argumentListExpression.expressions) {
@@ -1304,7 +1319,7 @@ class GinqAstWalker implements 
GinqAstVisitor<Expression>, SyntaxErrorReportable
                     if (FUNCTION_COUNT == methodName && ((TupleExpression) 
expression.arguments).getExpressions().isEmpty()) { // Similar to count(*) in 
SQL
                         expression.objectExpression = varX(__GROUP)
                         transformedExpression = expression
-                    } else if (methodName in AGG_FUNCTION_NAME_LIST) {
+                    } else if (methodName in AGG_FUNCTION_NAME_SET) {
                         Expression lambdaCode = ((TupleExpression) 
expression.arguments).getExpression(0)
                         lambdaCode.putNodeMetaData(__LAMBDA_PARAM_NAME, 
findRootObjectExpression(lambdaCode).text)
                         transformedExpression =
@@ -1563,7 +1578,7 @@ class GinqAstWalker implements 
GinqAstVisitor<Expression>, SyntaxErrorReportable
     private static final String FUNCTION_VAR = 'var'
     private static final String FUNCTION_VARP = 'varp'
     private static final String FUNCTION_AGG = 'agg'
-    private static final List<String> AGG_FUNCTION_NAME_LIST = 
[FUNCTION_COUNT, FUNCTION_MIN, FUNCTION_MAX, FUNCTION_SUM, FUNCTION_AVG, 
FUNCTION_MEDIAN, FUNCTION_STDEV, FUNCTION_STDEVP, FUNCTION_VAR, FUNCTION_VARP, 
FUNCTION_LIST, FUNCTION_AGG]
+    private static final Set<String> AGG_FUNCTION_NAME_SET = [FUNCTION_COUNT, 
FUNCTION_MIN, FUNCTION_MAX, FUNCTION_SUM, FUNCTION_AVG, FUNCTION_MEDIAN, 
FUNCTION_STDEV, FUNCTION_STDEVP, FUNCTION_VAR, FUNCTION_VARP, FUNCTION_LIST, 
FUNCTION_AGG] as HashSet
 
     private static final String FUNCTION_ROW_NUMBER = 'rowNumber'
     private static final String FUNCTION_LEAD = 'lead'
@@ -1576,9 +1591,10 @@ class GinqAstWalker implements 
GinqAstVisitor<Expression>, SyntaxErrorReportable
     private static final String FUNCTION_PERCENT_RANK = 'percentRank'
     private static final String FUNCTION_CUME_DIST = 'cumeDist'
     private static final String FUNCTION_NTILE = 'ntile'
-    private static final List<String> WINDOW_FUNCTION_LIST = [FUNCTION_COUNT, 
FUNCTION_MIN, FUNCTION_MAX, FUNCTION_SUM, FUNCTION_AVG, FUNCTION_MEDIAN, 
FUNCTION_STDEV, FUNCTION_STDEVP, FUNCTION_VAR, FUNCTION_VARP, FUNCTION_AGG,
-                                                              
FUNCTION_ROW_NUMBER, FUNCTION_LEAD, FUNCTION_LAG, FUNCTION_FIRST_VALUE, 
FUNCTION_LAST_VALUE, FUNCTION_NTH_VALUE, FUNCTION_RANK, FUNCTION_DENSE_RANK, 
FUNCTION_PERCENT_RANK, FUNCTION_CUME_DIST, FUNCTION_NTILE]
+    private static final Set<String> WINDOW_FUNCTION_SET = [FUNCTION_COUNT, 
FUNCTION_MIN, FUNCTION_MAX, FUNCTION_SUM, FUNCTION_AVG, FUNCTION_MEDIAN, 
FUNCTION_STDEV, FUNCTION_STDEVP, FUNCTION_VAR, FUNCTION_VARP, FUNCTION_AGG,
+                                                            
FUNCTION_ROW_NUMBER, FUNCTION_LEAD, FUNCTION_LAG, FUNCTION_FIRST_VALUE, 
FUNCTION_LAST_VALUE, FUNCTION_NTH_VALUE, FUNCTION_RANK, FUNCTION_DENSE_RANK, 
FUNCTION_PERCENT_RANK, FUNCTION_CUME_DIST, FUNCTION_NTILE] as HashSet
 
+    private static final Set<Integer> IN_OP_SET = [Types.KEYWORD_IN, 
Types.COMPARE_NOT_IN] as HashSet
     private static final String NAMEDRECORD_CLASS_NAME = NamedRecord.class.name
 
     private static final String USE_WINDOW_FUNCTION = 'useWindowFunction'
diff --git 
a/subprojects/groovy-ginq/src/spec/test/org/apache/groovy/ginq/GinqTest.groovy 
b/subprojects/groovy-ginq/src/spec/test/org/apache/groovy/ginq/GinqTest.groovy
index ea382d959b..662d4ae042 100644
--- 
a/subprojects/groovy-ginq/src/spec/test/org/apache/groovy/ginq/GinqTest.groovy
+++ 
b/subprojects/groovy-ginq/src/spec/test/org/apache/groovy/ginq/GinqTest.groovy
@@ -769,6 +769,259 @@ class GinqTest {
         '''
     }
 
+    @Test
+    void "testGinq - nested from where select - 0"() {
+        assertGinqScript '''
+            assert [2] == GQ {
+                from n in [1, 2, 3]
+                where n == (from m in [2] select m)
+                select n
+            }.toList()
+        '''
+    }
+
+    @Test
+    void "testGinq - nested from where select - 1"() {
+        assertGinqScript '''
+            assert [2] == GQ {
+                from n in [1, 2, 3]
+                where n == (from m in [1, 2] select max(m))
+                select n
+            }.toList()
+        '''
+    }
+
+    @Test
+    void "testGinq - nested from where select - 1 - swap operand"() {
+        assertGinqScript '''
+            assert [2] == GQ {
+                from n in [1, 2, 3]
+                where ((from m in [1, 2] select max(m)) == n)
+                select n
+            }.toList()
+        '''
+    }
+
+    @Test
+    void "testGinq - nested from where select - 2"() {
+        assertGinqScript '''
+            assert [1, 3] == GQ {
+                from n in [1, 2, 3]
+                where n != (from m in [1, 2] select max(m))
+                select n
+            }.toList()
+        '''
+    }
+
+    @Test
+    void "testGinq - nested from where select - 2 - swap operand"() {
+        assertGinqScript '''
+            assert [1, 3] == GQ {
+                from n in [1, 2, 3]
+                where ((from m in [1, 2] select max(m)) != n)
+                select n
+            }.toList()
+        '''
+    }
+
+    @Test
+    void "testGinq - nested from where select - 3"() {
+        assertGinqScript '''
+            assert [3] == GQ {
+                from n in [1, 2, 3]
+                where n > (from m in [1, 2] select max(m))
+                select n
+            }.toList()
+        '''
+    }
+
+    @Test
+    void "testGinq - nested from where select - 3 - swap operand"() {
+        assertGinqScript '''
+            assert [3] == GQ {
+                from n in [1, 2, 3]
+                where ((from m in [1, 2] select max(m)) < n)
+                select n
+            }.toList()
+        '''
+    }
+
+    @Test
+    void "testGinq - nested from where select - 4"() {
+        assertGinqScript '''
+            assert [2, 3] == GQ {
+                from n in [1, 2, 3]
+                where n >= (from m in [1, 2] select max(m))
+                select n
+            }.toList()
+        '''
+    }
+
+    @Test
+    void "testGinq - nested from where select - 4 - swap operand"() {
+        assertGinqScript '''
+            assert [2, 3] == GQ {
+                from n in [1, 2, 3]
+                where ((from m in [1, 2] select max(m)) <= n)
+                select n
+            }.toList()
+        '''
+    }
+
+    @Test
+    void "testGinq - nested from where select - 5"() {
+        assertGinqScript '''
+            assert [1] == GQ {
+                from n in [1, 2, 3]
+                where n < (from m in [1, 2] select max(m))
+                select n
+            }.toList()
+        '''
+    }
+
+    @Test
+    void "testGinq - nested from where select - 5 - swap operand"() {
+        assertGinqScript '''
+            assert [1] == GQ {
+                from n in [1, 2, 3]
+                where ((from m in [1, 2] select max(m)) > n)
+                select n
+            }.toList()
+        '''
+    }
+
+    @Test
+    void "testGinq - nested from where select - 6"() {
+        assertGinqScript '''
+            assert [1, 2] == GQ {
+                from n in [1, 2, 3]
+                where n <= (from m in [1, 2] select max(m))
+                select n
+            }.toList()
+        '''
+    }
+
+    @Test
+    void "testGinq - nested from where select - 6 - swap operand"() {
+        assertGinqScript '''
+            assert [1, 2] == GQ {
+                from n in [1, 2, 3]
+                where ((from m in [1, 2] select max(m)) >= n)
+                select n
+            }.toList()
+        '''
+    }
+
+    @Test
+    void "testGinq - nested from where select - 7"() {
+        assertGinqScript '''
+            assert [2] == GQ {
+                from n in [1, 2, 3]
+                where n === (from m in [1, 2] select max(m))
+                select n
+            }.toList()
+        '''
+    }
+
+    @Test
+    void "testGinq - nested from where select - 7 - swap operand"() {
+        assertGinqScript '''
+            assert [2] == GQ {
+                from n in [1, 2, 3]
+                where ((from m in [1, 2] select max(m)) === n)
+                select n
+            }.toList()
+        '''
+    }
+
+    @Test
+    void "testGinq - nested from where select - 8"() {
+        assertGinqScript '''
+            assert [1, 3] == GQ {
+                from n in [1, 2, 3]
+                where n !== (from m in [1, 2] select max(m))
+                select n
+            }.toList()
+        '''
+    }
+
+    @Test
+    void "testGinq - nested from where select - 8 - swap operand"() {
+        assertGinqScript '''
+            assert [1, 3] == GQ {
+                from n in [1, 2, 3]
+                where ((from m in [1, 2] select max(m)) !== n)
+                select n
+            }.toList()
+        '''
+    }
+
+    @Test
+    void "testGinq - nested from where select - 9"() {
+        assertGinqScript '''
+            assert ['123'] == GQ {
+                from n in ['abc', '123', 'a1b2c3']
+                where n ==~ (from m in [/[0-9]+/] select m)
+                select n
+            }.toList()
+        '''
+    }
+
+    @Test
+    void "testGinq - nested from where select - 10"() {
+        assertGinqScript '''
+            assert [/[a-z]+/] == GQ {
+                from n in [/[0-9]+/, /[a-z]+/]
+                where (from m in ['abc'] select m) ==~ n
+                select n
+            }.toList()
+        '''
+    }
+
+    @Test
+    void "testGinq - nested from where select - 11"() {
+        assertGinqScript '''
+            assert [2] == GQ {
+                from n in [2, 3, 4]
+                where 2 > (from m in [1, 2, 3] where m < n select max(m))
+                select n
+            }.toList()
+        '''
+    }
+
+    @Test
+    void "testGinq - nested from where select - 11 - swap operand"() {
+        assertGinqScript '''
+            assert [2] == GQ {
+                from n in [2, 3, 4]
+                where ((from m in [1, 2, 3] where m < n select max(m)) < 2)
+                select n
+            }.toList()
+        '''
+    }
+
+    @Test
+    void "testGinq - nested from where select - 12"() {
+        assertGinqScript '''
+            assert [2, 3] == GQ {
+                from n in [2, 3, 4]
+                where n == (from m in [1, 2, 3] where m >= n select min(m))
+                select n
+            }.toList()
+        '''
+    }
+
+    @Test
+    void "testGinq - nested from where select - 12 - swap operand"() {
+        assertGinqScript '''
+            assert [2, 3] == GQ {
+                from n in [2, 3, 4]
+                where ((from m in [1, 2, 3] where m >= n select min(m)) == n)
+                select n
+            }.toList()
+        '''
+    }
+
     @Test
     void "testGinq - nested from select - 0"() {
         assertGinqScript '''

Reply via email to