You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@groovy.apache.org by su...@apache.org on 2019/12/07 18:04:36 UTC

[groovy] branch master updated: Avoid unnecessary capturing the instance of enclosing class

This is an automated email from the ASF dual-hosted git repository.

sunlan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/groovy.git


The following commit(s) were added to refs/heads/master by this push:
     new 23b53b4  Avoid unnecessary capturing the instance of enclosing class
23b53b4 is described below

commit 23b53b4db1720fc365cc18eeb6bc2600a9ef368d
Author: Daniel Sun <su...@apache.org>
AuthorDate: Sun Dec 8 02:03:03 2019 +0800

    Avoid unnecessary capturing the instance of enclosing class
    
    If the lambda expression does not access the instance of enclosing class, e.g. no instance fields or instance methods accessed, the instance of enclosing class need not to be captured.
---
 .../classgen/asm/sc/StaticTypesLambdaWriter.java   |  46 +++-
 .../groovy/control/StaticImportVisitor.java        |  15 +-
 src/test/groovy/transform/stc/LambdaTest.groovy    | 255 +++++++++++++++++++--
 3 files changed, 282 insertions(+), 34 deletions(-)

diff --git a/src/main/java/org/codehaus/groovy/classgen/asm/sc/StaticTypesLambdaWriter.java b/src/main/java/org/codehaus/groovy/classgen/asm/sc/StaticTypesLambdaWriter.java
index 5005e33..4f7f5d9 100644
--- a/src/main/java/org/codehaus/groovy/classgen/asm/sc/StaticTypesLambdaWriter.java
+++ b/src/main/java/org/codehaus/groovy/classgen/asm/sc/StaticTypesLambdaWriter.java
@@ -19,6 +19,7 @@
 
 package org.codehaus.groovy.classgen.asm.sc;
 
+import org.apache.groovy.util.ObjectHolder;
 import org.codehaus.groovy.GroovyBugError;
 import org.codehaus.groovy.ast.ClassCodeVisitorSupport;
 import org.codehaus.groovy.ast.ClassHelper;
@@ -135,8 +136,8 @@ public class StaticTypesLambdaWriter extends LambdaWriter implements AbstractFun
                 addDeserializeLambdaMethod();
             }
 
-            newGroovyLambdaWrapperAndLoad(lambdaWrapperClassNode, expression);
-            loadEnclosingClassInstance();
+            newGroovyLambdaWrapperAndLoad(lambdaWrapperClassNode, syntheticLambdaMethodNode, expression);
+            loadEnclosingClassInstance(syntheticLambdaMethodNode);
         }
 
         MethodVisitor mv = controller.getMethodVisitor();
@@ -160,12 +161,12 @@ public class StaticTypesLambdaWriter extends LambdaWriter implements AbstractFun
         return new Parameter[]{new Parameter(ClassHelper.SERIALIZEDLAMBDA_TYPE, SERIALIZED_LAMBDA_PARAM_NAME)};
     }
 
-    private void loadEnclosingClassInstance() {
+    private void loadEnclosingClassInstance(MethodNode syntheticLambdaMethodNode) {
         MethodVisitor mv = controller.getMethodVisitor();
         OperandStack operandStack = controller.getOperandStack();
         CompileStack compileStack = controller.getCompileStack();
 
-        if (controller.isStaticMethod() || compileStack.isInSpecialConstructorCall()) {
+        if (controller.isStaticMethod() || compileStack.isInSpecialConstructorCall() || !isAccessingInstanceMembers(syntheticLambdaMethodNode)) {
             operandStack.pushConstant(ConstantExpression.NULL);
         } else {
             mv.visitVarInsn(ALOAD, 0);
@@ -173,13 +174,46 @@ public class StaticTypesLambdaWriter extends LambdaWriter implements AbstractFun
         }
     }
 
-    private void newGroovyLambdaWrapperAndLoad(ClassNode lambdaWrapperClassNode, LambdaExpression expression) {
+    private boolean isAccessingInstanceMembers(MethodNode syntheticLambdaMethodNode) {
+        ObjectHolder<Boolean> objectHolder = new ObjectHolder<>(false);
+        ClassCodeVisitorSupport classCodeVisitorSupport = new ClassCodeVisitorSupport() {
+            @Override
+            public void visitVariableExpression(VariableExpression expression) {
+                if (expression.isThisExpression()) {
+                    objectHolder.setObject(true);
+                }
+            }
+
+            @Override
+            public void visitMethodCallExpression(MethodCallExpression call) {
+                if (!call.getMethodTarget().isStatic()) {
+                    Expression objectExpression = call.getObjectExpression();
+                    if (objectExpression instanceof VariableExpression && ENCLOSING_THIS.equals(((VariableExpression) objectExpression).getName())) {
+                        objectHolder.setObject(true);
+                    }
+                }
+
+                super.visitMethodCallExpression(call);
+            }
+
+            @Override
+            protected SourceUnit getSourceUnit() {
+                return null;
+            }
+        };
+
+        classCodeVisitorSupport.visitMethod(syntheticLambdaMethodNode);
+
+        return objectHolder.getObject();
+    }
+
+    private void newGroovyLambdaWrapperAndLoad(ClassNode lambdaWrapperClassNode, MethodNode syntheticLambdaMethodNode, LambdaExpression expression) {
         MethodVisitor mv = controller.getMethodVisitor();
         String lambdaWrapperClassInternalName = BytecodeHelper.getClassInternalName(lambdaWrapperClassNode);
         mv.visitTypeInsn(NEW, lambdaWrapperClassInternalName);
         mv.visitInsn(DUP);
 
-        loadEnclosingClassInstance();
+        loadEnclosingClassInstance(syntheticLambdaMethodNode);
         controller.getOperandStack().dup();
 
         loadSharedVariables(expression);
diff --git a/src/main/java/org/codehaus/groovy/control/StaticImportVisitor.java b/src/main/java/org/codehaus/groovy/control/StaticImportVisitor.java
index 7d2db8d..e8557a7 100644
--- a/src/main/java/org/codehaus/groovy/control/StaticImportVisitor.java
+++ b/src/main/java/org/codehaus/groovy/control/StaticImportVisitor.java
@@ -100,25 +100,26 @@ public class StaticImportVisitor extends ClassCodeExpressionTransformer {
 
     public Expression transform(Expression exp) {
         if (exp == null) return null;
-        if (exp.getClass() == VariableExpression.class) {
+        Class<? extends Expression> clazz = exp.getClass();
+        if (clazz == VariableExpression.class) {
             return transformVariableExpression((VariableExpression) exp);
         }
-        if (exp.getClass() == BinaryExpression.class) {
+        if (clazz == BinaryExpression.class) {
             return transformBinaryExpression((BinaryExpression) exp);
         }
-        if (exp.getClass() == PropertyExpression.class) {
+        if (clazz == PropertyExpression.class) {
             return transformPropertyExpression((PropertyExpression) exp);
         }
-        if (exp.getClass() == MethodCallExpression.class) {
+        if (clazz == MethodCallExpression.class) {
             return transformMethodCallExpression((MethodCallExpression) exp);
         }
-        if (exp.getClass() == ClosureExpression.class) {
+        if (exp instanceof ClosureExpression) {
             return transformClosureExpression((ClosureExpression) exp);
         }
-        if (exp.getClass() == ConstructorCallExpression.class) {
+        if (clazz == ConstructorCallExpression.class) {
             return transformConstructorCallExpression((ConstructorCallExpression) exp);
         }
-        if (exp.getClass() == ArgumentListExpression.class) {
+        if (clazz == ArgumentListExpression.class) {
             Expression result = exp.transformExpression(this);
             if (inPropertyExpression) {
                 foundArgs = result;
diff --git a/src/test/groovy/transform/stc/LambdaTest.groovy b/src/test/groovy/transform/stc/LambdaTest.groovy
index 4612741..672aab3 100644
--- a/src/test/groovy/transform/stc/LambdaTest.groovy
+++ b/src/test/groovy/transform/stc/LambdaTest.groovy
@@ -955,6 +955,26 @@ class LambdaTest extends GroovyTestCase {
         '''
     }
 
+    void testInitializeBlocks() {
+        assertScript '''
+            import java.util.stream.Collectors
+            
+            @groovy.transform.CompileStatic
+            class Test1 {
+                static sl
+                def il
+                static { sl = [1, 2, 3].stream().map(e -> e + 1).toList() }
+                 
+                {
+                    il = [1, 2, 3].stream().map(e -> e + 2).toList()
+                }
+            }
+            
+            assert [2, 3, 4] == Test1.sl
+            assert [3, 4, 5] == new Test1().il
+        '''
+    }
+
     void testSerialize() {
         assertScript '''
         import java.util.function.Function
@@ -962,8 +982,7 @@ class LambdaTest extends GroovyTestCase {
         interface SerializableFunction<T, R> extends Function<T, R>, Serializable {}
         
         @groovy.transform.CompileStatic
-        class Test1 implements Serializable {
-            private static final long serialVersionUID = -1L;
+        class Test1 {
             def p() {
                     def out = new ByteArrayOutputStream()
                     out.withObjectOutputStream {
@@ -980,12 +999,11 @@ class LambdaTest extends GroovyTestCase {
     }
 
     void testSerializeFailed() {
-        shouldFail(NotSerializableException, '''
+        def errMsg = shouldFail(NotSerializableException, '''
         import java.util.function.Function
         
         @groovy.transform.CompileStatic
-        class Test1 implements Serializable {
-            private static final long serialVersionUID = -1L;
+        class Test1 {
             def p() {
                     def out = new ByteArrayOutputStream()
                     out.withObjectOutputStream {
@@ -999,6 +1017,8 @@ class LambdaTest extends GroovyTestCase {
 
         new Test1().p()
         ''')
+
+        assert errMsg.contains('$Lambda$')
     }
 
     void testDeserialize() {
@@ -1007,8 +1027,7 @@ class LambdaTest extends GroovyTestCase {
         import java.util.function.Function
         
         @groovy.transform.CompileStatic
-        class Test1 implements Serializable {
-            private static final long serialVersionUID = -1L;
+        class Test1 {
             byte[] p() {
                     def out = new ByteArrayOutputStream()
                     out.withObjectOutputStream {
@@ -1031,14 +1050,83 @@ class LambdaTest extends GroovyTestCase {
         '''
     }
 
+    void testDeserializeLambdaInInitializeBlock() {
+        assertScript '''
+            package tests.lambda
+            import java.util.function.Function
+            
+            @groovy.transform.CompileStatic
+            class Test1 implements Serializable {
+                private static final long serialVersionUID = -1L;
+                String a = 'a'
+                SerializableFunction<Integer, String> f
+                 
+                {
+                    f = ((Integer e) -> a + e)
+                }
+                
+                byte[] p() {
+                    def out = new ByteArrayOutputStream()
+                    out.withObjectOutputStream {
+                        it.writeObject(f)
+                    }
+                    
+                    return out.toByteArray()
+                }
+                
+                static void main(String[] args) {
+                    new ByteArrayInputStream(new Test1().p()).withObjectInputStream(Test1.class.classLoader) {
+                        SerializableFunction<Integer, String> f = (SerializableFunction<Integer, String>) it.readObject()
+                        assert 'a1' == f.apply(1)
+                    }
+                }
+                
+                interface SerializableFunction<T, R> extends Function<T, R>, Serializable {}
+            }
+        '''
+    }
+
+    void testDeserializeLambdaInInitializeBlockShouldFail() {
+        def errMsg = shouldFail(NotSerializableException, '''
+            package tests.lambda
+            import java.util.function.Function
+            
+            @groovy.transform.CompileStatic
+            class Test1 {
+                String a = 'a'
+                SerializableFunction<Integer, String> f
+                 
+                {
+                    f = ((Integer e) -> a + e)
+                }
+                
+                byte[] p() {
+                    def out = new ByteArrayOutputStream()
+                    out.withObjectOutputStream {
+                        it.writeObject(f)
+                    }
+                    
+                    return out.toByteArray()
+                }
+                
+                static void main(String[] args) {
+                    new Test1().p()
+                }
+                
+                interface SerializableFunction<T, R> extends Function<T, R>, Serializable {}
+            }
+        ''')
+
+        assert errMsg.contains('tests.lambda.Test1')
+    }
+
 
     void testDeserialize2() {
         assertScript '''
         import java.util.function.Function
         
         @groovy.transform.CompileStatic
-        class Test1 implements Serializable {
-            private static final long serialVersionUID = -1L;
+        class Test1 {
             static byte[] p() {
                     def out = new ByteArrayOutputStream()
                     out.withObjectOutputStream {
@@ -1067,8 +1155,7 @@ class LambdaTest extends GroovyTestCase {
         import java.util.function.Function
         
         @groovy.transform.CompileStatic
-        class Test1 implements Serializable {
-            private static final long serialVersionUID = -1L;
+        class Test1 {
             byte[] p() {
                     def out = new ByteArrayOutputStream()
                     out.withObjectOutputStream {
@@ -1098,8 +1185,7 @@ class LambdaTest extends GroovyTestCase {
         import java.util.function.Function
         
         @groovy.transform.CompileStatic
-        class Test1 implements Serializable {
-            private static final long serialVersionUID = -1L;
+        class Test1 {
             byte[] p() {
                     def out = new ByteArrayOutputStream()
                     String c = 'a'
@@ -1153,7 +1239,7 @@ class LambdaTest extends GroovyTestCase {
         '''
     }
 
-    void testDeserialize6() {
+    void testDeserialize6InstanceFields() {
         assertScript '''
         package tests.lambda
         import java.util.function.Function
@@ -1185,7 +1271,105 @@ class LambdaTest extends GroovyTestCase {
         '''
     }
 
-    void testDeserialize7() {
+    void testDeserialize6InstanceFieldsShouldFail() {
+        def errMsg = shouldFail(NotSerializableException, '''
+        package tests.lambda
+        import java.util.function.Function
+        
+        @groovy.transform.CompileStatic
+        class Test1 {
+            private String c = 'a'
+            
+            byte[] p() {
+                    def out = new ByteArrayOutputStream()
+                    SerializableFunction<Integer, String> f = (Integer e) -> c + e
+                    out.withObjectOutputStream {
+                        it.writeObject(f)
+                    }
+                    
+                    return out.toByteArray()
+            }
+            
+            static void main(String[] args) {
+                new ByteArrayInputStream(new Test1().p()).withObjectInputStream(Test1.class.classLoader) {
+                    SerializableFunction<Integer, String> f = (SerializableFunction<Integer, String>) it.readObject()
+                    assert 'a1' == f.apply(1)
+                }
+            }
+            
+            interface SerializableFunction<T, R> extends Function<T, R>, Serializable {}
+        }
+        ''')
+
+        assert errMsg.contains('tests.lambda.Test1')
+    }
+
+    void testDeserialize6InstanceMethods() {
+        assertScript '''
+        package tests.lambda
+        import java.util.function.Function
+        
+        @groovy.transform.CompileStatic
+        class Test1 implements Serializable {
+            private static final long serialVersionUID = -1L;
+            private String c() { 'a' }
+            
+            byte[] p() {
+                    def out = new ByteArrayOutputStream()
+                    SerializableFunction<Integer, String> f = (Integer e) -> c() + e
+                    out.withObjectOutputStream {
+                        it.writeObject(f)
+                    }
+                    
+                    return out.toByteArray()
+            }
+            
+            static void main(String[] args) {
+                new ByteArrayInputStream(new Test1().p()).withObjectInputStream(Test1.class.classLoader) {
+                    SerializableFunction<Integer, String> f = (SerializableFunction<Integer, String>) it.readObject()
+                    assert 'a1' == f.apply(1)
+                }
+            }
+            
+            interface SerializableFunction<T, R> extends Function<T, R>, Serializable {}
+        }
+        '''
+    }
+
+    void testDeserialize6InstanceMethodsShouldFail() {
+        def errMsg = shouldFail(NotSerializableException, '''
+        package tests.lambda
+        import java.util.function.Function
+        
+        @groovy.transform.CompileStatic
+        class Test1 {
+            private String c() { 'a' }
+            
+            byte[] p() {
+                    def out = new ByteArrayOutputStream()
+                    SerializableFunction<Integer, String> f = (Integer e) -> c() + e
+                    out.withObjectOutputStream {
+                        it.writeObject(f)
+                    }
+                    
+                    return out.toByteArray()
+            }
+            
+            static void main(String[] args) {
+                new ByteArrayInputStream(new Test1().p()).withObjectInputStream(Test1.class.classLoader) {
+                    SerializableFunction<Integer, String> f = (SerializableFunction<Integer, String>) it.readObject()
+                    assert 'a1' == f.apply(1)
+                }
+            }
+            
+            interface SerializableFunction<T, R> extends Function<T, R>, Serializable {}
+        }
+        ''')
+
+        assert errMsg.contains('tests.lambda.Test1')
+    }
+
+    void testDeserialize7StaticFields() {
         assertScript '''
         package tests.lambda
         import java.util.function.Function
@@ -1215,6 +1399,37 @@ class LambdaTest extends GroovyTestCase {
         '''
     }
 
+
+    void testDeserialize7StaticMethods() {
+        assertScript '''
+        package tests.lambda
+        import java.util.function.Function
+        
+        @groovy.transform.CompileStatic
+        class Test1 {
+            private static String c() { 'a' }
+            static byte[] p() {
+                    def out = new ByteArrayOutputStream()
+                    SerializableFunction<Integer, String> f = (Integer e) -> c() + e
+                    out.withObjectOutputStream {
+                        it.writeObject(f)
+                    }
+                    
+                    return out.toByteArray()
+            }
+            
+            static void main(String[] args) {
+                new ByteArrayInputStream(Test1.p()).withObjectInputStream(Test1.class.classLoader) {
+                    SerializableFunction<Integer, String> f = (SerializableFunction<Integer, String>) it.readObject()
+                    assert 'a1' == f.apply(1)
+                }
+            }
+            
+            interface SerializableFunction<T, R> extends Function<T, R>, Serializable {}
+        }
+        '''
+    }
+
     void testDeserializeNestedLambda() {
         assertScript '''
         import java.util.function.Function
@@ -1222,8 +1437,7 @@ class LambdaTest extends GroovyTestCase {
         interface SerializableFunction<T, R> extends Function<T, R>, Serializable {}
         
         @groovy.transform.CompileStatic
-        class Test1 implements Serializable {
-            private static final long serialVersionUID = -1L;
+        class Test1 {
             def p() {
                     def out1 = new ByteArrayOutputStream()
                     SerializableFunction<Integer, String> f1 = (Integer e) -> 'a' + e
@@ -1279,8 +1493,7 @@ class LambdaTest extends GroovyTestCase {
         interface SerializableFunction<T, R> extends Function<T, R>, Serializable {}
         
         @groovy.transform.CompileStatic
-        class Test1 implements Serializable {
-            private static final long serialVersionUID = -1L;
+        class Test1 {
             def p() {
                     def out1 = new ByteArrayOutputStream()
                     out1.withObjectOutputStream {
@@ -1336,8 +1549,7 @@ class LambdaTest extends GroovyTestCase {
         interface SerializableFunction<T, R> extends Function<T, R>, Serializable {}
         
         @groovy.transform.CompileStatic
-        class Test1 implements Serializable {
-            private static final long serialVersionUID = -1L;
+        class Test1 {
             static p() {
                     def out1 = new ByteArrayOutputStream()
                     out1.withObjectOutputStream {
@@ -1441,4 +1653,5 @@ class LambdaTest extends GroovyTestCase {
         }
         '''
     }
+
 }