This is an automated email from the ASF dual-hosted git repository.

sunlan pushed a commit to branch GROOVY-11720
in repository https://gitbox.apache.org/repos/asf/groovy.git

commit 59526b9196fe77a8fe23dd1a104d0bd2d7eb9f73
Author: Daniel Sun <sun...@apache.org>
AuthorDate: Sat Jul 26 09:03:45 2025 +0900

    GROOVY-11720: [GINQ] Failed to recognize sub-query in where clause
---
 .../org/apache/groovy/ginq/dsl/GinqAstBuilder.java |  37 ++-
 .../ginq/provider/collection/GinqAstWalker.groovy  |  17 +-
 .../test/org/apache/groovy/ginq/GinqTest.groovy    | 253 +++++++++++++++++++++
 3 files changed, 296 insertions(+), 11 deletions(-)

diff --git 
a/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/dsl/GinqAstBuilder.java
 
b/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/dsl/GinqAstBuilder.java
index f88b690c81..41e8240119 100644
--- 
a/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/dsl/GinqAstBuilder.java
+++ 
b/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/dsl/GinqAstBuilder.java
@@ -48,16 +48,28 @@ import org.codehaus.groovy.ast.stmt.BlockStatement;
 import org.codehaus.groovy.ast.stmt.ExpressionStatement;
 import org.codehaus.groovy.ast.stmt.Statement;
 import org.codehaus.groovy.control.SourceUnit;
-import org.codehaus.groovy.syntax.Types;
 
 import java.util.ArrayDeque;
 import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.Collections;
 import java.util.Deque;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Set;
 
+import static org.codehaus.groovy.syntax.Types.COMPARE_EQUAL;
+import static org.codehaus.groovy.syntax.Types.COMPARE_GREATER_THAN;
+import static org.codehaus.groovy.syntax.Types.COMPARE_GREATER_THAN_EQUAL;
+import static org.codehaus.groovy.syntax.Types.COMPARE_IDENTICAL;
+import static org.codehaus.groovy.syntax.Types.COMPARE_LESS_THAN;
+import static org.codehaus.groovy.syntax.Types.COMPARE_LESS_THAN_EQUAL;
+import static org.codehaus.groovy.syntax.Types.COMPARE_NOT_EQUAL;
+import static org.codehaus.groovy.syntax.Types.COMPARE_NOT_IDENTICAL;
+import static org.codehaus.groovy.syntax.Types.COMPARE_NOT_IN;
+import static org.codehaus.groovy.syntax.Types.KEYWORD_IN;
+import static org.codehaus.groovy.syntax.Types.MATCH_REGEX;
+
 /**
  * Build the AST for GINQ
  *
@@ -187,7 +199,7 @@ public class GinqAstBuilder extends CodeVisitorSupport 
implements SyntaxErrorRep
             }
             final Expression expression = arguments.getExpression(0);
             if (!(expression instanceof BinaryExpression
-                    && ((BinaryExpression) 
expression).getOperation().getType() == Types.KEYWORD_IN)) {
+                    && ((BinaryExpression) 
expression).getOperation().getType() == KEYWORD_IN)) {
                 this.collectSyntaxError(
                         new GinqSyntaxError(
                                 "`in` is expected for `" + methodName + "`, 
e.g. `" + methodName + " n in nums`",
@@ -376,12 +388,17 @@ public class GinqAstBuilder extends CodeVisitorSupport 
implements SyntaxErrorRep
     public void visitBinaryExpression(BinaryExpression expression) {
         super.visitBinaryExpression(expression);
 
-        final int opType = expression.getOperation().getType();
-        if (opType == Types.KEYWORD_IN || opType == Types.COMPARE_NOT_IN) {
-            if (null != latestGinqExpression && 
isSelectMethodCallExpression(expression.getRightExpression())) {
+        final Integer opType = expression.getOperation().getType();
+        if (FILTER_BINARY_OP_SET.contains(opType)) {
+            if (null != latestGinqExpression) {
                 // use the nested ginq and clear it
-                expression.setRightExpression(latestGinqExpression);
-                latestGinqExpression = null;
+                if 
(isSelectMethodCallExpression(expression.getRightExpression())) {
+                    expression.setRightExpression(latestGinqExpression);
+                    latestGinqExpression = null;
+                } else if 
(isSelectMethodCallExpression(expression.getLeftExpression())) {
+                    expression.setLeftExpression(latestGinqExpression);
+                    latestGinqExpression = null;
+                }
             }
         }
     }
@@ -464,8 +481,12 @@ public class GinqAstBuilder extends CodeVisitorSupport 
implements SyntaxErrorRep
         return sourceUnit;
     }
 
-    private static final String __LATEST_GINQ_EXPRESSION_CLAUSE = 
"__latestGinqExpressionClause";
+    public static final Set<Integer> FILTER_BINARY_OP_SET = 
Collections.unmodifiableSet(new HashSet<>(Arrays.asList(
+        KEYWORD_IN, COMPARE_NOT_IN, COMPARE_IDENTICAL, COMPARE_NOT_IDENTICAL,
+        COMPARE_EQUAL, COMPARE_NOT_EQUAL, COMPARE_LESS_THAN, 
COMPARE_LESS_THAN_EQUAL,
+        COMPARE_GREATER_THAN, COMPARE_GREATER_THAN_EQUAL, MATCH_REGEX)));
 
+    private static final String __LATEST_GINQ_EXPRESSION_CLAUSE = 
"__latestGinqExpressionClause";
     private static final String KW_WITH = "with"; // reserved keyword
     private static final String KW_FROM = "from";
     private static final String KW_IN = "in";
diff --git 
a/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/GinqAstWalker.groovy
 
b/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/GinqAstWalker.groovy
index a32910a6c4..28a1fe60a4 100644
--- 
a/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/GinqAstWalker.groovy
+++ 
b/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/GinqAstWalker.groovy
@@ -542,11 +542,22 @@ class GinqAstWalker implements 
GinqAstVisitor<Expression>, SyntaxErrorReportable
                 }
 
                 if (expression instanceof BinaryExpression) {
-                    if (expression.operation.type in [Types.KEYWORD_IN, 
Types.COMPARE_NOT_IN]) {
+                    if (expression.operation.type in 
GinqAstBuilder.FILTER_BINARY_OP_SET) {
+                        boolean containsGinqExpression = false
+                        if (expression.leftExpression instanceof 
AbstractGinqExpression) {
+                            expression.leftExpression = 
callX(classX(QUERYABLE_HELPER_TYPE), "singleValue", 
visit((AbstractGinqExpression) expression.leftExpression))
+                            containsGinqExpression = true
+                        }
                         if (expression.rightExpression instanceof 
AbstractGinqExpression) {
-                            expression.rightExpression = 
callX(visit((AbstractGinqExpression) expression.rightExpression), "toList")
-                            return expression
+                            if (expression.operation.type in 
[Types.KEYWORD_IN, Types.COMPARE_NOT_IN]) {
+                                expression.rightExpression = 
callX(visit((AbstractGinqExpression) expression.rightExpression), "toList")
+                            } else {
+                                expression.rightExpression = 
callX(classX(QUERYABLE_HELPER_TYPE), "singleValue", 
visit((AbstractGinqExpression) expression.rightExpression))
+                            }
+                            containsGinqExpression = true
                         }
+
+                        if (containsGinqExpression) return expression
                     }
                 }
 
diff --git 
a/subprojects/groovy-ginq/src/spec/test/org/apache/groovy/ginq/GinqTest.groovy 
b/subprojects/groovy-ginq/src/spec/test/org/apache/groovy/ginq/GinqTest.groovy
index ea382d959b..662d4ae042 100644
--- 
a/subprojects/groovy-ginq/src/spec/test/org/apache/groovy/ginq/GinqTest.groovy
+++ 
b/subprojects/groovy-ginq/src/spec/test/org/apache/groovy/ginq/GinqTest.groovy
@@ -769,6 +769,259 @@ class GinqTest {
         '''
     }
 
+    @Test
+    void "testGinq - nested from where select - 0"() {
+        assertGinqScript '''
+            assert [2] == GQ {
+                from n in [1, 2, 3]
+                where n == (from m in [2] select m)
+                select n
+            }.toList()
+        '''
+    }
+
+    @Test
+    void "testGinq - nested from where select - 1"() {
+        assertGinqScript '''
+            assert [2] == GQ {
+                from n in [1, 2, 3]
+                where n == (from m in [1, 2] select max(m))
+                select n
+            }.toList()
+        '''
+    }
+
+    @Test
+    void "testGinq - nested from where select - 1 - swap operand"() {
+        assertGinqScript '''
+            assert [2] == GQ {
+                from n in [1, 2, 3]
+                where ((from m in [1, 2] select max(m)) == n)
+                select n
+            }.toList()
+        '''
+    }
+
+    @Test
+    void "testGinq - nested from where select - 2"() {
+        assertGinqScript '''
+            assert [1, 3] == GQ {
+                from n in [1, 2, 3]
+                where n != (from m in [1, 2] select max(m))
+                select n
+            }.toList()
+        '''
+    }
+
+    @Test
+    void "testGinq - nested from where select - 2 - swap operand"() {
+        assertGinqScript '''
+            assert [1, 3] == GQ {
+                from n in [1, 2, 3]
+                where ((from m in [1, 2] select max(m)) != n)
+                select n
+            }.toList()
+        '''
+    }
+
+    @Test
+    void "testGinq - nested from where select - 3"() {
+        assertGinqScript '''
+            assert [3] == GQ {
+                from n in [1, 2, 3]
+                where n > (from m in [1, 2] select max(m))
+                select n
+            }.toList()
+        '''
+    }
+
+    @Test
+    void "testGinq - nested from where select - 3 - swap operand"() {
+        assertGinqScript '''
+            assert [3] == GQ {
+                from n in [1, 2, 3]
+                where ((from m in [1, 2] select max(m)) < n)
+                select n
+            }.toList()
+        '''
+    }
+
+    @Test
+    void "testGinq - nested from where select - 4"() {
+        assertGinqScript '''
+            assert [2, 3] == GQ {
+                from n in [1, 2, 3]
+                where n >= (from m in [1, 2] select max(m))
+                select n
+            }.toList()
+        '''
+    }
+
+    @Test
+    void "testGinq - nested from where select - 4 - swap operand"() {
+        assertGinqScript '''
+            assert [2, 3] == GQ {
+                from n in [1, 2, 3]
+                where ((from m in [1, 2] select max(m)) <= n)
+                select n
+            }.toList()
+        '''
+    }
+
+    @Test
+    void "testGinq - nested from where select - 5"() {
+        assertGinqScript '''
+            assert [1] == GQ {
+                from n in [1, 2, 3]
+                where n < (from m in [1, 2] select max(m))
+                select n
+            }.toList()
+        '''
+    }
+
+    @Test
+    void "testGinq - nested from where select - 5 - swap operand"() {
+        assertGinqScript '''
+            assert [1] == GQ {
+                from n in [1, 2, 3]
+                where ((from m in [1, 2] select max(m)) > n)
+                select n
+            }.toList()
+        '''
+    }
+
+    @Test
+    void "testGinq - nested from where select - 6"() {
+        assertGinqScript '''
+            assert [1, 2] == GQ {
+                from n in [1, 2, 3]
+                where n <= (from m in [1, 2] select max(m))
+                select n
+            }.toList()
+        '''
+    }
+
+    @Test
+    void "testGinq - nested from where select - 6 - swap operand"() {
+        assertGinqScript '''
+            assert [1, 2] == GQ {
+                from n in [1, 2, 3]
+                where ((from m in [1, 2] select max(m)) >= n)
+                select n
+            }.toList()
+        '''
+    }
+
+    @Test
+    void "testGinq - nested from where select - 7"() {
+        assertGinqScript '''
+            assert [2] == GQ {
+                from n in [1, 2, 3]
+                where n === (from m in [1, 2] select max(m))
+                select n
+            }.toList()
+        '''
+    }
+
+    @Test
+    void "testGinq - nested from where select - 7 - swap operand"() {
+        assertGinqScript '''
+            assert [2] == GQ {
+                from n in [1, 2, 3]
+                where ((from m in [1, 2] select max(m)) === n)
+                select n
+            }.toList()
+        '''
+    }
+
+    @Test
+    void "testGinq - nested from where select - 8"() {
+        assertGinqScript '''
+            assert [1, 3] == GQ {
+                from n in [1, 2, 3]
+                where n !== (from m in [1, 2] select max(m))
+                select n
+            }.toList()
+        '''
+    }
+
+    @Test
+    void "testGinq - nested from where select - 8 - swap operand"() {
+        assertGinqScript '''
+            assert [1, 3] == GQ {
+                from n in [1, 2, 3]
+                where ((from m in [1, 2] select max(m)) !== n)
+                select n
+            }.toList()
+        '''
+    }
+
+    @Test
+    void "testGinq - nested from where select - 9"() {
+        assertGinqScript '''
+            assert ['123'] == GQ {
+                from n in ['abc', '123', 'a1b2c3']
+                where n ==~ (from m in [/[0-9]+/] select m)
+                select n
+            }.toList()
+        '''
+    }
+
+    @Test
+    void "testGinq - nested from where select - 10"() {
+        assertGinqScript '''
+            assert [/[a-z]+/] == GQ {
+                from n in [/[0-9]+/, /[a-z]+/]
+                where (from m in ['abc'] select m) ==~ n
+                select n
+            }.toList()
+        '''
+    }
+
+    @Test
+    void "testGinq - nested from where select - 11"() {
+        assertGinqScript '''
+            assert [2] == GQ {
+                from n in [2, 3, 4]
+                where 2 > (from m in [1, 2, 3] where m < n select max(m))
+                select n
+            }.toList()
+        '''
+    }
+
+    @Test
+    void "testGinq - nested from where select - 11 - swap operand"() {
+        assertGinqScript '''
+            assert [2] == GQ {
+                from n in [2, 3, 4]
+                where ((from m in [1, 2, 3] where m < n select max(m)) < 2)
+                select n
+            }.toList()
+        '''
+    }
+
+    @Test
+    void "testGinq - nested from where select - 12"() {
+        assertGinqScript '''
+            assert [2, 3] == GQ {
+                from n in [2, 3, 4]
+                where n == (from m in [1, 2, 3] where m >= n select min(m))
+                select n
+            }.toList()
+        '''
+    }
+
+    @Test
+    void "testGinq - nested from where select - 12 - swap operand"() {
+        assertGinqScript '''
+            assert [2, 3] == GQ {
+                from n in [2, 3, 4]
+                where ((from m in [1, 2, 3] where m >= n select min(m)) == n)
+                select n
+            }.toList()
+        '''
+    }
+
     @Test
     void "testGinq - nested from select - 0"() {
         assertGinqScript '''

Reply via email to