You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by ch...@apache.org on 2021/11/05 13:42:02 UTC
[flink] 05/05: [FLINK-24550][rpc] Use ContextClassLoader for
message deserialization
This is an automated email from the ASF dual-hosted git repository.
chesnay pushed a commit to branch release-1.14
in repository https://gitbox.apache.org/repos/asf/flink.git
commit 341c13aed36f8d3d633298645002af516dec9050
Author: Chesnay Schepler <ch...@apache.org>
AuthorDate: Thu Nov 4 12:36:49 2021 +0100
[FLINK-24550][rpc] Use ContextClassLoader for message deserialization
---
.../runtime/rpc/akka/AkkaInvocationHandler.java | 10 +--
.../rpc/akka/ContextClassLoadingSettingTest.java | 78 ++++++++++++++++++++++
2 files changed, 84 insertions(+), 4 deletions(-)
diff --git a/flink-rpc/flink-rpc-akka/src/main/java/org/apache/flink/runtime/rpc/akka/AkkaInvocationHandler.java b/flink-rpc/flink-rpc-akka/src/main/java/org/apache/flink/runtime/rpc/akka/AkkaInvocationHandler.java
index db73771..ae43b2f 100644
--- a/flink-rpc/flink-rpc-akka/src/main/java/org/apache/flink/runtime/rpc/akka/AkkaInvocationHandler.java
+++ b/flink-rpc/flink-rpc-akka/src/main/java/org/apache/flink/runtime/rpc/akka/AkkaInvocationHandler.java
@@ -240,7 +240,9 @@ class AkkaInvocationHandler implements InvocationHandler, AkkaBasedEndpoint, Rpc
final CompletableFuture<?> resultFuture =
ask(rpcInvocation, futureTimeout)
.thenApply(
- resultValue -> deserializeValueIfNeeded(resultValue, method));
+ resultValue ->
+ deserializeValueIfNeeded(
+ resultValue, method, flinkClassLoader));
final CompletableFuture<Object> completableFuture = new CompletableFuture<>();
resultFuture.whenComplete(
@@ -414,11 +416,11 @@ class AkkaInvocationHandler implements InvocationHandler, AkkaBasedEndpoint, Rpc
return terminationFuture;
}
- static Object deserializeValueIfNeeded(Object o, Method method) {
+ private static Object deserializeValueIfNeeded(
+ Object o, Method method, ClassLoader flinkClassLoader) {
if (o instanceof AkkaRpcSerializedValue) {
try {
- return ((AkkaRpcSerializedValue) o)
- .deserializeValue(AkkaInvocationHandler.class.getClassLoader());
+ return ((AkkaRpcSerializedValue) o).deserializeValue(flinkClassLoader);
} catch (IOException | ClassNotFoundException e) {
throw new CompletionException(
new RpcException(
diff --git a/flink-rpc/flink-rpc-akka/src/test/java/org/apache/flink/runtime/rpc/akka/ContextClassLoadingSettingTest.java b/flink-rpc/flink-rpc-akka/src/test/java/org/apache/flink/runtime/rpc/akka/ContextClassLoadingSettingTest.java
index 1a585a1..ae62839 100644
--- a/flink-rpc/flink-rpc-akka/src/test/java/org/apache/flink/runtime/rpc/akka/ContextClassLoadingSettingTest.java
+++ b/flink-rpc/flink-rpc-akka/src/test/java/org/apache/flink/runtime/rpc/akka/ContextClassLoadingSettingTest.java
@@ -18,10 +18,12 @@
package org.apache.flink.runtime.rpc.akka;
import org.apache.flink.api.common.time.Time;
+import org.apache.flink.configuration.Configuration;
import org.apache.flink.runtime.concurrent.akka.AkkaFutureUtils;
import org.apache.flink.runtime.rpc.RpcEndpoint;
import org.apache.flink.runtime.rpc.RpcGateway;
import org.apache.flink.runtime.rpc.RpcService;
+import org.apache.flink.runtime.rpc.RpcUtils;
import org.apache.flink.util.TestLogger;
import org.apache.flink.util.concurrent.FutureUtils;
@@ -31,6 +33,11 @@ import org.junit.After;
import org.junit.Before;
import org.junit.Test;
+import javax.annotation.Nullable;
+
+import java.io.IOException;
+import java.io.ObjectInputStream;
+import java.io.Serializable;
import java.net.URL;
import java.net.URLClassLoader;
import java.util.Arrays;
@@ -39,6 +46,7 @@ import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
+import java.util.function.Consumer;
import static org.apache.flink.runtime.concurrent.akka.ClassLoadingUtils.runWithContextClassLoader;
import static org.hamcrest.CoreMatchers.either;
@@ -82,6 +90,8 @@ public class ContextClassLoadingSettingTest extends TestLogger {
actorSystem,
AkkaRpcServiceConfiguration.defaultConfiguration(),
pretendFlinkClassLoader);
+
+ PickyObject.classLoaderAssertion = this::assertIsFlinkClassLoader;
}
@After
@@ -281,6 +291,50 @@ public class ContextClassLoadingSettingTest extends TestLogger {
}
@Test
+ public void testAkkaRpcInvocationHandler_ContextClassLoaderUsedForDeserialization()
+ throws Exception {
+ // setup 2 actor systems and rpc services that support remote connections (for which RPCs go
+ // through serialization)
+ final AkkaRpcService serverAkkaRpcService =
+ new AkkaRpcService(
+ AkkaUtils.createActorSystem(
+ "serverActorSystem",
+ AkkaUtils.getAkkaConfig(
+ new Configuration(), new HostAndPort("localhost", 0))),
+ AkkaRpcServiceConfiguration.defaultConfiguration());
+
+ final AkkaRpcService clientAkkaRpcService =
+ new AkkaRpcService(
+ AkkaUtils.createActorSystem(
+ "clientActorSystem",
+ AkkaUtils.getAkkaConfig(
+ new Configuration(), new HostAndPort("localhost", 0))),
+ AkkaRpcServiceConfiguration.defaultConfiguration(),
+ pretendFlinkClassLoader);
+
+ try {
+ final TestEndpoint rpcEndpoint =
+ new TestEndpoint(serverAkkaRpcService, new PickyObject());
+ rpcEndpoint.start();
+
+ final TestEndpointGateway rpcGateway =
+ rpcEndpoint.getSelfGateway(TestEndpointGateway.class);
+
+ final TestEndpointGateway connect =
+ clientAkkaRpcService
+ .connect(rpcGateway.getAddress(), TestEndpointGateway.class)
+ .get();
+
+ // if the wrong classloader is used the deserialization fails and get() throws an
+ // exception
+ connect.getPickyObject().get();
+ } finally {
+ RpcUtils.terminateRpcService(clientAkkaRpcService, TIMEOUT);
+ RpcUtils.terminateRpcService(serverAkkaRpcService, TIMEOUT);
+ }
+ }
+
+ @Test
public void testSupervisorActor_TerminationFutureCompletedWithFlinkContextClassLoader()
throws Exception {
final TestEndpoint testEndpoint = new TestEndpoint(akkaRpcService);
@@ -315,6 +369,18 @@ public class ContextClassLoadingSettingTest extends TestLogger {
CompletableFuture<ClassLoader> doRunAsync();
void doSomethingWithoutReturningAnything();
+
+ CompletableFuture<PickyObject> getPickyObject();
+ }
+
+ /** An object that only allows deserialiation if its favorite ContextClassLoader is doing it. */
+ private static class PickyObject implements Serializable {
+ static Consumer<ClassLoader> classLoaderAssertion = null;
+
+ private void readObject(ObjectInputStream aInputStream)
+ throws ClassNotFoundException, IOException {
+ classLoaderAssertion.accept(Thread.currentThread().getContextClassLoader());
+ }
}
private static class TestEndpoint extends RpcEndpoint implements TestEndpointGateway {
@@ -325,8 +391,15 @@ public class ContextClassLoadingSettingTest extends TestLogger {
new CompletableFuture<>();
private final CompletableFuture<Void> rpcResponseFuture = new CompletableFuture<>();
+ @Nullable private final PickyObject pickyObject;
+
protected TestEndpoint(RpcService rpcService) {
+ this(rpcService, null);
+ }
+
+ protected TestEndpoint(RpcService rpcService, @Nullable PickyObject pickyObject) {
super(rpcService);
+ this.pickyObject = pickyObject;
}
@Override
@@ -368,6 +441,11 @@ public class ContextClassLoadingSettingTest extends TestLogger {
voidOperationClassLoader.complete(Thread.currentThread().getContextClassLoader());
}
+ @Override
+ public CompletableFuture<PickyObject> getPickyObject() {
+ return CompletableFuture.completedFuture(pickyObject);
+ }
+
public void completeRPCFuture() {
rpcResponseFuture.complete(null);
}