You are viewing a plain text version of this content. The canonical link for it is here.
Posted to github@beam.apache.org by "zhengbuqian (via GitHub)" <gi...@apache.org> on 2023/04/19 23:16:01 UTC

[GitHub] [beam] zhengbuqian commented on a diff in pull request #23492: Add Windmill support for MultimapState

zhengbuqian commented on code in PR #23492:
URL: https://github.com/apache/beam/pull/23492#discussion_r1151347185


##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternals.java:
##########
@@ -1588,7 +1598,648 @@ private Future<Iterable<Map.Entry<ByteString, V>>> getFuture() {
         return reader.valuePrefixFuture(stateKeyPrefix, stateFamily, valueCoder);
       }
     }
-  };
+  }
+
+  private static class WindmillMultimap<K, V> extends SimpleWindmillState
+      implements MultimapState<K, V> {
+
+    private final StateNamespace namespace;
+    private final StateTag<MultimapState<K, V>> address;
+    private final ByteString stateKey;
+    private final String stateFamily;
+    private final Coder<K> keyCoder;
+    private final Coder<V> valueCoder;
+
+    private class KeyState {
+      final K originalKey;
+      KeyExistence existence;
+      // valuesCached can be true if only existence == KNOWN_EXIST and all values of this key is
+      // cached (both values and localAdditions).
+      boolean valuesCached;
+      // Represents the values in windmill. When new values are added during user processing, they
+      // are added to localAdditions but not values. Those new values will be added to values only
+      // after they are persisted into windmill and removed from localAdditions
+      ConcatIterables<V> values;
+      int valuesSize;
+
+      // When new values are added during user processing, they are added to localAdditions, so that
+      // we can later try to persist them in windmill. When a key is removed during user processing,
+      // we mark removedLocally to be true so that we can later try to delete it from windmill. If
+      // localAdditions is not empty and removedLocally is true, values in localAdditions will be
+      // added to windmill after old values in windmill are removed.
+      List<V> localAdditions;
+      boolean removedLocally;
+
+      KeyState(K originalKey) {
+        this.originalKey = originalKey;
+        existence = KeyExistence.UNKNOWN_EXISTENCE;
+        valuesCached = complete;
+        values = new ConcatIterables<>();
+        valuesSize = 0;
+        localAdditions = Lists.newArrayList();
+        removedLocally = false;
+      }
+    }
+
+    private enum KeyExistence {
+      // this key is known to exist
+      KNOWN_EXIST,
+      // this key is known to be nonexistent
+      KNOWN_NONEXISTENT,
+      // we don't know if this key is in this multimap, this is just to provide a mapping between
+      // the original key and the structural key.
+      UNKNOWN_EXISTENCE
+    }
+
+    private boolean cleared = false;
+    // We use the structural value of the keys as the key in keyStateMap, so that different java
+    // Objects with the same content will be treated as the same Multimap key.
+    private Map<Object, KeyState> keyStateMap = Maps.newHashMap();
+    // If true, all keys are cached in keyStateMap with existence == KNOWN_EXIST.
+    private boolean allKeysKnown = false;
+
+    private boolean complete = false;
+    // hasLocalAdditions and hasLocalRemovals track whether there are local changes that needs to be
+    // propagated to windmill.
+    private boolean hasLocalAdditions = false;
+    private boolean hasLocalRemovals = false;
+
+    private WindmillMultimap(
+        StateNamespace namespace,
+        StateTag<MultimapState<K, V>> address,
+        String stateFamily,
+        Coder<K> keyCoder,
+        Coder<V> valueCoder,
+        boolean isNewShardingKey) {
+      this.namespace = namespace;
+      this.address = address;
+      this.stateKey = encodeKey(namespace, address);
+      this.stateFamily = stateFamily;
+      this.keyCoder = keyCoder;
+      this.valueCoder = valueCoder;
+      this.complete = isNewShardingKey;
+      this.allKeysKnown = isNewShardingKey;
+    }
+
+    @Override
+    public void put(K key, V value) {
+      final Object structuralKey = keyCoder.structuralValue(key);
+      hasLocalAdditions = true;
+      keyStateMap.compute(
+          structuralKey,
+          (k, v) -> {
+            if (v == null) v = new KeyState(key);
+            v.existence = KeyExistence.KNOWN_EXIST;
+            v.localAdditions.add(value);
+            return v;
+          });
+    }
+
+    // Initiates a backend state read to fetch all entries if necessary.
+    private Future<Iterable<Map.Entry<ByteString, Iterable<V>>>> getFuture(boolean omitValues) {
+      if (complete) {
+        return Futures.immediateFuture(Collections.emptyList());
+      } else {
+        return reader.multimapFetchAllFuture(omitValues, stateKey, stateFamily, valueCoder);
+      }
+    }
+
+    // Initiates a backend state read to fetch a single entry if necessary.
+    private Future<Iterable<V>> getFutureForKey(K key) {
+      try {
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        keyCoder.encode(key, keyStream, Context.OUTER);
+        return reader.multimapFetchSingleEntryFuture(
+            keyStream.toByteString(), stateKey, stateFamily, valueCoder);
+      } catch (IOException e) {
+        throw new RuntimeException(e);
+      }
+    }
+
+    @Override
+    public ReadableState<Iterable<V>> get(K key) {
+      return new ReadableState<Iterable<V>>() {
+        final Object structuralKey = keyCoder.structuralValue(key);
+
+        @Override
+        public Iterable<V> read() {
+          KeyState keyState = keyStateMap.computeIfAbsent(structuralKey, k -> new KeyState(key));
+          if (keyState.existence == KeyExistence.KNOWN_NONEXISTENT) {
+            return Collections.emptyList();
+          }
+          if (allKeysKnown && keyState.existence == KeyExistence.UNKNOWN_EXISTENCE) {
+            keyStateMap.remove(structuralKey);
+            return Collections.emptyList();
+          }
+          Iterable<V> localNewValues =
+              Iterables.limit(keyState.localAdditions, keyState.localAdditions.size());
+          if (keyState.removedLocally) {
+            // this key has been removed locally but the removal hasn't been sent to windmill,
+            // thus values in windmill(if any) are obsolete, and we only care about local values.
+            return Iterables.unmodifiableIterable(localNewValues);
+          }
+          if (keyState.valuesCached || complete) {
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(
+                    Iterables.limit(keyState.values, keyState.valuesSize), localNewValues));
+          }
+          Future<Iterable<V>> persistedData = getFutureForKey(key);
+          try (Closeable scope = scopedReadState()) {
+            final Iterable<V> persistedValues = persistedData.get();
+            // Iterables.isEmpty() is O(1).
+            if (Iterables.isEmpty(persistedValues)) {
+              if (keyState.localAdditions.isEmpty()) {
+                // empty in both cache and windmill, mark key as KNOWN_NONEXISTENT.
+                keyState.existence = KeyExistence.KNOWN_NONEXISTENT;
+                return Collections.emptyList();
+              }
+              return Iterables.unmodifiableIterable(localNewValues);
+            }
+            keyState.existence = KeyExistence.KNOWN_EXIST;
+            if (persistedValues instanceof Weighted) {
+              keyState.valuesCached = true;
+              ConcatIterables<V> it = new ConcatIterables<>();
+              it.extendWith(persistedValues);
+              keyState.values = it;
+              keyState.valuesSize = Iterables.size(persistedValues);
+            }
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(persistedValues, localNewValues));
+          } catch (InterruptedException | ExecutionException | IOException e) {
+            if (e instanceof InterruptedException) {
+              Thread.currentThread().interrupt();
+            }
+            throw new RuntimeException("Unable to read Multimap state", e);
+          }
+        }
+
+        @Override
+        @SuppressWarnings("FutureReturnValueIgnored")
+        public ReadableState<Iterable<V>> readLater() {
+          WindmillMultimap.this.getFutureForKey(key);
+          return this;
+        }
+      };
+    }
+
+    @Override
+    protected WorkItemCommitRequest persistDirectly(WindmillStateCache.ForKeyAndFamily cache)
+        throws IOException {
+      if (!cleared && !hasLocalAdditions && !hasLocalRemovals) {
+        cache.put(namespace, address, this, 1);
+        return WorkItemCommitRequest.newBuilder().buildPartial();
+      }
+      WorkItemCommitRequest.Builder commitBuilder = WorkItemCommitRequest.newBuilder();
+      Windmill.TagMultimapUpdateRequest.Builder builder = commitBuilder.addMultimapUpdatesBuilder();
+      builder.setTag(stateKey).setStateFamily(stateFamily);
+
+      if (cleared) {
+        builder.setDeleteAll(true);
+      }
+      if (hasLocalRemovals || hasLocalAdditions) {
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        ByteStringOutputStream valueStream = new ByteStringOutputStream();
+        Iterator<Entry<Object, KeyState>> iterator = keyStateMap.entrySet().iterator();
+        while (iterator.hasNext()) {
+          KeyState keyState = iterator.next().getValue();
+          if (!keyState.removedLocally && keyState.localAdditions.isEmpty()) {
+            if (keyState.existence == KeyExistence.KNOWN_NONEXISTENT) iterator.remove();
+            continue;
+          }
+          keyCoder.encode(keyState.originalKey, keyStream, Context.OUTER);
+          ByteString encodedKey = keyStream.toByteStringAndReset();
+          Windmill.TagMultimapEntry.Builder entryBuilder = builder.addUpdatesBuilder();
+          entryBuilder.setEntryName(encodedKey);
+          entryBuilder.setDeleteAll(keyState.removedLocally);

Review Comment:
   Done



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternals.java:
##########
@@ -1588,7 +1598,648 @@ private Future<Iterable<Map.Entry<ByteString, V>>> getFuture() {
         return reader.valuePrefixFuture(stateKeyPrefix, stateFamily, valueCoder);
       }
     }
-  };
+  }
+
+  private static class WindmillMultimap<K, V> extends SimpleWindmillState
+      implements MultimapState<K, V> {
+
+    private final StateNamespace namespace;
+    private final StateTag<MultimapState<K, V>> address;
+    private final ByteString stateKey;
+    private final String stateFamily;
+    private final Coder<K> keyCoder;
+    private final Coder<V> valueCoder;
+
+    private class KeyState {
+      final K originalKey;
+      KeyExistence existence;
+      // valuesCached can be true if only existence == KNOWN_EXIST and all values of this key is
+      // cached (both values and localAdditions).
+      boolean valuesCached;
+      // Represents the values in windmill. When new values are added during user processing, they
+      // are added to localAdditions but not values. Those new values will be added to values only
+      // after they are persisted into windmill and removed from localAdditions
+      ConcatIterables<V> values;
+      int valuesSize;
+
+      // When new values are added during user processing, they are added to localAdditions, so that
+      // we can later try to persist them in windmill. When a key is removed during user processing,
+      // we mark removedLocally to be true so that we can later try to delete it from windmill. If
+      // localAdditions is not empty and removedLocally is true, values in localAdditions will be
+      // added to windmill after old values in windmill are removed.
+      List<V> localAdditions;
+      boolean removedLocally;
+
+      KeyState(K originalKey) {
+        this.originalKey = originalKey;
+        existence = KeyExistence.UNKNOWN_EXISTENCE;
+        valuesCached = complete;
+        values = new ConcatIterables<>();
+        valuesSize = 0;
+        localAdditions = Lists.newArrayList();
+        removedLocally = false;
+      }
+    }
+
+    private enum KeyExistence {
+      // this key is known to exist
+      KNOWN_EXIST,
+      // this key is known to be nonexistent
+      KNOWN_NONEXISTENT,
+      // we don't know if this key is in this multimap, this is just to provide a mapping between
+      // the original key and the structural key.
+      UNKNOWN_EXISTENCE
+    }
+
+    private boolean cleared = false;
+    // We use the structural value of the keys as the key in keyStateMap, so that different java
+    // Objects with the same content will be treated as the same Multimap key.
+    private Map<Object, KeyState> keyStateMap = Maps.newHashMap();
+    // If true, all keys are cached in keyStateMap with existence == KNOWN_EXIST.
+    private boolean allKeysKnown = false;
+
+    private boolean complete = false;
+    // hasLocalAdditions and hasLocalRemovals track whether there are local changes that needs to be
+    // propagated to windmill.
+    private boolean hasLocalAdditions = false;
+    private boolean hasLocalRemovals = false;
+
+    private WindmillMultimap(
+        StateNamespace namespace,
+        StateTag<MultimapState<K, V>> address,
+        String stateFamily,
+        Coder<K> keyCoder,
+        Coder<V> valueCoder,
+        boolean isNewShardingKey) {
+      this.namespace = namespace;
+      this.address = address;
+      this.stateKey = encodeKey(namespace, address);
+      this.stateFamily = stateFamily;
+      this.keyCoder = keyCoder;
+      this.valueCoder = valueCoder;
+      this.complete = isNewShardingKey;
+      this.allKeysKnown = isNewShardingKey;
+    }
+
+    @Override
+    public void put(K key, V value) {
+      final Object structuralKey = keyCoder.structuralValue(key);
+      hasLocalAdditions = true;
+      keyStateMap.compute(
+          structuralKey,
+          (k, v) -> {
+            if (v == null) v = new KeyState(key);
+            v.existence = KeyExistence.KNOWN_EXIST;
+            v.localAdditions.add(value);
+            return v;
+          });
+    }
+
+    // Initiates a backend state read to fetch all entries if necessary.
+    private Future<Iterable<Map.Entry<ByteString, Iterable<V>>>> getFuture(boolean omitValues) {
+      if (complete) {
+        return Futures.immediateFuture(Collections.emptyList());
+      } else {
+        return reader.multimapFetchAllFuture(omitValues, stateKey, stateFamily, valueCoder);
+      }
+    }
+
+    // Initiates a backend state read to fetch a single entry if necessary.
+    private Future<Iterable<V>> getFutureForKey(K key) {
+      try {
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        keyCoder.encode(key, keyStream, Context.OUTER);
+        return reader.multimapFetchSingleEntryFuture(
+            keyStream.toByteString(), stateKey, stateFamily, valueCoder);
+      } catch (IOException e) {
+        throw new RuntimeException(e);
+      }
+    }
+
+    @Override
+    public ReadableState<Iterable<V>> get(K key) {
+      return new ReadableState<Iterable<V>>() {
+        final Object structuralKey = keyCoder.structuralValue(key);
+
+        @Override
+        public Iterable<V> read() {
+          KeyState keyState = keyStateMap.computeIfAbsent(structuralKey, k -> new KeyState(key));
+          if (keyState.existence == KeyExistence.KNOWN_NONEXISTENT) {
+            return Collections.emptyList();
+          }
+          if (allKeysKnown && keyState.existence == KeyExistence.UNKNOWN_EXISTENCE) {
+            keyStateMap.remove(structuralKey);
+            return Collections.emptyList();
+          }
+          Iterable<V> localNewValues =
+              Iterables.limit(keyState.localAdditions, keyState.localAdditions.size());
+          if (keyState.removedLocally) {
+            // this key has been removed locally but the removal hasn't been sent to windmill,
+            // thus values in windmill(if any) are obsolete, and we only care about local values.
+            return Iterables.unmodifiableIterable(localNewValues);
+          }
+          if (keyState.valuesCached || complete) {
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(
+                    Iterables.limit(keyState.values, keyState.valuesSize), localNewValues));
+          }
+          Future<Iterable<V>> persistedData = getFutureForKey(key);
+          try (Closeable scope = scopedReadState()) {
+            final Iterable<V> persistedValues = persistedData.get();
+            // Iterables.isEmpty() is O(1).
+            if (Iterables.isEmpty(persistedValues)) {
+              if (keyState.localAdditions.isEmpty()) {
+                // empty in both cache and windmill, mark key as KNOWN_NONEXISTENT.
+                keyState.existence = KeyExistence.KNOWN_NONEXISTENT;
+                return Collections.emptyList();
+              }
+              return Iterables.unmodifiableIterable(localNewValues);
+            }
+            keyState.existence = KeyExistence.KNOWN_EXIST;
+            if (persistedValues instanceof Weighted) {
+              keyState.valuesCached = true;
+              ConcatIterables<V> it = new ConcatIterables<>();
+              it.extendWith(persistedValues);
+              keyState.values = it;
+              keyState.valuesSize = Iterables.size(persistedValues);
+            }
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(persistedValues, localNewValues));
+          } catch (InterruptedException | ExecutionException | IOException e) {
+            if (e instanceof InterruptedException) {
+              Thread.currentThread().interrupt();
+            }
+            throw new RuntimeException("Unable to read Multimap state", e);
+          }
+        }
+
+        @Override
+        @SuppressWarnings("FutureReturnValueIgnored")
+        public ReadableState<Iterable<V>> readLater() {
+          WindmillMultimap.this.getFutureForKey(key);
+          return this;
+        }
+      };
+    }
+
+    @Override
+    protected WorkItemCommitRequest persistDirectly(WindmillStateCache.ForKeyAndFamily cache)
+        throws IOException {
+      if (!cleared && !hasLocalAdditions && !hasLocalRemovals) {
+        cache.put(namespace, address, this, 1);
+        return WorkItemCommitRequest.newBuilder().buildPartial();
+      }
+      WorkItemCommitRequest.Builder commitBuilder = WorkItemCommitRequest.newBuilder();
+      Windmill.TagMultimapUpdateRequest.Builder builder = commitBuilder.addMultimapUpdatesBuilder();
+      builder.setTag(stateKey).setStateFamily(stateFamily);
+
+      if (cleared) {
+        builder.setDeleteAll(true);
+      }
+      if (hasLocalRemovals || hasLocalAdditions) {
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        ByteStringOutputStream valueStream = new ByteStringOutputStream();
+        Iterator<Entry<Object, KeyState>> iterator = keyStateMap.entrySet().iterator();
+        while (iterator.hasNext()) {
+          KeyState keyState = iterator.next().getValue();
+          if (!keyState.removedLocally && keyState.localAdditions.isEmpty()) {
+            if (keyState.existence == KeyExistence.KNOWN_NONEXISTENT) iterator.remove();
+            continue;
+          }
+          keyCoder.encode(keyState.originalKey, keyStream, Context.OUTER);
+          ByteString encodedKey = keyStream.toByteStringAndReset();
+          Windmill.TagMultimapEntry.Builder entryBuilder = builder.addUpdatesBuilder();
+          entryBuilder.setEntryName(encodedKey);
+          entryBuilder.setDeleteAll(keyState.removedLocally);
+          keyState.removedLocally = false;
+          for (V value : keyState.localAdditions) {
+            valueCoder.encode(value, valueStream, Context.OUTER);
+            ByteString encodedValue = valueStream.toByteStringAndReset();
+            entryBuilder.addValues(encodedValue);
+          }
+          // Move newly added values from localAdditions to keyState.values as those new values now
+          // are also persisted in Windmill. If a key now has no more values and is not KNOWN_EXIST,
+          // remove it from cache.
+          if (keyState.valuesCached) {
+            keyState.values.extendWith(keyState.localAdditions);
+            keyState.valuesSize += keyState.localAdditions.size();
+          }
+          // Create a new localAdditions so that the cached values are unaffected.
+          keyState.localAdditions = Lists.newArrayList();
+          if (!keyState.valuesCached && keyState.existence != KeyExistence.KNOWN_EXIST) {
+            iterator.remove();
+          }
+        }
+      }
+
+      hasLocalAdditions = false;
+      hasLocalRemovals = false;
+      cleared = false;
+
+      cache.put(namespace, address, this, 1);
+      return commitBuilder.buildPartial();
+    }
+
+    @Override
+    public void remove(K key) {
+      final Object structuralKey = keyCoder.structuralValue(key);
+      KeyState keyState = keyStateMap.computeIfAbsent(structuralKey, k -> new KeyState(key));
+      if (keyState.existence == KeyExistence.KNOWN_NONEXISTENT
+          || (allKeysKnown && keyState.existence == KeyExistence.UNKNOWN_EXISTENCE)) {
+        return;
+      }
+      if (!keyState.valuesCached || (keyState.valuesCached && keyState.valuesSize > 0)) {
+        // there may be data in windmill that need to be removed.
+        hasLocalRemovals = true;
+        keyState.removedLocally = true;
+        keyState.values = new ConcatIterables<>();
+        keyState.valuesSize = 0;
+        keyState.existence = KeyExistence.KNOWN_NONEXISTENT;
+      } else {
+        // no data in windmill, deleting from local cache is sufficient.

Review Comment:
   Done swapping the case.
   
   such a key could be KNOWN_EXIST if it has local additions.



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternals.java:
##########
@@ -1610,10 +1610,21 @@ private Future<Iterable<Map.Entry<ByteString, V>>> getFuture() {
     private final Coder<K> keyCoder;
     private final Coder<V> valueCoder;
 
+    private enum KeyExistence {
+      // this key is known to exist, it has at least 1 value in either localAdditions or windmill
+      KNOWN_EXIST,
+      // this key is known to be nonexistent, it has 0 value in both localAdditions and windmill

Review Comment:
   Done



##########
runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WindmillStateReaderTest.java:
##########
@@ -99,6 +113,713 @@ private ByteString intData(int value) throws IOException {
     return output.toByteString();
   }
 
+  @Test
+  public void testReadMultimapSingleEntry() throws Exception {
+    Future<Iterable<Integer>> future =
+        underTest.multimapFetchSingleEntryFuture(
+            STATE_MULTIMAP_KEY_1, STATE_KEY_1, STATE_FAMILY, INT_CODER);
+    Mockito.verifyNoMoreInteractions(mockWindmill);
+
+    Windmill.KeyedGetDataRequest.Builder expectedRequest =
+        Windmill.KeyedGetDataRequest.newBuilder()
+            .setKey(DATA_KEY)
+            .setShardingKey(SHARDING_KEY)
+            .setWorkToken(WORK_TOKEN)
+            .setMaxBytes(WindmillStateReader.MAX_KEY_BYTES)
+            .addMultimapsToFetch(
+                Windmill.TagMultimapFetchRequest.newBuilder()
+                    .setTag(STATE_KEY_1)
+                    .setStateFamily(STATE_FAMILY)
+                    .setFetchEntryNamesOnly(false)
+                    .addEntriesToFetch(
+                        Windmill.TagMultimapEntry.newBuilder()
+                            .setEntryName(STATE_MULTIMAP_KEY_1)
+                            .setFetchMaxBytes(WindmillStateReader.INITIAL_MAX_MULTIMAP_BYTES)
+                            .build()));
+
+    Windmill.KeyedGetDataResponse.Builder response =
+        Windmill.KeyedGetDataResponse.newBuilder()
+            .setKey(DATA_KEY)
+            .addTagMultimaps(
+                Windmill.TagMultimapFetchResponse.newBuilder()
+                    .setTag(STATE_KEY_1)
+                    .setStateFamily(STATE_FAMILY)
+                    .addEntries(
+                        Windmill.TagMultimapEntry.newBuilder()
+                            .setEntryName(STATE_MULTIMAP_KEY_1)
+                            .addAllValues(Arrays.asList(intData(5), intData(6)))));
+    Mockito.when(mockWindmill.getStateData(COMPUTATION, expectedRequest.build()))
+        .thenReturn(response.build());
+
+    Iterable<Integer> results = future.get();
+    Mockito.verify(mockWindmill).getStateData(COMPUTATION, expectedRequest.build());
+    for (Integer unused : results) {
+      // Iterate over the results to force loading all the pages.
+    }
+    Mockito.verifyNoMoreInteractions(mockWindmill);
+
+    assertThat(results, Matchers.containsInAnyOrder(5, 6));
+    assertNoReader(future);
+  }
+
+  @Test
+  public void testReadMultimapSingleEntryPaginated() throws Exception {
+    Future<Iterable<Integer>> future =
+        underTest.multimapFetchSingleEntryFuture(
+            STATE_MULTIMAP_KEY_1, STATE_KEY_1, STATE_FAMILY, INT_CODER);
+    Mockito.verifyNoMoreInteractions(mockWindmill);
+
+    Windmill.KeyedGetDataRequest.Builder expectedRequest1 =
+        Windmill.KeyedGetDataRequest.newBuilder()
+            .setKey(DATA_KEY)
+            .setShardingKey(SHARDING_KEY)
+            .setWorkToken(WORK_TOKEN)
+            .setMaxBytes(WindmillStateReader.MAX_KEY_BYTES)
+            .addMultimapsToFetch(
+                Windmill.TagMultimapFetchRequest.newBuilder()
+                    .setTag(STATE_KEY_1)
+                    .setStateFamily(STATE_FAMILY)
+                    .setFetchEntryNamesOnly(false)
+                    .addEntriesToFetch(
+                        Windmill.TagMultimapEntry.newBuilder()
+                            .setEntryName(STATE_MULTIMAP_KEY_1)
+                            .setFetchMaxBytes(WindmillStateReader.INITIAL_MAX_MULTIMAP_BYTES)
+                            .build()));
+
+    Windmill.KeyedGetDataResponse.Builder response1 =
+        Windmill.KeyedGetDataResponse.newBuilder()
+            .setKey(DATA_KEY)
+            .addTagMultimaps(
+                Windmill.TagMultimapFetchResponse.newBuilder()
+                    .setTag(STATE_KEY_1)
+                    .setStateFamily(STATE_FAMILY)
+                    .addEntries(
+                        Windmill.TagMultimapEntry.newBuilder()
+                            .setEntryName(STATE_MULTIMAP_KEY_1)
+                            .addAllValues(Arrays.asList(intData(5), intData(6)))
+                            .setContinuationPosition(500)));
+    Windmill.KeyedGetDataRequest.Builder expectedRequest2 =
+        Windmill.KeyedGetDataRequest.newBuilder()
+            .setKey(DATA_KEY)
+            .setShardingKey(SHARDING_KEY)
+            .setWorkToken(WORK_TOKEN)
+            .setMaxBytes(WindmillStateReader.MAX_CONTINUATION_KEY_BYTES)
+            .addMultimapsToFetch(
+                Windmill.TagMultimapFetchRequest.newBuilder()
+                    .setTag(STATE_KEY_1)
+                    .setStateFamily(STATE_FAMILY)
+                    .setFetchEntryNamesOnly(false)
+                    .addEntriesToFetch(
+                        Windmill.TagMultimapEntry.newBuilder()
+                            .setEntryName(STATE_MULTIMAP_KEY_1)
+                            .setFetchMaxBytes(WindmillStateReader.CONTINUATION_MAX_MULTIMAP_BYTES)
+                            .setRequestPosition(500)
+                            .build()));
+
+    Windmill.KeyedGetDataResponse.Builder response2 =
+        Windmill.KeyedGetDataResponse.newBuilder()
+            .setKey(DATA_KEY)
+            .addTagMultimaps(
+                Windmill.TagMultimapFetchResponse.newBuilder()
+                    .setTag(STATE_KEY_1)
+                    .setStateFamily(STATE_FAMILY)
+                    .addEntries(
+                        Windmill.TagMultimapEntry.newBuilder()
+                            .setEntryName(STATE_MULTIMAP_KEY_1)
+                            .addAllValues(Arrays.asList(intData(7), intData(8)))
+                            .setContinuationPosition(800)
+                            .setRequestPosition(500)));
+    Windmill.KeyedGetDataRequest.Builder expectedRequest3 =
+        Windmill.KeyedGetDataRequest.newBuilder()
+            .setKey(DATA_KEY)
+            .setShardingKey(SHARDING_KEY)
+            .setWorkToken(WORK_TOKEN)
+            .setMaxBytes(WindmillStateReader.MAX_CONTINUATION_KEY_BYTES)
+            .addMultimapsToFetch(
+                Windmill.TagMultimapFetchRequest.newBuilder()
+                    .setTag(STATE_KEY_1)
+                    .setStateFamily(STATE_FAMILY)
+                    .setFetchEntryNamesOnly(false)
+                    .addEntriesToFetch(
+                        Windmill.TagMultimapEntry.newBuilder()
+                            .setEntryName(STATE_MULTIMAP_KEY_1)
+                            .setFetchMaxBytes(WindmillStateReader.CONTINUATION_MAX_MULTIMAP_BYTES)
+                            .setRequestPosition(800)
+                            .build()));
+
+    Windmill.KeyedGetDataResponse.Builder response3 =
+        Windmill.KeyedGetDataResponse.newBuilder()
+            .setKey(DATA_KEY)
+            .addTagMultimaps(
+                Windmill.TagMultimapFetchResponse.newBuilder()
+                    .setTag(STATE_KEY_1)
+                    .setStateFamily(STATE_FAMILY)
+                    .addEntries(
+                        Windmill.TagMultimapEntry.newBuilder()
+                            .setEntryName(STATE_MULTIMAP_KEY_1)
+                            .addAllValues(Arrays.asList(intData(9), intData(10)))
+                            .setRequestPosition(800)));
+    Mockito.when(mockWindmill.getStateData(COMPUTATION, expectedRequest1.build()))
+        .thenReturn(response1.build());
+    Mockito.when(mockWindmill.getStateData(COMPUTATION, expectedRequest2.build()))
+        .thenReturn(response2.build());
+    Mockito.when(mockWindmill.getStateData(COMPUTATION, expectedRequest3.build()))
+        .thenReturn(response3.build());
+
+    Iterable<Integer> results = future.get();
+    Mockito.verify(mockWindmill).getStateData(COMPUTATION, expectedRequest1.build());

Review Comment:
   Done. I think to call `when` to setup and call `verify` later to verify is pretty standard, like other test cases.



##########
runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WindmillStateReaderTest.java:
##########
@@ -99,6 +113,713 @@ private ByteString intData(int value) throws IOException {
     return output.toByteString();
   }
 
+  @Test
+  public void testReadMultimapSingleEntry() throws Exception {
+    Future<Iterable<Integer>> future =
+        underTest.multimapFetchSingleEntryFuture(
+            STATE_MULTIMAP_KEY_1, STATE_KEY_1, STATE_FAMILY, INT_CODER);
+    Mockito.verifyNoMoreInteractions(mockWindmill);
+
+    Windmill.KeyedGetDataRequest.Builder expectedRequest =
+        Windmill.KeyedGetDataRequest.newBuilder()
+            .setKey(DATA_KEY)
+            .setShardingKey(SHARDING_KEY)
+            .setWorkToken(WORK_TOKEN)
+            .setMaxBytes(WindmillStateReader.MAX_KEY_BYTES)
+            .addMultimapsToFetch(
+                Windmill.TagMultimapFetchRequest.newBuilder()
+                    .setTag(STATE_KEY_1)
+                    .setStateFamily(STATE_FAMILY)
+                    .setFetchEntryNamesOnly(false)
+                    .addEntriesToFetch(
+                        Windmill.TagMultimapEntry.newBuilder()
+                            .setEntryName(STATE_MULTIMAP_KEY_1)
+                            .setFetchMaxBytes(WindmillStateReader.INITIAL_MAX_MULTIMAP_BYTES)
+                            .build()));
+
+    Windmill.KeyedGetDataResponse.Builder response =
+        Windmill.KeyedGetDataResponse.newBuilder()
+            .setKey(DATA_KEY)
+            .addTagMultimaps(
+                Windmill.TagMultimapFetchResponse.newBuilder()
+                    .setTag(STATE_KEY_1)
+                    .setStateFamily(STATE_FAMILY)
+                    .addEntries(
+                        Windmill.TagMultimapEntry.newBuilder()
+                            .setEntryName(STATE_MULTIMAP_KEY_1)
+                            .addAllValues(Arrays.asList(intData(5), intData(6)))));
+    Mockito.when(mockWindmill.getStateData(COMPUTATION, expectedRequest.build()))
+        .thenReturn(response.build());
+
+    Iterable<Integer> results = future.get();
+    Mockito.verify(mockWindmill).getStateData(COMPUTATION, expectedRequest.build());
+    for (Integer unused : results) {
+      // Iterate over the results to force loading all the pages.
+    }
+    Mockito.verifyNoMoreInteractions(mockWindmill);
+
+    assertThat(results, Matchers.containsInAnyOrder(5, 6));
+    assertNoReader(future);
+  }
+
+  @Test
+  public void testReadMultimapSingleEntryPaginated() throws Exception {
+    Future<Iterable<Integer>> future =
+        underTest.multimapFetchSingleEntryFuture(
+            STATE_MULTIMAP_KEY_1, STATE_KEY_1, STATE_FAMILY, INT_CODER);
+    Mockito.verifyNoMoreInteractions(mockWindmill);
+
+    Windmill.KeyedGetDataRequest.Builder expectedRequest1 =
+        Windmill.KeyedGetDataRequest.newBuilder()
+            .setKey(DATA_KEY)
+            .setShardingKey(SHARDING_KEY)
+            .setWorkToken(WORK_TOKEN)
+            .setMaxBytes(WindmillStateReader.MAX_KEY_BYTES)
+            .addMultimapsToFetch(
+                Windmill.TagMultimapFetchRequest.newBuilder()
+                    .setTag(STATE_KEY_1)
+                    .setStateFamily(STATE_FAMILY)
+                    .setFetchEntryNamesOnly(false)
+                    .addEntriesToFetch(
+                        Windmill.TagMultimapEntry.newBuilder()
+                            .setEntryName(STATE_MULTIMAP_KEY_1)
+                            .setFetchMaxBytes(WindmillStateReader.INITIAL_MAX_MULTIMAP_BYTES)
+                            .build()));
+
+    Windmill.KeyedGetDataResponse.Builder response1 =
+        Windmill.KeyedGetDataResponse.newBuilder()
+            .setKey(DATA_KEY)
+            .addTagMultimaps(
+                Windmill.TagMultimapFetchResponse.newBuilder()
+                    .setTag(STATE_KEY_1)
+                    .setStateFamily(STATE_FAMILY)
+                    .addEntries(
+                        Windmill.TagMultimapEntry.newBuilder()
+                            .setEntryName(STATE_MULTIMAP_KEY_1)
+                            .addAllValues(Arrays.asList(intData(5), intData(6)))
+                            .setContinuationPosition(500)));
+    Windmill.KeyedGetDataRequest.Builder expectedRequest2 =
+        Windmill.KeyedGetDataRequest.newBuilder()
+            .setKey(DATA_KEY)
+            .setShardingKey(SHARDING_KEY)
+            .setWorkToken(WORK_TOKEN)
+            .setMaxBytes(WindmillStateReader.MAX_CONTINUATION_KEY_BYTES)
+            .addMultimapsToFetch(
+                Windmill.TagMultimapFetchRequest.newBuilder()
+                    .setTag(STATE_KEY_1)
+                    .setStateFamily(STATE_FAMILY)
+                    .setFetchEntryNamesOnly(false)
+                    .addEntriesToFetch(
+                        Windmill.TagMultimapEntry.newBuilder()
+                            .setEntryName(STATE_MULTIMAP_KEY_1)
+                            .setFetchMaxBytes(WindmillStateReader.CONTINUATION_MAX_MULTIMAP_BYTES)
+                            .setRequestPosition(500)
+                            .build()));
+
+    Windmill.KeyedGetDataResponse.Builder response2 =
+        Windmill.KeyedGetDataResponse.newBuilder()
+            .setKey(DATA_KEY)
+            .addTagMultimaps(
+                Windmill.TagMultimapFetchResponse.newBuilder()
+                    .setTag(STATE_KEY_1)
+                    .setStateFamily(STATE_FAMILY)
+                    .addEntries(
+                        Windmill.TagMultimapEntry.newBuilder()
+                            .setEntryName(STATE_MULTIMAP_KEY_1)
+                            .addAllValues(Arrays.asList(intData(7), intData(8)))
+                            .setContinuationPosition(800)
+                            .setRequestPosition(500)));
+    Windmill.KeyedGetDataRequest.Builder expectedRequest3 =
+        Windmill.KeyedGetDataRequest.newBuilder()
+            .setKey(DATA_KEY)
+            .setShardingKey(SHARDING_KEY)
+            .setWorkToken(WORK_TOKEN)
+            .setMaxBytes(WindmillStateReader.MAX_CONTINUATION_KEY_BYTES)
+            .addMultimapsToFetch(
+                Windmill.TagMultimapFetchRequest.newBuilder()
+                    .setTag(STATE_KEY_1)
+                    .setStateFamily(STATE_FAMILY)
+                    .setFetchEntryNamesOnly(false)
+                    .addEntriesToFetch(
+                        Windmill.TagMultimapEntry.newBuilder()
+                            .setEntryName(STATE_MULTIMAP_KEY_1)
+                            .setFetchMaxBytes(WindmillStateReader.CONTINUATION_MAX_MULTIMAP_BYTES)
+                            .setRequestPosition(800)
+                            .build()));
+
+    Windmill.KeyedGetDataResponse.Builder response3 =
+        Windmill.KeyedGetDataResponse.newBuilder()
+            .setKey(DATA_KEY)
+            .addTagMultimaps(
+                Windmill.TagMultimapFetchResponse.newBuilder()
+                    .setTag(STATE_KEY_1)
+                    .setStateFamily(STATE_FAMILY)
+                    .addEntries(
+                        Windmill.TagMultimapEntry.newBuilder()
+                            .setEntryName(STATE_MULTIMAP_KEY_1)
+                            .addAllValues(Arrays.asList(intData(9), intData(10)))
+                            .setRequestPosition(800)));
+    Mockito.when(mockWindmill.getStateData(COMPUTATION, expectedRequest1.build()))
+        .thenReturn(response1.build());
+    Mockito.when(mockWindmill.getStateData(COMPUTATION, expectedRequest2.build()))
+        .thenReturn(response2.build());
+    Mockito.when(mockWindmill.getStateData(COMPUTATION, expectedRequest3.build()))
+        .thenReturn(response3.build());
+
+    Iterable<Integer> results = future.get();
+    Mockito.verify(mockWindmill).getStateData(COMPUTATION, expectedRequest1.build());
+    for (Integer unused : results) {
+      // Iterate over the results to force loading all the pages.
+    }
+    Mockito.verify(mockWindmill).getStateData(COMPUTATION, expectedRequest2.build());
+    Mockito.verify(mockWindmill).getStateData(COMPUTATION, expectedRequest3.build());
+    Mockito.verifyNoMoreInteractions(mockWindmill);
+
+    assertThat(results, Matchers.contains(5, 6, 7, 8, 9, 10));
+    // NOTE: The future will still contain a reference to the underlying reader.
+  }
+
+  // check whether the two TagMultimapFetchRequests equal to each other, ignoring the order of
+  // entries and the order of values in each entry.
+  private static void assertMultimapFetchRequestEqual(
+      Windmill.TagMultimapFetchRequest req1, Windmill.TagMultimapFetchRequest req2) {
+    assertMultimapEntriesEqual(req1.getEntriesToFetchList(), req2.getEntriesToFetchList());
+    assertEquals(
+        req1.toBuilder().clearEntriesToFetch().build(),
+        req2.toBuilder().clearEntriesToFetch().build());
+  }
+
+  private static void assertMultimapEntriesEqual(
+      List<Windmill.TagMultimapEntry> left, List<Windmill.TagMultimapEntry> right) {
+    Map<ByteString, Windmill.TagMultimapEntry> map = Maps.newHashMap();
+    for (Windmill.TagMultimapEntry entry : left) {
+      map.put(entry.getEntryName(), entry);
+    }
+    for (Windmill.TagMultimapEntry entry : right) {
+      assertTrue(map.containsKey(entry.getEntryName()));
+      Windmill.TagMultimapEntry that = map.remove(entry.getEntryName());
+      if (entry.getValuesCount() == 0) {
+        assertEquals(0, that.getValuesCount());
+      } else {
+        assertThat(entry.getValuesList(), Matchers.containsInAnyOrder(that.getValuesList()));
+      }
+      assertEquals(entry.toBuilder().clearValues().build(), that.toBuilder().clearValues().build());
+    }
+    assertTrue(map.isEmpty());
+  }
+
+  @Test
+  public void testReadMultimapMultipleEntries() throws Exception {
+    Future<Iterable<Integer>> future1 =
+        underTest.multimapFetchSingleEntryFuture(
+            STATE_MULTIMAP_KEY_1, STATE_KEY_1, STATE_FAMILY, INT_CODER);
+    Future<Iterable<Integer>> future2 =
+        underTest.multimapFetchSingleEntryFuture(
+            STATE_MULTIMAP_KEY_2, STATE_KEY_1, STATE_FAMILY, INT_CODER);
+    Mockito.verifyNoMoreInteractions(mockWindmill);
+
+    Windmill.KeyedGetDataRequest.Builder expectedRequest =
+        Windmill.KeyedGetDataRequest.newBuilder()
+            .setKey(DATA_KEY)
+            .setShardingKey(SHARDING_KEY)
+            .setWorkToken(WORK_TOKEN)
+            .setMaxBytes(WindmillStateReader.MAX_KEY_BYTES)
+            .addMultimapsToFetch(
+                Windmill.TagMultimapFetchRequest.newBuilder()
+                    .setTag(STATE_KEY_1)
+                    .setStateFamily(STATE_FAMILY)
+                    .setFetchEntryNamesOnly(false)
+                    .addEntriesToFetch(
+                        Windmill.TagMultimapEntry.newBuilder()
+                            .setEntryName(STATE_MULTIMAP_KEY_1)
+                            .setFetchMaxBytes(WindmillStateReader.INITIAL_MAX_MULTIMAP_BYTES)
+                            .build())
+                    .addEntriesToFetch(
+                        Windmill.TagMultimapEntry.newBuilder()
+                            .setEntryName(STATE_MULTIMAP_KEY_2)
+                            .setFetchMaxBytes(WindmillStateReader.INITIAL_MAX_MULTIMAP_BYTES)
+                            .build()));
+
+    Windmill.KeyedGetDataResponse.Builder response =
+        Windmill.KeyedGetDataResponse.newBuilder()
+            .setKey(DATA_KEY)
+            .addTagMultimaps(
+                Windmill.TagMultimapFetchResponse.newBuilder()
+                    .setTag(STATE_KEY_1)
+                    .setStateFamily(STATE_FAMILY)
+                    .addEntries(
+                        Windmill.TagMultimapEntry.newBuilder()
+                            .setEntryName(STATE_MULTIMAP_KEY_1)
+                            .addAllValues(Arrays.asList(intData(5), intData(6))))
+                    .addEntries(
+                        Windmill.TagMultimapEntry.newBuilder()
+                            .setEntryName(STATE_MULTIMAP_KEY_2)
+                            .addAllValues(Arrays.asList(intData(15), intData(16)))));
+    when(mockWindmill.getStateData(ArgumentMatchers.eq(COMPUTATION), ArgumentMatchers.any()))
+        .thenReturn(response.build());
+
+    Iterable<Integer> results1 = future1.get();
+    Iterable<Integer> results2 = future2.get();
+
+    final ArgumentCaptor<Windmill.KeyedGetDataRequest> requestCaptor =
+        ArgumentCaptor.forClass(Windmill.KeyedGetDataRequest.class);
+    Mockito.verify(mockWindmill)
+        .getStateData(ArgumentMatchers.eq(COMPUTATION), requestCaptor.capture());
+    assertMultimapFetchRequestEqual(
+        expectedRequest.build().getMultimapsToFetch(0),
+        requestCaptor.getValue().getMultimapsToFetch(0));
+
+    // Iterate over the results to force loading all the pages.
+    for (Integer unused : results1) {}
+    for (Integer unused : results2) {}
+    Mockito.verifyNoMoreInteractions(mockWindmill);
+
+    assertThat(results1, Matchers.containsInAnyOrder(5, 6));
+    assertThat(results2, Matchers.containsInAnyOrder(15, 16));
+    assertNoReader(future1);
+    assertNoReader(future2);
+  }
+
+  @Test
+  public void testReadMultimapMultipleEntriesWithPagination() throws Exception {
+    Future<Iterable<Integer>> future1 =
+        underTest.multimapFetchSingleEntryFuture(
+            STATE_MULTIMAP_KEY_1, STATE_KEY_1, STATE_FAMILY, INT_CODER);
+    Future<Iterable<Integer>> future2 =
+        underTest.multimapFetchSingleEntryFuture(
+            STATE_MULTIMAP_KEY_2, STATE_KEY_1, STATE_FAMILY, INT_CODER);
+    Mockito.verifyNoMoreInteractions(mockWindmill);
+
+    Windmill.KeyedGetDataRequest.Builder expectedRequest1 =
+        Windmill.KeyedGetDataRequest.newBuilder()
+            .setKey(DATA_KEY)
+            .setShardingKey(SHARDING_KEY)
+            .setWorkToken(WORK_TOKEN)
+            .setMaxBytes(WindmillStateReader.MAX_KEY_BYTES)
+            .addMultimapsToFetch(
+                Windmill.TagMultimapFetchRequest.newBuilder()
+                    .setTag(STATE_KEY_1)
+                    .setStateFamily(STATE_FAMILY)
+                    .setFetchEntryNamesOnly(false)
+                    .addEntriesToFetch(
+                        Windmill.TagMultimapEntry.newBuilder()
+                            .setEntryName(STATE_MULTIMAP_KEY_1)
+                            .setFetchMaxBytes(WindmillStateReader.INITIAL_MAX_MULTIMAP_BYTES)
+                            .build())
+                    .addEntriesToFetch(
+                        Windmill.TagMultimapEntry.newBuilder()
+                            .setEntryName(STATE_MULTIMAP_KEY_2)
+                            .setFetchMaxBytes(WindmillStateReader.INITIAL_MAX_MULTIMAP_BYTES)
+                            .build()));
+
+    Windmill.KeyedGetDataResponse.Builder response1 =
+        Windmill.KeyedGetDataResponse.newBuilder()
+            .setKey(DATA_KEY)
+            .addTagMultimaps(
+                Windmill.TagMultimapFetchResponse.newBuilder()
+                    .setTag(STATE_KEY_1)
+                    .setStateFamily(STATE_FAMILY)
+                    .addEntries(
+                        Windmill.TagMultimapEntry.newBuilder()
+                            .setEntryName(STATE_MULTIMAP_KEY_1)
+                            .addAllValues(Arrays.asList(intData(5), intData(6)))
+                            .setContinuationPosition(800))
+                    .addEntries(
+                        Windmill.TagMultimapEntry.newBuilder()
+                            .setEntryName(STATE_MULTIMAP_KEY_2)
+                            .addAllValues(Arrays.asList(intData(15), intData(16)))));
+    Windmill.KeyedGetDataRequest.Builder expectedRequest2 =
+        Windmill.KeyedGetDataRequest.newBuilder()
+            .setKey(DATA_KEY)
+            .setShardingKey(SHARDING_KEY)
+            .setWorkToken(WORK_TOKEN)
+            .setMaxBytes(WindmillStateReader.MAX_CONTINUATION_KEY_BYTES)
+            .addMultimapsToFetch(
+                Windmill.TagMultimapFetchRequest.newBuilder()
+                    .setTag(STATE_KEY_1)
+                    .setStateFamily(STATE_FAMILY)
+                    .setFetchEntryNamesOnly(false)
+                    .addEntriesToFetch(
+                        Windmill.TagMultimapEntry.newBuilder()
+                            .setEntryName(STATE_MULTIMAP_KEY_1)
+                            .setFetchMaxBytes(WindmillStateReader.CONTINUATION_MAX_MULTIMAP_BYTES)
+                            .setRequestPosition(800)
+                            .build()));
+    Windmill.KeyedGetDataResponse.Builder response2 =
+        Windmill.KeyedGetDataResponse.newBuilder()
+            .setKey(DATA_KEY)
+            .addTagMultimaps(
+                Windmill.TagMultimapFetchResponse.newBuilder()
+                    .setTag(STATE_KEY_1)
+                    .setStateFamily(STATE_FAMILY)
+                    .addEntries(
+                        Windmill.TagMultimapEntry.newBuilder()
+                            .setEntryName(STATE_MULTIMAP_KEY_1)
+                            .addAllValues(Arrays.asList(intData(7), intData(8)))
+                            .setRequestPosition(800)));
+    when(mockWindmill.getStateData(ArgumentMatchers.eq(COMPUTATION), ArgumentMatchers.any()))
+        .thenReturn(response1.build())
+        .thenReturn(response2.build());
+
+    Iterable<Integer> results1 = future1.get();
+    Iterable<Integer> results2 = future2.get();
+
+    // Iterate over the results to force loading all the pages.
+    for (Integer unused : results1) {}
+    for (Integer unused : results2) {}
+
+    final ArgumentCaptor<Windmill.KeyedGetDataRequest> requestCaptor =
+        ArgumentCaptor.forClass(Windmill.KeyedGetDataRequest.class);
+    Mockito.verify(mockWindmill, times(2))
+        .getStateData(ArgumentMatchers.eq(COMPUTATION), requestCaptor.capture());
+    assertMultimapFetchRequestEqual(
+        expectedRequest1.build().getMultimapsToFetch(0),
+        requestCaptor.getAllValues().get(0).getMultimapsToFetch(0));
+    assertMultimapFetchRequestEqual(
+        expectedRequest2.build().getMultimapsToFetch(0),
+        requestCaptor.getAllValues().get(1).getMultimapsToFetch(0));
+    Mockito.verifyNoMoreInteractions(mockWindmill);
+
+    assertThat(results1, Matchers.containsInAnyOrder(5, 6, 7, 8));
+    assertThat(results2, Matchers.containsInAnyOrder(15, 16));
+    // NOTE: The future will still contain a reference to the underlying reader.
+  }
+
+  @Test
+  public void testReadMultimapKeys() throws Exception {
+    Future<Iterable<Map.Entry<ByteString, Iterable<Integer>>>> future =
+        underTest.multimapFetchAllFuture(true, STATE_KEY_1, STATE_FAMILY, INT_CODER);
+    Mockito.verifyNoMoreInteractions(mockWindmill);
+
+    Windmill.KeyedGetDataRequest.Builder expectedRequest =
+        Windmill.KeyedGetDataRequest.newBuilder()
+            .setKey(DATA_KEY)
+            .setShardingKey(SHARDING_KEY)
+            .setWorkToken(WORK_TOKEN)
+            .setMaxBytes(WindmillStateReader.MAX_KEY_BYTES)
+            .addMultimapsToFetch(
+                Windmill.TagMultimapFetchRequest.newBuilder()
+                    .setTag(STATE_KEY_1)
+                    .setStateFamily(STATE_FAMILY)
+                    .setFetchEntryNamesOnly(true)
+                    .setFetchMaxBytes(WindmillStateReader.INITIAL_MAX_MULTIMAP_BYTES));
+
+    Windmill.KeyedGetDataResponse.Builder response =
+        Windmill.KeyedGetDataResponse.newBuilder()
+            .setKey(DATA_KEY)
+            .addTagMultimaps(
+                Windmill.TagMultimapFetchResponse.newBuilder()
+                    .setTag(STATE_KEY_1)
+                    .setStateFamily(STATE_FAMILY)
+                    .addEntries(
+                        Windmill.TagMultimapEntry.newBuilder().setEntryName(STATE_MULTIMAP_KEY_1))
+                    .addEntries(
+                        Windmill.TagMultimapEntry.newBuilder().setEntryName(STATE_MULTIMAP_KEY_2)));
+    Mockito.when(mockWindmill.getStateData(COMPUTATION, expectedRequest.build()))
+        .thenReturn(response.build());
+
+    Iterable<Map.Entry<ByteString, Iterable<Integer>>> results = future.get();
+    Mockito.verify(mockWindmill).getStateData(COMPUTATION, expectedRequest.build());
+    List<ByteString> keys = Lists.newArrayList();
+    for (Map.Entry<ByteString, Iterable<Integer>> entry : results) {
+      keys.add(entry.getKey());
+      assertEquals(0, Iterables.size(entry.getValue()));
+    }
+    Mockito.verifyNoMoreInteractions(mockWindmill);
+
+    assertThat(keys, Matchers.containsInAnyOrder(STATE_MULTIMAP_KEY_1, STATE_MULTIMAP_KEY_2));
+    assertNoReader(future);
+  }
+
+  @Test
+  public void testReadMultimapKeysPaginated() throws Exception {
+    Future<Iterable<Map.Entry<ByteString, Iterable<Integer>>>> future =
+        underTest.multimapFetchAllFuture(true, STATE_KEY_1, STATE_FAMILY, INT_CODER);
+    Mockito.verifyNoMoreInteractions(mockWindmill);
+
+    Windmill.KeyedGetDataRequest.Builder expectedRequest1 =
+        Windmill.KeyedGetDataRequest.newBuilder()
+            .setKey(DATA_KEY)
+            .setShardingKey(SHARDING_KEY)
+            .setWorkToken(WORK_TOKEN)
+            .setMaxBytes(WindmillStateReader.MAX_KEY_BYTES)
+            .addMultimapsToFetch(
+                Windmill.TagMultimapFetchRequest.newBuilder()
+                    .setTag(STATE_KEY_1)
+                    .setStateFamily(STATE_FAMILY)
+                    .setFetchEntryNamesOnly(true)
+                    .setFetchMaxBytes(WindmillStateReader.INITIAL_MAX_MULTIMAP_BYTES));
+
+    Windmill.KeyedGetDataResponse.Builder response1 =
+        Windmill.KeyedGetDataResponse.newBuilder()
+            .setKey(DATA_KEY)
+            .addTagMultimaps(
+                Windmill.TagMultimapFetchResponse.newBuilder()
+                    .setTag(STATE_KEY_1)
+                    .setStateFamily(STATE_FAMILY)
+                    .addEntries(
+                        Windmill.TagMultimapEntry.newBuilder().setEntryName(STATE_MULTIMAP_KEY_1))
+                    .setContinuationPosition(STATE_MULTIMAP_CONT_1));
+
+    Windmill.KeyedGetDataRequest.Builder expectedRequest2 =
+        Windmill.KeyedGetDataRequest.newBuilder()
+            .setKey(DATA_KEY)
+            .setShardingKey(SHARDING_KEY)
+            .setWorkToken(WORK_TOKEN)
+            .setMaxBytes(WindmillStateReader.MAX_CONTINUATION_KEY_BYTES)
+            .addMultimapsToFetch(
+                Windmill.TagMultimapFetchRequest.newBuilder()
+                    .setTag(STATE_KEY_1)
+                    .setStateFamily(STATE_FAMILY)
+                    .setFetchEntryNamesOnly(true)
+                    .setFetchMaxBytes(WindmillStateReader.CONTINUATION_MAX_MULTIMAP_BYTES)
+                    .setRequestPosition(STATE_MULTIMAP_CONT_1));
+
+    Windmill.KeyedGetDataResponse.Builder response2 =
+        Windmill.KeyedGetDataResponse.newBuilder()
+            .setKey(DATA_KEY)
+            .addTagMultimaps(
+                Windmill.TagMultimapFetchResponse.newBuilder()
+                    .setTag(STATE_KEY_1)
+                    .setStateFamily(STATE_FAMILY)
+                    .addEntries(
+                        Windmill.TagMultimapEntry.newBuilder().setEntryName(STATE_MULTIMAP_KEY_2))
+                    .setRequestPosition(STATE_MULTIMAP_CONT_1)
+                    .setContinuationPosition(STATE_MULTIMAP_CONT_2));
+    Windmill.KeyedGetDataRequest.Builder expectedRequest3 =
+        Windmill.KeyedGetDataRequest.newBuilder()
+            .setKey(DATA_KEY)
+            .setShardingKey(SHARDING_KEY)
+            .setWorkToken(WORK_TOKEN)
+            .setMaxBytes(WindmillStateReader.MAX_CONTINUATION_KEY_BYTES)
+            .addMultimapsToFetch(
+                Windmill.TagMultimapFetchRequest.newBuilder()
+                    .setTag(STATE_KEY_1)
+                    .setStateFamily(STATE_FAMILY)
+                    .setFetchEntryNamesOnly(true)
+                    .setFetchMaxBytes(WindmillStateReader.CONTINUATION_MAX_MULTIMAP_BYTES)
+                    .setRequestPosition(STATE_MULTIMAP_CONT_2));
+
+    Windmill.KeyedGetDataResponse.Builder response3 =
+        Windmill.KeyedGetDataResponse.newBuilder()
+            .setKey(DATA_KEY)
+            .addTagMultimaps(
+                Windmill.TagMultimapFetchResponse.newBuilder()
+                    .setTag(STATE_KEY_1)
+                    .setStateFamily(STATE_FAMILY)
+                    .addEntries(
+                        Windmill.TagMultimapEntry.newBuilder().setEntryName(STATE_MULTIMAP_KEY_3))
+                    .setRequestPosition(STATE_MULTIMAP_CONT_2));
+    Mockito.when(mockWindmill.getStateData(COMPUTATION, expectedRequest1.build()))
+        .thenReturn(response1.build());
+    Mockito.when(mockWindmill.getStateData(COMPUTATION, expectedRequest2.build()))
+        .thenReturn(response2.build());
+    Mockito.when(mockWindmill.getStateData(COMPUTATION, expectedRequest3.build()))
+        .thenReturn(response3.build());
+
+    Iterable<Map.Entry<ByteString, Iterable<Integer>>> results = future.get();
+    Mockito.verify(mockWindmill).getStateData(COMPUTATION, expectedRequest1.build());
+    List<ByteString> keys = Lists.newArrayList();
+    for (Map.Entry<ByteString, Iterable<Integer>> entry : results) {
+      keys.add(entry.getKey());
+      assertEquals(0, Iterables.size(entry.getValue()));
+    }
+    Mockito.verify(mockWindmill).getStateData(COMPUTATION, expectedRequest2.build());
+    Mockito.verify(mockWindmill).getStateData(COMPUTATION, expectedRequest3.build());
+    Mockito.verifyNoMoreInteractions(mockWindmill);
+
+    assertThat(
+        keys,
+        Matchers.containsInAnyOrder(
+            STATE_MULTIMAP_KEY_1, STATE_MULTIMAP_KEY_2, STATE_MULTIMAP_KEY_3));
+    // NOTE: The future will still contain a reference to the underlying reader.
+  }
+
+  @Test
+  public void testReadMultimapAllEntries() throws Exception {
+    Future<Iterable<Map.Entry<ByteString, Iterable<Integer>>>> future =
+        underTest.multimapFetchAllFuture(false, STATE_KEY_1, STATE_FAMILY, INT_CODER);
+    Mockito.verifyNoMoreInteractions(mockWindmill);
+
+    Windmill.KeyedGetDataRequest.Builder expectedRequest =
+        Windmill.KeyedGetDataRequest.newBuilder()
+            .setKey(DATA_KEY)
+            .setShardingKey(SHARDING_KEY)
+            .setWorkToken(WORK_TOKEN)
+            .setMaxBytes(WindmillStateReader.MAX_KEY_BYTES)
+            .addMultimapsToFetch(
+                Windmill.TagMultimapFetchRequest.newBuilder()
+                    .setTag(STATE_KEY_1)
+                    .setStateFamily(STATE_FAMILY)
+                    .setFetchEntryNamesOnly(false)
+                    .setFetchMaxBytes(WindmillStateReader.INITIAL_MAX_MULTIMAP_BYTES));
+
+    Windmill.KeyedGetDataResponse.Builder response =
+        Windmill.KeyedGetDataResponse.newBuilder()
+            .setKey(DATA_KEY)
+            .addTagMultimaps(
+                Windmill.TagMultimapFetchResponse.newBuilder()
+                    .setTag(STATE_KEY_1)
+                    .setStateFamily(STATE_FAMILY)
+                    .addEntries(
+                        Windmill.TagMultimapEntry.newBuilder()
+                            .setEntryName(STATE_MULTIMAP_KEY_1)
+                            .addValues(intData(1))
+                            .addValues(intData(2)))
+                    .addEntries(
+                        Windmill.TagMultimapEntry.newBuilder()
+                            .setEntryName(STATE_MULTIMAP_KEY_2)
+                            .addValues(intData(10))
+                            .addValues(intData(20))));
+    Mockito.when(mockWindmill.getStateData(COMPUTATION, expectedRequest.build()))
+        .thenReturn(response.build());
+
+    Iterable<Map.Entry<ByteString, Iterable<Integer>>> results = future.get();
+    Mockito.verify(mockWindmill).getStateData(COMPUTATION, expectedRequest.build());
+    int foundEntries = 0;
+    for (Map.Entry<ByteString, Iterable<Integer>> entry : results) {
+      if (entry.getKey().equals(STATE_MULTIMAP_KEY_1)) {
+        foundEntries++;
+        assertThat(entry.getValue(), Matchers.containsInAnyOrder(1, 2));
+      } else {
+        foundEntries++;
+        assertEquals(STATE_MULTIMAP_KEY_2, entry.getKey());
+        assertThat(entry.getValue(), Matchers.containsInAnyOrder(10, 20));
+      }
+    }
+    assertEquals(2, foundEntries);
+    Mockito.verifyNoMoreInteractions(mockWindmill);
+    assertNoReader(future);
+  }
+
+  private static void assertMultimapEntries(
+      Iterable<Map.Entry<ByteString, Iterable<Integer>>> expected,
+      List<Map.Entry<ByteString, List<Integer>>> actual) {
+    Map<ByteString, List<Integer>> expectedMap = Maps.newHashMap();
+    for (Map.Entry<ByteString, Iterable<Integer>> entry : expected) {
+      ByteString key = entry.getKey();
+      if (!expectedMap.containsKey(key)) expectedMap.put(key, new ArrayList<>());
+      entry.getValue().forEach(expectedMap.get(key)::add);
+    }
+    for (Map.Entry<ByteString, List<Integer>> entry : actual) {
+      assertThat(

Review Comment:
   Done



##########
runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WindmillStateReaderTest.java:
##########
@@ -99,6 +113,713 @@ private ByteString intData(int value) throws IOException {
     return output.toByteString();
   }
 
+  @Test
+  public void testReadMultimapSingleEntry() throws Exception {
+    Future<Iterable<Integer>> future =
+        underTest.multimapFetchSingleEntryFuture(
+            STATE_MULTIMAP_KEY_1, STATE_KEY_1, STATE_FAMILY, INT_CODER);
+    Mockito.verifyNoMoreInteractions(mockWindmill);
+
+    Windmill.KeyedGetDataRequest.Builder expectedRequest =
+        Windmill.KeyedGetDataRequest.newBuilder()
+            .setKey(DATA_KEY)
+            .setShardingKey(SHARDING_KEY)
+            .setWorkToken(WORK_TOKEN)
+            .setMaxBytes(WindmillStateReader.MAX_KEY_BYTES)
+            .addMultimapsToFetch(
+                Windmill.TagMultimapFetchRequest.newBuilder()
+                    .setTag(STATE_KEY_1)
+                    .setStateFamily(STATE_FAMILY)
+                    .setFetchEntryNamesOnly(false)
+                    .addEntriesToFetch(
+                        Windmill.TagMultimapEntry.newBuilder()
+                            .setEntryName(STATE_MULTIMAP_KEY_1)
+                            .setFetchMaxBytes(WindmillStateReader.INITIAL_MAX_MULTIMAP_BYTES)
+                            .build()));
+
+    Windmill.KeyedGetDataResponse.Builder response =
+        Windmill.KeyedGetDataResponse.newBuilder()
+            .setKey(DATA_KEY)
+            .addTagMultimaps(
+                Windmill.TagMultimapFetchResponse.newBuilder()
+                    .setTag(STATE_KEY_1)
+                    .setStateFamily(STATE_FAMILY)
+                    .addEntries(
+                        Windmill.TagMultimapEntry.newBuilder()
+                            .setEntryName(STATE_MULTIMAP_KEY_1)
+                            .addAllValues(Arrays.asList(intData(5), intData(6)))));
+    Mockito.when(mockWindmill.getStateData(COMPUTATION, expectedRequest.build()))
+        .thenReturn(response.build());
+
+    Iterable<Integer> results = future.get();
+    Mockito.verify(mockWindmill).getStateData(COMPUTATION, expectedRequest.build());
+    for (Integer unused : results) {
+      // Iterate over the results to force loading all the pages.
+    }
+    Mockito.verifyNoMoreInteractions(mockWindmill);
+
+    assertThat(results, Matchers.containsInAnyOrder(5, 6));
+    assertNoReader(future);
+  }
+
+  @Test
+  public void testReadMultimapSingleEntryPaginated() throws Exception {
+    Future<Iterable<Integer>> future =
+        underTest.multimapFetchSingleEntryFuture(
+            STATE_MULTIMAP_KEY_1, STATE_KEY_1, STATE_FAMILY, INT_CODER);
+    Mockito.verifyNoMoreInteractions(mockWindmill);
+
+    Windmill.KeyedGetDataRequest.Builder expectedRequest1 =
+        Windmill.KeyedGetDataRequest.newBuilder()
+            .setKey(DATA_KEY)
+            .setShardingKey(SHARDING_KEY)
+            .setWorkToken(WORK_TOKEN)
+            .setMaxBytes(WindmillStateReader.MAX_KEY_BYTES)
+            .addMultimapsToFetch(
+                Windmill.TagMultimapFetchRequest.newBuilder()
+                    .setTag(STATE_KEY_1)
+                    .setStateFamily(STATE_FAMILY)
+                    .setFetchEntryNamesOnly(false)
+                    .addEntriesToFetch(
+                        Windmill.TagMultimapEntry.newBuilder()
+                            .setEntryName(STATE_MULTIMAP_KEY_1)
+                            .setFetchMaxBytes(WindmillStateReader.INITIAL_MAX_MULTIMAP_BYTES)
+                            .build()));
+
+    Windmill.KeyedGetDataResponse.Builder response1 =
+        Windmill.KeyedGetDataResponse.newBuilder()
+            .setKey(DATA_KEY)
+            .addTagMultimaps(
+                Windmill.TagMultimapFetchResponse.newBuilder()
+                    .setTag(STATE_KEY_1)
+                    .setStateFamily(STATE_FAMILY)
+                    .addEntries(
+                        Windmill.TagMultimapEntry.newBuilder()
+                            .setEntryName(STATE_MULTIMAP_KEY_1)
+                            .addAllValues(Arrays.asList(intData(5), intData(6)))
+                            .setContinuationPosition(500)));
+    Windmill.KeyedGetDataRequest.Builder expectedRequest2 =
+        Windmill.KeyedGetDataRequest.newBuilder()
+            .setKey(DATA_KEY)
+            .setShardingKey(SHARDING_KEY)
+            .setWorkToken(WORK_TOKEN)
+            .setMaxBytes(WindmillStateReader.MAX_CONTINUATION_KEY_BYTES)
+            .addMultimapsToFetch(
+                Windmill.TagMultimapFetchRequest.newBuilder()
+                    .setTag(STATE_KEY_1)
+                    .setStateFamily(STATE_FAMILY)
+                    .setFetchEntryNamesOnly(false)
+                    .addEntriesToFetch(
+                        Windmill.TagMultimapEntry.newBuilder()
+                            .setEntryName(STATE_MULTIMAP_KEY_1)
+                            .setFetchMaxBytes(WindmillStateReader.CONTINUATION_MAX_MULTIMAP_BYTES)
+                            .setRequestPosition(500)
+                            .build()));
+
+    Windmill.KeyedGetDataResponse.Builder response2 =
+        Windmill.KeyedGetDataResponse.newBuilder()
+            .setKey(DATA_KEY)
+            .addTagMultimaps(
+                Windmill.TagMultimapFetchResponse.newBuilder()
+                    .setTag(STATE_KEY_1)
+                    .setStateFamily(STATE_FAMILY)
+                    .addEntries(
+                        Windmill.TagMultimapEntry.newBuilder()
+                            .setEntryName(STATE_MULTIMAP_KEY_1)
+                            .addAllValues(Arrays.asList(intData(7), intData(8)))
+                            .setContinuationPosition(800)
+                            .setRequestPosition(500)));
+    Windmill.KeyedGetDataRequest.Builder expectedRequest3 =
+        Windmill.KeyedGetDataRequest.newBuilder()
+            .setKey(DATA_KEY)
+            .setShardingKey(SHARDING_KEY)
+            .setWorkToken(WORK_TOKEN)
+            .setMaxBytes(WindmillStateReader.MAX_CONTINUATION_KEY_BYTES)
+            .addMultimapsToFetch(
+                Windmill.TagMultimapFetchRequest.newBuilder()
+                    .setTag(STATE_KEY_1)
+                    .setStateFamily(STATE_FAMILY)
+                    .setFetchEntryNamesOnly(false)
+                    .addEntriesToFetch(
+                        Windmill.TagMultimapEntry.newBuilder()
+                            .setEntryName(STATE_MULTIMAP_KEY_1)
+                            .setFetchMaxBytes(WindmillStateReader.CONTINUATION_MAX_MULTIMAP_BYTES)
+                            .setRequestPosition(800)
+                            .build()));
+
+    Windmill.KeyedGetDataResponse.Builder response3 =
+        Windmill.KeyedGetDataResponse.newBuilder()
+            .setKey(DATA_KEY)
+            .addTagMultimaps(
+                Windmill.TagMultimapFetchResponse.newBuilder()
+                    .setTag(STATE_KEY_1)
+                    .setStateFamily(STATE_FAMILY)
+                    .addEntries(
+                        Windmill.TagMultimapEntry.newBuilder()
+                            .setEntryName(STATE_MULTIMAP_KEY_1)
+                            .addAllValues(Arrays.asList(intData(9), intData(10)))
+                            .setRequestPosition(800)));
+    Mockito.when(mockWindmill.getStateData(COMPUTATION, expectedRequest1.build()))
+        .thenReturn(response1.build());
+    Mockito.when(mockWindmill.getStateData(COMPUTATION, expectedRequest2.build()))
+        .thenReturn(response2.build());
+    Mockito.when(mockWindmill.getStateData(COMPUTATION, expectedRequest3.build()))
+        .thenReturn(response3.build());
+
+    Iterable<Integer> results = future.get();
+    Mockito.verify(mockWindmill).getStateData(COMPUTATION, expectedRequest1.build());
+    for (Integer unused : results) {
+      // Iterate over the results to force loading all the pages.
+    }
+    Mockito.verify(mockWindmill).getStateData(COMPUTATION, expectedRequest2.build());
+    Mockito.verify(mockWindmill).getStateData(COMPUTATION, expectedRequest3.build());
+    Mockito.verifyNoMoreInteractions(mockWindmill);
+
+    assertThat(results, Matchers.contains(5, 6, 7, 8, 9, 10));
+    // NOTE: The future will still contain a reference to the underlying reader.
+  }
+
+  // check whether the two TagMultimapFetchRequests equal to each other, ignoring the order of
+  // entries and the order of values in each entry.
+  private static void assertMultimapFetchRequestEqual(
+      Windmill.TagMultimapFetchRequest req1, Windmill.TagMultimapFetchRequest req2) {
+    assertMultimapEntriesEqual(req1.getEntriesToFetchList(), req2.getEntriesToFetchList());
+    assertEquals(
+        req1.toBuilder().clearEntriesToFetch().build(),
+        req2.toBuilder().clearEntriesToFetch().build());
+  }
+
+  private static void assertMultimapEntriesEqual(
+      List<Windmill.TagMultimapEntry> left, List<Windmill.TagMultimapEntry> right) {
+    Map<ByteString, Windmill.TagMultimapEntry> map = Maps.newHashMap();
+    for (Windmill.TagMultimapEntry entry : left) {
+      map.put(entry.getEntryName(), entry);
+    }
+    for (Windmill.TagMultimapEntry entry : right) {
+      assertTrue(map.containsKey(entry.getEntryName()));
+      Windmill.TagMultimapEntry that = map.remove(entry.getEntryName());
+      if (entry.getValuesCount() == 0) {
+        assertEquals(0, that.getValuesCount());
+      } else {
+        assertThat(entry.getValuesList(), Matchers.containsInAnyOrder(that.getValuesList()));
+      }
+      assertEquals(entry.toBuilder().clearValues().build(), that.toBuilder().clearValues().build());
+    }
+    assertTrue(map.isEmpty());
+  }
+
+  @Test
+  public void testReadMultimapMultipleEntries() throws Exception {
+    Future<Iterable<Integer>> future1 =
+        underTest.multimapFetchSingleEntryFuture(
+            STATE_MULTIMAP_KEY_1, STATE_KEY_1, STATE_FAMILY, INT_CODER);
+    Future<Iterable<Integer>> future2 =
+        underTest.multimapFetchSingleEntryFuture(
+            STATE_MULTIMAP_KEY_2, STATE_KEY_1, STATE_FAMILY, INT_CODER);
+    Mockito.verifyNoMoreInteractions(mockWindmill);
+
+    Windmill.KeyedGetDataRequest.Builder expectedRequest =
+        Windmill.KeyedGetDataRequest.newBuilder()
+            .setKey(DATA_KEY)
+            .setShardingKey(SHARDING_KEY)
+            .setWorkToken(WORK_TOKEN)
+            .setMaxBytes(WindmillStateReader.MAX_KEY_BYTES)
+            .addMultimapsToFetch(
+                Windmill.TagMultimapFetchRequest.newBuilder()
+                    .setTag(STATE_KEY_1)
+                    .setStateFamily(STATE_FAMILY)
+                    .setFetchEntryNamesOnly(false)
+                    .addEntriesToFetch(
+                        Windmill.TagMultimapEntry.newBuilder()
+                            .setEntryName(STATE_MULTIMAP_KEY_1)
+                            .setFetchMaxBytes(WindmillStateReader.INITIAL_MAX_MULTIMAP_BYTES)
+                            .build())
+                    .addEntriesToFetch(
+                        Windmill.TagMultimapEntry.newBuilder()
+                            .setEntryName(STATE_MULTIMAP_KEY_2)
+                            .setFetchMaxBytes(WindmillStateReader.INITIAL_MAX_MULTIMAP_BYTES)
+                            .build()));
+
+    Windmill.KeyedGetDataResponse.Builder response =
+        Windmill.KeyedGetDataResponse.newBuilder()
+            .setKey(DATA_KEY)
+            .addTagMultimaps(
+                Windmill.TagMultimapFetchResponse.newBuilder()
+                    .setTag(STATE_KEY_1)
+                    .setStateFamily(STATE_FAMILY)
+                    .addEntries(
+                        Windmill.TagMultimapEntry.newBuilder()
+                            .setEntryName(STATE_MULTIMAP_KEY_1)
+                            .addAllValues(Arrays.asList(intData(5), intData(6))))
+                    .addEntries(
+                        Windmill.TagMultimapEntry.newBuilder()
+                            .setEntryName(STATE_MULTIMAP_KEY_2)
+                            .addAllValues(Arrays.asList(intData(15), intData(16)))));
+    when(mockWindmill.getStateData(ArgumentMatchers.eq(COMPUTATION), ArgumentMatchers.any()))
+        .thenReturn(response.build());
+
+    Iterable<Integer> results1 = future1.get();
+    Iterable<Integer> results2 = future2.get();
+
+    final ArgumentCaptor<Windmill.KeyedGetDataRequest> requestCaptor =
+        ArgumentCaptor.forClass(Windmill.KeyedGetDataRequest.class);
+    Mockito.verify(mockWindmill)
+        .getStateData(ArgumentMatchers.eq(COMPUTATION), requestCaptor.capture());
+    assertMultimapFetchRequestEqual(
+        expectedRequest.build().getMultimapsToFetch(0),
+        requestCaptor.getValue().getMultimapsToFetch(0));
+
+    // Iterate over the results to force loading all the pages.
+    for (Integer unused : results1) {}
+    for (Integer unused : results2) {}
+    Mockito.verifyNoMoreInteractions(mockWindmill);
+
+    assertThat(results1, Matchers.containsInAnyOrder(5, 6));
+    assertThat(results2, Matchers.containsInAnyOrder(15, 16));
+    assertNoReader(future1);
+    assertNoReader(future2);
+  }
+
+  @Test
+  public void testReadMultimapMultipleEntriesWithPagination() throws Exception {
+    Future<Iterable<Integer>> future1 =
+        underTest.multimapFetchSingleEntryFuture(
+            STATE_MULTIMAP_KEY_1, STATE_KEY_1, STATE_FAMILY, INT_CODER);
+    Future<Iterable<Integer>> future2 =
+        underTest.multimapFetchSingleEntryFuture(
+            STATE_MULTIMAP_KEY_2, STATE_KEY_1, STATE_FAMILY, INT_CODER);
+    Mockito.verifyNoMoreInteractions(mockWindmill);
+
+    Windmill.KeyedGetDataRequest.Builder expectedRequest1 =
+        Windmill.KeyedGetDataRequest.newBuilder()
+            .setKey(DATA_KEY)
+            .setShardingKey(SHARDING_KEY)
+            .setWorkToken(WORK_TOKEN)
+            .setMaxBytes(WindmillStateReader.MAX_KEY_BYTES)
+            .addMultimapsToFetch(
+                Windmill.TagMultimapFetchRequest.newBuilder()
+                    .setTag(STATE_KEY_1)
+                    .setStateFamily(STATE_FAMILY)
+                    .setFetchEntryNamesOnly(false)
+                    .addEntriesToFetch(
+                        Windmill.TagMultimapEntry.newBuilder()
+                            .setEntryName(STATE_MULTIMAP_KEY_1)
+                            .setFetchMaxBytes(WindmillStateReader.INITIAL_MAX_MULTIMAP_BYTES)
+                            .build())
+                    .addEntriesToFetch(
+                        Windmill.TagMultimapEntry.newBuilder()
+                            .setEntryName(STATE_MULTIMAP_KEY_2)
+                            .setFetchMaxBytes(WindmillStateReader.INITIAL_MAX_MULTIMAP_BYTES)
+                            .build()));
+
+    Windmill.KeyedGetDataResponse.Builder response1 =
+        Windmill.KeyedGetDataResponse.newBuilder()
+            .setKey(DATA_KEY)
+            .addTagMultimaps(
+                Windmill.TagMultimapFetchResponse.newBuilder()
+                    .setTag(STATE_KEY_1)
+                    .setStateFamily(STATE_FAMILY)
+                    .addEntries(
+                        Windmill.TagMultimapEntry.newBuilder()
+                            .setEntryName(STATE_MULTIMAP_KEY_1)
+                            .addAllValues(Arrays.asList(intData(5), intData(6)))
+                            .setContinuationPosition(800))
+                    .addEntries(
+                        Windmill.TagMultimapEntry.newBuilder()
+                            .setEntryName(STATE_MULTIMAP_KEY_2)
+                            .addAllValues(Arrays.asList(intData(15), intData(16)))));
+    Windmill.KeyedGetDataRequest.Builder expectedRequest2 =
+        Windmill.KeyedGetDataRequest.newBuilder()
+            .setKey(DATA_KEY)
+            .setShardingKey(SHARDING_KEY)
+            .setWorkToken(WORK_TOKEN)
+            .setMaxBytes(WindmillStateReader.MAX_CONTINUATION_KEY_BYTES)
+            .addMultimapsToFetch(
+                Windmill.TagMultimapFetchRequest.newBuilder()
+                    .setTag(STATE_KEY_1)
+                    .setStateFamily(STATE_FAMILY)
+                    .setFetchEntryNamesOnly(false)
+                    .addEntriesToFetch(
+                        Windmill.TagMultimapEntry.newBuilder()
+                            .setEntryName(STATE_MULTIMAP_KEY_1)
+                            .setFetchMaxBytes(WindmillStateReader.CONTINUATION_MAX_MULTIMAP_BYTES)
+                            .setRequestPosition(800)
+                            .build()));
+    Windmill.KeyedGetDataResponse.Builder response2 =
+        Windmill.KeyedGetDataResponse.newBuilder()
+            .setKey(DATA_KEY)
+            .addTagMultimaps(
+                Windmill.TagMultimapFetchResponse.newBuilder()
+                    .setTag(STATE_KEY_1)
+                    .setStateFamily(STATE_FAMILY)
+                    .addEntries(
+                        Windmill.TagMultimapEntry.newBuilder()
+                            .setEntryName(STATE_MULTIMAP_KEY_1)
+                            .addAllValues(Arrays.asList(intData(7), intData(8)))
+                            .setRequestPosition(800)));
+    when(mockWindmill.getStateData(ArgumentMatchers.eq(COMPUTATION), ArgumentMatchers.any()))
+        .thenReturn(response1.build())
+        .thenReturn(response2.build());
+
+    Iterable<Integer> results1 = future1.get();
+    Iterable<Integer> results2 = future2.get();
+
+    // Iterate over the results to force loading all the pages.
+    for (Integer unused : results1) {}
+    for (Integer unused : results2) {}
+
+    final ArgumentCaptor<Windmill.KeyedGetDataRequest> requestCaptor =
+        ArgumentCaptor.forClass(Windmill.KeyedGetDataRequest.class);
+    Mockito.verify(mockWindmill, times(2))
+        .getStateData(ArgumentMatchers.eq(COMPUTATION), requestCaptor.capture());
+    assertMultimapFetchRequestEqual(
+        expectedRequest1.build().getMultimapsToFetch(0),
+        requestCaptor.getAllValues().get(0).getMultimapsToFetch(0));
+    assertMultimapFetchRequestEqual(
+        expectedRequest2.build().getMultimapsToFetch(0),
+        requestCaptor.getAllValues().get(1).getMultimapsToFetch(0));
+    Mockito.verifyNoMoreInteractions(mockWindmill);
+
+    assertThat(results1, Matchers.containsInAnyOrder(5, 6, 7, 8));
+    assertThat(results2, Matchers.containsInAnyOrder(15, 16));
+    // NOTE: The future will still contain a reference to the underlying reader.
+  }
+
+  @Test
+  public void testReadMultimapKeys() throws Exception {
+    Future<Iterable<Map.Entry<ByteString, Iterable<Integer>>>> future =
+        underTest.multimapFetchAllFuture(true, STATE_KEY_1, STATE_FAMILY, INT_CODER);
+    Mockito.verifyNoMoreInteractions(mockWindmill);
+
+    Windmill.KeyedGetDataRequest.Builder expectedRequest =
+        Windmill.KeyedGetDataRequest.newBuilder()
+            .setKey(DATA_KEY)
+            .setShardingKey(SHARDING_KEY)
+            .setWorkToken(WORK_TOKEN)
+            .setMaxBytes(WindmillStateReader.MAX_KEY_BYTES)
+            .addMultimapsToFetch(
+                Windmill.TagMultimapFetchRequest.newBuilder()
+                    .setTag(STATE_KEY_1)
+                    .setStateFamily(STATE_FAMILY)
+                    .setFetchEntryNamesOnly(true)
+                    .setFetchMaxBytes(WindmillStateReader.INITIAL_MAX_MULTIMAP_BYTES));
+
+    Windmill.KeyedGetDataResponse.Builder response =
+        Windmill.KeyedGetDataResponse.newBuilder()
+            .setKey(DATA_KEY)
+            .addTagMultimaps(
+                Windmill.TagMultimapFetchResponse.newBuilder()
+                    .setTag(STATE_KEY_1)
+                    .setStateFamily(STATE_FAMILY)
+                    .addEntries(
+                        Windmill.TagMultimapEntry.newBuilder().setEntryName(STATE_MULTIMAP_KEY_1))
+                    .addEntries(
+                        Windmill.TagMultimapEntry.newBuilder().setEntryName(STATE_MULTIMAP_KEY_2)));
+    Mockito.when(mockWindmill.getStateData(COMPUTATION, expectedRequest.build()))
+        .thenReturn(response.build());
+
+    Iterable<Map.Entry<ByteString, Iterable<Integer>>> results = future.get();
+    Mockito.verify(mockWindmill).getStateData(COMPUTATION, expectedRequest.build());
+    List<ByteString> keys = Lists.newArrayList();
+    for (Map.Entry<ByteString, Iterable<Integer>> entry : results) {
+      keys.add(entry.getKey());
+      assertEquals(0, Iterables.size(entry.getValue()));
+    }
+    Mockito.verifyNoMoreInteractions(mockWindmill);
+
+    assertThat(keys, Matchers.containsInAnyOrder(STATE_MULTIMAP_KEY_1, STATE_MULTIMAP_KEY_2));
+    assertNoReader(future);
+  }
+
+  @Test
+  public void testReadMultimapKeysPaginated() throws Exception {
+    Future<Iterable<Map.Entry<ByteString, Iterable<Integer>>>> future =
+        underTest.multimapFetchAllFuture(true, STATE_KEY_1, STATE_FAMILY, INT_CODER);
+    Mockito.verifyNoMoreInteractions(mockWindmill);
+
+    Windmill.KeyedGetDataRequest.Builder expectedRequest1 =
+        Windmill.KeyedGetDataRequest.newBuilder()
+            .setKey(DATA_KEY)
+            .setShardingKey(SHARDING_KEY)
+            .setWorkToken(WORK_TOKEN)
+            .setMaxBytes(WindmillStateReader.MAX_KEY_BYTES)
+            .addMultimapsToFetch(
+                Windmill.TagMultimapFetchRequest.newBuilder()
+                    .setTag(STATE_KEY_1)
+                    .setStateFamily(STATE_FAMILY)
+                    .setFetchEntryNamesOnly(true)
+                    .setFetchMaxBytes(WindmillStateReader.INITIAL_MAX_MULTIMAP_BYTES));
+
+    Windmill.KeyedGetDataResponse.Builder response1 =
+        Windmill.KeyedGetDataResponse.newBuilder()
+            .setKey(DATA_KEY)
+            .addTagMultimaps(
+                Windmill.TagMultimapFetchResponse.newBuilder()
+                    .setTag(STATE_KEY_1)
+                    .setStateFamily(STATE_FAMILY)
+                    .addEntries(
+                        Windmill.TagMultimapEntry.newBuilder().setEntryName(STATE_MULTIMAP_KEY_1))
+                    .setContinuationPosition(STATE_MULTIMAP_CONT_1));
+
+    Windmill.KeyedGetDataRequest.Builder expectedRequest2 =
+        Windmill.KeyedGetDataRequest.newBuilder()
+            .setKey(DATA_KEY)
+            .setShardingKey(SHARDING_KEY)
+            .setWorkToken(WORK_TOKEN)
+            .setMaxBytes(WindmillStateReader.MAX_CONTINUATION_KEY_BYTES)
+            .addMultimapsToFetch(
+                Windmill.TagMultimapFetchRequest.newBuilder()
+                    .setTag(STATE_KEY_1)
+                    .setStateFamily(STATE_FAMILY)
+                    .setFetchEntryNamesOnly(true)
+                    .setFetchMaxBytes(WindmillStateReader.CONTINUATION_MAX_MULTIMAP_BYTES)
+                    .setRequestPosition(STATE_MULTIMAP_CONT_1));
+
+    Windmill.KeyedGetDataResponse.Builder response2 =
+        Windmill.KeyedGetDataResponse.newBuilder()
+            .setKey(DATA_KEY)
+            .addTagMultimaps(
+                Windmill.TagMultimapFetchResponse.newBuilder()
+                    .setTag(STATE_KEY_1)
+                    .setStateFamily(STATE_FAMILY)
+                    .addEntries(
+                        Windmill.TagMultimapEntry.newBuilder().setEntryName(STATE_MULTIMAP_KEY_2))
+                    .setRequestPosition(STATE_MULTIMAP_CONT_1)
+                    .setContinuationPosition(STATE_MULTIMAP_CONT_2));
+    Windmill.KeyedGetDataRequest.Builder expectedRequest3 =
+        Windmill.KeyedGetDataRequest.newBuilder()
+            .setKey(DATA_KEY)
+            .setShardingKey(SHARDING_KEY)
+            .setWorkToken(WORK_TOKEN)
+            .setMaxBytes(WindmillStateReader.MAX_CONTINUATION_KEY_BYTES)
+            .addMultimapsToFetch(
+                Windmill.TagMultimapFetchRequest.newBuilder()
+                    .setTag(STATE_KEY_1)
+                    .setStateFamily(STATE_FAMILY)
+                    .setFetchEntryNamesOnly(true)
+                    .setFetchMaxBytes(WindmillStateReader.CONTINUATION_MAX_MULTIMAP_BYTES)
+                    .setRequestPosition(STATE_MULTIMAP_CONT_2));
+
+    Windmill.KeyedGetDataResponse.Builder response3 =
+        Windmill.KeyedGetDataResponse.newBuilder()
+            .setKey(DATA_KEY)
+            .addTagMultimaps(
+                Windmill.TagMultimapFetchResponse.newBuilder()
+                    .setTag(STATE_KEY_1)
+                    .setStateFamily(STATE_FAMILY)
+                    .addEntries(
+                        Windmill.TagMultimapEntry.newBuilder().setEntryName(STATE_MULTIMAP_KEY_3))
+                    .setRequestPosition(STATE_MULTIMAP_CONT_2));
+    Mockito.when(mockWindmill.getStateData(COMPUTATION, expectedRequest1.build()))
+        .thenReturn(response1.build());
+    Mockito.when(mockWindmill.getStateData(COMPUTATION, expectedRequest2.build()))
+        .thenReturn(response2.build());
+    Mockito.when(mockWindmill.getStateData(COMPUTATION, expectedRequest3.build()))
+        .thenReturn(response3.build());
+
+    Iterable<Map.Entry<ByteString, Iterable<Integer>>> results = future.get();
+    Mockito.verify(mockWindmill).getStateData(COMPUTATION, expectedRequest1.build());
+    List<ByteString> keys = Lists.newArrayList();
+    for (Map.Entry<ByteString, Iterable<Integer>> entry : results) {
+      keys.add(entry.getKey());
+      assertEquals(0, Iterables.size(entry.getValue()));
+    }
+    Mockito.verify(mockWindmill).getStateData(COMPUTATION, expectedRequest2.build());
+    Mockito.verify(mockWindmill).getStateData(COMPUTATION, expectedRequest3.build());
+    Mockito.verifyNoMoreInteractions(mockWindmill);
+
+    assertThat(
+        keys,
+        Matchers.containsInAnyOrder(
+            STATE_MULTIMAP_KEY_1, STATE_MULTIMAP_KEY_2, STATE_MULTIMAP_KEY_3));
+    // NOTE: The future will still contain a reference to the underlying reader.
+  }
+
+  @Test
+  public void testReadMultimapAllEntries() throws Exception {
+    Future<Iterable<Map.Entry<ByteString, Iterable<Integer>>>> future =
+        underTest.multimapFetchAllFuture(false, STATE_KEY_1, STATE_FAMILY, INT_CODER);
+    Mockito.verifyNoMoreInteractions(mockWindmill);
+
+    Windmill.KeyedGetDataRequest.Builder expectedRequest =
+        Windmill.KeyedGetDataRequest.newBuilder()
+            .setKey(DATA_KEY)
+            .setShardingKey(SHARDING_KEY)
+            .setWorkToken(WORK_TOKEN)
+            .setMaxBytes(WindmillStateReader.MAX_KEY_BYTES)
+            .addMultimapsToFetch(
+                Windmill.TagMultimapFetchRequest.newBuilder()
+                    .setTag(STATE_KEY_1)
+                    .setStateFamily(STATE_FAMILY)
+                    .setFetchEntryNamesOnly(false)
+                    .setFetchMaxBytes(WindmillStateReader.INITIAL_MAX_MULTIMAP_BYTES));
+
+    Windmill.KeyedGetDataResponse.Builder response =
+        Windmill.KeyedGetDataResponse.newBuilder()
+            .setKey(DATA_KEY)
+            .addTagMultimaps(
+                Windmill.TagMultimapFetchResponse.newBuilder()
+                    .setTag(STATE_KEY_1)
+                    .setStateFamily(STATE_FAMILY)
+                    .addEntries(
+                        Windmill.TagMultimapEntry.newBuilder()
+                            .setEntryName(STATE_MULTIMAP_KEY_1)
+                            .addValues(intData(1))
+                            .addValues(intData(2)))
+                    .addEntries(
+                        Windmill.TagMultimapEntry.newBuilder()
+                            .setEntryName(STATE_MULTIMAP_KEY_2)
+                            .addValues(intData(10))
+                            .addValues(intData(20))));
+    Mockito.when(mockWindmill.getStateData(COMPUTATION, expectedRequest.build()))
+        .thenReturn(response.build());
+
+    Iterable<Map.Entry<ByteString, Iterable<Integer>>> results = future.get();
+    Mockito.verify(mockWindmill).getStateData(COMPUTATION, expectedRequest.build());
+    int foundEntries = 0;
+    for (Map.Entry<ByteString, Iterable<Integer>> entry : results) {
+      if (entry.getKey().equals(STATE_MULTIMAP_KEY_1)) {
+        foundEntries++;
+        assertThat(entry.getValue(), Matchers.containsInAnyOrder(1, 2));
+      } else {
+        foundEntries++;
+        assertEquals(STATE_MULTIMAP_KEY_2, entry.getKey());
+        assertThat(entry.getValue(), Matchers.containsInAnyOrder(10, 20));
+      }
+    }
+    assertEquals(2, foundEntries);
+    Mockito.verifyNoMoreInteractions(mockWindmill);
+    assertNoReader(future);
+  }
+
+  private static void assertMultimapEntries(
+      Iterable<Map.Entry<ByteString, Iterable<Integer>>> expected,
+      List<Map.Entry<ByteString, List<Integer>>> actual) {
+    Map<ByteString, List<Integer>> expectedMap = Maps.newHashMap();
+    for (Map.Entry<ByteString, Iterable<Integer>> entry : expected) {
+      ByteString key = entry.getKey();
+      if (!expectedMap.containsKey(key)) expectedMap.put(key, new ArrayList<>());
+      entry.getValue().forEach(expectedMap.get(key)::add);
+    }
+    for (Map.Entry<ByteString, List<Integer>> entry : actual) {
+      assertThat(
+          entry.getValue(),
+          Matchers.containsInAnyOrder(expectedMap.remove(entry.getKey()).toArray()));
+    }
+    assertTrue(expectedMap.isEmpty());
+  }
+
+  @Test
+  public void testReadMultimapEntriesPaginated() throws Exception {
+    Future<Iterable<Map.Entry<ByteString, Iterable<Integer>>>> future =
+        underTest.multimapFetchAllFuture(false, STATE_KEY_1, STATE_FAMILY, INT_CODER);
+    Mockito.verifyNoMoreInteractions(mockWindmill);
+
+    Windmill.KeyedGetDataRequest.Builder expectedRequest1 =
+        Windmill.KeyedGetDataRequest.newBuilder()
+            .setKey(DATA_KEY)
+            .setShardingKey(SHARDING_KEY)
+            .setWorkToken(WORK_TOKEN)
+            .setMaxBytes(WindmillStateReader.MAX_KEY_BYTES)
+            .addMultimapsToFetch(
+                Windmill.TagMultimapFetchRequest.newBuilder()
+                    .setTag(STATE_KEY_1)
+                    .setStateFamily(STATE_FAMILY)
+                    .setFetchEntryNamesOnly(false)
+                    .setFetchMaxBytes(WindmillStateReader.INITIAL_MAX_MULTIMAP_BYTES));
+
+    Windmill.KeyedGetDataResponse.Builder response1 =
+        Windmill.KeyedGetDataResponse.newBuilder()
+            .setKey(DATA_KEY)
+            .addTagMultimaps(
+                Windmill.TagMultimapFetchResponse.newBuilder()
+                    .setTag(STATE_KEY_1)
+                    .setStateFamily(STATE_FAMILY)
+                    .addEntries(
+                        Windmill.TagMultimapEntry.newBuilder()
+                            .setEntryName(STATE_MULTIMAP_KEY_1)
+                            .addValues(intData(1))
+                            .addValues(intData(2)))
+                    .addEntries(
+                        Windmill.TagMultimapEntry.newBuilder()
+                            .setEntryName(STATE_MULTIMAP_KEY_2)
+                            .addValues(intData(3))
+                            .addValues(intData(3)))
+                    .setContinuationPosition(STATE_MULTIMAP_CONT_1));
+
+    Windmill.KeyedGetDataRequest.Builder expectedRequest2 =
+        Windmill.KeyedGetDataRequest.newBuilder()
+            .setKey(DATA_KEY)
+            .setShardingKey(SHARDING_KEY)
+            .setWorkToken(WORK_TOKEN)
+            .setMaxBytes(WindmillStateReader.MAX_CONTINUATION_KEY_BYTES)
+            .addMultimapsToFetch(
+                Windmill.TagMultimapFetchRequest.newBuilder()
+                    .setTag(STATE_KEY_1)
+                    .setStateFamily(STATE_FAMILY)
+                    .setFetchEntryNamesOnly(false)
+                    .setFetchMaxBytes(WindmillStateReader.CONTINUATION_MAX_MULTIMAP_BYTES)
+                    .setRequestPosition(STATE_MULTIMAP_CONT_1));
+
+    Windmill.KeyedGetDataResponse.Builder response2 =
+        Windmill.KeyedGetDataResponse.newBuilder()
+            .setKey(DATA_KEY)
+            .addTagMultimaps(
+                Windmill.TagMultimapFetchResponse.newBuilder()
+                    .setTag(STATE_KEY_1)
+                    .setStateFamily(STATE_FAMILY)
+                    .addEntries(
+                        Windmill.TagMultimapEntry.newBuilder()
+                            .setEntryName(STATE_MULTIMAP_KEY_2)
+                            .addValues(intData(2)))
+                    .setRequestPosition(STATE_MULTIMAP_CONT_1)
+                    .setContinuationPosition(STATE_MULTIMAP_CONT_2));
+    Windmill.KeyedGetDataRequest.Builder expectedRequest3 =
+        Windmill.KeyedGetDataRequest.newBuilder()
+            .setKey(DATA_KEY)
+            .setShardingKey(SHARDING_KEY)
+            .setWorkToken(WORK_TOKEN)
+            .setMaxBytes(WindmillStateReader.MAX_CONTINUATION_KEY_BYTES)
+            .addMultimapsToFetch(
+                Windmill.TagMultimapFetchRequest.newBuilder()
+                    .setTag(STATE_KEY_1)
+                    .setStateFamily(STATE_FAMILY)
+                    .setFetchEntryNamesOnly(false)
+                    .setFetchMaxBytes(WindmillStateReader.CONTINUATION_MAX_MULTIMAP_BYTES)
+                    .setRequestPosition(STATE_MULTIMAP_CONT_2));
+
+    Windmill.KeyedGetDataResponse.Builder response3 =
+        Windmill.KeyedGetDataResponse.newBuilder()
+            .setKey(DATA_KEY)
+            .addTagMultimaps(
+                Windmill.TagMultimapFetchResponse.newBuilder()
+                    .setTag(STATE_KEY_1)
+                    .setStateFamily(STATE_FAMILY)
+                    .addEntries(
+                        Windmill.TagMultimapEntry.newBuilder()
+                            .setEntryName(STATE_MULTIMAP_KEY_2)
+                            .addValues(intData(4)))
+                    .setRequestPosition(STATE_MULTIMAP_CONT_2));
+    Mockito.when(mockWindmill.getStateData(COMPUTATION, expectedRequest1.build()))
+        .thenReturn(response1.build());
+    Mockito.when(mockWindmill.getStateData(COMPUTATION, expectedRequest2.build()))
+        .thenReturn(response2.build());
+    Mockito.when(mockWindmill.getStateData(COMPUTATION, expectedRequest3.build()))
+        .thenReturn(response3.build());
+
+    Iterable<Map.Entry<ByteString, Iterable<Integer>>> results = future.get();
+    Mockito.verify(mockWindmill).getStateData(COMPUTATION, expectedRequest1.build());
+    assertMultimapEntries(
+        results,
+        Arrays.asList(
+            new AbstractMap.SimpleEntry<>(STATE_MULTIMAP_KEY_1, Arrays.asList(1, 2)),
+            new AbstractMap.SimpleEntry<>(STATE_MULTIMAP_KEY_2, Arrays.asList(3, 3, 2, 4))));
+    Mockito.verify(mockWindmill).getStateData(COMPUTATION, expectedRequest2.build());
+    Mockito.verify(mockWindmill).getStateData(COMPUTATION, expectedRequest3.build());
+    Mockito.verifyNoMoreInteractions(mockWindmill);
+    // NOTE: The future will still contain a reference to the underlying reader.

Review Comment:
   Done



##########
runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternalsTest.java:
##########
@@ -645,6 +655,1163 @@ public void testMapComplexPersist() throws Exception {
     assertEquals(0, commitBuilder.getValueUpdatesCount());
   }
 
+  private static <T> ByteString encodeWithCoder(T key, Coder<T> coder) throws IOException {
+    ByteStringOutputStream out = new ByteStringOutputStream();
+    coder.encode(key, out, Context.OUTER);
+    return out.toByteString();
+  }
+
+  // We use the structural value of the Multimap keys to differentiate between different keys. So we
+  // mix using the original key object and a duplicate but same key object so make sure the
+  // correctness.
+  private static byte[] dup(byte[] key) {
+    byte[] res = new byte[key.length];
+    System.arraycopy(key, 0, res, 0, key.length);
+    return res;
+  }
+
+  @Test
+  public void testMultimapGet() throws IOException {
+    final String tag = "multimap";
+    StateTag<MultimapState<byte[], Integer>> addr =
+        StateTags.multimap(tag, ByteArrayCoder.of(), VarIntCoder.of());
+    MultimapState<byte[], Integer> multimapState = underTest.state(NAMESPACE, addr);
+
+    final byte[] key = "key".getBytes(StandardCharsets.UTF_8);
+    SettableFuture<Iterable<Integer>> future = SettableFuture.create();
+    when(mockReader.multimapFetchSingleEntryFuture(
+            encodeWithCoder(key, ByteArrayCoder.of()),
+            key(NAMESPACE, tag),
+            STATE_FAMILY,
+            VarIntCoder.of()))
+        .thenReturn(future);
+
+    ReadableState<Iterable<Integer>> result = multimapState.get(dup(key)).readLater();
+    waitAndSet(future, Arrays.asList(1, 2, 3), 30);
+    assertThat(result.read(), Matchers.containsInAnyOrder(1, 2, 3));
+  }
+
+  @Test
+  public void testMultimapPutAndGet() throws IOException {
+    final String tag = "multimap";
+    StateTag<MultimapState<byte[], Integer>> addr =
+        StateTags.multimap(tag, ByteArrayCoder.of(), VarIntCoder.of());
+    MultimapState<byte[], Integer> multimapState = underTest.state(NAMESPACE, addr);
+
+    final byte[] key = "key".getBytes(StandardCharsets.UTF_8);
+    SettableFuture<Iterable<Integer>> future = SettableFuture.create();
+    when(mockReader.multimapFetchSingleEntryFuture(
+            encodeWithCoder(key, ByteArrayCoder.of()),
+            key(NAMESPACE, tag),
+            STATE_FAMILY,
+            VarIntCoder.of()))
+        .thenReturn(future);
+
+    multimapState.put(key, 1);
+    ReadableState<Iterable<Integer>> result = multimapState.get(dup(key)).readLater();
+    waitAndSet(future, Arrays.asList(1, 2, 3), 30);
+    assertThat(result.read(), Matchers.containsInAnyOrder(1, 1, 2, 3));
+  }
+
+  @Test
+  public void testMultimapRemoveAndGet() throws IOException {
+    final String tag = "multimap";
+    StateTag<MultimapState<byte[], Integer>> addr =
+        StateTags.multimap(tag, ByteArrayCoder.of(), VarIntCoder.of());
+    MultimapState<byte[], Integer> multimapState = underTest.state(NAMESPACE, addr);
+
+    final byte[] key = "key".getBytes(StandardCharsets.UTF_8);
+    SettableFuture<Iterable<Integer>> future = SettableFuture.create();
+    when(mockReader.multimapFetchSingleEntryFuture(
+            encodeWithCoder(key, ByteArrayCoder.of()),
+            key(NAMESPACE, tag),
+            STATE_FAMILY,
+            VarIntCoder.of()))
+        .thenReturn(future);
+
+    ReadableState<Iterable<Integer>> result1 = multimapState.get(key).readLater();
+    ReadableState<Iterable<Integer>> result2 = multimapState.get(dup(key)).readLater();
+    waitAndSet(future, Arrays.asList(1, 2, 3), 30);
+
+    assertTrue(multimapState.containsKey(key).read());
+    assertThat(result1.read(), Matchers.containsInAnyOrder(1, 2, 3));
+
+    multimapState.remove(key);
+    assertFalse(multimapState.containsKey(dup(key)).read());
+    assertThat(result2.read(), Matchers.emptyIterable());
+  }
+
+  @Test
+  public void testMultimapRemoveThenPut() throws IOException {
+    final String tag = "multimap";
+    StateTag<MultimapState<byte[], Integer>> addr =
+        StateTags.multimap(tag, ByteArrayCoder.of(), VarIntCoder.of());
+    MultimapState<byte[], Integer> multimapState = underTest.state(NAMESPACE, addr);
+
+    final byte[] key = "key".getBytes(StandardCharsets.UTF_8);
+    SettableFuture<Iterable<Integer>> future = SettableFuture.create();
+    when(mockReader.multimapFetchSingleEntryFuture(
+            encodeWithCoder(key, ByteArrayCoder.of()),
+            key(NAMESPACE, tag),
+            STATE_FAMILY,
+            VarIntCoder.of()))
+        .thenReturn(future);
+
+    ReadableState<Iterable<Integer>> result = multimapState.get(key).readLater();
+    waitAndSet(future, Arrays.asList(1, 2, 3), 30);
+    multimapState.remove(dup(key));
+    multimapState.put(key, 4);
+    multimapState.put(dup(key), 5);
+    assertThat(result.read(), Matchers.containsInAnyOrder(4, 5));
+  }
+
+  @Test
+  public void testMultimapRemovePersistPut() {
+    final String tag = "multimap";
+    StateTag<MultimapState<String, Integer>> addr =
+        StateTags.multimap(tag, StringUtf8Coder.of(), VarIntCoder.of());
+    MultimapState<String, Integer> multimapState = underTest.state(NAMESPACE, addr);
+
+    final String key = "key";
+    multimapState.put(key, 1);
+    multimapState.put(key, 2);
+
+    Windmill.WorkItemCommitRequest.Builder commitBuilder =
+        Windmill.WorkItemCommitRequest.newBuilder();
+
+    // After key is removed, this key is cache complete and no need to read backend.
+    multimapState.remove(key);
+    multimapState.put(key, 4);
+    // Since key is cache complete, value 4 in localAdditions should be added to cached values,
+    /// instead of being cleared from cache after persisted.
+    underTest.persist(commitBuilder);
+    assertTagMultimapUpdates(
+        Iterables.getOnlyElement(commitBuilder.getMultimapUpdatesBuilderList()),
+        new MultimapEntryUpdate(key, Arrays.asList(4), true));
+
+    multimapState.put(key, 5);
+    assertThat(multimapState.get(key).read(), Matchers.containsInAnyOrder(4, 5));
+  }
+
+  @Test
+  public void testMultimapGetLocalCombineStorage() throws IOException {
+    final String tag = "multimap";
+    StateTag<MultimapState<byte[], Integer>> addr =
+        StateTags.multimap(tag, ByteArrayCoder.of(), VarIntCoder.of());
+    MultimapState<byte[], Integer> multimapState = underTest.state(NAMESPACE, addr);
+
+    final byte[] key = "key".getBytes(StandardCharsets.UTF_8);
+    SettableFuture<Iterable<Integer>> future = SettableFuture.create();
+    when(mockReader.multimapFetchSingleEntryFuture(
+            encodeWithCoder(key, ByteArrayCoder.of()),
+            key(NAMESPACE, tag),
+            STATE_FAMILY,
+            VarIntCoder.of()))
+        .thenReturn(future);
+
+    ReadableState<Iterable<Integer>> result = multimapState.get(dup(key)).readLater();
+    waitAndSet(future, Arrays.asList(1, 2), 30);
+    multimapState.put(key, 3);
+    multimapState.put(dup(key), 4);
+    assertFalse(multimapState.isEmpty().read());
+    assertThat(result.read(), Matchers.containsInAnyOrder(1, 2, 3, 4));
+  }
+
+  @Test
+  public void testMultimapLocalRemoveOverrideStorage() throws IOException {
+    final String tag = "multimap";
+    StateTag<MultimapState<byte[], Integer>> addr =
+        StateTags.multimap(tag, ByteArrayCoder.of(), VarIntCoder.of());
+    MultimapState<byte[], Integer> multimapState = underTest.state(NAMESPACE, addr);
+
+    final byte[] key = "key".getBytes(StandardCharsets.UTF_8);
+    SettableFuture<Iterable<Integer>> future = SettableFuture.create();
+    when(mockReader.multimapFetchSingleEntryFuture(
+            encodeWithCoder(key, ByteArrayCoder.of()),
+            key(NAMESPACE, tag),
+            STATE_FAMILY,
+            VarIntCoder.of()))
+        .thenReturn(future);
+
+    ReadableState<Iterable<Integer>> result = multimapState.get(key).readLater();
+    waitAndSet(future, Arrays.asList(1, 2), 30);
+    multimapState.remove(dup(key));
+    assertThat(result.read(), Matchers.emptyIterable());
+    multimapState.put(key, 3);
+    multimapState.put(dup(key), 4);
+    assertFalse(multimapState.isEmpty().read());
+    assertThat(result.read(), Matchers.containsInAnyOrder(3, 4));
+  }
+
+  @Test
+  public void testMultimapLocalClearOverrideStorage() throws IOException {
+    final String tag = "multimap";
+    StateTag<MultimapState<byte[], Integer>> addr =
+        StateTags.multimap(tag, ByteArrayCoder.of(), VarIntCoder.of());
+    MultimapState<byte[], Integer> multimapState = underTest.state(NAMESPACE, addr);
+
+    final byte[] key1 = "key1".getBytes(StandardCharsets.UTF_8);
+    final byte[] key2 = "key2".getBytes(StandardCharsets.UTF_8);
+    SettableFuture<Iterable<Integer>> future = SettableFuture.create();
+    when(mockReader.multimapFetchSingleEntryFuture(
+            encodeWithCoder(key1, ByteArrayCoder.of()),
+            key(NAMESPACE, tag),
+            STATE_FAMILY,
+            VarIntCoder.of()))
+        .thenReturn(future);
+    SettableFuture<Iterable<Integer>> future2 = SettableFuture.create();
+    when(mockReader.multimapFetchSingleEntryFuture(
+            encodeWithCoder(key2, ByteArrayCoder.of()),
+            key(NAMESPACE, tag),
+            STATE_FAMILY,
+            VarIntCoder.of()))
+        .thenReturn(future2);
+
+    ReadableState<Iterable<Integer>> result1 = multimapState.get(key1).readLater();
+    ReadableState<Iterable<Integer>> result2 = multimapState.get(dup(key2)).readLater();
+    multimapState.clear();
+    waitAndSet(future, Arrays.asList(1, 2), 30);
+    assertThat(result1.read(), Matchers.emptyIterable());
+    assertThat(result2.read(), Matchers.emptyIterable());
+    assertThat(multimapState.keys().read(), Matchers.emptyIterable());
+    assertThat(multimapState.entries().read(), Matchers.emptyIterable());
+    assertTrue(multimapState.isEmpty().read());
+  }
+
+  private static Map.Entry<ByteString, Iterable<Integer>> multimapEntry(
+      byte[] key, Integer... values) throws IOException {
+    return new AbstractMap.SimpleEntry<>(
+        encodeWithCoder(key, ByteArrayCoder.of()), Arrays.asList(values));
+  }
+
+  @SafeVarargs
+  private static <T> List<T> weightedList(T... entries) {
+    WindmillStateReader.WeightedList<T> list =
+        new WindmillStateReader.WeightedList<>(new ArrayList<>());
+    for (T entry : entries) {
+      list.addWeighted(entry, 1);
+    }
+    return list;
+  }
+
+  @Test
+  public void testMultimapBasicEntriesAndKeys() throws IOException {
+    final String tag = "multimap";
+    StateTag<MultimapState<byte[], Integer>> addr =
+        StateTags.multimap(tag, ByteArrayCoder.of(), VarIntCoder.of());
+    MultimapState<byte[], Integer> multimapState = underTest.state(NAMESPACE, addr);
+
+    final byte[] key1 = "key1".getBytes(StandardCharsets.UTF_8);
+    final byte[] key2 = "key2".getBytes(StandardCharsets.UTF_8);
+
+    SettableFuture<Iterable<Map.Entry<ByteString, Iterable<Integer>>>> entriesFuture =
+        SettableFuture.create();
+    when(mockReader.multimapFetchAllFuture(
+            false, key(NAMESPACE, tag), STATE_FAMILY, VarIntCoder.of()))
+        .thenReturn(entriesFuture);
+    SettableFuture<Iterable<Map.Entry<ByteString, Iterable<Integer>>>> keysFuture =
+        SettableFuture.create();
+    when(mockReader.multimapFetchAllFuture(
+            true, key(NAMESPACE, tag), STATE_FAMILY, VarIntCoder.of()))
+        .thenReturn(keysFuture);
+
+    ReadableState<Iterable<Map.Entry<byte[], Integer>>> entriesResult =
+        multimapState.entries().readLater();
+    ReadableState<Iterable<byte[]>> keysResult = multimapState.keys().readLater();
+    waitAndSet(
+        entriesFuture,
+        Arrays.asList(multimapEntry(key1, 1, 2, 3), multimapEntry(key2, 2, 3, 4)),
+        30);
+    waitAndSet(keysFuture, Arrays.asList(multimapEntry(key1), multimapEntry(key2)), 30);
+
+    Iterable<Map.Entry<byte[], Integer>> entries = entriesResult.read();
+    assertEquals(6, Iterables.size(entries));
+    assertThat(
+        entries,
+        Matchers.containsInAnyOrder(
+            multimapEntryMatcher(key1, 1),
+            multimapEntryMatcher(key1, 2),
+            multimapEntryMatcher(key1, 3),
+            multimapEntryMatcher(key2, 4),
+            multimapEntryMatcher(key2, 2),
+            multimapEntryMatcher(key2, 3)));
+
+    Iterable<byte[]> keys = keysResult.read();
+    assertEquals(2, Iterables.size(keys));
+    assertThat(keys, Matchers.containsInAnyOrder(key1, key2));
+  }
+
+  private static CombinableMatcher<Object> multimapEntryMatcher(byte[] key, Integer value) {
+    return Matchers.both(Matchers.hasProperty("key", Matchers.equalTo(key)))
+        .and(Matchers.hasProperty("value", Matchers.equalTo(value)));
+  }
+
+  @Test
+  public void testMultimapEntriesAndKeysMergeLocalAdd() throws IOException {
+    final String tag = "multimap";
+    StateTag<MultimapState<byte[], Integer>> addr =
+        StateTags.multimap(tag, ByteArrayCoder.of(), VarIntCoder.of());
+    MultimapState<byte[], Integer> multimapState = underTest.state(NAMESPACE, addr);
+
+    final byte[] key1 = "key1".getBytes(StandardCharsets.UTF_8);
+    final byte[] key2 = "key2".getBytes(StandardCharsets.UTF_8);
+    final byte[] key3 = "key3".getBytes(StandardCharsets.UTF_8);
+
+    SettableFuture<Iterable<Map.Entry<ByteString, Iterable<Integer>>>> entriesFuture =
+        SettableFuture.create();
+    when(mockReader.multimapFetchAllFuture(
+            false, key(NAMESPACE, tag), STATE_FAMILY, VarIntCoder.of()))
+        .thenReturn(entriesFuture);
+    SettableFuture<Iterable<Map.Entry<ByteString, Iterable<Integer>>>> keysFuture =
+        SettableFuture.create();
+    when(mockReader.multimapFetchAllFuture(
+            true, key(NAMESPACE, tag), STATE_FAMILY, VarIntCoder.of()))
+        .thenReturn(keysFuture);
+
+    ReadableState<Iterable<Map.Entry<byte[], Integer>>> entriesResult =
+        multimapState.entries().readLater();
+    ReadableState<Iterable<byte[]>> keysResult = multimapState.keys().readLater();
+    waitAndSet(
+        entriesFuture,
+        Arrays.asList(multimapEntry(key1, 1, 2, 3), multimapEntry(key2, 2, 3, 4)),
+        30);
+    waitAndSet(keysFuture, Arrays.asList(multimapEntry(key1), multimapEntry(key2)), 30);
+
+    multimapState.put(key1, 7);
+    multimapState.put(dup(key2), 8);
+    multimapState.put(dup(key3), 8);
+
+    Iterable<Map.Entry<byte[], Integer>> entries = entriesResult.read();
+    assertEquals(9, Iterables.size(entries));
+    assertThat(
+        entries,
+        Matchers.containsInAnyOrder(
+            multimapEntryMatcher(key1, 1),
+            multimapEntryMatcher(key1, 2),
+            multimapEntryMatcher(key1, 3),
+            multimapEntryMatcher(key1, 7),
+            multimapEntryMatcher(key2, 4),
+            multimapEntryMatcher(key2, 2),
+            multimapEntryMatcher(key2, 3),
+            multimapEntryMatcher(key2, 8),
+            multimapEntryMatcher(key3, 8)));
+
+    Iterable<byte[]> keys = keysResult.read();
+    assertEquals(3, Iterables.size(keys));
+    assertThat(keys, Matchers.containsInAnyOrder(key1, key2, key3));
+  }
+
+  @Test
+  public void testMultimapEntriesAndKeysMergeLocalRemove() throws IOException {
+    final String tag = "multimap";
+    StateTag<MultimapState<byte[], Integer>> addr =
+        StateTags.multimap(tag, ByteArrayCoder.of(), VarIntCoder.of());
+    MultimapState<byte[], Integer> multimapState = underTest.state(NAMESPACE, addr);
+
+    final byte[] key1 = "key1".getBytes(StandardCharsets.UTF_8);
+    final byte[] key2 = "key2".getBytes(StandardCharsets.UTF_8);
+    final byte[] key3 = "key3".getBytes(StandardCharsets.UTF_8);
+
+    SettableFuture<Iterable<Map.Entry<ByteString, Iterable<Integer>>>> entriesFuture =
+        SettableFuture.create();
+    when(mockReader.multimapFetchAllFuture(
+            false, key(NAMESPACE, tag), STATE_FAMILY, VarIntCoder.of()))
+        .thenReturn(entriesFuture);
+    SettableFuture<Iterable<Map.Entry<ByteString, Iterable<Integer>>>> keysFuture =
+        SettableFuture.create();
+    when(mockReader.multimapFetchAllFuture(
+            true, key(NAMESPACE, tag), STATE_FAMILY, VarIntCoder.of()))
+        .thenReturn(keysFuture);
+
+    ReadableState<Iterable<Map.Entry<byte[], Integer>>> entriesResult =
+        multimapState.entries().readLater();
+    ReadableState<Iterable<byte[]>> keysResult = multimapState.keys().readLater();
+    waitAndSet(
+        entriesFuture,
+        Arrays.asList(multimapEntry(key1, 1, 2, 3), multimapEntry(key2, 2, 3, 4)),
+        30);
+    waitAndSet(keysFuture, Arrays.asList(multimapEntry(key1), multimapEntry(key2)), 30);
+
+    multimapState.remove(dup(key1));
+    multimapState.put(key2, 8);
+    multimapState.put(dup(key3), 8);
+
+    Iterable<Map.Entry<byte[], Integer>> entries = entriesResult.read();
+    assertEquals(5, Iterables.size(entries));
+    assertThat(
+        entries,
+        Matchers.containsInAnyOrder(
+            multimapEntryMatcher(key2, 4),
+            multimapEntryMatcher(key2, 2),
+            multimapEntryMatcher(key2, 3),
+            multimapEntryMatcher(key2, 8),
+            multimapEntryMatcher(key3, 8)));
+
+    Iterable<byte[]> keys = keysResult.read();
+    assertThat(keys, Matchers.containsInAnyOrder(key2, key3));
+  }
+
+  @Test
+  public void testMultimapEntriesPaginated() throws IOException {
+    final String tag = "multimap";
+    StateTag<MultimapState<byte[], Integer>> addr =
+        StateTags.multimap(tag, ByteArrayCoder.of(), VarIntCoder.of());
+    MultimapState<byte[], Integer> multimapState = underTest.state(NAMESPACE, addr);
+
+    final byte[] key1 = "key1".getBytes(StandardCharsets.UTF_8);
+    final byte[] key2 = "key2".getBytes(StandardCharsets.UTF_8);
+    final byte[] key3 = "key3".getBytes(StandardCharsets.UTF_8);
+
+    SettableFuture<Iterable<Map.Entry<ByteString, Iterable<Integer>>>> entriesFuture =
+        SettableFuture.create();
+    when(mockReader.multimapFetchAllFuture(
+            false, key(NAMESPACE, tag), STATE_FAMILY, VarIntCoder.of()))
+        .thenReturn(entriesFuture);
+    SettableFuture<Iterable<Map.Entry<ByteString, Iterable<Integer>>>> keysFuture =
+        SettableFuture.create();
+    when(mockReader.multimapFetchAllFuture(
+            true, key(NAMESPACE, tag), STATE_FAMILY, VarIntCoder.of()))
+        .thenReturn(keysFuture);
+
+    ReadableState<Iterable<Map.Entry<byte[], Integer>>> entriesResult =
+        multimapState.entries().readLater();
+    ReadableState<Iterable<byte[]>> keysResult = multimapState.keys().readLater();
+    waitAndSet(
+        entriesFuture,
+        weightedList(
+            multimapEntry(key1, 1, 2, 3),
+            // entry key2 is returned in 2 separate responses due to pagination.
+            multimapEntry(key2, 2, 3, 4),
+            multimapEntry(key2, 4, 5)),
+        30);
+    waitAndSet(keysFuture, Arrays.asList(multimapEntry(key1), multimapEntry(key2)), 30);
+
+    multimapState.remove(dup(key1));
+    multimapState.put(key2, 8);
+    multimapState.put(dup(key3), 8);
+
+    Iterable<Map.Entry<byte[], Integer>> entries = entriesResult.read();
+    assertEquals(7, Iterables.size(entries));
+    assertThat(
+        entries,
+        Matchers.containsInAnyOrder(
+            multimapEntryMatcher(key2, 2),
+            multimapEntryMatcher(key2, 3),
+            multimapEntryMatcher(key2, 4),
+            multimapEntryMatcher(key2, 4),
+            multimapEntryMatcher(key2, 5),
+            multimapEntryMatcher(key2, 8),
+            multimapEntryMatcher(key3, 8)));
+
+    Iterable<byte[]> keys = keysResult.read();
+    assertThat(keys, Matchers.containsInAnyOrder(key2, key3));
+  }
+
+  @Test
+  public void testMultimapCacheComplete() throws IOException {
+    final String tag = "multimap";
+    StateTag<MultimapState<byte[], Integer>> addr =
+        StateTags.multimap(tag, ByteArrayCoder.of(), VarIntCoder.of());
+    MultimapState<byte[], Integer> multimapState = underTest.state(NAMESPACE, addr);
+
+    final byte[] key = "key".getBytes(StandardCharsets.UTF_8);
+
+    SettableFuture<Iterable<Map.Entry<ByteString, Iterable<Integer>>>> entriesFuture =
+        SettableFuture.create();
+    when(mockReader.multimapFetchAllFuture(
+            false, key(NAMESPACE, tag), STATE_FAMILY, VarIntCoder.of()))
+        .thenReturn(entriesFuture);
+
+    // to set up the multimap as cache complete
+    waitAndSet(entriesFuture, weightedList(multimapEntry(key, 1, 2, 3)), 30);
+    multimapState.entries().read();
+
+    multimapState.put(key, 2);
+
+    when(mockReader.multimapFetchAllFuture(
+            anyBoolean(), eq(key(NAMESPACE, tag)), eq(STATE_FAMILY), eq(VarIntCoder.of())))
+        .thenThrow(
+            new RuntimeException(
+                "The multimap is cache complete and should not perform any windmill read."));
+    when(mockReader.multimapFetchSingleEntryFuture(
+            any(), eq(key(NAMESPACE, tag)), eq(STATE_FAMILY), eq(VarIntCoder.of())))
+        .thenThrow(
+            new RuntimeException(
+                "The multimap is cache complete and should not perform any windmill read."));
+
+    Iterable<Map.Entry<byte[], Integer>> entries = multimapState.entries().read();
+    assertEquals(4, Iterables.size(entries));
+    assertThat(
+        entries,
+        Matchers.containsInAnyOrder(
+            multimapEntryMatcher(key, 1),
+            multimapEntryMatcher(key, 2),
+            multimapEntryMatcher(key, 3),
+            multimapEntryMatcher(key, 2)));
+
+    Iterable<byte[]> keys = multimapState.keys().read();
+    assertThat(keys, Matchers.containsInAnyOrder(key));
+
+    Iterable<Integer> values = multimapState.get(dup(key)).read();
+    assertThat(values, Matchers.containsInAnyOrder(1, 2, 2, 3));
+  }
+
+  @Test
+  public void testMultimapCachedSingleEntry() throws IOException {
+    final String tag = "multimap";
+    StateTag<MultimapState<byte[], Integer>> addr =
+        StateTags.multimap(tag, ByteArrayCoder.of(), VarIntCoder.of());
+    MultimapState<byte[], Integer> multimapState = underTest.state(NAMESPACE, addr);
+
+    final byte[] key = "key".getBytes(StandardCharsets.UTF_8);
+
+    SettableFuture<Iterable<Integer>> entryFuture = SettableFuture.create();
+    when(mockReader.multimapFetchSingleEntryFuture(
+            encodeWithCoder(key, ByteArrayCoder.of()),
+            key(NAMESPACE, tag),
+            STATE_FAMILY,
+            VarIntCoder.of()))
+        .thenReturn(entryFuture);
+
+    // to set up the entry key as cache complete and add some local changes
+    waitAndSet(entryFuture, weightedList(1, 2, 3), 30);
+    multimapState.get(key).read();
+    multimapState.put(key, 2);
+
+    when(mockReader.multimapFetchSingleEntryFuture(
+            eq(encodeWithCoder(key, ByteArrayCoder.of())),
+            eq(key(NAMESPACE, tag)),
+            eq(STATE_FAMILY),
+            eq(VarIntCoder.of())))
+        .thenThrow(
+            new RuntimeException(
+                "The multimap is cache complete for "
+                    + Arrays.toString(key)
+                    + " and should not perform any windmill read."));
+
+    Iterable<Integer> values = multimapState.get(dup(key)).read();
+    assertThat(values, Matchers.containsInAnyOrder(1, 2, 2, 3));
+    assertTrue(multimapState.containsKey(key).read());
+  }
+
+  @Test
+  public void testMultimapCachedPartialEntry() throws IOException {
+    final String tag = "multimap";
+    StateTag<MultimapState<byte[], Integer>> addr =
+        StateTags.multimap(tag, ByteArrayCoder.of(), VarIntCoder.of());
+    MultimapState<byte[], Integer> multimapState = underTest.state(NAMESPACE, addr);
+
+    final byte[] key1 = "key1".getBytes(StandardCharsets.UTF_8);
+    final byte[] key2 = "key2".getBytes(StandardCharsets.UTF_8);
+    final byte[] key3 = "key3".getBytes(StandardCharsets.UTF_8);
+
+    SettableFuture<Iterable<Integer>> entryFuture = SettableFuture.create();
+    when(mockReader.multimapFetchSingleEntryFuture(
+            encodeWithCoder(key1, ByteArrayCoder.of()),
+            key(NAMESPACE, tag),
+            STATE_FAMILY,
+            VarIntCoder.of()))
+        .thenReturn(entryFuture);
+
+    // to set up the entry key1 as cache complete and add some local changes
+    waitAndSet(entryFuture, weightedList(1, 2, 3), 30);
+    multimapState.get(key1).read();
+    multimapState.put(key1, 2);
+    multimapState.put(key3, 20);
+
+    SettableFuture<Iterable<Map.Entry<ByteString, Iterable<Integer>>>> entriesFuture =
+        SettableFuture.create();
+    when(mockReader.multimapFetchAllFuture(
+            false, key(NAMESPACE, tag), STATE_FAMILY, VarIntCoder.of()))
+        .thenReturn(entriesFuture);
+
+    // windmill contains extra entry key2
+    waitAndSet(
+        entriesFuture,
+        weightedList(multimapEntry(key1, 1, 2, 3), multimapEntry(key2, 4, 5, 6)),
+        30);
+
+    // key1 exist in both cache and windmill; key2 exists only in windmill; key3 exists only in
+    // cache. They should all be merged.
+    Iterable<Map.Entry<byte[], Integer>> entries = multimapState.entries().read();
+
+    assertEquals(8, Iterables.size(entries));
+    assertThat(
+        entries,
+        Matchers.containsInAnyOrder(
+            multimapEntryMatcher(key1, 1),
+            multimapEntryMatcher(key1, 2),
+            multimapEntryMatcher(key1, 2),
+            multimapEntryMatcher(key1, 3),
+            multimapEntryMatcher(key2, 4),
+            multimapEntryMatcher(key2, 5),
+            multimapEntryMatcher(key2, 6),
+            multimapEntryMatcher(key3, 20)));
+
+    assertThat(multimapState.keys().read(), Matchers.containsInAnyOrder(key1, key2, key3));
+  }
+
+  @Test
+  public void testMultimapCachedPartialEntryCannotCachePolled() throws IOException {
+    final String tag = "multimap";
+    StateTag<MultimapState<byte[], Integer>> addr =
+        StateTags.multimap(tag, ByteArrayCoder.of(), VarIntCoder.of());
+    MultimapState<byte[], Integer> multimapState = underTest.state(NAMESPACE, addr);
+
+    final byte[] key1 = "key1".getBytes(StandardCharsets.UTF_8);
+    final byte[] key2 = "key2".getBytes(StandardCharsets.UTF_8);
+    final byte[] key3 = "key3".getBytes(StandardCharsets.UTF_8);
+
+    SettableFuture<Iterable<Integer>> entryFuture = SettableFuture.create();
+    when(mockReader.multimapFetchSingleEntryFuture(
+            encodeWithCoder(key1, ByteArrayCoder.of()),
+            key(NAMESPACE, tag),
+            STATE_FAMILY,
+            VarIntCoder.of()))
+        .thenReturn(entryFuture);
+
+    // to set up the entry key1 as cache complete and add some local changes
+    waitAndSet(entryFuture, weightedList(1, 2, 3), 30);
+    multimapState.get(key1).read();
+    multimapState.put(dup(key1), 2);
+    multimapState.put(dup(key3), 20);
+
+    SettableFuture<Iterable<Map.Entry<ByteString, Iterable<Integer>>>> entriesFuture =
+        SettableFuture.create();
+    when(mockReader.multimapFetchAllFuture(
+            false, key(NAMESPACE, tag), STATE_FAMILY, VarIntCoder.of()))
+        .thenReturn(entriesFuture);
+    SettableFuture<Iterable<Map.Entry<ByteString, Iterable<Integer>>>> keysFuture =
+        SettableFuture.create();
+    when(mockReader.multimapFetchAllFuture(
+            true, key(NAMESPACE, tag), STATE_FAMILY, VarIntCoder.of()))
+        .thenReturn(keysFuture);
+
+    // windmill contains extra entry key2, and this time the entries returned should not be cached.
+    waitAndSet(
+        entriesFuture,
+        Arrays.asList(multimapEntry(key1, 1, 2, 3), multimapEntry(key2, 4, 5, 6)),
+        30);
+    waitAndSet(keysFuture, Arrays.asList(multimapEntry(key1), multimapEntry(key2)), 30);
+
+    // key1 exist in both cache and windmill; key2 exists only in windmill; key3 exists only in
+    // cache. They should all be merged.
+    Iterable<Map.Entry<byte[], Integer>> entries = multimapState.entries().read();
+
+    assertEquals(8, Iterables.size(entries));
+    assertThat(
+        entries,
+        Matchers.containsInAnyOrder(
+            multimapEntryMatcher(key1, 1),
+            multimapEntryMatcher(key1, 2),
+            multimapEntryMatcher(key1, 2),
+            multimapEntryMatcher(key1, 3),
+            multimapEntryMatcher(key2, 4),
+            multimapEntryMatcher(key2, 5),
+            multimapEntryMatcher(key2, 6),
+            multimapEntryMatcher(key3, 20)));
+
+    assertThat(multimapState.keys().read(), Matchers.containsInAnyOrder(key1, key2, key3));
+  }
+
+  @Test
+  public void testMultimapModifyAfterReadDoesNotAffectResult() throws IOException {
+    final String tag = "multimap";
+    StateTag<MultimapState<byte[], Integer>> addr =
+        StateTags.multimap(tag, ByteArrayCoder.of(), VarIntCoder.of());
+    MultimapState<byte[], Integer> multimapState = underTest.state(NAMESPACE, addr);
+
+    final byte[] key1 = "key1".getBytes(StandardCharsets.UTF_8);
+    final byte[] key2 = "key2".getBytes(StandardCharsets.UTF_8);
+    final byte[] key3 = "key3".getBytes(StandardCharsets.UTF_8);
+    final byte[] key4 = "key4".getBytes(StandardCharsets.UTF_8);
+
+    SettableFuture<Iterable<Map.Entry<ByteString, Iterable<Integer>>>> entriesFuture =
+        SettableFuture.create();
+    when(mockReader.multimapFetchAllFuture(
+            false, key(NAMESPACE, tag), STATE_FAMILY, VarIntCoder.of()))
+        .thenReturn(entriesFuture);
+    SettableFuture<Iterable<Map.Entry<ByteString, Iterable<Integer>>>> keysFuture =
+        SettableFuture.create();
+    when(mockReader.multimapFetchAllFuture(
+            true, key(NAMESPACE, tag), STATE_FAMILY, VarIntCoder.of()))
+        .thenReturn(keysFuture);
+    SettableFuture<Iterable<Integer>> getKey1Future = SettableFuture.create();
+    SettableFuture<Iterable<Integer>> getKey2Future = SettableFuture.create();
+    SettableFuture<Iterable<Integer>> getKey4Future = SettableFuture.create();
+    when(mockReader.multimapFetchSingleEntryFuture(
+            encodeWithCoder(key1, ByteArrayCoder.of()),
+            key(NAMESPACE, tag),
+            STATE_FAMILY,
+            VarIntCoder.of()))
+        .thenReturn(getKey1Future);
+    when(mockReader.multimapFetchSingleEntryFuture(
+            encodeWithCoder(key2, ByteArrayCoder.of()),
+            key(NAMESPACE, tag),
+            STATE_FAMILY,
+            VarIntCoder.of()))
+        .thenReturn(getKey2Future);
+    when(mockReader.multimapFetchSingleEntryFuture(
+            encodeWithCoder(key4, ByteArrayCoder.of()),
+            key(NAMESPACE, tag),
+            STATE_FAMILY,
+            VarIntCoder.of()))
+        .thenReturn(getKey4Future);
+
+    ReadableState<Iterable<Map.Entry<byte[], Integer>>> entriesResult =
+        multimapState.entries().readLater();
+    ReadableState<Iterable<byte[]>> keysResult = multimapState.keys().readLater();
+    waitAndSet(
+        entriesFuture,
+        Arrays.asList(multimapEntry(key1, 1, 2, 3), multimapEntry(key2, 2, 3, 4)),
+        200);
+    waitAndSet(keysFuture, Arrays.asList(multimapEntry(key1), multimapEntry(key2)), 200);
+
+    // make key4 to be known nonexistent.
+    multimapState.remove(key4);
+
+    ReadableState<Iterable<Integer>> key1Future = multimapState.get(key1).readLater();
+    waitAndSet(getKey1Future, Arrays.asList(1, 2, 3), 200);
+    ReadableState<Iterable<Integer>> key2Future = multimapState.get(key2).readLater();
+    waitAndSet(getKey2Future, Arrays.asList(2, 3, 4), 200);
+    ReadableState<Iterable<Integer>> key4Future = multimapState.get(key4).readLater();
+    waitAndSet(getKey4Future, Collections.emptyList(), 200);
+
+    multimapState.put(key1, 7);
+    multimapState.put(dup(key2), 8);
+    multimapState.put(dup(key3), 8);
+
+    Iterable<Map.Entry<byte[], Integer>> entries = entriesResult.read();
+    Iterable<byte[]> keys = keysResult.read();
+    Iterable<Integer> key1Values = key1Future.read();
+    Iterable<Integer> key2Values = key2Future.read();
+    Iterable<Integer> key4Values = key4Future.read();
+
+    // values added/removed after read should not be reflected in result
+    multimapState.remove(key1);
+    multimapState.put(key2, 9);
+    multimapState.put(key4, 10);
+
+    assertEquals(9, Iterables.size(entries));
+    assertThat(
+        entries,
+        Matchers.containsInAnyOrder(
+            multimapEntryMatcher(key1, 1),
+            multimapEntryMatcher(key1, 2),
+            multimapEntryMatcher(key1, 3),
+            multimapEntryMatcher(key1, 7),
+            multimapEntryMatcher(key2, 4),
+            multimapEntryMatcher(key2, 2),
+            multimapEntryMatcher(key2, 3),
+            multimapEntryMatcher(key2, 8),
+            multimapEntryMatcher(key3, 8)));
+
+    assertEquals(3, Iterables.size(keys));
+    assertThat(keys, Matchers.containsInAnyOrder(key1, key2, key3));
+
+    assertEquals(4, Iterables.size(key1Values));
+    assertThat(key1Values, Matchers.containsInAnyOrder(1, 2, 3, 7));
+
+    assertEquals(4, Iterables.size(key2Values));
+    assertThat(key2Values, Matchers.containsInAnyOrder(2, 3, 4, 8));
+
+    assertTrue(Iterables.isEmpty(key4Values));
+  }
+
+  @Test
+  public void testMultimapLazyIterateHugeEntriesResult() {
+    // A multimap with 1 million keys with a total of 10GBs data
+    final String tag = "multimap";
+    StateTag<MultimapState<byte[], Integer>> addr =
+        StateTags.multimap(tag, ByteArrayCoder.of(), VarIntCoder.of());
+    MultimapState<byte[], Integer> multimapState = underTest.state(NAMESPACE, addr);
+
+    SettableFuture<Iterable<Map.Entry<ByteString, Iterable<Integer>>>> entriesFuture =
+        SettableFuture.create();
+    when(mockReader.multimapFetchAllFuture(
+            false, key(NAMESPACE, tag), STATE_FAMILY, VarIntCoder.of()))
+        .thenReturn(entriesFuture);
+
+    waitAndSet(
+        entriesFuture,
+        () ->
+            new Iterator<Map.Entry<ByteString, Iterable<Integer>>>() {
+              int returnedEntries = 0;
+              byte[] entryKey = new byte[10_000]; // each key is 10KB
+              final int targetEntries = 1_000_000; // return 1 million entries, which is 10 GBs
+              Random rand = new Random();
+
+              @Override
+              public boolean hasNext() {
+                return returnedEntries < targetEntries;
+              }
+
+              @Override
+              public Map.Entry<ByteString, Iterable<Integer>> next() {
+                returnedEntries++;
+                rand.nextBytes(entryKey);
+                try {
+                  return multimapEntry(entryKey, 1);
+                } catch (IOException e) {
+                  throw new RuntimeException(e);
+                }
+              }
+            },
+        200);
+    Iterable<Map.Entry<byte[], Integer>> entries = multimapState.entries().read();
+    assertEquals(1_000_000, Iterables.size(entries));
+  }
+
+  @Test
+  public void testMultimapLazyIterateHugeKeysResult() {
+    // A multimap with 1 million keys with a total of 10GBs data
+    final String tag = "multimap";
+    StateTag<MultimapState<byte[], Integer>> addr =
+        StateTags.multimap(tag, ByteArrayCoder.of(), VarIntCoder.of());
+    MultimapState<byte[], Integer> multimapState = underTest.state(NAMESPACE, addr);
+
+    SettableFuture<Iterable<Map.Entry<ByteString, Iterable<Integer>>>> keysFuture =
+        SettableFuture.create();
+    when(mockReader.multimapFetchAllFuture(
+            true, key(NAMESPACE, tag), STATE_FAMILY, VarIntCoder.of()))
+        .thenReturn(keysFuture);
+
+    waitAndSet(
+        keysFuture,
+        () ->
+            new Iterator<Map.Entry<ByteString, Iterable<Integer>>>() {
+              int returnedEntries = 0;
+              byte[] entryKey = new byte[10_000]; // each key is 10KB
+              final int targetEntries = 1_000_000; // return 1 million entries, which is 10 GBs
+              Random rand = new Random();
+
+              @Override
+              public boolean hasNext() {
+                return returnedEntries < targetEntries;
+              }
+
+              @Override
+              public Map.Entry<ByteString, Iterable<Integer>> next() {
+                returnedEntries++;
+                rand.nextBytes(entryKey);
+                try {
+                  return multimapEntry(entryKey);
+                } catch (IOException e) {
+                  throw new RuntimeException(e);
+                }
+              }
+            },
+        200);
+    Iterable<byte[]> keys = multimapState.keys().read();
+    assertEquals(1_000_000, Iterables.size(keys));
+  }
+
+  @Test
+  public void testMultimapLazyIterateHugeEntriesResultSingleEntry() throws IOException {
+    // A multimap with 1 key and 1 million values and a total of 10GBs data
+    final String tag = "multimap";
+    final Integer key = 100;
+    StateTag<MultimapState<Integer, byte[]>> addr =
+        StateTags.multimap(tag, VarIntCoder.of(), ByteArrayCoder.of());
+    MultimapState<Integer, byte[]> multimapState = underTest.state(NAMESPACE, addr);
+
+    SettableFuture<Iterable<Map.Entry<ByteString, Iterable<byte[]>>>> entriesFuture =
+        SettableFuture.create();
+    when(mockReader.multimapFetchAllFuture(
+            false, key(NAMESPACE, tag), STATE_FAMILY, ByteArrayCoder.of()))
+        .thenReturn(entriesFuture);
+    SettableFuture<Iterable<byte[]>> getKeyFuture = SettableFuture.create();
+    when(mockReader.multimapFetchSingleEntryFuture(
+            encodeWithCoder(key, VarIntCoder.of()),
+            key(NAMESPACE, tag),
+            STATE_FAMILY,
+            ByteArrayCoder.of()))
+        .thenReturn(getKeyFuture);
+
+    // a not weighted iterators that returns tons of data
+    Iterable<byte[]> values =
+        () ->
+            new Iterator<byte[]>() {
+              int returnedValues = 0;
+              byte[] value = new byte[10_000]; // each value is 10KB
+              final int targetValues = 1_000_000; // return 1 million values, which is 10 GBs
+              Random rand = new Random();
+
+              @Override
+              public boolean hasNext() {
+                return returnedValues < targetValues;
+              }
+
+              @Override
+              public byte[] next() {
+                returnedValues++;
+                rand.nextBytes(value);
+                return value;
+              }
+            };
+
+    waitAndSet(
+        entriesFuture,
+        Arrays.asList(
+            new AbstractMap.SimpleEntry<>(encodeWithCoder(key, VarIntCoder.of()), values)),
+        200);
+    waitAndSet(getKeyFuture, values, 200);
+
+    Iterable<Map.Entry<Integer, byte[]>> entries = multimapState.entries().read();
+    assertEquals(1_000_000, Iterables.size(entries));
+
+    Iterable<byte[]> valueResult = multimapState.get(key).read();
+    assertEquals(1_000_000, Iterables.size(valueResult));
+  }
+
+  private static class MultimapEntryUpdate {
+    String key;
+    Iterable<Integer> values;
+    boolean deleteAll;
+
+    public MultimapEntryUpdate(String key, Iterable<Integer> values, boolean deleteAll) {
+      this.key = key;
+      this.values = values;
+      this.deleteAll = deleteAll;
+    }
+
+    @Override
+    public boolean equals(Object o) {
+      if (this == o) return true;
+      if (!(o instanceof MultimapEntryUpdate)) return false;
+      MultimapEntryUpdate that = (MultimapEntryUpdate) o;
+      return deleteAll == that.deleteAll
+          && Objects.equals(key, that.key)
+          && Objects.equals(values, that.values);
+    }
+
+    @Override
+    public int hashCode() {
+      return Objects.hash(key, values, deleteAll);
+    }
+  }
+
+  private static MultimapEntryUpdate decodeTagMultimapEntry(Windmill.TagMultimapEntry entryProto) {
+    try {
+      String key = StringUtf8Coder.of().decode(entryProto.getEntryName().newInput(), Context.OUTER);
+      List<Integer> values = new ArrayList<>();
+      for (ByteString value : entryProto.getValuesList()) {
+        values.add(VarIntCoder.of().decode(value.newInput(), Context.OUTER));
+      }
+      return new MultimapEntryUpdate(key, values, entryProto.getDeleteAll());
+    } catch (IOException e) {
+      throw new RuntimeException(e);
+    }
+  }
+
+  private static void assertTagMultimapUpdates(
+      Windmill.TagMultimapUpdateRequest.Builder updates, MultimapEntryUpdate... expected) {
+    assertThat(
+        updates.getUpdatesList().stream()
+            .map(WindmillStateInternalsTest::decodeTagMultimapEntry)
+            .collect(Collectors.toList()),
+        Matchers.containsInAnyOrder(expected));
+  }
+
+  @Test
+  public void testMultimapPutAndPersist() {
+    final String tag = "multimap";
+    StateTag<MultimapState<String, Integer>> addr =
+        StateTags.multimap(tag, StringUtf8Coder.of(), VarIntCoder.of());
+    MultimapState<String, Integer> multimapState = underTest.state(NAMESPACE, addr);
+
+    final String key1 = "key1";
+    final String key2 = "key2";
+
+    multimapState.put(key1, 1);
+    multimapState.put(key1, 2);
+    multimapState.put(key2, 2);
+
+    Windmill.WorkItemCommitRequest.Builder commitBuilder =
+        Windmill.WorkItemCommitRequest.newBuilder();
+    underTest.persist(commitBuilder);
+
+    assertEquals(1, commitBuilder.getMultimapUpdatesCount());
+    Windmill.TagMultimapUpdateRequest.Builder builder =
+        Iterables.getOnlyElement(commitBuilder.getMultimapUpdatesBuilderList());
+    assertTagMultimapUpdates(
+        builder,
+        new MultimapEntryUpdate(key1, Arrays.asList(1, 2), false),
+        new MultimapEntryUpdate(key2, Arrays.asList(2), false));
+  }
+
+  @Test
+  public void testMultimapRemovePutAndPersist() {
+    final String tag = "multimap";
+    StateTag<MultimapState<String, Integer>> addr =
+        StateTags.multimap(tag, StringUtf8Coder.of(), VarIntCoder.of());
+    MultimapState<String, Integer> multimapState = underTest.state(NAMESPACE, addr);
+
+    final String key1 = "key1";
+    final String key2 = "key2";
+
+    // we should add 1 and 2 to key1
+    multimapState.remove(key1);
+    multimapState.put(key1, 1);
+    multimapState.put(key1, 2);
+    // we should not add 2 to key 2
+    multimapState.put(key2, 2);
+    multimapState.remove(key2);
+    // we should add 4 to key 2
+    multimapState.put(key2, 4);
+
+    Windmill.WorkItemCommitRequest.Builder commitBuilder =
+        Windmill.WorkItemCommitRequest.newBuilder();
+    underTest.persist(commitBuilder);
+
+    assertEquals(1, commitBuilder.getMultimapUpdatesCount());
+    Windmill.TagMultimapUpdateRequest.Builder builder =
+        Iterables.getOnlyElement(commitBuilder.getMultimapUpdatesBuilderList());
+    assertTagMultimapUpdates(
+        builder,
+        new MultimapEntryUpdate(key1, Arrays.asList(1, 2), true),
+        new MultimapEntryUpdate(key2, Arrays.asList(4), true));
+  }
+
+  @Test
+  public void testMultimapRemoveAndPersist() {
+    final String tag = "multimap";
+    StateTag<MultimapState<String, Integer>> addr =
+        StateTags.multimap(tag, StringUtf8Coder.of(), VarIntCoder.of());
+    MultimapState<String, Integer> multimapState = underTest.state(NAMESPACE, addr);
+
+    final String key1 = "key1";
+    final String key2 = "key2";
+
+    multimapState.remove(key1);
+    multimapState.remove(key2);
+
+    Windmill.WorkItemCommitRequest.Builder commitBuilder =
+        Windmill.WorkItemCommitRequest.newBuilder();
+    underTest.persist(commitBuilder);
+
+    assertEquals(1, commitBuilder.getMultimapUpdatesCount());
+    Windmill.TagMultimapUpdateRequest.Builder builder =
+        Iterables.getOnlyElement(commitBuilder.getMultimapUpdatesBuilderList());
+    assertTagMultimapUpdates(
+        builder,
+        new MultimapEntryUpdate(key1, Collections.emptyList(), true),
+        new MultimapEntryUpdate(key2, Collections.emptyList(), true));
+  }
+
+  @Test
+  public void testMultimapPutRemoveClearAndPersist() {
+    final String tag = "multimap";
+    StateTag<MultimapState<String, Integer>> addr =
+        StateTags.multimap(tag, StringUtf8Coder.of(), VarIntCoder.of());
+    MultimapState<String, Integer> multimapState = underTest.state(NAMESPACE, addr);
+
+    final String key1 = "key1";
+    final String key2 = "key2";
+    final String key3 = "key3";
+
+    // no need to send any put/remove if clear is called later
+    multimapState.put(key1, 1);
+    multimapState.put(key2, 2);
+    multimapState.remove(key2);
+    multimapState.clear();
+    // remove without put sent after clear should also not be added: we are cache complete after
+    // clear, so we know we can skip unnecessary remove.
+    multimapState.remove(key3);
+
+    Windmill.WorkItemCommitRequest.Builder commitBuilder =
+        Windmill.WorkItemCommitRequest.newBuilder();
+    underTest.persist(commitBuilder);
+
+    assertEquals(1, commitBuilder.getMultimapUpdatesCount());
+    Windmill.TagMultimapUpdateRequest.Builder builder =
+        Iterables.getOnlyElement(commitBuilder.getMultimapUpdatesBuilderList());
+    assertEquals(0, builder.getUpdatesCount());
+    assertTrue(builder.getDeleteAll());
+  }
+
+  @Test
+  public void testMultimapPutRemoveAndPersistWhenComplete() {
+    final String tag = "multimap";
+    StateTag<MultimapState<String, Integer>> addr =
+        StateTags.multimap(tag, StringUtf8Coder.of(), VarIntCoder.of());
+    MultimapState<String, Integer> multimapState = underTest.state(NAMESPACE, addr);
+
+    SettableFuture<Iterable<Map.Entry<ByteString, Iterable<Integer>>>> entriesFuture =
+        SettableFuture.create();
+    when(mockReader.multimapFetchAllFuture(
+            false, key(NAMESPACE, tag), STATE_FAMILY, VarIntCoder.of()))
+        .thenReturn(entriesFuture);
+
+    // to set up the multimap as cache complete
+    waitAndSet(entriesFuture, Collections.emptyList(), 30);
+    multimapState.entries().read();
+
+    final String key1 = "key1";
+    final String key2 = "key2";
+
+    // put when complete should be sent
+    multimapState.put(key1, 4);
+
+    // put-then-remove when complete should not be sent
+    multimapState.put(key2, 5);
+    multimapState.remove(key2);
+
+    Windmill.WorkItemCommitRequest.Builder commitBuilder =
+        Windmill.WorkItemCommitRequest.newBuilder();
+    underTest.persist(commitBuilder);
+
+    assertEquals(1, commitBuilder.getMultimapUpdatesCount());
+    Windmill.TagMultimapUpdateRequest.Builder builder =
+        Iterables.getOnlyElement(commitBuilder.getMultimapUpdatesBuilderList());
+    assertTagMultimapUpdates(builder, new MultimapEntryUpdate(key1, Arrays.asList(4), false));
+  }
+
+  @Test
+  public void testMultimapRemoveAndKeysAndPersist() throws IOException {
+    final String tag = "multimap";
+    StateTag<MultimapState<byte[], Integer>> addr =
+        StateTags.multimap(tag, ByteArrayCoder.of(), VarIntCoder.of());
+    MultimapState<byte[], Integer> multimapState = underTest.state(NAMESPACE, addr);
+
+    final byte[] key1 = "key1".getBytes(StandardCharsets.UTF_8);
+    final byte[] key2 = "key2".getBytes(StandardCharsets.UTF_8);
+
+    SettableFuture<Iterable<Map.Entry<ByteString, Iterable<Integer>>>> keysFuture =
+        SettableFuture.create();
+    when(mockReader.multimapFetchAllFuture(
+            true, key(NAMESPACE, tag), STATE_FAMILY, VarIntCoder.of()))
+        .thenReturn(keysFuture);
+
+    ReadableState<Iterable<byte[]>> keysResult = multimapState.keys().readLater();
+    waitAndSet(
+        keysFuture,
+        new WindmillStateReader.WeightedList<>(
+            Arrays.asList(multimapEntry(key1), multimapEntry(key2))),
+        30);
+
+    multimapState.remove(key1);
+
+    Iterable<byte[]> keys = keysResult.read();
+    assertEquals(1, Iterables.size(keys));
+    assertThat(keys, Matchers.containsInAnyOrder(key2));
+
+    Windmill.WorkItemCommitRequest.Builder commitBuilder =
+        Windmill.WorkItemCommitRequest.newBuilder();
+    underTest.persist(commitBuilder);
+
+    assertEquals(1, commitBuilder.getMultimapUpdatesCount());
+    Windmill.TagMultimapUpdateRequest.Builder builder =
+        Iterables.getOnlyElement(commitBuilder.getMultimapUpdatesBuilderList());
+    assertEquals(1, builder.getUpdatesCount());
+    assertFalse(builder.getDeleteAll());
+    Windmill.TagMultimapEntry entryUpdate = Iterables.getOnlyElement(builder.getUpdatesList());
+    byte[] decodedKey =
+        ByteArrayCoder.of().decode(entryUpdate.getEntryName().newInput(), Context.OUTER);
+    assertTrue(Arrays.equals(key1, decodedKey));
+    assertTrue(entryUpdate.getDeleteAll());
+  }

Review Comment:
   Added a fuzz test to perform different modifications and verify the state.



##########
runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternalsTest.java:
##########
@@ -645,6 +655,1163 @@ public void testMapComplexPersist() throws Exception {
     assertEquals(0, commitBuilder.getValueUpdatesCount());
   }
 
+  private static <T> ByteString encodeWithCoder(T key, Coder<T> coder) throws IOException {

Review Comment:
   Done



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternals.java:
##########
@@ -1723,12 +1729,17 @@ public ReadableState<Iterable<V>> get(K key) {
 
         @Override
         public Iterable<V> read() {
-          KeyState keyState = keyStateMap.computeIfAbsent(structuralKey, k -> new KeyState(key));
-          if (keyState.existence == KeyExistence.KNOWN_NONEXISTENT) {
-            return Collections.emptyList();
+          KeyState keyState = null;
+          if (allKeysKnown) {
+            keyState = keyStateMap.get(structuralKey);
+            if (keyState == null || keyState.existence == KeyExistence.UNKNOWN_EXISTENCE) {
+              if (keyState != null) keyStateMap.remove(structuralKey);

Review Comment:
   UNKNOWN_EXISTENCE is needed in situations like when a key doesn't exist in the map, and we need to check windmill to find out its existence. Before the windmill read returns result, we need to mark it as UNKNOWN_EXISTENCE.



##########
runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WindmillStateReaderTest.java:
##########
@@ -99,6 +113,713 @@ private ByteString intData(int value) throws IOException {
     return output.toByteString();
   }
 
+  @Test
+  public void testReadMultimapSingleEntry() throws Exception {
+    Future<Iterable<Integer>> future =
+        underTest.multimapFetchSingleEntryFuture(
+            STATE_MULTIMAP_KEY_1, STATE_KEY_1, STATE_FAMILY, INT_CODER);
+    Mockito.verifyNoMoreInteractions(mockWindmill);
+
+    Windmill.KeyedGetDataRequest.Builder expectedRequest =
+        Windmill.KeyedGetDataRequest.newBuilder()
+            .setKey(DATA_KEY)
+            .setShardingKey(SHARDING_KEY)
+            .setWorkToken(WORK_TOKEN)
+            .setMaxBytes(WindmillStateReader.MAX_KEY_BYTES)
+            .addMultimapsToFetch(
+                Windmill.TagMultimapFetchRequest.newBuilder()
+                    .setTag(STATE_KEY_1)
+                    .setStateFamily(STATE_FAMILY)
+                    .setFetchEntryNamesOnly(false)
+                    .addEntriesToFetch(
+                        Windmill.TagMultimapEntry.newBuilder()
+                            .setEntryName(STATE_MULTIMAP_KEY_1)
+                            .setFetchMaxBytes(WindmillStateReader.INITIAL_MAX_MULTIMAP_BYTES)
+                            .build()));
+
+    Windmill.KeyedGetDataResponse.Builder response =
+        Windmill.KeyedGetDataResponse.newBuilder()
+            .setKey(DATA_KEY)
+            .addTagMultimaps(
+                Windmill.TagMultimapFetchResponse.newBuilder()
+                    .setTag(STATE_KEY_1)
+                    .setStateFamily(STATE_FAMILY)
+                    .addEntries(
+                        Windmill.TagMultimapEntry.newBuilder()
+                            .setEntryName(STATE_MULTIMAP_KEY_1)
+                            .addAllValues(Arrays.asList(intData(5), intData(6)))));
+    Mockito.when(mockWindmill.getStateData(COMPUTATION, expectedRequest.build()))
+        .thenReturn(response.build());
+
+    Iterable<Integer> results = future.get();
+    Mockito.verify(mockWindmill).getStateData(COMPUTATION, expectedRequest.build());
+    for (Integer unused : results) {
+      // Iterate over the results to force loading all the pages.
+    }
+    Mockito.verifyNoMoreInteractions(mockWindmill);
+
+    assertThat(results, Matchers.containsInAnyOrder(5, 6));
+    assertNoReader(future);
+  }
+
+  @Test
+  public void testReadMultimapSingleEntryPaginated() throws Exception {
+    Future<Iterable<Integer>> future =
+        underTest.multimapFetchSingleEntryFuture(
+            STATE_MULTIMAP_KEY_1, STATE_KEY_1, STATE_FAMILY, INT_CODER);
+    Mockito.verifyNoMoreInteractions(mockWindmill);
+
+    Windmill.KeyedGetDataRequest.Builder expectedRequest1 =
+        Windmill.KeyedGetDataRequest.newBuilder()
+            .setKey(DATA_KEY)
+            .setShardingKey(SHARDING_KEY)
+            .setWorkToken(WORK_TOKEN)
+            .setMaxBytes(WindmillStateReader.MAX_KEY_BYTES)
+            .addMultimapsToFetch(
+                Windmill.TagMultimapFetchRequest.newBuilder()
+                    .setTag(STATE_KEY_1)
+                    .setStateFamily(STATE_FAMILY)
+                    .setFetchEntryNamesOnly(false)
+                    .addEntriesToFetch(
+                        Windmill.TagMultimapEntry.newBuilder()
+                            .setEntryName(STATE_MULTIMAP_KEY_1)
+                            .setFetchMaxBytes(WindmillStateReader.INITIAL_MAX_MULTIMAP_BYTES)
+                            .build()));
+
+    Windmill.KeyedGetDataResponse.Builder response1 =
+        Windmill.KeyedGetDataResponse.newBuilder()
+            .setKey(DATA_KEY)
+            .addTagMultimaps(
+                Windmill.TagMultimapFetchResponse.newBuilder()
+                    .setTag(STATE_KEY_1)
+                    .setStateFamily(STATE_FAMILY)
+                    .addEntries(
+                        Windmill.TagMultimapEntry.newBuilder()
+                            .setEntryName(STATE_MULTIMAP_KEY_1)
+                            .addAllValues(Arrays.asList(intData(5), intData(6)))
+                            .setContinuationPosition(500)));
+    Windmill.KeyedGetDataRequest.Builder expectedRequest2 =
+        Windmill.KeyedGetDataRequest.newBuilder()
+            .setKey(DATA_KEY)
+            .setShardingKey(SHARDING_KEY)
+            .setWorkToken(WORK_TOKEN)
+            .setMaxBytes(WindmillStateReader.MAX_CONTINUATION_KEY_BYTES)
+            .addMultimapsToFetch(
+                Windmill.TagMultimapFetchRequest.newBuilder()
+                    .setTag(STATE_KEY_1)
+                    .setStateFamily(STATE_FAMILY)
+                    .setFetchEntryNamesOnly(false)
+                    .addEntriesToFetch(
+                        Windmill.TagMultimapEntry.newBuilder()
+                            .setEntryName(STATE_MULTIMAP_KEY_1)
+                            .setFetchMaxBytes(WindmillStateReader.CONTINUATION_MAX_MULTIMAP_BYTES)
+                            .setRequestPosition(500)
+                            .build()));
+
+    Windmill.KeyedGetDataResponse.Builder response2 =
+        Windmill.KeyedGetDataResponse.newBuilder()
+            .setKey(DATA_KEY)
+            .addTagMultimaps(
+                Windmill.TagMultimapFetchResponse.newBuilder()
+                    .setTag(STATE_KEY_1)
+                    .setStateFamily(STATE_FAMILY)
+                    .addEntries(
+                        Windmill.TagMultimapEntry.newBuilder()
+                            .setEntryName(STATE_MULTIMAP_KEY_1)
+                            .addAllValues(Arrays.asList(intData(7), intData(8)))
+                            .setContinuationPosition(800)
+                            .setRequestPosition(500)));
+    Windmill.KeyedGetDataRequest.Builder expectedRequest3 =
+        Windmill.KeyedGetDataRequest.newBuilder()
+            .setKey(DATA_KEY)
+            .setShardingKey(SHARDING_KEY)
+            .setWorkToken(WORK_TOKEN)
+            .setMaxBytes(WindmillStateReader.MAX_CONTINUATION_KEY_BYTES)
+            .addMultimapsToFetch(
+                Windmill.TagMultimapFetchRequest.newBuilder()
+                    .setTag(STATE_KEY_1)
+                    .setStateFamily(STATE_FAMILY)
+                    .setFetchEntryNamesOnly(false)
+                    .addEntriesToFetch(
+                        Windmill.TagMultimapEntry.newBuilder()
+                            .setEntryName(STATE_MULTIMAP_KEY_1)
+                            .setFetchMaxBytes(WindmillStateReader.CONTINUATION_MAX_MULTIMAP_BYTES)
+                            .setRequestPosition(800)
+                            .build()));
+
+    Windmill.KeyedGetDataResponse.Builder response3 =
+        Windmill.KeyedGetDataResponse.newBuilder()
+            .setKey(DATA_KEY)
+            .addTagMultimaps(
+                Windmill.TagMultimapFetchResponse.newBuilder()
+                    .setTag(STATE_KEY_1)
+                    .setStateFamily(STATE_FAMILY)
+                    .addEntries(
+                        Windmill.TagMultimapEntry.newBuilder()
+                            .setEntryName(STATE_MULTIMAP_KEY_1)
+                            .addAllValues(Arrays.asList(intData(9), intData(10)))
+                            .setRequestPosition(800)));
+    Mockito.when(mockWindmill.getStateData(COMPUTATION, expectedRequest1.build()))
+        .thenReturn(response1.build());
+    Mockito.when(mockWindmill.getStateData(COMPUTATION, expectedRequest2.build()))
+        .thenReturn(response2.build());
+    Mockito.when(mockWindmill.getStateData(COMPUTATION, expectedRequest3.build()))
+        .thenReturn(response3.build());
+
+    Iterable<Integer> results = future.get();
+    Mockito.verify(mockWindmill).getStateData(COMPUTATION, expectedRequest1.build());
+    for (Integer unused : results) {
+      // Iterate over the results to force loading all the pages.
+    }
+    Mockito.verify(mockWindmill).getStateData(COMPUTATION, expectedRequest2.build());
+    Mockito.verify(mockWindmill).getStateData(COMPUTATION, expectedRequest3.build());
+    Mockito.verifyNoMoreInteractions(mockWindmill);
+
+    assertThat(results, Matchers.contains(5, 6, 7, 8, 9, 10));
+    // NOTE: The future will still contain a reference to the underlying reader.
+  }
+
+  // check whether the two TagMultimapFetchRequests equal to each other, ignoring the order of
+  // entries and the order of values in each entry.
+  private static void assertMultimapFetchRequestEqual(
+      Windmill.TagMultimapFetchRequest req1, Windmill.TagMultimapFetchRequest req2) {
+    assertMultimapEntriesEqual(req1.getEntriesToFetchList(), req2.getEntriesToFetchList());
+    assertEquals(
+        req1.toBuilder().clearEntriesToFetch().build(),
+        req2.toBuilder().clearEntriesToFetch().build());
+  }
+
+  private static void assertMultimapEntriesEqual(
+      List<Windmill.TagMultimapEntry> left, List<Windmill.TagMultimapEntry> right) {
+    Map<ByteString, Windmill.TagMultimapEntry> map = Maps.newHashMap();
+    for (Windmill.TagMultimapEntry entry : left) {
+      map.put(entry.getEntryName(), entry);
+    }
+    for (Windmill.TagMultimapEntry entry : right) {
+      assertTrue(map.containsKey(entry.getEntryName()));
+      Windmill.TagMultimapEntry that = map.remove(entry.getEntryName());
+      if (entry.getValuesCount() == 0) {
+        assertEquals(0, that.getValuesCount());
+      } else {
+        assertThat(entry.getValuesList(), Matchers.containsInAnyOrder(that.getValuesList()));
+      }
+      assertEquals(entry.toBuilder().clearValues().build(), that.toBuilder().clearValues().build());
+    }
+    assertTrue(map.isEmpty());
+  }
+
+  @Test
+  public void testReadMultimapMultipleEntries() throws Exception {
+    Future<Iterable<Integer>> future1 =
+        underTest.multimapFetchSingleEntryFuture(
+            STATE_MULTIMAP_KEY_1, STATE_KEY_1, STATE_FAMILY, INT_CODER);
+    Future<Iterable<Integer>> future2 =
+        underTest.multimapFetchSingleEntryFuture(
+            STATE_MULTIMAP_KEY_2, STATE_KEY_1, STATE_FAMILY, INT_CODER);
+    Mockito.verifyNoMoreInteractions(mockWindmill);
+
+    Windmill.KeyedGetDataRequest.Builder expectedRequest =
+        Windmill.KeyedGetDataRequest.newBuilder()
+            .setKey(DATA_KEY)
+            .setShardingKey(SHARDING_KEY)
+            .setWorkToken(WORK_TOKEN)
+            .setMaxBytes(WindmillStateReader.MAX_KEY_BYTES)
+            .addMultimapsToFetch(
+                Windmill.TagMultimapFetchRequest.newBuilder()
+                    .setTag(STATE_KEY_1)
+                    .setStateFamily(STATE_FAMILY)
+                    .setFetchEntryNamesOnly(false)
+                    .addEntriesToFetch(
+                        Windmill.TagMultimapEntry.newBuilder()
+                            .setEntryName(STATE_MULTIMAP_KEY_1)
+                            .setFetchMaxBytes(WindmillStateReader.INITIAL_MAX_MULTIMAP_BYTES)
+                            .build())
+                    .addEntriesToFetch(
+                        Windmill.TagMultimapEntry.newBuilder()
+                            .setEntryName(STATE_MULTIMAP_KEY_2)
+                            .setFetchMaxBytes(WindmillStateReader.INITIAL_MAX_MULTIMAP_BYTES)
+                            .build()));
+
+    Windmill.KeyedGetDataResponse.Builder response =
+        Windmill.KeyedGetDataResponse.newBuilder()
+            .setKey(DATA_KEY)
+            .addTagMultimaps(
+                Windmill.TagMultimapFetchResponse.newBuilder()
+                    .setTag(STATE_KEY_1)
+                    .setStateFamily(STATE_FAMILY)
+                    .addEntries(
+                        Windmill.TagMultimapEntry.newBuilder()
+                            .setEntryName(STATE_MULTIMAP_KEY_1)
+                            .addAllValues(Arrays.asList(intData(5), intData(6))))
+                    .addEntries(
+                        Windmill.TagMultimapEntry.newBuilder()
+                            .setEntryName(STATE_MULTIMAP_KEY_2)
+                            .addAllValues(Arrays.asList(intData(15), intData(16)))));
+    when(mockWindmill.getStateData(ArgumentMatchers.eq(COMPUTATION), ArgumentMatchers.any()))
+        .thenReturn(response.build());
+
+    Iterable<Integer> results1 = future1.get();
+    Iterable<Integer> results2 = future2.get();
+
+    final ArgumentCaptor<Windmill.KeyedGetDataRequest> requestCaptor =
+        ArgumentCaptor.forClass(Windmill.KeyedGetDataRequest.class);
+    Mockito.verify(mockWindmill)
+        .getStateData(ArgumentMatchers.eq(COMPUTATION), requestCaptor.capture());
+    assertMultimapFetchRequestEqual(
+        expectedRequest.build().getMultimapsToFetch(0),
+        requestCaptor.getValue().getMultimapsToFetch(0));
+
+    // Iterate over the results to force loading all the pages.

Review Comment:
   Done



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternals.java:
##########
@@ -1588,7 +1598,648 @@ private Future<Iterable<Map.Entry<ByteString, V>>> getFuture() {
         return reader.valuePrefixFuture(stateKeyPrefix, stateFamily, valueCoder);
       }
     }
-  };
+  }
+
+  private static class WindmillMultimap<K, V> extends SimpleWindmillState
+      implements MultimapState<K, V> {
+
+    private final StateNamespace namespace;
+    private final StateTag<MultimapState<K, V>> address;
+    private final ByteString stateKey;
+    private final String stateFamily;
+    private final Coder<K> keyCoder;
+    private final Coder<V> valueCoder;
+
+    private class KeyState {
+      final K originalKey;
+      KeyExistence existence;
+      // valuesCached can be true if only existence == KNOWN_EXIST and all values of this key is
+      // cached (both values and localAdditions).
+      boolean valuesCached;
+      // Represents the values in windmill. When new values are added during user processing, they
+      // are added to localAdditions but not values. Those new values will be added to values only
+      // after they are persisted into windmill and removed from localAdditions
+      ConcatIterables<V> values;
+      int valuesSize;
+
+      // When new values are added during user processing, they are added to localAdditions, so that
+      // we can later try to persist them in windmill. When a key is removed during user processing,
+      // we mark removedLocally to be true so that we can later try to delete it from windmill. If
+      // localAdditions is not empty and removedLocally is true, values in localAdditions will be
+      // added to windmill after old values in windmill are removed.
+      List<V> localAdditions;
+      boolean removedLocally;
+
+      KeyState(K originalKey) {
+        this.originalKey = originalKey;
+        existence = KeyExistence.UNKNOWN_EXISTENCE;
+        valuesCached = complete;
+        values = new ConcatIterables<>();
+        valuesSize = 0;
+        localAdditions = Lists.newArrayList();
+        removedLocally = false;
+      }
+    }
+
+    private enum KeyExistence {
+      // this key is known to exist
+      KNOWN_EXIST,
+      // this key is known to be nonexistent
+      KNOWN_NONEXISTENT,
+      // we don't know if this key is in this multimap, this is just to provide a mapping between
+      // the original key and the structural key.
+      UNKNOWN_EXISTENCE
+    }
+
+    private boolean cleared = false;
+    // We use the structural value of the keys as the key in keyStateMap, so that different java
+    // Objects with the same content will be treated as the same Multimap key.
+    private Map<Object, KeyState> keyStateMap = Maps.newHashMap();
+    // If true, all keys are cached in keyStateMap with existence == KNOWN_EXIST.
+    private boolean allKeysKnown = false;
+
+    private boolean complete = false;
+    // hasLocalAdditions and hasLocalRemovals track whether there are local changes that needs to be
+    // propagated to windmill.
+    private boolean hasLocalAdditions = false;
+    private boolean hasLocalRemovals = false;
+
+    private WindmillMultimap(
+        StateNamespace namespace,
+        StateTag<MultimapState<K, V>> address,
+        String stateFamily,
+        Coder<K> keyCoder,
+        Coder<V> valueCoder,
+        boolean isNewShardingKey) {
+      this.namespace = namespace;
+      this.address = address;
+      this.stateKey = encodeKey(namespace, address);
+      this.stateFamily = stateFamily;
+      this.keyCoder = keyCoder;
+      this.valueCoder = valueCoder;
+      this.complete = isNewShardingKey;
+      this.allKeysKnown = isNewShardingKey;
+    }
+
+    @Override
+    public void put(K key, V value) {
+      final Object structuralKey = keyCoder.structuralValue(key);
+      hasLocalAdditions = true;
+      keyStateMap.compute(
+          structuralKey,
+          (k, v) -> {
+            if (v == null) v = new KeyState(key);
+            v.existence = KeyExistence.KNOWN_EXIST;
+            v.localAdditions.add(value);
+            return v;
+          });
+    }
+
+    // Initiates a backend state read to fetch all entries if necessary.
+    private Future<Iterable<Map.Entry<ByteString, Iterable<V>>>> getFuture(boolean omitValues) {
+      if (complete) {
+        return Futures.immediateFuture(Collections.emptyList());
+      } else {
+        return reader.multimapFetchAllFuture(omitValues, stateKey, stateFamily, valueCoder);
+      }
+    }
+
+    // Initiates a backend state read to fetch a single entry if necessary.
+    private Future<Iterable<V>> getFutureForKey(K key) {
+      try {
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        keyCoder.encode(key, keyStream, Context.OUTER);
+        return reader.multimapFetchSingleEntryFuture(
+            keyStream.toByteString(), stateKey, stateFamily, valueCoder);
+      } catch (IOException e) {
+        throw new RuntimeException(e);
+      }
+    }
+
+    @Override
+    public ReadableState<Iterable<V>> get(K key) {
+      return new ReadableState<Iterable<V>>() {
+        final Object structuralKey = keyCoder.structuralValue(key);
+
+        @Override
+        public Iterable<V> read() {
+          KeyState keyState = keyStateMap.computeIfAbsent(structuralKey, k -> new KeyState(key));
+          if (keyState.existence == KeyExistence.KNOWN_NONEXISTENT) {
+            return Collections.emptyList();
+          }
+          if (allKeysKnown && keyState.existence == KeyExistence.UNKNOWN_EXISTENCE) {
+            keyStateMap.remove(structuralKey);
+            return Collections.emptyList();
+          }
+          Iterable<V> localNewValues =
+              Iterables.limit(keyState.localAdditions, keyState.localAdditions.size());
+          if (keyState.removedLocally) {
+            // this key has been removed locally but the removal hasn't been sent to windmill,
+            // thus values in windmill(if any) are obsolete, and we only care about local values.
+            return Iterables.unmodifiableIterable(localNewValues);
+          }
+          if (keyState.valuesCached || complete) {
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(
+                    Iterables.limit(keyState.values, keyState.valuesSize), localNewValues));
+          }
+          Future<Iterable<V>> persistedData = getFutureForKey(key);
+          try (Closeable scope = scopedReadState()) {
+            final Iterable<V> persistedValues = persistedData.get();
+            // Iterables.isEmpty() is O(1).
+            if (Iterables.isEmpty(persistedValues)) {
+              if (keyState.localAdditions.isEmpty()) {
+                // empty in both cache and windmill, mark key as KNOWN_NONEXISTENT.
+                keyState.existence = KeyExistence.KNOWN_NONEXISTENT;
+                return Collections.emptyList();
+              }
+              return Iterables.unmodifiableIterable(localNewValues);
+            }
+            keyState.existence = KeyExistence.KNOWN_EXIST;
+            if (persistedValues instanceof Weighted) {
+              keyState.valuesCached = true;
+              ConcatIterables<V> it = new ConcatIterables<>();
+              it.extendWith(persistedValues);
+              keyState.values = it;
+              keyState.valuesSize = Iterables.size(persistedValues);
+            }
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(persistedValues, localNewValues));
+          } catch (InterruptedException | ExecutionException | IOException e) {
+            if (e instanceof InterruptedException) {
+              Thread.currentThread().interrupt();
+            }
+            throw new RuntimeException("Unable to read Multimap state", e);
+          }
+        }
+
+        @Override
+        @SuppressWarnings("FutureReturnValueIgnored")
+        public ReadableState<Iterable<V>> readLater() {
+          WindmillMultimap.this.getFutureForKey(key);
+          return this;
+        }
+      };
+    }
+
+    @Override
+    protected WorkItemCommitRequest persistDirectly(WindmillStateCache.ForKeyAndFamily cache)
+        throws IOException {
+      if (!cleared && !hasLocalAdditions && !hasLocalRemovals) {
+        cache.put(namespace, address, this, 1);
+        return WorkItemCommitRequest.newBuilder().buildPartial();
+      }
+      WorkItemCommitRequest.Builder commitBuilder = WorkItemCommitRequest.newBuilder();
+      Windmill.TagMultimapUpdateRequest.Builder builder = commitBuilder.addMultimapUpdatesBuilder();
+      builder.setTag(stateKey).setStateFamily(stateFamily);
+
+      if (cleared) {
+        builder.setDeleteAll(true);
+      }
+      if (hasLocalRemovals || hasLocalAdditions) {
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        ByteStringOutputStream valueStream = new ByteStringOutputStream();
+        Iterator<Entry<Object, KeyState>> iterator = keyStateMap.entrySet().iterator();
+        while (iterator.hasNext()) {
+          KeyState keyState = iterator.next().getValue();
+          if (!keyState.removedLocally && keyState.localAdditions.isEmpty()) {
+            if (keyState.existence == KeyExistence.KNOWN_NONEXISTENT) iterator.remove();
+            continue;
+          }
+          keyCoder.encode(keyState.originalKey, keyStream, Context.OUTER);
+          ByteString encodedKey = keyStream.toByteStringAndReset();
+          Windmill.TagMultimapEntry.Builder entryBuilder = builder.addUpdatesBuilder();
+          entryBuilder.setEntryName(encodedKey);
+          entryBuilder.setDeleteAll(keyState.removedLocally);
+          keyState.removedLocally = false;
+          for (V value : keyState.localAdditions) {
+            valueCoder.encode(value, valueStream, Context.OUTER);
+            ByteString encodedValue = valueStream.toByteStringAndReset();
+            entryBuilder.addValues(encodedValue);
+          }
+          // Move newly added values from localAdditions to keyState.values as those new values now
+          // are also persisted in Windmill. If a key now has no more values and is not KNOWN_EXIST,
+          // remove it from cache.
+          if (keyState.valuesCached) {
+            keyState.values.extendWith(keyState.localAdditions);
+            keyState.valuesSize += keyState.localAdditions.size();
+          }
+          // Create a new localAdditions so that the cached values are unaffected.
+          keyState.localAdditions = Lists.newArrayList();
+          if (!keyState.valuesCached && keyState.existence != KeyExistence.KNOWN_EXIST) {
+            iterator.remove();
+          }
+        }
+      }
+
+      hasLocalAdditions = false;
+      hasLocalRemovals = false;
+      cleared = false;
+
+      cache.put(namespace, address, this, 1);
+      return commitBuilder.buildPartial();
+    }
+
+    @Override
+    public void remove(K key) {
+      final Object structuralKey = keyCoder.structuralValue(key);
+      KeyState keyState = keyStateMap.computeIfAbsent(structuralKey, k -> new KeyState(key));
+      if (keyState.existence == KeyExistence.KNOWN_NONEXISTENT
+          || (allKeysKnown && keyState.existence == KeyExistence.UNKNOWN_EXISTENCE)) {
+        return;
+      }
+      if (!keyState.valuesCached || (keyState.valuesCached && keyState.valuesSize > 0)) {
+        // there may be data in windmill that need to be removed.
+        hasLocalRemovals = true;
+        keyState.removedLocally = true;
+        keyState.values = new ConcatIterables<>();
+        keyState.valuesSize = 0;
+        keyState.existence = KeyExistence.KNOWN_NONEXISTENT;
+      } else {
+        // no data in windmill, deleting from local cache is sufficient.
+        keyStateMap.remove(structuralKey);
+      }
+      if (!keyState.localAdditions.isEmpty()) {
+        keyState.localAdditions = Lists.newArrayList();
+      }
+      keyState.valuesCached = true;
+    }
+
+    @Override
+    public void clear() {
+      keyStateMap = Maps.newHashMap();
+      cleared = true;
+      complete = true;
+      allKeysKnown = true;
+    }
+
+    @Override
+    public ReadableState<Iterable<K>> keys() {
+      return new ReadableState<Iterable<K>>() {
+
+        private Map<Object, K> cachedExistKeys() {
+          return keyStateMap.entrySet().stream()
+              .filter(entry -> entry.getValue().existence == KeyExistence.KNOWN_EXIST)
+              .collect(Collectors.toMap(Entry::getKey, e -> e.getValue().originalKey));
+        }
+
+        @Override
+        public Iterable<K> read() {
+          if (allKeysKnown) {
+            return Iterables.unmodifiableIterable(cachedExistKeys().values());
+          }
+          Future<Iterable<Entry<ByteString, Iterable<V>>>> persistedData = getFuture(true);
+          try (Closeable scope = scopedReadState()) {
+            Iterable<Entry<ByteString, Iterable<V>>> entries = persistedData.get();
+            if (entries instanceof Weighted) {
+              // This is a known amount of data, cache them all.
+              entries.forEach(
+                  entry -> {
+                    try {
+                      K originalKey = keyCoder.decode(entry.getKey().newInput(), Context.OUTER);
+                      KeyState keyState =
+                          keyStateMap.computeIfAbsent(
+                              keyCoder.structuralValue(originalKey),
+                              stk -> new KeyState(originalKey));
+                      if (keyState.existence == KeyExistence.UNKNOWN_EXISTENCE) {
+                        keyState.existence = KeyExistence.KNOWN_EXIST;
+                      }
+                    } catch (IOException e) {
+                      throw new RuntimeException(e);
+                    }
+                  });
+              allKeysKnown = true;
+              keyStateMap
+                  .values()
+                  .removeIf(keyState -> keyState.existence != KeyExistence.KNOWN_EXIST);
+              return Iterables.unmodifiableIterable(cachedExistKeys().values());
+            } else {
+              Map<Object, K> cachedExistKeys = Maps.newHashMap();
+              Set<Object> cachedNonExistKeys = Sets.newHashSet();
+              keyStateMap.forEach(
+                  (structuralKey, keyState) -> {
+                    switch (keyState.existence) {
+                      case KNOWN_EXIST:
+                        cachedExistKeys.put(structuralKey, keyState.originalKey);
+                        break;
+                      case KNOWN_NONEXISTENT:
+                        cachedNonExistKeys.add(structuralKey);
+                        break;
+                      default:
+                        break;
+                    }
+                  });
+              // keysOnlyInWindmill is lazily loaded.
+              Iterable<K> keysOnlyInWindmill =
+                  Iterables.filter(
+                      Iterables.transform(
+                          entries,
+                          entry -> {
+                            try {
+                              K originalKey =
+                                  keyCoder.decode(entry.getKey().newInput(), Context.OUTER);
+                              Object structuralKey = keyCoder.structuralValue(originalKey);
+                              if (cachedExistKeys.containsKey(structuralKey)
+                                  || cachedNonExistKeys.contains(structuralKey)) return null;
+                              return originalKey;
+                            } catch (IOException e) {
+                              throw new RuntimeException(e);
+                            }
+                          }),
+                      Objects::nonNull);
+              return Iterables.unmodifiableIterable(
+                  Iterables.concat(cachedExistKeys.values(), keysOnlyInWindmill));
+            }
+          } catch (InterruptedException | ExecutionException | IOException e) {
+            if (e instanceof InterruptedException) {
+              Thread.currentThread().interrupt();
+            }
+            throw new RuntimeException("Unable to read state", e);
+          }
+        }
+
+        @Override
+        @SuppressWarnings("FutureReturnValueIgnored")
+        public ReadableState<Iterable<K>> readLater() {
+          WindmillMultimap.this.getFuture(true);
+          return this;
+        }
+      };
+    }
+
+    @Override
+    public ReadableState<Iterable<Entry<K, V>>> entries() {
+      return new ReadableState<Iterable<Entry<K, V>>>() {
+        @Override
+        public Iterable<Entry<K, V>> read() {
+          if (complete) {
+            return Iterables.unmodifiableIterable(
+                unnestCachedEntries(mergedCachedEntries(null).entrySet()));
+          }
+          Future<Iterable<Entry<ByteString, Iterable<V>>>> persistedData = getFuture(false);
+          try (Closeable scope = scopedReadState()) {
+            Iterable<Entry<ByteString, Iterable<V>>> entries = persistedData.get();
+            if (Iterables.isEmpty(entries)) {
+              complete = true;
+              allKeysKnown = true;
+              return Iterables.unmodifiableIterable(
+                  unnestCachedEntries(mergedCachedEntries(null).entrySet()));
+            }
+            if (!(entries instanceof Weighted)) {
+              return nonWeightedEntries(entries);
+            }
+            // This is a known amount of data, cache them all.
+            entries.forEach(
+                entry -> {
+                  try {
+                    final K originalKey = keyCoder.decode(entry.getKey().newInput(), Context.OUTER);
+                    final Object structuralKey = keyCoder.structuralValue(originalKey);
+                    KeyState keyState =
+                        keyStateMap.computeIfAbsent(structuralKey, k -> new KeyState(originalKey));
+                    // Ignore any key from windmill that has been marked pending deletion or is
+                    // fully cached.
+                    if (keyState.existence == KeyExistence.KNOWN_NONEXISTENT
+                        || (keyState.existence == KeyExistence.KNOWN_EXIST
+                            && keyState.valuesCached)) return;
+                    // Or else cache contents from windmill.
+                    keyState.existence = KeyExistence.KNOWN_EXIST;
+                    keyState.values.extendWith(entry.getValue());
+                    keyState.valuesSize += Iterables.size(entry.getValue());
+                    // We can't set keyState.valuesCached to true here, because there may be more

Review Comment:
   Ack, done, removed a incorrect unit test as well.



##########
runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternalsTest.java:
##########
@@ -645,6 +655,1163 @@ public void testMapComplexPersist() throws Exception {
     assertEquals(0, commitBuilder.getValueUpdatesCount());
   }
 
+  private static <T> ByteString encodeWithCoder(T key, Coder<T> coder) throws IOException {
+    ByteStringOutputStream out = new ByteStringOutputStream();
+    coder.encode(key, out, Context.OUTER);
+    return out.toByteString();
+  }
+
+  // We use the structural value of the Multimap keys to differentiate between different keys. So we
+  // mix using the original key object and a duplicate but same key object so make sure the
+  // correctness.
+  private static byte[] dup(byte[] key) {
+    byte[] res = new byte[key.length];
+    System.arraycopy(key, 0, res, 0, key.length);
+    return res;
+  }
+
+  @Test
+  public void testMultimapGet() throws IOException {
+    final String tag = "multimap";
+    StateTag<MultimapState<byte[], Integer>> addr =
+        StateTags.multimap(tag, ByteArrayCoder.of(), VarIntCoder.of());
+    MultimapState<byte[], Integer> multimapState = underTest.state(NAMESPACE, addr);
+
+    final byte[] key = "key".getBytes(StandardCharsets.UTF_8);
+    SettableFuture<Iterable<Integer>> future = SettableFuture.create();
+    when(mockReader.multimapFetchSingleEntryFuture(
+            encodeWithCoder(key, ByteArrayCoder.of()),
+            key(NAMESPACE, tag),
+            STATE_FAMILY,
+            VarIntCoder.of()))
+        .thenReturn(future);
+
+    ReadableState<Iterable<Integer>> result = multimapState.get(dup(key)).readLater();
+    waitAndSet(future, Arrays.asList(1, 2, 3), 30);
+    assertThat(result.read(), Matchers.containsInAnyOrder(1, 2, 3));
+  }
+
+  @Test
+  public void testMultimapPutAndGet() throws IOException {
+    final String tag = "multimap";
+    StateTag<MultimapState<byte[], Integer>> addr =
+        StateTags.multimap(tag, ByteArrayCoder.of(), VarIntCoder.of());
+    MultimapState<byte[], Integer> multimapState = underTest.state(NAMESPACE, addr);
+
+    final byte[] key = "key".getBytes(StandardCharsets.UTF_8);
+    SettableFuture<Iterable<Integer>> future = SettableFuture.create();
+    when(mockReader.multimapFetchSingleEntryFuture(
+            encodeWithCoder(key, ByteArrayCoder.of()),
+            key(NAMESPACE, tag),
+            STATE_FAMILY,
+            VarIntCoder.of()))
+        .thenReturn(future);
+
+    multimapState.put(key, 1);
+    ReadableState<Iterable<Integer>> result = multimapState.get(dup(key)).readLater();
+    waitAndSet(future, Arrays.asList(1, 2, 3), 30);
+    assertThat(result.read(), Matchers.containsInAnyOrder(1, 1, 2, 3));
+  }
+
+  @Test
+  public void testMultimapRemoveAndGet() throws IOException {
+    final String tag = "multimap";
+    StateTag<MultimapState<byte[], Integer>> addr =
+        StateTags.multimap(tag, ByteArrayCoder.of(), VarIntCoder.of());
+    MultimapState<byte[], Integer> multimapState = underTest.state(NAMESPACE, addr);
+
+    final byte[] key = "key".getBytes(StandardCharsets.UTF_8);
+    SettableFuture<Iterable<Integer>> future = SettableFuture.create();
+    when(mockReader.multimapFetchSingleEntryFuture(
+            encodeWithCoder(key, ByteArrayCoder.of()),
+            key(NAMESPACE, tag),
+            STATE_FAMILY,
+            VarIntCoder.of()))
+        .thenReturn(future);
+
+    ReadableState<Iterable<Integer>> result1 = multimapState.get(key).readLater();
+    ReadableState<Iterable<Integer>> result2 = multimapState.get(dup(key)).readLater();
+    waitAndSet(future, Arrays.asList(1, 2, 3), 30);
+
+    assertTrue(multimapState.containsKey(key).read());
+    assertThat(result1.read(), Matchers.containsInAnyOrder(1, 2, 3));
+
+    multimapState.remove(key);
+    assertFalse(multimapState.containsKey(dup(key)).read());
+    assertThat(result2.read(), Matchers.emptyIterable());
+  }
+
+  @Test
+  public void testMultimapRemoveThenPut() throws IOException {
+    final String tag = "multimap";
+    StateTag<MultimapState<byte[], Integer>> addr =
+        StateTags.multimap(tag, ByteArrayCoder.of(), VarIntCoder.of());
+    MultimapState<byte[], Integer> multimapState = underTest.state(NAMESPACE, addr);
+
+    final byte[] key = "key".getBytes(StandardCharsets.UTF_8);
+    SettableFuture<Iterable<Integer>> future = SettableFuture.create();
+    when(mockReader.multimapFetchSingleEntryFuture(
+            encodeWithCoder(key, ByteArrayCoder.of()),
+            key(NAMESPACE, tag),
+            STATE_FAMILY,
+            VarIntCoder.of()))
+        .thenReturn(future);
+
+    ReadableState<Iterable<Integer>> result = multimapState.get(key).readLater();
+    waitAndSet(future, Arrays.asList(1, 2, 3), 30);
+    multimapState.remove(dup(key));
+    multimapState.put(key, 4);
+    multimapState.put(dup(key), 5);
+    assertThat(result.read(), Matchers.containsInAnyOrder(4, 5));
+  }
+
+  @Test
+  public void testMultimapRemovePersistPut() {
+    final String tag = "multimap";
+    StateTag<MultimapState<String, Integer>> addr =
+        StateTags.multimap(tag, StringUtf8Coder.of(), VarIntCoder.of());
+    MultimapState<String, Integer> multimapState = underTest.state(NAMESPACE, addr);
+
+    final String key = "key";
+    multimapState.put(key, 1);
+    multimapState.put(key, 2);
+
+    Windmill.WorkItemCommitRequest.Builder commitBuilder =
+        Windmill.WorkItemCommitRequest.newBuilder();
+
+    // After key is removed, this key is cache complete and no need to read backend.
+    multimapState.remove(key);
+    multimapState.put(key, 4);
+    // Since key is cache complete, value 4 in localAdditions should be added to cached values,
+    /// instead of being cleared from cache after persisted.
+    underTest.persist(commitBuilder);
+    assertTagMultimapUpdates(
+        Iterables.getOnlyElement(commitBuilder.getMultimapUpdatesBuilderList()),
+        new MultimapEntryUpdate(key, Arrays.asList(4), true));
+
+    multimapState.put(key, 5);
+    assertThat(multimapState.get(key).read(), Matchers.containsInAnyOrder(4, 5));
+  }
+
+  @Test
+  public void testMultimapGetLocalCombineStorage() throws IOException {
+    final String tag = "multimap";
+    StateTag<MultimapState<byte[], Integer>> addr =
+        StateTags.multimap(tag, ByteArrayCoder.of(), VarIntCoder.of());
+    MultimapState<byte[], Integer> multimapState = underTest.state(NAMESPACE, addr);
+
+    final byte[] key = "key".getBytes(StandardCharsets.UTF_8);
+    SettableFuture<Iterable<Integer>> future = SettableFuture.create();
+    when(mockReader.multimapFetchSingleEntryFuture(
+            encodeWithCoder(key, ByteArrayCoder.of()),
+            key(NAMESPACE, tag),
+            STATE_FAMILY,
+            VarIntCoder.of()))
+        .thenReturn(future);
+
+    ReadableState<Iterable<Integer>> result = multimapState.get(dup(key)).readLater();
+    waitAndSet(future, Arrays.asList(1, 2), 30);
+    multimapState.put(key, 3);
+    multimapState.put(dup(key), 4);
+    assertFalse(multimapState.isEmpty().read());
+    assertThat(result.read(), Matchers.containsInAnyOrder(1, 2, 3, 4));
+  }
+
+  @Test
+  public void testMultimapLocalRemoveOverrideStorage() throws IOException {
+    final String tag = "multimap";
+    StateTag<MultimapState<byte[], Integer>> addr =
+        StateTags.multimap(tag, ByteArrayCoder.of(), VarIntCoder.of());
+    MultimapState<byte[], Integer> multimapState = underTest.state(NAMESPACE, addr);
+
+    final byte[] key = "key".getBytes(StandardCharsets.UTF_8);
+    SettableFuture<Iterable<Integer>> future = SettableFuture.create();
+    when(mockReader.multimapFetchSingleEntryFuture(
+            encodeWithCoder(key, ByteArrayCoder.of()),
+            key(NAMESPACE, tag),
+            STATE_FAMILY,
+            VarIntCoder.of()))
+        .thenReturn(future);
+
+    ReadableState<Iterable<Integer>> result = multimapState.get(key).readLater();
+    waitAndSet(future, Arrays.asList(1, 2), 30);
+    multimapState.remove(dup(key));
+    assertThat(result.read(), Matchers.emptyIterable());
+    multimapState.put(key, 3);
+    multimapState.put(dup(key), 4);
+    assertFalse(multimapState.isEmpty().read());
+    assertThat(result.read(), Matchers.containsInAnyOrder(3, 4));
+  }
+
+  @Test
+  public void testMultimapLocalClearOverrideStorage() throws IOException {
+    final String tag = "multimap";
+    StateTag<MultimapState<byte[], Integer>> addr =
+        StateTags.multimap(tag, ByteArrayCoder.of(), VarIntCoder.of());
+    MultimapState<byte[], Integer> multimapState = underTest.state(NAMESPACE, addr);
+
+    final byte[] key1 = "key1".getBytes(StandardCharsets.UTF_8);
+    final byte[] key2 = "key2".getBytes(StandardCharsets.UTF_8);
+    SettableFuture<Iterable<Integer>> future = SettableFuture.create();
+    when(mockReader.multimapFetchSingleEntryFuture(
+            encodeWithCoder(key1, ByteArrayCoder.of()),
+            key(NAMESPACE, tag),
+            STATE_FAMILY,
+            VarIntCoder.of()))
+        .thenReturn(future);
+    SettableFuture<Iterable<Integer>> future2 = SettableFuture.create();
+    when(mockReader.multimapFetchSingleEntryFuture(
+            encodeWithCoder(key2, ByteArrayCoder.of()),
+            key(NAMESPACE, tag),
+            STATE_FAMILY,
+            VarIntCoder.of()))
+        .thenReturn(future2);
+
+    ReadableState<Iterable<Integer>> result1 = multimapState.get(key1).readLater();
+    ReadableState<Iterable<Integer>> result2 = multimapState.get(dup(key2)).readLater();
+    multimapState.clear();
+    waitAndSet(future, Arrays.asList(1, 2), 30);
+    assertThat(result1.read(), Matchers.emptyIterable());
+    assertThat(result2.read(), Matchers.emptyIterable());
+    assertThat(multimapState.keys().read(), Matchers.emptyIterable());
+    assertThat(multimapState.entries().read(), Matchers.emptyIterable());
+    assertTrue(multimapState.isEmpty().read());
+  }
+
+  private static Map.Entry<ByteString, Iterable<Integer>> multimapEntry(
+      byte[] key, Integer... values) throws IOException {
+    return new AbstractMap.SimpleEntry<>(
+        encodeWithCoder(key, ByteArrayCoder.of()), Arrays.asList(values));
+  }
+
+  @SafeVarargs
+  private static <T> List<T> weightedList(T... entries) {
+    WindmillStateReader.WeightedList<T> list =
+        new WindmillStateReader.WeightedList<>(new ArrayList<>());
+    for (T entry : entries) {
+      list.addWeighted(entry, 1);
+    }
+    return list;
+  }
+
+  @Test
+  public void testMultimapBasicEntriesAndKeys() throws IOException {
+    final String tag = "multimap";
+    StateTag<MultimapState<byte[], Integer>> addr =
+        StateTags.multimap(tag, ByteArrayCoder.of(), VarIntCoder.of());
+    MultimapState<byte[], Integer> multimapState = underTest.state(NAMESPACE, addr);
+
+    final byte[] key1 = "key1".getBytes(StandardCharsets.UTF_8);
+    final byte[] key2 = "key2".getBytes(StandardCharsets.UTF_8);
+
+    SettableFuture<Iterable<Map.Entry<ByteString, Iterable<Integer>>>> entriesFuture =
+        SettableFuture.create();
+    when(mockReader.multimapFetchAllFuture(
+            false, key(NAMESPACE, tag), STATE_FAMILY, VarIntCoder.of()))
+        .thenReturn(entriesFuture);
+    SettableFuture<Iterable<Map.Entry<ByteString, Iterable<Integer>>>> keysFuture =
+        SettableFuture.create();
+    when(mockReader.multimapFetchAllFuture(
+            true, key(NAMESPACE, tag), STATE_FAMILY, VarIntCoder.of()))
+        .thenReturn(keysFuture);
+
+    ReadableState<Iterable<Map.Entry<byte[], Integer>>> entriesResult =
+        multimapState.entries().readLater();
+    ReadableState<Iterable<byte[]>> keysResult = multimapState.keys().readLater();
+    waitAndSet(
+        entriesFuture,
+        Arrays.asList(multimapEntry(key1, 1, 2, 3), multimapEntry(key2, 2, 3, 4)),
+        30);
+    waitAndSet(keysFuture, Arrays.asList(multimapEntry(key1), multimapEntry(key2)), 30);
+
+    Iterable<Map.Entry<byte[], Integer>> entries = entriesResult.read();
+    assertEquals(6, Iterables.size(entries));
+    assertThat(
+        entries,
+        Matchers.containsInAnyOrder(
+            multimapEntryMatcher(key1, 1),
+            multimapEntryMatcher(key1, 2),
+            multimapEntryMatcher(key1, 3),
+            multimapEntryMatcher(key2, 4),
+            multimapEntryMatcher(key2, 2),
+            multimapEntryMatcher(key2, 3)));
+
+    Iterable<byte[]> keys = keysResult.read();
+    assertEquals(2, Iterables.size(keys));
+    assertThat(keys, Matchers.containsInAnyOrder(key1, key2));
+  }
+
+  private static CombinableMatcher<Object> multimapEntryMatcher(byte[] key, Integer value) {
+    return Matchers.both(Matchers.hasProperty("key", Matchers.equalTo(key)))
+        .and(Matchers.hasProperty("value", Matchers.equalTo(value)));
+  }
+
+  @Test
+  public void testMultimapEntriesAndKeysMergeLocalAdd() throws IOException {
+    final String tag = "multimap";
+    StateTag<MultimapState<byte[], Integer>> addr =
+        StateTags.multimap(tag, ByteArrayCoder.of(), VarIntCoder.of());
+    MultimapState<byte[], Integer> multimapState = underTest.state(NAMESPACE, addr);
+
+    final byte[] key1 = "key1".getBytes(StandardCharsets.UTF_8);
+    final byte[] key2 = "key2".getBytes(StandardCharsets.UTF_8);
+    final byte[] key3 = "key3".getBytes(StandardCharsets.UTF_8);
+
+    SettableFuture<Iterable<Map.Entry<ByteString, Iterable<Integer>>>> entriesFuture =
+        SettableFuture.create();
+    when(mockReader.multimapFetchAllFuture(
+            false, key(NAMESPACE, tag), STATE_FAMILY, VarIntCoder.of()))
+        .thenReturn(entriesFuture);
+    SettableFuture<Iterable<Map.Entry<ByteString, Iterable<Integer>>>> keysFuture =
+        SettableFuture.create();
+    when(mockReader.multimapFetchAllFuture(
+            true, key(NAMESPACE, tag), STATE_FAMILY, VarIntCoder.of()))
+        .thenReturn(keysFuture);
+
+    ReadableState<Iterable<Map.Entry<byte[], Integer>>> entriesResult =
+        multimapState.entries().readLater();
+    ReadableState<Iterable<byte[]>> keysResult = multimapState.keys().readLater();
+    waitAndSet(
+        entriesFuture,
+        Arrays.asList(multimapEntry(key1, 1, 2, 3), multimapEntry(key2, 2, 3, 4)),
+        30);
+    waitAndSet(keysFuture, Arrays.asList(multimapEntry(key1), multimapEntry(key2)), 30);
+
+    multimapState.put(key1, 7);
+    multimapState.put(dup(key2), 8);
+    multimapState.put(dup(key3), 8);
+
+    Iterable<Map.Entry<byte[], Integer>> entries = entriesResult.read();
+    assertEquals(9, Iterables.size(entries));
+    assertThat(
+        entries,
+        Matchers.containsInAnyOrder(
+            multimapEntryMatcher(key1, 1),
+            multimapEntryMatcher(key1, 2),
+            multimapEntryMatcher(key1, 3),
+            multimapEntryMatcher(key1, 7),
+            multimapEntryMatcher(key2, 4),
+            multimapEntryMatcher(key2, 2),
+            multimapEntryMatcher(key2, 3),
+            multimapEntryMatcher(key2, 8),
+            multimapEntryMatcher(key3, 8)));
+
+    Iterable<byte[]> keys = keysResult.read();
+    assertEquals(3, Iterables.size(keys));
+    assertThat(keys, Matchers.containsInAnyOrder(key1, key2, key3));
+  }
+
+  @Test
+  public void testMultimapEntriesAndKeysMergeLocalRemove() throws IOException {
+    final String tag = "multimap";
+    StateTag<MultimapState<byte[], Integer>> addr =
+        StateTags.multimap(tag, ByteArrayCoder.of(), VarIntCoder.of());
+    MultimapState<byte[], Integer> multimapState = underTest.state(NAMESPACE, addr);
+
+    final byte[] key1 = "key1".getBytes(StandardCharsets.UTF_8);
+    final byte[] key2 = "key2".getBytes(StandardCharsets.UTF_8);
+    final byte[] key3 = "key3".getBytes(StandardCharsets.UTF_8);
+
+    SettableFuture<Iterable<Map.Entry<ByteString, Iterable<Integer>>>> entriesFuture =
+        SettableFuture.create();
+    when(mockReader.multimapFetchAllFuture(
+            false, key(NAMESPACE, tag), STATE_FAMILY, VarIntCoder.of()))
+        .thenReturn(entriesFuture);
+    SettableFuture<Iterable<Map.Entry<ByteString, Iterable<Integer>>>> keysFuture =
+        SettableFuture.create();
+    when(mockReader.multimapFetchAllFuture(
+            true, key(NAMESPACE, tag), STATE_FAMILY, VarIntCoder.of()))
+        .thenReturn(keysFuture);
+
+    ReadableState<Iterable<Map.Entry<byte[], Integer>>> entriesResult =
+        multimapState.entries().readLater();
+    ReadableState<Iterable<byte[]>> keysResult = multimapState.keys().readLater();
+    waitAndSet(
+        entriesFuture,
+        Arrays.asList(multimapEntry(key1, 1, 2, 3), multimapEntry(key2, 2, 3, 4)),
+        30);
+    waitAndSet(keysFuture, Arrays.asList(multimapEntry(key1), multimapEntry(key2)), 30);
+
+    multimapState.remove(dup(key1));
+    multimapState.put(key2, 8);
+    multimapState.put(dup(key3), 8);
+
+    Iterable<Map.Entry<byte[], Integer>> entries = entriesResult.read();
+    assertEquals(5, Iterables.size(entries));
+    assertThat(
+        entries,
+        Matchers.containsInAnyOrder(
+            multimapEntryMatcher(key2, 4),
+            multimapEntryMatcher(key2, 2),
+            multimapEntryMatcher(key2, 3),
+            multimapEntryMatcher(key2, 8),
+            multimapEntryMatcher(key3, 8)));
+
+    Iterable<byte[]> keys = keysResult.read();
+    assertThat(keys, Matchers.containsInAnyOrder(key2, key3));
+  }
+
+  @Test
+  public void testMultimapEntriesPaginated() throws IOException {
+    final String tag = "multimap";
+    StateTag<MultimapState<byte[], Integer>> addr =
+        StateTags.multimap(tag, ByteArrayCoder.of(), VarIntCoder.of());
+    MultimapState<byte[], Integer> multimapState = underTest.state(NAMESPACE, addr);
+
+    final byte[] key1 = "key1".getBytes(StandardCharsets.UTF_8);
+    final byte[] key2 = "key2".getBytes(StandardCharsets.UTF_8);
+    final byte[] key3 = "key3".getBytes(StandardCharsets.UTF_8);
+
+    SettableFuture<Iterable<Map.Entry<ByteString, Iterable<Integer>>>> entriesFuture =
+        SettableFuture.create();
+    when(mockReader.multimapFetchAllFuture(
+            false, key(NAMESPACE, tag), STATE_FAMILY, VarIntCoder.of()))
+        .thenReturn(entriesFuture);
+    SettableFuture<Iterable<Map.Entry<ByteString, Iterable<Integer>>>> keysFuture =
+        SettableFuture.create();
+    when(mockReader.multimapFetchAllFuture(
+            true, key(NAMESPACE, tag), STATE_FAMILY, VarIntCoder.of()))
+        .thenReturn(keysFuture);
+
+    ReadableState<Iterable<Map.Entry<byte[], Integer>>> entriesResult =
+        multimapState.entries().readLater();
+    ReadableState<Iterable<byte[]>> keysResult = multimapState.keys().readLater();
+    waitAndSet(
+        entriesFuture,
+        weightedList(
+            multimapEntry(key1, 1, 2, 3),
+            // entry key2 is returned in 2 separate responses due to pagination.
+            multimapEntry(key2, 2, 3, 4),
+            multimapEntry(key2, 4, 5)),
+        30);
+    waitAndSet(keysFuture, Arrays.asList(multimapEntry(key1), multimapEntry(key2)), 30);
+
+    multimapState.remove(dup(key1));
+    multimapState.put(key2, 8);
+    multimapState.put(dup(key3), 8);
+
+    Iterable<Map.Entry<byte[], Integer>> entries = entriesResult.read();
+    assertEquals(7, Iterables.size(entries));
+    assertThat(
+        entries,
+        Matchers.containsInAnyOrder(
+            multimapEntryMatcher(key2, 2),
+            multimapEntryMatcher(key2, 3),
+            multimapEntryMatcher(key2, 4),
+            multimapEntryMatcher(key2, 4),
+            multimapEntryMatcher(key2, 5),
+            multimapEntryMatcher(key2, 8),
+            multimapEntryMatcher(key3, 8)));
+
+    Iterable<byte[]> keys = keysResult.read();
+    assertThat(keys, Matchers.containsInAnyOrder(key2, key3));
+  }
+
+  @Test
+  public void testMultimapCacheComplete() throws IOException {
+    final String tag = "multimap";
+    StateTag<MultimapState<byte[], Integer>> addr =
+        StateTags.multimap(tag, ByteArrayCoder.of(), VarIntCoder.of());
+    MultimapState<byte[], Integer> multimapState = underTest.state(NAMESPACE, addr);
+
+    final byte[] key = "key".getBytes(StandardCharsets.UTF_8);
+
+    SettableFuture<Iterable<Map.Entry<ByteString, Iterable<Integer>>>> entriesFuture =
+        SettableFuture.create();
+    when(mockReader.multimapFetchAllFuture(
+            false, key(NAMESPACE, tag), STATE_FAMILY, VarIntCoder.of()))
+        .thenReturn(entriesFuture);
+
+    // to set up the multimap as cache complete
+    waitAndSet(entriesFuture, weightedList(multimapEntry(key, 1, 2, 3)), 30);
+    multimapState.entries().read();
+
+    multimapState.put(key, 2);
+
+    when(mockReader.multimapFetchAllFuture(
+            anyBoolean(), eq(key(NAMESPACE, tag)), eq(STATE_FAMILY), eq(VarIntCoder.of())))
+        .thenThrow(
+            new RuntimeException(
+                "The multimap is cache complete and should not perform any windmill read."));
+    when(mockReader.multimapFetchSingleEntryFuture(
+            any(), eq(key(NAMESPACE, tag)), eq(STATE_FAMILY), eq(VarIntCoder.of())))
+        .thenThrow(
+            new RuntimeException(
+                "The multimap is cache complete and should not perform any windmill read."));
+
+    Iterable<Map.Entry<byte[], Integer>> entries = multimapState.entries().read();
+    assertEquals(4, Iterables.size(entries));
+    assertThat(
+        entries,
+        Matchers.containsInAnyOrder(
+            multimapEntryMatcher(key, 1),
+            multimapEntryMatcher(key, 2),
+            multimapEntryMatcher(key, 3),
+            multimapEntryMatcher(key, 2)));
+
+    Iterable<byte[]> keys = multimapState.keys().read();
+    assertThat(keys, Matchers.containsInAnyOrder(key));
+
+    Iterable<Integer> values = multimapState.get(dup(key)).read();
+    assertThat(values, Matchers.containsInAnyOrder(1, 2, 2, 3));
+  }
+
+  @Test
+  public void testMultimapCachedSingleEntry() throws IOException {
+    final String tag = "multimap";
+    StateTag<MultimapState<byte[], Integer>> addr =
+        StateTags.multimap(tag, ByteArrayCoder.of(), VarIntCoder.of());
+    MultimapState<byte[], Integer> multimapState = underTest.state(NAMESPACE, addr);
+
+    final byte[] key = "key".getBytes(StandardCharsets.UTF_8);
+
+    SettableFuture<Iterable<Integer>> entryFuture = SettableFuture.create();
+    when(mockReader.multimapFetchSingleEntryFuture(
+            encodeWithCoder(key, ByteArrayCoder.of()),
+            key(NAMESPACE, tag),
+            STATE_FAMILY,
+            VarIntCoder.of()))
+        .thenReturn(entryFuture);
+
+    // to set up the entry key as cache complete and add some local changes
+    waitAndSet(entryFuture, weightedList(1, 2, 3), 30);
+    multimapState.get(key).read();
+    multimapState.put(key, 2);
+
+    when(mockReader.multimapFetchSingleEntryFuture(
+            eq(encodeWithCoder(key, ByteArrayCoder.of())),
+            eq(key(NAMESPACE, tag)),
+            eq(STATE_FAMILY),
+            eq(VarIntCoder.of())))
+        .thenThrow(
+            new RuntimeException(
+                "The multimap is cache complete for "
+                    + Arrays.toString(key)
+                    + " and should not perform any windmill read."));
+
+    Iterable<Integer> values = multimapState.get(dup(key)).read();
+    assertThat(values, Matchers.containsInAnyOrder(1, 2, 2, 3));
+    assertTrue(multimapState.containsKey(key).read());
+  }
+
+  @Test
+  public void testMultimapCachedPartialEntry() throws IOException {
+    final String tag = "multimap";
+    StateTag<MultimapState<byte[], Integer>> addr =
+        StateTags.multimap(tag, ByteArrayCoder.of(), VarIntCoder.of());
+    MultimapState<byte[], Integer> multimapState = underTest.state(NAMESPACE, addr);
+
+    final byte[] key1 = "key1".getBytes(StandardCharsets.UTF_8);
+    final byte[] key2 = "key2".getBytes(StandardCharsets.UTF_8);
+    final byte[] key3 = "key3".getBytes(StandardCharsets.UTF_8);
+
+    SettableFuture<Iterable<Integer>> entryFuture = SettableFuture.create();
+    when(mockReader.multimapFetchSingleEntryFuture(
+            encodeWithCoder(key1, ByteArrayCoder.of()),
+            key(NAMESPACE, tag),
+            STATE_FAMILY,
+            VarIntCoder.of()))
+        .thenReturn(entryFuture);
+
+    // to set up the entry key1 as cache complete and add some local changes
+    waitAndSet(entryFuture, weightedList(1, 2, 3), 30);
+    multimapState.get(key1).read();
+    multimapState.put(key1, 2);
+    multimapState.put(key3, 20);
+
+    SettableFuture<Iterable<Map.Entry<ByteString, Iterable<Integer>>>> entriesFuture =
+        SettableFuture.create();
+    when(mockReader.multimapFetchAllFuture(
+            false, key(NAMESPACE, tag), STATE_FAMILY, VarIntCoder.of()))
+        .thenReturn(entriesFuture);
+
+    // windmill contains extra entry key2
+    waitAndSet(
+        entriesFuture,
+        weightedList(multimapEntry(key1, 1, 2, 3), multimapEntry(key2, 4, 5, 6)),
+        30);
+
+    // key1 exist in both cache and windmill; key2 exists only in windmill; key3 exists only in
+    // cache. They should all be merged.
+    Iterable<Map.Entry<byte[], Integer>> entries = multimapState.entries().read();
+
+    assertEquals(8, Iterables.size(entries));
+    assertThat(
+        entries,
+        Matchers.containsInAnyOrder(
+            multimapEntryMatcher(key1, 1),
+            multimapEntryMatcher(key1, 2),
+            multimapEntryMatcher(key1, 2),
+            multimapEntryMatcher(key1, 3),
+            multimapEntryMatcher(key2, 4),
+            multimapEntryMatcher(key2, 5),
+            multimapEntryMatcher(key2, 6),
+            multimapEntryMatcher(key3, 20)));
+
+    assertThat(multimapState.keys().read(), Matchers.containsInAnyOrder(key1, key2, key3));
+  }
+
+  @Test
+  public void testMultimapCachedPartialEntryCannotCachePolled() throws IOException {

Review Comment:
   Renamed the test.



##########
runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WindmillStateReaderTest.java:
##########
@@ -99,6 +113,713 @@ private ByteString intData(int value) throws IOException {
     return output.toByteString();
   }
 
+  @Test
+  public void testReadMultimapSingleEntry() throws Exception {
+    Future<Iterable<Integer>> future =
+        underTest.multimapFetchSingleEntryFuture(
+            STATE_MULTIMAP_KEY_1, STATE_KEY_1, STATE_FAMILY, INT_CODER);
+    Mockito.verifyNoMoreInteractions(mockWindmill);
+
+    Windmill.KeyedGetDataRequest.Builder expectedRequest =
+        Windmill.KeyedGetDataRequest.newBuilder()
+            .setKey(DATA_KEY)
+            .setShardingKey(SHARDING_KEY)
+            .setWorkToken(WORK_TOKEN)
+            .setMaxBytes(WindmillStateReader.MAX_KEY_BYTES)
+            .addMultimapsToFetch(
+                Windmill.TagMultimapFetchRequest.newBuilder()
+                    .setTag(STATE_KEY_1)
+                    .setStateFamily(STATE_FAMILY)
+                    .setFetchEntryNamesOnly(false)
+                    .addEntriesToFetch(
+                        Windmill.TagMultimapEntry.newBuilder()
+                            .setEntryName(STATE_MULTIMAP_KEY_1)
+                            .setFetchMaxBytes(WindmillStateReader.INITIAL_MAX_MULTIMAP_BYTES)
+                            .build()));
+
+    Windmill.KeyedGetDataResponse.Builder response =
+        Windmill.KeyedGetDataResponse.newBuilder()
+            .setKey(DATA_KEY)
+            .addTagMultimaps(
+                Windmill.TagMultimapFetchResponse.newBuilder()
+                    .setTag(STATE_KEY_1)
+                    .setStateFamily(STATE_FAMILY)
+                    .addEntries(
+                        Windmill.TagMultimapEntry.newBuilder()
+                            .setEntryName(STATE_MULTIMAP_KEY_1)
+                            .addAllValues(Arrays.asList(intData(5), intData(6)))));
+    Mockito.when(mockWindmill.getStateData(COMPUTATION, expectedRequest.build()))
+        .thenReturn(response.build());
+
+    Iterable<Integer> results = future.get();
+    Mockito.verify(mockWindmill).getStateData(COMPUTATION, expectedRequest.build());
+    for (Integer unused : results) {
+      // Iterate over the results to force loading all the pages.
+    }
+    Mockito.verifyNoMoreInteractions(mockWindmill);
+
+    assertThat(results, Matchers.containsInAnyOrder(5, 6));
+    assertNoReader(future);
+  }
+
+  @Test
+  public void testReadMultimapSingleEntryPaginated() throws Exception {
+    Future<Iterable<Integer>> future =
+        underTest.multimapFetchSingleEntryFuture(
+            STATE_MULTIMAP_KEY_1, STATE_KEY_1, STATE_FAMILY, INT_CODER);
+    Mockito.verifyNoMoreInteractions(mockWindmill);
+
+    Windmill.KeyedGetDataRequest.Builder expectedRequest1 =
+        Windmill.KeyedGetDataRequest.newBuilder()
+            .setKey(DATA_KEY)
+            .setShardingKey(SHARDING_KEY)
+            .setWorkToken(WORK_TOKEN)
+            .setMaxBytes(WindmillStateReader.MAX_KEY_BYTES)
+            .addMultimapsToFetch(
+                Windmill.TagMultimapFetchRequest.newBuilder()
+                    .setTag(STATE_KEY_1)
+                    .setStateFamily(STATE_FAMILY)
+                    .setFetchEntryNamesOnly(false)
+                    .addEntriesToFetch(
+                        Windmill.TagMultimapEntry.newBuilder()
+                            .setEntryName(STATE_MULTIMAP_KEY_1)
+                            .setFetchMaxBytes(WindmillStateReader.INITIAL_MAX_MULTIMAP_BYTES)
+                            .build()));
+
+    Windmill.KeyedGetDataResponse.Builder response1 =
+        Windmill.KeyedGetDataResponse.newBuilder()
+            .setKey(DATA_KEY)
+            .addTagMultimaps(
+                Windmill.TagMultimapFetchResponse.newBuilder()
+                    .setTag(STATE_KEY_1)
+                    .setStateFamily(STATE_FAMILY)
+                    .addEntries(
+                        Windmill.TagMultimapEntry.newBuilder()
+                            .setEntryName(STATE_MULTIMAP_KEY_1)
+                            .addAllValues(Arrays.asList(intData(5), intData(6)))
+                            .setContinuationPosition(500)));
+    Windmill.KeyedGetDataRequest.Builder expectedRequest2 =
+        Windmill.KeyedGetDataRequest.newBuilder()
+            .setKey(DATA_KEY)
+            .setShardingKey(SHARDING_KEY)
+            .setWorkToken(WORK_TOKEN)
+            .setMaxBytes(WindmillStateReader.MAX_CONTINUATION_KEY_BYTES)
+            .addMultimapsToFetch(
+                Windmill.TagMultimapFetchRequest.newBuilder()
+                    .setTag(STATE_KEY_1)
+                    .setStateFamily(STATE_FAMILY)
+                    .setFetchEntryNamesOnly(false)
+                    .addEntriesToFetch(
+                        Windmill.TagMultimapEntry.newBuilder()
+                            .setEntryName(STATE_MULTIMAP_KEY_1)
+                            .setFetchMaxBytes(WindmillStateReader.CONTINUATION_MAX_MULTIMAP_BYTES)
+                            .setRequestPosition(500)
+                            .build()));
+
+    Windmill.KeyedGetDataResponse.Builder response2 =
+        Windmill.KeyedGetDataResponse.newBuilder()
+            .setKey(DATA_KEY)
+            .addTagMultimaps(
+                Windmill.TagMultimapFetchResponse.newBuilder()
+                    .setTag(STATE_KEY_1)
+                    .setStateFamily(STATE_FAMILY)
+                    .addEntries(
+                        Windmill.TagMultimapEntry.newBuilder()
+                            .setEntryName(STATE_MULTIMAP_KEY_1)
+                            .addAllValues(Arrays.asList(intData(7), intData(8)))
+                            .setContinuationPosition(800)
+                            .setRequestPosition(500)));
+    Windmill.KeyedGetDataRequest.Builder expectedRequest3 =
+        Windmill.KeyedGetDataRequest.newBuilder()
+            .setKey(DATA_KEY)
+            .setShardingKey(SHARDING_KEY)
+            .setWorkToken(WORK_TOKEN)
+            .setMaxBytes(WindmillStateReader.MAX_CONTINUATION_KEY_BYTES)
+            .addMultimapsToFetch(
+                Windmill.TagMultimapFetchRequest.newBuilder()
+                    .setTag(STATE_KEY_1)
+                    .setStateFamily(STATE_FAMILY)
+                    .setFetchEntryNamesOnly(false)
+                    .addEntriesToFetch(
+                        Windmill.TagMultimapEntry.newBuilder()
+                            .setEntryName(STATE_MULTIMAP_KEY_1)
+                            .setFetchMaxBytes(WindmillStateReader.CONTINUATION_MAX_MULTIMAP_BYTES)
+                            .setRequestPosition(800)
+                            .build()));
+
+    Windmill.KeyedGetDataResponse.Builder response3 =
+        Windmill.KeyedGetDataResponse.newBuilder()
+            .setKey(DATA_KEY)
+            .addTagMultimaps(
+                Windmill.TagMultimapFetchResponse.newBuilder()
+                    .setTag(STATE_KEY_1)
+                    .setStateFamily(STATE_FAMILY)
+                    .addEntries(
+                        Windmill.TagMultimapEntry.newBuilder()
+                            .setEntryName(STATE_MULTIMAP_KEY_1)
+                            .addAllValues(Arrays.asList(intData(9), intData(10)))
+                            .setRequestPosition(800)));
+    Mockito.when(mockWindmill.getStateData(COMPUTATION, expectedRequest1.build()))
+        .thenReturn(response1.build());
+    Mockito.when(mockWindmill.getStateData(COMPUTATION, expectedRequest2.build()))
+        .thenReturn(response2.build());
+    Mockito.when(mockWindmill.getStateData(COMPUTATION, expectedRequest3.build()))
+        .thenReturn(response3.build());
+
+    Iterable<Integer> results = future.get();
+    Mockito.verify(mockWindmill).getStateData(COMPUTATION, expectedRequest1.build());
+    for (Integer unused : results) {
+      // Iterate over the results to force loading all the pages.
+    }
+    Mockito.verify(mockWindmill).getStateData(COMPUTATION, expectedRequest2.build());
+    Mockito.verify(mockWindmill).getStateData(COMPUTATION, expectedRequest3.build());
+    Mockito.verifyNoMoreInteractions(mockWindmill);
+
+    assertThat(results, Matchers.contains(5, 6, 7, 8, 9, 10));
+    // NOTE: The future will still contain a reference to the underlying reader.
+  }
+
+  // check whether the two TagMultimapFetchRequests equal to each other, ignoring the order of
+  // entries and the order of values in each entry.
+  private static void assertMultimapFetchRequestEqual(
+      Windmill.TagMultimapFetchRequest req1, Windmill.TagMultimapFetchRequest req2) {
+    assertMultimapEntriesEqual(req1.getEntriesToFetchList(), req2.getEntriesToFetchList());
+    assertEquals(
+        req1.toBuilder().clearEntriesToFetch().build(),
+        req2.toBuilder().clearEntriesToFetch().build());
+  }
+
+  private static void assertMultimapEntriesEqual(
+      List<Windmill.TagMultimapEntry> left, List<Windmill.TagMultimapEntry> right) {
+    Map<ByteString, Windmill.TagMultimapEntry> map = Maps.newHashMap();
+    for (Windmill.TagMultimapEntry entry : left) {
+      map.put(entry.getEntryName(), entry);
+    }
+    for (Windmill.TagMultimapEntry entry : right) {
+      assertTrue(map.containsKey(entry.getEntryName()));
+      Windmill.TagMultimapEntry that = map.remove(entry.getEntryName());
+      if (entry.getValuesCount() == 0) {
+        assertEquals(0, that.getValuesCount());
+      } else {
+        assertThat(entry.getValuesList(), Matchers.containsInAnyOrder(that.getValuesList()));
+      }
+      assertEquals(entry.toBuilder().clearValues().build(), that.toBuilder().clearValues().build());
+    }
+    assertTrue(map.isEmpty());
+  }
+
+  @Test
+  public void testReadMultimapMultipleEntries() throws Exception {
+    Future<Iterable<Integer>> future1 =
+        underTest.multimapFetchSingleEntryFuture(
+            STATE_MULTIMAP_KEY_1, STATE_KEY_1, STATE_FAMILY, INT_CODER);
+    Future<Iterable<Integer>> future2 =
+        underTest.multimapFetchSingleEntryFuture(
+            STATE_MULTIMAP_KEY_2, STATE_KEY_1, STATE_FAMILY, INT_CODER);
+    Mockito.verifyNoMoreInteractions(mockWindmill);
+
+    Windmill.KeyedGetDataRequest.Builder expectedRequest =
+        Windmill.KeyedGetDataRequest.newBuilder()
+            .setKey(DATA_KEY)
+            .setShardingKey(SHARDING_KEY)
+            .setWorkToken(WORK_TOKEN)
+            .setMaxBytes(WindmillStateReader.MAX_KEY_BYTES)
+            .addMultimapsToFetch(
+                Windmill.TagMultimapFetchRequest.newBuilder()
+                    .setTag(STATE_KEY_1)
+                    .setStateFamily(STATE_FAMILY)
+                    .setFetchEntryNamesOnly(false)
+                    .addEntriesToFetch(
+                        Windmill.TagMultimapEntry.newBuilder()
+                            .setEntryName(STATE_MULTIMAP_KEY_1)
+                            .setFetchMaxBytes(WindmillStateReader.INITIAL_MAX_MULTIMAP_BYTES)
+                            .build())
+                    .addEntriesToFetch(
+                        Windmill.TagMultimapEntry.newBuilder()
+                            .setEntryName(STATE_MULTIMAP_KEY_2)
+                            .setFetchMaxBytes(WindmillStateReader.INITIAL_MAX_MULTIMAP_BYTES)
+                            .build()));
+
+    Windmill.KeyedGetDataResponse.Builder response =
+        Windmill.KeyedGetDataResponse.newBuilder()
+            .setKey(DATA_KEY)
+            .addTagMultimaps(
+                Windmill.TagMultimapFetchResponse.newBuilder()
+                    .setTag(STATE_KEY_1)
+                    .setStateFamily(STATE_FAMILY)
+                    .addEntries(
+                        Windmill.TagMultimapEntry.newBuilder()
+                            .setEntryName(STATE_MULTIMAP_KEY_1)
+                            .addAllValues(Arrays.asList(intData(5), intData(6))))
+                    .addEntries(
+                        Windmill.TagMultimapEntry.newBuilder()
+                            .setEntryName(STATE_MULTIMAP_KEY_2)
+                            .addAllValues(Arrays.asList(intData(15), intData(16)))));
+    when(mockWindmill.getStateData(ArgumentMatchers.eq(COMPUTATION), ArgumentMatchers.any()))
+        .thenReturn(response.build());
+
+    Iterable<Integer> results1 = future1.get();
+    Iterable<Integer> results2 = future2.get();
+
+    final ArgumentCaptor<Windmill.KeyedGetDataRequest> requestCaptor =
+        ArgumentCaptor.forClass(Windmill.KeyedGetDataRequest.class);
+    Mockito.verify(mockWindmill)
+        .getStateData(ArgumentMatchers.eq(COMPUTATION), requestCaptor.capture());
+    assertMultimapFetchRequestEqual(
+        expectedRequest.build().getMultimapsToFetch(0),
+        requestCaptor.getValue().getMultimapsToFetch(0));
+
+    // Iterate over the results to force loading all the pages.
+    for (Integer unused : results1) {}
+    for (Integer unused : results2) {}
+    Mockito.verifyNoMoreInteractions(mockWindmill);
+
+    assertThat(results1, Matchers.containsInAnyOrder(5, 6));
+    assertThat(results2, Matchers.containsInAnyOrder(15, 16));
+    assertNoReader(future1);
+    assertNoReader(future2);
+  }
+
+  @Test
+  public void testReadMultimapMultipleEntriesWithPagination() throws Exception {
+    Future<Iterable<Integer>> future1 =
+        underTest.multimapFetchSingleEntryFuture(
+            STATE_MULTIMAP_KEY_1, STATE_KEY_1, STATE_FAMILY, INT_CODER);
+    Future<Iterable<Integer>> future2 =
+        underTest.multimapFetchSingleEntryFuture(
+            STATE_MULTIMAP_KEY_2, STATE_KEY_1, STATE_FAMILY, INT_CODER);
+    Mockito.verifyNoMoreInteractions(mockWindmill);
+
+    Windmill.KeyedGetDataRequest.Builder expectedRequest1 =
+        Windmill.KeyedGetDataRequest.newBuilder()
+            .setKey(DATA_KEY)
+            .setShardingKey(SHARDING_KEY)
+            .setWorkToken(WORK_TOKEN)
+            .setMaxBytes(WindmillStateReader.MAX_KEY_BYTES)
+            .addMultimapsToFetch(
+                Windmill.TagMultimapFetchRequest.newBuilder()
+                    .setTag(STATE_KEY_1)
+                    .setStateFamily(STATE_FAMILY)
+                    .setFetchEntryNamesOnly(false)
+                    .addEntriesToFetch(
+                        Windmill.TagMultimapEntry.newBuilder()
+                            .setEntryName(STATE_MULTIMAP_KEY_1)
+                            .setFetchMaxBytes(WindmillStateReader.INITIAL_MAX_MULTIMAP_BYTES)
+                            .build())
+                    .addEntriesToFetch(
+                        Windmill.TagMultimapEntry.newBuilder()
+                            .setEntryName(STATE_MULTIMAP_KEY_2)
+                            .setFetchMaxBytes(WindmillStateReader.INITIAL_MAX_MULTIMAP_BYTES)
+                            .build()));
+
+    Windmill.KeyedGetDataResponse.Builder response1 =
+        Windmill.KeyedGetDataResponse.newBuilder()
+            .setKey(DATA_KEY)
+            .addTagMultimaps(
+                Windmill.TagMultimapFetchResponse.newBuilder()
+                    .setTag(STATE_KEY_1)
+                    .setStateFamily(STATE_FAMILY)
+                    .addEntries(
+                        Windmill.TagMultimapEntry.newBuilder()
+                            .setEntryName(STATE_MULTIMAP_KEY_1)
+                            .addAllValues(Arrays.asList(intData(5), intData(6)))
+                            .setContinuationPosition(800))
+                    .addEntries(
+                        Windmill.TagMultimapEntry.newBuilder()
+                            .setEntryName(STATE_MULTIMAP_KEY_2)
+                            .addAllValues(Arrays.asList(intData(15), intData(16)))));
+    Windmill.KeyedGetDataRequest.Builder expectedRequest2 =
+        Windmill.KeyedGetDataRequest.newBuilder()
+            .setKey(DATA_KEY)
+            .setShardingKey(SHARDING_KEY)
+            .setWorkToken(WORK_TOKEN)
+            .setMaxBytes(WindmillStateReader.MAX_CONTINUATION_KEY_BYTES)
+            .addMultimapsToFetch(
+                Windmill.TagMultimapFetchRequest.newBuilder()
+                    .setTag(STATE_KEY_1)
+                    .setStateFamily(STATE_FAMILY)
+                    .setFetchEntryNamesOnly(false)
+                    .addEntriesToFetch(
+                        Windmill.TagMultimapEntry.newBuilder()
+                            .setEntryName(STATE_MULTIMAP_KEY_1)
+                            .setFetchMaxBytes(WindmillStateReader.CONTINUATION_MAX_MULTIMAP_BYTES)
+                            .setRequestPosition(800)
+                            .build()));
+    Windmill.KeyedGetDataResponse.Builder response2 =
+        Windmill.KeyedGetDataResponse.newBuilder()
+            .setKey(DATA_KEY)
+            .addTagMultimaps(
+                Windmill.TagMultimapFetchResponse.newBuilder()
+                    .setTag(STATE_KEY_1)
+                    .setStateFamily(STATE_FAMILY)
+                    .addEntries(
+                        Windmill.TagMultimapEntry.newBuilder()
+                            .setEntryName(STATE_MULTIMAP_KEY_1)
+                            .addAllValues(Arrays.asList(intData(7), intData(8)))
+                            .setRequestPosition(800)));
+    when(mockWindmill.getStateData(ArgumentMatchers.eq(COMPUTATION), ArgumentMatchers.any()))
+        .thenReturn(response1.build())
+        .thenReturn(response2.build());
+
+    Iterable<Integer> results1 = future1.get();
+    Iterable<Integer> results2 = future2.get();
+
+    // Iterate over the results to force loading all the pages.
+    for (Integer unused : results1) {}
+    for (Integer unused : results2) {}
+
+    final ArgumentCaptor<Windmill.KeyedGetDataRequest> requestCaptor =
+        ArgumentCaptor.forClass(Windmill.KeyedGetDataRequest.class);
+    Mockito.verify(mockWindmill, times(2))
+        .getStateData(ArgumentMatchers.eq(COMPUTATION), requestCaptor.capture());
+    assertMultimapFetchRequestEqual(
+        expectedRequest1.build().getMultimapsToFetch(0),
+        requestCaptor.getAllValues().get(0).getMultimapsToFetch(0));
+    assertMultimapFetchRequestEqual(
+        expectedRequest2.build().getMultimapsToFetch(0),
+        requestCaptor.getAllValues().get(1).getMultimapsToFetch(0));
+    Mockito.verifyNoMoreInteractions(mockWindmill);
+
+    assertThat(results1, Matchers.containsInAnyOrder(5, 6, 7, 8));
+    assertThat(results2, Matchers.containsInAnyOrder(15, 16));
+    // NOTE: The future will still contain a reference to the underlying reader.

Review Comment:
   Done



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@beam.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org