This is an automated email from the ASF dual-hosted git repository.

sunlan pushed a commit to branch GROOVY-11720
in repository https://gitbox.apache.org/repos/asf/groovy.git

commit fee52236ac3ff8e97a56256d761a324d3c663bce
Author: Daniel Sun <sun...@apache.org>
AuthorDate: Sat Jul 26 09:03:45 2025 +0900

    GROOVY-11720: [GINQ] Failed to recognize sub-query in where clause
---
 .../org/apache/groovy/ginq/dsl/GinqAstBuilder.java |  31 ++-
 .../ginq/provider/collection/GinqAstWalker.groovy  |  23 +-
 .../test/org/apache/groovy/ginq/GinqTest.groovy    | 253 +++++++++++++++++++++
 3 files changed, 295 insertions(+), 12 deletions(-)

diff --git 
a/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/dsl/GinqAstBuilder.java
 
b/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/dsl/GinqAstBuilder.java
index f88b690c81..ba60c2a6b7 100644
--- 
a/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/dsl/GinqAstBuilder.java
+++ 
b/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/dsl/GinqAstBuilder.java
@@ -53,11 +53,13 @@ import org.codehaus.groovy.syntax.Types;
 import java.util.ArrayDeque;
 import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.Collections;
 import java.util.Deque;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Set;
 
+
 /**
  * Build the AST for GINQ
  *
@@ -376,12 +378,17 @@ public class GinqAstBuilder extends CodeVisitorSupport 
implements SyntaxErrorRep
     public void visitBinaryExpression(BinaryExpression expression) {
         super.visitBinaryExpression(expression);
 
-        final int opType = expression.getOperation().getType();
-        if (opType == Types.KEYWORD_IN || opType == Types.COMPARE_NOT_IN) {
-            if (null != latestGinqExpression && 
isSelectMethodCallExpression(expression.getRightExpression())) {
+        final Integer opType = expression.getOperation().getType();
+        if (FILTER_BINARY_OP_SET.contains(opType)) {
+            if (null != latestGinqExpression) {
                 // use the nested ginq and clear it
-                expression.setRightExpression(latestGinqExpression);
-                latestGinqExpression = null;
+                if 
(isSelectMethodCallExpression(expression.getRightExpression())) {
+                    expression.setRightExpression(latestGinqExpression);
+                    latestGinqExpression = null;
+                } else if 
(isSelectMethodCallExpression(expression.getLeftExpression())) {
+                    expression.setLeftExpression(latestGinqExpression);
+                    latestGinqExpression = null;
+                }
             }
         }
     }
@@ -464,8 +471,12 @@ public class GinqAstBuilder extends CodeVisitorSupport 
implements SyntaxErrorRep
         return sourceUnit;
     }
 
-    private static final String __LATEST_GINQ_EXPRESSION_CLAUSE = 
"__latestGinqExpressionClause";
+    public static final Set<Integer> FILTER_BINARY_OP_SET = 
Collections.unmodifiableSet(new HashSet<>(Arrays.asList(
+        Types.KEYWORD_IN, Types.COMPARE_NOT_IN, Types.COMPARE_IDENTICAL, 
Types.COMPARE_NOT_IDENTICAL,
+        Types.COMPARE_EQUAL, Types.COMPARE_NOT_EQUAL, Types.COMPARE_LESS_THAN, 
Types.COMPARE_LESS_THAN_EQUAL,
+        Types.COMPARE_GREATER_THAN, Types.COMPARE_GREATER_THAN_EQUAL, 
Types.MATCH_REGEX)));
 
+    private static final String __LATEST_GINQ_EXPRESSION_CLAUSE = 
"__latestGinqExpressionClause";
     private static final String KW_WITH = "with"; // reserved keyword
     private static final String KW_FROM = "from";
     private static final String KW_IN = "in";
@@ -483,10 +494,12 @@ public class GinqAstBuilder extends CodeVisitorSupport 
implements SyntaxErrorRep
     private static final String KW_OVER = "over";
     private static final String KW_AS = "as";
     private static final String KW_SHUTDOWN = "shutdown";
-    private static final Set<String> KEYWORD_SET = new HashSet<>();
+    private static final Set<String> KEYWORD_SET;
     static {
-        KEYWORD_SET.addAll(Arrays.asList(KW_WITH, KW_FROM, KW_IN, KW_ON, 
KW_WHERE, KW_EXISTS, KW_GROUPBY, KW_HAVING, KW_ORDERBY,
+        Set<String> keywordSet = new HashSet<>();
+        keywordSet.addAll(Arrays.asList(KW_WITH, KW_FROM, KW_IN, KW_ON, 
KW_WHERE, KW_EXISTS, KW_GROUPBY, KW_HAVING, KW_ORDERBY,
                                          KW_LIMIT, KW_OFFSET, KW_SELECT, 
KW_DISTINCT, KW_WITHINGROUP, KW_OVER, KW_AS, KW_SHUTDOWN));
-        KEYWORD_SET.addAll(JoinExpression.JOIN_NAME_LIST);
+        keywordSet.addAll(JoinExpression.JOIN_NAME_LIST);
+        KEYWORD_SET = Collections.unmodifiableSet(keywordSet);
     }
 }
diff --git 
a/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/GinqAstWalker.groovy
 
b/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/GinqAstWalker.groovy
index a32910a6c4..8c5301a7b1 100644
--- 
a/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/GinqAstWalker.groovy
+++ 
b/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/GinqAstWalker.groovy
@@ -542,11 +542,20 @@ class GinqAstWalker implements 
GinqAstVisitor<Expression>, SyntaxErrorReportable
                 }
 
                 if (expression instanceof BinaryExpression) {
-                    if (expression.operation.type in [Types.KEYWORD_IN, 
Types.COMPARE_NOT_IN]) {
+                    if (expression.operation.type in 
GinqAstBuilder.FILTER_BINARY_OP_SET) {
+                        boolean containsGinqExpression = false
+                        if (expression.leftExpression instanceof 
AbstractGinqExpression) {
+                            expression.leftExpression = 
callSingleValue((AbstractGinqExpression) expression.leftExpression)
+                            containsGinqExpression = true
+                        }
                         if (expression.rightExpression instanceof 
AbstractGinqExpression) {
-                            expression.rightExpression = 
callX(visit((AbstractGinqExpression) expression.rightExpression), "toList")
-                            return expression
+                            expression.rightExpression = 
expression.operation.type in [Types.KEYWORD_IN, Types.COMPARE_NOT_IN]
+                                                            ? 
callToList((AbstractGinqExpression) expression.rightExpression)
+                                                            : 
callSingleValue((AbstractGinqExpression) expression.rightExpression)
+                            containsGinqExpression = true
                         }
+
+                        if (containsGinqExpression) return expression
                     }
                 }
 
@@ -560,6 +569,14 @@ class GinqAstWalker implements GinqAstVisitor<Expression>, 
SyntaxErrorReportable
         return whereMethodCallExpression
     }
 
+    private MethodCallExpression callSingleValue(AbstractGinqExpression 
expression) {
+        return callX(classX(QUERYABLE_HELPER_TYPE), "singleValue", 
visit(expression))
+    }
+
+    private MethodCallExpression callToList(AbstractGinqExpression expression) 
{
+        return callX(visit(expression), "toList")
+    }
+
     @Override
     MethodCallExpression visitGroupExpression(GroupExpression groupExpression) 
{
         DataSourceExpression dataSourceExpression = 
groupExpression.dataSourceExpression
diff --git 
a/subprojects/groovy-ginq/src/spec/test/org/apache/groovy/ginq/GinqTest.groovy 
b/subprojects/groovy-ginq/src/spec/test/org/apache/groovy/ginq/GinqTest.groovy
index ea382d959b..662d4ae042 100644
--- 
a/subprojects/groovy-ginq/src/spec/test/org/apache/groovy/ginq/GinqTest.groovy
+++ 
b/subprojects/groovy-ginq/src/spec/test/org/apache/groovy/ginq/GinqTest.groovy
@@ -769,6 +769,259 @@ class GinqTest {
         '''
     }
 
+    @Test
+    void "testGinq - nested from where select - 0"() {
+        assertGinqScript '''
+            assert [2] == GQ {
+                from n in [1, 2, 3]
+                where n == (from m in [2] select m)
+                select n
+            }.toList()
+        '''
+    }
+
+    @Test
+    void "testGinq - nested from where select - 1"() {
+        assertGinqScript '''
+            assert [2] == GQ {
+                from n in [1, 2, 3]
+                where n == (from m in [1, 2] select max(m))
+                select n
+            }.toList()
+        '''
+    }
+
+    @Test
+    void "testGinq - nested from where select - 1 - swap operand"() {
+        assertGinqScript '''
+            assert [2] == GQ {
+                from n in [1, 2, 3]
+                where ((from m in [1, 2] select max(m)) == n)
+                select n
+            }.toList()
+        '''
+    }
+
+    @Test
+    void "testGinq - nested from where select - 2"() {
+        assertGinqScript '''
+            assert [1, 3] == GQ {
+                from n in [1, 2, 3]
+                where n != (from m in [1, 2] select max(m))
+                select n
+            }.toList()
+        '''
+    }
+
+    @Test
+    void "testGinq - nested from where select - 2 - swap operand"() {
+        assertGinqScript '''
+            assert [1, 3] == GQ {
+                from n in [1, 2, 3]
+                where ((from m in [1, 2] select max(m)) != n)
+                select n
+            }.toList()
+        '''
+    }
+
+    @Test
+    void "testGinq - nested from where select - 3"() {
+        assertGinqScript '''
+            assert [3] == GQ {
+                from n in [1, 2, 3]
+                where n > (from m in [1, 2] select max(m))
+                select n
+            }.toList()
+        '''
+    }
+
+    @Test
+    void "testGinq - nested from where select - 3 - swap operand"() {
+        assertGinqScript '''
+            assert [3] == GQ {
+                from n in [1, 2, 3]
+                where ((from m in [1, 2] select max(m)) < n)
+                select n
+            }.toList()
+        '''
+    }
+
+    @Test
+    void "testGinq - nested from where select - 4"() {
+        assertGinqScript '''
+            assert [2, 3] == GQ {
+                from n in [1, 2, 3]
+                where n >= (from m in [1, 2] select max(m))
+                select n
+            }.toList()
+        '''
+    }
+
+    @Test
+    void "testGinq - nested from where select - 4 - swap operand"() {
+        assertGinqScript '''
+            assert [2, 3] == GQ {
+                from n in [1, 2, 3]
+                where ((from m in [1, 2] select max(m)) <= n)
+                select n
+            }.toList()
+        '''
+    }
+
+    @Test
+    void "testGinq - nested from where select - 5"() {
+        assertGinqScript '''
+            assert [1] == GQ {
+                from n in [1, 2, 3]
+                where n < (from m in [1, 2] select max(m))
+                select n
+            }.toList()
+        '''
+    }
+
+    @Test
+    void "testGinq - nested from where select - 5 - swap operand"() {
+        assertGinqScript '''
+            assert [1] == GQ {
+                from n in [1, 2, 3]
+                where ((from m in [1, 2] select max(m)) > n)
+                select n
+            }.toList()
+        '''
+    }
+
+    @Test
+    void "testGinq - nested from where select - 6"() {
+        assertGinqScript '''
+            assert [1, 2] == GQ {
+                from n in [1, 2, 3]
+                where n <= (from m in [1, 2] select max(m))
+                select n
+            }.toList()
+        '''
+    }
+
+    @Test
+    void "testGinq - nested from where select - 6 - swap operand"() {
+        assertGinqScript '''
+            assert [1, 2] == GQ {
+                from n in [1, 2, 3]
+                where ((from m in [1, 2] select max(m)) >= n)
+                select n
+            }.toList()
+        '''
+    }
+
+    @Test
+    void "testGinq - nested from where select - 7"() {
+        assertGinqScript '''
+            assert [2] == GQ {
+                from n in [1, 2, 3]
+                where n === (from m in [1, 2] select max(m))
+                select n
+            }.toList()
+        '''
+    }
+
+    @Test
+    void "testGinq - nested from where select - 7 - swap operand"() {
+        assertGinqScript '''
+            assert [2] == GQ {
+                from n in [1, 2, 3]
+                where ((from m in [1, 2] select max(m)) === n)
+                select n
+            }.toList()
+        '''
+    }
+
+    @Test
+    void "testGinq - nested from where select - 8"() {
+        assertGinqScript '''
+            assert [1, 3] == GQ {
+                from n in [1, 2, 3]
+                where n !== (from m in [1, 2] select max(m))
+                select n
+            }.toList()
+        '''
+    }
+
+    @Test
+    void "testGinq - nested from where select - 8 - swap operand"() {
+        assertGinqScript '''
+            assert [1, 3] == GQ {
+                from n in [1, 2, 3]
+                where ((from m in [1, 2] select max(m)) !== n)
+                select n
+            }.toList()
+        '''
+    }
+
+    @Test
+    void "testGinq - nested from where select - 9"() {
+        assertGinqScript '''
+            assert ['123'] == GQ {
+                from n in ['abc', '123', 'a1b2c3']
+                where n ==~ (from m in [/[0-9]+/] select m)
+                select n
+            }.toList()
+        '''
+    }
+
+    @Test
+    void "testGinq - nested from where select - 10"() {
+        assertGinqScript '''
+            assert [/[a-z]+/] == GQ {
+                from n in [/[0-9]+/, /[a-z]+/]
+                where (from m in ['abc'] select m) ==~ n
+                select n
+            }.toList()
+        '''
+    }
+
+    @Test
+    void "testGinq - nested from where select - 11"() {
+        assertGinqScript '''
+            assert [2] == GQ {
+                from n in [2, 3, 4]
+                where 2 > (from m in [1, 2, 3] where m < n select max(m))
+                select n
+            }.toList()
+        '''
+    }
+
+    @Test
+    void "testGinq - nested from where select - 11 - swap operand"() {
+        assertGinqScript '''
+            assert [2] == GQ {
+                from n in [2, 3, 4]
+                where ((from m in [1, 2, 3] where m < n select max(m)) < 2)
+                select n
+            }.toList()
+        '''
+    }
+
+    @Test
+    void "testGinq - nested from where select - 12"() {
+        assertGinqScript '''
+            assert [2, 3] == GQ {
+                from n in [2, 3, 4]
+                where n == (from m in [1, 2, 3] where m >= n select min(m))
+                select n
+            }.toList()
+        '''
+    }
+
+    @Test
+    void "testGinq - nested from where select - 12 - swap operand"() {
+        assertGinqScript '''
+            assert [2, 3] == GQ {
+                from n in [2, 3, 4]
+                where ((from m in [1, 2, 3] where m >= n select min(m)) == n)
+                select n
+            }.toList()
+        '''
+    }
+
     @Test
     void "testGinq - nested from select - 0"() {
         assertGinqScript '''

Reply via email to