Repository: groovy Updated Branches: refs/heads/master cba72a949 -> 8e535c261
GROOVY-8920: Fails to infer parameter and return type of SAM on RHS(closes #838) Project: http://git-wip-us.apache.org/repos/asf/groovy/repo Commit: http://git-wip-us.apache.org/repos/asf/groovy/commit/8e535c26 Tree: http://git-wip-us.apache.org/repos/asf/groovy/tree/8e535c26 Diff: http://git-wip-us.apache.org/repos/asf/groovy/diff/8e535c26 Branch: refs/heads/master Commit: 8e535c2616f760166f416c73ea462d53ca2656f7 Parents: cba72a9 Author: Daniel Sun <sun...@apache.org> Authored: Fri Dec 14 01:01:29 2018 +0800 Committer: Daniel Sun <sun...@apache.org> Committed: Fri Dec 14 01:01:29 2018 +0800 ---------------------------------------------------------------------- .../groovy/ast/tools/GenericsUtils.java | 82 +++++++++++--- .../stc/StaticTypeCheckingVisitor.java | 45 +++++++- src/test/groovy/transform/stc/LambdaTest.groovy | 113 ++++++++++++++++++- .../groovy/ast/tools/GenericsUtilsTest.groovy | 21 ++++ 4 files changed, 240 insertions(+), 21 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/groovy/blob/8e535c26/src/main/java/org/codehaus/groovy/ast/tools/GenericsUtils.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/codehaus/groovy/ast/tools/GenericsUtils.java b/src/main/java/org/codehaus/groovy/ast/tools/GenericsUtils.java index 8519642..2f90641 100644 --- a/src/main/java/org/codehaus/groovy/ast/tools/GenericsUtils.java +++ b/src/main/java/org/codehaus/groovy/ast/tools/GenericsUtils.java @@ -46,6 +46,7 @@ import org.codehaus.groovy.syntax.Reduction; import java.io.StringReader; import java.lang.ref.SoftReference; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.LinkedHashMap; @@ -805,6 +806,16 @@ public class GenericsUtils { return doMakeDeclaringAndActualGenericsTypeMap(declaringClass, actualReceiver, false).getV1(); } + /** + * The method is similar with {@link GenericsUtils#makeDeclaringAndActualGenericsTypeMap(ClassNode, ClassNode)}, + * The main difference is that the method will try to map all placeholders found to the relevant exact types, + * but the other will not try even if the parameterized type has placeholders + * + * @param declaringClass the generics class node declaring the generics types + * @param actualReceiver the sub-class class node + * @return the placeholder-to-actualtype mapping + * @since 3.0.0 + */ public static Map<GenericsType, GenericsType> makeDeclaringAndActualGenericsTypeMapOfExactType(ClassNode declaringClass, ClassNode actualReceiver) { List<ClassNode> parameterizedTypeList = new LinkedList<>(); @@ -887,10 +898,23 @@ public class GenericsUtils { return result; } + /** + * Check whether the ClassNode has non generics placeholders, aka not placeholder + * + * @param parameterizedType the class node + * @return the result + * @since 3.0.0 + */ public static boolean hasNonPlaceHolders(ClassNode parameterizedType) { return checkPlaceHolders(parameterizedType, genericsType -> !genericsType.isPlaceholder()); } + /** + * Check whether the ClassNode has generics placeholders + * @param parameterizedType the class node + * @return the result + * @since 3.0.0 + */ public static boolean hasPlaceHolders(ClassNode parameterizedType) { return checkPlaceHolders(parameterizedType, genericsType -> genericsType.isPlaceholder()); } @@ -911,19 +935,51 @@ public class GenericsUtils { return false; } - public static boolean isGenericsPlaceHolder(ClassNode cn) { - if (null == cn) return false; - - GenericsType[] genericsTypes = cn.getGenericsTypes(); - - if (null == genericsTypes) return false; - if (genericsTypes.length != 1) return false; - - GenericsType genericsType = genericsTypes[0]; - - if (!genericsType.isPlaceholder()) return false; - - return genericsType.getName().equals(cn.getUnresolvedName()); + /** + * Get the parameter and return types of the abstract method of SAM + * + * If the abstract method is not parameterized, we will get generics placeholders, e.g. T, U + * For example, the abstract method of {@link java.util.function.Function} is + * <pre> + * R apply(T t); + * </pre> + * + * We parameterize the above interface as {@code Function<String, Integer>}, then the abstract method will be + * <pre> + * Integer apply(String t); + * </pre> + * + * When we call {@code parameterizeSAM} on the ClassNode {@code Function<String, Integer>}, + * we can get parameter types and return type of the above abstract method, + * i.e. ClassNode {@code ClassHelper.STRING_TYPE} and {@code ClassHelper.Integer_TYPE} + * + * @param sam the class node which contains only one abstract method + * @return the parameter and return types + * @since 3.0.0 + * + */ + public static Tuple2<ClassNode[], ClassNode> parameterizeSAM(ClassNode sam) { + final Map<GenericsType, GenericsType> map = makePlaceholderAndParameterizedTypeMap(sam); + + MethodNode methodNode = ClassHelper.findSAM(sam); + + ClassNode[] parameterTypes = + Arrays.stream(methodNode.getParameters()) + .map(e -> { + ClassNode originalParameterType = e.getType(); + return originalParameterType.isGenericsPlaceHolder() + ? findActualTypeByGenericsPlaceholderName(originalParameterType.getUnresolvedName(), map) + : originalParameterType; + }) + .toArray(ClassNode[]::new); + + ClassNode originalReturnType = methodNode.getReturnType(); + ClassNode returnType = + originalReturnType.isGenericsPlaceHolder() + ? findActualTypeByGenericsPlaceholderName(originalReturnType.getUnresolvedName(), map) + : originalReturnType; + + return tuple(parameterTypes, returnType); } /** http://git-wip-us.apache.org/repos/asf/groovy/blob/8e535c26/src/main/java/org/codehaus/groovy/transform/stc/StaticTypeCheckingVisitor.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/codehaus/groovy/transform/stc/StaticTypeCheckingVisitor.java b/src/main/java/org/codehaus/groovy/transform/stc/StaticTypeCheckingVisitor.java index d5b7688..be7425b 100644 --- a/src/main/java/org/codehaus/groovy/transform/stc/StaticTypeCheckingVisitor.java +++ b/src/main/java/org/codehaus/groovy/transform/stc/StaticTypeCheckingVisitor.java @@ -22,6 +22,7 @@ import groovy.lang.Closure; import groovy.lang.DelegatesTo; import groovy.lang.IntRange; import groovy.lang.Range; +import groovy.lang.Tuple2; import groovy.transform.NamedParam; import groovy.transform.NamedParams; import groovy.transform.TypeChecked; @@ -179,7 +180,6 @@ import static org.codehaus.groovy.ast.tools.GeneralUtils.callX; import static org.codehaus.groovy.ast.tools.GeneralUtils.castX; import static org.codehaus.groovy.ast.tools.GeneralUtils.varX; import static org.codehaus.groovy.ast.tools.GenericsUtils.findActualTypeByGenericsPlaceholderName; -import static org.codehaus.groovy.ast.tools.GenericsUtils.isGenericsPlaceHolder; import static org.codehaus.groovy.ast.tools.GenericsUtils.makeDeclaringAndActualGenericsTypeMap; import static org.codehaus.groovy.ast.tools.GenericsUtils.toGenericTypesString; import static org.codehaus.groovy.ast.tools.WideningCategories.LowestUpperBoundClassNode; @@ -793,15 +793,20 @@ public class StaticTypeCheckingVisitor extends ClassCodeVisitorSupport { int op = expression.getOperation().getType(); leftExpression.visit(this); SetterInfo setterInfo = removeSetterInfo(leftExpression); + ClassNode lType = null; if (setterInfo != null) { if (ensureValidSetter(expression, leftExpression, rightExpression, setterInfo)) { return; } } else { + lType = getType(leftExpression); + inferParameterAndReturnTypesOfClosureOnRHS(lType, rightExpression, op); + rightExpression.visit(this); } - ClassNode lType = getType(leftExpression); + + if (null == lType) lType = getType(leftExpression); ClassNode rType = getType(rightExpression); if (isNullConstant(rightExpression)) { if (!isPrimitiveType(lType)) @@ -945,6 +950,31 @@ public class StaticTypeCheckingVisitor extends ClassCodeVisitorSupport { } } + private void inferParameterAndReturnTypesOfClosureOnRHS(ClassNode lType, Expression rightExpression, int op) { + if (ASSIGN == op) { + if (rightExpression instanceof ClosureExpression && ClassHelper.isFunctionalInterface(lType)) { + Tuple2<ClassNode[], ClassNode> typeInfo = GenericsUtils.parameterizeSAM(lType); + ClassNode[] paramTypes = typeInfo.getV1(); + ClosureExpression closureExpression = ((ClosureExpression) rightExpression); + Parameter[] closureParameters = closureExpression.getParameters(); + + if (paramTypes.length == closureParameters.length) { + for (int i = 0, n = closureParameters.length; i < n; i++) { + Parameter parameter = closureParameters[i]; + if (parameter.isDynamicTyped()) { + parameter.setType(paramTypes[i]); + parameter.setOriginType(paramTypes[i]); + } + } + } else { + addStaticTypeError("Wrong number of parameters: ", closureExpression); + } + + storeInferredReturnType(rightExpression, typeInfo.getV2()); + } + } + } + /** * Given a binary expression corresponding to an assignment, will check that the type of the RHS matches one * of the possible setters and if not, throw a type checking error. @@ -2354,6 +2384,13 @@ public class StaticTypeCheckingVisitor extends ClassCodeVisitorSupport { TypeCheckingContext.EnclosingClosure enclosingClosure = typeCheckingContext.getEnclosingClosure(); if (!enclosingClosure.getReturnTypes().isEmpty()) { ClassNode returnType = lowestUpperBound(enclosingClosure.getReturnTypes()); + + ClassNode expectedReturnType = getInferredReturnType(expression); + // type argument can not be of primitive type, we should convert it to the wrapper type + if (null != expectedReturnType && ClassHelper.isPrimitiveType(returnType) && expectedReturnType.equals(ClassHelper.getWrapper(returnType))) { + returnType = expectedReturnType; + } + storeInferredReturnType(expression, returnType); ClassNode inferredType = wrapClosureType(returnType); storeType(enclosingClosure.getClosureExpression(), inferredType); @@ -2859,7 +2896,7 @@ public class StaticTypeCheckingVisitor extends ClassCodeVisitorSupport { List<Integer> indexList = new LinkedList<>(); for (int i = 0, n = blockParameterTypes.length; i < n; i++) { ClassNode blockParameterType = blockParameterTypes[i]; - if (isGenericsPlaceHolder(blockParameterType)) { + if (null != blockParameterType && blockParameterType.isGenericsPlaceHolder()) { indexList.add(i); } } @@ -2873,7 +2910,7 @@ public class StaticTypeCheckingVisitor extends ClassCodeVisitorSupport { if (entry.getKey().getName().equals(blockParameterTypes[index].getUnresolvedName())) { ClassNode type = entry.getValue().getType(); - if (!isGenericsPlaceHolder(type)) { + if (null != type && !type.isGenericsPlaceHolder()) { blockParameterTypes[index] = type; } http://git-wip-us.apache.org/repos/asf/groovy/blob/8e535c26/src/test/groovy/transform/stc/LambdaTest.groovy ---------------------------------------------------------------------- diff --git a/src/test/groovy/transform/stc/LambdaTest.groovy b/src/test/groovy/transform/stc/LambdaTest.groovy index 5aae4ee..48d518f 100644 --- a/src/test/groovy/transform/stc/LambdaTest.groovy +++ b/src/test/groovy/transform/stc/LambdaTest.groovy @@ -169,6 +169,39 @@ class LambdaTest extends GroovyTestCase { ''' } + void testPredicateWithoutExplicitTypeDef() { + assertScript ''' + import groovy.transform.CompileStatic + import java.util.stream.Collectors + import java.util.stream.Stream + import java.util.function.Function + import java.util.function.Predicate + + @CompileStatic + public class Test1 { + public static void main(String[] args) { + p() + } + + public static void p() { + List<String> myList = Arrays.asList("a1", "a2", "b2", "b1", "c2", "c1") + Predicate<String> predicate = s -> s.startsWith("b") + Function<String, String> mapper = s -> s.toUpperCase() + + List<String> result = + myList + .stream() + .filter(predicate) + .map(mapper) + .sorted() + .collect(Collectors.toList()) + + assert ['B1', 'B2'] == result + } + } + ''' + } + void testUnaryOperator() { assertScript ''' import groovy.transform.CompileStatic @@ -429,7 +462,28 @@ class LambdaTest extends GroovyTestCase { } public static void p() { - Function<Integer, Integer> f = (Integer e) -> (Integer) (e + 1) // Casting is required... [Static type checking] - Incompatible generic argument types. Cannot assign java.util.function.Function <java.lang.Integer, int> to: java.util.function.Function <Integer, Integer> + Function<Integer, Integer> f = (Integer e) -> (Integer) (e + 1) + assert 2 == f(1) + } + } + ''' + } + + void testFunctionCallWithoutExplicitTypeDef() { + assertScript ''' + import groovy.transform.CompileStatic + import java.util.stream.Collectors + import java.util.stream.Stream + import java.util.function.Function + + @CompileStatic + public class Test1 { + public static void main(String[] args) { + p(); + } + + public static void p() { + Function<Integer, Integer> f = e -> e + 1 assert 2 == f(1) } } @@ -450,7 +504,7 @@ class LambdaTest extends GroovyTestCase { } public void p() { - Function<Integer, Integer> f = (Integer e) -> (Integer) (e + 1) // Casting is required... [Static type checking] - Incompatible generic argument types. Cannot assign java.util.function.Function <java.lang.Integer, int> to: java.util.function.Function <Integer, Integer> + Function<Integer, Integer> f = (Integer e) -> (Integer) (e + 1) assert 2 == f(1) } } @@ -471,7 +525,7 @@ class LambdaTest extends GroovyTestCase { } public static void p() { - Function<Integer, Integer> f = (Integer e) -> (Integer) (e + 1) // Casting is required... [Static type checking] - Incompatible generic argument types. Cannot assign java.util.function.Function <java.lang.Integer, int> to: java.util.function.Function <Integer, Integer> + Function<Integer, Integer> f = (Integer e) -> (Integer) (e + 1) assert 2 == f.apply(1) } } @@ -501,6 +555,29 @@ class LambdaTest extends GroovyTestCase { ''' } + void testConsumerCallWithoutExplicitTypeDef() { + assertScript ''' + import groovy.transform.CompileStatic + import java.util.stream.Collectors + import java.util.stream.Stream + import java.util.function.Consumer + + @CompileStatic + public class Test1 { + public static void main(String[] args) { + p(); + } + + public static void p() { + int r = 1 + Consumer<Integer> c = e -> { r += e } + c(2) + assert 3 == r + } + } + ''' + } + void testConsumerCall2() { assertScript ''' import groovy.transform.CompileStatic @@ -572,6 +649,32 @@ class LambdaTest extends GroovyTestCase { ''' } + + void testSamCallWithoutExplicitTypeDef() { + assertScript ''' + import groovy.transform.CompileStatic + import java.util.stream.Collectors + import java.util.stream.Stream + + @CompileStatic + public class Test1 { + public static void main(String[] args) { + p(); + } + + public static void p() { + SamCallable c = e -> e + assert 1 == c(1) + } + } + + @CompileStatic + interface SamCallable { + int call(int p); + } + ''' + } + void testSamCall2() { assertScript ''' import groovy.transform.CompileStatic @@ -653,7 +756,7 @@ class LambdaTest extends GroovyTestCase { } public static void p() { - Function<Integer, String> f = (Integer e) -> 'a' + e // STC can not infer the type of `e`, so we have to specify the type `Integer` by ourselves + Function<Integer, String> f = (Integer e) -> 'a' + e assert ['a1', 'a2', 'a3'] == [1, 2, 3].stream().map(f).collect(Collectors.toList()) } } @@ -767,4 +870,6 @@ class LambdaTest extends GroovyTestCase { } ''' } + + } http://git-wip-us.apache.org/repos/asf/groovy/blob/8e535c26/src/test/org/codehaus/groovy/ast/tools/GenericsUtilsTest.groovy ---------------------------------------------------------------------- diff --git a/src/test/org/codehaus/groovy/ast/tools/GenericsUtilsTest.groovy b/src/test/org/codehaus/groovy/ast/tools/GenericsUtilsTest.groovy index 27fcff7..dd5220e 100644 --- a/src/test/org/codehaus/groovy/ast/tools/GenericsUtilsTest.groovy +++ b/src/test/org/codehaus/groovy/ast/tools/GenericsUtilsTest.groovy @@ -27,6 +27,8 @@ import org.codehaus.groovy.control.Phases import java.util.function.BiFunction +import static groovy.lang.Tuple.tuple + class GenericsUtilsTest extends GroovyTestCase { void testFindParameterizedType1() { def code = ''' @@ -280,4 +282,23 @@ class GenericsUtilsTest extends GroovyTestCase { static ClassNode findClassNode(String name, List<ClassNode> classNodeList) { return classNodeList.find { it.name == name } } + + void testParameterizeSAM() { + def code = ''' + import java.util.function.* + interface T extends Function<String, Integer> {} + ''' + def ast = new CompilationUnit().tap { + addSource 'hello.groovy', code + compile Phases.SEMANTIC_ANALYSIS + }.ast + + def classNodeList = ast.getModules()[0].getClasses() + ClassNode parameterizedClassNode = findClassNode('T', classNodeList).getAllInterfaces().find { it.name.equals('java.util.function.Function') } + + Tuple2<ClassNode[], ClassNode> typeInfo = GenericsUtils.parameterizeSAM(parameterizedClassNode) + assert 1 == typeInfo.getV1().length + assert ClassHelper.STRING_TYPE == typeInfo.getV1()[0] + assert ClassHelper.Integer_TYPE == typeInfo.getV2() + } }