This is an automated email from the ASF dual-hosted git repository.

sunlan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/groovy.git


The following commit(s) were added to refs/heads/master by this push:
     new f46d61e  Tweak resolving variables in GINQ
f46d61e is described below

commit f46d61eb5ab33a026a8c0f0f23b3facc47513d44
Author: Daniel Sun <[email protected]>
AuthorDate: Fri Nov 27 23:55:47 2020 +0800

    Tweak resolving variables in GINQ
---
 .../ginq/provider/collection/GinqAstWalker.groovy  | 48 ++++++++++++----------
 1 file changed, 26 insertions(+), 22 deletions(-)

diff --git 
a/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/GinqAstWalker.groovy
 
b/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/GinqAstWalker.groovy
index ec7336a..01ded49 100644
--- 
a/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/GinqAstWalker.groovy
+++ 
b/subprojects/groovy-ginq/src/main/groovy/org/apache/groovy/ginq/provider/collection/GinqAstWalker.groovy
@@ -68,6 +68,7 @@ import org.objectweb.asm.Opcodes
 
 import java.util.stream.Collectors
 
+import static groovy.lang.Tuple.tuple
 import static org.codehaus.groovy.ast.ClassHelper.makeWithoutCaching
 import static org.codehaus.groovy.ast.tools.GeneralUtils.args
 import static org.codehaus.groovy.ast.tools.GeneralUtils.block
@@ -534,12 +535,9 @@ class GinqAstWalker implements GinqAstVisitor<Expression>, 
SyntaxErrorReportable
         }
 
         boolean isJoin = dataSourceExpression instanceof JoinExpression
-
-        List<DeclarationExpression> declarationExpressionList
-        if (isJoin) {
-            def lambdaParam = new VariableExpression(lambdaParamName)
-            Map<String, Expression> aliasToAccessPathMap = 
findAliasAccessPathForJoin(dataSourceExpression, lambdaParam)
-
+        boolean isGroup = groupByVisited
+        List<DeclarationExpression> declarationExpressionList = 
Collections.emptyList()
+        if (isJoin || isGroup) {
             def variableNameSet = new HashSet<String>()
             expr.visit(new CodeVisitorSupport() {
                 @Override
@@ -549,20 +547,24 @@ class GinqAstWalker implements 
GinqAstVisitor<Expression>, SyntaxErrorReportable
                 }
             })
 
+            def lambdaParam = new VariableExpression(lambdaParamName)
+            Map<String, Expression> aliasToAccessPathMap = 
findAliasAccessPath(dataSourceExpression, lambdaParam)
             declarationExpressionList =
                     aliasToAccessPathMap.entrySet().stream()
-                    .filter(e -> variableNameSet.contains(e.key))
-                    .map(e -> {
-                        def v = localVarX(e.key)
-                        v.modifiers = v.modifiers | Opcodes.ACC_FINAL
-                        return declX(v, e.value)
-                    })
-                    .collect(Collectors.toList())
-        } else {
-            declarationExpressionList = Collections.emptyList()
+                            .filter(e -> variableNameSet.contains(e.key))
+                            .map(e -> {
+                                def v = localVarX(e.key)
+                                v.modifiers = v.modifiers | Opcodes.ACC_FINAL
+
+                                if (isGroup) {
+                                    return declX(v, propX(propX(new 
VariableExpression(lambdaParamName), 'v1'), e.key))
+                                } else {
+                                    return declX(v, e.value)
+                                }
+                            })
+                            .collect(Collectors.toList())
         }
 
-
         // (1) correct itself
         expr = correctVars(dataSourceExpression, lambdaParamName, expr)
 
@@ -582,7 +584,7 @@ class GinqAstWalker implements GinqAstVisitor<Expression>, 
SyntaxErrorReportable
             }
         })
 
-        return Tuple.tuple(declarationExpressionList, expr)
+        return tuple(declarationExpressionList, expr)
     }
 
     private Expression correctVars(DataSourceExpression dataSourceExpression, 
String lambdaParamName, Expression expression) {
@@ -611,7 +613,7 @@ class GinqAstWalker implements GinqAstVisitor<Expression>, 
SyntaxErrorReportable
                         }
                     } else {
                         if 
(groupNameListExpression.getExpressions().stream().map(e -> e.text).anyMatch(e 
-> e == expression.text)
-                            || 
aliasNameListExpression.getExpressions().stream().map(e -> e.text).anyMatch(e 
-> e == expression.text)
+                                && 
aliasNameListExpression.getExpressions().stream().map(e -> e.text).allMatch(e 
-> e != expression.text)
                         ) {
                             // replace `gk` in the groupby with `__t.v1.gk`, 
note: __t.v1 stores the group key
                             transformedExpression = propX(propX(new 
VariableExpression(lambdaParamName), 'v1'), expression.text)
@@ -648,9 +650,11 @@ class GinqAstWalker implements GinqAstVisitor<Expression>, 
SyntaxErrorReportable
         return expression
     }
 
-    private static Map<String, Expression> 
findAliasAccessPathForJoin(DataSourceExpression dataSourceExpression, 
Expression prop) {
+    private static Map<String, Expression> 
findAliasAccessPath(DataSourceExpression dataSourceExpression, Expression prop) 
{
         boolean isJoin = dataSourceExpression instanceof JoinExpression
-        if (!isJoin) return Collections.emptyMap()
+        if (!isJoin) {
+            return Maps.of(dataSourceExpression.aliasExpr.text, prop)
+        }
 
         /*
                  * `n1`(`from` node) join `n2` join `n3`  will construct a 
join tree:
@@ -741,7 +745,7 @@ class GinqAstWalker implements GinqAstVisitor<Expression>, 
SyntaxErrorReportable
 
             lambdaCode.putNodeMetaData(__LAMBDA_PARAM_NAME, lambdaParamName)
             Tuple2<List<DeclarationExpression>, Expression> 
declarationAndLambdaCode = 
correctVariablesOfGinqExpression(dataSourceExpression, lambdaCode)
-            if (!(visitingAggregateFunction || (groupByVisited && 
visitingSelect))) {
+            if (!visitingAggregateFunction) {
                 declarationExpressionList = declarationAndLambdaCode.v1
             }
             lambdaCode = declarationAndLambdaCode.v2
@@ -757,7 +761,7 @@ class GinqAstWalker implements GinqAstVisitor<Expression>, 
SyntaxErrorReportable
             }
         }
 
-        return Tuple.tuple(lambdaParamName, declarationExpressionList, 
lambdaCode)
+        return tuple(lambdaParamName, declarationExpressionList, lambdaCode)
     }
 
     private boolean isGroupByVisited() {

Reply via email to