Repository: groovy
Updated Branches:
  refs/heads/master cba72a949 -> 8e535c261


GROOVY-8920: Fails to infer parameter and return type of SAM on RHS(closes #838)


Project: http://git-wip-us.apache.org/repos/asf/groovy/repo
Commit: http://git-wip-us.apache.org/repos/asf/groovy/commit/8e535c26
Tree: http://git-wip-us.apache.org/repos/asf/groovy/tree/8e535c26
Diff: http://git-wip-us.apache.org/repos/asf/groovy/diff/8e535c26

Branch: refs/heads/master
Commit: 8e535c2616f760166f416c73ea462d53ca2656f7
Parents: cba72a9
Author: Daniel Sun <sun...@apache.org>
Authored: Fri Dec 14 01:01:29 2018 +0800
Committer: Daniel Sun <sun...@apache.org>
Committed: Fri Dec 14 01:01:29 2018 +0800

----------------------------------------------------------------------
 .../groovy/ast/tools/GenericsUtils.java         |  82 +++++++++++---
 .../stc/StaticTypeCheckingVisitor.java          |  45 +++++++-
 src/test/groovy/transform/stc/LambdaTest.groovy | 113 ++++++++++++++++++-
 .../groovy/ast/tools/GenericsUtilsTest.groovy   |  21 ++++
 4 files changed, 240 insertions(+), 21 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/groovy/blob/8e535c26/src/main/java/org/codehaus/groovy/ast/tools/GenericsUtils.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/codehaus/groovy/ast/tools/GenericsUtils.java 
b/src/main/java/org/codehaus/groovy/ast/tools/GenericsUtils.java
index 8519642..2f90641 100644
--- a/src/main/java/org/codehaus/groovy/ast/tools/GenericsUtils.java
+++ b/src/main/java/org/codehaus/groovy/ast/tools/GenericsUtils.java
@@ -46,6 +46,7 @@ import org.codehaus.groovy.syntax.Reduction;
 import java.io.StringReader;
 import java.lang.ref.SoftReference;
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.LinkedHashMap;
@@ -805,6 +806,16 @@ public class GenericsUtils {
         return doMakeDeclaringAndActualGenericsTypeMap(declaringClass, 
actualReceiver, false).getV1();
     }
 
+    /**
+     * The method is similar with {@link 
GenericsUtils#makeDeclaringAndActualGenericsTypeMap(ClassNode, ClassNode)},
+     * The main difference is that the method will try to map all placeholders 
found to the relevant exact types,
+     * but the other will not try even if the parameterized type has 
placeholders
+     *
+     * @param declaringClass the generics class node declaring the generics 
types
+     * @param actualReceiver the sub-class class node
+     * @return the placeholder-to-actualtype mapping
+     * @since 3.0.0
+     */
     public static Map<GenericsType, GenericsType> 
makeDeclaringAndActualGenericsTypeMapOfExactType(ClassNode declaringClass, 
ClassNode actualReceiver) {
         List<ClassNode> parameterizedTypeList = new LinkedList<>();
 
@@ -887,10 +898,23 @@ public class GenericsUtils {
         return result;
     }
 
+    /**
+     * Check whether the ClassNode has non generics placeholders, aka not 
placeholder
+     *
+     * @param parameterizedType the class node
+     * @return the result
+     * @since 3.0.0
+     */
     public static boolean hasNonPlaceHolders(ClassNode parameterizedType) {
         return checkPlaceHolders(parameterizedType, genericsType -> 
!genericsType.isPlaceholder());
     }
 
+    /**
+     * Check whether the ClassNode has generics placeholders
+     * @param parameterizedType the class node
+     * @return the result
+     * @since 3.0.0
+     */
     public static boolean hasPlaceHolders(ClassNode parameterizedType) {
         return checkPlaceHolders(parameterizedType, genericsType -> 
genericsType.isPlaceholder());
     }
@@ -911,19 +935,51 @@ public class GenericsUtils {
         return false;
     }
 
-    public static boolean isGenericsPlaceHolder(ClassNode cn) {
-        if (null == cn) return false;
-
-        GenericsType[] genericsTypes = cn.getGenericsTypes();
-
-        if (null == genericsTypes) return false;
-        if (genericsTypes.length != 1) return false;
-
-        GenericsType genericsType = genericsTypes[0];
-
-        if (!genericsType.isPlaceholder()) return false;
-
-        return genericsType.getName().equals(cn.getUnresolvedName());
+    /**
+     * Get the parameter and return types of the abstract method of SAM
+     *
+     * If the abstract method is not parameterized, we will get generics 
placeholders, e.g. T, U
+     * For example, the abstract method of {@link java.util.function.Function} 
is
+     * <pre>
+     *      R apply(T t);
+     * </pre>
+     *
+     * We parameterize the above interface as {@code Function<String, 
Integer>}, then the abstract method will be
+     * <pre>
+     *      Integer apply(String t);
+     * </pre>
+     *
+     * When we call {@code parameterizeSAM} on the ClassNode {@code 
Function<String, Integer>},
+     * we can get parameter types and return type of the above abstract method,
+     * i.e. ClassNode {@code ClassHelper.STRING_TYPE} and {@code 
ClassHelper.Integer_TYPE}
+     *
+     * @param sam the class node which contains only one abstract method
+     * @return the parameter and return types
+     * @since 3.0.0
+     *
+     */
+    public static Tuple2<ClassNode[], ClassNode> parameterizeSAM(ClassNode 
sam) {
+        final Map<GenericsType, GenericsType> map = 
makePlaceholderAndParameterizedTypeMap(sam);
+
+        MethodNode methodNode = ClassHelper.findSAM(sam);
+
+        ClassNode[] parameterTypes =
+                Arrays.stream(methodNode.getParameters())
+                    .map(e -> {
+                        ClassNode originalParameterType = e.getType();
+                        return originalParameterType.isGenericsPlaceHolder()
+                                ? 
findActualTypeByGenericsPlaceholderName(originalParameterType.getUnresolvedName(),
 map)
+                                : originalParameterType;
+                    })
+                    .toArray(ClassNode[]::new);
+
+        ClassNode originalReturnType = methodNode.getReturnType();
+        ClassNode returnType =
+                originalReturnType.isGenericsPlaceHolder()
+                        ? 
findActualTypeByGenericsPlaceholderName(originalReturnType.getUnresolvedName(), 
map)
+                        : originalReturnType;
+
+        return tuple(parameterTypes, returnType);
     }
 
     /**

http://git-wip-us.apache.org/repos/asf/groovy/blob/8e535c26/src/main/java/org/codehaus/groovy/transform/stc/StaticTypeCheckingVisitor.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/codehaus/groovy/transform/stc/StaticTypeCheckingVisitor.java
 
b/src/main/java/org/codehaus/groovy/transform/stc/StaticTypeCheckingVisitor.java
index d5b7688..be7425b 100644
--- 
a/src/main/java/org/codehaus/groovy/transform/stc/StaticTypeCheckingVisitor.java
+++ 
b/src/main/java/org/codehaus/groovy/transform/stc/StaticTypeCheckingVisitor.java
@@ -22,6 +22,7 @@ import groovy.lang.Closure;
 import groovy.lang.DelegatesTo;
 import groovy.lang.IntRange;
 import groovy.lang.Range;
+import groovy.lang.Tuple2;
 import groovy.transform.NamedParam;
 import groovy.transform.NamedParams;
 import groovy.transform.TypeChecked;
@@ -179,7 +180,6 @@ import static 
org.codehaus.groovy.ast.tools.GeneralUtils.callX;
 import static org.codehaus.groovy.ast.tools.GeneralUtils.castX;
 import static org.codehaus.groovy.ast.tools.GeneralUtils.varX;
 import static 
org.codehaus.groovy.ast.tools.GenericsUtils.findActualTypeByGenericsPlaceholderName;
-import static 
org.codehaus.groovy.ast.tools.GenericsUtils.isGenericsPlaceHolder;
 import static 
org.codehaus.groovy.ast.tools.GenericsUtils.makeDeclaringAndActualGenericsTypeMap;
 import static org.codehaus.groovy.ast.tools.GenericsUtils.toGenericTypesString;
 import static 
org.codehaus.groovy.ast.tools.WideningCategories.LowestUpperBoundClassNode;
@@ -793,15 +793,20 @@ public class StaticTypeCheckingVisitor extends 
ClassCodeVisitorSupport {
             int op = expression.getOperation().getType();
             leftExpression.visit(this);
             SetterInfo setterInfo = removeSetterInfo(leftExpression);
+            ClassNode lType = null;
             if (setterInfo != null) {
                 if (ensureValidSetter(expression, leftExpression, 
rightExpression, setterInfo)) {
                     return;
                 }
 
             } else {
+                lType = getType(leftExpression);
+                inferParameterAndReturnTypesOfClosureOnRHS(lType, 
rightExpression, op);
+
                 rightExpression.visit(this);
             }
-            ClassNode lType = getType(leftExpression);
+
+            if (null == lType) lType = getType(leftExpression);
             ClassNode rType = getType(rightExpression);
             if (isNullConstant(rightExpression)) {
                 if (!isPrimitiveType(lType))
@@ -945,6 +950,31 @@ public class StaticTypeCheckingVisitor extends 
ClassCodeVisitorSupport {
         }
     }
 
+    private void inferParameterAndReturnTypesOfClosureOnRHS(ClassNode lType, 
Expression rightExpression, int op) {
+        if (ASSIGN == op) {
+            if (rightExpression instanceof ClosureExpression && 
ClassHelper.isFunctionalInterface(lType)) {
+                Tuple2<ClassNode[], ClassNode> typeInfo = 
GenericsUtils.parameterizeSAM(lType);
+                ClassNode[] paramTypes = typeInfo.getV1();
+                ClosureExpression closureExpression = ((ClosureExpression) 
rightExpression);
+                Parameter[] closureParameters = 
closureExpression.getParameters();
+
+                if (paramTypes.length == closureParameters.length) {
+                    for (int i = 0, n = closureParameters.length; i < n; i++) {
+                        Parameter parameter = closureParameters[i];
+                        if (parameter.isDynamicTyped()) {
+                            parameter.setType(paramTypes[i]);
+                            parameter.setOriginType(paramTypes[i]);
+                        }
+                    }
+                } else {
+                    addStaticTypeError("Wrong number of parameters: ", 
closureExpression);
+                }
+
+                storeInferredReturnType(rightExpression, typeInfo.getV2());
+            }
+        }
+    }
+
     /**
      * Given a binary expression corresponding to an assignment, will check 
that the type of the RHS matches one
      * of the possible setters and if not, throw a type checking error.
@@ -2354,6 +2384,13 @@ public class StaticTypeCheckingVisitor extends 
ClassCodeVisitorSupport {
         TypeCheckingContext.EnclosingClosure enclosingClosure = 
typeCheckingContext.getEnclosingClosure();
         if (!enclosingClosure.getReturnTypes().isEmpty()) {
             ClassNode returnType = 
lowestUpperBound(enclosingClosure.getReturnTypes());
+
+            ClassNode expectedReturnType = getInferredReturnType(expression);
+            // type argument can not be of primitive type, we should convert 
it to the wrapper type
+            if (null != expectedReturnType && 
ClassHelper.isPrimitiveType(returnType) && 
expectedReturnType.equals(ClassHelper.getWrapper(returnType))) {
+                returnType = expectedReturnType;
+            }
+
             storeInferredReturnType(expression, returnType);
             ClassNode inferredType = wrapClosureType(returnType);
             storeType(enclosingClosure.getClosureExpression(), inferredType);
@@ -2859,7 +2896,7 @@ public class StaticTypeCheckingVisitor extends 
ClassCodeVisitorSupport {
         List<Integer> indexList = new LinkedList<>();
         for (int i = 0, n = blockParameterTypes.length; i < n; i++) {
             ClassNode blockParameterType = blockParameterTypes[i];
-            if (isGenericsPlaceHolder(blockParameterType)) {
+            if (null != blockParameterType && 
blockParameterType.isGenericsPlaceHolder()) {
                 indexList.add(i);
             }
         }
@@ -2873,7 +2910,7 @@ public class StaticTypeCheckingVisitor extends 
ClassCodeVisitorSupport {
                     if 
(entry.getKey().getName().equals(blockParameterTypes[index].getUnresolvedName()))
 {
                         ClassNode type = entry.getValue().getType();
 
-                        if (!isGenericsPlaceHolder(type)) {
+                        if (null != type && !type.isGenericsPlaceHolder()) {
                             blockParameterTypes[index] = type;
                         }
 

http://git-wip-us.apache.org/repos/asf/groovy/blob/8e535c26/src/test/groovy/transform/stc/LambdaTest.groovy
----------------------------------------------------------------------
diff --git a/src/test/groovy/transform/stc/LambdaTest.groovy 
b/src/test/groovy/transform/stc/LambdaTest.groovy
index 5aae4ee..48d518f 100644
--- a/src/test/groovy/transform/stc/LambdaTest.groovy
+++ b/src/test/groovy/transform/stc/LambdaTest.groovy
@@ -169,6 +169,39 @@ class LambdaTest extends GroovyTestCase {
         '''
     }
 
+    void testPredicateWithoutExplicitTypeDef() {
+        assertScript '''
+        import groovy.transform.CompileStatic
+        import java.util.stream.Collectors
+        import java.util.stream.Stream
+        import java.util.function.Function
+        import java.util.function.Predicate
+        
+        @CompileStatic
+        public class Test1 {
+            public static void main(String[] args) {
+                p()
+            }
+        
+            public static void p() {
+                List<String> myList = Arrays.asList("a1", "a2", "b2", "b1", 
"c2", "c1")
+                Predicate<String> predicate = s -> s.startsWith("b")
+                Function<String, String> mapper = s -> s.toUpperCase()
+                
+                List<String> result =
+                        myList
+                            .stream()
+                            .filter(predicate)
+                            .map(mapper)
+                            .sorted()
+                            .collect(Collectors.toList())
+                
+                assert ['B1', 'B2'] == result
+            }
+        }
+        '''
+    }
+
     void testUnaryOperator() {
         assertScript '''
         import groovy.transform.CompileStatic
@@ -429,7 +462,28 @@ class LambdaTest extends GroovyTestCase {
             }
         
             public static void p() {
-                Function<Integer, Integer> f = (Integer e) -> (Integer) (e + 
1) // Casting is required...  [Static type checking] - Incompatible generic 
argument types. Cannot assign java.util.function.Function <java.lang.Integer, 
int> to: java.util.function.Function <Integer, Integer>
+                Function<Integer, Integer> f = (Integer e) -> (Integer) (e + 1)
+                assert 2 == f(1)
+            }
+        }
+        '''
+    }
+
+    void testFunctionCallWithoutExplicitTypeDef() {
+        assertScript '''
+        import groovy.transform.CompileStatic
+        import java.util.stream.Collectors
+        import java.util.stream.Stream
+        import java.util.function.Function
+        
+        @CompileStatic
+        public class Test1 {
+            public static void main(String[] args) {
+                p();
+            }
+        
+            public static void p() {
+                Function<Integer, Integer> f = e -> e + 1
                 assert 2 == f(1)
             }
         }
@@ -450,7 +504,7 @@ class LambdaTest extends GroovyTestCase {
             }
         
             public void p() {
-                Function<Integer, Integer> f = (Integer e) -> (Integer) (e + 
1) // Casting is required...  [Static type checking] - Incompatible generic 
argument types. Cannot assign java.util.function.Function <java.lang.Integer, 
int> to: java.util.function.Function <Integer, Integer>
+                Function<Integer, Integer> f = (Integer e) -> (Integer) (e + 1)
                 assert 2 == f(1)
             }
         }
@@ -471,7 +525,7 @@ class LambdaTest extends GroovyTestCase {
             }
         
             public static void p() {
-                Function<Integer, Integer> f = (Integer e) -> (Integer) (e + 
1) // Casting is required...  [Static type checking] - Incompatible generic 
argument types. Cannot assign java.util.function.Function <java.lang.Integer, 
int> to: java.util.function.Function <Integer, Integer>
+                Function<Integer, Integer> f = (Integer e) -> (Integer) (e + 1)
                 assert 2 == f.apply(1)
             }
         }
@@ -501,6 +555,29 @@ class LambdaTest extends GroovyTestCase {
         '''
     }
 
+    void testConsumerCallWithoutExplicitTypeDef() {
+        assertScript '''
+        import groovy.transform.CompileStatic
+        import java.util.stream.Collectors
+        import java.util.stream.Stream
+        import java.util.function.Consumer
+        
+        @CompileStatic
+        public class Test1 {
+            public static void main(String[] args) {
+                p();
+            }
+        
+            public static void p() {
+                int r = 1
+                Consumer<Integer> c = e -> { r += e }
+                c(2)
+                assert 3 == r
+            }
+        }
+        '''
+    }
+
     void testConsumerCall2() {
         assertScript '''
         import groovy.transform.CompileStatic
@@ -572,6 +649,32 @@ class LambdaTest extends GroovyTestCase {
         '''
     }
 
+
+    void testSamCallWithoutExplicitTypeDef() {
+        assertScript '''
+        import groovy.transform.CompileStatic
+        import java.util.stream.Collectors
+        import java.util.stream.Stream
+        
+        @CompileStatic
+        public class Test1 {
+            public static void main(String[] args) {
+                p();
+            }
+            
+            public static void p() {
+                SamCallable c = e -> e
+                assert 1 == c(1)
+            }
+        }
+        
+        @CompileStatic
+        interface SamCallable {
+            int call(int p);
+        }
+        '''
+    }
+
     void testSamCall2() {
         assertScript '''
         import groovy.transform.CompileStatic
@@ -653,7 +756,7 @@ class LambdaTest extends GroovyTestCase {
             }
         
             public static void p() {
-                Function<Integer, String> f = (Integer e) -> 'a' + e // STC 
can not infer the type of `e`, so we have to specify the type `Integer` by 
ourselves
+                Function<Integer, String> f = (Integer e) -> 'a' + e
                 assert ['a1', 'a2', 'a3'] == [1, 2, 
3].stream().map(f).collect(Collectors.toList())
             }
         }
@@ -767,4 +870,6 @@ class LambdaTest extends GroovyTestCase {
         }
         '''
     }
+
+
 }

http://git-wip-us.apache.org/repos/asf/groovy/blob/8e535c26/src/test/org/codehaus/groovy/ast/tools/GenericsUtilsTest.groovy
----------------------------------------------------------------------
diff --git a/src/test/org/codehaus/groovy/ast/tools/GenericsUtilsTest.groovy 
b/src/test/org/codehaus/groovy/ast/tools/GenericsUtilsTest.groovy
index 27fcff7..dd5220e 100644
--- a/src/test/org/codehaus/groovy/ast/tools/GenericsUtilsTest.groovy
+++ b/src/test/org/codehaus/groovy/ast/tools/GenericsUtilsTest.groovy
@@ -27,6 +27,8 @@ import org.codehaus.groovy.control.Phases
 
 import java.util.function.BiFunction
 
+import static groovy.lang.Tuple.tuple
+
 class GenericsUtilsTest extends GroovyTestCase {
     void testFindParameterizedType1() {
         def code = '''
@@ -280,4 +282,23 @@ class GenericsUtilsTest extends GroovyTestCase {
     static ClassNode findClassNode(String name, List<ClassNode> classNodeList) 
{
         return classNodeList.find { it.name == name }
     }
+
+    void testParameterizeSAM() {
+        def code = '''
+        import java.util.function.*
+        interface T extends Function<String, Integer> {}
+        '''
+        def ast = new CompilationUnit().tap {
+            addSource 'hello.groovy', code
+            compile Phases.SEMANTIC_ANALYSIS
+        }.ast
+
+        def classNodeList = ast.getModules()[0].getClasses()
+        ClassNode parameterizedClassNode = findClassNode('T', 
classNodeList).getAllInterfaces().find { 
it.name.equals('java.util.function.Function') }
+
+        Tuple2<ClassNode[], ClassNode> typeInfo = 
GenericsUtils.parameterizeSAM(parameterizedClassNode)
+        assert 1 == typeInfo.getV1().length
+        assert ClassHelper.STRING_TYPE == typeInfo.getV1()[0]
+        assert ClassHelper.Integer_TYPE == typeInfo.getV2()
+    }
 }

Reply via email to