You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@groovy.apache.org by em...@apache.org on 2021/05/11 21:25:17 UTC
[groovy] branch master updated: GROOVY-8409, GROOVY-9762,
GROOVY-9803: STC: SAM parameter type inference
This is an automated email from the ASF dual-hosted git repository.
emilles pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/groovy.git
The following commit(s) were added to refs/heads/master by this push:
new ae9ada9 GROOVY-8409, GROOVY-9762, GROOVY-9803: STC: SAM parameter type inference
ae9ada9 is described below
commit ae9ada909b7054af68b4ac1986ea7803a8365ab0
Author: Eric Milles <er...@thomsonreuters.com>
AuthorDate: Tue May 11 16:15:13 2021 -0500
GROOVY-8409, GROOVY-9762, GROOVY-9803: STC: SAM parameter type inference
---
.../transform/stc/StaticTypeCheckingVisitor.java | 96 +++++++++++++++++++++-
.../groovy/transform/stc/GenericsSTCTest.groovy | 44 +++++-----
2 files changed, 118 insertions(+), 22 deletions(-)
diff --git a/src/main/java/org/codehaus/groovy/transform/stc/StaticTypeCheckingVisitor.java b/src/main/java/org/codehaus/groovy/transform/stc/StaticTypeCheckingVisitor.java
index 9c9e4f4..d4ce68f 100644
--- a/src/main/java/org/codehaus/groovy/transform/stc/StaticTypeCheckingVisitor.java
+++ b/src/main/java/org/codehaus/groovy/transform/stc/StaticTypeCheckingVisitor.java
@@ -5295,12 +5295,19 @@ public class StaticTypeCheckingVisitor extends ClassCodeVisitorSupport {
ClassNode paramType = parameters[Math.min(i, nParams - 1)].getType();
ClassNode argumentType = getDeclaredOrInferredType(expressions.get(i));
- if (isUsingGenericsOrIsArrayUsingGenerics(paramType)) {
+ if (GenericsUtils.hasUnresolvedGenerics(paramType)) {
// if supplying array param with multiple arguments or single non-array argument, infer using element type
if (isVargs && (i >= nParams || (i == nParams - 1 && (nArguments > nParams || !argumentType.isArray())))) {
paramType = paramType.getComponentType();
}
+ if (isClosureWithType(argumentType)) {
+ MethodNode sam = findSAM(paramType);
+ if (sam != null) { // implicit closure coerce
+ argumentType = convertClosureTypeToSAMType(expressions.get(i), argumentType, sam, paramType);
+ }
+ }
+
Map<GenericsTypeName, GenericsType> connections = new HashMap<>();
extractGenericsConnections(connections, wrapTypeIfNecessary(argumentType), paramType);
connections.forEach((gtn, gt) -> resolvedPlaceholders.merge(gtn, gt, (gt1, gt2) -> {
@@ -5431,6 +5438,93 @@ public class StaticTypeCheckingVisitor extends ClassCodeVisitorSupport {
}
}
+ /**
+ * Converts a Closure type to the appropriate SAM type, which can be used to
+ * infer return type generics.
+ *
+ * @param expression the argument expression
+ * @param closureType the inferred type of {@code expression}
+ * @param samType the target type for the argument expression
+ * @return SAM type augmented using information from the argument expression
+ */
+ private static ClassNode convertClosureTypeToSAMType(final Expression expression, final ClassNode closureType, final MethodNode sam, final ClassNode samType) {
+ Map<GenericsTypeName, GenericsType> samTypeConnections = GenericsUtils.extractPlaceholders(samType);
+ samTypeConnections.replaceAll((xx, gt) -> // GROOVY-9762, GROOVY-9803: reduce "? super T" to "T"
+ Optional.ofNullable(gt.getLowerBound()).map(GenericsType::new).orElse(gt)
+ );
+ ClassNode closureReturnType = closureType.getGenericsTypes()[0].getType();
+
+ Parameter[] parameters = sam.getParameters();
+ if (parameters.length > 0
+ && expression instanceof MethodPointerExpression
+ && GenericsUtils.hasUnresolvedGenerics(closureReturnType)) {
+ // try to resolve referenced method type parameters in return type
+ MethodPointerExpression mp = (MethodPointerExpression) expression;
+ List<MethodNode> candidates = mp.getNodeMetaData(MethodNode.class);
+ if (candidates != null && !candidates.isEmpty()) {
+ ClassNode[] paramTypes = applyGenericsContext(samTypeConnections, extractTypesFromParameters(parameters));
+ ClassNode[] matchTypes = candidates.stream()
+ .map(candidate -> collateMethodReferenceParameterTypes(mp, candidate))
+ .filter(candidate -> checkSignatureSuitability(candidate, paramTypes))
+ .findFirst().orElse(null); // TODO: order signatures by param distance
+ if (matchTypes != null) {
+ Map<GenericsTypeName, GenericsType> connections = new HashMap<>();
+ for (int i = 0, n = parameters.length; i < n; i += 1) {
+ // SAM parameters should align with the referenced method's parameters
+ extractGenericsConnections(connections, paramTypes[i], matchTypes[i]);
+ }
+ // convert the method reference's generics into the SAM's generics domain
+ closureReturnType = applyGenericsContext(connections, closureReturnType);
+ // apply known generics connections to the SAM's placeholders in the return type
+ closureReturnType = applyGenericsContext(samTypeConnections, closureReturnType);
+ }
+ }
+ }
+
+ // the SAM's return type exactly corresponds to the inferred closure return type
+ extractGenericsConnections(samTypeConnections, closureReturnType, sam.getReturnType());
+
+ // repeat the same for each parameter given in the ClosureExpression
+ if (parameters.length > 0 && expression instanceof ClosureExpression) {
+ return closureType; // TODO
+ }
+
+ return applyGenericsContext(samTypeConnections, samType.redirect());
+ }
+
+ private static ClassNode[] collateMethodReferenceParameterTypes(final MethodPointerExpression source, final MethodNode target) {
+ Parameter[] params;
+
+ if (target instanceof ExtensionMethodNode && !((ExtensionMethodNode) target).isStaticExtension()) {
+ params = ((ExtensionMethodNode) target).getExtensionMethodNode().getParameters();
+ } else if (!target.isStatic() && source.getExpression() instanceof ClassExpression) {
+ ClassNode thisType = ((ClassExpression) source.getExpression()).getType();
+ // there is an implicit parameter for "String::length"
+ int n = target.getParameters().length;
+ params = new Parameter[n + 1];
+ params[0] = new Parameter(thisType, "");
+ System.arraycopy(target.getParameters(), 0, params, 1, n);
+ } else {
+ params = target.getParameters();
+ }
+
+ return extractTypesFromParameters(params);
+ }
+
+ private static boolean checkSignatureSuitability(final ClassNode[] receiverTypes, final ClassNode[] providerTypes) {
+ int n = receiverTypes.length;
+ if (n != providerTypes.length) {
+ return false;
+ }
+ for (int i = 0; i < n; i += 1) {
+ // for method closure, SAM parameters act like arguments
+ if (!isAssignableTo(providerTypes[i], receiverTypes[i])) {
+ return false;
+ }
+ }
+ return true;
+ }
+
private ClassNode getDeclaredOrInferredType(final Expression expression) {
ClassNode declaredOrInferred;
// in case of "T t = new ExtendsOrImplementsT()", return T for the expression type
diff --git a/src/test/groovy/transform/stc/GenericsSTCTest.groovy b/src/test/groovy/transform/stc/GenericsSTCTest.groovy
index 3bc13a7..7cd445a 100644
--- a/src/test/groovy/transform/stc/GenericsSTCTest.groovy
+++ b/src/test/groovy/transform/stc/GenericsSTCTest.groovy
@@ -1429,7 +1429,7 @@ class GenericsSTCTest extends StaticTypeCheckingTestCase {
}
}
- @NotYetImplemented // GROOVY-9803
+ // GROOVY-9803
void testShouldUseMethodGenericType8() {
assertScript '''
def opt = Optional.of(42)
@@ -1438,31 +1438,33 @@ class GenericsSTCTest extends StaticTypeCheckingTestCase {
assert opt.get() == 42
'''
// same as above but with separate type parameter name for each location
- assertScript '''
- abstract class A<I,O> {
- abstract O apply(I input)
- }
- class C<T> {
- static <U> C<U> of(U item) {
- new C<U>()
+ ['D.&wrap', 'Collections.&singleton', '{x -> [x].toSet()}', '{Collections.singleton(it)}'].each { toSet ->
+ assertScript """
+ abstract class A<I,O> {
+ abstract O apply(I input)
}
- def <V> C<V> map(A<? super T, ? super V> func) {
- new C<V>()
+ class C<T> {
+ static <U> C<U> of(U item) {
+ new C<U>()
+ }
+ def <V> C<V> map(A<? super T, ? super V> func) {
+ new C<V>()
+ }
}
- }
- class D {
- static <W> Set<W> wrap(W o) {
+ class D {
+ static <W> Set<W> wrap(W o) {
+ }
}
- }
- void test() {
- def c = C.of(42)
- def d = c.map(D.&wrap)
- def e = d.map(x -> x.first().intValue())
- }
+ void test() {
+ def c = C.of(42)
+ def d = c.map($toSet)
+ def e = d.map(x -> x.first().intValue())
+ }
- test()
- '''
+ test()
+ """
+ }
}
// GROOVY-9945