You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by al...@apache.org on 2017/02/23 16:06:27 UTC
[1/2] flink git commit: [FLINK-4856] Add MapState for keyed state
Repository: flink
Updated Branches:
refs/heads/master de2605ea7 -> 30c9e2b68
http://git-wip-us.apache.org/repos/asf/flink/blob/30c9e2b6/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
index e386e0f..04e4fbc 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
@@ -23,6 +23,7 @@ import org.apache.flink.annotation.VisibleForTesting;
import org.apache.flink.api.common.state.AggregatingStateDescriptor;
import org.apache.flink.api.common.state.FoldingStateDescriptor;
import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.state.MapStateDescriptor;
import org.apache.flink.api.common.state.ReducingStateDescriptor;
import org.apache.flink.api.common.state.StateDescriptor;
import org.apache.flink.api.common.state.ValueStateDescriptor;
@@ -44,6 +45,7 @@ import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
import org.apache.flink.runtime.state.ArrayListSerializer;
import org.apache.flink.runtime.state.CheckpointStreamFactory;
import org.apache.flink.runtime.state.DoneFuture;
+import org.apache.flink.runtime.state.HashMapSerializer;
import org.apache.flink.runtime.state.KeyGroupRange;
import org.apache.flink.runtime.state.KeyGroupRangeOffsets;
import org.apache.flink.runtime.state.KeyGroupsStateHandle;
@@ -55,6 +57,7 @@ import org.apache.flink.runtime.state.VoidNamespaceSerializer;
import org.apache.flink.runtime.state.internal.InternalAggregatingState;
import org.apache.flink.runtime.state.internal.InternalFoldingState;
import org.apache.flink.runtime.state.internal.InternalListState;
+import org.apache.flink.runtime.state.internal.InternalMapState;
import org.apache.flink.runtime.state.internal.InternalReducingState;
import org.apache.flink.runtime.state.internal.InternalValueState;
import org.apache.flink.util.InstantiationUtil;
@@ -186,7 +189,7 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
}
@Override
- protected <N, T, ACC> InternalFoldingState<N, T, ACC> createFoldingState(
+ public <N, T, ACC> InternalFoldingState<N, T, ACC> createFoldingState(
TypeSerializer<N> namespaceSerializer,
FoldingStateDescriptor<T, ACC> stateDesc) throws Exception {
@@ -195,6 +198,19 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
}
@Override
+ public <N, UK, UV> InternalMapState<N, UK, UV> createMapState(TypeSerializer<N> namespaceSerializer,
+ MapStateDescriptor<UK, UV> stateDesc) throws Exception {
+
+ StateTable<K, N, HashMap<UK, UV>> stateTable = tryRegisterStateTable(
+ stateDesc.getName(),
+ stateDesc.getType(),
+ namespaceSerializer,
+ new HashMapSerializer<>(stateDesc.getKeySerializer(), stateDesc.getValueSerializer()));
+
+ return new HeapMapState<>(this, stateDesc, stateTable, keySerializer, namespaceSerializer);
+ }
+
+ @Override
@SuppressWarnings("unchecked")
public RunnableFuture<KeyGroupsStateHandle> snapshot(
long checkpointId,
http://git-wip-us.apache.org/repos/asf/flink/blob/30c9e2b6/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapMapState.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapMapState.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapMapState.java
new file mode 100644
index 0000000..b28d661
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapMapState.java
@@ -0,0 +1,311 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state.heap;
+
+import org.apache.flink.api.common.state.MapState;
+import org.apache.flink.api.common.state.MapStateDescriptor;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.runtime.query.netty.message.KvStateRequestSerializer;
+import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
+import org.apache.flink.runtime.state.KeyedStateBackend;
+import org.apache.flink.runtime.state.internal.InternalMapState;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.Map;
+
+/**
+ * Heap-backed partitioned {@link MapState} that is snapshotted into files.
+ *
+ * @param <K> The type of the key.
+ * @param <N> The type of the namespace.
+ * @param <UK> The type of the keys in the state.
+ * @param <UV> The type of the values in the state.
+ */
+public class HeapMapState<K, N, UK, UV>
+ extends AbstractHeapState<K, N, HashMap<UK, UV>, MapState<UK, UV>, MapStateDescriptor<UK, UV>>
+ implements InternalMapState<N, UK, UV> {
+
+ /**
+ * Creates a new key/value state for the given hash map of key/value pairs.
+ *
+ * @param backend The state backend backing that created this state.
+ * @param stateDesc The state identifier for the state. This contains name
+ * and can create a default state value.
+ * @param stateTable The state tab;e to use in this kev/value state. May contain initial state.
+ */
+ public HeapMapState(KeyedStateBackend<K> backend,
+ MapStateDescriptor<UK, UV> stateDesc,
+ StateTable<K, N, HashMap<UK, UV>> stateTable,
+ TypeSerializer<K> keySerializer,
+ TypeSerializer<N> namespaceSerializer) {
+ super(backend, stateDesc, stateTable, keySerializer, namespaceSerializer);
+ }
+
+ @Override
+ public UV get(UK userKey) {
+ Preconditions.checkState(currentNamespace != null, "No namespace set.");
+ Preconditions.checkState(backend.getCurrentKey() != null, "No key set.");
+
+ Map<N, Map<K, HashMap<UK, UV>>> namespaceMap = stateTable.get(backend.getCurrentKeyGroupIndex());
+ if (namespaceMap == null) {
+ return null;
+ }
+
+ Map<K, HashMap<UK, UV>> keyedMap = namespaceMap.get(currentNamespace);
+ if (keyedMap == null) {
+ return null;
+ }
+
+ HashMap<UK, UV> userMap = keyedMap.get(backend.<K>getCurrentKey());
+ if (userMap == null) {
+ return null;
+ }
+
+ return userMap.get(userKey);
+ }
+
+ @Override
+ public void put(UK userKey, UV userValue) {
+ Preconditions.checkState(currentNamespace != null, "No namespace set.");
+ Preconditions.checkState(backend.getCurrentKey() != null, "No key set.");
+
+ Map<N, Map<K, HashMap<UK, UV>>> namespaceMap = stateTable.get(backend.getCurrentKeyGroupIndex());
+ if (namespaceMap == null) {
+ namespaceMap = createNewMap();
+ stateTable.set(backend.getCurrentKeyGroupIndex(), namespaceMap);
+ }
+
+ Map<K, HashMap<UK, UV>> keyedMap = namespaceMap.get(currentNamespace);
+ if (keyedMap == null) {
+ keyedMap = createNewMap();
+ namespaceMap.put(currentNamespace, keyedMap);
+ }
+
+ HashMap<UK, UV> userMap = keyedMap.get(backend.getCurrentKey());
+ if (userMap == null) {
+ userMap = new HashMap<>();
+ keyedMap.put(backend.getCurrentKey(), userMap);
+ }
+
+ userMap.put(userKey, userValue);
+ }
+
+ @Override
+ public void putAll(Map<UK, UV> value) {
+ Preconditions.checkState(currentNamespace != null, "No namespace set.");
+ Preconditions.checkState(backend.getCurrentKey() != null, "No key set.");
+
+ Map<N, Map<K, HashMap<UK, UV>>> namespaceMap = stateTable.get(backend.getCurrentKeyGroupIndex());
+ if (namespaceMap == null) {
+ namespaceMap = createNewMap();
+ stateTable.set(backend.getCurrentKeyGroupIndex(), namespaceMap);
+ }
+
+ Map<K, HashMap<UK, UV>> keyedMap = namespaceMap.get(currentNamespace);
+ if (keyedMap == null) {
+ keyedMap = createNewMap();
+ namespaceMap.put(currentNamespace, keyedMap);
+ }
+
+ HashMap<UK, UV> userMap = keyedMap.get(backend.getCurrentKey());
+ if (userMap == null) {
+ userMap = new HashMap<>();
+ keyedMap.put(backend.getCurrentKey(), userMap);
+ }
+
+ userMap.putAll(value);
+ }
+
+ @Override
+ public void remove(UK userKey) {
+ Preconditions.checkState(currentNamespace != null, "No namespace set.");
+ Preconditions.checkState(backend.getCurrentKey() != null, "No key set.");
+
+ Map<N, Map<K, HashMap<UK, UV>>> namespaceMap = stateTable.get(backend.getCurrentKeyGroupIndex());
+ if (namespaceMap == null) {
+ return;
+ }
+
+ Map<K, HashMap<UK, UV>> keyedMap = namespaceMap.get(currentNamespace);
+ if (keyedMap == null) {
+ return;
+ }
+
+ HashMap<UK, UV> userMap = keyedMap.get(backend.getCurrentKey());
+ if (userMap == null) {
+ return;
+ }
+
+ userMap.remove(userKey);
+
+ if (userMap.isEmpty()) {
+ clear();
+ }
+ }
+
+ @Override
+ public boolean contains(UK userKey) {
+ Preconditions.checkState(currentNamespace != null, "No namespace set.");
+ Preconditions.checkState(backend.getCurrentKey() != null, "No key set.");
+
+ Map<N, Map<K, HashMap<UK, UV>>> namespaceMap = stateTable.get(backend.getCurrentKeyGroupIndex());
+ if (namespaceMap == null) {
+ return false;
+ }
+
+ Map<K, HashMap<UK, UV>> keyedMap = namespaceMap.get(currentNamespace);
+ if (keyedMap == null) {
+ return false;
+ }
+
+ HashMap<UK, UV> userMap = keyedMap.get(backend.<K>getCurrentKey());
+
+ return userMap != null && userMap.containsKey(userKey);
+ }
+
+ @Override
+ public int size() {
+ Preconditions.checkState(currentNamespace != null, "No namespace set.");
+ Preconditions.checkState(backend.getCurrentKey() != null, "No key set.");
+
+ Map<N, Map<K, HashMap<UK, UV>>> namespaceMap = stateTable.get(backend.getCurrentKeyGroupIndex());
+ if (namespaceMap == null) {
+ return 0;
+ }
+
+ Map<K, HashMap<UK, UV>> keyedMap = namespaceMap.get(currentNamespace);
+ if (keyedMap == null) {
+ return 0;
+ }
+
+ HashMap<UK, UV> userMap = keyedMap.get(backend.<K>getCurrentKey());
+
+ return userMap == null ? 0 : userMap.size();
+ }
+
+ @Override
+ public Iterable<Map.Entry<UK, UV>> entries() {
+ Preconditions.checkState(currentNamespace != null, "No namespace set.");
+ Preconditions.checkState(backend.getCurrentKey() != null, "No key set.");
+
+ Map<N, Map<K, HashMap<UK, UV>>> namespaceMap = stateTable.get(backend.getCurrentKeyGroupIndex());
+ if (namespaceMap == null) {
+ return null;
+ }
+
+ Map<K, HashMap<UK, UV>> keyedMap = namespaceMap.get(currentNamespace);
+ if (keyedMap == null) {
+ return null;
+ }
+
+ HashMap<UK, UV> userMap = keyedMap.get(backend.<K>getCurrentKey());
+
+ return userMap == null ? null : userMap.entrySet();
+ }
+
+ @Override
+ public Iterable<UK> keys() {
+ Preconditions.checkState(currentNamespace != null, "No namespace set.");
+ Preconditions.checkState(backend.getCurrentKey() != null, "No key set.");
+
+ Map<N, Map<K, HashMap<UK, UV>>> namespaceMap = stateTable.get(backend.getCurrentKeyGroupIndex());
+ if (namespaceMap == null) {
+ return null;
+ }
+
+ Map<K, HashMap<UK, UV>> keyedMap = namespaceMap.get(currentNamespace);
+ if (keyedMap == null) {
+ return null;
+ }
+
+ HashMap<UK, UV> userMap = keyedMap.get(backend.<K>getCurrentKey());
+
+ return userMap == null ? null : userMap.keySet();
+ }
+
+ @Override
+ public Iterable<UV> values() {
+ Preconditions.checkState(currentNamespace != null, "No namespace set.");
+ Preconditions.checkState(backend.getCurrentKey() != null, "No key set.");
+
+ Map<N, Map<K, HashMap<UK, UV>>> namespaceMap = stateTable.get(backend.getCurrentKeyGroupIndex());
+ if (namespaceMap == null) {
+ return null;
+ }
+
+ Map<K, HashMap<UK, UV>> keyedMap = namespaceMap.get(currentNamespace);
+ if (keyedMap == null) {
+ return null;
+ }
+
+ HashMap<UK, UV> userMap = keyedMap.get(backend.<K>getCurrentKey());
+
+ return userMap == null ? null : userMap.values();
+ }
+
+ @Override
+ public Iterator<Map.Entry<UK, UV>> iterator() {
+ Preconditions.checkState(currentNamespace != null, "No namespace set.");
+ Preconditions.checkState(backend.getCurrentKey() != null, "No key set.");
+
+ Map<N, Map<K, HashMap<UK, UV>>> namespaceMap = stateTable.get(backend.getCurrentKeyGroupIndex());
+ if (namespaceMap == null) {
+ return null;
+ }
+
+ Map<K, HashMap<UK, UV>> keyedMap = namespaceMap.get(currentNamespace);
+ if (keyedMap == null) {
+ return null;
+ }
+
+ HashMap<UK, UV> userMap = keyedMap.get(backend.<K>getCurrentKey());
+
+ return userMap == null ? null : userMap.entrySet().iterator();
+ }
+
+ @Override
+ public byte[] getSerializedValue(K key, N namespace) throws IOException {
+ Preconditions.checkState(namespace != null, "No namespace given.");
+ Preconditions.checkState(key != null, "No key given.");
+
+ Map<N, Map<K, HashMap<UK, UV>>> namespaceMap = stateTable.get(KeyGroupRangeAssignment.assignToKeyGroup(key, backend.getNumberOfKeyGroups()));
+
+ if (namespaceMap == null) {
+ return null;
+ }
+
+ Map<K, HashMap<UK, UV>> keyedMap = namespaceMap.get(namespace);
+ if (keyedMap == null) {
+ return null;
+ }
+
+ HashMap<UK, UV> result = keyedMap.get(key);
+ if (result == null) {
+ return null;
+ }
+
+ TypeSerializer<UK> userKeySerializer = stateDesc.getKeySerializer();
+ TypeSerializer<UV> userValueSerializer = stateDesc.getValueSerializer();
+
+ return KvStateRequestSerializer.serializeMap(result.entrySet(), userKeySerializer, userValueSerializer);
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/30c9e2b6/flink-runtime/src/main/java/org/apache/flink/runtime/state/internal/InternalMapState.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/internal/InternalMapState.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/internal/InternalMapState.java
new file mode 100644
index 0000000..f2a7b41
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/internal/InternalMapState.java
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state.internal;
+
+import org.apache.flink.api.common.state.MapState;
+
+/**
+ * The peer to the {@link MapState} in the internal state type hierarchy.
+ *
+ * <p>See {@link InternalKvState} for a description of the internal state hierarchy.
+ *
+ * @param <N> The type of the namespace
+ * @param <UK> Type of the values folded into the state
+ * @param <UV> Type of the value in the state
+ */
+public interface InternalMapState<N, UK, UV> extends InternalKvState<N>, MapState<UK, UV> {}
http://git-wip-us.apache.org/repos/asf/flink/blob/30c9e2b6/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/message/KvStateRequestSerializerTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/message/KvStateRequestSerializerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/message/KvStateRequestSerializerTest.java
index 69dbe6f..dd61a3f 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/message/KvStateRequestSerializerTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/message/KvStateRequestSerializerTest.java
@@ -23,7 +23,9 @@ import io.netty.buffer.ByteBufAllocator;
import io.netty.buffer.UnpooledByteBufAllocator;
import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.state.MapStateDescriptor;
import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.common.typeutils.base.ByteSerializer;
import org.apache.flink.api.common.typeutils.base.LongSerializer;
import org.apache.flink.api.common.typeutils.base.StringSerializer;
import org.apache.flink.api.java.tuple.Tuple2;
@@ -36,11 +38,15 @@ import org.apache.flink.runtime.state.heap.HeapKeyedStateBackend;
import org.apache.flink.runtime.state.internal.InternalKvState;
import org.apache.flink.runtime.state.internal.InternalListState;
+import org.apache.flink.runtime.state.internal.InternalMapState;
import org.junit.Test;
+import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.util.ArrayList;
+import java.util.HashMap;
import java.util.List;
+import java.util.Map;
import java.util.concurrent.ThreadLocalRandom;
import static org.junit.Assert.assertArrayEquals;
@@ -410,6 +416,131 @@ public class KvStateRequestSerializerTest {
KvStateRequestSerializer.deserializeList(new byte[] {1, 1, 1, 1, 1, 1, 1, 1, 2, 3},
LongSerializer.INSTANCE);
}
+
+ /**
+ * Tests map serialization utils.
+ */
+ @Test
+ public void testMapSerialization() throws Exception {
+ final long key = 0L;
+
+ // objects for heap state list serialisation
+ final HeapKeyedStateBackend<Long> longHeapKeyedStateBackend =
+ new HeapKeyedStateBackend<>(
+ mock(TaskKvStateRegistry.class),
+ LongSerializer.INSTANCE,
+ ClassLoader.getSystemClassLoader(),
+ 1, new KeyGroupRange(0, 0)
+ );
+ longHeapKeyedStateBackend.setCurrentKey(key);
+
+ final InternalMapState<VoidNamespace, Long, String> mapState = (InternalMapState<VoidNamespace, Long, String>) longHeapKeyedStateBackend.getPartitionedState(
+ VoidNamespace.INSTANCE,
+ VoidNamespaceSerializer.INSTANCE,
+ new MapStateDescriptor<>("test", LongSerializer.INSTANCE, StringSerializer.INSTANCE));
+
+ testMapSerialization(key, mapState);
+ }
+
+ /**
+ * Verifies that the serialization of a map using the given map state
+ * matches the deserialization with {@link KvStateRequestSerializer#deserializeList}.
+ *
+ * @param key
+ * key of the map state
+ * @param mapState
+ * map state using the {@link VoidNamespace}, must also be a {@link InternalKvState} instance
+ *
+ * @throws Exception
+ */
+ public static void testMapSerialization(
+ final long key,
+ final InternalMapState<VoidNamespace, Long, String> mapState) throws Exception {
+
+ TypeSerializer<Long> userKeySerializer = LongSerializer.INSTANCE;
+ TypeSerializer<String> userValueSerializer = StringSerializer.INSTANCE;
+ mapState.setCurrentNamespace(VoidNamespace.INSTANCE);
+
+ // Map
+ final int numElements = 10;
+
+ final Map<Long, String> expectedValues = new HashMap<>();
+ for (int i = 1; i <= numElements; i++) {
+ final long value = ThreadLocalRandom.current().nextLong();
+ expectedValues.put(value, Long.toString(value));
+ mapState.put(value, Long.toString(value));
+ }
+
+ expectedValues.put(0L, null);
+ mapState.put(0L, null);
+
+ final byte[] serializedKey =
+ KvStateRequestSerializer.serializeKeyAndNamespace(
+ key, LongSerializer.INSTANCE,
+ VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE);
+
+ final byte[] serializedValues = mapState.getSerializedValue(serializedKey);
+
+ Map<Long, String> actualValues = KvStateRequestSerializer.deserializeMap(serializedValues, userKeySerializer, userValueSerializer);
+ assertEquals(expectedValues.size(), actualValues.size());
+ for (Map.Entry<Long, String> actualEntry : actualValues.entrySet()) {
+ assertEquals(expectedValues.get(actualEntry.getKey()), actualEntry.getValue());
+ }
+
+ // Single value
+ ByteArrayOutputStream baos = new ByteArrayOutputStream();
+ long expectedKey = ThreadLocalRandom.current().nextLong();
+ String expectedValue = Long.toString(expectedKey);
+ byte[] isNull = {0};
+
+ baos.write(KvStateRequestSerializer.serializeValue(expectedKey, userKeySerializer));
+ baos.write(isNull);
+ baos.write(KvStateRequestSerializer.serializeValue(expectedValue, userValueSerializer));
+ byte[] serializedValue = baos.toByteArray();
+
+ Map<Long, String> actualValue = KvStateRequestSerializer.deserializeMap(serializedValue, userKeySerializer, userValueSerializer);
+ assertEquals(1, actualValue.size());
+ assertEquals(expectedValue, actualValue.get(expectedKey));
+ }
+
+ /**
+ * Tests map deserialization with too few bytes.
+ */
+ @Test
+ public void testDeserializeMapEmpty() throws Exception {
+ Map<Long, String> actualValue = KvStateRequestSerializer
+ .deserializeMap(new byte[] {}, LongSerializer.INSTANCE, StringSerializer.INSTANCE);
+ assertEquals(0, actualValue.size());
+ }
+
+ /**
+ * Tests map deserialization with too few bytes.
+ */
+ @Test(expected = IOException.class)
+ public void testDeserializeMapTooShort1() throws Exception {
+ // 1 byte (incomplete Key)
+ KvStateRequestSerializer.deserializeMap(new byte[] {1}, LongSerializer.INSTANCE, StringSerializer.INSTANCE);
+ }
+
+ /**
+ * Tests map deserialization with too few bytes.
+ */
+ @Test(expected = IOException.class)
+ public void testDeserializeMapTooShort2() throws Exception {
+ // Long (Key) + 1 byte (incomplete Value)
+ KvStateRequestSerializer.deserializeMap(new byte[]{1, 1, 1, 1, 1, 1, 1, 1, 0},
+ LongSerializer.INSTANCE, LongSerializer.INSTANCE);
+ }
+
+ /**
+ * Tests map deserialization with too few bytes.
+ */
+ @Test(expected = IOException.class)
+ public void testDeserializeMapTooShort3() throws Exception {
+ // Long (Key1) + Boolean (false) + Long (Value1) + 1 byte (incomplete Key2)
+ KvStateRequestSerializer.deserializeMap(new byte[] {1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 3},
+ LongSerializer.INSTANCE, LongSerializer.INSTANCE);
+ }
private byte[] randomByteArray(int capacity) {
byte[] bytes = new byte[capacity];
http://git-wip-us.apache.org/repos/asf/flink/blob/30c9e2b6/flink-runtime/src/test/java/org/apache/flink/runtime/state/FileStateBackendTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/FileStateBackendTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/FileStateBackendTest.java
index 57f4572..75014e7 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/FileStateBackendTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/FileStateBackendTest.java
@@ -65,6 +65,10 @@ public class FileStateBackendTest extends StateBackendTestBase<FsStateBackend> {
@Override
@Test
public void testReducingStateRestoreWithWrongSerializers() {}
+
+ @Override
+ @Test
+ public void testMapStateRestoreWithWrongSerializers() {}
@Test
public void testStateOutputStream() throws IOException {
http://git-wip-us.apache.org/repos/asf/flink/blob/30c9e2b6/flink-runtime/src/test/java/org/apache/flink/runtime/state/MemoryStateBackendTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/MemoryStateBackendTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/MemoryStateBackendTest.java
index c267afc..362fcd6 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/MemoryStateBackendTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/MemoryStateBackendTest.java
@@ -59,6 +59,10 @@ public class MemoryStateBackendTest extends StateBackendTestBase<MemoryStateBack
@Override
@Test
public void testReducingStateRestoreWithWrongSerializers() {}
+
+ @Override
+ @Test
+ public void testMapStateRestoreWithWrongSerializers() {}
@Test
@SuppressWarnings("unchecked")
http://git-wip-us.apache.org/repos/asf/flink/blob/30c9e2b6/flink-runtime/src/test/java/org/apache/flink/runtime/state/SerializationProxiesTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/SerializationProxiesTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/SerializationProxiesTest.java
index 66e8d02..0dbe2eb 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/SerializationProxiesTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/SerializationProxiesTest.java
@@ -161,7 +161,7 @@ public class SerializationProxiesTest {
@Test
public void testFixTypeOrder() {
// ensure all elements are covered
- Assert.assertEquals(6, StateDescriptor.Type.values().length);
+ Assert.assertEquals(7, StateDescriptor.Type.values().length);
// fix the order of elements to keep serialization format stable
Assert.assertEquals(0, StateDescriptor.Type.UNKNOWN.ordinal());
Assert.assertEquals(1, StateDescriptor.Type.VALUE.ordinal());
@@ -169,5 +169,6 @@ public class SerializationProxiesTest {
Assert.assertEquals(3, StateDescriptor.Type.REDUCING.ordinal());
Assert.assertEquals(4, StateDescriptor.Type.FOLDING.ordinal());
Assert.assertEquals(5, StateDescriptor.Type.AGGREGATING.ordinal());
+ Assert.assertEquals(6, StateDescriptor.Type.MAP.ordinal());
}
}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/flink/blob/30c9e2b6/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java
index 7737ecf..3b0350d 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java
@@ -28,6 +28,8 @@ import org.apache.flink.api.common.state.FoldingState;
import org.apache.flink.api.common.state.FoldingStateDescriptor;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.state.MapState;
+import org.apache.flink.api.common.state.MapStateDescriptor;
import org.apache.flink.api.common.state.ReducingState;
import org.apache.flink.api.common.state.ReducingStateDescriptor;
import org.apache.flink.api.common.state.ValueState;
@@ -54,8 +56,12 @@ import org.apache.flink.util.TestLogger;
import org.junit.Test;
import java.io.IOException;
+import java.util.ArrayList;
import java.util.Collections;
+import java.util.HashMap;
+import java.util.Iterator;
import java.util.List;
+import java.util.Map;
import java.util.Random;
import java.util.Timer;
import java.util.TimerTask;
@@ -784,6 +790,169 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
fail(e.getMessage());
}
}
+
+ @Test
+ @SuppressWarnings("unchecked,rawtypes")
+ public void testMapState() {
+ try {
+ CheckpointStreamFactory streamFactory = createStreamFactory();
+ AbstractKeyedStateBackend<Integer> backend = createKeyedBackend(IntSerializer.INSTANCE);
+
+ MapStateDescriptor<Integer, String> kvId = new MapStateDescriptor<>("id", Integer.class, String.class);
+ kvId.initializeSerializerUnlessSet(new ExecutionConfig());
+
+ TypeSerializer<Integer> keySerializer = IntSerializer.INSTANCE;
+ TypeSerializer<VoidNamespace> namespaceSerializer = VoidNamespaceSerializer.INSTANCE;
+ TypeSerializer<Integer> userKeySerializer = kvId.getKeySerializer();
+ TypeSerializer<String> userValueSerializer = kvId.getValueSerializer();
+
+ MapState<Integer, String> state = backend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, kvId);
+ @SuppressWarnings("unchecked")
+ InternalKvState<VoidNamespace> kvState = (InternalKvState<VoidNamespace>) state;
+
+ // some modifications to the state
+ backend.setCurrentKey(1);
+ assertEquals(0, state.size());
+ assertEquals(null, state.get(1));
+ assertEquals(null, getSerializedMap(kvState, 1, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, userKeySerializer, userValueSerializer));
+ state.put(1, "1");
+ backend.setCurrentKey(2);
+ assertEquals(0, state.size());
+ assertEquals(null, state.get(2));
+ assertEquals(null, getSerializedMap(kvState, 2, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, userKeySerializer, userValueSerializer));
+ state.put(2, "2");
+ backend.setCurrentKey(1);
+ assertEquals(1, state.size());
+ assertTrue(state.contains(1));
+ assertEquals("1", state.get(1));
+ assertEquals(new HashMap<Integer, String>() {{ put (1, "1"); }},
+ getSerializedMap(kvState, 1, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, userKeySerializer, userValueSerializer));
+
+ // draw a snapshot
+ KeyGroupsStateHandle snapshot1 = runSnapshot(backend.snapshot(682375462378L, 2, streamFactory));
+
+ // make some more modifications
+ backend.setCurrentKey(1);
+ state.put(1, "101");
+ backend.setCurrentKey(2);
+ state.put(102, "102");
+ backend.setCurrentKey(3);
+ state.put(103, "103");
+ state.putAll(new HashMap<Integer, String>() {{ put(1031, "1031"); put(1032, "1032"); }});
+
+ // draw another snapshot
+ KeyGroupsStateHandle snapshot2 = runSnapshot(backend.snapshot(682375462379L, 4, streamFactory));
+
+ // validate the original state
+ backend.setCurrentKey(1);
+ assertEquals("101", state.get(1));
+ assertEquals(new HashMap<Integer, String>() {{ put(1, "101"); }},
+ getSerializedMap(kvState, 1, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, userKeySerializer, userValueSerializer));
+ backend.setCurrentKey(2);
+ assertEquals("102", state.get(102));
+ assertEquals(new HashMap<Integer, String>() {{ put(2, "2"); put(102, "102"); }},
+ getSerializedMap(kvState, 2, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, userKeySerializer, userValueSerializer));
+ backend.setCurrentKey(3);
+ assertEquals(3, state.size());
+ assertTrue(state.contains(103));
+ assertEquals("103", state.get(103));
+ assertEquals(new HashMap<Integer, String>() {{ put(103, "103"); put(1031, "1031"); put(1032, "1032"); }},
+ getSerializedMap(kvState, 3, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, userKeySerializer, userValueSerializer));
+
+ List<Integer> keys = new ArrayList<>();
+ for (Integer key : state.keys()) {
+ keys.add(key);
+ }
+ List<Integer> expectedKeys = new ArrayList<Integer>() {{ add(103); add(1031); add(1032); }};
+ assertEquals(keys.size(), expectedKeys.size());
+ keys.removeAll(expectedKeys);
+ assertTrue(keys.isEmpty());
+
+ List<String> values = new ArrayList<>();
+ for (String value : state.values()) {
+ values.add(value);
+ }
+ List<String> expectedValues = new ArrayList<String>() {{ add("103"); add("1031"); add("1032"); }};
+ assertEquals(values.size(), expectedValues.size());
+ values.removeAll(expectedValues);
+ assertTrue(values.isEmpty());
+
+ // make some more modifications
+ backend.setCurrentKey(1);
+ state.clear();
+ backend.setCurrentKey(2);
+ state.remove(102);
+ backend.setCurrentKey(3);
+ final String updateSuffix = "_updated";
+ Iterator<Map.Entry<Integer, String>> iterator = state.iterator();
+ while (iterator.hasNext()) {
+ Map.Entry<Integer, String> entry = iterator.next();
+ if (entry.getValue().length() != 4) {
+ iterator.remove();
+ } else {
+ entry.setValue(entry.getValue() + updateSuffix);
+ }
+ }
+
+ // validate the state
+ backend.setCurrentKey(1);
+ assertEquals(0, state.size());
+ backend.setCurrentKey(2);
+ assertFalse(state.contains(102));
+ backend.setCurrentKey(3);
+ for (Map.Entry<Integer, String> entry : state.entries()) {
+ assertEquals(4 + updateSuffix.length(), entry.getValue().length());
+ assertTrue(entry.getValue().endsWith(updateSuffix));
+ }
+
+ backend.dispose();
+ // restore the first snapshot and validate it
+ backend = restoreKeyedBackend(IntSerializer.INSTANCE, snapshot1);
+ snapshot1.discardState();
+
+ MapState<Integer, String> restored1 = backend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, kvId);
+ @SuppressWarnings("unchecked")
+ InternalKvState<VoidNamespace> restoredKvState1 = (InternalKvState<VoidNamespace>) restored1;
+
+ backend.setCurrentKey(1);
+ assertEquals("1", restored1.get(1));
+ assertEquals(new HashMap<Integer, String>() {{ put (1, "1"); }},
+ getSerializedMap(restoredKvState1, 1, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, userKeySerializer, userValueSerializer));
+ backend.setCurrentKey(2);
+ assertEquals("2", restored1.get(2));
+ assertEquals(new HashMap<Integer, String>() {{ put (2, "2"); }},
+ getSerializedMap(restoredKvState1, 2, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, userKeySerializer, userValueSerializer));
+
+ backend.dispose();
+ // restore the second snapshot and validate it
+ backend = restoreKeyedBackend(IntSerializer.INSTANCE, snapshot2);
+ snapshot2.discardState();
+
+ @SuppressWarnings("unchecked")
+ MapState<Integer, String> restored2 = backend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, kvId);
+ @SuppressWarnings("unchecked")
+ InternalKvState<VoidNamespace> restoredKvState2 = (InternalKvState<VoidNamespace>) restored2;
+
+ backend.setCurrentKey(1);
+ assertEquals("101", restored2.get(1));
+ assertEquals(new HashMap<Integer, String>() {{ put (1, "101"); }},
+ getSerializedMap(restoredKvState2, 1, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, userKeySerializer, userValueSerializer));
+ backend.setCurrentKey(2);
+ assertEquals("102", restored2.get(102));
+ assertEquals(new HashMap<Integer, String>() {{ put(2, "2"); put (102, "102"); }},
+ getSerializedMap(restoredKvState2, 2, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, userKeySerializer, userValueSerializer));
+ backend.setCurrentKey(3);
+ assertEquals("103", restored2.get(103));
+ assertEquals(new HashMap<Integer, String>() {{ put(103, "103"); put(1031, "1031"); put(1032, "1032"); }},
+ getSerializedMap(restoredKvState2, 3, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, userKeySerializer, userValueSerializer));
+
+ backend.dispose();
+ } catch (Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+
+ }
/**
* Verify that {@link ValueStateDescriptor} allows {@code null} as default.
@@ -917,9 +1086,36 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
backend.dispose();
}
+ /**
+ * Verify that an empty {@code MapState} yields {@code null}.
+ */
+ @Test
+ public void testMapStateDefaultValue() throws Exception {
+ AbstractKeyedStateBackend<Integer> backend = createKeyedBackend(IntSerializer.INSTANCE);
+
+ MapStateDescriptor<String, String> kvId = new MapStateDescriptor<>("id", String.class, String.class);
+ kvId.initializeSerializerUnlessSet(new ExecutionConfig());
+
+ MapState<String, String> state = backend.getPartitionedState(
+ VoidNamespace.INSTANCE,
+ VoidNamespaceSerializer.INSTANCE, kvId);
+
+ backend.setCurrentKey(1);
+ assertNull(state.entries());
+ state.put("Ciao", "Hello");
+ state.put("Bello", "Nice");
+
+ assertEquals(state.size(), 2);
+ assertEquals(state.get("Ciao"), "Hello");
+ assertEquals(state.get("Bello"), "Nice");
+ state.clear();
+ assertNull(state.entries());
+ backend.dispose();
+ }
+
/**
* This test verifies that state is correctly assigned to key groups and that restore
* restores the relevant key groups in the backend.
@@ -1172,6 +1368,58 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
fail(e.getMessage());
}
}
+
+ @Test
+ @SuppressWarnings("unchecked")
+ public void testMapStateRestoreWithWrongSerializers() {
+ try {
+ CheckpointStreamFactory streamFactory = createStreamFactory();
+ AbstractKeyedStateBackend<Integer> backend = createKeyedBackend(IntSerializer.INSTANCE);
+
+ MapStateDescriptor<String, String> kvId = new MapStateDescriptor<>("id", StringSerializer.INSTANCE, StringSerializer.INSTANCE);
+ MapState<String, String> state = backend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, kvId);
+
+ backend.setCurrentKey(1);
+ state.put("1", "First");
+ backend.setCurrentKey(2);
+ state.put("2", "Second");
+
+ // draw a snapshot
+ KeyGroupsStateHandle snapshot1 = runSnapshot(backend.snapshot(682375462378L, 2, streamFactory));
+
+ backend.dispose();
+ // restore the first snapshot and validate it
+ backend = restoreKeyedBackend(IntSerializer.INSTANCE, snapshot1);
+ snapshot1.discardState();
+
+ @SuppressWarnings("unchecked")
+ TypeSerializer<String> fakeStringSerializer =
+ (TypeSerializer<String>) (TypeSerializer<?>) FloatSerializer.INSTANCE;
+
+ try {
+ kvId = new MapStateDescriptor<>("id", fakeStringSerializer, StringSerializer.INSTANCE);
+
+ state = backend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, kvId);
+
+ state.entries();
+
+ fail("should recognize wrong serializers");
+ } catch (IOException e) {
+ if (!e.getMessage().contains("Trying to access state using wrong ")) {
+ fail("wrong exception " + e);
+ }
+ // expected
+ } catch (Exception e) {
+ fail("wrong exception " + e);
+ }
+ backend.dispose();
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+
@Test
public void testCopyDefaultValue() throws Exception {
@@ -1357,6 +1605,31 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
assertTrue(stateTable.get(keyGroupIndex) instanceof ConcurrentHashMap);
assertTrue(stateTable.get(keyGroupIndex).get(VoidNamespace.INSTANCE) instanceof ConcurrentHashMap);
}
+
+ {
+ // MapState
+ MapStateDescriptor<Integer, String> desc = new MapStateDescriptor<>("map-state", Integer.class, String.class);
+ desc.setQueryable("my-query");
+ desc.initializeSerializerUnlessSet(new ExecutionConfig());
+
+ MapState<Integer, String> state = backend.getPartitionedState(
+ VoidNamespace.INSTANCE,
+ VoidNamespaceSerializer.INSTANCE,
+ desc);
+
+ InternalKvState<VoidNamespace> kvState = (InternalKvState<VoidNamespace>) state;
+ assertTrue(kvState instanceof AbstractHeapState);
+
+ kvState.setCurrentNamespace(VoidNamespace.INSTANCE);
+ backend.setCurrentKey(1);
+ state.put(121818273, "121818273");
+
+ int keyGroupIndex = KeyGroupRangeAssignment.assignToKeyGroup(1, numberOfKeyGroups);
+ StateTable stateTable = ((AbstractHeapState) kvState).getStateTable();
+ assertNotNull("State not set", stateTable.get(keyGroupIndex));
+ assertTrue(stateTable.get(keyGroupIndex) instanceof ConcurrentHashMap);
+ assertTrue(stateTable.get(keyGroupIndex).get(VoidNamespace.INSTANCE) instanceof ConcurrentHashMap);
+ }
backend.dispose();
}
@@ -1495,6 +1768,32 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
return KvStateRequestSerializer.deserializeList(serializedValue, valueSerializer);
}
}
+
+ /**
+ * Returns the value by getting the serialized value and deserializing it
+ * if it is not null.
+ */
+ private static <UK, UV, K, N> Map<UK, UV> getSerializedMap(
+ InternalKvState<N> kvState,
+ K key,
+ TypeSerializer<K> keySerializer,
+ N namespace,
+ TypeSerializer<N> namespaceSerializer,
+ TypeSerializer<UK> userKeySerializer,
+ TypeSerializer<UV> userValueSerializer
+ ) throws Exception {
+
+ byte[] serializedKeyAndNamespace = KvStateRequestSerializer.serializeKeyAndNamespace(
+ key, keySerializer, namespace, namespaceSerializer);
+
+ byte[] serializedValue = kvState.getSerializedValue(serializedKeyAndNamespace);
+
+ if (serializedValue == null) {
+ return null;
+ } else {
+ return KvStateRequestSerializer.deserializeMap(serializedValue, userKeySerializer, userValueSerializer);
+ }
+ }
private KeyGroupsStateHandle runSnapshot(RunnableFuture<KeyGroupsStateHandle> snapshotRunnableFuture) throws Exception {
if(!snapshotRunnableFuture.isDone()) {
http://git-wip-us.apache.org/repos/asf/flink/blob/30c9e2b6/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/async/RichAsyncFunction.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/async/RichAsyncFunction.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/async/RichAsyncFunction.java
index e6a186a..7971460 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/async/RichAsyncFunction.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/async/RichAsyncFunction.java
@@ -36,6 +36,8 @@ import org.apache.flink.api.common.state.FoldingState;
import org.apache.flink.api.common.state.FoldingStateDescriptor;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.state.MapState;
+import org.apache.flink.api.common.state.MapStateDescriptor;
import org.apache.flink.api.common.state.ReducingState;
import org.apache.flink.api.common.state.ReducingStateDescriptor;
import org.apache.flink.api.common.state.ValueState;
@@ -171,6 +173,11 @@ public abstract class RichAsyncFunction<IN, OUT> extends AbstractRichFunction im
throw new UnsupportedOperationException("State is not supported in rich async functions.");
}
+ @Override
+ public <UK, UV> MapState<UK, UV> getMapState(MapStateDescriptor<UK, UV> stateProperties) {
+ throw new UnsupportedOperationException("State is not supported in rich async functions.");
+ }
+
@Override
public <V, A extends Serializable> void addAccumulator(String name, Accumulator<V, A> accumulator) {
http://git-wip-us.apache.org/repos/asf/flink/blob/30c9e2b6/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContext.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContext.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContext.java
index b9c9b9b..b666a2b 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContext.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContext.java
@@ -27,6 +27,8 @@ import org.apache.flink.api.common.state.FoldingStateDescriptor;
import org.apache.flink.api.common.state.KeyedStateStore;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.state.MapState;
+import org.apache.flink.api.common.state.MapStateDescriptor;
import org.apache.flink.api.common.state.ReducingState;
import org.apache.flink.api.common.state.ReducingStateDescriptor;
import org.apache.flink.api.common.state.StateDescriptor;
@@ -136,6 +138,13 @@ public class StreamingRuntimeContext extends AbstractRuntimeUDFContext {
stateProperties.initializeSerializerUnlessSet(getExecutionConfig());
return keyedStateStore.getFoldingState(stateProperties);
}
+
+ @Override
+ public <UK, UV> MapState<UK, UV> getMapState(MapStateDescriptor<UK, UV> stateProperties) {
+ KeyedStateStore keyedStateStore = checkPreconditionsAndGetKeyedStateStore(stateProperties);
+ stateProperties.initializeSerializerUnlessSet(getExecutionConfig());
+ return keyedStateStore.getMapState(stateProperties);
+ }
private KeyedStateStore checkPreconditionsAndGetKeyedStateStore(StateDescriptor<?, ?> stateDescriptor) {
Preconditions.checkNotNull(stateDescriptor, "The state properties must not be null");
http://git-wip-us.apache.org/repos/asf/flink/blob/30c9e2b6/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/functions/async/RichAsyncFunctionTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/functions/async/RichAsyncFunctionTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/functions/async/RichAsyncFunctionTest.java
index 815f856..562883d 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/functions/async/RichAsyncFunctionTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/functions/async/RichAsyncFunctionTest.java
@@ -27,6 +27,7 @@ import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.functions.RuntimeContext;
import org.apache.flink.api.common.state.FoldingStateDescriptor;
import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.state.MapStateDescriptor;
import org.apache.flink.api.common.state.ReducingStateDescriptor;
import org.apache.flink.api.common.state.ValueStateDescriptor;
import org.apache.flink.metrics.MetricGroup;
@@ -165,7 +166,6 @@ public class RichAsyncFunctionTest {
} catch (UnsupportedOperationException e) {
// expected
}
-
try {
runtimeContext.getFoldingState(new FoldingStateDescriptor<>("foobar", 0, new FoldFunction<Integer, Integer>() {
@Override
@@ -178,6 +178,12 @@ public class RichAsyncFunctionTest {
}
try {
+ runtimeContext.getMapState(new MapStateDescriptor<>("foobar", Integer.class, String.class));
+ } catch (UnsupportedOperationException e) {
+ // expected
+ }
+
+ try {
runtimeContext.addAccumulator("foobar", new Accumulator<Integer, Integer>() {
private static final long serialVersionUID = -4673320336846482358L;
http://git-wip-us.apache.org/repos/asf/flink/blob/30c9e2b6/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContextTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContextTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContextTest.java
index 294b8da..36496f2 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContextTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContextTest.java
@@ -27,6 +27,8 @@ import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.state.FoldingStateDescriptor;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.state.MapState;
+import org.apache.flink.api.common.state.MapStateDescriptor;
import org.apache.flink.api.common.state.ReducingStateDescriptor;
import org.apache.flink.api.common.state.StateDescriptor;
import org.apache.flink.api.common.state.ValueStateDescriptor;
@@ -52,6 +54,7 @@ import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import java.util.Collections;
+import java.util.Map;
import java.util.concurrent.Future;
import java.util.concurrent.atomic.AtomicReference;
@@ -178,7 +181,7 @@ public class StreamingRuntimeContextTest {
public void testListStateReturnsEmptyListByDefault() throws Exception {
StreamingRuntimeContext context = new StreamingRuntimeContext(
- createPlainMockOp(),
+ createListPlainMockOp(),
createMockEnvironment(),
Collections.<String, Accumulator<?, ?>>emptyMap());
@@ -190,6 +193,48 @@ public class StreamingRuntimeContextTest {
assertFalse(value.iterator().hasNext());
}
+ @Test
+ public void testMapStateInstantiation() throws Exception {
+
+ final ExecutionConfig config = new ExecutionConfig();
+ config.registerKryoType(Path.class);
+
+ final AtomicReference<Object> descriptorCapture = new AtomicReference<>();
+
+ StreamingRuntimeContext context = new StreamingRuntimeContext(
+ createDescriptorCapturingMockOp(descriptorCapture, config),
+ createMockEnvironment(),
+ Collections.<String, Accumulator<?, ?>>emptyMap());
+
+ MapStateDescriptor<String, TaskInfo> descr =
+ new MapStateDescriptor<>("name", String.class, TaskInfo.class);
+
+ context.getMapState(descr);
+
+ MapStateDescriptor<?, ?> descrIntercepted = (MapStateDescriptor<?, ?>) descriptorCapture.get();
+ TypeSerializer<?> valueSerializer = descrIntercepted.getValueSerializer();
+
+ // check that the Path class is really registered, i.e., the execution config was applied
+ assertTrue(valueSerializer instanceof KryoSerializer);
+ assertTrue(((KryoSerializer<?>) valueSerializer).getKryo().getRegistration(Path.class).getId() > 0);
+ }
+
+ @Test
+ public void testMapStateReturnsEmptyMapByDefault() throws Exception {
+
+ StreamingRuntimeContext context = new StreamingRuntimeContext(
+ createMapPlainMockOp(),
+ createMockEnvironment(),
+ Collections.<String, Accumulator<?, ?>>emptyMap());
+
+ MapStateDescriptor<Integer, String> descr = new MapStateDescriptor<>("name", Integer.class, String.class);
+ MapState<Integer, String> state = context.getMapState(descr);
+
+ Iterable<Map.Entry<Integer, String>> value = state.entries();
+ assertNotNull(value);
+ assertFalse(value.iterator().hasNext());
+ }
+
// ------------------------------------------------------------------------
//
// ------------------------------------------------------------------------
@@ -221,7 +266,7 @@ public class StreamingRuntimeContextTest {
}
@SuppressWarnings("unchecked")
- private static AbstractStreamOperator<?> createPlainMockOp() throws Exception {
+ private static AbstractStreamOperator<?> createListPlainMockOp() throws Exception {
AbstractStreamOperator<?> operatorMock = mock(AbstractStreamOperator.class);
ExecutionConfig config = new ExecutionConfig();
@@ -256,6 +301,42 @@ public class StreamingRuntimeContextTest {
return operatorMock;
}
+ @SuppressWarnings("unchecked")
+ private static AbstractStreamOperator<?> createMapPlainMockOp() throws Exception {
+
+ AbstractStreamOperator<?> operatorMock = mock(AbstractStreamOperator.class);
+ ExecutionConfig config = new ExecutionConfig();
+
+ KeyedStateBackend keyedStateBackend= mock(KeyedStateBackend.class);
+
+ DefaultKeyedStateStore keyedStateStore = new DefaultKeyedStateStore(keyedStateBackend, config);
+
+ when(operatorMock.getExecutionConfig()).thenReturn(config);
+
+ doAnswer(new Answer<MapState<Integer, String>>() {
+
+ @Override
+ public MapState<Integer, String> answer(InvocationOnMock invocationOnMock) throws Throwable {
+ MapStateDescriptor<Integer, String> descr =
+ (MapStateDescriptor<Integer, String>) invocationOnMock.getArguments()[2];
+
+ AbstractKeyedStateBackend<Integer> backend = new MemoryStateBackend().createKeyedStateBackend(
+ new DummyEnvironment("test_task", 1, 0),
+ new JobID(),
+ "test_op",
+ IntSerializer.INSTANCE,
+ 1,
+ new KeyGroupRange(0, 0),
+ new KvStateRegistry().createTaskRegistry(new JobID(), new JobVertexID()));
+ backend.setCurrentKey(0);
+ return backend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, descr);
+ }
+ }).when(keyedStateBackend).getPartitionedState(Matchers.any(), any(TypeSerializer.class), any(MapStateDescriptor.class));
+
+ when(operatorMock.getKeyedStateStore()).thenReturn(keyedStateStore);
+ return operatorMock;
+ }
+
private static Environment createMockEnvironment() {
Environment env = mock(Environment.class);
when(env.getUserClassLoader()).thenReturn(StreamingRuntimeContextTest.class.getClassLoader());
http://git-wip-us.apache.org/repos/asf/flink/blob/30c9e2b6/flink-tests/src/test/java/org/apache/flink/test/query/KVStateRequestSerializerRocksDBTest.java
----------------------------------------------------------------------
diff --git a/flink-tests/src/test/java/org/apache/flink/test/query/KVStateRequestSerializerRocksDBTest.java b/flink-tests/src/test/java/org/apache/flink/test/query/KVStateRequestSerializerRocksDBTest.java
index 0562443..6e2fd62 100644
--- a/flink-tests/src/test/java/org/apache/flink/test/query/KVStateRequestSerializerRocksDBTest.java
+++ b/flink-tests/src/test/java/org/apache/flink/test/query/KVStateRequestSerializerRocksDBTest.java
@@ -20,8 +20,10 @@ package org.apache.flink.test.query;
import org.apache.flink.api.common.JobID;
import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.state.MapStateDescriptor;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.common.typeutils.base.LongSerializer;
+import org.apache.flink.api.common.typeutils.base.StringSerializer;
import org.apache.flink.contrib.streaming.state.PredefinedOptions;
import org.apache.flink.contrib.streaming.state.RocksDBKeyedStateBackend;
import org.apache.flink.runtime.query.TaskKvStateRegistry;
@@ -32,6 +34,7 @@ import org.apache.flink.runtime.state.VoidNamespace;
import org.apache.flink.runtime.state.VoidNamespaceSerializer;
import org.apache.flink.runtime.state.internal.InternalListState;
+import org.apache.flink.runtime.state.internal.InternalMapState;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
@@ -122,4 +125,41 @@ public final class KVStateRequestSerializerRocksDBTest {
KvStateRequestSerializerTest.testListSerialization(key, listState);
}
+
+ /**
+ * Tests map serialization and deserialization match.
+ *
+ * @see KvStateRequestSerializerTest#testMapSerialization()
+ * KvStateRequestSerializerTest#testMapSerialization() using the heap state back-end
+ * test
+ */
+ @Test
+ public void testMapSerialization() throws Exception {
+ final long key = 0L;
+
+ // objects for RocksDB state list serialisation
+ DBOptions dbOptions = PredefinedOptions.DEFAULT.createDBOptions();
+ dbOptions.setCreateIfMissing(true);
+ ColumnFamilyOptions columnFamilyOptions = PredefinedOptions.DEFAULT.createColumnOptions();
+ final RocksDBKeyedStateBackend<Long> longHeapKeyedStateBackend =
+ new RocksDBKeyedStateBackend<>(
+ new JobID(), "no-op",
+ ClassLoader.getSystemClassLoader(),
+ temporaryFolder.getRoot(),
+ dbOptions,
+ columnFamilyOptions,
+ mock(TaskKvStateRegistry.class),
+ LongSerializer.INSTANCE,
+ 1, new KeyGroupRange(0, 0)
+ );
+ longHeapKeyedStateBackend.setCurrentKey(key);
+
+ final InternalMapState<VoidNamespace, Long, String> mapState = (InternalMapState<VoidNamespace, Long, String>)
+ longHeapKeyedStateBackend.getPartitionedState(
+ VoidNamespace.INSTANCE,
+ VoidNamespaceSerializer.INSTANCE,
+ new MapStateDescriptor<>("test", LongSerializer.INSTANCE, StringSerializer.INSTANCE));
+
+ KvStateRequestSerializerTest.testMapSerialization(key, mapState);
+ }
}
[2/2] flink git commit: [FLINK-4856] Add MapState for keyed state
Posted by al...@apache.org.
[FLINK-4856] Add MapState for keyed state
Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/30c9e2b6
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/30c9e2b6
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/30c9e2b6
Branch: refs/heads/master
Commit: 30c9e2b683bf7f4776ffc23b6a860946a4429ae5
Parents: de2605e
Author: xiaogang.sxg <xi...@alibaba-inc.com>
Authored: Fri Feb 17 11:19:18 2017 +0800
Committer: Aljoscha Krettek <al...@gmail.com>
Committed: Thu Feb 23 16:56:29 2017 +0100
----------------------------------------------------------------------
docs/dev/stream/state.md | 8 +-
.../streaming/state/AbstractRocksDBState.java | 49 +-
.../state/RocksDBKeyedStateBackend.java | 10 +
.../streaming/state/RocksDBMapState.java | 546 +++++++++++++++++++
.../api/common/functions/RuntimeContext.java | 42 ++
.../util/AbstractRuntimeUDFContext.java | 9 +
.../flink/api/common/state/KeyedStateStore.java | 40 ++
.../apache/flink/api/common/state/MapState.java | 134 +++++
.../api/common/state/MapStateDescriptor.java | 147 +++++
.../flink/api/common/state/StateBinder.java | 9 +
.../flink/api/common/state/StateDescriptor.java | 2 +-
.../common/typeutils/base/MapSerializer.java | 193 +++++++
.../flink/api/java/typeutils/MapTypeInfo.java | 147 +++++
.../common/state/MapStateDescriptorTest.java | 115 ++++
.../typeutils/base/MapSerializerTest.java | 90 +++
.../flink/hdfstests/FileStateBackendTest.java | 4 +
.../netty/message/KvStateRequestSerializer.java | 67 +++
.../state/AbstractKeyedStateBackend.java | 23 +-
.../runtime/state/DefaultKeyedStateStore.java | 14 +
.../flink/runtime/state/HashMapSerializer.java | 193 +++++++
.../flink/runtime/state/UserFacingMapState.java | 103 ++++
.../state/heap/HeapKeyedStateBackend.java | 18 +-
.../flink/runtime/state/heap/HeapMapState.java | 311 +++++++++++
.../state/internal/InternalMapState.java | 32 ++
.../message/KvStateRequestSerializerTest.java | 131 +++++
.../runtime/state/FileStateBackendTest.java | 4 +
.../runtime/state/MemoryStateBackendTest.java | 4 +
.../runtime/state/SerializationProxiesTest.java | 3 +-
.../runtime/state/StateBackendTestBase.java | 299 ++++++++++
.../api/functions/async/RichAsyncFunction.java | 7 +
.../api/operators/StreamingRuntimeContext.java | 9 +
.../functions/async/RichAsyncFunctionTest.java | 8 +-
.../operators/StreamingRuntimeContextTest.java | 85 ++-
.../KVStateRequestSerializerRocksDBTest.java | 40 ++
34 files changed, 2887 insertions(+), 9 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/flink/blob/30c9e2b6/docs/dev/stream/state.md
----------------------------------------------------------------------
diff --git a/docs/dev/stream/state.md b/docs/dev/stream/state.md
index e554e29..40522e1 100644
--- a/docs/dev/stream/state.md
+++ b/docs/dev/stream/state.md
@@ -118,6 +118,11 @@ added to the state. Contrary to `ReducingState`, the aggregate type may be diffe
of elements that are added to the state. The interface is the same as for `ListState` but elements
added using `add(T)` are folded into an aggregate using a specified `FoldFunction`.
+* `MapState<UK, UV>`: This keeps a list of mappings. You can put key-value pairs into the state and
+retrieve an `Iterable` over all currently stored mappings. Mappings are added using `put(UK, UV)` or
+`putAll(Map<UK, UV>)`. The value associated with a user key can be retrieved using `get(UK)`. The iterable
+views for mappings, keys and values can be retrieved using `entries()`, `keys()` and `values()` respectively.
+
All types of state also have a method `clear()` that clears the state for the currently
active key, i.e. the key of the input element.
@@ -136,7 +141,7 @@ To get a state handle, you have to create a `StateDescriptor`. This holds the na
that you can reference them), the type of the values that the state holds, and possibly
a user-specified function, such as a `ReduceFunction`. Depending on what type of state you
want to retrieve, you create either a `ValueStateDescriptor`, a `ListStateDescriptor`,
-a `ReducingStateDescriptor` or a `FoldingStateDescriptor`.
+a `ReducingStateDescriptor`, a `FoldingStateDescriptor` or a `MapStateDescriptor`.
State is accessed using the `RuntimeContext`, so it is only possible in *rich functions*.
Please see [here]({{ site.baseurl }}/dev/api_concepts.html#rich-functions) for
@@ -147,6 +152,7 @@ is available in a `RichFunction` has these methods for accessing state:
* `ReducingState<T> getReducingState(ReducingStateDescriptor<T>)`
* `ListState<T> getListState(ListStateDescriptor<T>)`
* `FoldingState<T, ACC> getFoldingState(FoldingStateDescriptor<T, ACC>)`
+* `MapState<UK, UV> getMapState(MapStateDescriptor<UK, UV>)`
This is an example `FlatMapFunction` that shows how all of the parts fit together:
http://git-wip-us.apache.org/repos/asf/flink/blob/30c9e2b6/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/AbstractRocksDBState.java
----------------------------------------------------------------------
diff --git a/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/AbstractRocksDBState.java b/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/AbstractRocksDBState.java
index 89f41aa..569971a 100644
--- a/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/AbstractRocksDBState.java
+++ b/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/AbstractRocksDBState.java
@@ -21,7 +21,10 @@ import org.apache.flink.api.common.state.State;
import org.apache.flink.api.common.state.StateDescriptor;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.core.memory.ByteArrayInputStreamWithPos;
import org.apache.flink.core.memory.ByteArrayOutputStreamWithPos;
+import org.apache.flink.core.memory.DataInputView;
import org.apache.flink.core.memory.DataOutputView;
import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
import org.apache.flink.runtime.query.netty.message.KvStateRequestSerializer;
@@ -50,7 +53,7 @@ public abstract class AbstractRocksDBState<K, N, S extends State, SD extends Sta
implements InternalKvState<N>, State {
/** Serializer for the namespace */
- private final TypeSerializer<N> namespaceSerializer;
+ final TypeSerializer<N> namespaceSerializer;
/** The current namespace, which the next value methods will refer to */
private N currentNamespace;
@@ -215,4 +218,48 @@ public abstract class AbstractRocksDBState<K, N, S extends State, SD extends Sta
value >>>= 8;
} while (value != 0);
}
+
+ protected Tuple3<Integer, K, N> readKeyWithGroupAndNamespace(ByteArrayInputStreamWithPos inputStream, DataInputView inputView) throws IOException {
+ int keyGroup = readKeyGroup(inputView);
+ K key = readKey(inputStream, inputView);
+ N namespace = readNamespace(inputStream, inputView);
+
+ return new Tuple3<>(keyGroup, key, namespace);
+ }
+
+ private int readKeyGroup(DataInputView inputView) throws IOException {
+ int keyGroup = 0;
+ for (int i = 0; i < backend.getKeyGroupPrefixBytes(); ++i) {
+ keyGroup <<= 8;
+ keyGroup |= (inputView.readByte() & 0xFF);
+ }
+ return keyGroup;
+ }
+
+ private K readKey(ByteArrayInputStreamWithPos inputStream, DataInputView inputView) throws IOException {
+ int beforeRead = inputStream.getPosition();
+ K key = backend.getKeySerializer().deserialize(inputView);
+ if (ambiguousKeyPossible) {
+ int length = inputStream.getPosition() - beforeRead;
+ readVariableIntBytes(inputView, length);
+ }
+ return key;
+ }
+
+ private N readNamespace(ByteArrayInputStreamWithPos inputStream, DataInputView inputView) throws IOException {
+ int beforeRead = inputStream.getPosition();
+ N namespace = namespaceSerializer.deserialize(inputView);
+ if (ambiguousKeyPossible) {
+ int length = inputStream.getPosition() - beforeRead;
+ readVariableIntBytes(inputView, length);
+ }
+ return namespace;
+ }
+
+ private void readVariableIntBytes(DataInputView inputView, int value) throws IOException {
+ do {
+ inputView.readByte();
+ value >>>= 8;
+ } while (value != 0);
+ }
}
http://git-wip-us.apache.org/repos/asf/flink/blob/30c9e2b6/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
----------------------------------------------------------------------
diff --git a/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java b/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
index d8d77b6..a0efe78 100644
--- a/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
+++ b/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
@@ -21,6 +21,7 @@ import org.apache.flink.api.common.JobID;
import org.apache.flink.api.common.state.AggregatingStateDescriptor;
import org.apache.flink.api.common.state.FoldingStateDescriptor;
import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.state.MapStateDescriptor;
import org.apache.flink.api.common.state.ReducingStateDescriptor;
import org.apache.flink.api.common.state.StateDescriptor;
import org.apache.flink.api.common.state.ValueStateDescriptor;
@@ -53,6 +54,7 @@ import org.apache.flink.runtime.state.StreamStateHandle;
import org.apache.flink.runtime.state.internal.InternalAggregatingState;
import org.apache.flink.runtime.state.internal.InternalFoldingState;
import org.apache.flink.runtime.state.internal.InternalListState;
+import org.apache.flink.runtime.state.internal.InternalMapState;
import org.apache.flink.runtime.state.internal.InternalReducingState;
import org.apache.flink.runtime.state.internal.InternalValueState;
import org.apache.flink.runtime.util.SerializableObject;
@@ -882,6 +884,14 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
return new RocksDBFoldingState<>(columnFamily, namespaceSerializer, stateDesc, this);
}
+ @Override
+ protected <N, UK, UV> InternalMapState<N, UK, UV> createMapState(TypeSerializer<N> namespaceSerializer,
+ MapStateDescriptor<UK, UV> stateDesc) throws Exception {
+ ColumnFamilyHandle columnFamily = getColumnFamily(stateDesc, namespaceSerializer);
+
+ return new RocksDBMapState<>(columnFamily, namespaceSerializer, stateDesc, this);
+ }
+
/**
* Wraps a RocksDB iterator to cache it's current key and assign an id for the key/value state to the iterator.
* Used by #MergeIterator.
http://git-wip-us.apache.org/repos/asf/flink/blob/30c9e2b6/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBMapState.java
----------------------------------------------------------------------
diff --git a/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBMapState.java b/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBMapState.java
new file mode 100644
index 0000000..e9e9d9b
--- /dev/null
+++ b/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBMapState.java
@@ -0,0 +1,546 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ * <p/>
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * <p/>
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.contrib.streaming.state;
+
+import org.apache.flink.api.common.state.MapState;
+import org.apache.flink.api.common.state.MapStateDescriptor;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.core.memory.ByteArrayInputStreamWithPos;
+import org.apache.flink.core.memory.ByteArrayOutputStreamWithPos;
+import org.apache.flink.core.memory.DataInputViewStreamWrapper;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.runtime.query.netty.message.KvStateRequestSerializer;
+import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
+import org.apache.flink.runtime.state.internal.InternalMapState;
+import org.apache.flink.util.Preconditions;
+import org.rocksdb.ColumnFamilyHandle;
+import org.rocksdb.RocksDB;
+import org.rocksdb.RocksDBException;
+import org.rocksdb.RocksIterator;
+import org.rocksdb.WriteOptions;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Iterator;
+import java.util.Map;
+
+/**
+ * {@link MapState} implementation that stores state in RocksDB.
+ * <p>
+ * <p>{@link RocksDBStateBackend} must ensure that we set the
+ * {@link org.rocksdb.StringAppendOperator} on the column family that we use for our state since
+ * we use the {@code merge()} call.
+ *
+ * @param <K> The type of the key.
+ * @param <N> The type of the namespace.
+ * @param <UK> The type of the keys in the map state.
+ * @param <UV> The type of the values in the map state.
+ */
+public class RocksDBMapState<K, N, UK, UV>
+ extends AbstractRocksDBState<K, N, MapState<UK, UV>, MapStateDescriptor<UK, UV>, Map<UK, UV>>
+ implements InternalMapState<N, UK, UV> {
+
+ private static Logger LOG = LoggerFactory.getLogger(RocksDBMapState.class);
+
+ /** Serializer for the keys and values */
+ private final TypeSerializer<UK> userKeySerializer;
+ private final TypeSerializer<UV> userValueSerializer;
+
+ /**
+ * We disable writes to the write-ahead-log here. We can't have these in the base class
+ * because JNI segfaults for some reason if they are.
+ */
+ private final WriteOptions writeOptions;
+
+ /**
+ * Creates a new {@code RocksDBMapState}.
+ *
+ * @param namespaceSerializer The serializer for the namespace.
+ * @param stateDesc The state identifier for the state.
+ */
+ public RocksDBMapState(ColumnFamilyHandle columnFamily,
+ TypeSerializer<N> namespaceSerializer,
+ MapStateDescriptor<UK, UV> stateDesc,
+ RocksDBKeyedStateBackend<K> backend) {
+
+ super(columnFamily, namespaceSerializer, stateDesc, backend);
+
+ this.userKeySerializer = stateDesc.getKeySerializer();
+ this.userValueSerializer = stateDesc.getValueSerializer();
+
+ writeOptions = new WriteOptions();
+ writeOptions.setDisableWAL(true);
+ }
+
+ // ------------------------------------------------------------------------
+ // MapState Implementation
+ // ------------------------------------------------------------------------
+
+ @Override
+ public UV get(UK userKey) throws IOException, RocksDBException {
+ byte[] rawKeyBytes = serializeUserKeyWithCurrentKeyAndNamespace(userKey);
+ byte[] rawValueBytes = backend.db.get(columnFamily, rawKeyBytes);
+
+ return (rawValueBytes == null ? null : deserializeUserValue(rawValueBytes));
+ }
+
+ @Override
+ public void put(UK userKey, UV userValue) throws IOException, RocksDBException {
+
+ byte[] rawKeyBytes = serializeUserKeyWithCurrentKeyAndNamespace(userKey);
+ byte[] rawValueBytes = serializeUserValue(userValue);
+
+ backend.db.put(columnFamily, writeOptions, rawKeyBytes, rawValueBytes);
+ }
+
+ @Override
+ public void putAll(Map<UK, UV> map) throws IOException, RocksDBException {
+ if (map == null) {
+ return;
+ }
+
+ for (Map.Entry<UK, UV> entry : map.entrySet()) {
+ put(entry.getKey(), entry.getValue());
+ }
+ }
+
+ @Override
+ public void remove(UK userKey) throws IOException, RocksDBException {
+ byte[] rawKeyBytes = serializeUserKeyWithCurrentKeyAndNamespace(userKey);
+
+ backend.db.remove(columnFamily, writeOptions, rawKeyBytes);
+ }
+
+ @Override
+ public boolean contains(UK userKey) throws IOException, RocksDBException {
+ byte[] rawKeyBytes = serializeUserKeyWithCurrentKeyAndNamespace(userKey);
+ byte[] rawValueBytes = backend.db.get(columnFamily, rawKeyBytes);
+
+ return (rawValueBytes != null);
+ }
+
+ @Override
+ public int size() throws IOException, RocksDBException {
+ Iterator<Map.Entry<UK, UV>> iterator = iterator();
+
+ int count = 0;
+ while (iterator.hasNext()) {
+ count++;
+ iterator.next();
+ }
+
+ return count;
+ }
+
+ @Override
+ public Iterable<Map.Entry<UK, UV>> entries() throws IOException, RocksDBException {
+ final Iterator<Map.Entry<UK, UV>> iterator = iterator();
+
+ // Return null to make the behavior consistent with other states.
+ if (!iterator.hasNext()) {
+ return null;
+ } else {
+ return new Iterable<Map.Entry<UK, UV>>() {
+ @Override
+ public Iterator<Map.Entry<UK, UV>> iterator() {
+ return iterator;
+ }
+ };
+ }
+ }
+
+ @Override
+ public Iterable<UK> keys() throws IOException, RocksDBException {
+ final byte[] prefixBytes = serializeCurrentKeyAndNamespace();
+
+ return new Iterable<UK>() {
+ @Override
+ public Iterator<UK> iterator() {
+ return new RocksDBMapIterator<UK>(backend.db, prefixBytes) {
+ @Override
+ public UK next() {
+ RocksDBMapEntry entry = nextEntry();
+ return (entry == null ? null : entry.getKey());
+ }
+ };
+ }
+ };
+ }
+
+ @Override
+ public Iterable<UV> values() throws IOException, RocksDBException {
+ final byte[] prefixBytes = serializeCurrentKeyAndNamespace();
+
+ return new Iterable<UV>() {
+ @Override
+ public Iterator<UV> iterator() {
+ return new RocksDBMapIterator<UV>(backend.db, prefixBytes) {
+ @Override
+ public UV next() {
+ RocksDBMapEntry entry = nextEntry();
+ return (entry == null ? null : entry.getValue());
+ }
+ };
+ }
+ };
+ }
+
+ @Override
+ public Iterator<Map.Entry<UK, UV>> iterator() throws IOException, RocksDBException {
+ final byte[] prefixBytes = serializeCurrentKeyAndNamespace();
+
+ return new RocksDBMapIterator<Map.Entry<UK, UV>>(backend.db, prefixBytes) {
+ @Override
+ public Map.Entry<UK, UV> next() {
+ return nextEntry();
+ }
+ };
+ }
+
+ @Override
+ public void clear() {
+ try {
+ Iterator<Map.Entry<UK, UV>> iterator = iterator();
+
+ while (iterator.hasNext()) {
+ iterator.next();
+ iterator.remove();
+ }
+ } catch (Exception e) {
+ LOG.warn("Error while cleaning the state.", e);
+ }
+ }
+
+ @Override
+ @SuppressWarnings("unchecked")
+ public byte[] getSerializedValue(byte[] serializedKeyAndNamespace) throws Exception {
+ Preconditions.checkNotNull(serializedKeyAndNamespace, "Serialized key and namespace");
+
+ //TODO make KvStateRequestSerializer key-group aware to save this round trip and key-group computation
+ Tuple2<K, N> des = KvStateRequestSerializer.deserializeKeyAndNamespace(
+ serializedKeyAndNamespace,
+ backend.getKeySerializer(),
+ namespaceSerializer);
+
+ int keyGroup = KeyGroupRangeAssignment.assignToKeyGroup(des.f0, backend.getNumberOfKeyGroups());
+
+ ByteArrayOutputStreamWithPos outputStream = new ByteArrayOutputStreamWithPos(128);
+ DataOutputViewStreamWrapper outputView = new DataOutputViewStreamWrapper(outputStream);
+ writeKeyWithGroupAndNamespace(keyGroup, des.f0, des.f1, outputStream, outputView);
+ final byte[] keyPrefixBytes = outputStream.toByteArray();
+
+ final Iterator<Map.Entry<UK, UV>> iterator = new RocksDBMapIterator<Map.Entry<UK, UV>>(backend.db, keyPrefixBytes) {
+ @Override
+ public Map.Entry<UK, UV> next() {
+ return nextEntry();
+ }
+ };
+
+ // Return null to make the behavior consistent with other backends
+ if (!iterator.hasNext()) {
+ return null;
+ }
+
+ return KvStateRequestSerializer.serializeMap(new Iterable<Map.Entry<UK, UV>>() {
+ @Override
+ public Iterator<Map.Entry<UK, UV>> iterator() {
+ return iterator;
+ }
+ }, userKeySerializer, userValueSerializer);
+ }
+
+ // ------------------------------------------------------------------------
+ // Serialization Methods
+ // ------------------------------------------------------------------------
+
+ private byte[] serializeCurrentKeyAndNamespace() throws IOException {
+ writeCurrentKeyWithGroupAndNamespace();
+
+ return keySerializationStream.toByteArray();
+ }
+
+ private byte[] serializeUserKeyWithCurrentKeyAndNamespace(UK userKey) throws IOException {
+ writeCurrentKeyWithGroupAndNamespace();
+ userKeySerializer.serialize(userKey, keySerializationDataOutputView);
+
+ return keySerializationStream.toByteArray();
+ }
+
+ private byte[] serializeUserValue(UV userValue) throws IOException {
+ keySerializationStream.reset();
+
+ if (userValue == null) {
+ keySerializationDataOutputView.writeBoolean(true);
+ } else {
+ keySerializationDataOutputView.writeBoolean(false);
+ userValueSerializer.serialize(userValue, keySerializationDataOutputView);
+ }
+
+
+ return keySerializationStream.toByteArray();
+ }
+
+ private UK deserializeUserKey(byte[] rawKeyBytes) throws IOException {
+ ByteArrayInputStreamWithPos bais = new ByteArrayInputStreamWithPos(rawKeyBytes);
+ DataInputViewStreamWrapper in = new DataInputViewStreamWrapper(bais);
+
+ readKeyWithGroupAndNamespace(bais, in);
+
+ return userKeySerializer.deserialize(in);
+ }
+
+ private UV deserializeUserValue(byte[] rawValueBytes) throws IOException {
+ ByteArrayInputStreamWithPos bais = new ByteArrayInputStreamWithPos(rawValueBytes);
+ DataInputViewStreamWrapper in = new DataInputViewStreamWrapper(bais);
+
+ boolean isNull = in.readBoolean();
+
+ return isNull ? null : userValueSerializer.deserialize(in);
+ }
+
+ // ------------------------------------------------------------------------
+ // Internal Classes
+ // ------------------------------------------------------------------------
+
+ /** A map entry in RocksDBMapState */
+ private class RocksDBMapEntry implements Map.Entry<UK, UV> {
+ private final RocksDB db;
+
+ /** The raw bytes of the key stored in RocksDB. Each user key is stored in RocksDB
+ * with the format #KeyGroup#Key#Namespace#UserKey. */
+ private final byte[] rawKeyBytes;
+
+ /** The raw bytes of the value stored in RocksDB */
+ private byte[] rawValueBytes;
+
+ /** True if the entry has been deleted. */
+ private boolean deleted;
+
+ /** The user key and value. The deserialization is performed lazily, i.e. the key
+ * and the value is deserialized only when they are accessed. */
+ private UK userKey = null;
+ private UV userValue = null;
+
+ RocksDBMapEntry(final RocksDB db, final byte[] rawKeyBytes, final byte[] rawValueBytes) {
+ this.db = db;
+
+ this.rawKeyBytes = rawKeyBytes;
+ this.rawValueBytes = rawValueBytes;
+ this.deleted = false;
+ }
+
+ public void remove() {
+ deleted = true;
+ rawValueBytes = null;
+
+ try {
+ db.remove(columnFamily, writeOptions, rawKeyBytes);
+ } catch (RocksDBException e) {
+ throw new RuntimeException("Error while removing data from RocksDB.", e);
+ }
+ }
+
+ @Override
+ public UK getKey() {
+ if (userKey == null) {
+ try {
+ userKey = deserializeUserKey(rawKeyBytes);
+ } catch (IOException e) {
+ throw new RuntimeException("Error while deserializing the user key.");
+ }
+ }
+
+ return userKey;
+ }
+
+ @Override
+ public UV getValue() {
+ if (deleted) {
+ return null;
+ } else {
+ if (userValue == null) {
+ try {
+ userValue = deserializeUserValue(rawValueBytes);
+ } catch (IOException e) {
+ throw new RuntimeException("Error while deserializing the user value.");
+ }
+ }
+
+ return userValue;
+ }
+ }
+
+ @Override
+ public UV setValue(UV value) {
+ if (deleted) {
+ throw new IllegalStateException("The value has already been deleted.");
+ }
+
+ UV oldValue = getValue();
+
+ try {
+ userValue = value;
+ rawValueBytes = serializeUserValue(value);
+
+ db.put(columnFamily, writeOptions, rawKeyBytes, rawValueBytes);
+ } catch (IOException | RocksDBException e) {
+ throw new RuntimeException("Error while putting data into RocksDB.", e);
+ }
+
+ return oldValue;
+ }
+ }
+
+ /** An auxiliary utility to scan all entries under the given key. */
+ private abstract class RocksDBMapIterator<T> implements Iterator<T> {
+
+ final static int CACHE_SIZE_BASE = 1;
+ final static int CACHE_SIZE_LIMIT = 128;
+
+ /** The db where data resides. */
+ private final RocksDB db;
+
+ /**
+ * The prefix bytes of the key being accessed. All entries under the same key
+ * has the same prefix, hence we can stop the iterating once coming across an
+ * entry with a different prefix.
+ */
+ private final byte[] keyPrefixBytes;
+
+ /**
+ * True if all entries have been accessed or the iterator has come across an
+ * entry with a different prefix.
+ */
+ private boolean expired = false;
+
+ /** A in-memory cache for the entries in the rocksdb. */
+ private ArrayList<RocksDBMapEntry> cacheEntries = new ArrayList<>();
+ private int cacheIndex = 0;
+
+
+ RocksDBMapIterator(final RocksDB db, final byte[] keyPrefixBytes) {
+ this.db = db;
+ this.keyPrefixBytes = keyPrefixBytes;
+ }
+
+ @Override
+ public boolean hasNext() {
+ loadCache();
+
+ return (cacheIndex < cacheEntries.size());
+ }
+
+ @Override
+ public void remove() {
+ if (cacheIndex == 0 || cacheIndex > cacheEntries.size()) {
+ throw new IllegalStateException("The remove operation must be called after an valid next operation.");
+ }
+
+ RocksDBMapEntry lastEntry = cacheEntries.get(cacheIndex - 1);
+ lastEntry.remove();
+ }
+
+ final RocksDBMapEntry nextEntry() {
+ loadCache();
+
+ if (cacheIndex == cacheEntries.size()) {
+ if (!expired) {
+ throw new IllegalStateException();
+ }
+
+ return null;
+ }
+
+ RocksDBMapEntry entry = cacheEntries.get(cacheIndex);
+ cacheIndex++;
+
+ return entry;
+ }
+
+ private void loadCache() {
+ if (cacheIndex > cacheEntries.size()) {
+ throw new IllegalStateException();
+ }
+
+ // Load cache entries only when the cache is empty and there still exist unread entries
+ if (cacheIndex < cacheEntries.size() || expired) {
+ return;
+ }
+
+ RocksIterator iterator = db.newIterator(columnFamily);
+
+ /*
+ * The iteration starts from the prefix bytes at the first loading. The cache then is
+ * reloaded when the next entry to return is the last one in the cache. At that time,
+ * we will start the iterating from the last returned entry.
+ */
+ RocksDBMapEntry lastEntry = cacheEntries.size() == 0 ? null : cacheEntries.get(cacheEntries.size() - 1);
+ byte[] startBytes = (lastEntry == null ? keyPrefixBytes : lastEntry.rawKeyBytes);
+ int numEntries = (lastEntry == null ? CACHE_SIZE_BASE : Math.min(cacheEntries.size() * 2, CACHE_SIZE_LIMIT));
+
+ cacheEntries.clear();
+ cacheIndex = 0;
+
+ iterator.seek(startBytes);
+
+ /*
+ * If the last returned entry is not deleted, it will be the first entry in the
+ * iterating. Skip it to avoid redundant access in such cases.
+ */
+ if (lastEntry != null && !lastEntry.deleted) {
+ iterator.next();
+ }
+
+ while (true) {
+ if (!iterator.isValid() || !underSameKey(iterator.key())) {
+ expired = true;
+ break;
+ }
+
+ if (cacheEntries.size() >= numEntries) {
+ break;
+ }
+
+ RocksDBMapEntry entry = new RocksDBMapEntry(db, iterator.key(), iterator.value());
+ cacheEntries.add(entry);
+
+ iterator.next();
+ }
+
+ iterator.close();
+ }
+
+ private boolean underSameKey(byte[] rawKeyBytes) {
+ if (rawKeyBytes.length < keyPrefixBytes.length) {
+ return false;
+ }
+
+ for (int i = 0; i < keyPrefixBytes.length; ++i) {
+ if (rawKeyBytes[i] != keyPrefixBytes[i]) {
+ return false;
+ }
+ }
+
+ return true;
+ }
+ }
+}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/flink/blob/30c9e2b6/flink-core/src/main/java/org/apache/flink/api/common/functions/RuntimeContext.java
----------------------------------------------------------------------
diff --git a/flink-core/src/main/java/org/apache/flink/api/common/functions/RuntimeContext.java b/flink-core/src/main/java/org/apache/flink/api/common/functions/RuntimeContext.java
index 405e390..98ad018 100644
--- a/flink-core/src/main/java/org/apache/flink/api/common/functions/RuntimeContext.java
+++ b/flink-core/src/main/java/org/apache/flink/api/common/functions/RuntimeContext.java
@@ -31,6 +31,8 @@ import org.apache.flink.api.common.state.FoldingState;
import org.apache.flink.api.common.state.FoldingStateDescriptor;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.state.MapState;
+import org.apache.flink.api.common.state.MapStateDescriptor;
import org.apache.flink.api.common.state.ReducingState;
import org.apache.flink.api.common.state.ReducingStateDescriptor;
import org.apache.flink.api.common.state.ValueState;
@@ -387,4 +389,44 @@ public interface RuntimeContext {
*/
@PublicEvolving
<T, ACC> FoldingState<T, ACC> getFoldingState(FoldingStateDescriptor<T, ACC> stateProperties);
+
+ /**
+ * Gets a handle to the system's key/value map state. This state is similar to the state
+ * accessed via {@link #getState(ValueStateDescriptor)}, but is optimized for state that
+ * is composed of user-defined key-value pairs
+ *
+ * <p>This state is only accessible if the function is executed on a KeyedStream.
+ *
+ * <pre>{@code
+ * DataStream<MyType> stream = ...;
+ * KeyedStream<MyType> keyedStream = stream.keyBy("id");
+ *
+ * keyedStream.map(new RichMapFunction<MyType, List<MyType>>() {
+ *
+ * private MapState<MyType, Long> state;
+ *
+ * public void open(Configuration cfg) {
+ * state = getRuntimeContext().getMapState(
+ * new MapStateDescriptor<>("sum", MyType.class, Long.class));
+ * }
+ *
+ * public Tuple2<MyType, Long> map(MyType value) {
+ * return new Tuple2<>(value, state.get(value));
+ * }
+ * });
+ *
+ * }</pre>
+ *
+ * @param stateProperties The descriptor defining the properties of the stats.
+ *
+ * @param <UK> The type of the user keys stored in the state.
+ * @param <UV> The type of the user values stored in the state.
+ *
+ * @return The partitioned state object.
+ *
+ * @throws UnsupportedOperationException Thrown, if no partitioned state is available for the
+ * function (function is not part of a KeyedStream).
+ */
+ @PublicEvolving
+ <UK, UV> MapState<UK, UV> getMapState(MapStateDescriptor<UK, UV> stateProperties);
}
http://git-wip-us.apache.org/repos/asf/flink/blob/30c9e2b6/flink-core/src/main/java/org/apache/flink/api/common/functions/util/AbstractRuntimeUDFContext.java
----------------------------------------------------------------------
diff --git a/flink-core/src/main/java/org/apache/flink/api/common/functions/util/AbstractRuntimeUDFContext.java b/flink-core/src/main/java/org/apache/flink/api/common/functions/util/AbstractRuntimeUDFContext.java
index 0eafeaa..2538799 100644
--- a/flink-core/src/main/java/org/apache/flink/api/common/functions/util/AbstractRuntimeUDFContext.java
+++ b/flink-core/src/main/java/org/apache/flink/api/common/functions/util/AbstractRuntimeUDFContext.java
@@ -33,6 +33,8 @@ import org.apache.flink.api.common.state.FoldingState;
import org.apache.flink.api.common.state.FoldingStateDescriptor;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.state.MapState;
+import org.apache.flink.api.common.state.MapStateDescriptor;
import org.apache.flink.api.common.state.ReducingState;
import org.apache.flink.api.common.state.ReducingStateDescriptor;
import org.apache.flink.api.common.state.ValueState;
@@ -214,4 +216,11 @@ public abstract class AbstractRuntimeUDFContext implements RuntimeContext {
throw new UnsupportedOperationException(
"This state is only accessible by functions executed on a KeyedStream");
}
+
+ @Override
+ @PublicEvolving
+ public <UK, UV> MapState<UK, UV> getMapState(MapStateDescriptor<UK, UV> stateProperties) {
+ throw new UnsupportedOperationException(
+ "This state is only accessible by functions executed on a KeyedStream");
+ }
}
http://git-wip-us.apache.org/repos/asf/flink/blob/30c9e2b6/flink-core/src/main/java/org/apache/flink/api/common/state/KeyedStateStore.java
----------------------------------------------------------------------
diff --git a/flink-core/src/main/java/org/apache/flink/api/common/state/KeyedStateStore.java b/flink-core/src/main/java/org/apache/flink/api/common/state/KeyedStateStore.java
index bbb4c67..2187f6c 100644
--- a/flink-core/src/main/java/org/apache/flink/api/common/state/KeyedStateStore.java
+++ b/flink-core/src/main/java/org/apache/flink/api/common/state/KeyedStateStore.java
@@ -196,4 +196,44 @@ public interface KeyedStateStore {
*/
@PublicEvolving
<T, ACC> FoldingState<T, ACC> getFoldingState(FoldingStateDescriptor<T, ACC> stateProperties);
+
+ /**
+ * Gets a handle to the system's key/value map state. This state is similar to the state
+ * accessed via {@link #getState(ValueStateDescriptor)}, but is optimized for state that
+ * is composed of user-defined key-value pairs
+ *
+ * <p>This state is only accessible if the function is executed on a KeyedStream.
+ *
+ * <pre>{@code
+ * DataStream<MyType> stream = ...;
+ * KeyedStream<MyType> keyedStream = stream.keyBy("id");
+ *
+ * keyedStream.map(new RichMapFunction<MyType, List<MyType>>() {
+ *
+ * private MapState<MyType, Long> state;
+ *
+ * public void open(Configuration cfg) {
+ * state = getRuntimeContext().getMapState(
+ * new MapStateDescriptor<>("sum", MyType.class, Long.class));
+ * }
+ *
+ * public Tuple2<MyType, Long> map(MyType value) {
+ * return new Tuple2<>(value, state.get(value));
+ * }
+ * });
+ *
+ * }</pre>
+ *
+ * @param stateProperties The descriptor defining the properties of the stats.
+ *
+ * @param <UK> The type of the user keys stored in the state.
+ * @param <UV> The type of the user values stored in the state.
+ *
+ * @return The partitioned state object.
+ *
+ * @throws UnsupportedOperationException Thrown, if no partitioned state is available for the
+ * function (function is not part of a KeyedStream).
+ */
+ @PublicEvolving
+ <UK, UV> MapState<UK,UV> getMapState(MapStateDescriptor<UK, UV> stateProperties);
}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/flink/blob/30c9e2b6/flink-core/src/main/java/org/apache/flink/api/common/state/MapState.java
----------------------------------------------------------------------
diff --git a/flink-core/src/main/java/org/apache/flink/api/common/state/MapState.java b/flink-core/src/main/java/org/apache/flink/api/common/state/MapState.java
new file mode 100644
index 0000000..fa657ef
--- /dev/null
+++ b/flink-core/src/main/java/org/apache/flink/api/common/state/MapState.java
@@ -0,0 +1,134 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.api.common.state;
+
+import org.apache.flink.annotation.PublicEvolving;
+
+import java.util.Iterator;
+import java.util.Map;
+
+/**
+ * {@link State} interface for partitioned key-value state. The key-value pair can be
+ * added, updated and retrieved.
+ *
+ * <p>The state is accessed and modified by user functions, and checkpointed consistently
+ * by the system as part of the distributed snapshots.
+ *
+ * <p>The state is only accessible by functions applied on a KeyedDataStream. The key is
+ * automatically supplied by the system, so the function always sees the value mapped to the
+ * key of the current element. That way, the system can handle stream and state partitioning
+ * consistently together.
+ *
+ * @param <UK> Type of the keys in the state.
+ * @param <UV> Type of the values in the state.
+ */
+@PublicEvolving
+public interface MapState<UK, UV> extends State {
+
+ /**
+ * Returns the current value associated with the given key.
+ *
+ * @param key The key of the mapping
+ * @return The value of the mapping with the given key
+ *
+ * @throws Exception Thrown if the system cannot access the state.
+ */
+ UV get(UK key) throws Exception;
+
+ /**
+ * Associates a new value with the given key.
+ *
+ * @param key The key of the mapping
+ * @param value The new value of the mapping
+ *
+ * @throws Exception Thrown if the system cannot access the state.
+ */
+ void put(UK key, UV value) throws Exception;
+
+ /**
+ * Copies all of the mappings from the given map into the state.
+ *
+ * @param map The mappings to be stored in this state
+ *
+ * @throws Exception Thrown if the system cannot access the state.
+ */
+ void putAll(Map<UK, UV> map) throws Exception;
+
+ /**
+ * Deletes the mapping of the given key.
+ *
+ * @param key The key of the mapping
+ *
+ * @throws Exception Thrown if the system cannot access the state.
+ */
+ void remove(UK key) throws Exception;
+
+ /**
+ * Returns whether there exists the given mapping.
+ *
+ * @param key The key of the mapping
+ * @return True if there exists a mapping whose key equals to the given key
+ *
+ * @throws Exception Thrown if the system cannot access the state.
+ */
+ boolean contains(UK key) throws Exception;
+
+ /**
+ * @return The number of mappings in the state.
+ *
+ * @throws Exception Thrown if the system cannot access the state.
+ */
+ int size() throws Exception;
+
+ /**
+ * Returns all the mappings in the state
+ *
+ * @return An iterable view of all the key-value pairs in the state.
+ *
+ * @throws Exception Thrown if the system cannot access the state.
+ */
+ Iterable<Map.Entry<UK, UV>> entries() throws Exception;
+
+ /**
+ * Returns all the keys in the state
+ *
+ * @return An iterable view of all the keys in the state.
+ *
+ * @throws Exception Thrown if the system cannot access the state.
+ */
+ Iterable<UK> keys() throws Exception;
+
+ /**
+ * Returns all the values in the state.
+ *
+ * @return An iterable view of all the values in the state.
+ *
+ * @throws Exception Thrown if the system cannot access the state.
+ */
+ Iterable<UV> values() throws Exception;
+
+ /**
+ * Iterates over all the mappings in the state.
+ *
+ * @return An iterator over all the mappings in the state
+ *
+ * @throws Exception Thrown if the system cannot access the state.
+ */
+ Iterator<Map.Entry<UK, UV>> iterator() throws Exception;
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/30c9e2b6/flink-core/src/main/java/org/apache/flink/api/common/state/MapStateDescriptor.java
----------------------------------------------------------------------
diff --git a/flink-core/src/main/java/org/apache/flink/api/common/state/MapStateDescriptor.java b/flink-core/src/main/java/org/apache/flink/api/common/state/MapStateDescriptor.java
new file mode 100644
index 0000000..d4a49f8
--- /dev/null
+++ b/flink-core/src/main/java/org/apache/flink/api/common/state/MapStateDescriptor.java
@@ -0,0 +1,147 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.api.common.state;
+
+import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.common.typeutils.base.MapSerializer;
+import org.apache.flink.api.java.typeutils.MapTypeInfo;
+
+import java.util.Map;
+
+/**
+ * A {@link StateDescriptor} for {@link MapState}. This can be used to create state where the type
+ * is a map that can be updated and iterated over.
+ *
+ * <p>Using {@code MapState} is typically more efficient than manually maintaining a map in a
+ * {@link ValueState}, because the backing implementation can support efficient updates, rather then
+ * replacing the full map on write.
+ *
+ * <p>To create keyed map state (on a KeyedStream), use
+ * {@link org.apache.flink.api.common.functions.RuntimeContext#getMapState(MapStateDescriptor)}.
+ *
+ * @param <UK> The type of the keys that can be added to the map state.
+ */
+@PublicEvolving
+public class MapStateDescriptor<UK, UV> extends StateDescriptor<MapState<UK, UV>, Map<UK, UV>> {
+
+ /**
+ * Create a new {@code MapStateDescriptor} with the given name and the given type serializers.
+ *
+ * @param name The name of the {@code MapStateDescriptor}.
+ * @param keySerializer The type serializer for the keys in the state.
+ * @param valueSerializer The type serializer for the values in the state.
+ */
+ public MapStateDescriptor(String name, TypeSerializer<UK> keySerializer, TypeSerializer<UV> valueSerializer) {
+ super(name, new MapSerializer<>(keySerializer, valueSerializer), null);
+ }
+
+ /**
+ * Create a new {@code MapStateDescriptor} with the given name and the given type informations.
+ *
+ * @param name The name of the {@code MapStateDescriptor}.
+ * @param keyTypeInfo The type information for the keys in the state.
+ * @param valueTypeInfo The type information for the values in the state.
+ */
+ public MapStateDescriptor(String name, TypeInformation<UK> keyTypeInfo, TypeInformation<UV> valueTypeInfo) {
+ super(name, new MapTypeInfo<>(keyTypeInfo, valueTypeInfo), null);
+ }
+
+ /**
+ * Create a new {@code MapStateDescriptor} with the given name and the given type information.
+ *
+ * <p>If this constructor fails (because it is not possible to describe the type via a class),
+ * consider using the {@link #MapStateDescriptor(String, TypeInformation, TypeInformation)} constructor.
+ *
+ * @param name The name of the {@code MapStateDescriptor}.
+ * @param keyClass The class of the type of keys in the state.
+ * @param valueClass The class of the type of values in the state.
+ */
+ public MapStateDescriptor(String name, Class<UK> keyClass, Class<UV> valueClass) {
+ super(name, new MapTypeInfo<>(keyClass, valueClass), null);
+ }
+
+ @Override
+ public MapState<UK, UV> bind(StateBinder stateBinder) throws Exception {
+ return stateBinder.createMapState(this);
+ }
+
+ @Override
+ public Type getType() {
+ return Type.MAP;
+ }
+
+ /**
+ * Gets the serializer for the keys in the state.
+ *
+ * @return The serializer for the keys in the state.
+ */
+ public TypeSerializer<UK> getKeySerializer() {
+ final TypeSerializer<Map<UK, UV>> rawSerializer = getSerializer();
+ if (!(rawSerializer instanceof MapSerializer)) {
+ throw new IllegalStateException("Unexpected serializer type.");
+ }
+
+ return ((MapSerializer<UK, UV>) rawSerializer).getKeySerializer();
+ }
+
+ /**
+ * Gets the serializer for the values in the state.
+ *
+ * @return The serializer for the values in the state.
+ */
+ public TypeSerializer<UV> getValueSerializer() {
+ final TypeSerializer<Map<UK, UV>> rawSerializer = getSerializer();
+ if (!(rawSerializer instanceof MapSerializer)) {
+ throw new IllegalStateException("Unexpected serializer type.");
+ }
+
+ return ((MapSerializer<UK, UV>) rawSerializer).getValueSerializer();
+ }
+
+ @Override
+ public int hashCode() {
+ int result = serializer.hashCode();
+ result = 31 * result + name.hashCode();
+ return result;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+
+ MapStateDescriptor<?, ?> that = (MapStateDescriptor<?, ?>) o;
+ return serializer.equals(that.serializer) && name.equals(that.name);
+ }
+
+ @Override
+ public String toString() {
+ return "MapStateDescriptor{" +
+ "name=" + name +
+ ", serializer=" + serializer +
+ '}';
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/30c9e2b6/flink-core/src/main/java/org/apache/flink/api/common/state/StateBinder.java
----------------------------------------------------------------------
diff --git a/flink-core/src/main/java/org/apache/flink/api/common/state/StateBinder.java b/flink-core/src/main/java/org/apache/flink/api/common/state/StateBinder.java
index 08dfc90..9df7a47 100644
--- a/flink-core/src/main/java/org/apache/flink/api/common/state/StateBinder.java
+++ b/flink-core/src/main/java/org/apache/flink/api/common/state/StateBinder.java
@@ -70,4 +70,13 @@ public interface StateBinder {
* @param <ACC> Type of the value in the state
*/
<T, ACC> FoldingState<T, ACC> createFoldingState(FoldingStateDescriptor<T, ACC> stateDesc) throws Exception;
+
+ /**
+ * Creates and returns a new {@link MapState}.
+ * @param stateDesc The {@code StateDescriptor} that contains the name of the state.
+ *
+ * @param <MK> Type of the keys in the state
+ * @param <MV> Type of the values in the state
+ */
+ <MK, MV> MapState<MK, MV> createMapState(MapStateDescriptor<MK, MV> stateDesc) throws Exception;
}
http://git-wip-us.apache.org/repos/asf/flink/blob/30c9e2b6/flink-core/src/main/java/org/apache/flink/api/common/state/StateDescriptor.java
----------------------------------------------------------------------
diff --git a/flink-core/src/main/java/org/apache/flink/api/common/state/StateDescriptor.java b/flink-core/src/main/java/org/apache/flink/api/common/state/StateDescriptor.java
index 332e649..a52ea32 100644
--- a/flink-core/src/main/java/org/apache/flink/api/common/state/StateDescriptor.java
+++ b/flink-core/src/main/java/org/apache/flink/api/common/state/StateDescriptor.java
@@ -55,7 +55,7 @@ public abstract class StateDescriptor<S extends State, T> implements Serializabl
*/
// IMPORTANT: Do not change the order of the elements in this enum, ordinal is used in serialization
public enum Type {
- @Deprecated UNKNOWN, VALUE, LIST, REDUCING, FOLDING, AGGREGATING
+ @Deprecated UNKNOWN, VALUE, LIST, REDUCING, FOLDING, AGGREGATING, MAP
}
private static final long serialVersionUID = 1L;
http://git-wip-us.apache.org/repos/asf/flink/blob/30c9e2b6/flink-core/src/main/java/org/apache/flink/api/common/typeutils/base/MapSerializer.java
----------------------------------------------------------------------
diff --git a/flink-core/src/main/java/org/apache/flink/api/common/typeutils/base/MapSerializer.java b/flink-core/src/main/java/org/apache/flink/api/common/typeutils/base/MapSerializer.java
new file mode 100644
index 0000000..5e1a3bf
--- /dev/null
+++ b/flink-core/src/main/java/org/apache/flink/api/common/typeutils/base/MapSerializer.java
@@ -0,0 +1,193 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.api.common.typeutils.base;
+
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.core.memory.DataInputView;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.Map;
+import java.util.HashMap;
+
+/**
+ * A serializer for {@link Map}. The serializer relies on a key serializer and a value serializer
+ * for the serialization of the map's key-value pairs.
+ *
+ * <p>The serialization format for the map is as follows: four bytes for the length of the map,
+ * followed by the serialized representation of each key-value pair. To allow null values, each value
+ * is prefixed by a null marker.
+ *
+ * @param <K> The type of the keys in the map.
+ * @param <V> The type of the values in the map.
+ */
+public class MapSerializer<K, V> extends TypeSerializer<Map<K, V>> {
+
+ private static final long serialVersionUID = -6885593032367050078L;
+
+ /** The serializer for the keys in the map */
+ private final TypeSerializer<K> keySerializer;
+
+ /** The serializer for the values in the map */
+ private final TypeSerializer<V> valueSerializer;
+
+ /**
+ * Creates a map serializer that uses the given serializers to serialize the key-value pairs in the map.
+ *
+ * @param keySerializer The serializer for the keys in the map
+ * @param valueSerializer The serializer for the values in the map
+ */
+ public MapSerializer(TypeSerializer<K> keySerializer, TypeSerializer<V> valueSerializer) {
+ this.keySerializer = Preconditions.checkNotNull(keySerializer, "The key serializer cannot be null");
+ this.valueSerializer = Preconditions.checkNotNull(valueSerializer, "The value serializer cannot be null.");
+ }
+
+ // ------------------------------------------------------------------------
+ // MapSerializer specific properties
+ // ------------------------------------------------------------------------
+
+ public TypeSerializer<K> getKeySerializer() {
+ return keySerializer;
+ }
+
+ public TypeSerializer<V> getValueSerializer() {
+ return valueSerializer;
+ }
+
+ // ------------------------------------------------------------------------
+ // Type Serializer implementation
+ // ------------------------------------------------------------------------
+
+ @Override
+ public boolean isImmutableType() {
+ return false;
+ }
+
+ @Override
+ public TypeSerializer<Map<K, V>> duplicate() {
+ TypeSerializer<K> duplicateKeySerializer = keySerializer.duplicate();
+ TypeSerializer<V> duplicateValueSerializer = valueSerializer.duplicate();
+
+ return new MapSerializer<>(duplicateKeySerializer, duplicateValueSerializer);
+ }
+
+ @Override
+ public Map<K, V> createInstance() {
+ return new HashMap<>();
+ }
+
+ @Override
+ public Map<K, V> copy(Map<K, V> from) {
+ Map<K, V> newMap = new HashMap<>(from.size());
+
+ for (Map.Entry<K, V> entry : from.entrySet()) {
+ K newKey = keySerializer.copy(entry.getKey());
+ V newValue = entry.getValue() == null ? null : valueSerializer.copy(entry.getValue());
+
+ newMap.put(newKey, newValue);
+ }
+
+ return newMap;
+ }
+
+ @Override
+ public Map<K, V> copy(Map<K, V> from, Map<K, V> reuse) {
+ return copy(from);
+ }
+
+ @Override
+ public int getLength() {
+ return -1; // var length
+ }
+
+ @Override
+ public void serialize(Map<K, V> map, DataOutputView target) throws IOException {
+ final int size = map.size();
+ target.writeInt(size);
+
+ for (Map.Entry<K, V> entry : map.entrySet()) {
+ keySerializer.serialize(entry.getKey(), target);
+
+ if (entry.getValue() == null) {
+ target.writeBoolean(true);
+ } else {
+ target.writeBoolean(false);
+ valueSerializer.serialize(entry.getValue(), target);
+ }
+ }
+ }
+
+ @Override
+ public Map<K, V> deserialize(DataInputView source) throws IOException {
+ final int size = source.readInt();
+
+ final Map<K, V> map = new HashMap<>(size);
+ for (int i = 0; i < size; ++i) {
+ K key = keySerializer.deserialize(source);
+
+ boolean isNull = source.readBoolean();
+ V value = isNull ? null : valueSerializer.deserialize(source);
+
+ map.put(key, value);
+ }
+
+ return map;
+ }
+
+ @Override
+ public Map<K, V> deserialize(Map<K, V> reuse, DataInputView source) throws IOException {
+ return deserialize(source);
+ }
+
+ @Override
+ public void copy(DataInputView source, DataOutputView target) throws IOException {
+ final int size = source.readInt();
+ target.writeInt(size);
+
+ for (int i = 0; i < size; ++i) {
+ keySerializer.copy(source, target);
+
+ boolean isNull = source.readBoolean();
+ target.writeBoolean(isNull);
+
+ if (!isNull) {
+ valueSerializer.copy(source, target);
+ }
+ }
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ return obj == this ||
+ (obj != null && obj.getClass() == getClass() &&
+ keySerializer.equals(((MapSerializer<?, ?>) obj).getKeySerializer()) &&
+ valueSerializer.equals(((MapSerializer<?, ?>) obj).getValueSerializer()));
+ }
+
+ @Override
+ public boolean canEqual(Object obj) {
+ return (obj != null && obj.getClass() == getClass());
+ }
+
+ @Override
+ public int hashCode() {
+ return keySerializer.hashCode() * 31 + valueSerializer.hashCode();
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/30c9e2b6/flink-core/src/main/java/org/apache/flink/api/java/typeutils/MapTypeInfo.java
----------------------------------------------------------------------
diff --git a/flink-core/src/main/java/org/apache/flink/api/java/typeutils/MapTypeInfo.java b/flink-core/src/main/java/org/apache/flink/api/java/typeutils/MapTypeInfo.java
new file mode 100644
index 0000000..ca04e0c
--- /dev/null
+++ b/flink-core/src/main/java/org/apache/flink/api/java/typeutils/MapTypeInfo.java
@@ -0,0 +1,147 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.api.java.typeutils;
+
+import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.api.common.ExecutionConfig;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.common.typeutils.base.MapSerializer;
+import org.apache.flink.util.Preconditions;
+
+import java.util.Map;
+
+import static org.apache.flink.util.Preconditions.checkNotNull;
+
+/**
+ * Special {@code TypeInformation} used by {@link org.apache.flink.api.common.state.MapStateDescriptor}.
+ *
+ * @param <K> The type of the keys in the map.
+ * @param <V> The type of the values in the map.
+ */
+@PublicEvolving
+public class MapTypeInfo<K, V> extends TypeInformation<Map<K, V>> {
+
+ /* The type information for the keys in the map*/
+ private final TypeInformation<K> keyTypeInfo;
+
+ /* The type information for the values in the map */
+ private final TypeInformation<V> valueTypeInfo;
+
+ public MapTypeInfo(TypeInformation<K> keyTypeInfo, TypeInformation<V> valueTypeInfo) {
+ this.keyTypeInfo = Preconditions.checkNotNull(keyTypeInfo, "The key type information cannot be null.");
+ this.valueTypeInfo = Preconditions.checkNotNull(valueTypeInfo, "The value type information cannot be null.");
+ }
+
+ public MapTypeInfo(Class<K> keyClass, Class<V> valueClass) {
+ this.keyTypeInfo = of(checkNotNull(keyClass, "The key class cannot be null."));
+ this.valueTypeInfo = of(checkNotNull(valueClass, "The value class cannot be null."));
+ }
+
+ // ------------------------------------------------------------------------
+ // MapTypeInfo specific properties
+ // ------------------------------------------------------------------------
+
+ /**
+ * Gets the type information for the keys in the map
+ */
+ public TypeInformation<K> getKeyTypeInfo() {
+ return keyTypeInfo;
+ }
+
+ /**
+ * Gets the type information for the values in the map
+ */
+ public TypeInformation<V> getValueTypeInfo() {
+ return valueTypeInfo;
+ }
+
+ // ------------------------------------------------------------------------
+ // TypeInformation implementation
+ // ------------------------------------------------------------------------
+
+ @Override
+ public boolean isBasicType() {
+ return false;
+ }
+
+ @Override
+ public boolean isTupleType() {
+ return false;
+ }
+
+ @Override
+ public int getArity() {
+ return 0;
+ }
+
+ @Override
+ public int getTotalFields() {
+ return 2;
+ }
+
+ @SuppressWarnings("unchecked")
+ @Override
+ public Class<Map<K, V>> getTypeClass() {
+ return (Class<Map<K, V>>)(Class<?>)Map.class;
+ }
+
+ @Override
+ public boolean isKeyType() {
+ return false;
+ }
+
+ @Override
+ public TypeSerializer<Map<K, V>> createSerializer(ExecutionConfig config) {
+ TypeSerializer<K> keyTypeSerializer = keyTypeInfo.createSerializer(config);
+ TypeSerializer<V> valueTypeSerializer = valueTypeInfo.createSerializer(config);
+
+ return new MapSerializer<>(keyTypeSerializer, valueTypeSerializer);
+ }
+
+ @Override
+ public String toString() {
+ return "Map<" + keyTypeInfo + ", " + valueTypeInfo + ">";
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ if (obj == this) {
+ return true;
+ } else if (obj instanceof MapTypeInfo) {
+ @SuppressWarnings("unchecked")
+ MapTypeInfo<K, V> other = (MapTypeInfo<K, V>) obj;
+
+ return (other.canEqual(this) &&
+ keyTypeInfo.equals(other.keyTypeInfo) && valueTypeInfo.equals(other.valueTypeInfo));
+ } else {
+ return false;
+ }
+ }
+
+ @Override
+ public int hashCode() {
+ return 31 * keyTypeInfo.hashCode() + valueTypeInfo.hashCode();
+ }
+
+ @Override
+ public boolean canEqual(Object obj) {
+ return (obj != null && obj.getClass() == getClass());
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/30c9e2b6/flink-core/src/test/java/org/apache/flink/api/common/state/MapStateDescriptorTest.java
----------------------------------------------------------------------
diff --git a/flink-core/src/test/java/org/apache/flink/api/common/state/MapStateDescriptorTest.java b/flink-core/src/test/java/org/apache/flink/api/common/state/MapStateDescriptorTest.java
new file mode 100644
index 0000000..9d1b105
--- /dev/null
+++ b/flink-core/src/test/java/org/apache/flink/api/common/state/MapStateDescriptorTest.java
@@ -0,0 +1,115 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.api.common.state;
+
+import org.apache.flink.api.common.ExecutionConfig;
+import org.apache.flink.api.common.TaskInfo;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.common.typeutils.base.LongSerializer;
+import org.apache.flink.api.common.typeutils.base.MapSerializer;
+import org.apache.flink.api.common.typeutils.base.StringSerializer;
+import org.apache.flink.api.java.typeutils.runtime.kryo.KryoSerializer;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.core.testutils.CommonTestUtils;
+
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+
+public class MapStateDescriptorTest {
+
+ @Test
+ public void testMapStateDescriptorEagerSerializer() throws Exception {
+
+ TypeSerializer<Integer> keySerializer = new KryoSerializer<>(Integer.class, new ExecutionConfig());
+ TypeSerializer<String> valueSerializer = new KryoSerializer<>(String.class, new ExecutionConfig());
+
+ MapStateDescriptor<Integer, String> descr =
+ new MapStateDescriptor<>("testName", keySerializer, valueSerializer);
+
+ assertEquals("testName", descr.getName());
+ assertNotNull(descr.getSerializer());
+ assertTrue(descr.getSerializer() instanceof MapSerializer);
+ assertNotNull(descr.getKeySerializer());
+ assertEquals(keySerializer, descr.getKeySerializer());
+ assertNotNull(descr.getValueSerializer());
+ assertEquals(valueSerializer, descr.getValueSerializer());
+
+ MapStateDescriptor<Integer, String> copy = CommonTestUtils.createCopySerializable(descr);
+
+ assertEquals("testName", copy.getName());
+ assertNotNull(copy.getSerializer());
+ assertTrue(copy.getSerializer() instanceof MapSerializer);
+
+ assertNotNull(copy.getKeySerializer());
+ assertEquals(keySerializer, copy.getKeySerializer());
+ assertNotNull(copy.getValueSerializer());
+ assertEquals(valueSerializer, copy.getValueSerializer());
+ }
+
+ @Test
+ public void testMapStateDescriptorLazySerializer() throws Exception {
+ // some different registered value
+ ExecutionConfig cfg = new ExecutionConfig();
+ cfg.registerKryoType(TaskInfo.class);
+
+ MapStateDescriptor<Path, String> descr =
+ new MapStateDescriptor<>("testName", Path.class, String.class);
+
+ try {
+ descr.getSerializer();
+ fail("should cause an exception");
+ } catch (IllegalStateException ignored) {}
+
+ descr.initializeSerializerUnlessSet(cfg);
+
+ assertNotNull(descr.getSerializer());
+ assertTrue(descr.getSerializer() instanceof MapSerializer);
+
+ assertNotNull(descr.getKeySerializer());
+ assertTrue(descr.getKeySerializer() instanceof KryoSerializer);
+
+ assertTrue(((KryoSerializer<?>) descr.getKeySerializer()).getKryo().getRegistration(TaskInfo.class).getId() > 0);
+
+ assertNotNull(descr.getValueSerializer());
+ assertTrue(descr.getValueSerializer() instanceof StringSerializer);
+ }
+
+ @Test
+ public void testMapStateDescriptorAutoSerializer() throws Exception {
+
+ MapStateDescriptor<String, Long> descr =
+ new MapStateDescriptor<>("testName", String.class, Long.class);
+
+ MapStateDescriptor<String, Long> copy = CommonTestUtils.createCopySerializable(descr);
+
+ assertEquals("testName", copy.getName());
+
+ assertNotNull(copy.getSerializer());
+ assertTrue(copy.getSerializer() instanceof MapSerializer);
+
+ assertNotNull(copy.getKeySerializer());
+ assertEquals(StringSerializer.INSTANCE, copy.getKeySerializer());
+ assertNotNull(copy.getValueSerializer());
+ assertEquals(LongSerializer.INSTANCE, copy.getValueSerializer());
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/30c9e2b6/flink-core/src/test/java/org/apache/flink/api/common/typeutils/base/MapSerializerTest.java
----------------------------------------------------------------------
diff --git a/flink-core/src/test/java/org/apache/flink/api/common/typeutils/base/MapSerializerTest.java b/flink-core/src/test/java/org/apache/flink/api/common/typeutils/base/MapSerializerTest.java
new file mode 100644
index 0000000..9ce7de1
--- /dev/null
+++ b/flink-core/src/test/java/org/apache/flink/api/common/typeutils/base/MapSerializerTest.java
@@ -0,0 +1,90 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.api.common.typeutils.base;
+
+import org.apache.flink.api.common.typeutils.SerializerTestBase;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Random;
+import java.util.TreeMap;
+
+/**
+ * A test for the {@link MapSerializer}.
+ */
+public class MapSerializerTest extends SerializerTestBase<Map<Long, String>> {
+
+ @Override
+ protected TypeSerializer<Map<Long, String>> createSerializer() {
+ return new MapSerializer<>(LongSerializer.INSTANCE, StringSerializer.INSTANCE);
+ }
+
+ @Override
+ protected int getLength() {
+ return -1;
+ }
+
+ @SuppressWarnings("unchecked")
+ @Override
+ protected Class<Map<Long, String>> getTypeClass() {
+ return (Class<Map<Long, String>>) (Class<?>) Map.class;
+ }
+
+ @SuppressWarnings({"rawtypes", "unchecked"})
+ @Override
+ protected Map<Long, String>[] getTestData() {
+ final Random rnd = new Random(123654789);
+
+ // empty maps
+ final Map<Long, String> map1 = Collections.emptyMap();
+ final Map<Long, String> map2 = new HashMap<>();
+ final Map<Long, String> map3 = new TreeMap<>();
+
+ // single element maps
+ final Map<Long, String> map4 = Collections.singletonMap(0L, "hello");
+ final Map<Long, String> map5 = new HashMap<>();
+ map5.put(12345L, "12345L");
+ final Map<Long, String> map6 = new TreeMap<>();
+ map6.put(777888L, "777888L");
+
+ // longer maps
+ final Map<Long, String> map7 = new HashMap<>();
+ for (int i = 0; i < rnd.nextInt(200); i++) {
+ map7.put(rnd.nextLong(), Long.toString(rnd.nextLong()));
+ }
+
+ final Map<Long, String> map8 = new TreeMap<>();
+ for (int i = 0; i < rnd.nextInt(200); i++) {
+ map8.put(rnd.nextLong(), Long.toString(rnd.nextLong()));
+ }
+
+ // null-value maps
+ final Map<Long, String> map9 = Collections.singletonMap(0L, null);
+ final Map<Long, String> map10 = new HashMap<>();
+ map10.put(999L, null);
+ final Map<Long, String> map11 = new TreeMap<>();
+ map11.put(666L, null);
+
+ return (Map<Long, String>[]) new Map[] {
+ map1, map2, map3, map4, map5, map6, map7, map8, map9, map10, map11
+ };
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/30c9e2b6/flink-fs-tests/src/test/java/org/apache/flink/hdfstests/FileStateBackendTest.java
----------------------------------------------------------------------
diff --git a/flink-fs-tests/src/test/java/org/apache/flink/hdfstests/FileStateBackendTest.java b/flink-fs-tests/src/test/java/org/apache/flink/hdfstests/FileStateBackendTest.java
index 109d152..7f8eea8 100644
--- a/flink-fs-tests/src/test/java/org/apache/flink/hdfstests/FileStateBackendTest.java
+++ b/flink-fs-tests/src/test/java/org/apache/flink/hdfstests/FileStateBackendTest.java
@@ -118,6 +118,10 @@ public class FileStateBackendTest extends StateBackendTestBase<FsStateBackend> {
@Override
@Test
public void testReducingStateRestoreWithWrongSerializers() {}
+
+ @Override
+ @Test
+ public void testMapStateRestoreWithWrongSerializers() {}
@Test
public void testStateOutputStream() {
http://git-wip-us.apache.org/repos/asf/flink/blob/30c9e2b6/flink-runtime/src/main/java/org/apache/flink/runtime/query/netty/message/KvStateRequestSerializer.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/query/netty/message/KvStateRequestSerializer.java b/flink-runtime/src/main/java/org/apache/flink/runtime/query/netty/message/KvStateRequestSerializer.java
index 2f32861..bc830e1 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/query/netty/message/KvStateRequestSerializer.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/query/netty/message/KvStateRequestSerializer.java
@@ -36,7 +36,9 @@ import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.nio.ByteBuffer;
import java.util.ArrayList;
+import java.util.HashMap;
import java.util.List;
+import java.util.Map;
/**
* Serialization and deserialization of messages exchanged between
@@ -484,6 +486,71 @@ public final class KvStateRequestSerializer {
return null;
}
}
+
+ /**
+ * Serializes all values of the Iterable with the given serializer.
+ *
+ * @param entries Key-value pairs to serialize
+ * @param keySerializer Serializer for UK
+ * @param valueSerializer Serializer for UV
+ * @param <UK> Type of the keys
+ * @param <UV> Type of the values
+ * @return Serialized values or <code>null</code> if values <code>null</code> or empty
+ * @throws IOException On failure during serialization
+ */
+ public static <UK, UV> byte[] serializeMap(Iterable<Map.Entry<UK, UV>> entries, TypeSerializer<UK> keySerializer, TypeSerializer<UV> valueSerializer) throws IOException {
+ if (entries != null) {
+ // Serialize
+ DataOutputSerializer dos = new DataOutputSerializer(32);
+
+ for (Map.Entry<UK, UV> entry : entries) {
+ keySerializer.serialize(entry.getKey(), dos);
+
+ if (entry.getValue() == null) {
+ dos.writeBoolean(true);
+ } else {
+ dos.writeBoolean(false);
+ valueSerializer.serialize(entry.getValue(), dos);
+ }
+ }
+
+ return dos.getCopyOfBuffer();
+ } else {
+ return null;
+ }
+ }
+
+ /**
+ * Deserializes all kv pairs with the given serializer.
+ *
+ * @param serializedValue Serialized value of type Map<UK, UV>
+ * @param keySerializer Serializer for UK
+ * @param valueSerializer Serializer for UV
+ * @param <UK> Type of the key
+ * @param <UV> Type of the value.
+ * @return Deserialized map or <code>null</code> if the serialized value
+ * is <code>null</code>
+ * @throws IOException On failure during deserialization
+ */
+ public static <UK, UV> Map<UK, UV> deserializeMap(byte[] serializedValue, TypeSerializer<UK> keySerializer, TypeSerializer<UV> valueSerializer) throws IOException {
+ if (serializedValue != null) {
+ DataInputDeserializer in = new DataInputDeserializer(serializedValue, 0, serializedValue.length);
+
+ Map<UK, UV> result = new HashMap<>();
+ while (in.available() > 0) {
+ UK key = keySerializer.deserialize(in);
+
+ boolean isNull = in.readBoolean();
+ UV value = isNull ? null : valueSerializer.deserialize(in);
+
+ result.put(key, value);
+ }
+
+ return result;
+ } else {
+ return null;
+ }
+ }
// ------------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/flink/blob/30c9e2b6/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractKeyedStateBackend.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractKeyedStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractKeyedStateBackend.java
index fe5d1cc..3ed49f1 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractKeyedStateBackend.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractKeyedStateBackend.java
@@ -25,6 +25,8 @@ import org.apache.flink.api.common.state.FoldingState;
import org.apache.flink.api.common.state.FoldingStateDescriptor;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.state.MapState;
+import org.apache.flink.api.common.state.MapStateDescriptor;
import org.apache.flink.api.common.state.ReducingState;
import org.apache.flink.api.common.state.ReducingStateDescriptor;
import org.apache.flink.api.common.state.State;
@@ -39,6 +41,7 @@ import org.apache.flink.runtime.state.internal.InternalAggregatingState;
import org.apache.flink.runtime.state.internal.InternalFoldingState;
import org.apache.flink.runtime.state.internal.InternalKvState;
import org.apache.flink.runtime.state.internal.InternalListState;
+import org.apache.flink.runtime.state.internal.InternalMapState;
import org.apache.flink.runtime.state.internal.InternalReducingState;
import org.apache.flink.runtime.state.internal.InternalValueState;
import org.apache.flink.util.Preconditions;
@@ -189,6 +192,20 @@ public abstract class AbstractKeyedStateBackend<K>
FoldingStateDescriptor<T, ACC> stateDesc) throws Exception;
/**
+ * Creates and returns a new {@link MapState}.
+ *
+ * @param namespaceSerializer TypeSerializer for the state namespace.
+ * @param stateDesc The {@code StateDescriptor} that contains the name of the state.
+ *
+ * @param <N> The type of the namespace.
+ * @param <UK> Type of the keys in the state
+ * @param <UV> Type of the values in the state *
+ */
+ protected abstract <N, UK, UV> InternalMapState<N, UK, UV> createMapState(
+ TypeSerializer<N> namespaceSerializer,
+ MapStateDescriptor<UK, UV> stateDesc) throws Exception;
+
+ /**
* @see KeyedStateBackend
*/
@Override
@@ -285,12 +302,16 @@ public abstract class AbstractKeyedStateBackend<K>
AggregatingStateDescriptor<T, ACC, R> stateDesc) throws Exception {
return AbstractKeyedStateBackend.this.createAggregatingState(namespaceSerializer, stateDesc);
}
-
@Override
public <T, ACC> FoldingState<T, ACC> createFoldingState(FoldingStateDescriptor<T, ACC> stateDesc) throws Exception {
return AbstractKeyedStateBackend.this.createFoldingState(namespaceSerializer, stateDesc);
}
+
+ @Override
+ public <UK, UV> MapState<UK, UV> createMapState(MapStateDescriptor<UK, UV> stateDesc) throws Exception {
+ return AbstractKeyedStateBackend.this.createMapState(namespaceSerializer, stateDesc);
+ }
});
http://git-wip-us.apache.org/repos/asf/flink/blob/30c9e2b6/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultKeyedStateStore.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultKeyedStateStore.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultKeyedStateStore.java
index d8b8aa8..a32cebd 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultKeyedStateStore.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultKeyedStateStore.java
@@ -25,6 +25,8 @@ import org.apache.flink.api.common.state.FoldingStateDescriptor;
import org.apache.flink.api.common.state.KeyedStateStore;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.state.MapState;
+import org.apache.flink.api.common.state.MapStateDescriptor;
import org.apache.flink.api.common.state.ReducingState;
import org.apache.flink.api.common.state.ReducingStateDescriptor;
import org.apache.flink.api.common.state.State;
@@ -93,6 +95,18 @@ public class DefaultKeyedStateStore implements KeyedStateStore {
}
}
+ @Override
+ public <UK, UV> MapState<UK, UV> getMapState(MapStateDescriptor<UK, UV> stateProperties) {
+ requireNonNull(stateProperties, "The state properties must not be null");
+ try {
+ stateProperties.initializeSerializerUnlessSet(executionConfig);
+ MapState<UK, UV> originalState = getPartitionedState(stateProperties);
+ return new UserFacingMapState<>(originalState);
+ } catch (Exception e) {
+ throw new RuntimeException("Error while getting state", e);
+ }
+ }
+
private <S extends State> S getPartitionedState(StateDescriptor<S, ?> stateDescriptor) throws Exception {
return keyedStateBackend.getPartitionedState(
VoidNamespace.INSTANCE,
http://git-wip-us.apache.org/repos/asf/flink/blob/30c9e2b6/flink-runtime/src/main/java/org/apache/flink/runtime/state/HashMapSerializer.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/HashMapSerializer.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/HashMapSerializer.java
new file mode 100644
index 0000000..61cc58c
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/HashMapSerializer.java
@@ -0,0 +1,193 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state;
+
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.core.memory.DataInputView;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * A serializer for {@link HashMap}. The serializer relies on a key serializer and a value serializer
+ * for the serialization of the map's key-value pairs.
+ *
+ * <p>The serialization format for the map is as follows: four bytes for the length of the map,
+ * followed by the serialized representation of each key-value pair. To allow null values, each value
+ * is prefixed by a null marker.
+ *
+ * @param <K> The type of the keys in the map.
+ * @param <V> The type of the values in the map.
+ */
+public class HashMapSerializer<K, V> extends TypeSerializer<HashMap<K, V>> {
+
+ private static final long serialVersionUID = -6885593032367050078L;
+
+ /** The serializer for the keys in the map */
+ private final TypeSerializer<K> keySerializer;
+
+ /** The serializer for the values in the map */
+ private final TypeSerializer<V> valueSerializer;
+
+ /**
+ * Creates a map serializer that uses the given serializers to serialize the key-value pairs in the map.
+ *
+ * @param keySerializer The serializer for the keys in the map
+ * @param valueSerializer The serializer for the values in the map
+ */
+ public HashMapSerializer(TypeSerializer<K> keySerializer, TypeSerializer<V> valueSerializer) {
+ this.keySerializer = Preconditions.checkNotNull(keySerializer, "The key serializer cannot be null");
+ this.valueSerializer = Preconditions.checkNotNull(valueSerializer, "The value serializer cannot be null.");
+ }
+
+ // ------------------------------------------------------------------------
+ // HashMapSerializer specific properties
+ // ------------------------------------------------------------------------
+
+ public TypeSerializer<K> getKeySerializer() {
+ return keySerializer;
+ }
+
+ public TypeSerializer<V> getValueSerializer() {
+ return valueSerializer;
+ }
+
+ // ------------------------------------------------------------------------
+ // Type Serializer implementation
+ // ------------------------------------------------------------------------
+
+ @Override
+ public boolean isImmutableType() {
+ return false;
+ }
+
+ @Override
+ public TypeSerializer<HashMap<K, V>> duplicate() {
+ TypeSerializer<K> duplicateKeySerializer = keySerializer.duplicate();
+ TypeSerializer<V> duplicateValueSerializer = valueSerializer.duplicate();
+
+ return new HashMapSerializer<>(duplicateKeySerializer, duplicateValueSerializer);
+ }
+
+ @Override
+ public HashMap<K, V> createInstance() {
+ return new HashMap<>();
+ }
+
+ @Override
+ public HashMap<K, V> copy(HashMap<K, V> from) {
+ HashMap<K, V> newHashMap = new HashMap<>(from.size());
+
+ for (Map.Entry<K, V> entry : from.entrySet()) {
+ K newKey = keySerializer.copy(entry.getKey());
+ V newValue = entry.getValue() == null ? null : valueSerializer.copy(entry.getValue());
+
+ newHashMap.put(newKey, newValue);
+ }
+
+ return newHashMap;
+ }
+
+ @Override
+ public HashMap<K, V> copy(HashMap<K, V> from, HashMap<K, V> reuse) {
+ return copy(from);
+ }
+
+ @Override
+ public int getLength() {
+ return -1; // var length
+ }
+
+ @Override
+ public void serialize(HashMap<K, V> map, DataOutputView target) throws IOException {
+ final int size = map.size();
+ target.writeInt(size);
+
+ for (Map.Entry<K, V> entry : map.entrySet()) {
+ keySerializer.serialize(entry.getKey(), target);
+
+ if (entry.getValue() == null) {
+ target.writeBoolean(true);
+ } else {
+ target.writeBoolean(false);
+ valueSerializer.serialize(entry.getValue(), target);
+ }
+ }
+ }
+
+ @Override
+ public HashMap<K, V> deserialize(DataInputView source) throws IOException {
+ final int size = source.readInt();
+
+ final HashMap<K, V> map = new HashMap<>(size);
+ for (int i = 0; i < size; ++i) {
+ K key = keySerializer.deserialize(source);
+
+ boolean isNull = source.readBoolean();
+ V value = isNull ? null : valueSerializer.deserialize(source);
+
+ map.put(key, value);
+ }
+
+ return map;
+ }
+
+ @Override
+ public HashMap<K, V> deserialize(HashMap<K, V> reuse, DataInputView source) throws IOException {
+ return deserialize(source);
+ }
+
+ @Override
+ public void copy(DataInputView source, DataOutputView target) throws IOException {
+ final int size = source.readInt();
+ target.writeInt(size);
+
+ for (int i = 0; i < size; ++i) {
+ keySerializer.copy(source, target);
+
+ boolean isNull = source.readBoolean();
+ target.writeBoolean(isNull);
+
+ if (!isNull) {
+ valueSerializer.copy(source, target);
+ }
+ }
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ return obj == this ||
+ (obj != null && obj.getClass() == getClass() &&
+ keySerializer.equals(((HashMapSerializer<?, ?>) obj).getKeySerializer()) &&
+ valueSerializer.equals(((HashMapSerializer<?, ?>) obj).getValueSerializer()));
+ }
+
+ @Override
+ public boolean canEqual(Object obj) {
+ return (obj != null && obj.getClass() == getClass());
+ }
+
+ @Override
+ public int hashCode() {
+ return keySerializer.hashCode() * 31 + valueSerializer.hashCode();
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/30c9e2b6/flink-runtime/src/main/java/org/apache/flink/runtime/state/UserFacingMapState.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/UserFacingMapState.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/UserFacingMapState.java
new file mode 100644
index 0000000..6cddf6d
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/UserFacingMapState.java
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state;
+
+import org.apache.flink.api.common.state.MapState;
+
+import java.util.Collections;
+import java.util.Iterator;
+import java.util.Map;
+
+/**
+ * Simple wrapper map state that exposes empty state properly as an empty map.
+ *
+ * @param <K> The type of keys in the map state.
+ * @param <V> The type of values in the map state.
+ */
+class UserFacingMapState<K, V> implements MapState<K, V> {
+
+ private final MapState<K, V> originalState;
+
+ private final Map<K, V> emptyState = Collections.<K, V>emptyMap();
+
+ UserFacingMapState(MapState<K, V> originalState) {
+ this.originalState = originalState;
+ }
+
+ // ------------------------------------------------------------------------
+
+ @Override
+ public V get(K key) throws Exception {
+ return originalState.get(key);
+ }
+
+ @Override
+ public void put(K key, V value) throws Exception {
+ originalState.put(key, value);
+ }
+
+ @Override
+ public void putAll(Map<K, V> value) throws Exception {
+ originalState.putAll(value);
+ }
+
+ @Override
+ public void clear() {
+ originalState.clear();
+ }
+
+ @Override
+ public void remove(K key) throws Exception {
+ originalState.remove(key);
+ }
+
+ @Override
+ public boolean contains(K key) throws Exception {
+ return originalState.contains(key);
+ }
+
+ @Override
+ public int size() throws Exception {
+ return originalState.size();
+ }
+
+ @Override
+ public Iterable<Map.Entry<K, V>> entries() throws Exception {
+ Iterable<Map.Entry<K, V>> original = originalState.entries();
+ return original != null ? original : emptyState.entrySet();
+ }
+
+ @Override
+ public Iterable<K> keys() throws Exception {
+ Iterable<K> original = originalState.keys();
+ return original != null ? original : emptyState.keySet();
+ }
+
+ @Override
+ public Iterable<V> values() throws Exception {
+ Iterable<V> original = originalState.values();
+ return original != null ? original : emptyState.values();
+ }
+
+ @Override
+ public Iterator<Map.Entry<K, V>> iterator() throws Exception {
+ Iterator<Map.Entry<K, V>> original = originalState.iterator();
+ return original != null ? original : emptyState.entrySet().iterator();
+ }
+}