You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@groovy.apache.org by em...@apache.org on 2022/07/30 15:24:36 UTC

[groovy] branch GROOVY_3_0_X updated (a01446dbe2 -> d250ef17bc)

This is an automated email from the ASF dual-hosted git repository.

emilles pushed a change to branch GROOVY_3_0_X
in repository https://gitbox.apache.org/repos/asf/groovy.git


    from a01446dbe2 GROOVY-10618, GROOVY-10711: SC: `BooleanExpression` and `NotExpression`
     new 5a231d6f9e GROOVY-8487: SC: for-in loop over iterator
     new d250ef17bc GROOVY-10712: STC: for-in loop over iterator

The 2 revisions listed above as "new" are entirely new to this
repository and will be described in separate emails.  The revisions
listed as "add" were already present in the repository and have only
been added to this reference.


Summary of changes:
 .../asm/sc/StaticTypesStatementWriter.java         | 158 ++++++++-----------
 .../groovy/runtime/DefaultGroovyMethods.java       |   3 +-
 .../transform/stc/StaticTypeCheckingVisitor.java   |  35 +++--
 src/test/groovy/transform/stc/LoopsSTCTest.groovy  | 175 +++++++++++----------
 4 files changed, 183 insertions(+), 188 deletions(-)


[groovy] 01/02: GROOVY-8487: SC: for-in loop over iterator

Posted by em...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

emilles pushed a commit to branch GROOVY_3_0_X
in repository https://gitbox.apache.org/repos/asf/groovy.git

commit 5a231d6f9e5e7141f675234fe5ff362bf734bee2
Author: Eric Milles <er...@thomsonreuters.com>
AuthorDate: Thu Jul 28 16:35:49 2022 -0500

    GROOVY-8487: SC: for-in loop over iterator
    
    3_0_X backport
---
 .../asm/sc/StaticTypesStatementWriter.java         | 158 ++++++++-----------
 .../groovy/runtime/DefaultGroovyMethods.java       |   3 +-
 src/test/groovy/transform/stc/LoopsSTCTest.groovy  | 175 +++++++++++----------
 3 files changed, 164 insertions(+), 172 deletions(-)

diff --git a/src/main/java/org/codehaus/groovy/classgen/asm/sc/StaticTypesStatementWriter.java b/src/main/java/org/codehaus/groovy/classgen/asm/sc/StaticTypesStatementWriter.java
index 30cfbc48b8..fb1d11365e 100644
--- a/src/main/java/org/codehaus/groovy/classgen/asm/sc/StaticTypesStatementWriter.java
+++ b/src/main/java/org/codehaus/groovy/classgen/asm/sc/StaticTypesStatementWriter.java
@@ -20,20 +20,19 @@ package org.codehaus.groovy.classgen.asm.sc;
 
 import org.codehaus.groovy.ast.ClassHelper;
 import org.codehaus.groovy.ast.ClassNode;
+import org.codehaus.groovy.ast.MethodNode;
 import org.codehaus.groovy.ast.Parameter;
-import org.codehaus.groovy.ast.expr.ArgumentListExpression;
 import org.codehaus.groovy.ast.expr.Expression;
 import org.codehaus.groovy.ast.expr.MethodCallExpression;
 import org.codehaus.groovy.ast.stmt.BlockStatement;
 import org.codehaus.groovy.ast.stmt.ForStatement;
+import org.codehaus.groovy.ast.tools.GeneralUtils;
 import org.codehaus.groovy.classgen.AsmClassGenerator;
 import org.codehaus.groovy.classgen.asm.BytecodeVariable;
 import org.codehaus.groovy.classgen.asm.CompileStack;
 import org.codehaus.groovy.classgen.asm.MethodCaller;
 import org.codehaus.groovy.classgen.asm.OperandStack;
 import org.codehaus.groovy.classgen.asm.StatementWriter;
-import org.codehaus.groovy.classgen.asm.TypeChooser;
-import org.codehaus.groovy.transform.stc.StaticTypeCheckingSupport;
 import org.objectweb.asm.Label;
 import org.objectweb.asm.MethodVisitor;
 
@@ -63,64 +62,60 @@ import static org.objectweb.asm.Opcodes.SALOAD;
  */
 public class StaticTypesStatementWriter extends StatementWriter {
 
-    private static final ClassNode ITERABLE_CLASSNODE = ClassHelper.make(Iterable.class);
     private static final ClassNode ENUMERATION_CLASSNODE = ClassHelper.make(Enumeration.class);
     private static final MethodCaller ENUMERATION_NEXT_METHOD = MethodCaller.newInterface(Enumeration.class, "nextElement");
     private static final MethodCaller ENUMERATION_HASMORE_METHOD = MethodCaller.newInterface(Enumeration.class, "hasMoreElements");
 
-    public StaticTypesStatementWriter(StaticTypesWriterController controller) {
+    public StaticTypesStatementWriter(final StaticTypesWriterController controller) {
         super(controller);
     }
 
     @Override
-    public void writeBlockStatement(BlockStatement statement) {
+    public void writeBlockStatement(final BlockStatement statement) {
         controller.switchToFastPath();
         super.writeBlockStatement(statement);
         controller.switchToSlowPath();
     }
 
+    //--------------------------------------------------------------------------
+
     @Override
     protected void writeForInLoop(final ForStatement loop) {
-        controller.getAcg().onLineNumber(loop,"visitForLoop");
+        controller.getAcg().onLineNumber(loop, "visitForLoop");
         writeStatementLabel(loop);
 
         CompileStack compileStack = controller.getCompileStack();
-        MethodVisitor mv = controller.getMethodVisitor();
         OperandStack operandStack = controller.getOperandStack();
 
         compileStack.pushLoop(loop.getVariableScope(), loop.getStatementLabels());
 
-        // Identify type of collection
-        TypeChooser typeChooser = controller.getTypeChooser();
+        // identify type of collection
         Expression collectionExpression = loop.getCollectionExpression();
-        ClassNode collectionType = typeChooser.resolveType(collectionExpression, controller.getClassNode());
+        ClassNode collectionType = controller.getTypeChooser().resolveType(collectionExpression, controller.getClassNode());
+
+        int mark = operandStack.getStackLength();
         Parameter loopVariable = loop.getVariable();
-        int size = operandStack.getStackLength();
         if (collectionType.isArray() && loopVariable.getType().equals(collectionType.getComponentType())) {
-            writeOptimizedForEachLoop(compileStack, operandStack, mv, loop, collectionExpression, collectionType, loopVariable);
-        } else if (ENUMERATION_CLASSNODE.equals(collectionType)) {
+            writeOptimizedForEachLoop(loop, loopVariable, collectionExpression, collectionType);
+        } else if (GeneralUtils.isOrImplements(collectionType, ENUMERATION_CLASSNODE)) {
             writeEnumerationBasedForEachLoop(loop, collectionExpression, collectionType);
         } else {
             writeIteratorBasedForEachLoop(loop, collectionExpression, collectionType);
         }
-        operandStack.popDownTo(size);
+        operandStack.popDownTo(mark);
         compileStack.pop();
     }
 
-    private void writeOptimizedForEachLoop(
-            CompileStack compileStack,
-            OperandStack operandStack,
-            MethodVisitor mv,
-            ForStatement loop,
-            Expression arrayExpression,
-            ClassNode arrayType,
-            Parameter loopVariable) {
+    private void writeOptimizedForEachLoop(final ForStatement loop, final Parameter loopVariable, final Expression arrayExpression, final ClassNode arrayType) {
+        CompileStack compileStack = controller.getCompileStack();
+        OperandStack operandStack = controller.getOperandStack();
+        MethodVisitor mv = controller.getMethodVisitor();
+        AsmClassGenerator acg = controller.getAcg();
+
         BytecodeVariable variable = compileStack.defineVariable(loopVariable, arrayType.getComponentType(), false);
         Label continueLabel = compileStack.getContinueLabel();
         Label breakLabel = compileStack.getBreakLabel();
 
-        AsmClassGenerator acg = controller.getAcg();
-
         // load array on stack
         arrayExpression.visit(acg);
         mv.visitInsn(DUP);
@@ -145,9 +140,9 @@ public class StaticTypesStatementWriter extends StatementWriter {
         mv.visitJumpInsn(IF_ICMPGE, breakLabel);
 
         // get array element
-        loadFromArray(mv, variable, array, loopIdx);
+        loadFromArray(mv, operandStack, variable, array, loopIdx);
 
-        // $idx++
+        // $idx += 1
         mv.visitIincInsn(loopIdx, 1);
 
         // loop body
@@ -162,39 +157,24 @@ public class StaticTypesStatementWriter extends StatementWriter {
         compileStack.removeVar(array);
     }
 
-    private void loadFromArray(MethodVisitor mv, BytecodeVariable variable, int array, int iteratorIdx) {
-        OperandStack os = controller.getOperandStack();
+    private static void loadFromArray(final MethodVisitor mv, final OperandStack os, final BytecodeVariable variable, final int array, final int index) {
         mv.visitVarInsn(ALOAD, array);
-        mv.visitVarInsn(ILOAD, iteratorIdx);
-
+        mv.visitVarInsn(ILOAD, index);
         ClassNode varType = variable.getType();
-        boolean primitiveType = ClassHelper.isPrimitiveType(varType);
-        boolean isByte = ClassHelper.byte_TYPE.equals(varType);
-        boolean isShort = ClassHelper.short_TYPE.equals(varType);
-        boolean isInt = ClassHelper.int_TYPE.equals(varType);
-        boolean isLong = ClassHelper.long_TYPE.equals(varType);
-        boolean isFloat = ClassHelper.float_TYPE.equals(varType);
-        boolean isDouble = ClassHelper.double_TYPE.equals(varType);
-        boolean isChar = ClassHelper.char_TYPE.equals(varType);
-        boolean isBoolean = ClassHelper.boolean_TYPE.equals(varType);
-
-        if (primitiveType) {
-            if (isByte) {
+        if (ClassHelper.isPrimitiveType(varType)) {
+            if (varType.equals(ClassHelper.int_TYPE)) {
+                mv.visitInsn(IALOAD);
+            } else if (varType.equals(ClassHelper.long_TYPE)) {
+                mv.visitInsn(LALOAD);
+            } else if (varType.equals(ClassHelper.byte_TYPE) || varType.equals(ClassHelper.boolean_TYPE)) {
                 mv.visitInsn(BALOAD);
-            }
-            if (isShort) {
+            } else if (varType.equals(ClassHelper.char_TYPE)) {
+                mv.visitInsn(CALOAD);
+            } else if (varType.equals(ClassHelper.short_TYPE)) {
                 mv.visitInsn(SALOAD);
-            }
-            if (isInt || isChar || isBoolean) {
-                mv.visitInsn(isChar ? CALOAD : isBoolean ? BALOAD : IALOAD);
-            }
-            if (isLong) {
-                mv.visitInsn(LALOAD);
-            }
-            if (isFloat) {
+            } else if (varType.equals(ClassHelper.float_TYPE)) {
                 mv.visitInsn(FALOAD);
-            }
-            if (isDouble) {
+            } else if (varType.equals(ClassHelper.double_TYPE)) {
                 mv.visitInsn(DALOAD);
             }
         } else {
@@ -204,61 +184,55 @@ public class StaticTypesStatementWriter extends StatementWriter {
         os.storeVar(variable);
     }
 
-    private void writeIteratorBasedForEachLoop(
-            ForStatement loop,
-            Expression collectionExpression,
-            ClassNode collectionType) {
-
-        if (StaticTypeCheckingSupport.implementsInterfaceOrIsSubclassOf(collectionType, ITERABLE_CLASSNODE)) {
-            MethodCallExpression iterator = new MethodCallExpression(collectionExpression, "iterator", new ArgumentListExpression());
-            iterator.setMethodTarget(collectionType.getMethod("iterator", Parameter.EMPTY_ARRAY));
-            iterator.setImplicitThis(false);
-            iterator.visit(controller.getAcg());
-        } else {
-            collectionExpression.visit(controller.getAcg());
-            controller.getMethodVisitor().visitMethodInsn(INVOKESTATIC, "org/codehaus/groovy/runtime/DefaultGroovyMethods", "iterator", "(Ljava/lang/Object;)Ljava/util/Iterator;", false);
-            controller.getOperandStack().replace(ClassHelper.Iterator_TYPE);
-        }
-
-        writeForInLoopControlAndBlock(loop);
-    }
-
-    private void writeEnumerationBasedForEachLoop(
-            ForStatement loop,
-            Expression collectionExpression,
-            ClassNode collectionType) {
-
+    private void writeEnumerationBasedForEachLoop(final ForStatement loop, final Expression collectionExpression, final ClassNode collectionType) {
         CompileStack compileStack = controller.getCompileStack();
-        MethodVisitor mv = controller.getMethodVisitor();
         OperandStack operandStack = controller.getOperandStack();
+        MethodVisitor mv = controller.getMethodVisitor();
 
-        // Declare the loop counter.
         BytecodeVariable variable = compileStack.defineVariable(loop.getVariable(), false);
+        Label continueLabel = compileStack.getContinueLabel();
+        Label breakLabel = compileStack.getBreakLabel();
 
         collectionExpression.visit(controller.getAcg());
 
-        // Then get the iterator and generate the loop control
-
-        int enumIdx = compileStack.defineTemporaryVariable("$enum", ENUMERATION_CLASSNODE, true);
+        int enumeration = compileStack.defineTemporaryVariable("$enum", ENUMERATION_CLASSNODE, true);
 
-        Label continueLabel = compileStack.getContinueLabel();
-        Label breakLabel = compileStack.getBreakLabel();
+        mv.visitVarInsn(ALOAD, enumeration);
+        mv.visitJumpInsn(IFNULL, breakLabel);
 
         mv.visitLabel(continueLabel);
-        mv.visitVarInsn(ALOAD, enumIdx);
+
+        mv.visitVarInsn(ALOAD, enumeration);
         ENUMERATION_HASMORE_METHOD.call(mv);
-        // note: ifeq tests for ==0, a boolean is 0 if it is false
-        mv.visitJumpInsn(IFEQ, breakLabel);
+        mv.visitJumpInsn(IFEQ, breakLabel); // jump if zero (aka false)
 
-        mv.visitVarInsn(ALOAD, enumIdx);
+        mv.visitVarInsn(ALOAD, enumeration);
         ENUMERATION_NEXT_METHOD.call(mv);
         operandStack.push(ClassHelper.OBJECT_TYPE);
         operandStack.storeVar(variable);
 
-        // Generate the loop body
         loop.getLoopBlock().visit(controller.getAcg());
-
         mv.visitJumpInsn(GOTO, continueLabel);
+
         mv.visitLabel(breakLabel);
     }
+
+    private void writeIteratorBasedForEachLoop(final ForStatement loop, final Expression collectionExpression, final ClassNode collectionType) {
+        if (GeneralUtils.isOrImplements(collectionType, ClassHelper.Iterator_TYPE)) {
+            collectionExpression.visit(controller.getAcg()); // GROOVY-8487: iterator supplied
+        } else {
+            MethodNode iterator = collectionType.getMethod("iterator", Parameter.EMPTY_ARRAY);
+            if (iterator != null && iterator.getReturnType().equals(ClassHelper.Iterator_TYPE)) {
+                MethodCallExpression call = GeneralUtils.callX(collectionExpression, "iterator");
+                call.setImplicitThis(false);
+                call.setMethodTarget(iterator);
+                call.visit(controller.getAcg());
+            } else {
+                collectionExpression.visit(controller.getAcg());
+                controller.getMethodVisitor().visitMethodInsn(INVOKESTATIC, "org/codehaus/groovy/runtime/DefaultGroovyMethods", "iterator", "(Ljava/lang/Object;)Ljava/util/Iterator;", false);
+                controller.getOperandStack().replace(ClassHelper.Iterator_TYPE);
+            }
+        }
+        writeForInLoopControlAndBlock(loop);
+    }
 }
diff --git a/src/main/java/org/codehaus/groovy/runtime/DefaultGroovyMethods.java b/src/main/java/org/codehaus/groovy/runtime/DefaultGroovyMethods.java
index 69e6c3c3e8..3df39e4f46 100644
--- a/src/main/java/org/codehaus/groovy/runtime/DefaultGroovyMethods.java
+++ b/src/main/java/org/codehaus/groovy/runtime/DefaultGroovyMethods.java
@@ -18278,7 +18278,8 @@ public class DefaultGroovyMethods extends DefaultGroovyMethodsSupport {
      * @see org.codehaus.groovy.runtime.typehandling.DefaultTypeTransformation#asCollection(java.lang.Object)
      * @since 1.0
      */
-    public static Iterator iterator(Object o) {
+    public static Iterator iterator(final Object o) {
+        if (o instanceof Iterator) return (Iterator)o;
         return DefaultTypeTransformation.asCollection(o).iterator();
     }
 
diff --git a/src/test/groovy/transform/stc/LoopsSTCTest.groovy b/src/test/groovy/transform/stc/LoopsSTCTest.groovy
index d4824ce9b2..db2d219bb9 100644
--- a/src/test/groovy/transform/stc/LoopsSTCTest.groovy
+++ b/src/test/groovy/transform/stc/LoopsSTCTest.groovy
@@ -24,23 +24,15 @@ package groovy.transform.stc
 class LoopsSTCTest extends StaticTypeCheckingTestCase {
 
     void testMethodCallInLoop() {
-        assertScript '''
-            int foo(int x) { x+1 }
-            int x = 0
-            for (int i=0;i<10;i++) {
-                x = foo(x)
-            }
-        '''
-    }
-
-    void testMethodCallInLoopAndDef() {
-        assertScript '''
-            int foo(int x) { x+1 }
-            def x = 0
-            for (int i=0;i<10;i++) {
-                x = foo(x)
-            }
-        '''
+        for (type in ['int', 'def']) {
+            assertScript """
+                int foo(int x) { x+1 }
+                $type x = 0
+                for (int i=0;i<10;i++) {
+                    x = foo(x)
+                }
+            """
+        }
     }
 
     void testMethodCallWithEachAndDefAndTwoFooMethods() {
@@ -52,7 +44,8 @@ class LoopsSTCTest extends StaticTypeCheckingTestCase {
                  // there are two possible target methods. This is not a problem for STC, but it is for static compilation
                 x = foo(x)
             }
-        ''', 'Cannot find matching method'
+        ''',
+        'Cannot find matching method'
     }
 
     void testMethodCallInLoopAndDefAndTwoFooMethods() {
@@ -64,7 +57,8 @@ class LoopsSTCTest extends StaticTypeCheckingTestCase {
                  // there are two possible target methods. This is not a problem for STC, but it is for static compilation
                 x = foo(x)
             }
-        ''', 'Cannot find matching method'
+        ''',
+        'Cannot find matching method'
     }
 
     void testMethodCallInLoopAndDefAndTwoFooMethodsAndOneWithBadType() {
@@ -77,7 +71,8 @@ class LoopsSTCTest extends StaticTypeCheckingTestCase {
                 // then called in turn as a parameter of foo(). There's no #foo(Date)
                 x = foo(x)
             }
-        ''', 'Cannot find matching method'
+        ''',
+        'Cannot find matching method'
     }
 
     void testMethodCallInLoopAndDefAndTwoFooMethodsAndOneWithBadTypeAndIndirection() {
@@ -91,7 +86,8 @@ class LoopsSTCTest extends StaticTypeCheckingTestCase {
                 // then called in turn as a parameter of foo(). There's no #foo(Date)
                 x = y
             }
-        ''', 'Cannot find matching method'
+        ''',
+        'Cannot find matching method'
     }
 
     void testMethodCallWithEachAndDefAndTwoFooMethodsAndOneWithBadTypeAndIndirection() {
@@ -105,7 +101,8 @@ class LoopsSTCTest extends StaticTypeCheckingTestCase {
                 // then called in turn as a parameter of foo(). There's no #foo(Date)
                 x = y
             }
-        ''', 'Cannot find matching method'
+        ''',
+        'Cannot find matching method'
     }
 
     // GROOVY-5587
@@ -175,7 +172,17 @@ class LoopsSTCTest extends StaticTypeCheckingTestCase {
         '''
     }
 
+    // GROOVY-8643
     void testForInLoopOnArray() {
+        assertScript '''
+            String[] strings = null
+            for (string in strings) {
+                string.toUpperCase()
+            }
+        '''
+    }
+
+    void testForInLoopOnArray2() {
         assertScript '''
             String[] strings = ['a','b','c']
             for (string in strings) {
@@ -185,7 +192,7 @@ class LoopsSTCTest extends StaticTypeCheckingTestCase {
     }
 
     // GROOVY-10579
-    void testForInLoopOnArray2() {
+    void testForInLoopOnArray3() {
         assertScript '''
             int[] numbers = [1,2,3,4,5]
             int sum = 0
@@ -208,12 +215,23 @@ class LoopsSTCTest extends StaticTypeCheckingTestCase {
         '''
     }
 
+    // GROOVY-8487
+    void testForInLoopOnIterator() {
+        assertScript '''
+            def list = []
+            for (item in ['a','b','c'].iterator()) {
+                list.add(item)
+            }
+            assert list.join('') == 'abc'
+        '''
+    }
+
     // GROOVY-6123
     void testForInLoopOnEnumeration() {
         assertScript '''
             Vector<String> v = new Vector<>()
             v.add('ooo')
-            def en = v.elements()
+            Enumeration<String> en = v.elements()
             for (e in en) {
                 assert e.toUpperCase() == 'OOO'
             }
@@ -229,25 +247,27 @@ class LoopsSTCTest extends StaticTypeCheckingTestCase {
                 assert e.toUpperCase() in ['OOO','GROOVY']
                 if (e=='ooo') continue
             }
+
+            en = null
+            for (e in en) {
+                e.toUpperCase()
+            }
         '''
     }
 
-    void testShouldNotInferSoftReferenceAsComponentType() {
-        assertScript '''import java.lang.reflect.Field
-            import org.codehaus.groovy.ast.stmt.ForStatement
-
-            @ASTTest(phase=INSTRUCTION_SELECTION, value= {
-                def FIELD_ARRAY = make(Field).makeArray()
-                def forStmt = lookup('myLoop')[0]
-                assert forStmt instanceof ForStatement
-                def collectionType = forStmt.collectionExpression.getNodeMetaData(INFERRED_TYPE)
-                assert collectionType == FIELD_ARRAY
+    void testShouldNotInferSoftReferenceAsElementType() {
+        assertScript '''
+            @ASTTest(phase=INSTRUCTION_SELECTION, value={
+                def loop = lookup('loop')[0]
+                assert loop instanceof org.codehaus.groovy.ast.stmt.ForStatement
+                def collectionType = loop.collectionExpression.getNodeMetaData(INFERRED_TYPE)
+                assert collectionType == make(java.lang.reflect.Field).makeArray()
             })
-            void forInTest() {
-                int i = 0;
-                myLoop:
+            void test() {
+                int i = 0
+                loop:
                 for (def field : String.class.declaredFields) {
-                    i++;
+                    i++
                 }
                 assert i > 0
             }
@@ -255,56 +275,53 @@ class LoopsSTCTest extends StaticTypeCheckingTestCase {
     }
 
     // GROOVY-5640
-    void testShouldInferComponentTypeAsIterableOfNodes() {
-        assertScript '''import org.codehaus.groovy.ast.stmt.ForStatement
-        class Node {}
-
-        interface Traverser {
-            Iterable<Node> nodes()
-        }
-
-        class MyTraverser implements Traverser {
-
-            Iterable<Node> nodes() {
-                []
+    void testShouldInferNodeElementTypeForIterableOfNodes() {
+        assertScript '''
+            class Node {
             }
-        }
-
-        @ASTTest(phase=INSTRUCTION_SELECTION, value= {
-            def forStmt = lookup('loop')[0]
-            assert forStmt instanceof ForStatement
-            def collectionType = forStmt.collectionExpression.getNodeMetaData(INFERRED_TYPE)
-            assert collectionType == make(Iterable)
-            assert collectionType.isUsingGenerics()
-            assert collectionType.genericsTypes.length == 1
-            assert collectionType.genericsTypes[0].type.name == 'Node'
-        })
-        void test() {
-            loop:
-            for (def node : new MyTraverser().nodes()) {
-                println node.class.name
+            interface Traverser {
+                Iterable<Node> nodes()
+            }
+            class MyTraverser implements Traverser {
+                Iterable<Node> nodes() {
+                    []
+                }
             }
-        }
 
+            @ASTTest(phase=INSTRUCTION_SELECTION, value={
+                def loop = lookup('loop')[0]
+                assert loop instanceof org.codehaus.groovy.ast.stmt.ForStatement
+                def collectionType = loop.collectionExpression.getNodeMetaData(INFERRED_TYPE)
+                assert collectionType == make(Iterable)
+                assert collectionType.isUsingGenerics()
+                assert collectionType.genericsTypes.length == 1
+                assert collectionType.genericsTypes[0].type.name == 'Node'
+            })
+            void test() {
+                loop:
+                for (def node : new MyTraverser().nodes()) {
+                    println node.class.name
+                }
+            }
         '''
     }
 
     // GROOVY-5641
     void testShouldInferLoopElementTypeWithUndeclaredType() {
-        assertScript '''import org.codehaus.groovy.ast.stmt.ForStatement
-        @ASTTest(phase=INSTRUCTION_SELECTION, value= {
-            def forStmt = lookup('loop')[0]
-            assert forStmt instanceof ForStatement
-            def collectionType = forStmt.collectionExpression.getNodeMetaData(INFERRED_TYPE)
-            assert collectionType == make(IntRange)
-        })
-        void foo() {
-            int[] perm = new int[10]
-            loop:
-            for (i in 0..<10) {
-              assert perm[i-0] == 0
+        assertScript '''
+            @ASTTest(phase=INSTRUCTION_SELECTION, value={
+                def loop = lookup('loop')[0]
+                assert loop instanceof org.codehaus.groovy.ast.stmt.ForStatement
+                def collectionType = loop.collectionExpression.getNodeMetaData(INFERRED_TYPE)
+                assert collectionType == make(IntRange)
+            })
+            void test() {
+                int[] ints = new int[10]
+                loop:
+                for (i in 0..<10) {
+                  assert ints[i-0] == 0
+                }
             }
-        }
         '''
     }
 }


[groovy] 02/02: GROOVY-10712: STC: for-in loop over iterator

Posted by em...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

emilles pushed a commit to branch GROOVY_3_0_X
in repository https://gitbox.apache.org/repos/asf/groovy.git

commit d250ef17bc7d0d37e7a19e1507b73b4eea18d06c
Author: Eric Milles <er...@thomsonreuters.com>
AuthorDate: Sat Jul 30 10:15:06 2022 -0500

    GROOVY-10712: STC: for-in loop over iterator
    
    2_5_X backport
---
 .../transform/stc/StaticTypeCheckingVisitor.java   | 35 ++++++++++++----------
 src/test/groovy/transform/stc/LoopsSTCTest.groovy  |  6 ++--
 2 files changed, 22 insertions(+), 19 deletions(-)

diff --git a/src/main/java/org/codehaus/groovy/transform/stc/StaticTypeCheckingVisitor.java b/src/main/java/org/codehaus/groovy/transform/stc/StaticTypeCheckingVisitor.java
index edf4617707..532457bbe4 100644
--- a/src/main/java/org/codehaus/groovy/transform/stc/StaticTypeCheckingVisitor.java
+++ b/src/main/java/org/codehaus/groovy/transform/stc/StaticTypeCheckingVisitor.java
@@ -1973,32 +1973,35 @@ public class StaticTypeCheckingVisitor extends ClassCodeVisitorSupport {
     }
 
     /**
-     * Given a loop collection type, returns the inferred type of the loop element. Used, for
-     * example, to infer the element type of a (for e in list) loop.
+     * Returns the inferred loop element type given a loop collection type. Used,
+     * for example, to infer the element type of a {@code for (e in list)} loop.
      *
      * @param collectionType the type of the collection
      * @return the inferred component type
+     * @see #inferComponentType
      */
     public static ClassNode inferLoopElementType(final ClassNode collectionType) {
         ClassNode componentType = collectionType.getComponentType();
         if (componentType == null) {
             if (implementsInterfaceOrIsSubclassOf(collectionType, ITERABLE_TYPE)) {
-                ClassNode intf = GenericsUtils.parameterizeType(collectionType, ITERABLE_TYPE);
-                GenericsType[] genericsTypes = intf.getGenericsTypes();
-                componentType = genericsTypes[0].getType();
-            } else if (implementsInterfaceOrIsSubclassOf(collectionType, MAP_TYPE)) {
-                // GROOVY-6240
-                ClassNode intf = GenericsUtils.parameterizeType(collectionType, MAP_TYPE);
-                GenericsType[] genericsTypes = intf.getGenericsTypes();
+                ClassNode col = GenericsUtils.parameterizeType(collectionType, ITERABLE_TYPE);
+                componentType = col.getGenericsTypes()[0].getType();
+
+            } else if (implementsInterfaceOrIsSubclassOf(collectionType, MAP_TYPE)) { // GROOVY-6240
+                ClassNode col = GenericsUtils.parameterizeType(collectionType, MAP_TYPE);
                 componentType = MAP_ENTRY_TYPE.getPlainNodeReference();
-                componentType.setGenericsTypes(genericsTypes);
-            } else if (STRING_TYPE.equals(collectionType)) {
+                componentType.setGenericsTypes(col.getGenericsTypes());
+
+            } else if (implementsInterfaceOrIsSubclassOf(collectionType, ENUMERATION_TYPE)) { // GROOVY-6123
+                ClassNode col = GenericsUtils.parameterizeType(collectionType, ENUMERATION_TYPE);
+                componentType = col.getGenericsTypes()[0].getType();
+
+            } else if (implementsInterfaceOrIsSubclassOf(collectionType, Iterator_TYPE)) { // GROOVY-10712
+                ClassNode col = GenericsUtils.parameterizeType(collectionType, Iterator_TYPE);
+                componentType = col.getGenericsTypes()[0].getType();
+
+            } else if (collectionType.equals(STRING_TYPE)) {
                 componentType = STRING_TYPE;
-            } else if (ENUMERATION_TYPE.equals(collectionType)) {
-                // GROOVY-6123
-                ClassNode intf = GenericsUtils.parameterizeType(collectionType, ENUMERATION_TYPE);
-                GenericsType[] genericsTypes = intf.getGenericsTypes();
-                componentType = genericsTypes[0].getType();
             } else {
                 componentType = OBJECT_TYPE;
             }
diff --git a/src/test/groovy/transform/stc/LoopsSTCTest.groovy b/src/test/groovy/transform/stc/LoopsSTCTest.groovy
index db2d219bb9..7291043f50 100644
--- a/src/test/groovy/transform/stc/LoopsSTCTest.groovy
+++ b/src/test/groovy/transform/stc/LoopsSTCTest.groovy
@@ -215,14 +215,14 @@ class LoopsSTCTest extends StaticTypeCheckingTestCase {
         '''
     }
 
-    // GROOVY-8487
+    // GROOVY-8487, GROOVY-10712
     void testForInLoopOnIterator() {
         assertScript '''
             def list = []
             for (item in ['a','b','c'].iterator()) {
-                list.add(item)
+                list.add(item.toUpperCase())
             }
-            assert list.join('') == 'abc'
+            assert list == ['A','B','C']
         '''
     }