This is an automated email from the ASF dual-hosted git repository.

emilles pushed a commit to branch GROOVY_3_0_X
in repository https://gitbox.apache.org/repos/asf/groovy.git

commit c917668243d74746bb0ddfef96de61da0bcf2b75
Author: Eric Milles <eric.mil...@thomsonreuters.com>
AuthorDate: Fri Mar 8 17:37:53 2024 -0600

    GROOVY-11335: STC: loop item type from `UnionTypeClassNode`
    
    3_0_X backport
---
 .../transform/stc/StaticTypeCheckingVisitor.java   | 47 ++++++-----
 .../groovy/transform/stc/UnionTypeClassNode.java   | 98 ++++++++++------------
 src/test/groovy/transform/stc/LoopsSTCTest.groovy  | 14 ++++
 3 files changed, 83 insertions(+), 76 deletions(-)

diff --git 
a/src/main/java/org/codehaus/groovy/transform/stc/StaticTypeCheckingVisitor.java
 
b/src/main/java/org/codehaus/groovy/transform/stc/StaticTypeCheckingVisitor.java
index 04a7a4be28..9da79a987a 100644
--- 
a/src/main/java/org/codehaus/groovy/transform/stc/StaticTypeCheckingVisitor.java
+++ 
b/src/main/java/org/codehaus/groovy/transform/stc/StaticTypeCheckingVisitor.java
@@ -1985,26 +1985,27 @@ public class StaticTypeCheckingVisitor extends 
ClassCodeVisitorSupport {
      * @see #inferComponentType
      */
     public static ClassNode inferLoopElementType(final ClassNode 
collectionType) {
-        ClassNode componentType = collectionType.getComponentType();
-        if (componentType == null) {
-            if (implementsInterfaceOrIsSubclassOf(collectionType, 
ITERABLE_TYPE)) {
-                ClassNode col = GenericsUtils.parameterizeType(collectionType, 
ITERABLE_TYPE);
-                componentType = col.getGenericsTypes()[0].getType();
-
-            } else if (implementsInterfaceOrIsSubclassOf(collectionType, 
MAP_TYPE)) { // GROOVY-6240
-                ClassNode col = GenericsUtils.parameterizeType(collectionType, 
MAP_TYPE);
-                componentType = MAP_ENTRY_TYPE.getPlainNodeReference();
-                componentType.setGenericsTypes(col.getGenericsTypes());
-
-            } else if (implementsInterfaceOrIsSubclassOf(collectionType, 
ENUMERATION_TYPE)) { // GROOVY-6123
-                ClassNode col = GenericsUtils.parameterizeType(collectionType, 
ENUMERATION_TYPE);
-                componentType = col.getGenericsTypes()[0].getType();
-
-            } else if (collectionType.equals(STRING_TYPE)) {
-                componentType = STRING_TYPE;
-            } else {
-                componentType = OBJECT_TYPE;
-            }
+        ClassNode componentType;
+        if (collectionType.isArray()) { // GROOVY-11335
+            componentType = collectionType.getComponentType();
+
+        } else if (implementsInterfaceOrIsSubclassOf(collectionType, 
ITERABLE_TYPE)) {
+            ClassNode col = GenericsUtils.parameterizeType(collectionType, 
ITERABLE_TYPE);
+            componentType = col.getGenericsTypes()[0].getType();
+
+        } else if (implementsInterfaceOrIsSubclassOf(collectionType, 
MAP_TYPE)) { // GROOVY-6240
+            ClassNode col = GenericsUtils.parameterizeType(collectionType, 
MAP_TYPE);
+            componentType = MAP_ENTRY_TYPE.getPlainNodeReference();
+            componentType.setGenericsTypes(col.getGenericsTypes());
+
+        } else if (implementsInterfaceOrIsSubclassOf(collectionType, 
ENUMERATION_TYPE)) { // GROOVY-6123
+            ClassNode col = GenericsUtils.parameterizeType(collectionType, 
ENUMERATION_TYPE);
+            componentType = col.getGenericsTypes()[0].getType();
+
+        } else if (collectionType.equals(STRING_TYPE)) {
+            componentType = STRING_TYPE;
+        } else {
+            componentType = OBJECT_TYPE;
         }
         return componentType;
     }
@@ -4692,8 +4693,10 @@ public class StaticTypeCheckingVisitor extends 
ClassCodeVisitorSupport {
     }
 
     protected ClassNode inferComponentType(final ClassNode containerType, 
final ClassNode indexType) {
-        ClassNode componentType = containerType.getComponentType();
-        if (componentType == null) {
+        ClassNode componentType = null;
+        if (containerType.isArray()) { // GROOVY-11335
+            componentType = containerType.getComponentType();
+        } else {
             // GROOVY-5521: check for "getAt" method
             typeCheckingContext.pushErrorCollector();
             MethodCallExpression vcall = callX(localVarX("_hash_", 
containerType), "getAt", varX("_index_", indexType));
diff --git 
a/src/main/java/org/codehaus/groovy/transform/stc/UnionTypeClassNode.java 
b/src/main/java/org/codehaus/groovy/transform/stc/UnionTypeClassNode.java
index b4aa4154e0..ab402af70a 100644
--- a/src/main/java/org/codehaus/groovy/transform/stc/UnionTypeClassNode.java
+++ b/src/main/java/org/codehaus/groovy/transform/stc/UnionTypeClassNode.java
@@ -35,7 +35,6 @@ import org.codehaus.groovy.ast.stmt.Statement;
 import org.codehaus.groovy.transform.ASTTransformation;
 
 import java.util.Arrays;
-import java.util.Collections;
 import java.util.HashSet;
 import java.util.Iterator;
 import java.util.LinkedHashSet;
@@ -172,59 +171,51 @@ class UnionTypeClassNode extends ClassNode {
         throw new UnsupportedOperationException();
     }
 
-    @Override
-    public boolean declaresInterface(final ClassNode classNode) {
-        for (ClassNode delegate : delegates) {
-            if (delegate.declaresInterface(classNode)) return true;
-        }
-        return false;
-    }
-
     @Override
     public List<MethodNode> getAbstractMethods() {
-        List<MethodNode> allMethods = new LinkedList<MethodNode>();
+        List<MethodNode> answer = new LinkedList<>();
         for (ClassNode delegate : delegates) {
-            allMethods.addAll(delegate.getAbstractMethods());
+            answer.addAll(delegate.getAbstractMethods());
         }
-        return allMethods;
+        return answer;
     }
 
     @Override
     public List<MethodNode> getAllDeclaredMethods() {
-        List<MethodNode> allMethods = new LinkedList<MethodNode>();
+        List<MethodNode> answer = new LinkedList<>();
         for (ClassNode delegate : delegates) {
-            allMethods.addAll(delegate.getAllDeclaredMethods());
+            answer.addAll(delegate.getAllDeclaredMethods());
         }
-        return allMethods;
+        return answer;
     }
 
     @Override
     public Set<ClassNode> getAllInterfaces() {
-        Set<ClassNode> allMethods = new HashSet<ClassNode>();
+        Set<ClassNode> answer = new HashSet<>();
         for (ClassNode delegate : delegates) {
-            allMethods.addAll(delegate.getAllInterfaces());
+            answer.addAll(delegate.getAllInterfaces());
         }
-        return allMethods;
+        return answer;
     }
 
     @Override
     public List<AnnotationNode> getAnnotations() {
-        List<AnnotationNode> nodes = new LinkedList<AnnotationNode>();
+        List<AnnotationNode> answer = new LinkedList<>();
         for (ClassNode delegate : delegates) {
             List<AnnotationNode> annotations = delegate.getAnnotations();
-            if (annotations != null) nodes.addAll(annotations);
+            if (annotations != null) answer.addAll(annotations);
         }
-        return nodes;
+        return answer;
     }
 
     @Override
     public List<AnnotationNode> getAnnotations(final ClassNode type) {
-        List<AnnotationNode> nodes = new LinkedList<AnnotationNode>();
+        List<AnnotationNode> answer = new LinkedList<>();
         for (ClassNode delegate : delegates) {
             List<AnnotationNode> annotations = delegate.getAnnotations(type);
-            if (annotations != null) nodes.addAll(annotations);
+            if (annotations != null) answer.addAll(annotations);
         }
-        return nodes;
+        return answer;
     }
 
     @Override
@@ -234,11 +225,11 @@ class UnionTypeClassNode extends ClassNode {
 
     @Override
     public List<ConstructorNode> getDeclaredConstructors() {
-        List<ConstructorNode> nodes = new LinkedList<ConstructorNode>();
+        List<ConstructorNode> answer = new LinkedList<>();
         for (ClassNode delegate : delegates) {
-            nodes.addAll(delegate.getDeclaredConstructors());
+            answer.addAll(delegate.getDeclaredConstructors());
         }
-        return nodes;
+        return answer;
     }
 
     @Override
@@ -261,12 +252,12 @@ class UnionTypeClassNode extends ClassNode {
 
     @Override
     public List<MethodNode> getDeclaredMethods(final String name) {
-        List<MethodNode> nodes = new LinkedList<MethodNode>();
+        List<MethodNode> answer = new LinkedList<>();
         for (ClassNode delegate : delegates) {
             List<MethodNode> methods = delegate.getDeclaredMethods(name);
-            if (methods != null) nodes.addAll(methods);
+            if (methods != null) answer.addAll(methods);
         }
-        return nodes;
+        return answer;
     }
 
     @Override
@@ -290,12 +281,12 @@ class UnionTypeClassNode extends ClassNode {
 
     @Override
     public List<FieldNode> getFields() {
-        List<FieldNode> nodes = new LinkedList<FieldNode>();
+        List<FieldNode> answer = new LinkedList<>();
         for (ClassNode delegate : delegates) {
             List<FieldNode> fields = delegate.getFields();
-            if (fields != null) nodes.addAll(fields);
+            if (fields != null) answer.addAll(fields);
         }
-        return nodes;
+        return answer;
     }
 
     @Override
@@ -305,22 +296,25 @@ class UnionTypeClassNode extends ClassNode {
 
     @Override
     public ClassNode[] getInterfaces() {
-        Set<ClassNode> nodes = new LinkedHashSet<ClassNode>();
+        Set<ClassNode> answer = new LinkedHashSet<>();
         for (ClassNode delegate : delegates) {
-            ClassNode[] interfaces = delegate.getInterfaces();
-            if (interfaces != null) Collections.addAll(nodes, interfaces);
+            if (delegate.isInterface()) {
+                answer.remove(delegate); answer.add(delegate);
+            } else {
+                answer.addAll(Arrays.asList(delegate.getInterfaces()));
+            }
         }
-        return nodes.toArray(ClassNode.EMPTY_ARRAY);
+        return answer.toArray(ClassNode.EMPTY_ARRAY);
     }
 
     @Override
     public List<MethodNode> getMethods() {
-        List<MethodNode> nodes = new LinkedList<MethodNode>();
+        List<MethodNode> answer = new LinkedList<>();
         for (ClassNode delegate : delegates) {
             List<MethodNode> methods = delegate.getMethods();
-            if (methods != null) nodes.addAll(methods);
+            if (methods != null) answer.addAll(methods);
         }
-        return nodes;
+        return answer;
     }
 
     @Override
@@ -334,12 +328,12 @@ class UnionTypeClassNode extends ClassNode {
 
     @Override
     public List<PropertyNode> getProperties() {
-        List<PropertyNode> nodes = new LinkedList<PropertyNode>();
+        List<PropertyNode> answer = new LinkedList<>();
         for (ClassNode delegate : delegates) {
             List<PropertyNode> properties = delegate.getProperties();
-            if (properties != null) nodes.addAll(properties);
+            if (properties != null) answer.addAll(properties);
         }
-        return nodes;
+        return answer;
     }
 
     @Override
@@ -349,22 +343,18 @@ class UnionTypeClassNode extends ClassNode {
 
     @Override
     public ClassNode[] getUnresolvedInterfaces() {
-        Set<ClassNode> nodes = new LinkedHashSet<ClassNode>();
-        for (ClassNode delegate : delegates) {
-            ClassNode[] interfaces = delegate.getUnresolvedInterfaces();
-            if (interfaces != null) Collections.addAll(nodes, interfaces);
-        }
-        return nodes.toArray(ClassNode.EMPTY_ARRAY);
+        return getUnresolvedInterfaces(false);
     }
 
     @Override
     public ClassNode[] getUnresolvedInterfaces(final boolean useRedirect) {
-        Set<ClassNode> nodes = new LinkedHashSet<ClassNode>();
-        for (ClassNode delegate : delegates) {
-            ClassNode[] interfaces = 
delegate.getUnresolvedInterfaces(useRedirect);
-            if (interfaces != null) Collections.addAll(nodes, interfaces);
+        ClassNode[] interfaces = getInterfaces();
+        if (useRedirect) {
+            for (int i = 0; i < interfaces.length; ++i) {
+                interfaces[i] = interfaces[i].redirect();
+            }
         }
-        return nodes.toArray(ClassNode.EMPTY_ARRAY);
+        return interfaces;
     }
 
     @Override
diff --git a/src/test/groovy/transform/stc/LoopsSTCTest.groovy 
b/src/test/groovy/transform/stc/LoopsSTCTest.groovy
index dd159a4a2a..4dd61a129f 100644
--- a/src/test/groovy/transform/stc/LoopsSTCTest.groovy
+++ b/src/test/groovy/transform/stc/LoopsSTCTest.groovy
@@ -225,6 +225,20 @@ class LoopsSTCTest extends StaticTypeCheckingTestCase {
         '''
     }
 
+    // GROOVY-11335
+    void testForInLoopOnCollection() {
+        assertScript '''
+            def whatever(Collection<String> coll) {
+                if (coll instanceof Serializable) {
+                    for (item in coll) {
+                        return item.toLowerCase()
+                    }
+                }
+            }
+            assert whatever(['Works']) == 'works'
+        '''
+    }
+
     // GROOVY-6123
     void testForInLoopOnEnumeration() {
         assertScript '''

Reply via email to