You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@groovy.apache.org by em...@apache.org on 2022/02/09 20:10:35 UTC

[groovy] 02/03: GROOVY-10476: Stream provides iterator() but does not implement Iterable

This is an automated email from the ASF dual-hosted git repository.

emilles pushed a commit to branch GROOVY_4_0_X
in repository https://gitbox.apache.org/repos/asf/groovy.git

commit 4b247376a4b1f856b2116abbf8ce27ce7abe5633
Author: Eric Milles <er...@thomsonreuters.com>
AuthorDate: Wed Feb 9 12:05:11 2022 -0600

    GROOVY-10476: Stream provides iterator() but does not implement Iterable
---
 .../java/org/codehaus/groovy/ast/ClassHelper.java  |   2 +
 .../asm/sc/StaticTypesStatementWriter.java         | 140 +++++++++------------
 .../transform/stc/StaticTypeCheckingVisitor.java   |  37 +++---
 .../groovy/transform/stc/MethodCallsSTCTest.groovy |  13 +-
 4 files changed, 91 insertions(+), 101 deletions(-)

diff --git a/src/main/java/org/codehaus/groovy/ast/ClassHelper.java b/src/main/java/org/codehaus/groovy/ast/ClassHelper.java
index c41dfbf..95acea2 100644
--- a/src/main/java/org/codehaus/groovy/ast/ClassHelper.java
+++ b/src/main/java/org/codehaus/groovy/ast/ClassHelper.java
@@ -71,6 +71,7 @@ import java.util.List;
 import java.util.Map;
 import java.util.Set;
 import java.util.regex.Pattern;
+import java.util.stream.Stream;
 
 /**
  * Helper for {@link ClassNode} and classes handling them.  Contains a set of
@@ -151,6 +152,7 @@ public class ClassHelper {
             Enum_Type = makeWithoutCaching(Enum.class),
             CLASS_Type = makeWithoutCaching(Class.class),
             TUPLE_TYPE = makeWithoutCaching(Tuple.class),
+            STREAM_TYPE = makeWithoutCaching(Stream.class),
             ITERABLE_TYPE = makeWithoutCaching(Iterable.class),
             REFERENCE_TYPE = makeWithoutCaching(Reference.class),
             COLLECTION_TYPE = makeWithoutCaching(Collection.class),
diff --git a/src/main/java/org/codehaus/groovy/classgen/asm/sc/StaticTypesStatementWriter.java b/src/main/java/org/codehaus/groovy/classgen/asm/sc/StaticTypesStatementWriter.java
index cc02e8f..69c09c3 100644
--- a/src/main/java/org/codehaus/groovy/classgen/asm/sc/StaticTypesStatementWriter.java
+++ b/src/main/java/org/codehaus/groovy/classgen/asm/sc/StaticTypesStatementWriter.java
@@ -20,34 +20,25 @@ package org.codehaus.groovy.classgen.asm.sc;
 
 import org.codehaus.groovy.ast.ClassHelper;
 import org.codehaus.groovy.ast.ClassNode;
+import org.codehaus.groovy.ast.MethodNode;
 import org.codehaus.groovy.ast.Parameter;
-import org.codehaus.groovy.ast.expr.ArgumentListExpression;
 import org.codehaus.groovy.ast.expr.Expression;
 import org.codehaus.groovy.ast.expr.MethodCallExpression;
 import org.codehaus.groovy.ast.stmt.BlockStatement;
 import org.codehaus.groovy.ast.stmt.ForStatement;
+import org.codehaus.groovy.ast.tools.GeneralUtils;
 import org.codehaus.groovy.classgen.AsmClassGenerator;
 import org.codehaus.groovy.classgen.asm.BytecodeVariable;
 import org.codehaus.groovy.classgen.asm.CompileStack;
 import org.codehaus.groovy.classgen.asm.MethodCaller;
 import org.codehaus.groovy.classgen.asm.OperandStack;
 import org.codehaus.groovy.classgen.asm.StatementWriter;
-import org.codehaus.groovy.classgen.asm.TypeChooser;
-import org.codehaus.groovy.transform.stc.StaticTypeCheckingSupport;
 import org.objectweb.asm.Label;
 import org.objectweb.asm.MethodVisitor;
 
 import java.util.Enumeration;
+import java.util.Objects;
 
-import static org.codehaus.groovy.ast.ClassHelper.isPrimitiveType;
-import static org.codehaus.groovy.ast.ClassHelper.isPrimitiveBoolean;
-import static org.codehaus.groovy.ast.ClassHelper.isPrimitiveByte;
-import static org.codehaus.groovy.ast.ClassHelper.isPrimitiveChar;
-import static org.codehaus.groovy.ast.ClassHelper.isPrimitiveDouble;
-import static org.codehaus.groovy.ast.ClassHelper.isPrimitiveFloat;
-import static org.codehaus.groovy.ast.ClassHelper.isPrimitiveInt;
-import static org.codehaus.groovy.ast.ClassHelper.isPrimitiveLong;
-import static org.codehaus.groovy.ast.ClassHelper.isPrimitiveShort;
 import static org.objectweb.asm.Opcodes.AALOAD;
 import static org.objectweb.asm.Opcodes.ALOAD;
 import static org.objectweb.asm.Opcodes.ARRAYLENGTH;
@@ -72,65 +63,60 @@ import static org.objectweb.asm.Opcodes.SALOAD;
  */
 public class StaticTypesStatementWriter extends StatementWriter {
 
-    private static final ClassNode ITERABLE_CLASSNODE = ClassHelper.make(Iterable.class);
     private static final ClassNode ENUMERATION_CLASSNODE = ClassHelper.make(Enumeration.class);
     private static final MethodCaller ENUMERATION_NEXT_METHOD = MethodCaller.newInterface(Enumeration.class, "nextElement");
     private static final MethodCaller ENUMERATION_HASMORE_METHOD = MethodCaller.newInterface(Enumeration.class, "hasMoreElements");
 
-    public StaticTypesStatementWriter(StaticTypesWriterController controller) {
+    public StaticTypesStatementWriter(final StaticTypesWriterController controller) {
         super(controller);
     }
 
     @Override
-    public void writeBlockStatement(BlockStatement statement) {
+    public void writeBlockStatement(final BlockStatement statement) {
         controller.switchToFastPath();
         super.writeBlockStatement(statement);
         controller.switchToSlowPath();
     }
 
+    //--------------------------------------------------------------------------
+
     @Override
     protected void writeForInLoop(final ForStatement loop) {
-        controller.getAcg().onLineNumber(loop,"visitForLoop");
+        controller.getAcg().onLineNumber(loop, "visitForLoop");
         writeStatementLabel(loop);
 
         CompileStack compileStack = controller.getCompileStack();
-        MethodVisitor mv = controller.getMethodVisitor();
         OperandStack operandStack = controller.getOperandStack();
 
         compileStack.pushLoop(loop.getVariableScope(), loop.getStatementLabels());
 
-        // Identify type of collection
-        TypeChooser typeChooser = controller.getTypeChooser();
+        // identify type of collection
         Expression collectionExpression = loop.getCollectionExpression();
-        ClassNode collectionType = typeChooser.resolveType(collectionExpression, controller.getClassNode());
+        ClassNode collectionType = controller.getTypeChooser().resolveType(collectionExpression, controller.getClassNode());
+
+        int mark = operandStack.getStackLength();
         Parameter loopVariable = loop.getVariable();
-        int size = operandStack.getStackLength();
         if (collectionType.isArray() && loopVariable.getOriginType().equals(collectionType.getComponentType())) {
-            writeOptimizedForEachLoop(compileStack, operandStack, mv, loop, collectionExpression, collectionType, loopVariable);
-        } else if (ENUMERATION_CLASSNODE.equals(collectionType)) {
+            writeOptimizedForEachLoop(loop, loopVariable, collectionExpression, collectionType);
+        } else if (GeneralUtils.isOrImplements(collectionType, ENUMERATION_CLASSNODE)) {
             writeEnumerationBasedForEachLoop(loop, collectionExpression, collectionType);
         } else {
             writeIteratorBasedForEachLoop(loop, collectionExpression, collectionType);
         }
-        operandStack.popDownTo(size);
+        operandStack.popDownTo(mark);
         compileStack.pop();
     }
 
-    private void writeOptimizedForEachLoop(
-            CompileStack compileStack,
-            OperandStack operandStack,
-            MethodVisitor mv,
-            ForStatement loop,
-            Expression collectionExpression,
-            ClassNode collectionType,
-            Parameter loopVariable) {
-        BytecodeVariable variable = compileStack.defineVariable(loopVariable, false);
+    private void writeOptimizedForEachLoop(final ForStatement loop, final Parameter loopVariable, final Expression collectionExpression, final ClassNode collectionType) {
+        CompileStack compileStack = controller.getCompileStack();
+        OperandStack operandStack = controller.getOperandStack();
+        MethodVisitor mv = controller.getMethodVisitor();
+        AsmClassGenerator acg = controller.getAcg();
 
+        BytecodeVariable variable = compileStack.defineVariable(loopVariable, false);
         Label continueLabel = compileStack.getContinueLabel();
         Label breakLabel = compileStack.getBreakLabel();
 
-        AsmClassGenerator acg = controller.getAcg();
-
         // load array on stack
         collectionExpression.visit(acg);
         mv.visitInsn(DUP);
@@ -155,9 +141,9 @@ public class StaticTypesStatementWriter extends StatementWriter {
         mv.visitJumpInsn(IF_ICMPGE, breakLabel);
 
         // get array element
-        loadFromArray(mv, variable, array, loopIdx);
+        loadFromArray(mv, operandStack, variable, array, loopIdx);
 
-        // $idx++
+        // $idx += 1
         mv.visitIincInsn(loopIdx, 1);
 
         // loop body
@@ -172,27 +158,24 @@ public class StaticTypesStatementWriter extends StatementWriter {
         compileStack.removeVar(array);
     }
 
-    private void loadFromArray(MethodVisitor mv, BytecodeVariable variable, int array, int iteratorIdx) {
-        OperandStack os = controller.getOperandStack();
+    private static void loadFromArray(final MethodVisitor mv, final OperandStack os, final BytecodeVariable variable, final int array, final int index) {
         mv.visitVarInsn(ALOAD, array);
-        mv.visitVarInsn(ILOAD, iteratorIdx);
-
+        mv.visitVarInsn(ILOAD, index);
         ClassNode varType = variable.getType();
-
-        if (isPrimitiveType(varType)) {
-            if (isPrimitiveInt(varType)) {
+        if (ClassHelper.isPrimitiveType(varType)) {
+            if (ClassHelper.isPrimitiveInt(varType)) {
                 mv.visitInsn(IALOAD);
-            } else if (isPrimitiveLong(varType)) {
+            } else if (ClassHelper.isPrimitiveLong(varType)) {
                 mv.visitInsn(LALOAD);
-            } else if (isPrimitiveByte(varType) || isPrimitiveBoolean(varType)) {
+            } else if (ClassHelper.isPrimitiveByte(varType) || ClassHelper.isPrimitiveBoolean(varType)) {
                 mv.visitInsn(BALOAD);
-            } else if (isPrimitiveChar(varType)) {
+            } else if (ClassHelper.isPrimitiveChar(varType)) {
                 mv.visitInsn(CALOAD);
-            } else if (isPrimitiveShort(varType)) {
+            } else if (ClassHelper.isPrimitiveShort(varType)) {
                 mv.visitInsn(SALOAD);
-            } else if (isPrimitiveFloat(varType)) {
+            } else if (ClassHelper.isPrimitiveFloat(varType)) {
                 mv.visitInsn(FALOAD);
-            } else if (isPrimitiveDouble(varType)) {
+            } else if (ClassHelper.isPrimitiveDouble(varType)) {
                 mv.visitInsn(DALOAD);
             }
         } else {
@@ -202,46 +185,19 @@ public class StaticTypesStatementWriter extends StatementWriter {
         os.storeVar(variable);
     }
 
-    private void writeIteratorBasedForEachLoop(
-            ForStatement loop,
-            Expression collectionExpression,
-            ClassNode collectionType) {
-
-        if (StaticTypeCheckingSupport.implementsInterfaceOrIsSubclassOf(collectionType, ITERABLE_CLASSNODE)) {
-            MethodCallExpression iterator = new MethodCallExpression(collectionExpression, "iterator", new ArgumentListExpression());
-            iterator.setMethodTarget(collectionType.getMethod("iterator", Parameter.EMPTY_ARRAY));
-            iterator.setImplicitThis(false);
-            iterator.visit(controller.getAcg());
-        } else {
-            collectionExpression.visit(controller.getAcg());
-            controller.getMethodVisitor().visitMethodInsn(INVOKESTATIC, "org/codehaus/groovy/runtime/DefaultGroovyMethods", "iterator", "(Ljava/lang/Object;)Ljava/util/Iterator;", false);
-            controller.getOperandStack().replace(ClassHelper.Iterator_TYPE);
-        }
-
-        writeForInLoopControlAndBlock(loop);
-    }
-
-    private void writeEnumerationBasedForEachLoop(
-            ForStatement loop,
-            Expression collectionExpression,
-            ClassNode collectionType) {
-
+    private void writeEnumerationBasedForEachLoop(final ForStatement loop, final Expression collectionExpression, final ClassNode collectionType) {
         CompileStack compileStack = controller.getCompileStack();
-        MethodVisitor mv = controller.getMethodVisitor();
         OperandStack operandStack = controller.getOperandStack();
+        MethodVisitor mv = controller.getMethodVisitor();
 
-        // Declare the loop counter.
         BytecodeVariable variable = compileStack.defineVariable(loop.getVariable(), false);
+        Label continueLabel = compileStack.getContinueLabel();
+        Label breakLabel = compileStack.getBreakLabel();
 
         collectionExpression.visit(controller.getAcg());
 
-        // Then get the iterator and generate the loop control
-
         int enumIdx = compileStack.defineTemporaryVariable("$enum", ENUMERATION_CLASSNODE, true);
 
-        Label continueLabel = compileStack.getContinueLabel();
-        Label breakLabel = compileStack.getBreakLabel();
-
         mv.visitLabel(continueLabel);
         mv.visitVarInsn(ALOAD, enumIdx);
         ENUMERATION_HASMORE_METHOD.call(mv);
@@ -253,12 +209,30 @@ public class StaticTypesStatementWriter extends StatementWriter {
         operandStack.push(ClassHelper.OBJECT_TYPE);
         operandStack.storeVar(variable);
 
-        // Generate the loop body
         loop.getLoopBlock().visit(controller.getAcg());
 
         mv.visitJumpInsn(GOTO, continueLabel);
         mv.visitLabel(breakLabel);
-
     }
 
+    private void writeIteratorBasedForEachLoop(final ForStatement loop, final Expression collectionExpression, final ClassNode collectionType) {
+        // GROOVY-10476: BaseStream provides an iterator() but does not implement Iterable
+        MethodNode iterator = collectionType.getMethod("iterator", Parameter.EMPTY_ARRAY);
+        if (iterator == null) {
+            iterator = GeneralUtils.getInterfacesAndSuperInterfaces(collectionType).stream()
+                    .map(in -> in.getMethod("iterator", Parameter.EMPTY_ARRAY))
+                    .filter(Objects::nonNull).findFirst().orElse(null);
+        }
+        if (iterator != null && iterator.getReturnType().equals(ClassHelper.Iterator_TYPE)) {
+            MethodCallExpression call = GeneralUtils.callX(collectionExpression, "iterator");
+            call.setImplicitThis(false);
+            call.setMethodTarget(iterator);
+            call.visit(controller.getAcg());
+        } else {
+            collectionExpression.visit(controller.getAcg());
+            controller.getMethodVisitor().visitMethodInsn(INVOKESTATIC, "org/codehaus/groovy/runtime/DefaultGroovyMethods", "iterator", "(Ljava/lang/Object;)Ljava/util/Iterator;", false);
+            controller.getOperandStack().replace(ClassHelper.Iterator_TYPE);
+        }
+        writeForInLoopControlAndBlock(loop);
+    }
 }
diff --git a/src/main/java/org/codehaus/groovy/transform/stc/StaticTypeCheckingVisitor.java b/src/main/java/org/codehaus/groovy/transform/stc/StaticTypeCheckingVisitor.java
index b121c4b..3dabf27 100644
--- a/src/main/java/org/codehaus/groovy/transform/stc/StaticTypeCheckingVisitor.java
+++ b/src/main/java/org/codehaus/groovy/transform/stc/StaticTypeCheckingVisitor.java
@@ -163,6 +163,7 @@ import static org.codehaus.groovy.ast.ClassHelper.OBJECT_TYPE;
 import static org.codehaus.groovy.ast.ClassHelper.PATTERN_TYPE;
 import static org.codehaus.groovy.ast.ClassHelper.RANGE_TYPE;
 import static org.codehaus.groovy.ast.ClassHelper.SET_TYPE;
+import static org.codehaus.groovy.ast.ClassHelper.STREAM_TYPE;
 import static org.codehaus.groovy.ast.ClassHelper.STRING_TYPE;
 import static org.codehaus.groovy.ast.ClassHelper.Short_TYPE;
 import static org.codehaus.groovy.ast.ClassHelper.VOID_TYPE;
@@ -1985,32 +1986,34 @@ public class StaticTypeCheckingVisitor extends ClassCodeVisitorSupport {
     }
 
     /**
-     * Given a loop collection type, returns the inferred type of the loop element. Used, for
-     * example, to infer the element type of a (for e in list) loop.
+     * Returns the inferred loop element type given a loop collection type. Used,
+     * for example, to infer the element type of a {@code for (e in list)} loop.
      *
      * @param collectionType the type of the collection
      * @return the inferred component type
+     * @see #inferComponentType
      */
     public static ClassNode inferLoopElementType(final ClassNode collectionType) {
         ClassNode componentType = collectionType.getComponentType();
         if (componentType == null) {
-            if (implementsInterfaceOrIsSubclassOf(collectionType, ITERABLE_TYPE)) {
-                ClassNode intf = GenericsUtils.parameterizeType(collectionType, ITERABLE_TYPE);
-                GenericsType[] genericsTypes = intf.getGenericsTypes();
-                componentType = genericsTypes[0].getType();
-            } else if (implementsInterfaceOrIsSubclassOf(collectionType, MAP_TYPE)) {
-                // GROOVY-6240
-                ClassNode intf = GenericsUtils.parameterizeType(collectionType, MAP_TYPE);
-                GenericsType[] genericsTypes = intf.getGenericsTypes();
-                componentType = MAP_ENTRY_TYPE.getPlainNodeReference();
-                componentType.setGenericsTypes(genericsTypes);
+            if (isOrImplements(collectionType, ITERABLE_TYPE)) {
+                ClassNode col = GenericsUtils.parameterizeType(collectionType, ITERABLE_TYPE);
+                componentType = getCombinedBoundType(col.getGenericsTypes()[0]);
+
+            } else if (isOrImplements(collectionType, MAP_TYPE)) { // GROOVY-6240
+                ClassNode col = GenericsUtils.parameterizeType(collectionType, MAP_TYPE);
+                componentType = makeClassSafe0(MAP_ENTRY_TYPE, col.getGenericsTypes());
+
+            } else if (isOrImplements(collectionType, STREAM_TYPE)) { // GROOVY-10476
+                ClassNode col = GenericsUtils.parameterizeType(collectionType, STREAM_TYPE);
+                componentType = getCombinedBoundType(col.getGenericsTypes()[0]);
+
+            } else if (isOrImplements(collectionType, ENUMERATION_TYPE)) { // GROOVY-6123
+                ClassNode col = GenericsUtils.parameterizeType(collectionType, ENUMERATION_TYPE);
+                componentType = getCombinedBoundType(col.getGenericsTypes()[0]);
+
             } else if (isStringType(collectionType)) {
                 componentType = STRING_TYPE;
-            } else if (ENUMERATION_TYPE.equals(collectionType)) {
-                // GROOVY-6123
-                ClassNode intf = GenericsUtils.parameterizeType(collectionType, ENUMERATION_TYPE);
-                GenericsType[] genericsTypes = intf.getGenericsTypes();
-                componentType = genericsTypes[0].getType();
             } else {
                 componentType = OBJECT_TYPE;
             }
diff --git a/src/test/groovy/transform/stc/MethodCallsSTCTest.groovy b/src/test/groovy/transform/stc/MethodCallsSTCTest.groovy
index dc0fb19..4d29461 100644
--- a/src/test/groovy/transform/stc/MethodCallsSTCTest.groovy
+++ b/src/test/groovy/transform/stc/MethodCallsSTCTest.groovy
@@ -1161,7 +1161,7 @@ class MethodCallsSTCTest extends StaticTypeCheckingTestCase {
     }
 
     // GROOVY-8133
-    void testSpreadDotOperator() {
+    void testSpreadDot() {
         assertScript '''
             def list = ['a','b','c'].stream()*.toUpperCase()
             assert list == ['A', 'B', 'C']
@@ -1186,6 +1186,17 @@ class MethodCallsSTCTest extends StaticTypeCheckingTestCase {
         '''
     }
 
+    // GROOVY-10476
+    void testForInLoop() {
+        assertScript '''
+            def list = []
+            for (item in ['a','b','c'].stream()) {
+                list.add(item.toUpperCase())
+            }
+            assert list == ['A', 'B', 'C']
+        '''
+    }
+
     void testBoxingShouldCostMore() {
         assertScript '''
             int foo(int x) { 1 }