This is an automated email from the ASF dual-hosted git repository.

emilles pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/groovy.git

commit 5e03289d257ed2870225706797f3522df1b6662a
Author: Eric Milles <eric.mil...@thomsonreuters.com>
AuthorDate: Fri Mar 8 17:37:53 2024 -0600

    GROOVY-11335: STC: loop item type from `UnionTypeClassNode`
---
 .../transform/stc/StaticTypeCheckingVisitor.java   | 61 +++++++-------
 .../groovy/transform/stc/UnionTypeClassNode.java   | 98 ++++++++++------------
 src/test/groovy/transform/stc/LoopsSTCTest.groovy  | 14 ++++
 3 files changed, 90 insertions(+), 83 deletions(-)

diff --git 
a/src/main/java/org/codehaus/groovy/transform/stc/StaticTypeCheckingVisitor.java
 
b/src/main/java/org/codehaus/groovy/transform/stc/StaticTypeCheckingVisitor.java
index 2be11c7c1f..9a8066f8e4 100644
--- 
a/src/main/java/org/codehaus/groovy/transform/stc/StaticTypeCheckingVisitor.java
+++ 
b/src/main/java/org/codehaus/groovy/transform/stc/StaticTypeCheckingVisitor.java
@@ -2067,33 +2067,34 @@ out:    if ((samParameterTypes.length == 1 && 
isOrImplements(samParameterTypes[0
      * @see #inferComponentType
      */
     public static ClassNode inferLoopElementType(final ClassNode 
collectionType) {
-        ClassNode componentType = collectionType.getComponentType();
-        if (componentType == null) {
-            if (isOrImplements(collectionType, ITERABLE_TYPE)) {
-                ClassNode col = GenericsUtils.parameterizeType(collectionType, 
ITERABLE_TYPE);
-                componentType = 
getCombinedBoundType(col.getGenericsTypes()[0]);
-
-            } else if (isOrImplements(collectionType, MAP_TYPE)) { // 
GROOVY-6240
-                ClassNode col = GenericsUtils.parameterizeType(collectionType, 
MAP_TYPE);
-                componentType = makeClassSafe0(MAP_ENTRY_TYPE, 
col.getGenericsTypes());
-
-            } else if (isOrImplements(collectionType, STREAM_TYPE)) { // 
GROOVY-10476
-                ClassNode col = GenericsUtils.parameterizeType(collectionType, 
STREAM_TYPE);
-                componentType = 
getCombinedBoundType(col.getGenericsTypes()[0]);
-
-            } else if (isOrImplements(collectionType, ENUMERATION_TYPE)) { // 
GROOVY-6123
-                ClassNode col = GenericsUtils.parameterizeType(collectionType, 
ENUMERATION_TYPE);
-                componentType = 
getCombinedBoundType(col.getGenericsTypes()[0]);
-
-            } else if (isOrImplements(collectionType, Iterator_TYPE)) { // 
GROOVY-10712
-                ClassNode col = GenericsUtils.parameterizeType(collectionType, 
Iterator_TYPE);
-                componentType = 
getCombinedBoundType(col.getGenericsTypes()[0]);
-
-            } else if (isStringType(collectionType)) {
-                componentType = STRING_TYPE;
-            } else {
-                componentType = OBJECT_TYPE;
-            }
+        ClassNode componentType;
+        if (collectionType.isArray()) { // GROOVY-11335
+            componentType = collectionType.getComponentType();
+
+        } else if (isOrImplements(collectionType, ITERABLE_TYPE)) {
+            ClassNode col = GenericsUtils.parameterizeType(collectionType, 
ITERABLE_TYPE);
+            componentType = getCombinedBoundType(col.getGenericsTypes()[0]);
+
+        } else if (isOrImplements(collectionType, MAP_TYPE)) { // GROOVY-6240
+            ClassNode col = GenericsUtils.parameterizeType(collectionType, 
MAP_TYPE);
+            componentType = makeClassSafe0(MAP_ENTRY_TYPE, 
col.getGenericsTypes());
+
+        } else if (isOrImplements(collectionType, STREAM_TYPE)) { // 
GROOVY-10476
+            ClassNode col = GenericsUtils.parameterizeType(collectionType, 
STREAM_TYPE);
+            componentType = getCombinedBoundType(col.getGenericsTypes()[0]);
+
+        } else if (isOrImplements(collectionType, Iterator_TYPE)) { // 
GROOVY-10712
+            ClassNode col = GenericsUtils.parameterizeType(collectionType, 
Iterator_TYPE);
+            componentType = getCombinedBoundType(col.getGenericsTypes()[0]);
+
+        } else if (isOrImplements(collectionType, ENUMERATION_TYPE)) { // 
GROOVY-6123
+            ClassNode col = GenericsUtils.parameterizeType(collectionType, 
ENUMERATION_TYPE);
+            componentType = getCombinedBoundType(col.getGenericsTypes()[0]);
+
+        } else if (isStringType(collectionType)) {
+            componentType = STRING_TYPE;
+        } else {
+            componentType = OBJECT_TYPE;
         }
         return componentType;
     }
@@ -4678,8 +4679,10 @@ out:                if (mn.size() != 1) {
     }
 
     protected ClassNode inferComponentType(final ClassNode receiverType, final 
ClassNode subscriptType) {
-        ClassNode componentType = receiverType.getComponentType();
-        if (componentType == null) {
+        ClassNode componentType = null;
+        if (receiverType.isArray()) { // GROOVY-11335
+            componentType = receiverType.getComponentType();
+        } else {
             MethodCallExpression mce;
             if (subscriptType != null) { // GROOVY-5521: check for a suitable 
"getAt(T)" method
                 mce = callX(varX("#", receiverType), "getAt", varX("selector", 
subscriptType));
diff --git 
a/src/main/java/org/codehaus/groovy/transform/stc/UnionTypeClassNode.java 
b/src/main/java/org/codehaus/groovy/transform/stc/UnionTypeClassNode.java
index 90b6560687..cb1377b9b5 100644
--- a/src/main/java/org/codehaus/groovy/transform/stc/UnionTypeClassNode.java
+++ b/src/main/java/org/codehaus/groovy/transform/stc/UnionTypeClassNode.java
@@ -35,7 +35,6 @@ import org.codehaus.groovy.ast.stmt.Statement;
 import org.codehaus.groovy.transform.ASTTransformation;
 
 import java.util.Arrays;
-import java.util.Collections;
 import java.util.HashSet;
 import java.util.Iterator;
 import java.util.LinkedHashSet;
@@ -172,59 +171,51 @@ class UnionTypeClassNode extends ClassNode {
         throw new UnsupportedOperationException();
     }
 
-    @Override
-    public boolean declaresInterface(final ClassNode classNode) {
-        for (ClassNode delegate : delegates) {
-            if (delegate.declaresInterface(classNode)) return true;
-        }
-        return false;
-    }
-
     @Override
     public List<MethodNode> getAbstractMethods() {
-        List<MethodNode> allMethods = new LinkedList<MethodNode>();
+        List<MethodNode> answer = new LinkedList<>();
         for (ClassNode delegate : delegates) {
-            allMethods.addAll(delegate.getAbstractMethods());
+            answer.addAll(delegate.getAbstractMethods());
         }
-        return allMethods;
+        return answer;
     }
 
     @Override
     public List<MethodNode> getAllDeclaredMethods() {
-        List<MethodNode> allMethods = new LinkedList<MethodNode>();
+        List<MethodNode> answer = new LinkedList<>();
         for (ClassNode delegate : delegates) {
-            allMethods.addAll(delegate.getAllDeclaredMethods());
+            answer.addAll(delegate.getAllDeclaredMethods());
         }
-        return allMethods;
+        return answer;
     }
 
     @Override
     public Set<ClassNode> getAllInterfaces() {
-        Set<ClassNode> allMethods = new HashSet<ClassNode>();
+        Set<ClassNode> answer = new HashSet<>();
         for (ClassNode delegate : delegates) {
-            allMethods.addAll(delegate.getAllInterfaces());
+            answer.addAll(delegate.getAllInterfaces());
         }
-        return allMethods;
+        return answer;
     }
 
     @Override
     public List<AnnotationNode> getAnnotations() {
-        List<AnnotationNode> nodes = new LinkedList<AnnotationNode>();
+        List<AnnotationNode> answer = new LinkedList<>();
         for (ClassNode delegate : delegates) {
             List<AnnotationNode> annotations = delegate.getAnnotations();
-            if (annotations != null) nodes.addAll(annotations);
+            if (annotations != null) answer.addAll(annotations);
         }
-        return nodes;
+        return answer;
     }
 
     @Override
     public List<AnnotationNode> getAnnotations(final ClassNode type) {
-        List<AnnotationNode> nodes = new LinkedList<AnnotationNode>();
+        List<AnnotationNode> answer = new LinkedList<>();
         for (ClassNode delegate : delegates) {
             List<AnnotationNode> annotations = delegate.getAnnotations(type);
-            if (annotations != null) nodes.addAll(annotations);
+            if (annotations != null) answer.addAll(annotations);
         }
-        return nodes;
+        return answer;
     }
 
     @Override
@@ -234,11 +225,11 @@ class UnionTypeClassNode extends ClassNode {
 
     @Override
     public List<ConstructorNode> getDeclaredConstructors() {
-        List<ConstructorNode> nodes = new LinkedList<ConstructorNode>();
+        List<ConstructorNode> answer = new LinkedList<>();
         for (ClassNode delegate : delegates) {
-            nodes.addAll(delegate.getDeclaredConstructors());
+            answer.addAll(delegate.getDeclaredConstructors());
         }
-        return nodes;
+        return answer;
     }
 
     @Override
@@ -261,12 +252,12 @@ class UnionTypeClassNode extends ClassNode {
 
     @Override
     public List<MethodNode> getDeclaredMethods(final String name) {
-        List<MethodNode> nodes = new LinkedList<MethodNode>();
+        List<MethodNode> answer = new LinkedList<>();
         for (ClassNode delegate : delegates) {
             List<MethodNode> methods = delegate.getDeclaredMethods(name);
-            if (methods != null) nodes.addAll(methods);
+            if (methods != null) answer.addAll(methods);
         }
-        return nodes;
+        return answer;
     }
 
     @Override
@@ -290,12 +281,12 @@ class UnionTypeClassNode extends ClassNode {
 
     @Override
     public List<FieldNode> getFields() {
-        List<FieldNode> nodes = new LinkedList<FieldNode>();
+        List<FieldNode> answer = new LinkedList<>();
         for (ClassNode delegate : delegates) {
             List<FieldNode> fields = delegate.getFields();
-            if (fields != null) nodes.addAll(fields);
+            if (fields != null) answer.addAll(fields);
         }
-        return nodes;
+        return answer;
     }
 
     @Override
@@ -305,22 +296,25 @@ class UnionTypeClassNode extends ClassNode {
 
     @Override
     public ClassNode[] getInterfaces() {
-        Set<ClassNode> nodes = new LinkedHashSet<ClassNode>();
+        Set<ClassNode> answer = new LinkedHashSet<>();
         for (ClassNode delegate : delegates) {
-            ClassNode[] interfaces = delegate.getInterfaces();
-            if (interfaces != null) Collections.addAll(nodes, interfaces);
+            if (delegate.isInterface()) {
+                answer.remove(delegate); answer.add(delegate);
+            } else {
+                answer.addAll(Arrays.asList(delegate.getInterfaces()));
+            }
         }
-        return nodes.toArray(ClassNode.EMPTY_ARRAY);
+        return answer.toArray(ClassNode.EMPTY_ARRAY);
     }
 
     @Override
     public List<MethodNode> getMethods() {
-        List<MethodNode> nodes = new LinkedList<MethodNode>();
+        List<MethodNode> answer = new LinkedList<>();
         for (ClassNode delegate : delegates) {
             List<MethodNode> methods = delegate.getMethods();
-            if (methods != null) nodes.addAll(methods);
+            if (methods != null) answer.addAll(methods);
         }
-        return nodes;
+        return answer;
     }
 
     @Override
@@ -334,12 +328,12 @@ class UnionTypeClassNode extends ClassNode {
 
     @Override
     public List<PropertyNode> getProperties() {
-        List<PropertyNode> nodes = new LinkedList<PropertyNode>();
+        List<PropertyNode> answer = new LinkedList<>();
         for (ClassNode delegate : delegates) {
             List<PropertyNode> properties = delegate.getProperties();
-            if (properties != null) nodes.addAll(properties);
+            if (properties != null) answer.addAll(properties);
         }
-        return nodes;
+        return answer;
     }
 
     @Override
@@ -349,22 +343,18 @@ class UnionTypeClassNode extends ClassNode {
 
     @Override
     public ClassNode[] getUnresolvedInterfaces() {
-        Set<ClassNode> nodes = new LinkedHashSet<ClassNode>();
-        for (ClassNode delegate : delegates) {
-            ClassNode[] interfaces = delegate.getUnresolvedInterfaces();
-            if (interfaces != null) Collections.addAll(nodes, interfaces);
-        }
-        return nodes.toArray(ClassNode.EMPTY_ARRAY);
+        return getUnresolvedInterfaces(false);
     }
 
     @Override
     public ClassNode[] getUnresolvedInterfaces(final boolean useRedirect) {
-        Set<ClassNode> nodes = new LinkedHashSet<ClassNode>();
-        for (ClassNode delegate : delegates) {
-            ClassNode[] interfaces = 
delegate.getUnresolvedInterfaces(useRedirect);
-            if (interfaces != null) Collections.addAll(nodes, interfaces);
+        ClassNode[] interfaces = getInterfaces();
+        if (useRedirect) {
+            for (int i = 0; i < interfaces.length; ++i) {
+                interfaces[i] = interfaces[i].redirect();
+            }
         }
-        return nodes.toArray(ClassNode.EMPTY_ARRAY);
+        return interfaces;
     }
 
     @Override
diff --git a/src/test/groovy/transform/stc/LoopsSTCTest.groovy 
b/src/test/groovy/transform/stc/LoopsSTCTest.groovy
index 11bc0a91d3..e0562a6140 100644
--- a/src/test/groovy/transform/stc/LoopsSTCTest.groovy
+++ b/src/test/groovy/transform/stc/LoopsSTCTest.groovy
@@ -252,6 +252,20 @@ class LoopsSTCTest extends StaticTypeCheckingTestCase {
         '''
     }
 
+    // GROOVY-11335
+    void testForInLoopOnCollection() {
+        assertScript '''
+            def whatever(Collection<String> coll) {
+                if (coll instanceof Serializable) {
+                    for (item in coll) {
+                        return item.toLowerCase()
+                    }
+                }
+            }
+            assert whatever(['Works']) == 'works'
+        '''
+    }
+
     // GROOVY-6123
     void testForInLoopOnEnumeration() {
         assertScript '''

Reply via email to