You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@groovy.apache.org by em...@apache.org on 2021/03/22 17:53:16 UTC

[groovy] 01/01: GROOVY-9995: infer ctor call diamond type from closure target type

This is an automated email from the ASF dual-hosted git repository.

emilles pushed a commit to branch GROOVY-9995
in repository https://gitbox.apache.org/repos/asf/groovy.git

commit 23ebc17764c6c94cce722eef90c8656f803afa75
Author: Eric Milles <er...@thomsonreuters.com>
AuthorDate: Mon Mar 22 12:27:23 2021 -0500

    GROOVY-9995: infer ctor call diamond type from closure target type
---
 .../transform/stc/StaticTypeCheckingVisitor.java   | 25 +++++++++++++++++-----
 .../groovy/transform/stc/GenericsSTCTest.groovy    | 20 +++++++++++++++++
 2 files changed, 40 insertions(+), 5 deletions(-)

diff --git a/src/main/java/org/codehaus/groovy/transform/stc/StaticTypeCheckingVisitor.java b/src/main/java/org/codehaus/groovy/transform/stc/StaticTypeCheckingVisitor.java
index 591c2bc..15738af 100644
--- a/src/main/java/org/codehaus/groovy/transform/stc/StaticTypeCheckingVisitor.java
+++ b/src/main/java/org/codehaus/groovy/transform/stc/StaticTypeCheckingVisitor.java
@@ -748,8 +748,12 @@ public class StaticTypeCheckingVisitor extends ClassCodeVisitorSupport {
                 lType = getType(leftExpression);
             } else {
                 lType = getType(leftExpression);
-                if (op == ASSIGN && isFunctionalInterface(lType)) {
-                    processFunctionalInterfaceAssignment(lType, rightExpression);
+                if (op == ASSIGN) {
+                    if (isFunctionalInterface(lType)) {
+                        processFunctionalInterfaceAssignment(lType, rightExpression);
+                    } else if (isClosureWithType(lType) && rightExpression instanceof ClosureExpression) {
+                        storeInferredReturnType(rightExpression, getCombinedBoundType(lType.getGenericsTypes()[0]));
+                    }
                 }
                 rightExpression.visit(this);
             }
@@ -992,6 +996,10 @@ public class StaticTypeCheckingVisitor extends ClassCodeVisitorSupport {
         }
     }
 
+    private static boolean isClosureWithType(final ClassNode type) {
+        return type.equals(CLOSURE_TYPE) && Optional.ofNullable(type.getGenericsTypes()).filter(gts -> gts != null && gts.length == 1).isPresent();
+    }
+
     private boolean isCompoundAssignment(final Expression exp) {
         if (!(exp instanceof BinaryExpression)) return false;
         int type = ((BinaryExpression) exp).getOperation().getType();
@@ -1273,7 +1281,7 @@ public class StaticTypeCheckingVisitor extends ClassCodeVisitorSupport {
 
         ClassNode rTypeInferred, rTypeWrapped; // check for instanceof and spreading
         if (rightExpression instanceof VariableExpression && hasInferredReturnType(rightExpression) && assignmentExpression.getOperation().getType() == ASSIGN) {
-            rTypeInferred = rightExpression.getNodeMetaData(INFERRED_RETURN_TYPE);
+            rTypeInferred = getInferredReturnType(rightExpression);
         } else {
             rTypeInferred = rightExpressionType;
         }
@@ -1845,6 +1853,8 @@ public class StaticTypeCheckingVisitor extends ClassCodeVisitorSupport {
                 ClassNode lType = getType(node);
                 if (isFunctionalInterface(lType)) { // GROOVY-9977
                     processFunctionalInterfaceAssignment(lType, init);
+                } else if (isClosureWithType(lType) && init instanceof ClosureExpression) {
+                    storeInferredReturnType(init, getCombinedBoundType(lType.getGenericsTypes()[0]));
                 }
                 init.visit(this);
                 ClassNode rType = getType(init);
@@ -2145,11 +2155,16 @@ public class StaticTypeCheckingVisitor extends ClassCodeVisitorSupport {
         Expression expression = statement.getExpression();
         ClassNode type;
         if (expression instanceof VariableExpression && hasInferredReturnType(expression)) {
-            type = expression.getNodeMetaData(INFERRED_RETURN_TYPE);
+            type = getInferredReturnType(expression);
         } else {
             type = getType(expression);
         }
         if (typeCheckingContext.getEnclosingClosure() != null) {
+            // GROOVY-9995: return ctor call with diamond operator
+            if (expression instanceof ConstructorCallExpression) {
+                ClassNode inferredClosureReturnType = getInferredReturnType(typeCheckingContext.getEnclosingClosure().getClosureExpression());
+                if (inferredClosureReturnType != null) inferDiamondType((ConstructorCallExpression) expression, inferredClosureReturnType);
+            }
             return type;
         }
         MethodNode enclosingMethod = typeCheckingContext.getEnclosingMethod();
@@ -4157,7 +4172,7 @@ public class StaticTypeCheckingVisitor extends ClassCodeVisitorSupport {
      * @param type the inferred type of {@code expr}
      */
     private ClassNode checkForTargetType(final Expression expr, final ClassNode type) {
-        ClassNode sourceType = (hasInferredReturnType(expr) ? expr.getNodeMetaData(INFERRED_RETURN_TYPE) : type);
+        ClassNode sourceType = (hasInferredReturnType(expr) ? getInferredReturnType(expr) : type);
 
         ClassNode targetType = null;
         MethodNode enclosingMethod = typeCheckingContext.getEnclosingMethod();
diff --git a/src/test/groovy/transform/stc/GenericsSTCTest.groovy b/src/test/groovy/transform/stc/GenericsSTCTest.groovy
index 087eaab..5fff955 100644
--- a/src/test/groovy/transform/stc/GenericsSTCTest.groovy
+++ b/src/test/groovy/transform/stc/GenericsSTCTest.groovy
@@ -475,6 +475,26 @@ class GenericsSTCTest extends StaticTypeCheckingTestCase {
         '''
     }
 
+    // GROOVY-9995
+    void testDiamondInferrenceFromConstructor12() {
+        [
+            ['Closure<A<Long>>', 'java.util.concurrent.Callable<A<Long>>'],
+            ['new A<>(42L)', 'return new A<>(42L)']
+        ].combinations { functionalType, returnStmt ->
+            assertScript """
+                @groovy.transform.TupleConstructor(defaults=false)
+                class A<T extends Number> {
+                    T p
+                }
+                $functionalType callable = { ->
+                    $returnStmt
+                }
+                Long n = callable.call().p
+                assert n == 42L
+            """
+        }
+    }
+
     void testLinkedListWithListArgument() {
         assertScript '''
             List<String> list = new LinkedList<String>(['1','2','3'])