You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@groovy.apache.org by em...@apache.org on 2023/07/24 21:59:57 UTC

[groovy] branch master updated: GROOVY-9144: prevent illegal access

This is an automated email from the ASF dual-hosted git repository.

emilles pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/groovy.git


The following commit(s) were added to refs/heads/master by this push:
     new bff4797be7 GROOVY-9144: prevent illegal access
bff4797be7 is described below

commit bff4797be7b17010885342c5c74eb03635ca4a21
Author: Eric Milles <er...@thomsonreuters.com>
AuthorDate: Mon Jul 24 13:48:04 2023 -0500

    GROOVY-9144: prevent illegal access
---
 .../TimedInterruptibleASTTransformation.groovy     |  79 +++++++--------
 .../codehaus/groovy/reflection/CachedClass.java    |   3 +
 .../codehaus/groovy/reflection/CachedField.java    |  62 +++++++-----
 .../groovy/runtime/DefaultGroovyMethods.java       |  45 ++++-----
 .../org/codehaus/groovy/vmplugin/v8/Selector.java  |   6 +-
 .../groovy/transform/TimedInterruptTest.groovy     |  11 +--
 .../org/apache/groovy/macrolib/MacroLibTest.groovy | 108 ++++++++++-----------
 .../NotYetImplementedASTTransformation.java        |  36 ++++---
 8 files changed, 181 insertions(+), 169 deletions(-)

diff --git a/src/main/groovy/org/codehaus/groovy/transform/TimedInterruptibleASTTransformation.groovy b/src/main/groovy/org/codehaus/groovy/transform/TimedInterruptibleASTTransformation.groovy
index 592d9e6913..91617d3b10 100644
--- a/src/main/groovy/org/codehaus/groovy/transform/TimedInterruptibleASTTransformation.groovy
+++ b/src/main/groovy/org/codehaus/groovy/transform/TimedInterruptibleASTTransformation.groovy
@@ -40,13 +40,12 @@ import org.codehaus.groovy.ast.stmt.ForStatement
 import org.codehaus.groovy.ast.stmt.LoopingStatement
 import org.codehaus.groovy.ast.stmt.Statement
 import org.codehaus.groovy.ast.stmt.WhileStatement
-import org.codehaus.groovy.control.CompilePhase
 import org.codehaus.groovy.control.SourceUnit
 
+import java.lang.reflect.Modifier
 import java.util.concurrent.TimeUnit
 import java.util.concurrent.TimeoutException
 
-import static org.codehaus.groovy.ast.ClassHelper.make
 import static org.codehaus.groovy.ast.tools.GeneralUtils.args
 import static org.codehaus.groovy.ast.tools.GeneralUtils.block
 import static org.codehaus.groovy.ast.tools.GeneralUtils.callX
@@ -59,8 +58,6 @@ import static org.codehaus.groovy.ast.tools.GeneralUtils.plusX
 import static org.codehaus.groovy.ast.tools.GeneralUtils.propX
 import static org.codehaus.groovy.ast.tools.GeneralUtils.throwS
 import static org.codehaus.groovy.ast.tools.GeneralUtils.varX
-import static org.objectweb.asm.Opcodes.ACC_FINAL
-import static org.objectweb.asm.Opcodes.ACC_PRIVATE
 
 /**
  * Allows "interrupt-safe" executions of scripts by adding timer expiration
@@ -70,18 +67,16 @@ import static org.objectweb.asm.Opcodes.ACC_PRIVATE
  * @see groovy.transform.ThreadInterrupt
  * @since 1.8.0
  */
-@CompileStatic
-@AutoFinal
-@GroovyASTTransformation(phase = CompilePhase.CANONICALIZATION)
+@AutoFinal @CompileStatic @GroovyASTTransformation
 class TimedInterruptibleASTTransformation extends AbstractASTTransformation {
 
-    private static final ClassNode MY_TYPE = make(TimedInterrupt)
+    private static final ClassNode MY_TYPE = ClassHelper.make(TimedInterrupt)
     private static final String CHECK_METHOD_START_MEMBER = 'checkOnMethodStart'
     private static final String APPLY_TO_ALL_CLASSES = 'applyToAllClasses'
     private static final String APPLY_TO_ALL_MEMBERS = 'applyToAllMembers'
     private static final String THROWN_EXCEPTION_TYPE = 'thrown'
 
-    @SuppressWarnings('Instanceof')
+    @Override
     void visit(ASTNode[] nodes, SourceUnit source) {
         init(nodes, source)
         AnnotationNode node = (AnnotationNode) nodes[0]
@@ -94,7 +89,7 @@ class TimedInterruptibleASTTransformation extends AbstractASTTransformation {
         def applyToAllMembers = getConstantAnnotationParameter(node, APPLY_TO_ALL_MEMBERS, Boolean.TYPE, true)
         def applyToAllClasses = applyToAllMembers ? getConstantAnnotationParameter(node, APPLY_TO_ALL_CLASSES, Boolean.TYPE, true) : false
         def maximum = getConstantAnnotationParameter(node, 'value', Long.TYPE, Long.MAX_VALUE)
-        def thrown = AbstractInterruptibleASTTransformation.getClassAnnotationParameter(node, THROWN_EXCEPTION_TYPE, make(TimeoutException))
+        def thrown = AbstractInterruptibleASTTransformation.getClassAnnotationParameter(node, THROWN_EXCEPTION_TYPE, ClassHelper.make(TimeoutException))
 
         Expression unit = node.getMember('unit') ?: propX(classX(TimeUnit), 'SECONDS')
 
@@ -136,7 +131,6 @@ class TimedInterruptibleASTTransformation extends AbstractASTTransformation {
         }
     }
 
-    @SuppressWarnings('Instanceof')
     static getConstantAnnotationParameter(AnnotationNode node, String parameterName, Class type, defaultValue) {
         def member = node.getMember(parameterName)
         if (member) {
@@ -170,7 +164,6 @@ class TimedInterruptibleASTTransformation extends AbstractASTTransformation {
         private final ClassNode thrown
         private final String basename
 
-        @SuppressWarnings('ParameterCount')
         TimedInterruptionVisitor(SourceUnit source, checkOnMethodStart, applyToAllClasses, applyToAllMembers, maximum, Expression unit, ClassNode thrown, hash) {
             this.sourceUnit = source
             this.checkOnMethodStart = checkOnMethodStart
@@ -189,7 +182,7 @@ class TimedInterruptibleASTTransformation extends AbstractASTTransformation {
             ifS(
                     ltX(
                             propX(varX('this'), basename + '$expireTime'),
-                            callX(make(System), 'nanoTime')
+                            callX(ClassHelper.make(System), 'nanoTime')
                     ),
                     throwS(
                             ctorX(thrown,
@@ -226,42 +219,40 @@ class TimedInterruptibleASTTransformation extends AbstractASTTransformation {
 
         @Override
         void visitClass(ClassNode node) {
-            if (node.getDeclaredField(basename + '$expireTime')) {
-                return
-            }
-            expireTimeField = node.addField(basename + '$expireTime',
-                    ACC_FINAL | ACC_PRIVATE,
-                    ClassHelper.long_TYPE,
-                    plusX(
-                            callX(make(System), 'nanoTime'),
-                            callX(
-                                    propX(classX(TimeUnit), 'NANOSECONDS'),
-                                    'convert',
-                                    args(constX(maximum, true), unit)
-                            )
-                    )
-            )
-            expireTimeField.synthetic = true
-            ClassNode dateClass = make(Date)
-            startTimeField = node.addField(basename + '$startTime',
-                    ACC_FINAL | ACC_PRIVATE,
-                    dateClass,
-                    ctorX(dateClass)
-            )
-            startTimeField.synthetic = true
+            String startTime = basename + '$startTime'
+            String expireTime = basename + '$expireTime'
+            if (node.getDeclaredField(expireTime) == null) {
+                expireTimeField = node.addFieldFirst(
+                        expireTime,
+                        Modifier.FINAL | Modifier.PRIVATE,
+                        ClassHelper.long_TYPE,
+                        plusX(
+                                callX(ClassHelper.make(System), 'nanoTime'),
+                                callX(
+                                        propX(classX(TimeUnit), 'NANOSECONDS'),
+                                        'convert',
+                                        args(constX(maximum, true), unit)
+                                )
+                        )
+                )
+                expireTimeField.synthetic = true
+
+                ClassNode dateClass = ClassHelper.make(Date)
+                startTimeField = node.addFieldFirst(
+                        startTime,
+                        Modifier.FINAL | Modifier.PRIVATE,
+                        dateClass,
+                        ctorX(dateClass)
+                )
+                startTimeField.synthetic = true
 
-            // force these fields to be initialized first
-            node.fields.remove(expireTimeField)
-            node.fields.remove(startTimeField)
-            node.fields.add(0, startTimeField)
-            node.fields.add(0, expireTimeField)
-            if (applyToAllMembers) {
-                super.visitClass node
+                if (applyToAllMembers) {
+                    super.visitClass(node)
+                }
             }
         }
 
         @Override
-        @SuppressWarnings('Instanceof')
         void visitClosureExpression(ClosureExpression closureExpr) {
             def code = closureExpr.code
             if (code instanceof BlockStatement) {
diff --git a/src/main/java/org/codehaus/groovy/reflection/CachedClass.java b/src/main/java/org/codehaus/groovy/reflection/CachedClass.java
index 5457a63431..52f215d6ae 100644
--- a/src/main/java/org/codehaus/groovy/reflection/CachedClass.java
+++ b/src/main/java/org/codehaus/groovy/reflection/CachedClass.java
@@ -58,6 +58,9 @@ public class CachedClass {
     }
 
     private static <M extends AccessibleObject & Member> boolean isAccessibleOrCanSetAccessible(M m) {
+        if (isPublic(m.getModifiers()) && m.getDeclaringClass().getPackageName().startsWith("sun.")) {
+            return false;
+        }
         if (isProtected(m.getModifiers()) && isPublic(m.getDeclaringClass().getModifiers())) {
             return true;
         }
diff --git a/src/main/java/org/codehaus/groovy/reflection/CachedField.java b/src/main/java/org/codehaus/groovy/reflection/CachedField.java
index f7ccc80634..58b8431097 100644
--- a/src/main/java/org/codehaus/groovy/reflection/CachedField.java
+++ b/src/main/java/org/codehaus/groovy/reflection/CachedField.java
@@ -20,23 +20,31 @@ package org.codehaus.groovy.reflection;
 
 import groovy.lang.GroovyRuntimeException;
 import groovy.lang.MetaProperty;
-import org.codehaus.groovy.runtime.typehandling.DefaultTypeTransformation;
 
+import java.lang.invoke.MethodHandle;
+import java.lang.invoke.MethodHandles;
 import java.lang.reflect.Field;
 import java.lang.reflect.Modifier;
 
-import static org.codehaus.groovy.reflection.ReflectionUtils.makeAccessibleInPrivilegedAction;
+import static org.codehaus.groovy.runtime.typehandling.DefaultTypeTransformation.castToType;
 
 public class CachedField extends MetaProperty {
-    private final Field field;
 
     public CachedField(final Field field) {
         super(field.getName(), field.getType());
         this.field = field;
     }
 
+    private final Field field;
+    private boolean madeAccessible;
+    private void makeAccessible() {
+        ReflectionUtils.makeAccessibleInPrivilegedAction(field);
+        AccessPermissionChecker.checkAccessPermission(field);
+        madeAccessible = true;
+    }
+
     public Field getCachedField() {
-        makeAccessibleIfNecessary();
+        if (!madeAccessible) makeAccessible();
         return field;
     }
 
@@ -44,14 +52,6 @@ public class CachedField extends MetaProperty {
         return field.getDeclaringClass();
     }
 
-    /**
-     * {@inheritDoc}
-     */
-    @Override
-    public int getModifiers() {
-        return field.getModifiers();
-    }
-
     public boolean isFinal() {
         return Modifier.isFinal(getModifiers());
     }
@@ -60,12 +60,20 @@ public class CachedField extends MetaProperty {
         return Modifier.isStatic(getModifiers());
     }
 
+    /**
+     * {@inheritDoc}
+     */
+    @Override
+    public int getModifiers() {
+        return field.getModifiers();
+    }
+
     /**
      * {@inheritDoc}
      */
     @Override
     public Object getProperty(final Object object) {
-        makeAccessibleIfNecessary();
+        var field = getCachedField();
         try {
             return field.get(object);
         } catch (IllegalAccessException | IllegalArgumentException e) {
@@ -77,25 +85,33 @@ public class CachedField extends MetaProperty {
      * {@inheritDoc}
      */
     @Override
-    public void setProperty(final Object object, final Object newValue) {
+    public  void  setProperty(final Object object, Object newValue) {
         if (isFinal()) {
             throw new GroovyRuntimeException("Cannot set the property '" + name + "' because the backing field is final.");
         }
-        makeAccessibleIfNecessary();
-        Object goalValue = DefaultTypeTransformation.castToType(newValue, field.getType());
+        newValue = castToType(newValue, field.getType());
+        var field = getCachedField();
         try {
-            field.set(object, goalValue);
+            field.set(object, newValue);
         } catch (IllegalAccessException | IllegalArgumentException e) {
             throw new GroovyRuntimeException("Cannot set the property '" + name + "'.", e);
         }
     }
 
-    private transient boolean madeAccessible;
-    private void makeAccessibleIfNecessary() {
-        if (!madeAccessible) {
-            makeAccessibleInPrivilegedAction(field);
-            madeAccessible = true;
+    public MethodHandle asAccessMethod(final MethodHandles.Lookup lookup) throws IllegalAccessException {
+        try {
+            return lookup.unreflectGetter(field);
+        } catch (IllegalAccessException e) {
+            if (!madeAccessible) {
+                try {
+                    makeAccessible();
+                    return lookup.unreflectGetter(field);
+                } catch (IllegalAccessException ignore) {
+                } catch (Throwable t) {
+                    e.addSuppressed(t);
+                }
+            }
+            throw e;
         }
-        AccessPermissionChecker.checkAccessPermission(field);
     }
 }
diff --git a/src/main/java/org/codehaus/groovy/runtime/DefaultGroovyMethods.java b/src/main/java/org/codehaus/groovy/runtime/DefaultGroovyMethods.java
index 8e61dd5edd..b532fc6ed3 100644
--- a/src/main/java/org/codehaus/groovy/runtime/DefaultGroovyMethods.java
+++ b/src/main/java/org/codehaus/groovy/runtime/DefaultGroovyMethods.java
@@ -57,6 +57,7 @@ import groovy.util.PermutationGenerator;
 import groovy.util.ProxyGenerator;
 import org.apache.groovy.io.StringBuilderWriter;
 import org.apache.groovy.util.ReversedList;
+import org.apache.groovy.util.SystemUtil;
 import org.codehaus.groovy.classgen.Verifier;
 import org.codehaus.groovy.reflection.ClassInfo;
 import org.codehaus.groovy.reflection.MixinInMetaClass;
@@ -114,7 +115,6 @@ import java.math.BigInteger;
 import java.math.RoundingMode;
 import java.net.URL;
 import java.security.CodeSource;
-import java.security.PrivilegedAction;
 import java.text.MessageFormat;
 import java.util.AbstractCollection;
 import java.util.AbstractMap;
@@ -439,47 +439,38 @@ public class DefaultGroovyMethods extends DefaultGroovyMethodsSupport {
             return "null";
         }
         StringBuilder buffer = new StringBuilder("<");
-        Class klass = self.getClass();
+        Class<?> klass = self.getClass();
         buffer.append(klass.getName());
         buffer.append("@");
         buffer.append(Integer.toHexString(self.hashCode()));
         boolean groovyObject = self instanceof GroovyObject;
-
         while (klass != null) {
-            for (final Field field : klass.getDeclaredFields()) {
-                if ((field.getModifiers() & Modifier.STATIC) == 0) {
-                    if (groovyObject && field.getName().equals("metaClass")) {
-                        continue;
-                    }
-                    trySetAccessible(field);
-                    buffer.append(" ");
-                    buffer.append(field.getName());
-                    buffer.append("=");
-                    try {
-                        buffer.append(FormatHelper.toString(field.get(self)));
-                    } catch (IllegalAccessException e) {
+            for (Field field : klass.getDeclaredFields()) {
+                if (Modifier.isStatic(field.getModifiers()) || (groovyObject && field.getName().equals("metaClass"))) {
+                    continue;
+                }
+                buffer.append(" ");
+                buffer.append(field.getName());
+                buffer.append("=");
+                if (!field.canAccess(self)) { // GROOVY-9144
+                    if (!SystemUtil.getBooleanSafe("groovy.force.illegal.access")
+                            || ReflectionUtils.makeAccessibleInPrivilegedAction(field).isEmpty()) {
                         buffer.append("inaccessible");
-                    } catch (Exception e) {
-                        buffer.append(e);
+                        continue;
                     }
                 }
+                try {
+                    buffer.append(FormatHelper.toString(field.get(self)));
+                } catch (Exception e) {
+                    buffer.append(e);
+                }
             }
-
             klass = klass.getSuperclass();
         }
-
         buffer.append(">");
         return buffer.toString();
     }
 
-    @SuppressWarnings("removal") // TODO a future Groovy version should perform the accessible check not as a privileged action
-    private static void trySetAccessible(final Field field) {
-        java.security.AccessController.doPrivileged((PrivilegedAction<Object>) () -> {
-            ReflectionUtils.trySetAccessible(field);
-            return null;
-        });
-    }
-
     /**
      * Retrieves the list of {@link groovy.lang.MetaProperty} objects for 'self' and wraps it
      * in a list of {@link groovy.lang.PropertyValue} objects that additionally provide
diff --git a/src/main/java/org/codehaus/groovy/vmplugin/v8/Selector.java b/src/main/java/org/codehaus/groovy/vmplugin/v8/Selector.java
index 8851bc0da3..06c618943d 100644
--- a/src/main/java/org/codehaus/groovy/vmplugin/v8/Selector.java
+++ b/src/main/java/org/codehaus/groovy/vmplugin/v8/Selector.java
@@ -336,9 +336,9 @@ public abstract class Selector {
                 insertName = true; // pass "name" field as argument
             } else if (mp instanceof CachedField && !Modifier.isStatic(mp.getModifiers())) {
                 try {
-                    MethodHandles.Lookup lookup = Modifier.isPublic(mp.getModifiers()) ? LOOKUP
-                      : ((Java8) VMPluginFactory.getPlugin()).newLookup(sender); // GROOVY-9596
-                    handle = lookup.unreflectGetter(((CachedField) mp).getCachedField());
+                    // GROOVY-9144, GROOVY-9596: get lookup for sender and unreflect before forcing access
+                    MethodHandles.Lookup lookup = ((Java8) VMPluginFactory.getPlugin()).newLookup(sender);
+                    handle = ((CachedField) mp).asAccessMethod(lookup);
                 } catch (IllegalAccessException e) {
                     throw new GroovyBugError(e);
                 }
diff --git a/src/test/groovy/transform/TimedInterruptTest.groovy b/src/test/groovy/transform/TimedInterruptTest.groovy
index 2108c83a26..b784686a64 100644
--- a/src/test/groovy/transform/TimedInterruptTest.groovy
+++ b/src/test/groovy/transform/TimedInterruptTest.groovy
@@ -205,13 +205,13 @@ final class TimedInterruptTest {
             new C()
         '''
         def system = new StubFor(System)
-        // start time initialized to the Long of the Beast
         system.demand.nanoTime(4) { 666L } // 2 times to cover full instantiation
         system.demand.nanoTime() { 1000000667L }
         system.use {
             def instance = shell.evaluate(script)
             // may get false positives if multiple annotations with the same expireTime defined in test script
-            assert instance.dump().matches('.*timedInterrupt\\S+\\$expireTime=1000000666 .*')
+            def expired = instance.class.declaredFields.find { it.name =~ 'timedInterrupt\\S+\\$expireTime' }
+            assert instance.@(expired.name) == 1000000666L
 
             shouldFail(TimeoutException) {
                 instance.m()
@@ -223,7 +223,6 @@ final class TimedInterruptTest {
 
     static void assertPassesNormalFailsSlowExecution(Map<String,?> args, Class type) {
         def system = new StubFor(System)
-        // start time initialized to ...
         system.demand.nanoTime() { 666L }
         def instance
         system.use {
@@ -232,7 +231,8 @@ final class TimedInterruptTest {
         long expireTime = args.getOrDefault('expireTime', 1000000666L)
         String methodName = args.getOrDefault('methodName', 'myMethod')
         // may get false positives if multiple annotations with the same expireTime defined
-        assert instance.dump().matches('.*timedInterrupt\\S+\\$expireTime=' + expireTime + ' .*')
+        def expired = instance.class.declaredFields.find { it.name =~ 'timedInterrupt\\S+\\$expireTime' }
+        assert instance.@(expired.name) == expireTime
 
         system.demand.nanoTime() { expireTime }
         system.use {
@@ -242,7 +242,7 @@ final class TimedInterruptTest {
         // one nanosecond too slow
         system.demand.nanoTime() { expireTime + 1 }
         system.use {
-            def err = shouldFail(args.getOrDefault('exception', java.util.concurrent.TimeoutException)) {
+            def err = shouldFail(args.getOrDefault('exception', TimeoutException)) {
                 instance.(methodName)()
             }
             assert err.message.contains('Execution timed out after ' + args.getOrDefault('units', '1') + ' ' + args.getOrDefault('timeUnitName', 'seconds'))
@@ -251,7 +251,6 @@ final class TimedInterruptTest {
 
     static void assertPassesSlowExecution(Class c) {
         def system = new StubFor(System)
-        // start time initialized to the Long of the Beast
         system.demand.nanoTime() { 666L }
         def instance
         system.use {
diff --git a/subprojects/groovy-macro-library/src/test/groovy/org/apache/groovy/macrolib/MacroLibTest.groovy b/subprojects/groovy-macro-library/src/test/groovy/org/apache/groovy/macrolib/MacroLibTest.groovy
index da74e2f213..bc63e8b315 100644
--- a/subprojects/groovy-macro-library/src/test/groovy/org/apache/groovy/macrolib/MacroLibTest.groovy
+++ b/subprojects/groovy-macro-library/src/test/groovy/org/apache/groovy/macrolib/MacroLibTest.groovy
@@ -18,89 +18,87 @@
  */
 package org.apache.groovy.macrolib
 
-import groovy.test.GroovyTestCase
-import groovy.transform.CompileStatic
+import org.junit.Test
 
-@CompileStatic
-class MacroLibTest extends GroovyTestCase {
+import static groovy.test.GroovyAssert.assertScript
 
-    def BASE = '''
-    def num = 42
-    def list = [1 ,2, 3]
-    def range = 0..5
-    def string = 'foo'
+final class MacroLibTest {
+
+    private static final String BASE = '''\
+        def num = 42
+        def list = [1 ,2, 3]
+        def range = 0..5
+        def string = 'foo'
     '''
 
+    @Test
     void testSV() {
-        assertScript """
-        $BASE
-        assert SV(num, list, range, string).toString() == 'num=42, list=[1, 2, 3], range=[0, 1, 2, 3, 4, 5], string=foo'
-        """
+        assertScript BASE + '''\
+            assert SV(num, list, range, string).toString() == 'num=42, list=[1, 2, 3], range=[0, 1, 2, 3, 4, 5], string=foo'
+        '''
     }
 
+    @Test
     void testSVInClosure() {
-        assertScript """
-        $BASE
-        def cl = {
-            SV(num, list, range, string).toString()
-        }
-
-        assert cl().toString() == 'num=42, list=[1, 2, 3], range=[0, 1, 2, 3, 4, 5], string=foo'
-        """
+        assertScript BASE + '''\
+            def cl = {
+                SV(num, list, range, string).toString()
+            }
+            assert cl().toString() == 'num=42, list=[1, 2, 3], range=[0, 1, 2, 3, 4, 5], string=foo'
+        '''
     }
 
+    @Test
     void testList() {
-        assertScript """
-        $BASE
-        assert [SV(num, list), SV(range, string)].toString() == '[num=42, list=[1, 2, 3], range=[0, 1, 2, 3, 4, 5], string=foo]'
-        """
+        assertScript BASE + '''\
+            assert [SV(num, list), SV(range, string)].toString() == '[num=42, list=[1, 2, 3], range=[0, 1, 2, 3, 4, 5], string=foo]'
+        '''
     }
 
+    @Test
     void testSVInclude() {
-        assertScript """
-        $BASE
-        def numSV = SV(num)
-        assert SV(numSV, list, range, string).toString() == 'numSV=num=42, list=[1, 2, 3], range=[0, 1, 2, 3, 4, 5], string=foo'
-        """
+        assertScript BASE + '''\
+            def numSV = SV(num)
+            assert SV(numSV, list, range, string).toString() == 'numSV=num=42, list=[1, 2, 3], range=[0, 1, 2, 3, 4, 5], string=foo'
+        '''
     }
 
+    @Test
     void testNested() {
-        assertScript """
-        $BASE
-        def result = SV(SV(num), string)
-        def strip = 'org.codehaus.groovy.macro.runtime.MacroStub.INSTANCE.'
-        assert result - strip == 'macroMethod()=num=42, string=foo'
-        """
+        assertScript BASE + '''\
+            def result = SV(SV(num), string)
+            def strip = 'org.codehaus.groovy.macro.runtime.MacroStub.INSTANCE.'
+            assert result - strip == 'macroMethod()=num=42, string=foo'
+        '''
     }
 
+    @Test
     void testSVI() {
-        assertScript """
-        $BASE
-        assert SVI(num, list, range, string).toString() == /num=42, list=[1, 2, 3], range=0..5, string='foo'/
-        """
+        assertScript BASE + '''\
+            assert SVI(num, list, range, string).toString() == /num=42, list=[1, 2, 3], range=0..5, string='foo'/
+        '''
     }
 
+    @Test
     void testSVD() {
-        assertScript """
-        $BASE
-        def result = SVD(num, list, range, string)
-        def trimmed = result.replaceAll(/@[^>]+/, '@...')
-        assert trimmed == /num=<ja...@...>, list=<ja...@...>, range=<gr...@...>, string=<ja...@...>/
-        """
+        assertScript BASE + '''\
+            def result = SVD(num, list, range, string)
+            def trimmed = result.replaceAll(/@[^>]+/, '@...')
+            assert trimmed == /num=<ja...@...>, list=<ja...@...>, range=<gr...@...>, string=<ja...@...>/
+        '''
     }
 
+    @Test
     void testNV() {
-        assertScript """
-        $BASE
-        assert NV(num).toString() == 'num=42'
-        """
+        assertScript BASE + '''\
+            assert NV(num).toString() == 'num=42'
+        '''
     }
 
+    @Test
     void testNVL() {
-        assertScript """
-        $BASE
-        assert NVL(num, list, range, string).toString() == "[num=42, list=[1, 2, 3], range=0..5, string='foo']"
-        """
+        assertScript BASE + '''\
+            assert NVL(num, list, range, string).toString() == "[num=42, list=[1, 2, 3], range=0..5, string='foo']"
+        '''
     }
-
 }
diff --git a/subprojects/groovy-test/src/main/java/org/apache/groovy/test/transform/NotYetImplementedASTTransformation.java b/subprojects/groovy-test/src/main/java/org/apache/groovy/test/transform/NotYetImplementedASTTransformation.java
index 8abf5102dc..c8f74e5e49 100644
--- a/subprojects/groovy-test/src/main/java/org/apache/groovy/test/transform/NotYetImplementedASTTransformation.java
+++ b/subprojects/groovy-test/src/main/java/org/apache/groovy/test/transform/NotYetImplementedASTTransformation.java
@@ -25,21 +25,23 @@ import org.codehaus.groovy.ast.ClassNode;
 import org.codehaus.groovy.ast.ConstructorNode;
 import org.codehaus.groovy.ast.MethodNode;
 import org.codehaus.groovy.ast.Parameter;
+import org.codehaus.groovy.ast.expr.Expression;
 import org.codehaus.groovy.ast.stmt.BlockStatement;
 import org.codehaus.groovy.ast.stmt.EmptyStatement;
 import org.codehaus.groovy.ast.stmt.ReturnStatement;
 import org.codehaus.groovy.ast.stmt.ThrowStatement;
 import org.codehaus.groovy.ast.stmt.TryCatchStatement;
-import org.codehaus.groovy.control.CompilePhase;
 import org.codehaus.groovy.control.SourceUnit;
 import org.codehaus.groovy.transform.AbstractASTTransformation;
 import org.codehaus.groovy.transform.GroovyASTTransformation;
 
 import static org.codehaus.groovy.ast.tools.GeneralUtils.args;
 import static org.codehaus.groovy.ast.tools.GeneralUtils.block;
+import static org.codehaus.groovy.ast.tools.GeneralUtils.castX;
 import static org.codehaus.groovy.ast.tools.GeneralUtils.catchS;
 import static org.codehaus.groovy.ast.tools.GeneralUtils.constX;
 import static org.codehaus.groovy.ast.tools.GeneralUtils.ctorX;
+import static org.codehaus.groovy.ast.tools.GeneralUtils.nullX;
 import static org.codehaus.groovy.ast.tools.GeneralUtils.param;
 import static org.codehaus.groovy.ast.tools.GeneralUtils.throwS;
 import static org.codehaus.groovy.ast.tools.GeneralUtils.tryCatchS;
@@ -49,23 +51,33 @@ import static org.codehaus.groovy.ast.tools.GeneralUtils.tryCatchS;
  *
  * @see groovy.test.NotYetImplemented
  */
-@GroovyASTTransformation(phase = CompilePhase.CANONICALIZATION)
+@GroovyASTTransformation
 public class NotYetImplementedASTTransformation extends AbstractASTTransformation {
-    private static final ClassNode DEFAULT_THROW_TYPE = ClassHelper.make(AssertionError.class);
 
     @Override
     public void visit(ASTNode[] nodes, SourceUnit source) {
         init(nodes, source);
-        AnnotationNode anno = (AnnotationNode) nodes[0];
         MethodNode methodNode = (MethodNode) nodes[1];
+        AnnotationNode annotation = (AnnotationNode) nodes[0];
 
-        ClassNode exception = getMemberClassValue(anno, "exception");
+        ClassNode exception = getMemberClassValue(annotation, "exception");
+        boolean   withCause = false;
         if (exception == null) {
-            exception = DEFAULT_THROW_TYPE;
-        }
-        ConstructorNode cons = exception.getDeclaredConstructor(new Parameter[]{new Parameter(ClassHelper.STRING_TYPE, "dummy")});
-        if (cons == null) {
-            addError("Error during @NotYetImplemented processing: supplied exception " + exception.getNameWithoutPackage() + " doesn't have expected String constructor", methodNode);
+            exception = ClassHelper.make(AssertionError.class);
+            withCause = true; // AssertionError(String,Throwable) is public
+        } else {
+            Parameter message = new Parameter(ClassHelper.STRING_TYPE, "message");
+            ConstructorNode ctor = exception.getDeclaredConstructor(new Parameter[]{message});
+            if (ctor != null && ctor.isPublic()) {
+                // all set
+            } else {
+                ctor = exception.getDeclaredConstructor(new Parameter[]{message, new Parameter(ClassHelper.THROWABLE_TYPE, "cause")});
+                if (ctor != null && ctor.isPublic()) {
+                    withCause = true;
+                } else {
+                    addError("Error during @NotYetImplemented processing: supplied exception " + exception.getNameWithoutPackage() + " doesn't have expected String constructor", methodNode);
+                }
+            }
         }
 
         if (methodNode.getCode() instanceof BlockStatement && !methodNode.getCode().isEmpty()) {
@@ -76,7 +88,9 @@ public class NotYetImplementedASTTransformation extends AbstractASTTransformatio
                     EmptyStatement.INSTANCE,
                     catchS(param(ClassHelper.THROWABLE_TYPE.getPlainNodeReference(), "ignore"), ReturnStatement.RETURN_NULL_OR_VOID));
 
-            ThrowStatement throwStatement = throwS(ctorX(exception, args(constX("Method is marked with @NotYetImplemented but passes unexpectedly"))));
+            Expression arguments = constX("Method is marked with @NotYetImplemented but passes unexpectedly");
+            if (withCause) arguments = args(arguments, castX(ClassHelper.THROWABLE_TYPE, nullX()));
+            ThrowStatement throwStatement = throwS(ctorX(exception, arguments));
 
             methodNode.setCode(block(tryCatchStatement, throwStatement));
         }