You are viewing a plain text version of this content. The canonical link for it is here.
Posted to issues@flink.apache.org by GitBox <gi...@apache.org> on 2020/10/20 06:24:58 UTC

[GitHub] [flink] dianfu commented on a change in pull request #13665: [FLINK-19232][python] Support MapState and MapView in PyFlink.

dianfu commented on a change in pull request #13665:
URL: https://github.com/apache/flink/pull/13665#discussion_r508192677



##########
File path: flink-python/pyflink/fn_execution/state_impl.py
##########
@@ -119,13 +121,331 @@ def clear(self):
         self._internal_state.clear()
 
 
+class CachedMapState(LRUCache):
+
+    def __init__(self, max_entries):
+        super(CachedMapState, self).__init__(max_entries, None)
+        self._all_data_cached = False
+        self._existed_keys = set()
+
+        def on_evict(key, value):
+            self._existed_keys.remove(key)
+            self._all_data_cached = False
+
+        self.set_on_evict(on_evict)
+
+    def set_all_data_cached(self):
+        self._all_data_cached = True
+
+    def is_all_data_cached(self):
+        return self._all_data_cached
+
+    def put(self, key, existed_and_value):
+        if existed_and_value[0]:
+            self._existed_keys.add(key)
+        super(CachedMapState, self).put(key, existed_and_value)
+
+    def get_cached_existed_keys(self):
+        return self._existed_keys
+
+
+class CachingMapStateHandler(object):
+    # GET request flags
+    GET_FLAG = 0
+    ITERATE_FLAG = 1
+    CHECK_EMPTY_FLAG = 2
+    # GET response flags
+    EXIST_FLAG = 0
+    IS_NONE_FLAG = 1
+    NOT_EXIST_FLAG = 2
+    IS_EMPTY_FLAG = 3
+    NOT_EMPTY_FLAG = 4
+    # APPEND request flags
+    DELETE = 0
+    SET_NONE = 1
+    SET_VALUE = 2
+
+    def __init__(self, caching_state_handler, max_cached_map_key_entries):
+        self._state_cache = caching_state_handler._state_cache
+        self._underlying = caching_state_handler._underlying
+        self._context = caching_state_handler._context
+        self._max_cached_map_key_entries = max_cached_map_key_entries
+
+    def _get_cache_token(self):
+        if not self._state_cache.is_cache_enabled():
+            return None
+        if self._context.user_state_cache_token:
+            return self._context.user_state_cache_token
+        else:
+            return self._context.bundle_cache_token
+
+    def blocking_get(self, state_key, map_key, map_key_coder, map_value_coder):
+        cache_token = self._get_cache_token()
+        if not cache_token:
+            # Cache disabled / no cache token. Can't do a lookup/store in the cache.
+            return self._get_raw(state_key, map_key, map_key_coder, map_value_coder)
+        # Cache lookup
+        cache_state_key = self._convert_to_cache_key(state_key)
+        cached_map_state = self._state_cache.get(cache_state_key, cache_token)
+        if cached_map_state is None:
+            existed, value = self._get_raw(state_key, map_key, map_key_coder, map_value_coder)
+            cached_map_state = CachedMapState(self._max_cached_map_key_entries)
+            cached_map_state.put(map_key, (existed, value))
+            self._state_cache.put(cache_state_key, cache_token, cached_map_state)
+            return existed, value
+        else:
+            cached_value = cached_map_state.get(map_key)
+            if cached_value is None:
+                existed, value = self._get_raw(state_key, map_key, map_key_coder, map_value_coder)
+                cached_map_state.put(map_key, (existed, value))
+                return existed, value
+            else:
+                return cached_value
+
+    def extend(self, state_key, items: List[Tuple[int, Any, Any]], map_key_coder, map_value_coder):
+        cache_token = self._get_cache_token()
+        if cache_token:
+            # Cache lookup
+            cache_state_key = self._convert_to_cache_key(state_key)
+            cached_map_state = self._state_cache.get(cache_state_key, cache_token)
+            if cached_map_state is None:
+                cached_map_state = CachedMapState(self._max_cached_map_key_entries)
+                self._state_cache.put(cache_state_key, cache_token, cached_map_state)
+            for request_flag, map_key, map_value in items:
+                if request_flag == self.DELETE:
+                    cached_map_state.put(map_key, (False, None))
+                elif request_flag == self.SET_NONE:
+                    cached_map_state.put(map_key, (True, None))
+                elif request_flag == self.SET_VALUE:
+                    cached_map_state.put(map_key, (True, map_value))
+                else:
+                    raise Exception("Unknown flag: " + str(request_flag))
+        self._append_raw(
+            state_key,
+            items,
+            map_key_coder,
+            map_value_coder)
+
+    def check_empty(self, state_key):
+        cache_token = self._get_cache_token()
+        if cache_token:
+            # Cache lookup
+            cache_state_key = self._convert_to_cache_key(state_key)
+            cached_map_state = self._state_cache.get(cache_state_key, cache_token)
+            if cached_map_state is not None:
+                if cached_map_state.is_all_data_cached() and \
+                        len(cached_map_state.get_cached_existed_keys()) == 0:
+                    return True
+                elif len(cached_map_state.get_cached_existed_keys()) > 0:
+                    return False
+        return self._check_empty_raw(state_key)
+
+    def clear(self, state_key):
+        cache_token = self._get_cache_token()
+        if cache_token:
+            cache_key = self._convert_to_cache_key(state_key)
+            self._state_cache.evict(cache_key, cache_token)
+        return self._underlying.clear(state_key)
+
+    def _check_empty_raw(self, state_key):
+        output_stream = coder_impl.create_OutputStream()
+        output_stream.write_byte(self.CHECK_EMPTY_FLAG)
+        continuation_token = output_stream.get()
+        data, response_token = self._underlying.get_raw(state_key, continuation_token)
+        if data[0] == self.IS_EMPTY_FLAG:
+            return True
+        elif data[0] == self.NOT_EMPTY_FLAG:
+            return False
+        else:
+            raise Exception("Unknown response flag: " + str(data[0]))
+
+    def _get_raw(self, state_key, map_key, map_key_coder, map_value_coder):
+        output_stream = coder_impl.create_OutputStream()
+        output_stream.write_byte(self.GET_FLAG)
+        map_key_coder.encode_to_stream(map_key, output_stream, True)
+        continuation_token = output_stream.get()
+        data, response_token = self._underlying.get_raw(state_key, continuation_token)
+        input_stream = coder_impl.create_InputStream(data)
+        result_flag = input_stream.read_byte()
+        if result_flag == self.EXIST_FLAG:
+            return True, map_value_coder.decode_from_stream(input_stream, True)
+        elif result_flag == self.IS_NONE_FLAG:
+            return True, None
+        elif result_flag == self.NOT_EXIST_FLAG:
+            return False, None
+        else:
+            raise Exception("Unknown response flag: " + str(result_flag))
+
+    def _append_raw(self, state_key, items, map_key_coder, map_value_coder):
+        output_stream = coder_impl.create_OutputStream()
+        output_stream.write_bigendian_int32(len(items))
+        for request_flag, map_key, map_value in items:
+            output_stream.write_byte(request_flag)
+            # Not all the coder impls will serialize the length of bytes when we set the "nested"
+            # param to "True", so we need to encode the length of bytes manually.
+            tmp_out = coder_impl.create_OutputStream()
+            map_key_coder.encode_to_stream(map_key, tmp_out, True)
+            serialized_data = tmp_out.get()
+            output_stream.write_bigendian_int32(len(serialized_data))
+            output_stream.write(serialized_data)
+            if request_flag == self.SET_VALUE:
+                tmp_out = coder_impl.create_OutputStream()
+                map_value_coder.encode_to_stream(map_value, tmp_out, True)
+                serialized_data = tmp_out.get()
+                output_stream.write_bigendian_int32(len(serialized_data))
+                output_stream.write(serialized_data)
+        return self._underlying.append_raw(state_key, output_stream.get())
+
+    @staticmethod
+    def _convert_to_cache_key(state_key):
+        return state_key.SerializeToString()
+
+
+class InternalSynchronousMapRuntimeState(object):
+
+    def __init__(self,
+                 map_state_handler: CachingMapStateHandler,
+                 state_key,
+                 map_key_coder,
+                 map_value_coder,
+                 max_write_cache_entries):
+        self._map_state_handler = map_state_handler
+        self._state_key = state_key
+        self._map_key_coder = map_key_coder
+        self._map_key_coder_impl = map_key_coder._create_impl()
+        self._map_value_coder = map_value_coder
+        self._map_value_coder_impl = map_value_coder._create_impl()
+        self._write_cache = dict()
+        self._max_write_cache_entries = max_write_cache_entries
+        self._is_empty = None
+        self._cleared = False
+
+    def get(self, map_key):
+        if map_key in self._write_cache:
+            existed_and_value = self._write_cache[map_key]
+            if existed_and_value[0]:
+                return existed_and_value[1]
+            else:
+                raise KeyError("Mapping key %s not found!" % map_key)

Review comment:
       Mapping key -> Map key? Need also check the other places.

##########
File path: flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/aggregate/PythonStreamGroupAggregateOperator.java
##########
@@ -136,7 +136,14 @@
 	/**
 	 * The maximum NUMBER of the states cached in Python side.
 	 */
-	private int stateCacheSize;
+	private final int stateCacheSize;
+
+	/**
+	 * The maximum number of cached items which read from Java side in a Python MapState.

Review comment:
       ```suggestion
   	 * The maximum number of cached entries in a single Python MapState.
   ```

##########
File path: flink-python/src/main/java/org/apache/flink/streaming/api/runners/python/beam/BeamPythonFunctionRunner.java
##########
@@ -636,5 +670,173 @@ private static StateRequestHandler getStateRequestHandler(
 				VoidNamespaceSerializer.INSTANCE,
 				listStateDescriptor);
 		}
+
+		private CompletionStage<BeamFnApi.StateResponse.Builder> handleMapState(
+			BeamFnApi.StateRequest request) throws Exception {
+			// Currently the `beam_fn_api.proto` does not support MapState, so we use the
+			// the `MultimapSideInput` message to mark the state as a MapState for now.
+			if (request.getStateKey().hasMultimapSideInput()) {
+				BeamFnApi.StateKey.MultimapSideInput mapUserState = request.getStateKey().getMultimapSideInput();
+				// get key
+				byte[] keyBytes = mapUserState.getKey().toByteArray();
+				bais.setBuffer(keyBytes, 0, keyBytes.length);
+				Object key = keySerializer.deserialize(baisWrapper);
+				keyedStateBackend.setCurrentKey(
+					((RowDataSerializer) keyedStateBackend.getKeySerializer()).toBinaryRow((RowData) key));
+			} else {
+				throw new RuntimeException("Unsupported bag state request: " + request);
+			}
+
+			switch (request.getRequestCase()) {
+				case GET:
+					return handleMapGetRequest(request);
+				case APPEND:
+					return handleMapAppendRequest(request);
+				case CLEAR:
+					return handleMapClearRequest(request);
+				default:
+					throw new RuntimeException(
+						String.format(
+							"Unsupported request type %s for user state.", request.getRequestCase()));
+			}
+		}
+
+		private CompletionStage<BeamFnApi.StateResponse.Builder> handleMapGetRequest(BeamFnApi.StateRequest request)
+				throws Exception {
+			// The structure of get request bytes is:
+			// [flag(1 byte)][serialized map key(if needed)]
+			byte[] getRequest = request.getGet().getContinuationToken().toByteArray();
+			byte getFlag = getRequest[0];
+			if (reuseByteArrayWrapper == null) {
+				reuseByteArrayWrapper = new ByteArrayWrapper(getRequest, 1);

Review comment:
       what about initializing reuseByteArrayWrapper in the constructor?

##########
File path: flink-python/src/main/java/org/apache/flink/streaming/api/utils/ByteArrayWrapper.java
##########
@@ -0,0 +1,109 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.api.typeutils;
+
+import java.io.Serializable;
+
+/**
+ * A wrapper of the byte array. This class is used to calculate a deterministic hash code of a byte array.
+ */

Review comment:
       @internal

##########
File path: flink-python/src/main/java/org/apache/flink/streaming/api/utils/ByteArrayWrapperSerializer.java
##########
@@ -0,0 +1,114 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.api.typeutils;

Review comment:
       package name is incorrect

##########
File path: flink-python/pyflink/fn_execution/state_impl.py
##########
@@ -191,19 +550,57 @@ def _create_state(self, state_spec: userstate.StateSpec) -> userstate.Accumulati
         else:
             raise NotImplementedError(state_spec)
 
+    def _create_internal_map_state(self, name, map_key_coder, map_value_coder):
+        # Currently the `beam_fn_api.proto` does not support MapState, so we use the
+        # the `MultimapSideInput` message to mark the state as a MapState for now.
+        state_key = beam_fn_api_pb2.StateKey(
+            multimap_side_input=beam_fn_api_pb2.StateKey.MultimapSideInput(
+                transform_id="",
+                side_input_id=name,
+                key=self._encoded_current_key))
+        return InternalSynchronousMapRuntimeState(
+            self._map_state_handler,
+            state_key,
+            map_key_coder,
+            map_value_coder,
+            self._map_state_write_cache_size)
+
     def set_current_key(self, key):
+        if key == self._current_key:
+            return
+        encoded_old_key = self._encoded_current_key
         self._current_key = key
         self._encoded_current_key = self._key_coder_impl.encode_nested(self._current_key)
         for state_name, state_obj in self._all_states.items():
-            state_obj._internal_state = \
-                self._get_internal_bag_state(state_name, state_obj._internal_state._value_coder)
+            if self._state_cache_size > 0:
+                # cache old internal state
+                self._internal_state_cache.put(
+                    (state_name, encoded_old_key), state_obj._internal_state)
+            if isinstance(state_obj, (SynchronousValueRuntimeState, SynchronousListRuntimeState)):
+                state_obj._internal_state = self._get_internal_bag_state(
+                    state_name, state_obj._internal_state._value_coder)
+            elif isinstance(state_obj, SynchronousMapRuntimeState):
+                state_obj._internal_state = self._get_internal_map_state(
+                    state_name,
+                    state_obj._internal_state._map_key_coder,
+                    state_obj._internal_state._map_value_coder)
+            else:
+                raise Exception("Unknown internal state '%s': %s" % (state_name, state_obj))
 
     def get_current_key(self):
         return self._current_key
 
     def commit(self):
-        for internal_state in self._all_internal_states:
-            internal_state.commit()
-            # reset the status of the internal state to reuse the object cross bundle
+        for internal_state in self._internal_state_cache:
+            self.commit_internal_state(None, internal_state)
+        for name, state in self._all_states.items():
+            if (name, self._encoded_current_key) not in self._internal_state_cache:
+                self.commit_internal_state(None, state._internal_state)
+
+    @staticmethod
+    def commit_internal_state(key, internal_state):

Review comment:
       key is not used

##########
File path: flink-python/src/main/java/org/apache/flink/streaming/api/utils/ByteArrayWrapperSerializer.java
##########
@@ -0,0 +1,114 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.api.typeutils;
+
+import org.apache.flink.api.common.typeutils.SimpleTypeSerializerSnapshot;
+import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot;
+import org.apache.flink.api.common.typeutils.base.TypeSerializerSingleton;
+import org.apache.flink.core.memory.DataInputView;
+import org.apache.flink.core.memory.DataOutputView;
+
+import java.io.IOException;
+import java.util.Arrays;
+
+/**
+ * The serializer of {@link ByteArrayWrapper}.
+ */

Review comment:
       @Internal

##########
File path: flink-python/pyflink/fn_execution/state_impl.py
##########
@@ -119,13 +121,331 @@ def clear(self):
         self._internal_state.clear()
 
 
+class CachedMapState(LRUCache):
+
+    def __init__(self, max_entries):
+        super(CachedMapState, self).__init__(max_entries, None)
+        self._all_data_cached = False
+        self._existed_keys = set()
+
+        def on_evict(key, value):
+            self._existed_keys.remove(key)
+            self._all_data_cached = False
+
+        self.set_on_evict(on_evict)
+
+    def set_all_data_cached(self):
+        self._all_data_cached = True
+
+    def is_all_data_cached(self):
+        return self._all_data_cached
+
+    def put(self, key, existed_and_value):
+        if existed_and_value[0]:
+            self._existed_keys.add(key)
+        super(CachedMapState, self).put(key, existed_and_value)
+
+    def get_cached_existed_keys(self):
+        return self._existed_keys
+
+
+class CachingMapStateHandler(object):
+    # GET request flags
+    GET_FLAG = 0
+    ITERATE_FLAG = 1
+    CHECK_EMPTY_FLAG = 2
+    # GET response flags
+    EXIST_FLAG = 0
+    IS_NONE_FLAG = 1
+    NOT_EXIST_FLAG = 2
+    IS_EMPTY_FLAG = 3
+    NOT_EMPTY_FLAG = 4
+    # APPEND request flags
+    DELETE = 0
+    SET_NONE = 1
+    SET_VALUE = 2
+
+    def __init__(self, caching_state_handler, max_cached_map_key_entries):
+        self._state_cache = caching_state_handler._state_cache
+        self._underlying = caching_state_handler._underlying
+        self._context = caching_state_handler._context
+        self._max_cached_map_key_entries = max_cached_map_key_entries
+
+    def _get_cache_token(self):
+        if not self._state_cache.is_cache_enabled():
+            return None
+        if self._context.user_state_cache_token:
+            return self._context.user_state_cache_token
+        else:
+            return self._context.bundle_cache_token
+
+    def blocking_get(self, state_key, map_key, map_key_coder, map_value_coder):
+        cache_token = self._get_cache_token()
+        if not cache_token:
+            # Cache disabled / no cache token. Can't do a lookup/store in the cache.
+            return self._get_raw(state_key, map_key, map_key_coder, map_value_coder)
+        # Cache lookup
+        cache_state_key = self._convert_to_cache_key(state_key)
+        cached_map_state = self._state_cache.get(cache_state_key, cache_token)
+        if cached_map_state is None:
+            existed, value = self._get_raw(state_key, map_key, map_key_coder, map_value_coder)
+            cached_map_state = CachedMapState(self._max_cached_map_key_entries)
+            cached_map_state.put(map_key, (existed, value))
+            self._state_cache.put(cache_state_key, cache_token, cached_map_state)
+            return existed, value
+        else:
+            cached_value = cached_map_state.get(map_key)
+            if cached_value is None:
+                existed, value = self._get_raw(state_key, map_key, map_key_coder, map_value_coder)
+                cached_map_state.put(map_key, (existed, value))
+                return existed, value
+            else:
+                return cached_value
+
+    def extend(self, state_key, items: List[Tuple[int, Any, Any]], map_key_coder, map_value_coder):
+        cache_token = self._get_cache_token()
+        if cache_token:
+            # Cache lookup
+            cache_state_key = self._convert_to_cache_key(state_key)
+            cached_map_state = self._state_cache.get(cache_state_key, cache_token)
+            if cached_map_state is None:
+                cached_map_state = CachedMapState(self._max_cached_map_key_entries)
+                self._state_cache.put(cache_state_key, cache_token, cached_map_state)
+            for request_flag, map_key, map_value in items:
+                if request_flag == self.DELETE:
+                    cached_map_state.put(map_key, (False, None))
+                elif request_flag == self.SET_NONE:
+                    cached_map_state.put(map_key, (True, None))
+                elif request_flag == self.SET_VALUE:
+                    cached_map_state.put(map_key, (True, map_value))
+                else:
+                    raise Exception("Unknown flag: " + str(request_flag))
+        self._append_raw(
+            state_key,
+            items,
+            map_key_coder,
+            map_value_coder)
+
+    def check_empty(self, state_key):
+        cache_token = self._get_cache_token()
+        if cache_token:
+            # Cache lookup
+            cache_state_key = self._convert_to_cache_key(state_key)
+            cached_map_state = self._state_cache.get(cache_state_key, cache_token)
+            if cached_map_state is not None:
+                if cached_map_state.is_all_data_cached() and \
+                        len(cached_map_state.get_cached_existed_keys()) == 0:
+                    return True
+                elif len(cached_map_state.get_cached_existed_keys()) > 0:
+                    return False
+        return self._check_empty_raw(state_key)
+
+    def clear(self, state_key):
+        cache_token = self._get_cache_token()
+        if cache_token:
+            cache_key = self._convert_to_cache_key(state_key)
+            self._state_cache.evict(cache_key, cache_token)
+        return self._underlying.clear(state_key)
+
+    def _check_empty_raw(self, state_key):
+        output_stream = coder_impl.create_OutputStream()
+        output_stream.write_byte(self.CHECK_EMPTY_FLAG)
+        continuation_token = output_stream.get()
+        data, response_token = self._underlying.get_raw(state_key, continuation_token)
+        if data[0] == self.IS_EMPTY_FLAG:
+            return True
+        elif data[0] == self.NOT_EMPTY_FLAG:
+            return False
+        else:
+            raise Exception("Unknown response flag: " + str(data[0]))
+
+    def _get_raw(self, state_key, map_key, map_key_coder, map_value_coder):
+        output_stream = coder_impl.create_OutputStream()
+        output_stream.write_byte(self.GET_FLAG)
+        map_key_coder.encode_to_stream(map_key, output_stream, True)
+        continuation_token = output_stream.get()
+        data, response_token = self._underlying.get_raw(state_key, continuation_token)
+        input_stream = coder_impl.create_InputStream(data)
+        result_flag = input_stream.read_byte()
+        if result_flag == self.EXIST_FLAG:
+            return True, map_value_coder.decode_from_stream(input_stream, True)
+        elif result_flag == self.IS_NONE_FLAG:
+            return True, None
+        elif result_flag == self.NOT_EXIST_FLAG:
+            return False, None
+        else:
+            raise Exception("Unknown response flag: " + str(result_flag))
+
+    def _append_raw(self, state_key, items, map_key_coder, map_value_coder):
+        output_stream = coder_impl.create_OutputStream()
+        output_stream.write_bigendian_int32(len(items))
+        for request_flag, map_key, map_value in items:
+            output_stream.write_byte(request_flag)
+            # Not all the coder impls will serialize the length of bytes when we set the "nested"
+            # param to "True", so we need to encode the length of bytes manually.
+            tmp_out = coder_impl.create_OutputStream()
+            map_key_coder.encode_to_stream(map_key, tmp_out, True)
+            serialized_data = tmp_out.get()
+            output_stream.write_bigendian_int32(len(serialized_data))
+            output_stream.write(serialized_data)
+            if request_flag == self.SET_VALUE:
+                tmp_out = coder_impl.create_OutputStream()
+                map_value_coder.encode_to_stream(map_value, tmp_out, True)
+                serialized_data = tmp_out.get()
+                output_stream.write_bigendian_int32(len(serialized_data))
+                output_stream.write(serialized_data)
+        return self._underlying.append_raw(state_key, output_stream.get())
+
+    @staticmethod
+    def _convert_to_cache_key(state_key):
+        return state_key.SerializeToString()
+
+
+class InternalSynchronousMapRuntimeState(object):
+
+    def __init__(self,
+                 map_state_handler: CachingMapStateHandler,
+                 state_key,
+                 map_key_coder,
+                 map_value_coder,
+                 max_write_cache_entries):
+        self._map_state_handler = map_state_handler
+        self._state_key = state_key
+        self._map_key_coder = map_key_coder
+        self._map_key_coder_impl = map_key_coder._create_impl()
+        self._map_value_coder = map_value_coder
+        self._map_value_coder_impl = map_value_coder._create_impl()
+        self._write_cache = dict()
+        self._max_write_cache_entries = max_write_cache_entries
+        self._is_empty = None
+        self._cleared = False
+
+    def get(self, map_key):
+        if map_key in self._write_cache:
+            existed_and_value = self._write_cache[map_key]
+            if existed_and_value[0]:
+                return existed_and_value[1]
+            else:
+                raise KeyError("Mapping key %s not found!" % map_key)
+        if self._cleared:
+            raise KeyError("Mapping key %s not found!" % map_key)
+        existed, value = self._map_state_handler.blocking_get(
+            self._state_key, map_key, self._map_key_coder_impl, self._map_value_coder_impl)
+        if existed:
+            return value
+        else:
+            raise KeyError("Mapping key %s not found!" % map_key)
+
+    def put(self, map_key, map_value):
+        self._write_cache[map_key] = (True, map_value)
+        self._is_empty = False
+        if len(self._write_cache) >= self._max_write_cache_entries:
+            self.commit()
+
+    def put_all(self, dict_value):
+        for map_key, map_value in dict_value:
+            self._write_cache[map_key] = (True, map_value)
+        self._is_empty = False
+        if len(self._write_cache) >= self._max_write_cache_entries:
+            self.commit()
+
+    def remove(self, map_key):
+        self._write_cache[map_key] = (False, None)
+        self._is_empty = None
+        if len(self._write_cache) >= self._max_write_cache_entries:
+            self.commit()
+
+    def contains(self, map_key):
+        try:
+            self.get(map_key)
+            return True
+        except KeyError:
+            return False
+
+    def is_empty(self):
+        if self._is_empty is None:
+            self._is_empty = self._map_state_handler.check_empty(self._state_key)
+        return self._is_empty
+
+    def clear(self):
+        self._cleared = True
+        self._is_empty = True
+        self._write_cache.clear()
+
+    def commit(self):
+        to_await = None
+        if self._cleared:
+            to_await = self._map_state_handler.clear(self._state_key)
+        if self._write_cache:
+            append_items = []
+            for map_key, existed_and_value in self._write_cache.items():

Review comment:
       ```suggestion
               for map_key, (exists, value) in self._write_cache.items():
   ```

##########
File path: flink-python/src/main/java/org/apache/flink/python/PythonOptions.java
##########
@@ -173,4 +173,24 @@
 		.defaultValue(1000)
 		.withDescription("The maximum number of states cached in a Python UDF worker. Note that this " +
 			"is an experimental flag and might not be available in future releases.");
+
+	/**
+	 * The maximum number of cached items which read from Java side in a Python MapState.
+	 */
+	@Experimental
+	public static final ConfigOption<Integer> MAP_STATE_READ_CACHE_SIZE = ConfigOptions
+		.key("python.map-state.read.cache.size")
+		.defaultValue(1000)
+		.withDescription("The maximum number of cached items which read from Java side in a Python MapState. "
+			+ "Note that this is an experimental flag and might not be available in future releases.");
+
+	/**
+	 * The maximum number of write requests cached in a Python MapState.
+	 */
+	@Experimental
+	public static final ConfigOption<Integer> MAP_STATE_WRITE_CACHE_SIZE = ConfigOptions
+		.key("python.map-state.write.cache.size")
+		.defaultValue(1000)
+		.withDescription("The maximum number of write requests cached in a Python MapState. Note that this " +

Review comment:
       ```suggestion
   		.withDescription("The maximum number of cached write requests for a single Python MapState. The write requests will be flushed to the state backend (managed in the Java operator) when the number of cached write requests exceed this limit. Note that this " +
   ```

##########
File path: flink-python/src/main/java/org/apache/flink/python/PythonOptions.java
##########
@@ -173,4 +173,24 @@
 		.defaultValue(1000)
 		.withDescription("The maximum number of states cached in a Python UDF worker. Note that this " +
 			"is an experimental flag and might not be available in future releases.");
+
+	/**
+	 * The maximum number of cached items which read from Java side in a Python MapState.
+	 */
+	@Experimental
+	public static final ConfigOption<Integer> MAP_STATE_READ_CACHE_SIZE = ConfigOptions
+		.key("python.map-state.read.cache.size")
+		.defaultValue(1000)
+		.withDescription("The maximum number of cached items which read from Java side in a Python MapState. "

Review comment:
       ```suggestion
   		.withDescription("The maximum number of cached entries for a single Python MapState. "
   ```

##########
File path: flink-python/pyflink/fn_execution/state_impl.py
##########
@@ -119,13 +121,331 @@ def clear(self):
         self._internal_state.clear()
 
 
+class CachedMapState(LRUCache):
+
+    def __init__(self, max_entries):
+        super(CachedMapState, self).__init__(max_entries, None)
+        self._all_data_cached = False
+        self._existed_keys = set()
+
+        def on_evict(key, value):
+            self._existed_keys.remove(key)
+            self._all_data_cached = False
+
+        self.set_on_evict(on_evict)
+
+    def set_all_data_cached(self):
+        self._all_data_cached = True
+
+    def is_all_data_cached(self):
+        return self._all_data_cached
+
+    def put(self, key, existed_and_value):
+        if existed_and_value[0]:
+            self._existed_keys.add(key)
+        super(CachedMapState, self).put(key, existed_and_value)
+
+    def get_cached_existed_keys(self):
+        return self._existed_keys
+
+
+class CachingMapStateHandler(object):
+    # GET request flags
+    GET_FLAG = 0
+    ITERATE_FLAG = 1
+    CHECK_EMPTY_FLAG = 2
+    # GET response flags
+    EXIST_FLAG = 0
+    IS_NONE_FLAG = 1
+    NOT_EXIST_FLAG = 2
+    IS_EMPTY_FLAG = 3
+    NOT_EMPTY_FLAG = 4
+    # APPEND request flags
+    DELETE = 0
+    SET_NONE = 1
+    SET_VALUE = 2
+
+    def __init__(self, caching_state_handler, max_cached_map_key_entries):
+        self._state_cache = caching_state_handler._state_cache
+        self._underlying = caching_state_handler._underlying
+        self._context = caching_state_handler._context
+        self._max_cached_map_key_entries = max_cached_map_key_entries
+
+    def _get_cache_token(self):
+        if not self._state_cache.is_cache_enabled():
+            return None
+        if self._context.user_state_cache_token:
+            return self._context.user_state_cache_token
+        else:
+            return self._context.bundle_cache_token
+
+    def blocking_get(self, state_key, map_key, map_key_coder, map_value_coder):
+        cache_token = self._get_cache_token()
+        if not cache_token:
+            # Cache disabled / no cache token. Can't do a lookup/store in the cache.
+            return self._get_raw(state_key, map_key, map_key_coder, map_value_coder)
+        # Cache lookup
+        cache_state_key = self._convert_to_cache_key(state_key)
+        cached_map_state = self._state_cache.get(cache_state_key, cache_token)
+        if cached_map_state is None:
+            existed, value = self._get_raw(state_key, map_key, map_key_coder, map_value_coder)
+            cached_map_state = CachedMapState(self._max_cached_map_key_entries)
+            cached_map_state.put(map_key, (existed, value))
+            self._state_cache.put(cache_state_key, cache_token, cached_map_state)
+            return existed, value
+        else:
+            cached_value = cached_map_state.get(map_key)
+            if cached_value is None:
+                existed, value = self._get_raw(state_key, map_key, map_key_coder, map_value_coder)
+                cached_map_state.put(map_key, (existed, value))
+                return existed, value
+            else:
+                return cached_value
+
+    def extend(self, state_key, items: List[Tuple[int, Any, Any]], map_key_coder, map_value_coder):
+        cache_token = self._get_cache_token()
+        if cache_token:
+            # Cache lookup
+            cache_state_key = self._convert_to_cache_key(state_key)
+            cached_map_state = self._state_cache.get(cache_state_key, cache_token)
+            if cached_map_state is None:
+                cached_map_state = CachedMapState(self._max_cached_map_key_entries)
+                self._state_cache.put(cache_state_key, cache_token, cached_map_state)
+            for request_flag, map_key, map_value in items:
+                if request_flag == self.DELETE:
+                    cached_map_state.put(map_key, (False, None))
+                elif request_flag == self.SET_NONE:
+                    cached_map_state.put(map_key, (True, None))
+                elif request_flag == self.SET_VALUE:
+                    cached_map_state.put(map_key, (True, map_value))
+                else:
+                    raise Exception("Unknown flag: " + str(request_flag))
+        self._append_raw(
+            state_key,
+            items,
+            map_key_coder,
+            map_value_coder)
+
+    def check_empty(self, state_key):
+        cache_token = self._get_cache_token()
+        if cache_token:
+            # Cache lookup
+            cache_state_key = self._convert_to_cache_key(state_key)
+            cached_map_state = self._state_cache.get(cache_state_key, cache_token)
+            if cached_map_state is not None:
+                if cached_map_state.is_all_data_cached() and \
+                        len(cached_map_state.get_cached_existed_keys()) == 0:
+                    return True
+                elif len(cached_map_state.get_cached_existed_keys()) > 0:
+                    return False
+        return self._check_empty_raw(state_key)
+
+    def clear(self, state_key):
+        cache_token = self._get_cache_token()
+        if cache_token:
+            cache_key = self._convert_to_cache_key(state_key)
+            self._state_cache.evict(cache_key, cache_token)
+        return self._underlying.clear(state_key)
+
+    def _check_empty_raw(self, state_key):
+        output_stream = coder_impl.create_OutputStream()
+        output_stream.write_byte(self.CHECK_EMPTY_FLAG)
+        continuation_token = output_stream.get()
+        data, response_token = self._underlying.get_raw(state_key, continuation_token)
+        if data[0] == self.IS_EMPTY_FLAG:
+            return True
+        elif data[0] == self.NOT_EMPTY_FLAG:
+            return False
+        else:
+            raise Exception("Unknown response flag: " + str(data[0]))
+
+    def _get_raw(self, state_key, map_key, map_key_coder, map_value_coder):
+        output_stream = coder_impl.create_OutputStream()
+        output_stream.write_byte(self.GET_FLAG)
+        map_key_coder.encode_to_stream(map_key, output_stream, True)
+        continuation_token = output_stream.get()
+        data, response_token = self._underlying.get_raw(state_key, continuation_token)
+        input_stream = coder_impl.create_InputStream(data)
+        result_flag = input_stream.read_byte()
+        if result_flag == self.EXIST_FLAG:
+            return True, map_value_coder.decode_from_stream(input_stream, True)
+        elif result_flag == self.IS_NONE_FLAG:
+            return True, None
+        elif result_flag == self.NOT_EXIST_FLAG:
+            return False, None
+        else:
+            raise Exception("Unknown response flag: " + str(result_flag))
+
+    def _append_raw(self, state_key, items, map_key_coder, map_value_coder):
+        output_stream = coder_impl.create_OutputStream()
+        output_stream.write_bigendian_int32(len(items))
+        for request_flag, map_key, map_value in items:
+            output_stream.write_byte(request_flag)
+            # Not all the coder impls will serialize the length of bytes when we set the "nested"
+            # param to "True", so we need to encode the length of bytes manually.
+            tmp_out = coder_impl.create_OutputStream()
+            map_key_coder.encode_to_stream(map_key, tmp_out, True)
+            serialized_data = tmp_out.get()
+            output_stream.write_bigendian_int32(len(serialized_data))
+            output_stream.write(serialized_data)
+            if request_flag == self.SET_VALUE:
+                tmp_out = coder_impl.create_OutputStream()
+                map_value_coder.encode_to_stream(map_value, tmp_out, True)
+                serialized_data = tmp_out.get()
+                output_stream.write_bigendian_int32(len(serialized_data))
+                output_stream.write(serialized_data)
+        return self._underlying.append_raw(state_key, output_stream.get())
+
+    @staticmethod
+    def _convert_to_cache_key(state_key):
+        return state_key.SerializeToString()
+
+
+class InternalSynchronousMapRuntimeState(object):
+
+    def __init__(self,
+                 map_state_handler: CachingMapStateHandler,
+                 state_key,
+                 map_key_coder,
+                 map_value_coder,
+                 max_write_cache_entries):
+        self._map_state_handler = map_state_handler
+        self._state_key = state_key
+        self._map_key_coder = map_key_coder
+        self._map_key_coder_impl = map_key_coder._create_impl()
+        self._map_value_coder = map_value_coder
+        self._map_value_coder_impl = map_value_coder._create_impl()
+        self._write_cache = dict()
+        self._max_write_cache_entries = max_write_cache_entries
+        self._is_empty = None
+        self._cleared = False
+
+    def get(self, map_key):
+        if map_key in self._write_cache:
+            existed_and_value = self._write_cache[map_key]
+            if existed_and_value[0]:
+                return existed_and_value[1]
+            else:
+                raise KeyError("Mapping key %s not found!" % map_key)
+        if self._cleared:
+            raise KeyError("Mapping key %s not found!" % map_key)
+        existed, value = self._map_state_handler.blocking_get(

Review comment:
       ```suggestion
           exists, value = self._map_state_handler.blocking_get(
   ```

##########
File path: flink-python/pyflink/fn_execution/state_impl.py
##########
@@ -119,13 +121,331 @@ def clear(self):
         self._internal_state.clear()
 
 
+class CachedMapState(LRUCache):
+
+    def __init__(self, max_entries):
+        super(CachedMapState, self).__init__(max_entries, None)
+        self._all_data_cached = False
+        self._existed_keys = set()
+
+        def on_evict(key, value):
+            self._existed_keys.remove(key)
+            self._all_data_cached = False
+
+        self.set_on_evict(on_evict)
+
+    def set_all_data_cached(self):
+        self._all_data_cached = True
+
+    def is_all_data_cached(self):
+        return self._all_data_cached
+
+    def put(self, key, existed_and_value):
+        if existed_and_value[0]:
+            self._existed_keys.add(key)
+        super(CachedMapState, self).put(key, existed_and_value)
+
+    def get_cached_existed_keys(self):
+        return self._existed_keys
+
+
+class CachingMapStateHandler(object):
+    # GET request flags
+    GET_FLAG = 0
+    ITERATE_FLAG = 1
+    CHECK_EMPTY_FLAG = 2
+    # GET response flags
+    EXIST_FLAG = 0
+    IS_NONE_FLAG = 1
+    NOT_EXIST_FLAG = 2
+    IS_EMPTY_FLAG = 3
+    NOT_EMPTY_FLAG = 4
+    # APPEND request flags
+    DELETE = 0
+    SET_NONE = 1
+    SET_VALUE = 2
+
+    def __init__(self, caching_state_handler, max_cached_map_key_entries):
+        self._state_cache = caching_state_handler._state_cache
+        self._underlying = caching_state_handler._underlying
+        self._context = caching_state_handler._context
+        self._max_cached_map_key_entries = max_cached_map_key_entries
+
+    def _get_cache_token(self):
+        if not self._state_cache.is_cache_enabled():
+            return None
+        if self._context.user_state_cache_token:
+            return self._context.user_state_cache_token
+        else:
+            return self._context.bundle_cache_token
+
+    def blocking_get(self, state_key, map_key, map_key_coder, map_value_coder):
+        cache_token = self._get_cache_token()
+        if not cache_token:
+            # Cache disabled / no cache token. Can't do a lookup/store in the cache.
+            return self._get_raw(state_key, map_key, map_key_coder, map_value_coder)
+        # Cache lookup
+        cache_state_key = self._convert_to_cache_key(state_key)
+        cached_map_state = self._state_cache.get(cache_state_key, cache_token)
+        if cached_map_state is None:
+            existed, value = self._get_raw(state_key, map_key, map_key_coder, map_value_coder)
+            cached_map_state = CachedMapState(self._max_cached_map_key_entries)
+            cached_map_state.put(map_key, (existed, value))
+            self._state_cache.put(cache_state_key, cache_token, cached_map_state)
+            return existed, value
+        else:
+            cached_value = cached_map_state.get(map_key)
+            if cached_value is None:
+                existed, value = self._get_raw(state_key, map_key, map_key_coder, map_value_coder)
+                cached_map_state.put(map_key, (existed, value))
+                return existed, value
+            else:
+                return cached_value
+
+    def extend(self, state_key, items: List[Tuple[int, Any, Any]], map_key_coder, map_value_coder):
+        cache_token = self._get_cache_token()
+        if cache_token:
+            # Cache lookup
+            cache_state_key = self._convert_to_cache_key(state_key)
+            cached_map_state = self._state_cache.get(cache_state_key, cache_token)
+            if cached_map_state is None:
+                cached_map_state = CachedMapState(self._max_cached_map_key_entries)
+                self._state_cache.put(cache_state_key, cache_token, cached_map_state)
+            for request_flag, map_key, map_value in items:
+                if request_flag == self.DELETE:
+                    cached_map_state.put(map_key, (False, None))
+                elif request_flag == self.SET_NONE:
+                    cached_map_state.put(map_key, (True, None))
+                elif request_flag == self.SET_VALUE:
+                    cached_map_state.put(map_key, (True, map_value))
+                else:
+                    raise Exception("Unknown flag: " + str(request_flag))
+        self._append_raw(
+            state_key,
+            items,
+            map_key_coder,
+            map_value_coder)
+
+    def check_empty(self, state_key):
+        cache_token = self._get_cache_token()
+        if cache_token:
+            # Cache lookup
+            cache_state_key = self._convert_to_cache_key(state_key)
+            cached_map_state = self._state_cache.get(cache_state_key, cache_token)
+            if cached_map_state is not None:
+                if cached_map_state.is_all_data_cached() and \
+                        len(cached_map_state.get_cached_existed_keys()) == 0:
+                    return True
+                elif len(cached_map_state.get_cached_existed_keys()) > 0:
+                    return False
+        return self._check_empty_raw(state_key)
+
+    def clear(self, state_key):
+        cache_token = self._get_cache_token()
+        if cache_token:
+            cache_key = self._convert_to_cache_key(state_key)
+            self._state_cache.evict(cache_key, cache_token)
+        return self._underlying.clear(state_key)
+
+    def _check_empty_raw(self, state_key):
+        output_stream = coder_impl.create_OutputStream()
+        output_stream.write_byte(self.CHECK_EMPTY_FLAG)
+        continuation_token = output_stream.get()
+        data, response_token = self._underlying.get_raw(state_key, continuation_token)
+        if data[0] == self.IS_EMPTY_FLAG:
+            return True
+        elif data[0] == self.NOT_EMPTY_FLAG:
+            return False
+        else:
+            raise Exception("Unknown response flag: " + str(data[0]))
+
+    def _get_raw(self, state_key, map_key, map_key_coder, map_value_coder):
+        output_stream = coder_impl.create_OutputStream()
+        output_stream.write_byte(self.GET_FLAG)
+        map_key_coder.encode_to_stream(map_key, output_stream, True)
+        continuation_token = output_stream.get()
+        data, response_token = self._underlying.get_raw(state_key, continuation_token)
+        input_stream = coder_impl.create_InputStream(data)
+        result_flag = input_stream.read_byte()
+        if result_flag == self.EXIST_FLAG:
+            return True, map_value_coder.decode_from_stream(input_stream, True)
+        elif result_flag == self.IS_NONE_FLAG:
+            return True, None
+        elif result_flag == self.NOT_EXIST_FLAG:
+            return False, None
+        else:
+            raise Exception("Unknown response flag: " + str(result_flag))
+
+    def _append_raw(self, state_key, items, map_key_coder, map_value_coder):
+        output_stream = coder_impl.create_OutputStream()
+        output_stream.write_bigendian_int32(len(items))
+        for request_flag, map_key, map_value in items:
+            output_stream.write_byte(request_flag)
+            # Not all the coder impls will serialize the length of bytes when we set the "nested"
+            # param to "True", so we need to encode the length of bytes manually.
+            tmp_out = coder_impl.create_OutputStream()
+            map_key_coder.encode_to_stream(map_key, tmp_out, True)
+            serialized_data = tmp_out.get()
+            output_stream.write_bigendian_int32(len(serialized_data))
+            output_stream.write(serialized_data)
+            if request_flag == self.SET_VALUE:
+                tmp_out = coder_impl.create_OutputStream()
+                map_value_coder.encode_to_stream(map_value, tmp_out, True)
+                serialized_data = tmp_out.get()
+                output_stream.write_bigendian_int32(len(serialized_data))
+                output_stream.write(serialized_data)
+        return self._underlying.append_raw(state_key, output_stream.get())
+
+    @staticmethod
+    def _convert_to_cache_key(state_key):
+        return state_key.SerializeToString()
+
+
+class InternalSynchronousMapRuntimeState(object):
+
+    def __init__(self,
+                 map_state_handler: CachingMapStateHandler,
+                 state_key,
+                 map_key_coder,
+                 map_value_coder,
+                 max_write_cache_entries):
+        self._map_state_handler = map_state_handler
+        self._state_key = state_key
+        self._map_key_coder = map_key_coder
+        self._map_key_coder_impl = map_key_coder._create_impl()
+        self._map_value_coder = map_value_coder
+        self._map_value_coder_impl = map_value_coder._create_impl()
+        self._write_cache = dict()
+        self._max_write_cache_entries = max_write_cache_entries
+        self._is_empty = None
+        self._cleared = False
+
+    def get(self, map_key):
+        if map_key in self._write_cache:
+            existed_and_value = self._write_cache[map_key]

Review comment:
       ```suggestion
               (exists, value) = self._write_cache[map_key]
   ```




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org