You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by ch...@apache.org on 2021/11/05 13:42:02 UTC

[flink] 05/05: [FLINK-24550][rpc] Use ContextClassLoader for message deserialization

This is an automated email from the ASF dual-hosted git repository.

chesnay pushed a commit to branch release-1.14
in repository https://gitbox.apache.org/repos/asf/flink.git

commit 341c13aed36f8d3d633298645002af516dec9050
Author: Chesnay Schepler <ch...@apache.org>
AuthorDate: Thu Nov 4 12:36:49 2021 +0100

    [FLINK-24550][rpc] Use ContextClassLoader for message deserialization
---
 .../runtime/rpc/akka/AkkaInvocationHandler.java    | 10 +--
 .../rpc/akka/ContextClassLoadingSettingTest.java   | 78 ++++++++++++++++++++++
 2 files changed, 84 insertions(+), 4 deletions(-)

diff --git a/flink-rpc/flink-rpc-akka/src/main/java/org/apache/flink/runtime/rpc/akka/AkkaInvocationHandler.java b/flink-rpc/flink-rpc-akka/src/main/java/org/apache/flink/runtime/rpc/akka/AkkaInvocationHandler.java
index db73771..ae43b2f 100644
--- a/flink-rpc/flink-rpc-akka/src/main/java/org/apache/flink/runtime/rpc/akka/AkkaInvocationHandler.java
+++ b/flink-rpc/flink-rpc-akka/src/main/java/org/apache/flink/runtime/rpc/akka/AkkaInvocationHandler.java
@@ -240,7 +240,9 @@ class AkkaInvocationHandler implements InvocationHandler, AkkaBasedEndpoint, Rpc
             final CompletableFuture<?> resultFuture =
                     ask(rpcInvocation, futureTimeout)
                             .thenApply(
-                                    resultValue -> deserializeValueIfNeeded(resultValue, method));
+                                    resultValue ->
+                                            deserializeValueIfNeeded(
+                                                    resultValue, method, flinkClassLoader));
 
             final CompletableFuture<Object> completableFuture = new CompletableFuture<>();
             resultFuture.whenComplete(
@@ -414,11 +416,11 @@ class AkkaInvocationHandler implements InvocationHandler, AkkaBasedEndpoint, Rpc
         return terminationFuture;
     }
 
-    static Object deserializeValueIfNeeded(Object o, Method method) {
+    private static Object deserializeValueIfNeeded(
+            Object o, Method method, ClassLoader flinkClassLoader) {
         if (o instanceof AkkaRpcSerializedValue) {
             try {
-                return ((AkkaRpcSerializedValue) o)
-                        .deserializeValue(AkkaInvocationHandler.class.getClassLoader());
+                return ((AkkaRpcSerializedValue) o).deserializeValue(flinkClassLoader);
             } catch (IOException | ClassNotFoundException e) {
                 throw new CompletionException(
                         new RpcException(
diff --git a/flink-rpc/flink-rpc-akka/src/test/java/org/apache/flink/runtime/rpc/akka/ContextClassLoadingSettingTest.java b/flink-rpc/flink-rpc-akka/src/test/java/org/apache/flink/runtime/rpc/akka/ContextClassLoadingSettingTest.java
index 1a585a1..ae62839 100644
--- a/flink-rpc/flink-rpc-akka/src/test/java/org/apache/flink/runtime/rpc/akka/ContextClassLoadingSettingTest.java
+++ b/flink-rpc/flink-rpc-akka/src/test/java/org/apache/flink/runtime/rpc/akka/ContextClassLoadingSettingTest.java
@@ -18,10 +18,12 @@
 package org.apache.flink.runtime.rpc.akka;
 
 import org.apache.flink.api.common.time.Time;
+import org.apache.flink.configuration.Configuration;
 import org.apache.flink.runtime.concurrent.akka.AkkaFutureUtils;
 import org.apache.flink.runtime.rpc.RpcEndpoint;
 import org.apache.flink.runtime.rpc.RpcGateway;
 import org.apache.flink.runtime.rpc.RpcService;
+import org.apache.flink.runtime.rpc.RpcUtils;
 import org.apache.flink.util.TestLogger;
 import org.apache.flink.util.concurrent.FutureUtils;
 
@@ -31,6 +33,11 @@ import org.junit.After;
 import org.junit.Before;
 import org.junit.Test;
 
+import javax.annotation.Nullable;
+
+import java.io.IOException;
+import java.io.ObjectInputStream;
+import java.io.Serializable;
 import java.net.URL;
 import java.net.URLClassLoader;
 import java.util.Arrays;
@@ -39,6 +46,7 @@ import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.ExecutionException;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.TimeoutException;
+import java.util.function.Consumer;
 
 import static org.apache.flink.runtime.concurrent.akka.ClassLoadingUtils.runWithContextClassLoader;
 import static org.hamcrest.CoreMatchers.either;
@@ -82,6 +90,8 @@ public class ContextClassLoadingSettingTest extends TestLogger {
                         actorSystem,
                         AkkaRpcServiceConfiguration.defaultConfiguration(),
                         pretendFlinkClassLoader);
+
+        PickyObject.classLoaderAssertion = this::assertIsFlinkClassLoader;
     }
 
     @After
@@ -281,6 +291,50 @@ public class ContextClassLoadingSettingTest extends TestLogger {
     }
 
     @Test
+    public void testAkkaRpcInvocationHandler_ContextClassLoaderUsedForDeserialization()
+            throws Exception {
+        // setup 2 actor systems and rpc services that support remote connections (for which RPCs go
+        // through serialization)
+        final AkkaRpcService serverAkkaRpcService =
+                new AkkaRpcService(
+                        AkkaUtils.createActorSystem(
+                                "serverActorSystem",
+                                AkkaUtils.getAkkaConfig(
+                                        new Configuration(), new HostAndPort("localhost", 0))),
+                        AkkaRpcServiceConfiguration.defaultConfiguration());
+
+        final AkkaRpcService clientAkkaRpcService =
+                new AkkaRpcService(
+                        AkkaUtils.createActorSystem(
+                                "clientActorSystem",
+                                AkkaUtils.getAkkaConfig(
+                                        new Configuration(), new HostAndPort("localhost", 0))),
+                        AkkaRpcServiceConfiguration.defaultConfiguration(),
+                        pretendFlinkClassLoader);
+
+        try {
+            final TestEndpoint rpcEndpoint =
+                    new TestEndpoint(serverAkkaRpcService, new PickyObject());
+            rpcEndpoint.start();
+
+            final TestEndpointGateway rpcGateway =
+                    rpcEndpoint.getSelfGateway(TestEndpointGateway.class);
+
+            final TestEndpointGateway connect =
+                    clientAkkaRpcService
+                            .connect(rpcGateway.getAddress(), TestEndpointGateway.class)
+                            .get();
+
+            // if the wrong classloader is used the deserialization fails and get() throws an
+            // exception
+            connect.getPickyObject().get();
+        } finally {
+            RpcUtils.terminateRpcService(clientAkkaRpcService, TIMEOUT);
+            RpcUtils.terminateRpcService(serverAkkaRpcService, TIMEOUT);
+        }
+    }
+
+    @Test
     public void testSupervisorActor_TerminationFutureCompletedWithFlinkContextClassLoader()
             throws Exception {
         final TestEndpoint testEndpoint = new TestEndpoint(akkaRpcService);
@@ -315,6 +369,18 @@ public class ContextClassLoadingSettingTest extends TestLogger {
         CompletableFuture<ClassLoader> doRunAsync();
 
         void doSomethingWithoutReturningAnything();
+
+        CompletableFuture<PickyObject> getPickyObject();
+    }
+
+    /** An object that only allows deserialiation if its favorite ContextClassLoader is doing it. */
+    private static class PickyObject implements Serializable {
+        static Consumer<ClassLoader> classLoaderAssertion = null;
+
+        private void readObject(ObjectInputStream aInputStream)
+                throws ClassNotFoundException, IOException {
+            classLoaderAssertion.accept(Thread.currentThread().getContextClassLoader());
+        }
     }
 
     private static class TestEndpoint extends RpcEndpoint implements TestEndpointGateway {
@@ -325,8 +391,15 @@ public class ContextClassLoadingSettingTest extends TestLogger {
                 new CompletableFuture<>();
         private final CompletableFuture<Void> rpcResponseFuture = new CompletableFuture<>();
 
+        @Nullable private final PickyObject pickyObject;
+
         protected TestEndpoint(RpcService rpcService) {
+            this(rpcService, null);
+        }
+
+        protected TestEndpoint(RpcService rpcService, @Nullable PickyObject pickyObject) {
             super(rpcService);
+            this.pickyObject = pickyObject;
         }
 
         @Override
@@ -368,6 +441,11 @@ public class ContextClassLoadingSettingTest extends TestLogger {
             voidOperationClassLoader.complete(Thread.currentThread().getContextClassLoader());
         }
 
+        @Override
+        public CompletableFuture<PickyObject> getPickyObject() {
+            return CompletableFuture.completedFuture(pickyObject);
+        }
+
         public void completeRPCFuture() {
             rpcResponseFuture.complete(null);
         }