You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@groovy.apache.org by em...@apache.org on 2023/01/03 20:06:54 UTC

[groovy] branch master updated: GROOVY-10889: `@NamedVariant`: check casts for parameter reference

This is an automated email from the ASF dual-hosted git repository.

emilles pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/groovy.git


The following commit(s) were added to refs/heads/master by this push:
     new 1516992472 GROOVY-10889: `@NamedVariant`: check casts for parameter reference
1516992472 is described below

commit 1516992472d5304449d2d6659af525305f5c6dc8
Author: Eric Milles <er...@thomsonreuters.com>
AuthorDate: Tue Jan 3 13:52:25 2023 -0600

    GROOVY-10889: `@NamedVariant`: check casts for parameter reference
---
 .../transform/NamedVariantASTTransformation.java   |  78 +++++++-------
 .../transform/NamedVariantTransformTest.groovy     | 118 ++++++++++-----------
 2 files changed, 94 insertions(+), 102 deletions(-)

diff --git a/src/main/java/org/codehaus/groovy/transform/NamedVariantASTTransformation.java b/src/main/java/org/codehaus/groovy/transform/NamedVariantASTTransformation.java
index 0af31ef7ae..6a15eb0035 100644
--- a/src/main/java/org/codehaus/groovy/transform/NamedVariantASTTransformation.java
+++ b/src/main/java/org/codehaus/groovy/transform/NamedVariantASTTransformation.java
@@ -29,10 +29,8 @@ import org.codehaus.groovy.ast.ConstructorNode;
 import org.codehaus.groovy.ast.MethodNode;
 import org.codehaus.groovy.ast.Parameter;
 import org.codehaus.groovy.ast.PropertyNode;
-import org.codehaus.groovy.ast.Variable;
 import org.codehaus.groovy.ast.expr.ArgumentListExpression;
 import org.codehaus.groovy.ast.expr.CastExpression;
-import org.codehaus.groovy.ast.expr.ConstantExpression;
 import org.codehaus.groovy.ast.expr.Expression;
 import org.codehaus.groovy.ast.expr.MethodCallExpression;
 import org.codehaus.groovy.ast.expr.VariableExpression;
@@ -58,10 +56,8 @@ import static org.codehaus.groovy.ast.ClassHelper.MAP_TYPE;
 import static org.codehaus.groovy.ast.ClassHelper.STRING_TYPE;
 import static org.codehaus.groovy.ast.ClassHelper.isPrimitiveType;
 import static org.codehaus.groovy.ast.ClassHelper.make;
-import static org.codehaus.groovy.ast.ClassHelper.makeWithoutCaching;
 import static org.codehaus.groovy.ast.tools.GeneralUtils.args;
 import static org.codehaus.groovy.ast.tools.GeneralUtils.asX;
-import static org.codehaus.groovy.ast.tools.GeneralUtils.block;
 import static org.codehaus.groovy.ast.tools.GeneralUtils.boolX;
 import static org.codehaus.groovy.ast.tools.GeneralUtils.callThisX;
 import static org.codehaus.groovy.ast.tools.GeneralUtils.callX;
@@ -86,11 +82,12 @@ import static org.codehaus.groovy.ast.tools.GeneralUtils.varX;
 @GroovyASTTransformation(phase = CompilePhase.SEMANTIC_ANALYSIS)
 public class NamedVariantASTTransformation extends AbstractASTTransformation {
 
+    private static final ClassNode NAMED_PARAM_TYPE = make(NamedParam.class);
     private static final ClassNode NAMED_VARIANT_TYPE = make(NamedVariant.class);
+    private static final ClassNode NAMED_DELEGATE_TYPE = make(NamedDelegate.class);
+    private static final ClassNode ILLEGAL_ARGUMENT_TYPE = make(IllegalArgumentException.class);
+
     private static final String NAMED_VARIANT = "@" + NAMED_VARIANT_TYPE.getNameWithoutPackage();
-    private static final ClassNode NAMED_PARAM_TYPE = makeWithoutCaching(NamedParam.class, false);
-    private static final ClassNode NAMED_DELEGATE_TYPE = makeWithoutCaching(NamedDelegate.class, false);
-    private static final ClassNode ILLEGAL_ARGUMENT = makeWithoutCaching(IllegalArgumentException.class);
 
     @Override
     public void visit(final ASTNode[] nodes, final SourceUnit source) {
@@ -99,14 +96,14 @@ public class NamedVariantASTTransformation extends AbstractASTTransformation {
         AnnotationNode anno = (AnnotationNode) nodes[0];
         if (!NAMED_VARIANT_TYPE.equals(anno.getClassNode())) return;
 
-        Parameter[] fromParams = mNode.getParameters();
-        if (fromParams.length == 0) {
+        Parameter[] mNodeParams = mNode.getParameters();
+        if (mNodeParams.length == 0) {
             addError("Error during " + NAMED_VARIANT + " processing. No-args method not supported.", mNode);
             return;
         }
 
-        boolean autoDelegate = memberHasValue(anno, "autoDelegate", true);
-        boolean coerce = memberHasValue(anno, "coerce", true);
+        boolean autoDelegate = memberHasValue(anno, "autoDelegate", Boolean.TRUE);
+        boolean coerce = memberHasValue(anno, "coerce", Boolean.TRUE);
         Parameter mapParam = param(GenericsUtils.nonGeneric(MAP_TYPE), "namedArgs");
         List<Parameter> genParams = new ArrayList<>();
         genParams.add(mapParam);
@@ -117,30 +114,30 @@ public class NamedVariantASTTransformation extends AbstractASTTransformation {
 
         // first pass, just check for annotations of interest
         boolean annoFound = false;
-        for (Parameter fromParam : fromParams) {
-            if (AnnotatedNodeUtils.hasAnnotation(fromParam, NAMED_PARAM_TYPE) || AnnotatedNodeUtils.hasAnnotation(fromParam, NAMED_DELEGATE_TYPE)) {
+        for (Parameter mNodeParam : mNodeParams) {
+            if (AnnotatedNodeUtils.hasAnnotation(mNodeParam, NAMED_PARAM_TYPE) || AnnotatedNodeUtils.hasAnnotation(mNodeParam, NAMED_DELEGATE_TYPE)) {
                 annoFound = true;
                 break;
             }
         }
 
         if (!annoFound && autoDelegate) { // the first param is the delegate
-            processDelegateParam(mNode, mapParam, args, propNames, fromParams[0], coerce);
+            processDelegateParam(mNode, mapParam, args, propNames, mNodeParams[0], coerce);
         } else {
             Map<Parameter, Expression> seen = new HashMap<>();
-            for (Parameter fromParam : fromParams) {
+            for (Parameter mNodeParam : mNodeParams) {
                 if (!annoFound) {
-                    if (!processImplicitNamedParam(this, mNode, mapParam, inner, args, propNames, fromParam, coerce, seen)) return;
-                } else if (AnnotatedNodeUtils.hasAnnotation(fromParam, NAMED_PARAM_TYPE)) {
-                    if (!processExplicitNamedParam(mNode, mapParam, inner, args, propNames, fromParam, coerce, seen)) return;
-                } else if (AnnotatedNodeUtils.hasAnnotation(fromParam, NAMED_DELEGATE_TYPE)) {
-                    if (!processDelegateParam(mNode, mapParam, args, propNames, fromParam, coerce)) return;
+                    if (!processImplicitNamedParam(this, mNode, mapParam, inner, args, propNames, mNodeParam, coerce, seen)) return;
+                } else if (AnnotatedNodeUtils.hasAnnotation(mNodeParam, NAMED_PARAM_TYPE)) {
+                    if (!processExplicitNamedParam(mNode, mapParam, inner, args, propNames, mNodeParam, coerce, seen)) return;
+                } else if (AnnotatedNodeUtils.hasAnnotation(mNodeParam, NAMED_DELEGATE_TYPE)) {
+                    if (!processDelegateParam(mNode, mapParam, args, propNames, mNodeParam, coerce)) return;
                 } else {
-                    Expression arg = varX(fromParam);
-                    Expression argOrDefault = fromParam.hasInitialExpression() ? elvisX(arg, fromParam.getDefaultValue()) : arg;
-                    args.addExpression(asType(argOrDefault, fromParam.getType(), coerce));
-                    if (hasDuplicates(this, mNode, propNames, fromParam.getName())) return;
-                    genParams.add(fromParam);
+                    Expression arg = varX(mNodeParam);
+                    Expression argOrDefault = mNodeParam.hasInitialExpression() ? elvisX(arg, mNodeParam.getDefaultValue()) : arg;
+                    args.addExpression(asType(argOrDefault, mNodeParam.getType(), coerce));
+                    if (hasDuplicates(this, mNode, propNames, mNodeParam.getName())) return;
+                    genParams.add(mNodeParam);
                 }
             }
         }
@@ -168,7 +165,7 @@ public class NamedVariantASTTransformation extends AbstractASTTransformation {
             inner.addStatement(new AssertStatement(boolX(containsKey(mapParam, name)),
                     plusX(constX("Missing required named argument '" + name + "'. Keys found: "), callX(varX(mapParam), "keySet"))));
         }
-        Expression defValue = earlierParamIfSeen(seen, fromParam.getInitialExpression());
+        Expression defValue = getDefaultValue(fromParam.getInitialExpression(), seen);
         Expression initExpr = namedParamValue(mapParam, name, type, coerce, defValue);
         if (seen != null) {
             seen.put(fromParam, initExpr);
@@ -177,16 +174,6 @@ public class NamedVariantASTTransformation extends AbstractASTTransformation {
         return true;
     }
 
-    private static Expression earlierParamIfSeen(Map<Parameter, Expression> seen, Expression defValue) {
-        if (seen == null) return defValue;
-        // handle earlier param with or without cast
-        if (defValue instanceof CastExpression) {
-            defValue = ((CastExpression) defValue).getExpression();
-        }
-        return defValue instanceof VariableExpression ?
-            seen.getOrDefault(((VariableExpression) defValue).getAccessedVariable(), defValue) : defValue;
-    }
-
     private boolean processExplicitNamedParam(final MethodNode mNode, final Parameter mapParam, final BlockStatement inner, final ArgumentListExpression args, final List<String> propNames, final Parameter fromParam, final boolean coerce, Map<Parameter, Expression> seen) {
         AnnotationNode namedParam = fromParam.getAnnotations(NAMED_PARAM_TYPE).get(0);
 
@@ -205,7 +192,7 @@ public class NamedVariantASTTransformation extends AbstractASTTransformation {
             // TODO: Check attribute type is assignable to declared param type?
         }
 
-        boolean required = memberHasValue(namedParam, "required", true);
+        boolean required = memberHasValue(namedParam, "required", Boolean.TRUE);
         if (required) {
             if (fromParam.hasInitialExpression()) {
                 addError("Error during " + NAMED_VARIANT + " processing. A required parameter can't have an initial value.", fromParam);
@@ -214,7 +201,7 @@ public class NamedVariantASTTransformation extends AbstractASTTransformation {
             inner.addStatement(new AssertStatement(boolX(containsKey(mapParam, name)),
                     plusX(constX("Missing required named argument '" + name + "'. Keys found: "), callX(varX(mapParam), "keySet"))));
         }
-        Expression defValue = earlierParamIfSeen(seen, fromParam.getInitialExpression());
+        Expression defValue = getDefaultValue(fromParam.getInitialExpression(), seen);
         Expression initExpr = namedParamValue(mapParam, name, type, coerce, defValue);
         seen.put(fromParam, initExpr);
         args.addExpression(initExpr);
@@ -261,7 +248,7 @@ public class NamedVariantASTTransformation extends AbstractASTTransformation {
         Parameter namedArgKey = param(STRING_TYPE, "namedArgKey");
         if (!(mNode instanceof ConstructorNode)) {
             inner.getStatements().add(0, ifS(isNullX(varX(mapParam)),
-                    throwS(ctorX(ILLEGAL_ARGUMENT, args(constX("Named parameter map cannot be null"))))));
+                    throwS(ctorX(ILLEGAL_ARGUMENT_TYPE, args(constX("Named parameter map cannot be null"))))));
         }
         inner.addStatement(
                 new ForStatement(
@@ -304,6 +291,19 @@ public class NamedVariantASTTransformation extends AbstractASTTransformation {
         }
     }
 
+    private static Expression getDefaultValue(final Expression defaultValue, final Map<Parameter, Expression> seen) {
+        if (defaultValue != null && seen != null) { // GROOVY-10561, GROOVY-10889
+            Expression v = defaultValue;
+            while (v instanceof CastExpression) {
+                v = ((CastExpression) v).getExpression();
+            }
+            if (v instanceof VariableExpression) { // maybe it's a reference to a previous parameter
+                return seen.getOrDefault(((VariableExpression) v).getAccessedVariable(), defaultValue);
+            }
+        }
+        return defaultValue;
+    }
+
     private static Expression namedParamValue(final Parameter mapParam, final String name, final ClassNode type, final boolean coerce, Expression defaultValue) {
         Expression value = propX(varX(mapParam), name); // TODO: "map.get(name)"
         if (defaultValue == null && isPrimitiveType(type)) {
diff --git a/src/test/org/codehaus/groovy/transform/NamedVariantTransformTest.groovy b/src/test/org/codehaus/groovy/transform/NamedVariantTransformTest.groovy
index 88814c19aa..ecd1319000 100644
--- a/src/test/org/codehaus/groovy/transform/NamedVariantTransformTest.groovy
+++ b/src/test/org/codehaus/groovy/transform/NamedVariantTransformTest.groovy
@@ -18,8 +18,6 @@
  */
 package org.codehaus.groovy.transform
 
-import groovy.transform.CompileStatic
-import groovy.transform.NamedVariant
 import org.junit.Test
 
 import static groovy.test.GroovyAssert.assertScript
@@ -28,15 +26,15 @@ import static groovy.test.GroovyAssert.shouldFail
 /**
  * Tests for the {@code @NamedVariant} transformation.
  */
-@CompileStatic
 final class NamedVariantTransformTest {
 
+    private final GroovyShell shell = GroovyShell.withConfig {
+        imports { star 'groovy.transform', 'org.codehaus.groovy.ast' }
+    }
+
     @Test
     void testMethod() {
-        assertScript '''
-            import groovy.transform.*
-            import org.codehaus.groovy.ast.*
-
+        assertScript shell, '''
             @ASTTest(phase=CANONICALIZATION, value={
                 def method = node.getMethod('m', new Parameter(ClassHelper.MAP_TYPE, 'map'))
                 use(org.apache.groovy.ast.tools.AnnotatedNodeUtils) {
@@ -57,9 +55,7 @@ final class NamedVariantTransformTest {
 
     @Test
     void testNamedParam() {
-        assertScript '''
-            import groovy.transform.*
-
+        assertScript shell, '''
             class Animal {
                 String type
                 String name
@@ -84,9 +80,7 @@ final class NamedVariantTransformTest {
 
     @Test
     void testNamedParamWithRename() {
-        assertScript '''
-            import groovy.transform.*
-
+        assertScript shell, '''
             @ToString(includeNames=true)
             class Color {
                 Integer r, g, b
@@ -103,9 +97,7 @@ final class NamedVariantTransformTest {
 
     @Test
     void testNamedParamConstructor() {
-        assertScript '''
-            import groovy.transform.*
-
+        assertScript shell, '''
             @ToString(includeNames=true, includeFields=true)
             class Color {
                 @NamedVariant
@@ -123,8 +115,7 @@ final class NamedVariantTransformTest {
 
     @Test
     void testConstructorVisibility() {
-        assertScript '''
-            import groovy.transform.*
+        assertScript shell, '''
             import static groovy.transform.options.Visibility.*
 
             class Color {
@@ -146,9 +137,7 @@ final class NamedVariantTransformTest {
 
     @Test
     void testNamedParamInnerClass() {
-        assertScript '''
-            import groovy.transform.*
-
+        assertScript shell, '''
             class Foo {
                 int adjust
                 @ToString(includeNames = true)
@@ -181,9 +170,7 @@ final class NamedVariantTransformTest {
 
     @Test
     void testGeneratedMethodsSkipped() {
-        assertScript '''
-            import groovy.transform.*
-
+        assertScript shell, '''
             class Storm { String front }
             class Switch { String back }
 
@@ -195,8 +182,7 @@ final class NamedVariantTransformTest {
 
     @Test // GROOVY-9158, GROOVY-10497
     void testNamedParamWithDefaultArgument() {
-        assertScript '''
-            import groovy.transform.*
+        assertScript shell, '''
             import static groovy.test.GroovyAssert.shouldFail
 
             @NamedVariant(coerce=true)
@@ -228,8 +214,7 @@ final class NamedVariantTransformTest {
             }
         '''
 
-        assertScript '''
-            import groovy.transform.*
+        assertScript shell, '''
             import static groovy.test.GroovyAssert.shouldFail
 
             @NamedVariant
@@ -247,9 +232,7 @@ final class NamedVariantTransformTest {
 
     @Test // GROOVY-10176
     void testNamedParamWithPrimitiveValues() {
-        assertScript '''
-            import groovy.transform.*
-
+        assertScript shell, '''
             @ToString(includeNames=true)
             class Color {
                 int r, g, b
@@ -274,9 +257,7 @@ final class NamedVariantTransformTest {
     @Test
     void testNamedParamRequiredVersusOptional() {
         // check dynamic case
-        def err = shouldFail '''
-            import groovy.transform.*
-
+        def err = shouldFail shell, '''
             class Color {
                 int r, g, b
             }
@@ -291,8 +272,7 @@ final class NamedVariantTransformTest {
         assert err =~ /Missing required named argument 'color'/
 
         // also check static error (GROOVY-10484)
-        err = shouldFail '''
-            import groovy.transform.*
+        err = shouldFail shell, '''
             class Color {
                 int r, g, b
             }
@@ -310,9 +290,7 @@ final class NamedVariantTransformTest {
 
     @Test // GROOVY-9183
     void testNamedDelegateWithPrimitiveValues() {
-        assertScript '''
-            import groovy.transform.*
-
+        assertScript shell, '''
             class Color {
                 int r, g, b
             }
@@ -331,9 +309,7 @@ final class NamedVariantTransformTest {
 
     @Test // GROOVY-10261
     void testNamedVariantWithDefaultArguments() {
-        assertScript '''
-            import groovy.transform.*
-
+        assertScript shell, '''
             @TupleConstructor(defaults=false)
             @ToString(includeNames=true)
             class Color {
@@ -354,9 +330,7 @@ final class NamedVariantTransformTest {
 
     @Test // GROOVY-9183, GROOVY-10500
     void testNamedDelegateWithPropertyDefaults() {
-        assertScript '''
-            import groovy.transform.*
-
+        assertScript shell, '''
             class RowMapper {
                 final Settings settings
 
@@ -394,29 +368,47 @@ final class NamedVariantTransformTest {
         '''
     }
 
-    @NamedVariant // GROOVY-10561
-    String fileInSourceSet(String language = 'java', String extension = language) {
-        return "$language -> .$extension"
-    }
-
-    @NamedVariant // GROOVY-10561
-    String foo(String a = 'a', String b = a, String c = (String) a) {
-        return "$a $b $c"
-    }
-
     @Test // GROOVY-10561
     void testReferenceToEarlierParam() {
-        assert fileInSourceSet() == 'java -> .java'
-        assert fileInSourceSet('groovy') == 'groovy -> .groovy'
-        assert fileInSourceSet(language: 'kotlin', extension: 'kt') == 'kotlin -> .kt'
-        assert fileInSourceSet(language: 'groovy') == 'groovy -> .groovy'
+        assertScript shell, '''
+            @NamedVariant
+            String fileInSourceSet(String language = 'java', String extension = language) {
+                return "$language -> .$extension"
+            }
+
+            assert fileInSourceSet() == 'java -> .java'
+            assert fileInSourceSet('groovy') == 'groovy -> .groovy'
+            assert fileInSourceSet(language: 'groovy') == 'groovy -> .groovy'
+            assert fileInSourceSet(language: 'kotlin', extension: 'kt') == 'kotlin -> .kt'
+        '''
     }
 
     @Test // GROOVY-10561
     void testEarlierParamInExpression() {
-        assert foo() == 'a a a'
-        assert foo('c') == 'c c c'
-        assert foo('c', 'd') == 'c d c'
-        assert foo('c', 'd', 'e') == 'c d e'
+        assertScript shell, '''
+            @NamedVariant
+            String foo(String a = 'a', String b = a, String c = (String) a) {
+                return "$a $b $c"
+            }
+
+            assert foo() == 'a a a'
+            assert foo('c') == 'c c c'
+            assert foo('c', 'd') == 'c d c'
+            assert foo('c', 'd', 'e') == 'c d e'
+        '''
+    }
+
+    @Test // GROOVY-10889
+    void testDefaultValueCastIsRetained() {
+        assertScript shell, '''
+            @NamedVariant
+            Tuple2<Integer,Set<String>> createSampleData(Integer integer = 0, Set<String> strings = [] as Set) {
+                Tuple.tuple(integer, strings)
+            }
+
+            def pair = createSampleData(integer: 1)
+            assert pair[0] == 1
+            assert pair[1] == Collections.emptySet()
+        '''
     }
 }