You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by uc...@apache.org on 2016/08/09 14:47:42 UTC
[08/10] flink git commit: [FLINK-3779] [runtime] Add KvState network
client and server
http://git-wip-us.apache.org/repos/asf/flink/blob/af07eed8/flink-runtime/src/main/java/org/apache/flink/runtime/query/netty/message/KvStateRequestSerializer.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/query/netty/message/KvStateRequestSerializer.java b/flink-runtime/src/main/java/org/apache/flink/runtime/query/netty/message/KvStateRequestSerializer.java
new file mode 100644
index 0000000..0ae60f6
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/query/netty/message/KvStateRequestSerializer.java
@@ -0,0 +1,518 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.query.netty.message;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.ByteBufAllocator;
+import io.netty.buffer.ByteBufInputStream;
+import io.netty.buffer.ByteBufOutputStream;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.runtime.query.KvStateID;
+import org.apache.flink.runtime.query.netty.KvStateClient;
+import org.apache.flink.runtime.query.netty.KvStateServer;
+import org.apache.flink.runtime.util.DataInputDeserializer;
+import org.apache.flink.runtime.util.DataOutputSerializer;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.io.ObjectInputStream;
+import java.io.ObjectOutputStream;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.Iterator;
+import java.util.List;
+
+/**
+ * Serialization and deserialization of messages exchanged between
+ * {@link KvStateClient} and {@link KvStateServer}.
+ *
+ * <p>The binary messages have the following format:
+ *
+ * <pre>
+ * <------ Frame ------------------------->
+ * +----------------------------------------+
+ * | HEADER (8) | PAYLOAD (VAR) |
+ * +------------------+----------------------------------------+
+ * | FRAME LENGTH (4) | VERSION (4) | TYPE (4) | CONTENT (VAR) |
+ * +------------------+----------------------------------------+
+ * </pre>
+ *
+ * <p>The concrete content of a message depends on the {@link KvStateRequestType}.
+ */
+public final class KvStateRequestSerializer {
+
+ /** The serialization version ID. */
+ private static final int VERSION = 0x79a1b710;
+
+ /** Byte length of the header. */
+ private static final int HEADER_LENGTH = 8;
+
+ // ------------------------------------------------------------------------
+ // Serialization
+ // ------------------------------------------------------------------------
+
+ /**
+ * Allocates a buffer and serializes the KvState request into it.
+ *
+ * @param alloc ByteBuf allocator for the buffer to
+ * serialize message into
+ * @param requestId ID for this request
+ * @param kvStateId ID of the requested KvState instance
+ * @param serializedKeyAndNamespace Serialized key and namespace to request
+ * from the KvState instance.
+ * @return Serialized KvState request message
+ */
+ public static ByteBuf serializeKvStateRequest(
+ ByteBufAllocator alloc,
+ long requestId,
+ KvStateID kvStateId,
+ byte[] serializedKeyAndNamespace) {
+
+ // Header + request ID + KvState ID + Serialized namespace
+ int frameLength = HEADER_LENGTH + 8 + (8 + 8) + (4 + serializedKeyAndNamespace.length);
+ ByteBuf buf = alloc.ioBuffer(frameLength + 4); // +4 for frame length
+
+ buf.writeInt(frameLength);
+
+ writeHeader(buf, KvStateRequestType.REQUEST);
+
+ buf.writeLong(requestId);
+ buf.writeLong(kvStateId.getLowerPart());
+ buf.writeLong(kvStateId.getUpperPart());
+ buf.writeInt(serializedKeyAndNamespace.length);
+ buf.writeBytes(serializedKeyAndNamespace);
+
+ return buf;
+ }
+
+ /**
+ * Allocates a buffer and serializes the KvState request result into it.
+ *
+ * @param alloc ByteBuf allocator for the buffer to serialize message into
+ * @param requestId ID for this request
+ * @param serializedResult Serialized Result
+ * @return Serialized KvState request result message
+ */
+ public static ByteBuf serializeKvStateRequestResult(
+ ByteBufAllocator alloc,
+ long requestId,
+ byte[] serializedResult) {
+
+ Preconditions.checkNotNull(serializedResult, "Serialized result");
+
+ // Header + request ID + serialized result
+ int frameLength = HEADER_LENGTH + 8 + 4 + serializedResult.length;
+
+ ByteBuf buf = alloc.ioBuffer(frameLength);
+
+ buf.writeInt(frameLength);
+ writeHeader(buf, KvStateRequestType.REQUEST_RESULT);
+ buf.writeLong(requestId);
+
+ buf.writeInt(serializedResult.length);
+ buf.writeBytes(serializedResult);
+
+ return buf;
+ }
+
+ /**
+ * Allocates a buffer and serializes the KvState request failure into it.
+ *
+ * @param alloc ByteBuf allocator for the buffer to serialize message into
+ * @param requestId ID of the request responding to
+ * @param cause Failure cause
+ * @return Serialized KvState request failure message
+ * @throws IOException Serialization failures are forwarded
+ */
+ public static ByteBuf serializeKvStateRequestFailure(
+ ByteBufAllocator alloc,
+ long requestId,
+ Throwable cause) throws IOException {
+
+ ByteBuf buf = alloc.ioBuffer();
+
+ // Frame length is set at the end
+ buf.writeInt(0);
+
+ writeHeader(buf, KvStateRequestType.REQUEST_FAILURE);
+
+ // Message
+ buf.writeLong(requestId);
+
+ try (ByteBufOutputStream bbos = new ByteBufOutputStream(buf);
+ ObjectOutputStream out = new ObjectOutputStream(bbos)) {
+
+ out.writeObject(cause);
+ }
+
+ // Set frame length
+ int frameLength = buf.readableBytes() - 4;
+ buf.setInt(0, frameLength);
+
+ return buf;
+ }
+
+ /**
+ * Allocates a buffer and serializes the server failure into it.
+ *
+ * <p>The cause must not be or contain any user types as causes.
+ *
+ * @param alloc ByteBuf allocator for the buffer to serialize message into
+ * @param cause Failure cause
+ * @return Serialized server failure message
+ * @throws IOException Serialization failures are forwarded
+ */
+ public static ByteBuf serializeServerFailure(ByteBufAllocator alloc, Throwable cause) throws IOException {
+ ByteBuf buf = alloc.ioBuffer();
+
+ // Frame length is set at end
+ buf.writeInt(0);
+
+ writeHeader(buf, KvStateRequestType.SERVER_FAILURE);
+
+ try (ByteBufOutputStream bbos = new ByteBufOutputStream(buf);
+ ObjectOutputStream out = new ObjectOutputStream(bbos)) {
+
+ out.writeObject(cause);
+ }
+
+ // Set frame length
+ int frameLength = buf.readableBytes() - 4;
+ buf.setInt(0, frameLength);
+
+ return buf;
+ }
+
+ // ------------------------------------------------------------------------
+ // Deserialization
+ // ------------------------------------------------------------------------
+
+ /**
+ * Deserializes the header and returns the request type.
+ *
+ * @param buf Buffer to deserialize (expected to be at header position)
+ * @return Deserialzied request type
+ * @throws IllegalArgumentException If unexpected message version or message type
+ */
+ public static KvStateRequestType deserializeHeader(ByteBuf buf) {
+ // Check the version
+ int version = buf.readInt();
+ if (version != VERSION) {
+ throw new IllegalArgumentException("Illegal message version " + version +
+ ". Expected: " + VERSION + ".");
+ }
+
+ // Get the message type
+ int msgType = buf.readInt();
+ KvStateRequestType[] values = KvStateRequestType.values();
+ if (msgType >= 0 && msgType <= values.length) {
+ return values[msgType];
+ } else {
+ throw new IllegalArgumentException("Illegal message type with index " + msgType);
+ }
+ }
+
+ /**
+ * Deserializes the KvState request message.
+ *
+ * <p><strong>Important</strong>: the returned buffer is sliced from the
+ * incoming ByteBuf stream and retained. Therefore, it needs to be recycled
+ * by the consumer.
+ *
+ * @param buf Buffer to deserialize (expected to be positioned after header)
+ * @return Deserialized KvStateRequest
+ */
+ public static KvStateRequest deserializeKvStateRequest(ByteBuf buf) {
+ long requestId = buf.readLong();
+ KvStateID kvStateId = new KvStateID(buf.readLong(), buf.readLong());
+
+ // Serialized key and namespace
+ int length = buf.readInt();
+
+ if (length < 0) {
+ throw new IllegalArgumentException("Negative length for serialized key and namespace. " +
+ "This indicates a serialization error.");
+ }
+
+ // Copy the buffer in order to be able to safely recycle the ByteBuf
+ byte[] serializedKeyAndNamespace = new byte[length];
+ if (length > 0) {
+ buf.readBytes(serializedKeyAndNamespace);
+ }
+
+ return new KvStateRequest(requestId, kvStateId, serializedKeyAndNamespace);
+ }
+
+ /**
+ * Deserializes the KvState request result.
+ *
+ * @param buf Buffer to deserialize (expected to be positioned after header)
+ * @return Deserialized KvStateRequestResult
+ */
+ public static KvStateRequestResult deserializeKvStateRequestResult(ByteBuf buf) {
+ long requestId = buf.readLong();
+
+ // Serialized KvState
+ int length = buf.readInt();
+
+ if (length < 0) {
+ throw new IllegalArgumentException("Negative length for serialized result. " +
+ "This indicates a serialization error.");
+ }
+
+ byte[] serializedValue = new byte[length];
+
+ if (length > 0) {
+ buf.readBytes(serializedValue);
+ }
+
+ return new KvStateRequestResult(requestId, serializedValue);
+ }
+
+ /**
+ * Deserializes the KvState request failure.
+ *
+ * @param buf Buffer to deserialize (expected to be positioned after header)
+ * @return Deserialized KvStateRequestFailure
+ */
+ public static KvStateRequestFailure deserializeKvStateRequestFailure(ByteBuf buf) throws IOException, ClassNotFoundException {
+ long requestId = buf.readLong();
+
+ Throwable cause;
+ try (ByteBufInputStream bbis = new ByteBufInputStream(buf);
+ ObjectInputStream in = new ObjectInputStream(bbis)) {
+
+ cause = (Throwable) in.readObject();
+ }
+
+ return new KvStateRequestFailure(requestId, cause);
+ }
+
+ /**
+ * Deserializes the KvState request failure.
+ *
+ * @param buf Buffer to deserialize (expected to be positioned after header)
+ * @return Deserialized KvStateRequestFailure
+ * @throws IOException Serialization failure are forwarded
+ * @throws ClassNotFoundException If Exception type can not be loaded
+ */
+ public static Throwable deserializeServerFailure(ByteBuf buf) throws IOException, ClassNotFoundException {
+ try (ByteBufInputStream bbis = new ByteBufInputStream(buf);
+ ObjectInputStream in = new ObjectInputStream(bbis)) {
+
+ return (Throwable) in.readObject();
+ }
+ }
+
+ // ------------------------------------------------------------------------
+ // Generic serialization utils
+ // ------------------------------------------------------------------------
+
+ /**
+ * Serializes the key and namespace into a {@link ByteBuffer}.
+ *
+ * <p>The serialized format matches the RocksDB state backend key format, i.e.
+ * the key and namespace don't have to be deserialized for RocksDB lookups.
+ *
+ * @param key Key to serialize
+ * @param keySerializer Serializer for the key
+ * @param namespace Namespace to serialize
+ * @param namespaceSerializer Serializer for the namespace
+ * @param <K> Key type
+ * @param <N> Namespace type
+ * @return Buffer holding the serialized key and namespace
+ * @throws IOException Serialization errors are forwarded
+ */
+ public static <K, N> byte[] serializeKeyAndNamespace(
+ K key,
+ TypeSerializer<K> keySerializer,
+ N namespace,
+ TypeSerializer<N> namespaceSerializer) throws IOException {
+
+ DataOutputSerializer dos = new DataOutputSerializer(32);
+
+ keySerializer.serialize(key, dos);
+ dos.writeByte(42);
+ namespaceSerializer.serialize(namespace, dos);
+
+ return dos.getCopyOfBuffer();
+ }
+
+ /**
+ * Deserializes the key and namespace into a {@link Tuple2}.
+ *
+ * @param serializedKeyAndNamespace Serialized key and namespace
+ * @param keySerializer Serializer for the key
+ * @param namespaceSerializer Serializer for the namespace
+ * @param <K> Key type
+ * @param <N> Namespace
+ * @return Tuple2 holding deserialized key and namespace
+ * @throws IOException Serialization errors are forwarded
+ * @throws IllegalStateException If unexpected magic number between key and namespace
+ */
+ public static <K, N> Tuple2<K, N> deserializeKeyAndNamespace(
+ byte[] serializedKeyAndNamespace,
+ TypeSerializer<K> keySerializer,
+ TypeSerializer<N> namespaceSerializer) throws IOException {
+
+ DataInputDeserializer dis = new DataInputDeserializer(
+ serializedKeyAndNamespace,
+ 0,
+ serializedKeyAndNamespace.length);
+
+ K key = keySerializer.deserialize(dis);
+ byte magicNumber = dis.readByte();
+ if (magicNumber != 42) {
+ throw new IllegalArgumentException("Unexpected magic number " + magicNumber +
+ ". This indicates a mismatch in the key serializers used by the " +
+ "KvState instance and this access.");
+ }
+ N namespace = namespaceSerializer.deserialize(dis);
+
+ if (dis.available() > 0) {
+ throw new IllegalArgumentException("Unconsumed bytes in the serialized key " +
+ "and namespace. This indicates a mismatch in the key/namespace " +
+ "serializers used by the KvState instance and this access.");
+ }
+
+ return new Tuple2<>(key, namespace);
+ }
+
+ /**
+ * Serializes the value with the given serializer.
+ *
+ * @param value Value of type T to serialize
+ * @param serializer Serializer for T
+ * @param <T> Type of the value
+ * @return Serialized value or <code>null</code> if value <code>null</code>
+ * @throws IOException On failure during serialization
+ */
+ public static <T> byte[] serializeValue(T value, TypeSerializer<T> serializer) throws IOException {
+ if (value != null) {
+ // Serialize
+ DataOutputSerializer dos = new DataOutputSerializer(32);
+ serializer.serialize(value, dos);
+ return dos.getCopyOfBuffer();
+ } else {
+ return null;
+ }
+ }
+
+ /**
+ * Deserializes the value with the given serializer.
+ *
+ * @param serializedValue Serialized value of type T
+ * @param serializer Serializer for T
+ * @param <T> Type of the value
+ * @return Deserialized value or <code>null</code> if the serialized value
+ * is <code>null</code>
+ * @throws IOException On failure during deserialization
+ */
+ public static <T> T deserializeValue(byte[] serializedValue, TypeSerializer<T> serializer) throws IOException {
+ if (serializedValue == null) {
+ return null;
+ } else {
+ DataInputDeserializer deser = new DataInputDeserializer(serializedValue, 0, serializedValue.length);
+ return serializer.deserialize(deser);
+ }
+ }
+
+ /**
+ * Serializes all values of the Iterable with the given serializer.
+ *
+ * @param values Values of type T to serialize
+ * @param serializer Serializer for T
+ * @param <T> Type of the values
+ * @return Serialized values or <code>null</code> if values <code>null</code> or empty
+ * @throws IOException On failure during serialization
+ */
+ public static <T> byte[] serializeList(Iterable<T> values, TypeSerializer<T> serializer) throws IOException {
+ if (values != null) {
+ Iterator<T> it = values.iterator();
+
+ if (it.hasNext()) {
+ // Serialize
+ DataOutputSerializer dos = new DataOutputSerializer(32);
+
+ while (it.hasNext()) {
+ serializer.serialize(it.next(), dos);
+
+ // This byte added here in order to have the binary format
+ // prescribed by RocksDB.
+ dos.write(0);
+ }
+
+ return dos.getCopyOfBuffer();
+ } else {
+ return null;
+ }
+ } else {
+ return null;
+ }
+ }
+
+ /**
+ * Deserializes all values with the given serializer.
+ *
+ * @param serializedValue Serialized value of type List<T>
+ * @param serializer Serializer for T
+ * @param <T> Type of the value
+ * @return Deserialized list or <code>null</code> if the serialized value
+ * is <code>null</code>
+ * @throws IOException On failure during deserialization
+ */
+ public static <T> List<T> deserializeList(byte[] serializedValue, TypeSerializer<T> serializer) throws IOException {
+ if (serializedValue != null) {
+ DataInputDeserializer in = new DataInputDeserializer(serializedValue, 0, serializedValue.length);
+
+ List<T> result = new ArrayList<>();
+ while (in.available() > 0) {
+ result.add(serializer.deserialize(in));
+
+ // The expected binary format has a single byte separator. We
+ // want a consistent binary format in order to not need any
+ // special casing during deserialization. A "cleaner" format
+ // would skip this extra byte, but would require a memory copy
+ // for RocksDB, which stores the data serialized in this way
+ // for lists.
+ if (in.available() > 0) {
+ in.readByte();
+ }
+ }
+
+ return result;
+ } else {
+ return null;
+ }
+ }
+
+ // ------------------------------------------------------------------------
+
+ /**
+ * Helper for writing the header.
+ *
+ * @param buf Buffer to serialize header into
+ * @param requestType Result type to serialize
+ */
+ private static void writeHeader(ByteBuf buf, KvStateRequestType requestType) {
+ buf.writeInt(VERSION);
+ buf.writeInt(requestType.ordinal());
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/af07eed8/flink-runtime/src/main/java/org/apache/flink/runtime/query/netty/message/KvStateRequestType.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/query/netty/message/KvStateRequestType.java b/flink-runtime/src/main/java/org/apache/flink/runtime/query/netty/message/KvStateRequestType.java
new file mode 100644
index 0000000..de7270a
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/query/netty/message/KvStateRequestType.java
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.query.netty.message;
+
+import org.apache.flink.runtime.query.netty.KvStateServer;
+
+/**
+ * Expected message types when communicating with the {@link KvStateServer}.
+ */
+public enum KvStateRequestType {
+
+ /** Request a KvState instance. */
+ REQUEST,
+
+ /** Successful response to a KvStateRequest. */
+ REQUEST_RESULT,
+
+ /** Failure response to a KvStateRequest. */
+ REQUEST_FAILURE,
+
+ /** Generic server failure. */
+ SERVER_FAILURE
+
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/af07eed8/flink-runtime/src/main/java/org/apache/flink/runtime/query/netty/package-info.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/query/netty/package-info.java b/flink-runtime/src/main/java/org/apache/flink/runtime/query/netty/package-info.java
new file mode 100644
index 0000000..7e8de40
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/query/netty/package-info.java
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * This package contains all Netty-based client/server classes used to query
+ * KvState instances.
+ *
+ * <h2>Server and Client</h2>
+ *
+ * <p>Both server and client expect received binary messages to contain a frame
+ * length field. Netty's {@link io.netty.handler.codec.LengthFieldBasedFrameDecoder}
+ * is used to fully receive the frame before giving it to the respective client
+ * or server handler.
+ *
+ * <p>Connection establishment and release happens by the client. The server
+ * only closes a connection if a fatal failure happens that cannot be resolved
+ * otherwise.
+ *
+ * <p>The is a single server per task manager and a single client can be shared
+ * by multiple Threads.
+ *
+ * <p>See also:
+ * <ul>
+ * <li>{@link org.apache.flink.runtime.query.netty.KvStateServer}</li>
+ * <li>{@link org.apache.flink.runtime.query.netty.KvStateServerHandler}</li>
+ * <li>{@link org.apache.flink.runtime.query.netty.KvStateClient}</li>
+ * <li>{@link org.apache.flink.runtime.query.netty.KvStateClientHandler}</li>
+ * </ul>
+ *
+ * <h2>Serialization</h2>
+ *
+ * <p>The exchanged binary messages have the following format:
+ *
+ * <pre>
+ * <------ Frame ------------------------->
+ * +----------------------------------------+
+ * | HEADER (8) | PAYLOAD (VAR) |
+ * +------------------+----------------------------------------+
+ * | FRAME LENGTH (4) | VERSION (4) | TYPE (4) | CONTENT (VAR) |
+ * +------------------+----------------------------------------+
+ * </pre>
+ *
+ * <p>For frame decoding, both server and client use Netty's {@link
+ * io.netty.handler.codec.LengthFieldBasedFrameDecoder}. Message serialization
+ * is done via static helpers in {@link org.apache.flink.runtime.query.netty.message.KvStateRequestSerializer}.
+ * The serialization helpers return {@link io.netty.buffer.ByteBuf} instances,
+ * which are ready to be sent to the client or server respectively as they
+ * contain the frame length.
+ *
+ * <p>See also:
+ * <ul>
+ * <li>{@link org.apache.flink.runtime.query.netty.message.KvStateRequestSerializer}</li>
+ * </ul>
+ *
+ * <h2>Statistics</h2>
+ *
+ * <p>Both server and client keep track of request statistics via {@link
+ * org.apache.flink.runtime.query.netty.KvStateRequestStats}.
+ *
+ * <p>See also:
+ * <ul>
+ * <li>{@link org.apache.flink.runtime.query.netty.KvStateRequestStats}</li>
+ * </ul>
+ */
+package org.apache.flink.runtime.query.netty;
http://git-wip-us.apache.org/repos/asf/flink/blob/af07eed8/flink-runtime/src/main/java/org/apache/flink/runtime/util/DataInputDeserializer.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/util/DataInputDeserializer.java b/flink-runtime/src/main/java/org/apache/flink/runtime/util/DataInputDeserializer.java
index bdccdd1..9822a83 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/util/DataInputDeserializer.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/util/DataInputDeserializer.java
@@ -95,6 +95,14 @@ public class DataInputDeserializer implements DataInputView, java.io.Serializabl
// ----------------------------------------------------------------------------------------
// Data Input
// ----------------------------------------------------------------------------------------
+
+ public int available() {
+ if (position < end) {
+ return end - position - 1;
+ } else {
+ return 0;
+ }
+ }
@Override
public boolean readBoolean() throws IOException {
http://git-wip-us.apache.org/repos/asf/flink/blob/af07eed8/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/KvStateClientHandlerTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/KvStateClientHandlerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/KvStateClientHandlerTest.java
new file mode 100644
index 0000000..31a9620
--- /dev/null
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/KvStateClientHandlerTest.java
@@ -0,0 +1,110 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.query.netty;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.channel.embedded.EmbeddedChannel;
+import org.apache.flink.runtime.query.netty.message.KvStateRequestSerializer;
+import org.junit.Test;
+
+import java.nio.channels.ClosedChannelException;
+
+import static org.junit.Assert.assertEquals;
+import static org.mockito.Matchers.any;
+import static org.mockito.Matchers.eq;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+
+public class KvStateClientHandlerTest {
+
+ /**
+ * Tests that on reads the expected callback methods are called and read
+ * buffers are recycled.
+ */
+ @Test
+ public void testReadCallbacksAndBufferRecycling() throws Exception {
+ KvStateClientHandlerCallback callback = mock(KvStateClientHandlerCallback.class);
+
+ EmbeddedChannel channel = new EmbeddedChannel(new KvStateClientHandler(callback));
+
+ //
+ // Request success
+ //
+ ByteBuf buf = KvStateRequestSerializer.serializeKvStateRequestResult(
+ channel.alloc(),
+ 1222112277,
+ new byte[0]);
+ buf.skipBytes(4); // skip frame length
+
+ // Verify callback
+ channel.writeInbound(buf);
+ verify(callback, times(1)).onRequestResult(eq(1222112277L), any(byte[].class));
+ assertEquals("Buffer not recycled", 0, buf.refCnt());
+
+ //
+ // Request failure
+ //
+ buf = KvStateRequestSerializer.serializeKvStateRequestFailure(
+ channel.alloc(),
+ 1222112278,
+ new RuntimeException("Expected test Exception"));
+ buf.skipBytes(4); // skip frame length
+
+ // Verify callback
+ channel.writeInbound(buf);
+ verify(callback, times(1)).onRequestFailure(eq(1222112278L), any(RuntimeException.class));
+ assertEquals("Buffer not recycled", 0, buf.refCnt());
+
+ //
+ // Server failure
+ //
+ buf = KvStateRequestSerializer.serializeServerFailure(
+ channel.alloc(),
+ new RuntimeException("Expected test Exception"));
+ buf.skipBytes(4); // skip frame length
+
+ // Verify callback
+ channel.writeInbound(buf);
+ verify(callback, times(1)).onFailure(any(RuntimeException.class));
+
+ //
+ // Unexpected messages
+ //
+ buf = channel.alloc().buffer(4).writeInt(1223823);
+
+ // Verify callback
+ channel.writeInbound(buf);
+ verify(callback, times(2)).onFailure(any(IllegalStateException.class));
+ assertEquals("Buffer not recycled", 0, buf.refCnt());
+
+ //
+ // Exception caught
+ //
+ channel.pipeline().fireExceptionCaught(new RuntimeException("Expected test Exception"));
+ verify(callback, times(3)).onFailure(any(RuntimeException.class));
+
+ //
+ // Channel inactive
+ //
+ channel.pipeline().fireChannelInactive();
+ verify(callback, times(4)).onFailure(any(ClosedChannelException.class));
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/af07eed8/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/KvStateClientTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/KvStateClientTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/KvStateClientTest.java
new file mode 100644
index 0000000..72d9f61
--- /dev/null
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/KvStateClientTest.java
@@ -0,0 +1,718 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.query.netty;
+
+import io.netty.bootstrap.ServerBootstrap;
+import io.netty.buffer.ByteBuf;
+import io.netty.channel.Channel;
+import io.netty.channel.ChannelHandler;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.ChannelInboundHandlerAdapter;
+import io.netty.channel.ChannelInitializer;
+import io.netty.channel.nio.NioEventLoopGroup;
+import io.netty.channel.socket.SocketChannel;
+import io.netty.channel.socket.nio.NioServerSocketChannel;
+import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
+import org.apache.flink.api.common.JobID;
+import org.apache.flink.api.common.state.ValueStateDescriptor;
+import org.apache.flink.api.common.typeutils.base.IntSerializer;
+import org.apache.flink.runtime.jobgraph.JobVertexID;
+import org.apache.flink.runtime.query.KvStateID;
+import org.apache.flink.runtime.query.KvStateRegistry;
+import org.apache.flink.runtime.query.KvStateServerAddress;
+import org.apache.flink.runtime.query.netty.message.KvStateRequest;
+import org.apache.flink.runtime.query.netty.message.KvStateRequestSerializer;
+import org.apache.flink.runtime.query.netty.message.KvStateRequestType;
+import org.apache.flink.runtime.state.VoidNamespace;
+import org.apache.flink.runtime.state.VoidNamespaceSerializer;
+import org.apache.flink.runtime.state.memory.MemValueState;
+import org.apache.flink.util.NetUtils;
+import org.junit.AfterClass;
+import org.junit.Test;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import scala.concurrent.Await;
+import scala.concurrent.Future;
+import scala.concurrent.duration.Deadline;
+import scala.concurrent.duration.FiniteDuration;
+
+import java.net.ConnectException;
+import java.net.InetAddress;
+import java.net.InetSocketAddress;
+import java.net.UnknownHostException;
+import java.nio.channels.ClosedChannelException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.concurrent.Callable;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.LinkedBlockingQueue;
+import java.util.concurrent.ThreadLocalRandom;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicReference;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+
+public class KvStateClientTest {
+
+ private static final Logger LOG = LoggerFactory.getLogger(KvStateClientTest.class);
+
+ // Thread pool for client bootstrap (shared between tests)
+ private static final NioEventLoopGroup NIO_GROUP = new NioEventLoopGroup();
+
+ private final static FiniteDuration TEST_TIMEOUT = new FiniteDuration(100, TimeUnit.SECONDS);
+
+ @AfterClass
+ public static void tearDown() throws Exception {
+ if (NIO_GROUP != null) {
+ NIO_GROUP.shutdownGracefully();
+ }
+ }
+
+ /**
+ * Tests simple queries, of which half succeed and half fail.
+ */
+ @Test
+ public void testSimpleRequests() throws Exception {
+ Deadline deadline = TEST_TIMEOUT.fromNow();
+ AtomicKvStateRequestStats stats = new AtomicKvStateRequestStats();
+
+ KvStateClient client = null;
+ Channel serverChannel = null;
+
+ try {
+ client = new KvStateClient(1, stats);
+
+ // Random result
+ final byte[] expected = new byte[1024];
+ ThreadLocalRandom.current().nextBytes(expected);
+
+ final LinkedBlockingQueue<ByteBuf> received = new LinkedBlockingQueue<>();
+ final AtomicReference<Channel> channel = new AtomicReference<>();
+
+ serverChannel = createServerChannel(new ChannelInboundHandlerAdapter() {
+ @Override
+ public void channelActive(ChannelHandlerContext ctx) throws Exception {
+ channel.set(ctx.channel());
+ }
+
+ @Override
+ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
+ received.add((ByteBuf) msg);
+ }
+ });
+
+ KvStateServerAddress serverAddress = getKvStateServerAddress(serverChannel);
+
+ List<Future<byte[]>> futures = new ArrayList<>();
+
+ int numQueries = 1024;
+
+ for (int i = 0; i < numQueries; i++) {
+ futures.add(client.getKvState(serverAddress, new KvStateID(), new byte[0]));
+ }
+
+ // Respond to messages
+ Exception testException = new RuntimeException("Expected test Exception");
+
+ for (int i = 0; i < numQueries; i++) {
+ ByteBuf buf = received.poll(deadline.timeLeft().toMillis(), TimeUnit.MILLISECONDS);
+ assertNotNull("Receive timed out", buf);
+
+ Channel ch = channel.get();
+ assertNotNull("Channel not active", ch);
+
+ assertEquals(KvStateRequestType.REQUEST, KvStateRequestSerializer.deserializeHeader(buf));
+ KvStateRequest request = KvStateRequestSerializer.deserializeKvStateRequest(buf);
+
+ buf.release();
+
+ if (i % 2 == 0) {
+ ByteBuf response = KvStateRequestSerializer.serializeKvStateRequestResult(
+ serverChannel.alloc(),
+ request.getRequestId(),
+ expected);
+
+ ch.writeAndFlush(response);
+ } else {
+ ByteBuf response = KvStateRequestSerializer.serializeKvStateRequestFailure(
+ serverChannel.alloc(),
+ request.getRequestId(),
+ testException);
+
+ ch.writeAndFlush(response);
+ }
+ }
+
+ for (int i = 0; i < numQueries; i++) {
+ if (i % 2 == 0) {
+ byte[] serializedResult = Await.result(futures.get(i), deadline.timeLeft());
+ assertArrayEquals(expected, serializedResult);
+ } else {
+ try {
+ Await.result(futures.get(i), deadline.timeLeft());
+ fail("Did not throw expected Exception");
+ } catch (RuntimeException ignored) {
+ // Expected
+ }
+ }
+ }
+
+ assertEquals(numQueries, stats.getNumRequests());
+ int expectedRequests = numQueries / 2;
+
+ // Counts can take some time to propagate
+ while (deadline.hasTimeLeft() && (stats.getNumSuccessful() != expectedRequests ||
+ stats.getNumFailed() != expectedRequests)) {
+ Thread.sleep(100);
+ }
+
+ assertEquals(expectedRequests, stats.getNumSuccessful());
+ assertEquals(expectedRequests, stats.getNumFailed());
+ } finally {
+ if (client != null) {
+ client.shutDown();
+ }
+
+ if (serverChannel != null) {
+ serverChannel.close();
+ }
+
+ assertEquals("Channel leak", 0, stats.getNumConnections());
+ }
+ }
+
+ /**
+ * Tests that a request to an unavailable host is failed with ConnectException.
+ */
+ @Test
+ public void testRequestUnavailableHost() throws Exception {
+ Deadline deadline = TEST_TIMEOUT.fromNow();
+ AtomicKvStateRequestStats stats = new AtomicKvStateRequestStats();
+ KvStateClient client = null;
+
+ try {
+ client = new KvStateClient(1, stats);
+
+ int availablePort = NetUtils.getAvailablePort();
+
+ KvStateServerAddress serverAddress = new KvStateServerAddress(
+ InetAddress.getLocalHost(),
+ availablePort);
+
+ Future<byte[]> future = client.getKvState(serverAddress, new KvStateID(), new byte[0]);
+
+ try {
+ Await.result(future, deadline.timeLeft());
+ fail("Did not throw expected ConnectException");
+ } catch (ConnectException ignored) {
+ // Expected
+ }
+ } finally {
+ if (client != null) {
+ client.shutDown();
+ }
+
+ assertEquals("Channel leak", 0, stats.getNumConnections());
+ }
+ }
+
+ /**
+ * Multiple threads concurrently fire queries.
+ */
+ @Test
+ public void testConcurrentQueries() throws Exception {
+ Deadline deadline = TEST_TIMEOUT.fromNow();
+ AtomicKvStateRequestStats stats = new AtomicKvStateRequestStats();
+
+ ExecutorService executor = null;
+ KvStateClient client = null;
+ Channel serverChannel = null;
+
+ final byte[] serializedResult = new byte[1024];
+ ThreadLocalRandom.current().nextBytes(serializedResult);
+
+ try {
+ int numQueryTasks = 4;
+ final int numQueriesPerTask = 1024;
+
+ executor = Executors.newFixedThreadPool(numQueryTasks);
+
+ client = new KvStateClient(1, stats);
+
+ serverChannel = createServerChannel(new ChannelInboundHandlerAdapter() {
+ @Override
+ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
+ ByteBuf buf = (ByteBuf) msg;
+ assertEquals(KvStateRequestType.REQUEST, KvStateRequestSerializer.deserializeHeader(buf));
+ KvStateRequest request = KvStateRequestSerializer.deserializeKvStateRequest(buf);
+
+ buf.release();
+
+ ByteBuf response = KvStateRequestSerializer.serializeKvStateRequestResult(
+ ctx.alloc(),
+ request.getRequestId(),
+ serializedResult);
+
+ ctx.channel().writeAndFlush(response);
+ }
+ });
+
+ final KvStateServerAddress serverAddress = getKvStateServerAddress(serverChannel);
+
+ final KvStateClient finalClient = client;
+ Callable<List<Future<byte[]>>> queryTask = new Callable<List<Future<byte[]>>>() {
+ @Override
+ public List<Future<byte[]>> call() throws Exception {
+ List<Future<byte[]>> results = new ArrayList<>(numQueriesPerTask);
+
+ for (int i = 0; i < numQueriesPerTask; i++) {
+ results.add(finalClient.getKvState(
+ serverAddress,
+ new KvStateID(),
+ new byte[0]));
+ }
+
+ return results;
+ }
+ };
+
+ // Submit query tasks
+ List<java.util.concurrent.Future<List<Future<byte[]>>>> futures = new ArrayList<>();
+ for (int i = 0; i < numQueryTasks; i++) {
+ futures.add(executor.submit(queryTask));
+ }
+
+ // Verify results
+ for (java.util.concurrent.Future<List<Future<byte[]>>> future : futures) {
+ List<Future<byte[]>> results = future.get(deadline.timeLeft().toMillis(), TimeUnit.MILLISECONDS);
+ for (Future<byte[]> result : results) {
+ byte[] actual = Await.result(result, deadline.timeLeft());
+ assertArrayEquals(serializedResult, actual);
+ }
+ }
+
+ int totalQueries = numQueryTasks * numQueriesPerTask;
+
+ // Counts can take some time to propagate
+ while (deadline.hasTimeLeft() && (stats.getNumSuccessful() != totalQueries ||
+ stats.getNumFailed() != totalQueries)) {
+ Thread.sleep(100);
+ }
+
+ assertEquals(totalQueries, stats.getNumRequests());
+ assertEquals(totalQueries, stats.getNumSuccessful());
+ } finally {
+ if (executor != null) {
+ executor.shutdown();
+ }
+
+ if (serverChannel != null) {
+ serverChannel.close();
+ }
+
+ if (client != null) {
+ client.shutDown();
+ }
+
+ assertEquals("Channel leak", 0, stats.getNumConnections());
+ }
+ }
+
+ /**
+ * Tests that a server failure closes the connection and removes it from
+ * the established connections.
+ */
+ @Test
+ public void testFailureClosesChannel() throws Exception {
+ Deadline deadline = TEST_TIMEOUT.fromNow();
+ AtomicKvStateRequestStats stats = new AtomicKvStateRequestStats();
+
+ KvStateClient client = null;
+ Channel serverChannel = null;
+
+ try {
+ client = new KvStateClient(1, stats);
+
+ final LinkedBlockingQueue<ByteBuf> received = new LinkedBlockingQueue<>();
+ final AtomicReference<Channel> channel = new AtomicReference<>();
+
+ serverChannel = createServerChannel(new ChannelInboundHandlerAdapter() {
+ @Override
+ public void channelActive(ChannelHandlerContext ctx) throws Exception {
+ channel.set(ctx.channel());
+ }
+
+ @Override
+ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
+ received.add((ByteBuf) msg);
+ }
+ });
+
+ KvStateServerAddress serverAddress = getKvStateServerAddress(serverChannel);
+
+ // Requests
+ List<Future<byte[]>> futures = new ArrayList<>();
+ futures.add(client.getKvState(serverAddress, new KvStateID(), new byte[0]));
+ futures.add(client.getKvState(serverAddress, new KvStateID(), new byte[0]));
+
+ ByteBuf buf = received.poll(deadline.timeLeft().toMillis(), TimeUnit.MILLISECONDS);
+ assertNotNull("Receive timed out", buf);
+ buf.release();
+
+ buf = received.poll(deadline.timeLeft().toMillis(), TimeUnit.MILLISECONDS);
+ assertNotNull("Receive timed out", buf);
+ buf.release();
+
+ assertEquals(1, stats.getNumConnections());
+
+ Channel ch = channel.get();
+ assertNotNull("Channel not active", ch);
+
+ // Respond with failure
+ ch.writeAndFlush(KvStateRequestSerializer.serializeServerFailure(
+ serverChannel.alloc(),
+ new RuntimeException("Expected test server failure")));
+
+ try {
+ Await.result(futures.remove(0), deadline.timeLeft());
+ fail("Did not throw expected server failure");
+ } catch (RuntimeException ignored) {
+ // Expected
+ }
+
+ try {
+ Await.result(futures.remove(0), deadline.timeLeft());
+ fail("Did not throw expected server failure");
+ } catch (RuntimeException ignored) {
+ // Expected
+ }
+
+ assertEquals(0, stats.getNumConnections());
+
+ // Counts can take some time to propagate
+ while (deadline.hasTimeLeft() && (stats.getNumSuccessful() != 0 ||
+ stats.getNumFailed() != 2)) {
+ Thread.sleep(100);
+ }
+
+ assertEquals(2, stats.getNumRequests());
+ assertEquals(0, stats.getNumSuccessful());
+ assertEquals(2, stats.getNumFailed());
+ } finally {
+ if (client != null) {
+ client.shutDown();
+ }
+
+ if (serverChannel != null) {
+ serverChannel.close();
+ }
+
+ assertEquals("Channel leak", 0, stats.getNumConnections());
+ }
+ }
+
+ /**
+ * Tests that a server channel close, closes the connection and removes it
+ * from the established connections.
+ */
+ @Test
+ public void testServerClosesChannel() throws Exception {
+ Deadline deadline = TEST_TIMEOUT.fromNow();
+ AtomicKvStateRequestStats stats = new AtomicKvStateRequestStats();
+
+ KvStateClient client = null;
+ Channel serverChannel = null;
+
+ try {
+ client = new KvStateClient(1, stats);
+
+ final AtomicBoolean received = new AtomicBoolean();
+ final AtomicReference<Channel> channel = new AtomicReference<>();
+
+ serverChannel = createServerChannel(new ChannelInboundHandlerAdapter() {
+ @Override
+ public void channelActive(ChannelHandlerContext ctx) throws Exception {
+ channel.set(ctx.channel());
+ }
+
+ @Override
+ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
+ received.set(true);
+ }
+ });
+
+ KvStateServerAddress serverAddress = getKvStateServerAddress(serverChannel);
+
+ // Requests
+ Future<byte[]> future = client.getKvState(serverAddress, new KvStateID(), new byte[0]);
+
+ while (!received.get() && deadline.hasTimeLeft()) {
+ Thread.sleep(50);
+ }
+ assertTrue("Receive timed out", received.get());
+
+ assertEquals(1, stats.getNumConnections());
+
+ channel.get().close().await(deadline.timeLeft().toMillis(), TimeUnit.MILLISECONDS);
+
+ try {
+ Await.result(future, deadline.timeLeft());
+ fail("Did not throw expected server failure");
+ } catch (ClosedChannelException ignored) {
+ // Expected
+ }
+
+ assertEquals(0, stats.getNumConnections());
+
+ // Counts can take some time to propagate
+ while (deadline.hasTimeLeft() && (stats.getNumSuccessful() != 0 ||
+ stats.getNumFailed() != 1)) {
+ Thread.sleep(100);
+ }
+
+ assertEquals(1, stats.getNumRequests());
+ assertEquals(0, stats.getNumSuccessful());
+ assertEquals(1, stats.getNumFailed());
+ } finally {
+ if (client != null) {
+ client.shutDown();
+ }
+
+ if (serverChannel != null) {
+ serverChannel.close();
+ }
+
+ assertEquals("Channel leak", 0, stats.getNumConnections());
+ }
+ }
+
+ /**
+ * Tests multiple clients querying multiple servers until 100k queries have
+ * been processed. At this point, the client is shut down and its verified
+ * that all ongoing requests are failed.
+ */
+ @Test
+ public void testClientServerIntegration() throws Exception {
+ // Config
+ final int numServers = 2;
+ final int numServerEventLoopThreads = 2;
+ final int numServerQueryThreads = 2;
+
+ final int numClientEventLoopThreads = 4;
+ final int numClientsTasks = 8;
+
+ final int batchSize = 16;
+
+ final FiniteDuration timeout = new FiniteDuration(10, TimeUnit.SECONDS);
+
+ AtomicKvStateRequestStats clientStats = new AtomicKvStateRequestStats();
+
+ KvStateClient client = null;
+ ExecutorService clientTaskExecutor = null;
+ final KvStateServer[] server = new KvStateServer[numServers];
+
+ try {
+ client = new KvStateClient(numClientEventLoopThreads, clientStats);
+ clientTaskExecutor = Executors.newFixedThreadPool(numClientsTasks);
+
+ // Create state
+ ValueStateDescriptor<Integer> desc = new ValueStateDescriptor<>("any", IntSerializer.INSTANCE, null);
+ desc.setQueryable("any");
+
+ MemValueState<Integer, VoidNamespace, Integer> kvState = new MemValueState<>(
+ IntSerializer.INSTANCE,
+ VoidNamespaceSerializer.INSTANCE,
+ desc);
+
+ // Create servers
+ KvStateRegistry[] registry = new KvStateRegistry[numServers];
+ AtomicKvStateRequestStats[] serverStats = new AtomicKvStateRequestStats[numServers];
+ final KvStateID[] ids = new KvStateID[numServers];
+
+ for (int i = 0; i < numServers; i++) {
+ registry[i] = new KvStateRegistry();
+ serverStats[i] = new AtomicKvStateRequestStats();
+ server[i] = new KvStateServer(
+ InetAddress.getLocalHost(),
+ 0,
+ numServerEventLoopThreads,
+ numServerQueryThreads,
+ registry[i],
+ serverStats[i]);
+
+ server[i].start();
+
+ // Value per server
+ kvState.setCurrentKey(1010 + i);
+ kvState.setCurrentNamespace(VoidNamespace.INSTANCE);
+ kvState.update(201 + i);
+
+ // Register KvState (one state instance for all server)
+ ids[i] = registry[i].registerKvState(new JobID(), new JobVertexID(), 0, "any", kvState);
+ }
+
+ final KvStateClient finalClient = client;
+ Callable<Void> queryTask = new Callable<Void>() {
+ @Override
+ public Void call() throws Exception {
+ while (true) {
+ if (Thread.interrupted()) {
+ throw new InterruptedException();
+ }
+
+ // Random server permutation
+ List<Integer> random = new ArrayList<>();
+ for (int j = 0; j < batchSize; j++) {
+ random.add(j);
+ }
+ Collections.shuffle(random);
+
+ // Dispatch queries
+ List<Future<byte[]>> futures = new ArrayList<>(batchSize);
+
+ for (int j = 0; j < batchSize; j++) {
+ int targetServer = random.get(j) % numServers;
+
+ byte[] serializedKeyAndNamespace = KvStateRequestSerializer.serializeKeyAndNamespace(
+ 1010 + targetServer,
+ IntSerializer.INSTANCE,
+ VoidNamespace.INSTANCE,
+ VoidNamespaceSerializer.INSTANCE);
+
+ futures.add(finalClient.getKvState(
+ server[targetServer].getAddress(),
+ ids[targetServer],
+ serializedKeyAndNamespace));
+ }
+
+ // Verify results
+ for (int j = 0; j < batchSize; j++) {
+ int targetServer = random.get(j) % numServers;
+
+ Future<byte[]> future = futures.get(j);
+ byte[] buf = Await.result(future, timeout);
+ int value = KvStateRequestSerializer.deserializeValue(buf, IntSerializer.INSTANCE);
+ assertEquals(201 + targetServer, value);
+ }
+ }
+ }
+ };
+
+ // Submit tasks
+ List<java.util.concurrent.Future<Void>> taskFutures = new ArrayList<>();
+ for (int i = 0; i < numClientsTasks; i++) {
+ taskFutures.add(clientTaskExecutor.submit(queryTask));
+ }
+
+ long numRequests;
+ while ((numRequests = clientStats.getNumRequests()) < 100_000) {
+ Thread.sleep(100);
+ LOG.info("Number of requests {}/100_000", numRequests);
+ }
+
+ // Shut down
+ client.shutDown();
+
+ for (java.util.concurrent.Future<Void> future : taskFutures) {
+ try {
+ future.get();
+ fail("Did not throw expected Exception after shut down");
+ } catch (ExecutionException t) {
+ if (t.getCause() instanceof ClosedChannelException ||
+ t.getCause() instanceof IllegalStateException) {
+ // Expected
+ } else {
+ t.printStackTrace();
+ fail("Failed with unexpected Exception type: " + t.getClass().getName());
+ }
+ }
+ }
+
+ assertEquals("Connection leak (client)", 0, clientStats.getNumConnections());
+ for (int i = 0; i < numServers; i++) {
+ boolean success = false;
+ int numRetries = 0;
+ while (!success) {
+ try {
+ assertEquals("Connection leak (server)", 0, serverStats[i].getNumConnections());
+ success = true;
+ } catch (Throwable t) {
+ if (numRetries < 10) {
+ LOG.info("Retrying connection leak check (server)");
+ Thread.sleep((numRetries + 1) * 50);
+ numRetries++;
+ } else {
+ throw t;
+ }
+ }
+ }
+ }
+ } finally {
+ if (client != null) {
+ client.shutDown();
+ }
+
+ for (int i = 0; i < numServers; i++) {
+ if (server[i] != null) {
+ server[i].shutDown();
+ }
+ }
+
+ if (clientTaskExecutor != null) {
+ clientTaskExecutor.shutdown();
+ }
+ }
+ }
+
+ // ------------------------------------------------------------------------
+
+ private Channel createServerChannel(final ChannelHandler... handlers) throws UnknownHostException, InterruptedException {
+ ServerBootstrap bootstrap = new ServerBootstrap()
+ // Bind address and port
+ .localAddress(InetAddress.getLocalHost(), 0)
+ // NIO server channels
+ .group(NIO_GROUP)
+ .channel(NioServerSocketChannel.class)
+ // See initializer for pipeline details
+ .childHandler(new ChannelInitializer<SocketChannel>() {
+ @Override
+ protected void initChannel(SocketChannel ch) throws Exception {
+ ch.pipeline()
+ .addLast(new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4))
+ .addLast(handlers);
+ }
+ });
+
+ return bootstrap.bind().sync().channel();
+ }
+
+ private KvStateServerAddress getKvStateServerAddress(Channel serverChannel) {
+ InetSocketAddress localAddress = (InetSocketAddress) serverChannel.localAddress();
+
+ return new KvStateServerAddress(localAddress.getAddress(), localAddress.getPort());
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/af07eed8/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/KvStateServerHandlerTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/KvStateServerHandlerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/KvStateServerHandlerTest.java
new file mode 100644
index 0000000..6ad7ece
--- /dev/null
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/KvStateServerHandlerTest.java
@@ -0,0 +1,622 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.query.netty;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+import io.netty.channel.ChannelHandler;
+import io.netty.channel.embedded.EmbeddedChannel;
+import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
+import org.apache.flink.api.common.JobID;
+import org.apache.flink.api.common.state.ValueStateDescriptor;
+import org.apache.flink.api.common.typeutils.base.IntSerializer;
+import org.apache.flink.api.common.typeutils.base.StringSerializer;
+import org.apache.flink.api.common.typeutils.base.array.BytePrimitiveArraySerializer;
+import org.apache.flink.runtime.jobgraph.JobVertexID;
+import org.apache.flink.runtime.query.KvStateID;
+import org.apache.flink.runtime.query.KvStateRegistry;
+import org.apache.flink.runtime.query.netty.message.KvStateRequestFailure;
+import org.apache.flink.runtime.query.netty.message.KvStateRequestResult;
+import org.apache.flink.runtime.query.netty.message.KvStateRequestSerializer;
+import org.apache.flink.runtime.query.netty.message.KvStateRequestType;
+import org.apache.flink.runtime.state.KvState;
+import org.apache.flink.runtime.state.VoidNamespace;
+import org.apache.flink.runtime.state.VoidNamespaceSerializer;
+import org.apache.flink.runtime.state.memory.MemValueState;
+import org.junit.AfterClass;
+import org.junit.Test;
+
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.TimeoutException;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+import static org.mockito.Matchers.any;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+public class KvStateServerHandlerTest {
+
+ /** Shared Thread pool for query execution */
+ private final static ExecutorService TEST_THREAD_POOL = Executors.newSingleThreadExecutor();
+
+ private final static int READ_TIMEOUT_MILLIS = 10000;
+
+ @AfterClass
+ public static void tearDown() throws Exception {
+ if (TEST_THREAD_POOL != null) {
+ TEST_THREAD_POOL.shutdown();
+ }
+ }
+
+ /**
+ * Tests a simple successful query via an EmbeddedChannel.
+ */
+ @Test
+ public void testSimpleQuery() throws Exception {
+ KvStateRegistry registry = new KvStateRegistry();
+ AtomicKvStateRequestStats stats = new AtomicKvStateRequestStats();
+
+ KvStateServerHandler handler = new KvStateServerHandler(registry, TEST_THREAD_POOL, stats);
+ EmbeddedChannel channel = new EmbeddedChannel(getFrameDecoder(), handler);
+
+ // Register state
+ ValueStateDescriptor<Integer> desc = new ValueStateDescriptor<>("any", IntSerializer.INSTANCE, null);
+ desc.setQueryable("any");
+
+ MemValueState<Integer, VoidNamespace, Integer> kvState = new MemValueState<>(
+ IntSerializer.INSTANCE,
+ VoidNamespaceSerializer.INSTANCE,
+ desc);
+
+ KvStateID kvStateId = registry.registerKvState(
+ new JobID(),
+ new JobVertexID(),
+ 0,
+ "vanilla",
+ kvState);
+
+ // Update the KvState and request it
+ int expectedValue = 712828289;
+
+ int key = 99812822;
+ kvState.setCurrentKey(key);
+ kvState.setCurrentNamespace(VoidNamespace.INSTANCE);
+
+ kvState.update(expectedValue);
+
+ byte[] serializedKeyAndNamespace = KvStateRequestSerializer.serializeKeyAndNamespace(
+ key,
+ IntSerializer.INSTANCE,
+ VoidNamespace.INSTANCE,
+ VoidNamespaceSerializer.INSTANCE);
+
+ long requestId = Integer.MAX_VALUE + 182828L;
+ ByteBuf request = KvStateRequestSerializer.serializeKvStateRequest(
+ channel.alloc(),
+ requestId,
+ kvStateId,
+ serializedKeyAndNamespace);
+
+ // Write the request and wait for the response
+ channel.writeInbound(request);
+
+ ByteBuf buf = (ByteBuf) readInboundBlocking(channel);
+ buf.skipBytes(4); // skip frame length
+
+ // Verify the response
+ assertEquals(KvStateRequestType.REQUEST_RESULT, KvStateRequestSerializer.deserializeHeader(buf));
+ KvStateRequestResult response = KvStateRequestSerializer.deserializeKvStateRequestResult(buf);
+
+ assertEquals(requestId, response.getRequestId());
+
+ int actualValue = KvStateRequestSerializer.deserializeValue(response.getSerializedResult(), IntSerializer.INSTANCE);
+ assertEquals(expectedValue, actualValue);
+
+ assertEquals(1, stats.getNumRequests());
+ assertEquals(1, stats.getNumSuccessful());
+ }
+
+ /**
+ * Tests the failure response with {@link UnknownKvStateID} as cause on
+ * queries for unregistered KvStateIDs.
+ */
+ @Test
+ public void testQueryUnknownKvStateID() throws Exception {
+ KvStateRegistry registry = new KvStateRegistry();
+ AtomicKvStateRequestStats stats = new AtomicKvStateRequestStats();
+
+ KvStateServerHandler handler = new KvStateServerHandler(registry, TEST_THREAD_POOL, stats);
+ EmbeddedChannel channel = new EmbeddedChannel(getFrameDecoder(), handler);
+
+ long requestId = Integer.MAX_VALUE + 182828L;
+ ByteBuf request = KvStateRequestSerializer.serializeKvStateRequest(
+ channel.alloc(),
+ requestId,
+ new KvStateID(),
+ new byte[0]);
+
+ // Write the request and wait for the response
+ channel.writeInbound(request);
+
+ ByteBuf buf = (ByteBuf) readInboundBlocking(channel);
+ buf.skipBytes(4); // skip frame length
+
+ // Verify the response
+ assertEquals(KvStateRequestType.REQUEST_FAILURE, KvStateRequestSerializer.deserializeHeader(buf));
+ KvStateRequestFailure response = KvStateRequestSerializer.deserializeKvStateRequestFailure(buf);
+
+ assertEquals(requestId, response.getRequestId());
+
+ assertTrue("Did not respond with expected failure cause", response.getCause() instanceof UnknownKvStateID);
+
+ assertEquals(1, stats.getNumRequests());
+ assertEquals(1, stats.getNumFailed());
+ }
+
+ /**
+ * Tests the failure response with {@link UnknownKeyOrNamespace} as cause
+ * on queries for non-existing keys.
+ */
+ @Test
+ public void testQueryUnknownKey() throws Exception {
+ KvStateRegistry registry = new KvStateRegistry();
+ AtomicKvStateRequestStats stats = new AtomicKvStateRequestStats();
+
+ KvStateServerHandler handler = new KvStateServerHandler(registry, TEST_THREAD_POOL, stats);
+ EmbeddedChannel channel = new EmbeddedChannel(getFrameDecoder(), handler);
+
+ // Register state
+ ValueStateDescriptor<Integer> desc = new ValueStateDescriptor<>("any", IntSerializer.INSTANCE, null);
+ desc.setQueryable("any");
+
+ MemValueState<Integer, VoidNamespace, Integer> kvState = new MemValueState<>(
+ IntSerializer.INSTANCE,
+ VoidNamespaceSerializer.INSTANCE,
+ desc);
+
+ KvStateID kvStateId = registry.registerKvState(
+ new JobID(),
+ new JobVertexID(),
+ 0,
+ "vanilla",
+ kvState);
+
+ byte[] serializedKeyAndNamespace = KvStateRequestSerializer.serializeKeyAndNamespace(
+ 1238283,
+ IntSerializer.INSTANCE,
+ VoidNamespace.INSTANCE,
+ VoidNamespaceSerializer.INSTANCE);
+
+ long requestId = Integer.MAX_VALUE + 22982L;
+ ByteBuf request = KvStateRequestSerializer.serializeKvStateRequest(
+ channel.alloc(),
+ requestId,
+ kvStateId,
+ serializedKeyAndNamespace);
+
+ // Write the request and wait for the response
+ channel.writeInbound(request);
+
+ ByteBuf buf = (ByteBuf) readInboundBlocking(channel);
+ buf.skipBytes(4); // skip frame length
+
+ // Verify the response
+ assertEquals(KvStateRequestType.REQUEST_FAILURE, KvStateRequestSerializer.deserializeHeader(buf));
+ KvStateRequestFailure response = KvStateRequestSerializer.deserializeKvStateRequestFailure(buf);
+
+ assertEquals(requestId, response.getRequestId());
+
+ assertTrue("Did not respond with expected failure cause", response.getCause() instanceof UnknownKeyOrNamespace);
+
+ assertEquals(1, stats.getNumRequests());
+ assertEquals(1, stats.getNumFailed());
+ }
+
+ /**
+ * Tests the failure response on a failure on the {@link KvState#getSerializedValue(byte[])}
+ * call.
+ */
+ @Test
+ public void testFailureOnGetSerializedValue() throws Exception {
+ KvStateRegistry registry = new KvStateRegistry();
+ AtomicKvStateRequestStats stats = new AtomicKvStateRequestStats();
+
+ KvStateServerHandler handler = new KvStateServerHandler(registry, TEST_THREAD_POOL, stats);
+ EmbeddedChannel channel = new EmbeddedChannel(getFrameDecoder(), handler);
+
+ // Failing KvState
+ KvState<?, ?, ?, ?, ?> kvState = mock(KvState.class);
+ when(kvState.getSerializedValue(any(byte[].class)))
+ .thenThrow(new RuntimeException("Expected test Exception"));
+
+ KvStateID kvStateId = registry.registerKvState(
+ new JobID(),
+ new JobVertexID(),
+ 0,
+ "vanilla",
+ kvState);
+
+ ByteBuf request = KvStateRequestSerializer.serializeKvStateRequest(
+ channel.alloc(),
+ 282872,
+ kvStateId,
+ new byte[0]);
+
+ // Write the request and wait for the response
+ channel.writeInbound(request);
+
+ ByteBuf buf = (ByteBuf) readInboundBlocking(channel);
+ buf.skipBytes(4); // skip frame length
+
+ // Verify the response
+ assertEquals(KvStateRequestType.REQUEST_FAILURE, KvStateRequestSerializer.deserializeHeader(buf));
+ KvStateRequestFailure response = KvStateRequestSerializer.deserializeKvStateRequestFailure(buf);
+
+ assertTrue(response.getCause().getMessage().contains("Expected test Exception"));
+
+ assertEquals(1, stats.getNumRequests());
+ assertEquals(1, stats.getNumFailed());
+ }
+
+ /**
+ * Tests that the channel is closed if an Exception reaches the channel
+ * handler.
+ */
+ @Test
+ public void testCloseChannelOnExceptionCaught() throws Exception {
+ KvStateRegistry registry = new KvStateRegistry();
+ AtomicKvStateRequestStats stats = new AtomicKvStateRequestStats();
+
+ KvStateServerHandler handler = new KvStateServerHandler(registry, TEST_THREAD_POOL, stats);
+ EmbeddedChannel channel = new EmbeddedChannel(handler);
+
+ channel.pipeline().fireExceptionCaught(new RuntimeException("Expected test Exception"));
+
+ ByteBuf buf = (ByteBuf) readInboundBlocking(channel);
+ buf.skipBytes(4); // skip frame length
+
+ // Verify the response
+ assertEquals(KvStateRequestType.SERVER_FAILURE, KvStateRequestSerializer.deserializeHeader(buf));
+ Throwable response = KvStateRequestSerializer.deserializeServerFailure(buf);
+
+ assertTrue(response.getMessage().contains("Expected test Exception"));
+
+ channel.closeFuture().await(READ_TIMEOUT_MILLIS);
+ assertFalse(channel.isActive());
+ }
+
+ /**
+ * Tests the failure response on a rejected execution, because the query
+ * executor has been closed.
+ */
+ @Test
+ public void testQueryExecutorShutDown() throws Exception {
+ KvStateRegistry registry = new KvStateRegistry();
+ AtomicKvStateRequestStats stats = new AtomicKvStateRequestStats();
+
+ ExecutorService closedExecutor = Executors.newSingleThreadExecutor();
+ closedExecutor.shutdown();
+ assertTrue(closedExecutor.isShutdown());
+
+ KvStateServerHandler handler = new KvStateServerHandler(registry, closedExecutor, stats);
+ EmbeddedChannel channel = new EmbeddedChannel(getFrameDecoder(), handler);
+
+ // Register state
+ ValueStateDescriptor<Integer> desc = new ValueStateDescriptor<>("any", IntSerializer.INSTANCE, null);
+ desc.setQueryable("any");
+
+ MemValueState<Integer, VoidNamespace, Integer> kvState = new MemValueState<>(
+ IntSerializer.INSTANCE,
+ VoidNamespaceSerializer.INSTANCE,
+ desc);
+
+ KvStateID kvStateId = registry.registerKvState(
+ new JobID(),
+ new JobVertexID(),
+ 0,
+ "vanilla",
+ kvState);
+
+ ByteBuf request = KvStateRequestSerializer.serializeKvStateRequest(
+ channel.alloc(),
+ 282872,
+ kvStateId,
+ new byte[0]);
+
+ // Write the request and wait for the response
+ channel.writeInbound(request);
+
+ ByteBuf buf = (ByteBuf) readInboundBlocking(channel);
+ buf.skipBytes(4); // skip frame length
+
+ // Verify the response
+ assertEquals(KvStateRequestType.REQUEST_FAILURE, KvStateRequestSerializer.deserializeHeader(buf));
+ KvStateRequestFailure response = KvStateRequestSerializer.deserializeKvStateRequestFailure(buf);
+
+ assertTrue(response.getCause().getMessage().contains("RejectedExecutionException"));
+
+ assertEquals(1, stats.getNumRequests());
+ assertEquals(1, stats.getNumFailed());
+ }
+
+ /**
+ * Tests response on unexpected messages.
+ */
+ @Test
+ public void testUnexpectedMessage() throws Exception {
+ KvStateRegistry registry = new KvStateRegistry();
+ AtomicKvStateRequestStats stats = new AtomicKvStateRequestStats();
+
+ KvStateServerHandler handler = new KvStateServerHandler(registry, TEST_THREAD_POOL, stats);
+ EmbeddedChannel channel = new EmbeddedChannel(getFrameDecoder(), handler);
+
+ // Write the request and wait for the response
+ ByteBuf unexpectedMessage = Unpooled.buffer(8);
+ unexpectedMessage.writeInt(4);
+ unexpectedMessage.writeInt(123238213);
+
+ channel.writeInbound(unexpectedMessage);
+
+ ByteBuf buf = (ByteBuf) readInboundBlocking(channel);
+ buf.skipBytes(4); // skip frame length
+
+ // Verify the response
+ assertEquals(KvStateRequestType.SERVER_FAILURE, KvStateRequestSerializer.deserializeHeader(buf));
+ Throwable response = KvStateRequestSerializer.deserializeServerFailure(buf);
+
+ assertEquals(0, stats.getNumRequests());
+ assertEquals(0, stats.getNumFailed());
+
+ unexpectedMessage = KvStateRequestSerializer.serializeKvStateRequestResult(
+ channel.alloc(),
+ 192,
+ new byte[0]);
+
+ channel.writeInbound(unexpectedMessage);
+
+ buf = (ByteBuf) readInboundBlocking(channel);
+ buf.skipBytes(4); // skip frame length
+
+ // Verify the response
+ assertEquals(KvStateRequestType.SERVER_FAILURE, KvStateRequestSerializer.deserializeHeader(buf));
+ response = KvStateRequestSerializer.deserializeServerFailure(buf);
+
+ assertTrue("Unexpected failure cause " + response.getClass().getName(), response instanceof IllegalArgumentException);
+
+ assertEquals(0, stats.getNumRequests());
+ assertEquals(0, stats.getNumFailed());
+ }
+
+ /**
+ * Tests that incoming buffer instances are recycled.
+ */
+ @Test
+ public void testIncomingBufferIsRecycled() throws Exception {
+ KvStateRegistry registry = new KvStateRegistry();
+ AtomicKvStateRequestStats stats = new AtomicKvStateRequestStats();
+
+ KvStateServerHandler handler = new KvStateServerHandler(registry, TEST_THREAD_POOL, stats);
+ EmbeddedChannel channel = new EmbeddedChannel(getFrameDecoder(), handler);
+
+ ByteBuf request = KvStateRequestSerializer.serializeKvStateRequest(
+ channel.alloc(),
+ 282872,
+ new KvStateID(),
+ new byte[0]);
+
+ assertEquals(1, request.refCnt());
+
+ // Write regular request
+ channel.writeInbound(request);
+ assertEquals("Buffer not recycled", 0, request.refCnt());
+
+ // Write unexpected msg
+ ByteBuf unexpected = channel.alloc().buffer(8);
+ unexpected.writeInt(4);
+ unexpected.writeInt(4);
+
+ assertEquals(1, unexpected.refCnt());
+
+ channel.writeInbound(unexpected);
+ assertEquals("Buffer not recycled", 0, unexpected.refCnt());
+ }
+
+ /**
+ * Tests the failure response if the serializers don't match.
+ */
+ @Test
+ public void testSerializerMismatch() throws Exception {
+ KvStateRegistry registry = new KvStateRegistry();
+ AtomicKvStateRequestStats stats = new AtomicKvStateRequestStats();
+
+ KvStateServerHandler handler = new KvStateServerHandler(registry, TEST_THREAD_POOL, stats);
+ EmbeddedChannel channel = new EmbeddedChannel(getFrameDecoder(), handler);
+
+ // Register state
+ ValueStateDescriptor<Integer> desc = new ValueStateDescriptor<>("any", IntSerializer.INSTANCE, null);
+ desc.setQueryable("any");
+
+ MemValueState<Integer, VoidNamespace, Integer> kvState = new MemValueState<>(
+ IntSerializer.INSTANCE,
+ VoidNamespaceSerializer.INSTANCE,
+ desc);
+
+ KvStateID kvStateId = registry.registerKvState(
+ new JobID(),
+ new JobVertexID(),
+ 0,
+ "vanilla",
+ kvState);
+
+ int key = 99812822;
+
+ // Update the KvState
+ kvState.setCurrentKey(key);
+ kvState.setCurrentNamespace(VoidNamespace.INSTANCE);
+ kvState.update(712828289);
+
+ byte[] wrongKeyAndNamespace = KvStateRequestSerializer.serializeKeyAndNamespace(
+ "wrong-key-type",
+ StringSerializer.INSTANCE,
+ "wrong-namespace-type",
+ StringSerializer.INSTANCE);
+
+ byte[] wrongNamespace = KvStateRequestSerializer.serializeKeyAndNamespace(
+ key,
+ IntSerializer.INSTANCE,
+ "wrong-namespace-type",
+ StringSerializer.INSTANCE);
+
+ ByteBuf request = KvStateRequestSerializer.serializeKvStateRequest(
+ channel.alloc(),
+ 182828,
+ kvStateId,
+ wrongKeyAndNamespace);
+
+ // Write the request and wait for the response
+ channel.writeInbound(request);
+
+ ByteBuf buf = (ByteBuf) readInboundBlocking(channel);
+ buf.skipBytes(4); // skip frame length
+
+ // Verify the response
+ assertEquals(KvStateRequestType.REQUEST_FAILURE, KvStateRequestSerializer.deserializeHeader(buf));
+ KvStateRequestFailure response = KvStateRequestSerializer.deserializeKvStateRequestFailure(buf);
+ assertEquals(182828, response.getRequestId());
+ assertTrue(response.getCause().getMessage().contains("IllegalArgumentException"));
+
+ // Repeat with wrong namespace only
+ request = KvStateRequestSerializer.serializeKvStateRequest(
+ channel.alloc(),
+ 182829,
+ kvStateId,
+ wrongNamespace);
+
+ // Write the request and wait for the response
+ channel.writeInbound(request);
+
+ buf = (ByteBuf) readInboundBlocking(channel);
+ buf.skipBytes(4); // skip frame length
+
+ // Verify the response
+ assertEquals(KvStateRequestType.REQUEST_FAILURE, KvStateRequestSerializer.deserializeHeader(buf));
+ response = KvStateRequestSerializer.deserializeKvStateRequestFailure(buf);
+ assertEquals(182829, response.getRequestId());
+ assertTrue(response.getCause().getMessage().contains("IllegalArgumentException"));
+
+ assertEquals(2, stats.getNumRequests());
+ assertEquals(2, stats.getNumFailed());
+ }
+
+ /**
+ * Tests that large responses are chunked.
+ */
+ @Test
+ public void testChunkedResponse() throws Exception {
+ KvStateRegistry registry = new KvStateRegistry();
+ KvStateRequestStats stats = new AtomicKvStateRequestStats();
+
+ KvStateServerHandler handler = new KvStateServerHandler(registry, TEST_THREAD_POOL, stats);
+ EmbeddedChannel channel = new EmbeddedChannel(getFrameDecoder(), handler);
+
+ // Register state
+ ValueStateDescriptor<byte[]> desc = new ValueStateDescriptor<>("any", BytePrimitiveArraySerializer.INSTANCE, null);
+ desc.setQueryable("any");
+
+ MemValueState<Integer, VoidNamespace, byte[]> kvState = new MemValueState<>(
+ IntSerializer.INSTANCE,
+ VoidNamespaceSerializer.INSTANCE,
+ desc);
+
+ KvStateID kvStateId = registry.registerKvState(
+ new JobID(),
+ new JobVertexID(),
+ 0,
+ "vanilla",
+ kvState);
+
+ // Update KvState
+ byte[] bytes = new byte[2 * channel.config().getWriteBufferHighWaterMark()];
+
+ byte current = 0;
+ for (int i = 0; i < bytes.length; i++) {
+ bytes[i] = current++;
+ }
+
+ int key = 99812822;
+ kvState.setCurrentKey(key);
+ kvState.setCurrentNamespace(VoidNamespace.INSTANCE);
+ kvState.update(bytes);
+
+ // Request
+ byte[] serializedKeyAndNamespace = KvStateRequestSerializer.serializeKeyAndNamespace(
+ key,
+ IntSerializer.INSTANCE,
+ VoidNamespace.INSTANCE,
+ VoidNamespaceSerializer.INSTANCE);
+
+ long requestId = Integer.MAX_VALUE + 182828L;
+ ByteBuf request = KvStateRequestSerializer.serializeKvStateRequest(
+ channel.alloc(),
+ requestId,
+ kvStateId,
+ serializedKeyAndNamespace);
+
+ // Write the request and wait for the response
+ channel.writeInbound(request);
+
+ Object msg = readInboundBlocking(channel);
+ assertTrue("Not ChunkedByteBuf", msg instanceof ChunkedByteBuf);
+ }
+
+ // ------------------------------------------------------------------------
+
+ /**
+ * Queries the embedded channel for data.
+ */
+ private Object readInboundBlocking(EmbeddedChannel channel) throws InterruptedException, TimeoutException {
+ final int sleepMillis = 50;
+
+ int sleptMillis = 0;
+
+ Object msg = null;
+ while (sleptMillis < READ_TIMEOUT_MILLIS &&
+ (msg = channel.readOutbound()) == null) {
+
+ Thread.sleep(sleepMillis);
+ sleptMillis += sleepMillis;
+ }
+
+ if (msg == null) {
+ throw new TimeoutException();
+ } else {
+ return msg;
+ }
+ }
+
+ /**
+ * Frame length decoder (expected by the serialized messages).
+ */
+ private ChannelHandler getFrameDecoder() {
+ return new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4);
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/af07eed8/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/KvStateServerTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/KvStateServerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/KvStateServerTest.java
new file mode 100644
index 0000000..d653f73
--- /dev/null
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/KvStateServerTest.java
@@ -0,0 +1,174 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.query.netty;
+
+import io.netty.bootstrap.Bootstrap;
+import io.netty.buffer.ByteBuf;
+import io.netty.channel.Channel;
+import io.netty.channel.ChannelHandler;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.ChannelInboundHandlerAdapter;
+import io.netty.channel.ChannelInitializer;
+import io.netty.channel.EventLoopGroup;
+import io.netty.channel.nio.NioEventLoopGroup;
+import io.netty.channel.socket.SocketChannel;
+import io.netty.channel.socket.nio.NioSocketChannel;
+import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
+import org.apache.flink.api.common.JobID;
+import org.apache.flink.api.common.state.ValueStateDescriptor;
+import org.apache.flink.api.common.typeutils.base.IntSerializer;
+import org.apache.flink.runtime.jobgraph.JobVertexID;
+import org.apache.flink.runtime.query.KvStateID;
+import org.apache.flink.runtime.query.KvStateRegistry;
+import org.apache.flink.runtime.query.KvStateServerAddress;
+import org.apache.flink.runtime.query.netty.message.KvStateRequestResult;
+import org.apache.flink.runtime.query.netty.message.KvStateRequestSerializer;
+import org.apache.flink.runtime.query.netty.message.KvStateRequestType;
+import org.apache.flink.runtime.state.VoidNamespace;
+import org.apache.flink.runtime.state.VoidNamespaceSerializer;
+import org.apache.flink.runtime.state.memory.MemValueState;
+import org.junit.AfterClass;
+import org.junit.Test;
+
+import java.net.InetAddress;
+import java.util.concurrent.BlockingQueue;
+import java.util.concurrent.LinkedBlockingQueue;
+import java.util.concurrent.TimeUnit;
+
+import static org.junit.Assert.assertEquals;
+
+public class KvStateServerTest {
+
+ // Thread pool for client bootstrap (shared between tests)
+ private static final NioEventLoopGroup NIO_GROUP = new NioEventLoopGroup();
+
+ private final static int TIMEOUT_MILLIS = 10000;
+
+ @AfterClass
+ public static void tearDown() throws Exception {
+ if (NIO_GROUP != null) {
+ NIO_GROUP.shutdownGracefully();
+ }
+ }
+
+ /**
+ * Tests a simple successful query via a SocketChannel.
+ */
+ @Test
+ public void testSimpleRequest() throws Exception {
+ KvStateServer server = null;
+ Bootstrap bootstrap = null;
+
+ try {
+ KvStateRegistry registry = new KvStateRegistry();
+ KvStateRequestStats stats = new AtomicKvStateRequestStats();
+
+ server = new KvStateServer(InetAddress.getLocalHost(), 0, 1, 1, registry, stats);
+ server.start();
+
+ KvStateServerAddress serverAddress = server.getAddress();
+
+ // Register state
+ MemValueState<Integer, VoidNamespace, Integer> kvState = new MemValueState<>(
+ IntSerializer.INSTANCE,
+ VoidNamespaceSerializer.INSTANCE,
+ new ValueStateDescriptor<>("any", IntSerializer.INSTANCE, null));
+
+ KvStateID kvStateId = registry.registerKvState(
+ new JobID(),
+ new JobVertexID(),
+ 0,
+ "vanilla",
+ kvState);
+
+ // Update KvState
+ int expectedValue = 712828289;
+
+ int key = 99812822;
+ kvState.setCurrentKey(key);
+ kvState.setCurrentNamespace(VoidNamespace.INSTANCE);
+ kvState.update(expectedValue);
+
+ // Request
+ byte[] serializedKeyAndNamespace = KvStateRequestSerializer.serializeKeyAndNamespace(
+ key,
+ IntSerializer.INSTANCE,
+ VoidNamespace.INSTANCE,
+ VoidNamespaceSerializer.INSTANCE);
+
+ // Connect to the server
+ final BlockingQueue<ByteBuf> responses = new LinkedBlockingQueue<>();
+ bootstrap = createBootstrap(
+ new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4),
+ new ChannelInboundHandlerAdapter() {
+ @Override
+ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
+ responses.add((ByteBuf) msg);
+ }
+ });
+
+ Channel channel = bootstrap
+ .connect(serverAddress.getHost(), serverAddress.getPort())
+ .sync().channel();
+
+ long requestId = Integer.MAX_VALUE + 182828L;
+ ByteBuf request = KvStateRequestSerializer.serializeKvStateRequest(
+ channel.alloc(),
+ requestId,
+ kvStateId,
+ serializedKeyAndNamespace);
+
+ channel.writeAndFlush(request);
+
+ ByteBuf buf = responses.poll(TIMEOUT_MILLIS, TimeUnit.MILLISECONDS);
+
+ assertEquals(KvStateRequestType.REQUEST_RESULT, KvStateRequestSerializer.deserializeHeader(buf));
+ KvStateRequestResult response = KvStateRequestSerializer.deserializeKvStateRequestResult(buf);
+
+ assertEquals(requestId, response.getRequestId());
+ int actualValue = KvStateRequestSerializer.deserializeValue(response.getSerializedResult(), IntSerializer.INSTANCE);
+ assertEquals(expectedValue, actualValue);
+ } finally {
+ if (server != null) {
+ server.shutDown();
+ }
+
+ if (bootstrap != null) {
+ EventLoopGroup group = bootstrap.group();
+ if (group != null) {
+ group.shutdownGracefully();
+ }
+ }
+ }
+ }
+
+ /**
+ * Creates a client bootstrap.
+ */
+ private Bootstrap createBootstrap(final ChannelHandler... handlers) {
+ return new Bootstrap().group(NIO_GROUP).channel(NioSocketChannel.class)
+ .handler(new ChannelInitializer<SocketChannel>() {
+ @Override
+ protected void initChannel(SocketChannel ch) throws Exception {
+ ch.pipeline().addLast(handlers);
+ }
+ });
+ }
+
+}