You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by tr...@apache.org on 2021/02/18 18:00:36 UTC

[flink] 01/02: [FLINK-20580][rpc] Separate wire value class from user values

This is an automated email from the ASF dual-hosted git repository.

trohrmann pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit bbdd769253b25b4093a1759c835c6ff1d99d390d
Author: Kezhu Wang <ke...@gmail.com>
AuthorDate: Sat Feb 13 00:06:53 2021 +0800

    [FLINK-20580][rpc] Separate wire value class from user values
---
 .../runtime/rpc/akka/AkkaInvocationHandler.java    |   5 +-
 .../flink/runtime/rpc/akka/AkkaRpcActor.java       |  18 ++-
 .../runtime/rpc/akka/AkkaRpcSerializedValue.java   |  88 +++++++++++++++
 .../flink/runtime/rpc/akka/AkkaRpcActorTest.java   |  60 ++++++++++
 .../rpc/akka/AkkaRpcSerializedValueTest.java       | 125 +++++++++++++++++++++
 .../runtime/rpc/akka/RemoteAkkaRpcActorTest.java   |  28 +++++
 6 files changed, 311 insertions(+), 13 deletions(-)

diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/AkkaInvocationHandler.java b/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/AkkaInvocationHandler.java
index a6d6a4b..5ec3f1d 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/AkkaInvocationHandler.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/AkkaInvocationHandler.java
@@ -34,7 +34,6 @@ import org.apache.flink.runtime.rpc.messages.RpcInvocation;
 import org.apache.flink.runtime.rpc.messages.RunAsync;
 import org.apache.flink.util.ExceptionUtils;
 import org.apache.flink.util.Preconditions;
-import org.apache.flink.util.SerializedValue;
 
 import akka.actor.ActorRef;
 import akka.pattern.Patterns;
@@ -389,9 +388,9 @@ class AkkaInvocationHandler implements InvocationHandler, AkkaBasedEndpoint, Rpc
     }
 
     static Object deserializeValueIfNeeded(Object o, Method method) {
-        if (o instanceof SerializedValue) {
+        if (o instanceof AkkaRpcSerializedValue) {
             try {
-                return ((SerializedValue<?>) o)
+                return ((AkkaRpcSerializedValue) o)
                         .deserializeValue(AkkaInvocationHandler.class.getClassLoader());
             } catch (IOException | ClassNotFoundException e) {
                 throw new CompletionException(
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/AkkaRpcActor.java b/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/AkkaRpcActor.java
index 08fe9b5..b0a2a82 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/AkkaRpcActor.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/AkkaRpcActor.java
@@ -36,7 +36,6 @@ import org.apache.flink.runtime.rpc.messages.RunAsync;
 import org.apache.flink.types.Either;
 import org.apache.flink.util.ExceptionUtils;
 import org.apache.flink.util.Preconditions;
-import org.apache.flink.util.SerializedValue;
 
 import akka.actor.AbstractActor;
 import akka.actor.ActorRef;
@@ -332,7 +331,7 @@ class AkkaRpcActor<T extends RpcEndpoint & RpcGateway> extends AbstractActor {
 
     private void sendSyncResponse(Object response, String methodName) {
         if (isRemoteSender(getSender())) {
-            Either<SerializedValue<?>, AkkaRpcException> serializedResult =
+            Either<AkkaRpcSerializedValue, AkkaRpcException> serializedResult =
                     serializeRemoteResultAndVerifySize(response, methodName);
 
             if (serializedResult.isLeft()) {
@@ -356,8 +355,10 @@ class AkkaRpcActor<T extends RpcEndpoint & RpcGateway> extends AbstractActor {
                                 promise.failure(throwable);
                             } else {
                                 if (isRemoteSender(sender)) {
-                                    Either<SerializedValue<?>, AkkaRpcException> serializedResult =
-                                            serializeRemoteResultAndVerifySize(value, methodName);
+                                    Either<AkkaRpcSerializedValue, AkkaRpcException>
+                                            serializedResult =
+                                                    serializeRemoteResultAndVerifySize(
+                                                            value, methodName);
 
                                     if (serializedResult.isLeft()) {
                                         promise.success(serializedResult.left());
@@ -380,15 +381,12 @@ class AkkaRpcActor<T extends RpcEndpoint & RpcGateway> extends AbstractActor {
         return !sender.path().address().hasLocalScope();
     }
 
-    private Either<SerializedValue<?>, AkkaRpcException> serializeRemoteResultAndVerifySize(
+    private Either<AkkaRpcSerializedValue, AkkaRpcException> serializeRemoteResultAndVerifySize(
             Object result, String methodName) {
         try {
-            SerializedValue<?> serializedResult = new SerializedValue<>(result);
+            AkkaRpcSerializedValue serializedResult = AkkaRpcSerializedValue.valueOf(result);
 
-            long resultSize =
-                    serializedResult.getByteArray() == null
-                            ? 0
-                            : serializedResult.getByteArray().length;
+            long resultSize = serializedResult.getSerializedDataLength();
             if (resultSize > maximumFramesize) {
                 return Either.Right(
                         new AkkaRpcException(
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/AkkaRpcSerializedValue.java b/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/AkkaRpcSerializedValue.java
new file mode 100644
index 0000000..fb8011c
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/rpc/akka/AkkaRpcSerializedValue.java
@@ -0,0 +1,88 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.rpc.akka;
+
+import org.apache.flink.util.InstantiationUtil;
+import org.apache.flink.util.Preconditions;
+
+import javax.annotation.Nullable;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.Arrays;
+
+/** A self-contained serialized value to decouple from user values and transfer on wire. */
+final class AkkaRpcSerializedValue implements Serializable {
+    private static final long serialVersionUID = -4388571068440835689L;
+
+    @Nullable private final byte[] serializedData;
+
+    private AkkaRpcSerializedValue(@Nullable byte[] serializedData) {
+        this.serializedData = serializedData;
+    }
+
+    @Nullable
+    public byte[] getSerializedData() {
+        return serializedData;
+    }
+
+    /** Return length of serialized data, zero if no serialized data. */
+    public int getSerializedDataLength() {
+        return serializedData == null ? 0 : serializedData.length;
+    }
+
+    @Nullable
+    public <T> T deserializeValue(ClassLoader loader) throws IOException, ClassNotFoundException {
+        Preconditions.checkNotNull(loader, "No classloader has been passed");
+        return serializedData == null
+                ? null
+                : InstantiationUtil.deserializeObject(serializedData, loader);
+    }
+
+    /**
+     * Construct a serialized value to transfer on wire.
+     *
+     * @param value nullable value
+     * @return serialized value to transfer on wire
+     * @throws IOException exception during value serialization
+     */
+    public static AkkaRpcSerializedValue valueOf(@Nullable Object value) throws IOException {
+        byte[] serializedData = value == null ? null : InstantiationUtil.serializeObject(value);
+        return new AkkaRpcSerializedValue(serializedData);
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (o instanceof AkkaRpcSerializedValue) {
+            AkkaRpcSerializedValue other = (AkkaRpcSerializedValue) o;
+            return Arrays.equals(serializedData, other.serializedData);
+        }
+        return false;
+    }
+
+    @Override
+    public int hashCode() {
+        return Arrays.hashCode(serializedData);
+    }
+
+    @Override
+    public String toString() {
+        return serializedData == null ? "AkkaRpcSerializedValue(null)" : "AkkaRpcSerializedValue";
+    }
+}
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/rpc/akka/AkkaRpcActorTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/rpc/akka/AkkaRpcActorTest.java
index 689e14f..c833027 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/rpc/akka/AkkaRpcActorTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/rpc/akka/AkkaRpcActorTest.java
@@ -34,6 +34,7 @@ import org.apache.flink.util.ExceptionUtils;
 import org.apache.flink.util.FlinkException;
 import org.apache.flink.util.FlinkRuntimeException;
 import org.apache.flink.util.Preconditions;
+import org.apache.flink.util.SerializedValue;
 import org.apache.flink.util.TestLogger;
 
 import akka.actor.ActorRef;
@@ -47,6 +48,8 @@ import org.slf4j.LoggerFactory;
 
 import javax.annotation.Nullable;
 
+import java.io.IOException;
+import java.io.UncheckedIOException;
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.ExecutionException;
@@ -535,6 +538,28 @@ public class AkkaRpcActorTest extends TestLogger {
         }
     }
 
+    @Test
+    public void canRespondWithSerializedValueLocally() throws Exception {
+        try (final SerializedValueRespondingEndpoint endpoint =
+                new SerializedValueRespondingEndpoint(akkaRpcService)) {
+            endpoint.start();
+
+            final SerializedValueRespondingGateway selfGateway =
+                    endpoint.getSelfGateway(SerializedValueRespondingGateway.class);
+
+            assertThat(
+                    selfGateway.getSerializedValueSynchronously(),
+                    equalTo(SerializedValueRespondingEndpoint.SERIALIZED_VALUE));
+
+            final CompletableFuture<SerializedValue<String>> responseFuture =
+                    selfGateway.getSerializedValue();
+
+            assertThat(
+                    responseFuture.get(),
+                    equalTo(SerializedValueRespondingEndpoint.SERIALIZED_VALUE));
+        }
+    }
+
     // ------------------------------------------------------------------------
     //  Test Actors and Interfaces
     // ------------------------------------------------------------------------
@@ -586,6 +611,41 @@ public class AkkaRpcActorTest extends TestLogger {
 
     // ------------------------------------------------------------------------
 
+    interface SerializedValueRespondingGateway extends RpcGateway {
+        CompletableFuture<SerializedValue<String>> getSerializedValue();
+
+        SerializedValue<String> getSerializedValueSynchronously();
+    }
+
+    static class SerializedValueRespondingEndpoint extends RpcEndpoint
+            implements SerializedValueRespondingGateway {
+        static final SerializedValue<String> SERIALIZED_VALUE;
+
+        static {
+            try {
+                SERIALIZED_VALUE = new SerializedValue<>("string-value");
+            } catch (IOException e) {
+                throw new UncheckedIOException(e);
+            }
+        }
+
+        public SerializedValueRespondingEndpoint(RpcService rpcService) {
+            super(rpcService);
+        }
+
+        @Override
+        public CompletableFuture<SerializedValue<String>> getSerializedValue() {
+            return CompletableFuture.completedFuture(SERIALIZED_VALUE);
+        }
+
+        @Override
+        public SerializedValue<String> getSerializedValueSynchronously() {
+            return SERIALIZED_VALUE;
+        }
+    }
+
+    // ------------------------------------------------------------------------
+
     private interface ExceptionalGateway extends RpcGateway {
         CompletableFuture<Integer> doStuff();
     }
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/rpc/akka/AkkaRpcSerializedValueTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/rpc/akka/AkkaRpcSerializedValueTest.java
new file mode 100644
index 0000000..70263a6
--- /dev/null
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/rpc/akka/AkkaRpcSerializedValueTest.java
@@ -0,0 +1,125 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.rpc.akka;
+
+import org.apache.flink.util.InstantiationUtil;
+import org.apache.flink.util.TestLogger;
+
+import org.junit.Test;
+
+import java.math.BigDecimal;
+import java.math.BigInteger;
+import java.time.Instant;
+import java.util.Set;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.greaterThan;
+import static org.hamcrest.Matchers.not;
+import static org.hamcrest.Matchers.notNullValue;
+import static org.hamcrest.Matchers.nullValue;
+import static org.junit.Assert.assertThat;
+
+/** Tests for the {@link AkkaRpcSerializedValue}. */
+public class AkkaRpcSerializedValueTest extends TestLogger {
+
+    @Test
+    public void testNullValue() throws Exception {
+        AkkaRpcSerializedValue serializedValue = AkkaRpcSerializedValue.valueOf(null);
+        assertThat(serializedValue.getSerializedData(), nullValue());
+        assertThat(serializedValue.getSerializedDataLength(), equalTo(0));
+        assertThat(serializedValue.deserializeValue(getClass().getClassLoader()), nullValue());
+
+        AkkaRpcSerializedValue otherSerializedValue = AkkaRpcSerializedValue.valueOf(null);
+        assertThat(otherSerializedValue, equalTo(serializedValue));
+        assertThat(otherSerializedValue.hashCode(), equalTo(serializedValue.hashCode()));
+
+        AkkaRpcSerializedValue clonedSerializedValue = InstantiationUtil.clone(serializedValue);
+        assertThat(clonedSerializedValue.getSerializedData(), nullValue());
+        assertThat(clonedSerializedValue.getSerializedDataLength(), equalTo(0));
+        assertThat(
+                clonedSerializedValue.deserializeValue(getClass().getClassLoader()), nullValue());
+        assertThat(clonedSerializedValue, equalTo(serializedValue));
+        assertThat(clonedSerializedValue.hashCode(), equalTo(serializedValue.hashCode()));
+    }
+
+    @Test
+    public void testNotNullValues() throws Exception {
+        Set<Object> values =
+                Stream.of(
+                                true,
+                                (byte) 5,
+                                (short) 6,
+                                5,
+                                5L,
+                                5.5F,
+                                6.5,
+                                'c',
+                                "string",
+                                Instant.now(),
+                                BigInteger.valueOf(Long.MAX_VALUE).add(BigInteger.TEN),
+                                BigDecimal.valueOf(Math.PI))
+                        .collect(Collectors.toSet());
+
+        Object previousValue = null;
+        AkkaRpcSerializedValue previousSerializedValue = null;
+        for (Object value : values) {
+            AkkaRpcSerializedValue serializedValue = AkkaRpcSerializedValue.valueOf(value);
+            assertThat(value.toString(), serializedValue.getSerializedData(), notNullValue());
+            assertThat(value.toString(), serializedValue.getSerializedDataLength(), greaterThan(0));
+            assertThat(
+                    value.toString(),
+                    serializedValue.deserializeValue(getClass().getClassLoader()),
+                    equalTo(value));
+
+            AkkaRpcSerializedValue otherSerializedValue = AkkaRpcSerializedValue.valueOf(value);
+            assertThat(value.toString(), otherSerializedValue, equalTo(serializedValue));
+            assertThat(
+                    value.toString(),
+                    otherSerializedValue.hashCode(),
+                    equalTo(serializedValue.hashCode()));
+
+            AkkaRpcSerializedValue clonedSerializedValue = InstantiationUtil.clone(serializedValue);
+            assertThat(
+                    value.toString(),
+                    clonedSerializedValue.getSerializedData(),
+                    equalTo(serializedValue.getSerializedData()));
+            assertThat(
+                    value.toString(),
+                    clonedSerializedValue.deserializeValue(getClass().getClassLoader()),
+                    equalTo(value));
+            assertThat(value.toString(), clonedSerializedValue, equalTo(serializedValue));
+            assertThat(
+                    value.toString(),
+                    clonedSerializedValue.hashCode(),
+                    equalTo(serializedValue.hashCode()));
+
+            if (previousValue != null && !previousValue.equals(value)) {
+                assertThat(
+                        value.toString() + " " + previousValue.toString(),
+                        serializedValue,
+                        not(equalTo(previousSerializedValue)));
+            }
+
+            previousValue = value;
+            previousSerializedValue = serializedValue;
+        }
+    }
+}
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/rpc/akka/RemoteAkkaRpcActorTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/rpc/akka/RemoteAkkaRpcActorTest.java
index a9940d8..3cde7f1 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/rpc/akka/RemoteAkkaRpcActorTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/rpc/akka/RemoteAkkaRpcActorTest.java
@@ -21,6 +21,7 @@ package org.apache.flink.runtime.rpc.akka;
 import org.apache.flink.api.common.time.Time;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.runtime.rpc.RpcUtils;
+import org.apache.flink.util.SerializedValue;
 import org.apache.flink.util.TestLogger;
 
 import org.junit.AfterClass;
@@ -32,6 +33,7 @@ import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.ExecutionException;
 import java.util.concurrent.TimeoutException;
 
+import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.is;
 import static org.hamcrest.Matchers.nullValue;
 import static org.junit.Assert.assertThat;
@@ -97,4 +99,30 @@ public class RemoteAkkaRpcActorTest extends TestLogger {
             assertThat(value, is(nullValue()));
         }
     }
+
+    @Test
+    public void canRespondWithSerializedValueRemotely() throws Exception {
+        try (final AkkaRpcActorTest.SerializedValueRespondingEndpoint endpoint =
+                new AkkaRpcActorTest.SerializedValueRespondingEndpoint(rpcService)) {
+            endpoint.start();
+
+            final AkkaRpcActorTest.SerializedValueRespondingGateway remoteGateway =
+                    otherRpcService
+                            .connect(
+                                    endpoint.getAddress(),
+                                    AkkaRpcActorTest.SerializedValueRespondingGateway.class)
+                            .join();
+
+            assertThat(
+                    remoteGateway.getSerializedValueSynchronously(),
+                    equalTo(AkkaRpcActorTest.SerializedValueRespondingEndpoint.SERIALIZED_VALUE));
+
+            final CompletableFuture<SerializedValue<String>> responseFuture =
+                    remoteGateway.getSerializedValue();
+
+            assertThat(
+                    responseFuture.get(),
+                    equalTo(AkkaRpcActorTest.SerializedValueRespondingEndpoint.SERIALIZED_VALUE));
+        }
+    }
 }