This is an automated email from the ASF dual-hosted git repository.

sunlan pushed a commit to branch GROOVY-8258
in repository https://gitbox.apache.org/repos/asf/groovy.git


The following commit(s) were added to refs/heads/GROOVY-8258 by this push:
     new 5aeb3cf  GROOVY-8258: add more test cases
5aeb3cf is described below

commit 5aeb3cfa62645b00d49833463b55ee2bc0d90123
Author: Daniel Sun <[email protected]>
AuthorDate: Thu Oct 8 02:47:20 2020 +0800

    GROOVY-8258: add more test cases
---
 .../org/apache/groovy/linq/dsl/GinqAstBuilder.java | 12 ++++---
 .../linq/provider/collection/GinqAstWalker.groovy  | 38 ++++++++++++++--------
 .../groovy/org/apache/groovy/linq/GinqTest.groovy  | 24 ++++++++++++++
 3 files changed, 56 insertions(+), 18 deletions(-)

diff --git 
a/subprojects/groovy-linq/src/main/groovy/org/apache/groovy/linq/dsl/GinqAstBuilder.java
 
b/subprojects/groovy-linq/src/main/groovy/org/apache/groovy/linq/dsl/GinqAstBuilder.java
index 1a7c79e..fa58543 100644
--- 
a/subprojects/groovy-linq/src/main/groovy/org/apache/groovy/linq/dsl/GinqAstBuilder.java
+++ 
b/subprojects/groovy-linq/src/main/groovy/org/apache/groovy/linq/dsl/GinqAstBuilder.java
@@ -36,13 +36,16 @@ import org.codehaus.groovy.ast.expr.MethodCallExpression;
 import org.codehaus.groovy.control.SourceUnit;
 import org.codehaus.groovy.syntax.Types;
 
+import java.util.ArrayDeque;
+import java.util.Deque;
+
 /**
  * Build the AST for GINQ
  *
  * @since 4.0.0
  */
 public class GinqAstBuilder extends CodeVisitorSupport implements 
SyntaxErrorReportable {
-    private SimpleGinqExpression currentSimpleGinqExpression;
+    private Deque<SimpleGinqExpression> simpleGinqExpressionStack = new 
ArrayDeque<>();
     private SimpleGinqExpression latestSimpleGinqExpression;
     private GinqExpression ginqExpression; // store the return value
     private final SourceUnit sourceUnit;
@@ -61,9 +64,11 @@ public class GinqAstBuilder extends CodeVisitorSupport 
implements SyntaxErrorRep
         final String methodName = call.getMethodAsString();
 
         if ("from".equals(methodName)) {
-            currentSimpleGinqExpression = new SimpleGinqExpression(); // store 
the result
+            simpleGinqExpressionStack.push(new SimpleGinqExpression()); // 
store the result
         }
 
+        SimpleGinqExpression currentSimpleGinqExpression = 
simpleGinqExpressionStack.peek();
+
         if ("from".equals(methodName)  || 
JoinExpression.isJoinExpression(methodName)) {
             ArgumentListExpression arguments = (ArgumentListExpression) 
call.getArguments();
             if (arguments.getExpressions().size() != 1) {
@@ -141,8 +146,7 @@ public class GinqAstBuilder extends CodeVisitorSupport 
implements SyntaxErrorRep
             currentSimpleGinqExpression.setSelectExpression(selectExpression);
             ginqExpression = selectExpression;
 
-            latestSimpleGinqExpression = currentSimpleGinqExpression;
-            currentSimpleGinqExpression = null;
+            latestSimpleGinqExpression = simpleGinqExpressionStack.pop();
 
             return;
         }
diff --git 
a/subprojects/groovy-linq/src/main/groovy/org/apache/groovy/linq/provider/collection/GinqAstWalker.groovy
 
b/subprojects/groovy-linq/src/main/groovy/org/apache/groovy/linq/provider/collection/GinqAstWalker.groovy
index 63328c3..779c215 100644
--- 
a/subprojects/groovy-linq/src/main/groovy/org/apache/groovy/linq/provider/collection/GinqAstWalker.groovy
+++ 
b/subprojects/groovy-linq/src/main/groovy/org/apache/groovy/linq/provider/collection/GinqAstWalker.groovy
@@ -266,31 +266,41 @@ class GinqAstWalker implements GinqVisitor<Object>, 
SyntaxErrorReportable {
         final Expression firstAliasExpr = dataSourceExpression.aliasExpr
         final Expression secondAliasExpr = joinExpression.aliasExpr
 
+        def correctVar = { Expression expression ->
+            if (expression instanceof VariableExpression) {
+                Expression transformedExpression = null
+                if (firstAliasExpr.text == expression.text) {
+                    // replace `n1` with `__t.v1`
+                    transformedExpression = constructFirstAliasVariableAccess()
+                } else if (secondAliasExpr.text == expression.text) {
+                    // replace `n2` with `__t.v2`
+                    transformedExpression = 
constructSecondAliasVariableAccess()
+                }
+
+                if (null != transformedExpression) {
+                    return transformedExpression
+                }
+            }
+
+            return expression
+        }
+
         // The synthetic lambda parameter `__t` represents the element from 
the result datasource of joining, e.g. `n1` innerJoin `n2`
         // The element from first datasource(`n1`) is referenced via `_t.v1`
         // and the element from second datasource(`n2`) is referenced via 
`_t.v2`
         expr = expr.transformExpression(new ExpressionTransformer() {
             @Override
             Expression transform(Expression expression) {
-                if (expression instanceof VariableExpression) {
-                    Expression transformedExpression = null
-                    if (firstAliasExpr.text == expression.text) {
-                        // replace `n1` with `__t.v1`
-                        transformedExpression = 
constructFirstAliasVariableAccess()
-                    } else if (secondAliasExpr.text == expression.text) {
-                        // replace `n2` with `__t.v2`
-                        transformedExpression = 
constructSecondAliasVariableAccess()
-                    }
-
-                    if (null != transformedExpression) {
-                        return transformedExpression
-                    }
+                Expression transformedExpression = correctVar(expression)
+                if (transformedExpression !== expression) {
+                    return transformedExpression
                 }
 
                 return expression.transformExpression(this)
             }
         })
-        return expr
+
+        return correctVar(expr)
     }
 
     @Override
diff --git 
a/subprojects/groovy-linq/src/test/groovy/org/apache/groovy/linq/GinqTest.groovy
 
b/subprojects/groovy-linq/src/test/groovy/org/apache/groovy/linq/GinqTest.groovy
index a4310a3..e38943b 100644
--- 
a/subprojects/groovy-linq/src/test/groovy/org/apache/groovy/linq/GinqTest.groovy
+++ 
b/subprojects/groovy-linq/src/test/groovy/org/apache/groovy/linq/GinqTest.groovy
@@ -296,6 +296,30 @@ class GinqTest {
     }
 
     @Test
+    void "testGinq - from innerJoin select - 9"() {
+        assertScript '''
+            assert [1, 2, 3] == GINQ {
+                from n in [1, 2, 3]
+                innerJoin k in [2, 3, 4]
+                on n + 1 == k
+                select n
+            }.toList()
+        '''
+    }
+
+    @Test
+    void "testGinq - from innerJoin select - 10"() {
+        assertScript '''
+            assert [2, 3, 4] == GINQ {
+                from n in [1, 2, 3]
+                innerJoin k in [2, 3, 4]
+                on n + 1 == k
+                select k
+            }.toList()
+        '''
+    }
+
+    @Test
     void "testGinq - from innerJoin where select - 1"() {
         assertScript '''
             def nums1 = [1, 2, 3]

Reply via email to