You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@groovy.apache.org by pa...@apache.org on 2019/04/08 12:58:57 UTC

[groovy] 05/20: Reuse the type inference of lambda expression

This is an automated email from the ASF dual-hosted git repository.

paulk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/groovy.git

commit fdce3791fea7777a8854e431461709706be01326
Author: Daniel Sun <su...@apache.org>
AuthorDate: Mon Mar 4 23:26:10 2019 +0800

    Reuse the type inference of lambda expression
---
 .../groovy/classgen/asm/BytecodeHelper.java        | 13 ++++-
 .../asm/sc/AbstractFunctionInterfaceWriter.java    |  3 --
 ...StaticTypesMethodReferenceExpressionWriter.java | 16 +++---
 .../transform/stc/StaticTypeCheckingVisitor.java   | 62 +++++++++++++++-------
 .../stc/MethodReferenceTest.groovy}                |  9 +---
 5 files changed, 64 insertions(+), 39 deletions(-)

diff --git a/src/main/java/org/codehaus/groovy/classgen/asm/BytecodeHelper.java b/src/main/java/org/codehaus/groovy/classgen/asm/BytecodeHelper.java
index b010726..fb35cd9 100644
--- a/src/main/java/org/codehaus/groovy/classgen/asm/BytecodeHelper.java
+++ b/src/main/java/org/codehaus/groovy/classgen/asm/BytecodeHelper.java
@@ -71,10 +71,19 @@ public class BytecodeHelper implements Opcodes {
     }
 
     public static String getMethodDescriptor(ClassNode returnType, Parameter[] parameters) {
+        ClassNode[] parameterTypes = new ClassNode[parameters.length];
+        for (int i = 0; i < parameters.length; i++) {
+            parameterTypes[i] = parameters[i].getType();
+        }
+
+        return getMethodDescriptor(returnType, parameterTypes);
+    }
+
+    public static String getMethodDescriptor(ClassNode returnType, ClassNode[] parameterTypes) {
         StringBuilder buffer = new StringBuilder(100);
         buffer.append("(");
-        for (Parameter parameter : parameters) {
-            buffer.append(getTypeDescription(parameter.getType()));
+        for (ClassNode parameterType : parameterTypes) {
+            buffer.append(getTypeDescription(parameterType));
         }
         buffer.append(")");
         buffer.append(getTypeDescription(returnType));
diff --git a/src/main/java/org/codehaus/groovy/classgen/asm/sc/AbstractFunctionInterfaceWriter.java b/src/main/java/org/codehaus/groovy/classgen/asm/sc/AbstractFunctionInterfaceWriter.java
index f543be8..e74eea3 100644
--- a/src/main/java/org/codehaus/groovy/classgen/asm/sc/AbstractFunctionInterfaceWriter.java
+++ b/src/main/java/org/codehaus/groovy/classgen/asm/sc/AbstractFunctionInterfaceWriter.java
@@ -68,9 +68,6 @@ public interface AbstractFunctionInterfaceWriter {
 
     default Object[] createBootstrapMethodArguments(String abstractMethodDesc, ClassNode methodOwnerClassNode, MethodNode methodNode) {
         Parameter[] parameters = methodNode.getNodeMetaData(ORIGINAL_PARAMETERS_WITH_EXACT_TYPE);
-        if (null == parameters) {
-            parameters = methodNode.getParameters();
-        }
 
         return new Object[]{
                 Type.getType(abstractMethodDesc),
diff --git a/src/main/java/org/codehaus/groovy/classgen/asm/sc/StaticTypesMethodReferenceExpressionWriter.java b/src/main/java/org/codehaus/groovy/classgen/asm/sc/StaticTypesMethodReferenceExpressionWriter.java
index 79c02ae..20a3d63 100644
--- a/src/main/java/org/codehaus/groovy/classgen/asm/sc/StaticTypesMethodReferenceExpressionWriter.java
+++ b/src/main/java/org/codehaus/groovy/classgen/asm/sc/StaticTypesMethodReferenceExpressionWriter.java
@@ -19,27 +19,23 @@
 package org.codehaus.groovy.classgen.asm.sc;
 
 import groovy.lang.GroovyRuntimeException;
-import groovy.lang.Tuple2;
 import org.codehaus.groovy.ast.ClassHelper;
 import org.codehaus.groovy.ast.ClassNode;
 import org.codehaus.groovy.ast.MethodNode;
 import org.codehaus.groovy.ast.Parameter;
 import org.codehaus.groovy.ast.expr.MethodReferenceExpression;
-import org.codehaus.groovy.ast.tools.GenericsUtils;
 import org.codehaus.groovy.ast.tools.ParameterUtils;
 import org.codehaus.groovy.classgen.asm.BytecodeHelper;
 import org.codehaus.groovy.classgen.asm.MethodReferenceExpressionWriter;
 import org.codehaus.groovy.classgen.asm.WriterController;
-import org.objectweb.asm.Handle;
 import org.objectweb.asm.MethodVisitor;
-import org.objectweb.asm.Opcodes;
-import org.objectweb.asm.Type;
 
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.List;
 
 import static org.codehaus.groovy.ast.ClassHelper.getWrapper;
+import static org.codehaus.groovy.transform.stc.StaticTypesMarker.CLOSURE_ARGUMENTS;
 
 /**
  * Writer responsible for generating method reference in statically compiled mode.
@@ -66,12 +62,16 @@ public class StaticTypesMethodReferenceExpressionWriter extends MethodReferenceE
         String mrMethodName = methodReferenceExpression.getMethodName().getText();
 
 
-        MethodNode mrMethodNode = findMrMethodNode(mrMethodName, createParametersWithExactType(abstractMethodNode, functionalInterfaceType), mrExpressionType);
+        ClassNode[] methodReferenceParamTypes = methodReferenceExpression.getNodeMetaData(CLOSURE_ARGUMENTS);
+        Parameter[] parametersWithExactType = createParametersWithExactType(abstractMethodNode, methodReferenceParamTypes);
+        MethodNode mrMethodNode = findMrMethodNode(mrMethodName, parametersWithExactType, mrExpressionType);
 
         if (null == mrMethodNode) {
             throw new GroovyRuntimeException("Failed to find the expected method[" + mrMethodName + "] in type[" + mrExpressionType.getName() + "]");
         }
 
+        mrMethodNode.putNodeMetaData(ORIGINAL_PARAMETERS_WITH_EXACT_TYPE, parametersWithExactType);
+
         MethodVisitor mv = controller.getMethodVisitor();
         mv.visitInvokeDynamicInsn(
                 abstractMethodNode.getName(),
@@ -82,9 +82,7 @@ public class StaticTypesMethodReferenceExpressionWriter extends MethodReferenceE
         controller.getOperandStack().push(redirect);
     }
 
-    private Parameter[] createParametersWithExactType(MethodNode abstractMethodNode, ClassNode functionInterfaceType) {
-        Tuple2<ClassNode[], ClassNode> abstractMethodNodeTypeInfo = GenericsUtils.parameterizeSAM(functionInterfaceType);
-        ClassNode[] inferredParameterTypes = abstractMethodNodeTypeInfo.getV1();
+    private Parameter[] createParametersWithExactType(MethodNode abstractMethodNode, ClassNode[] inferredParameterTypes) {
         Parameter[] parameters = abstractMethodNode.getParameters();
         if (parameters == null) {
             parameters = Parameter.EMPTY_ARRAY;
diff --git a/src/main/java/org/codehaus/groovy/transform/stc/StaticTypeCheckingVisitor.java b/src/main/java/org/codehaus/groovy/transform/stc/StaticTypeCheckingVisitor.java
index 04ce335..adf1466 100644
--- a/src/main/java/org/codehaus/groovy/transform/stc/StaticTypeCheckingVisitor.java
+++ b/src/main/java/org/codehaus/groovy/transform/stc/StaticTypeCheckingVisitor.java
@@ -62,6 +62,7 @@ import org.codehaus.groovy.ast.expr.ElvisOperatorExpression;
 import org.codehaus.groovy.ast.expr.EmptyExpression;
 import org.codehaus.groovy.ast.expr.Expression;
 import org.codehaus.groovy.ast.expr.FieldExpression;
+import org.codehaus.groovy.ast.expr.LambdaExpression;
 import org.codehaus.groovy.ast.expr.ListExpression;
 import org.codehaus.groovy.ast.expr.MapEntryExpression;
 import org.codehaus.groovy.ast.expr.MapExpression;
@@ -177,8 +178,10 @@ import static org.codehaus.groovy.ast.ClassHelper.void_WRAPPER_TYPE;
 import static org.codehaus.groovy.ast.GenericsType.GenericsTypeName;
 import static org.codehaus.groovy.ast.tools.GeneralUtils.args;
 import static org.codehaus.groovy.ast.tools.GeneralUtils.binX;
+import static org.codehaus.groovy.ast.tools.GeneralUtils.block;
 import static org.codehaus.groovy.ast.tools.GeneralUtils.callX;
 import static org.codehaus.groovy.ast.tools.GeneralUtils.castX;
+import static org.codehaus.groovy.ast.tools.GeneralUtils.cloneParams;
 import static org.codehaus.groovy.ast.tools.GeneralUtils.varX;
 import static org.codehaus.groovy.ast.tools.GenericsUtils.findActualTypeByGenericsPlaceholderName;
 import static org.codehaus.groovy.ast.tools.GenericsUtils.makeDeclaringAndActualGenericsTypeMap;
@@ -264,6 +267,7 @@ import static org.codehaus.groovy.transform.stc.StaticTypeCheckingSupport.resolv
 import static org.codehaus.groovy.transform.stc.StaticTypeCheckingSupport.toMethodParametersString;
 import static org.codehaus.groovy.transform.stc.StaticTypeCheckingSupport.typeCheckMethodArgumentWithGenerics;
 import static org.codehaus.groovy.transform.stc.StaticTypeCheckingSupport.typeCheckMethodsWithGenerics;
+import static org.codehaus.groovy.transform.stc.StaticTypesMarker.CLOSURE_ARGUMENTS;
 import static org.codehaus.groovy.transform.stc.StaticTypesMarker.INFERRED_TYPE;
 //import static org.codehaus.groovy.syntax.Types.COMPARE_NOT_INSTANCEOF;
 
@@ -950,15 +954,15 @@ public class StaticTypeCheckingVisitor extends ClassCodeVisitorSupport {
                 if (leftExpression instanceof VariableExpression) {
                     if (rightExpression instanceof ClosureExpression) {
                         Parameter[] parameters = ((ClosureExpression) rightExpression).getParameters();
-                        leftExpression.putNodeMetaData(StaticTypesMarker.CLOSURE_ARGUMENTS, parameters);
+                        leftExpression.putNodeMetaData(CLOSURE_ARGUMENTS, parameters);
                     } else if (rightExpression instanceof VariableExpression &&
                             ((VariableExpression) rightExpression).getAccessedVariable() instanceof Expression &&
-                            ((Expression) ((VariableExpression) rightExpression).getAccessedVariable()).getNodeMetaData(StaticTypesMarker.CLOSURE_ARGUMENTS) != null) {
+                            ((Expression) ((VariableExpression) rightExpression).getAccessedVariable()).getNodeMetaData(CLOSURE_ARGUMENTS) != null) {
                         Variable targetVariable = findTargetVariable((VariableExpression) leftExpression);
                         if (targetVariable instanceof ASTNode) {
                             ((ASTNode) targetVariable).putNodeMetaData(
-                                    StaticTypesMarker.CLOSURE_ARGUMENTS,
-                                    ((Expression) ((VariableExpression) rightExpression).getAccessedVariable()).getNodeMetaData(StaticTypesMarker.CLOSURE_ARGUMENTS));
+                                    CLOSURE_ARGUMENTS,
+                                    ((Expression) ((VariableExpression) rightExpression).getAccessedVariable()).getNodeMetaData(CLOSURE_ARGUMENTS));
                         }
                     }
                 }
@@ -2882,7 +2886,7 @@ public class StaticTypeCheckingVisitor extends ClassCodeVisitorSupport {
         // to the corresponding parameters of the SAM type method
         MethodNode methodForSAM = findSAM(classForSAM);
         ClassNode[] parameterTypesForSAM = extractTypesFromParameters(methodForSAM.getParameters());
-        ClassNode[] blockParameterTypes = (ClassNode[]) openBlock.getNodeMetaData(StaticTypesMarker.CLOSURE_ARGUMENTS);
+        ClassNode[] blockParameterTypes = (ClassNode[]) openBlock.getNodeMetaData(CLOSURE_ARGUMENTS);
         if (blockParameterTypes == null) {
             Parameter[] p = openBlock.getParameters();
             if (p == null) {
@@ -2916,7 +2920,7 @@ public class StaticTypeCheckingVisitor extends ClassCodeVisitorSupport {
 
         tryToInferUnresolvedBlockParameterType(paramTypeWithReceiverInformation, methodForSAM, blockParameterTypes);
 
-        openBlock.putNodeMetaData(StaticTypesMarker.CLOSURE_ARGUMENTS, blockParameterTypes);
+        openBlock.putNodeMetaData(CLOSURE_ARGUMENTS, blockParameterTypes);
     }
 
     private void tryToInferUnresolvedBlockParameterType(ClassNode paramTypeWithReceiverInformation, MethodNode methodForSAM, ClassNode[] blockParameterTypes) {
@@ -3051,7 +3055,7 @@ public class StaticTypeCheckingVisitor extends ClassCodeVisitorSupport {
         if (candidates.size() == 1) {
             ClassNode[] inferred = candidates.get(0);
             if (closureParams.length == 0 && inferred.length == 1) {
-                expression.putNodeMetaData(StaticTypesMarker.CLOSURE_ARGUMENTS, inferred);
+                expression.putNodeMetaData(CLOSURE_ARGUMENTS, inferred);
             } else {
                 final int length = closureParams.length;
                 for (int i = 0; i < length; i++) {
@@ -3380,7 +3384,7 @@ public class StaticTypeCheckingVisitor extends ClassCodeVisitorSupport {
                     GenericsType[] genericsTypes = field.getType().getGenericsTypes();
                     if (genericsTypes != null) {
                         ClassNode closureReturnType = genericsTypes[0].getType();
-                        Object data = field.getNodeMetaData(StaticTypesMarker.CLOSURE_ARGUMENTS);
+                        Object data = field.getNodeMetaData(CLOSURE_ARGUMENTS);
                         if (data != null) {
                             Parameter[] parameters = (Parameter[]) data;
                             typeCheckClosureCall(callArguments, args, parameters);
@@ -3390,7 +3394,7 @@ public class StaticTypeCheckingVisitor extends ClassCodeVisitorSupport {
                 } else if (objectExpression instanceof VariableExpression) {
                     Variable variable = findTargetVariable((VariableExpression) objectExpression);
                     if (variable instanceof ASTNode) {
-                        Object data = ((ASTNode) variable).getNodeMetaData(StaticTypesMarker.CLOSURE_ARGUMENTS);
+                        Object data = ((ASTNode) variable).getNodeMetaData(CLOSURE_ARGUMENTS);
                         if (data != null) {
                             Parameter[] parameters = (Parameter[]) data;
                             typeCheckClosureCall(callArguments, args, parameters);
@@ -3616,24 +3620,46 @@ public class StaticTypeCheckingVisitor extends ClassCodeVisitorSupport {
 
     private void inferMethodReferenceType(MethodCallExpression call, ClassNode receiver, ArgumentListExpression argumentList) {
         MethodNode selectedMethod = call.getNodeMetaData(StaticTypesMarker.DIRECT_METHOD_CALL_TARGET);
+        Parameter[] parameters = selectedMethod.getParameters();
         List<Expression> argumentExpressionList = argumentList.getExpressions();
 
-        int methodReferenceParamCnt = 0;
+        List<Integer> methodReferenceParamIndexList = new LinkedList<>();
+        List<Expression> newArgumentExpressionList = new LinkedList<>();
         for (int i = 0, n = argumentExpressionList.size(); i < n; i++) {
             Expression argumentExpression = argumentExpressionList.get(i);
             if (!(argumentExpression instanceof MethodReferenceExpression)) {
+                newArgumentExpressionList.add(argumentExpression);
                 continue;
             }
 
-            // TODO transform method reference to lambda expression
-            methodReferenceParamCnt++;
+            Parameter param = parameters[i];
+            ClassNode paramType = param.getType();
+            MethodNode abstractMethodNode = ClassHelper.findSAM(paramType);
+
+            Parameter[] abstractMethodNodeParameters = abstractMethodNode.getParameters();
+            if (null == abstractMethodNodeParameters) {
+                abstractMethodNodeParameters = Parameter.EMPTY_ARRAY;
+            }
+
+            LambdaExpression lambdaExpression =
+                    new LambdaExpression(
+                            cloneParams(abstractMethodNodeParameters),
+                            block()
+                    );
+
+            newArgumentExpressionList.add(lambdaExpression);
+            methodReferenceParamIndexList.add(i);
         }
 
-        if (0 == methodReferenceParamCnt) return;
+        if (methodReferenceParamIndexList.isEmpty()) return;
 
-        visitMethodCallArguments(receiver, argumentList, true, selectedMethod);
+        visitMethodCallArguments(receiver, new ArgumentListExpression(newArgumentExpressionList), true, selectedMethod);
 
-        // TODO get the inferred types and store them in the node metadata
+        for (Integer methodReferenceParamIndex : methodReferenceParamIndexList) {
+            LambdaExpression lambdaExpression = (LambdaExpression) newArgumentExpressionList.get(methodReferenceParamIndex);
+            ClassNode[] argumentTypes = lambdaExpression.getNodeMetaData(CLOSURE_ARGUMENTS);
+            argumentExpressionList.get(methodReferenceParamIndex).putNodeMetaData(CLOSURE_ARGUMENTS, argumentTypes);
+        }
     }
 
     // adjust data to handle cases like nested .with since we didn't have enough information earlier
@@ -5017,7 +5043,7 @@ public class StaticTypeCheckingVisitor extends ClassCodeVisitorSupport {
 
     private ClassNode getTypeFromClosureArguments(Parameter parameter, TypeCheckingContext.EnclosingClosure enclosingClosure) {
         ClosureExpression closureExpression = enclosingClosure.getClosureExpression();
-        ClassNode[] closureParamTypes = (ClassNode[]) closureExpression.getNodeMetaData(StaticTypesMarker.CLOSURE_ARGUMENTS);
+        ClassNode[] closureParamTypes = (ClassNode[]) closureExpression.getNodeMetaData(CLOSURE_ARGUMENTS);
         if (closureParamTypes == null) return null;
         final Parameter[] parameters = closureExpression.getParameters();
         String name = parameter.getName();
@@ -5369,8 +5395,8 @@ public class StaticTypeCheckingVisitor extends ClassCodeVisitorSupport {
                 List<ClassNode[]> genericsToConnect = new LinkedList<ClassNode[]>();
                 Parameter[] closureParams = ((ClosureExpression) expression).getParameters();
                 ClassNode[] closureParamTypes = extractTypesFromParameters(closureParams);
-                if (expression.getNodeMetaData(StaticTypesMarker.CLOSURE_ARGUMENTS) != null) {
-                    closureParamTypes = expression.getNodeMetaData(StaticTypesMarker.CLOSURE_ARGUMENTS);
+                if (expression.getNodeMetaData(CLOSURE_ARGUMENTS) != null) {
+                    closureParamTypes = expression.getNodeMetaData(CLOSURE_ARGUMENTS);
                 }
                 final Parameter[] parameters = sam.getParameters();
                 for (int i = 0; i < parameters.length; i++) {
diff --git a/src/test/groovy/bugs/Groovy9008.groovy b/src/test/groovy/transform/stc/MethodReferenceTest.groovy
similarity index 90%
rename from src/test/groovy/bugs/Groovy9008.groovy
rename to src/test/groovy/transform/stc/MethodReferenceTest.groovy
index e774c13..29856d2 100644
--- a/src/test/groovy/bugs/Groovy9008.groovy
+++ b/src/test/groovy/transform/stc/MethodReferenceTest.groovy
@@ -16,13 +16,10 @@
  *  specific language governing permissions and limitations
  *  under the License.
  */
-package groovy.bugs
-
-class Groovy9008 extends GroovyTestCase {
-    private static final boolean SKIP = true // TODO remove it
+package groovy.transform.stc
 
+class MethodReferenceTest extends GroovyTestCase {
     void testMethodReferenceFunction() {
-        if (SKIP) return
 
         assertScript '''
             import java.util.stream.Collectors
@@ -39,8 +36,6 @@ class Groovy9008 extends GroovyTestCase {
     }
 
     void testMethodReferenceBinaryOperator() {
-        if (SKIP) return
-
         assertScript '''
             import java.util.stream.Stream