You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@groovy.apache.org by em...@apache.org on 2021/05/11 21:25:17 UTC

[groovy] branch master updated: GROOVY-8409, GROOVY-9762, GROOVY-9803: STC: SAM parameter type inference

This is an automated email from the ASF dual-hosted git repository.

emilles pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/groovy.git


The following commit(s) were added to refs/heads/master by this push:
     new ae9ada9  GROOVY-8409, GROOVY-9762, GROOVY-9803: STC: SAM parameter type inference
ae9ada9 is described below

commit ae9ada909b7054af68b4ac1986ea7803a8365ab0
Author: Eric Milles <er...@thomsonreuters.com>
AuthorDate: Tue May 11 16:15:13 2021 -0500

    GROOVY-8409, GROOVY-9762, GROOVY-9803: STC: SAM parameter type inference
---
 .../transform/stc/StaticTypeCheckingVisitor.java   | 96 +++++++++++++++++++++-
 .../groovy/transform/stc/GenericsSTCTest.groovy    | 44 +++++-----
 2 files changed, 118 insertions(+), 22 deletions(-)

diff --git a/src/main/java/org/codehaus/groovy/transform/stc/StaticTypeCheckingVisitor.java b/src/main/java/org/codehaus/groovy/transform/stc/StaticTypeCheckingVisitor.java
index 9c9e4f4..d4ce68f 100644
--- a/src/main/java/org/codehaus/groovy/transform/stc/StaticTypeCheckingVisitor.java
+++ b/src/main/java/org/codehaus/groovy/transform/stc/StaticTypeCheckingVisitor.java
@@ -5295,12 +5295,19 @@ public class StaticTypeCheckingVisitor extends ClassCodeVisitorSupport {
                     ClassNode paramType = parameters[Math.min(i, nParams - 1)].getType();
                     ClassNode argumentType = getDeclaredOrInferredType(expressions.get(i));
 
-                    if (isUsingGenericsOrIsArrayUsingGenerics(paramType)) {
+                    if (GenericsUtils.hasUnresolvedGenerics(paramType)) {
                         // if supplying array param with multiple arguments or single non-array argument, infer using element type
                         if (isVargs && (i >= nParams || (i == nParams - 1 && (nArguments > nParams || !argumentType.isArray())))) {
                             paramType = paramType.getComponentType();
                         }
 
+                        if (isClosureWithType(argumentType)) {
+                            MethodNode sam = findSAM(paramType);
+                            if (sam != null) { // implicit closure coerce
+                                argumentType = convertClosureTypeToSAMType(expressions.get(i), argumentType, sam, paramType);
+                            }
+                        }
+
                         Map<GenericsTypeName, GenericsType> connections = new HashMap<>();
                         extractGenericsConnections(connections, wrapTypeIfNecessary(argumentType), paramType);
                         connections.forEach((gtn, gt) -> resolvedPlaceholders.merge(gtn, gt, (gt1, gt2) -> {
@@ -5431,6 +5438,93 @@ public class StaticTypeCheckingVisitor extends ClassCodeVisitorSupport {
         }
     }
 
+    /**
+     * Converts a Closure type to the appropriate SAM type, which can be used to
+     * infer return type generics.
+     *
+     * @param expression  the argument expression
+     * @param closureType the inferred type of {@code expression}
+     * @param samType     the target type for the argument expression
+     * @return SAM type augmented using information from the argument expression
+     */
+    private static ClassNode convertClosureTypeToSAMType(final Expression expression, final ClassNode closureType, final MethodNode sam, final ClassNode samType) {
+        Map<GenericsTypeName, GenericsType> samTypeConnections = GenericsUtils.extractPlaceholders(samType);
+        samTypeConnections.replaceAll((xx, gt) -> // GROOVY-9762, GROOVY-9803: reduce "? super T" to "T"
+            Optional.ofNullable(gt.getLowerBound()).map(GenericsType::new).orElse(gt)
+        );
+        ClassNode closureReturnType = closureType.getGenericsTypes()[0].getType();
+
+        Parameter[] parameters = sam.getParameters();
+        if (parameters.length > 0
+                && expression instanceof MethodPointerExpression
+                && GenericsUtils.hasUnresolvedGenerics(closureReturnType)) {
+            // try to resolve referenced method type parameters in return type
+            MethodPointerExpression mp = (MethodPointerExpression) expression;
+            List<MethodNode> candidates = mp.getNodeMetaData(MethodNode.class);
+            if (candidates != null && !candidates.isEmpty()) {
+                ClassNode[] paramTypes = applyGenericsContext(samTypeConnections, extractTypesFromParameters(parameters));
+                ClassNode[] matchTypes = candidates.stream()
+                        .map(candidate -> collateMethodReferenceParameterTypes(mp, candidate))
+                        .filter(candidate -> checkSignatureSuitability(candidate, paramTypes))
+                        .findFirst().orElse(null); // TODO: order signatures by param distance
+                if (matchTypes != null) {
+                    Map<GenericsTypeName, GenericsType> connections = new HashMap<>();
+                    for (int i = 0, n = parameters.length; i < n; i += 1) {
+                        // SAM parameters should align with the referenced method's parameters
+                        extractGenericsConnections(connections, paramTypes[i], matchTypes[i]);
+                    }
+                    // convert the method reference's generics into the SAM's generics domain
+                    closureReturnType = applyGenericsContext(connections, closureReturnType);
+                    // apply known generics connections to the SAM's placeholders in the return type
+                    closureReturnType = applyGenericsContext(samTypeConnections, closureReturnType);
+                }
+            }
+        }
+
+        // the SAM's return type exactly corresponds to the inferred closure return type
+        extractGenericsConnections(samTypeConnections, closureReturnType, sam.getReturnType());
+
+        // repeat the same for each parameter given in the ClosureExpression
+        if (parameters.length > 0 && expression instanceof ClosureExpression) {
+            return closureType; // TODO
+        }
+
+        return applyGenericsContext(samTypeConnections, samType.redirect());
+    }
+
+    private static ClassNode[] collateMethodReferenceParameterTypes(final MethodPointerExpression source, final MethodNode target) {
+        Parameter[] params;
+
+        if (target instanceof ExtensionMethodNode && !((ExtensionMethodNode) target).isStaticExtension()) {
+            params = ((ExtensionMethodNode) target).getExtensionMethodNode().getParameters();
+        } else if (!target.isStatic() && source.getExpression() instanceof ClassExpression) {
+            ClassNode thisType = ((ClassExpression) source.getExpression()).getType();
+            // there is an implicit parameter for "String::length"
+            int n = target.getParameters().length;
+            params = new Parameter[n + 1];
+            params[0] = new Parameter(thisType, "");
+            System.arraycopy(target.getParameters(), 0, params, 1, n);
+        } else {
+            params = target.getParameters();
+        }
+
+        return extractTypesFromParameters(params);
+    }
+
+    private static boolean checkSignatureSuitability(final ClassNode[] receiverTypes, final ClassNode[] providerTypes) {
+        int n = receiverTypes.length;
+        if (n != providerTypes.length) {
+            return false;
+        }
+        for (int i = 0; i < n; i += 1) {
+            // for method closure, SAM parameters act like arguments
+            if (!isAssignableTo(providerTypes[i], receiverTypes[i])) {
+                return false;
+            }
+        }
+        return true;
+    }
+
     private ClassNode getDeclaredOrInferredType(final Expression expression) {
         ClassNode declaredOrInferred;
         // in case of "T t = new ExtendsOrImplementsT()", return T for the expression type
diff --git a/src/test/groovy/transform/stc/GenericsSTCTest.groovy b/src/test/groovy/transform/stc/GenericsSTCTest.groovy
index 3bc13a7..7cd445a 100644
--- a/src/test/groovy/transform/stc/GenericsSTCTest.groovy
+++ b/src/test/groovy/transform/stc/GenericsSTCTest.groovy
@@ -1429,7 +1429,7 @@ class GenericsSTCTest extends StaticTypeCheckingTestCase {
         }
     }
 
-    @NotYetImplemented // GROOVY-9803
+    // GROOVY-9803
     void testShouldUseMethodGenericType8() {
         assertScript '''
             def opt = Optional.of(42)
@@ -1438,31 +1438,33 @@ class GenericsSTCTest extends StaticTypeCheckingTestCase {
             assert opt.get() == 42
         '''
         // same as above but with separate type parameter name for each location
-        assertScript '''
-            abstract class A<I,O> {
-                abstract O apply(I input)
-            }
-            class C<T> {
-                static <U> C<U> of(U item) {
-                    new C<U>()
+        ['D.&wrap', 'Collections.&singleton', '{x -> [x].toSet()}', '{Collections.singleton(it)}'].each { toSet ->
+            assertScript """
+                abstract class A<I,O> {
+                    abstract O apply(I input)
                 }
-                def <V> C<V> map(A<? super T, ? super V> func) {
-                    new C<V>()
+                class C<T> {
+                    static <U> C<U> of(U item) {
+                        new C<U>()
+                    }
+                    def <V> C<V> map(A<? super T, ? super V> func) {
+                        new C<V>()
+                    }
                 }
-            }
-            class D {
-                static <W> Set<W> wrap(W o) {
+                class D {
+                    static <W> Set<W> wrap(W o) {
+                    }
                 }
-            }
 
-            void test() {
-                def c = C.of(42)
-                def d = c.map(D.&wrap)
-                def e = d.map(x -> x.first().intValue())
-            }
+                void test() {
+                    def c = C.of(42)
+                    def d = c.map($toSet)
+                    def e = d.map(x -> x.first().intValue())
+                }
 
-            test()
-        '''
+                test()
+            """
+        }
     }
 
     // GROOVY-9945