You are viewing a plain text version of this content. The canonical link for it is here.
Posted to github@beam.apache.org by "zhengbuqian (via GitHub)" <gi...@apache.org> on 2023/03/24 03:08:51 UTC

[GitHub] [beam] zhengbuqian commented on a diff in pull request #23492: Add Windmill support for MultimapState

zhengbuqian commented on code in PR #23492:
URL: https://github.com/apache/beam/pull/23492#discussion_r1146977349


##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternals.java:
##########
@@ -1588,7 +1598,648 @@ private Future<Iterable<Map.Entry<ByteString, V>>> getFuture() {
         return reader.valuePrefixFuture(stateKeyPrefix, stateFamily, valueCoder);
       }
     }
-  };
+  }
+
+  private static class WindmillMultimap<K, V> extends SimpleWindmillState
+      implements MultimapState<K, V> {
+
+    private final StateNamespace namespace;
+    private final StateTag<MultimapState<K, V>> address;
+    private final ByteString stateKey;
+    private final String stateFamily;
+    private final Coder<K> keyCoder;
+    private final Coder<V> valueCoder;
+
+    private class KeyState {
+      final K originalKey;
+      KeyExistence existence;
+      // valuesCached can be true if only existence == KNOWN_EXIST and all values of this key is

Review Comment:
   Done



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternals.java:
##########
@@ -1588,7 +1598,648 @@ private Future<Iterable<Map.Entry<ByteString, V>>> getFuture() {
         return reader.valuePrefixFuture(stateKeyPrefix, stateFamily, valueCoder);
       }
     }
-  };
+  }
+
+  private static class WindmillMultimap<K, V> extends SimpleWindmillState
+      implements MultimapState<K, V> {
+
+    private final StateNamespace namespace;
+    private final StateTag<MultimapState<K, V>> address;
+    private final ByteString stateKey;
+    private final String stateFamily;
+    private final Coder<K> keyCoder;
+    private final Coder<V> valueCoder;
+
+    private class KeyState {
+      final K originalKey;
+      KeyExistence existence;
+      // valuesCached can be true if only existence == KNOWN_EXIST and all values of this key is
+      // cached (both values and localAdditions).
+      boolean valuesCached;
+      // Represents the values in windmill. When new values are added during user processing, they
+      // are added to localAdditions but not values. Those new values will be added to values only
+      // after they are persisted into windmill and removed from localAdditions
+      ConcatIterables<V> values;
+      int valuesSize;
+
+      // When new values are added during user processing, they are added to localAdditions, so that
+      // we can later try to persist them in windmill. When a key is removed during user processing,
+      // we mark removedLocally to be true so that we can later try to delete it from windmill. If
+      // localAdditions is not empty and removedLocally is true, values in localAdditions will be
+      // added to windmill after old values in windmill are removed.
+      List<V> localAdditions;
+      boolean removedLocally;
+
+      KeyState(K originalKey) {
+        this.originalKey = originalKey;
+        existence = KeyExistence.UNKNOWN_EXISTENCE;
+        valuesCached = complete;
+        values = new ConcatIterables<>();
+        valuesSize = 0;
+        localAdditions = Lists.newArrayList();
+        removedLocally = false;
+      }
+    }
+
+    private enum KeyExistence {

Review Comment:
   Done



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternals.java:
##########
@@ -1588,7 +1598,648 @@ private Future<Iterable<Map.Entry<ByteString, V>>> getFuture() {
         return reader.valuePrefixFuture(stateKeyPrefix, stateFamily, valueCoder);
       }
     }
-  };
+  }
+
+  private static class WindmillMultimap<K, V> extends SimpleWindmillState
+      implements MultimapState<K, V> {
+
+    private final StateNamespace namespace;
+    private final StateTag<MultimapState<K, V>> address;
+    private final ByteString stateKey;
+    private final String stateFamily;
+    private final Coder<K> keyCoder;
+    private final Coder<V> valueCoder;
+
+    private class KeyState {
+      final K originalKey;
+      KeyExistence existence;
+      // valuesCached can be true if only existence == KNOWN_EXIST and all values of this key is
+      // cached (both values and localAdditions).
+      boolean valuesCached;
+      // Represents the values in windmill. When new values are added during user processing, they
+      // are added to localAdditions but not values. Those new values will be added to values only
+      // after they are persisted into windmill and removed from localAdditions
+      ConcatIterables<V> values;
+      int valuesSize;
+
+      // When new values are added during user processing, they are added to localAdditions, so that
+      // we can later try to persist them in windmill. When a key is removed during user processing,
+      // we mark removedLocally to be true so that we can later try to delete it from windmill. If
+      // localAdditions is not empty and removedLocally is true, values in localAdditions will be
+      // added to windmill after old values in windmill are removed.
+      List<V> localAdditions;
+      boolean removedLocally;
+
+      KeyState(K originalKey) {
+        this.originalKey = originalKey;
+        existence = KeyExistence.UNKNOWN_EXISTENCE;
+        valuesCached = complete;
+        values = new ConcatIterables<>();
+        valuesSize = 0;
+        localAdditions = Lists.newArrayList();
+        removedLocally = false;
+      }
+    }
+
+    private enum KeyExistence {
+      // this key is known to exist
+      KNOWN_EXIST,
+      // this key is known to be nonexistent
+      KNOWN_NONEXISTENT,
+      // we don't know if this key is in this multimap, this is just to provide a mapping between
+      // the original key and the structural key.
+      UNKNOWN_EXISTENCE
+    }
+
+    private boolean cleared = false;
+    // We use the structural value of the keys as the key in keyStateMap, so that different java
+    // Objects with the same content will be treated as the same Multimap key.
+    private Map<Object, KeyState> keyStateMap = Maps.newHashMap();
+    // If true, all keys are cached in keyStateMap with existence == KNOWN_EXIST.
+    private boolean allKeysKnown = false;
+
+    private boolean complete = false;
+    // hasLocalAdditions and hasLocalRemovals track whether there are local changes that needs to be
+    // propagated to windmill.
+    private boolean hasLocalAdditions = false;
+    private boolean hasLocalRemovals = false;
+
+    private WindmillMultimap(
+        StateNamespace namespace,
+        StateTag<MultimapState<K, V>> address,
+        String stateFamily,
+        Coder<K> keyCoder,
+        Coder<V> valueCoder,
+        boolean isNewShardingKey) {
+      this.namespace = namespace;
+      this.address = address;
+      this.stateKey = encodeKey(namespace, address);
+      this.stateFamily = stateFamily;
+      this.keyCoder = keyCoder;
+      this.valueCoder = valueCoder;
+      this.complete = isNewShardingKey;
+      this.allKeysKnown = isNewShardingKey;
+    }
+
+    @Override
+    public void put(K key, V value) {
+      final Object structuralKey = keyCoder.structuralValue(key);
+      hasLocalAdditions = true;
+      keyStateMap.compute(
+          structuralKey,
+          (k, v) -> {
+            if (v == null) v = new KeyState(key);
+            v.existence = KeyExistence.KNOWN_EXIST;
+            v.localAdditions.add(value);
+            return v;
+          });
+    }
+
+    // Initiates a backend state read to fetch all entries if necessary.
+    private Future<Iterable<Map.Entry<ByteString, Iterable<V>>>> getFuture(boolean omitValues) {

Review Comment:
   Done



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternals.java:
##########
@@ -1588,7 +1598,648 @@ private Future<Iterable<Map.Entry<ByteString, V>>> getFuture() {
         return reader.valuePrefixFuture(stateKeyPrefix, stateFamily, valueCoder);
       }
     }
-  };
+  }
+
+  private static class WindmillMultimap<K, V> extends SimpleWindmillState
+      implements MultimapState<K, V> {
+
+    private final StateNamespace namespace;
+    private final StateTag<MultimapState<K, V>> address;
+    private final ByteString stateKey;
+    private final String stateFamily;
+    private final Coder<K> keyCoder;
+    private final Coder<V> valueCoder;
+
+    private class KeyState {
+      final K originalKey;
+      KeyExistence existence;
+      // valuesCached can be true if only existence == KNOWN_EXIST and all values of this key is
+      // cached (both values and localAdditions).
+      boolean valuesCached;
+      // Represents the values in windmill. When new values are added during user processing, they
+      // are added to localAdditions but not values. Those new values will be added to values only
+      // after they are persisted into windmill and removed from localAdditions
+      ConcatIterables<V> values;
+      int valuesSize;
+
+      // When new values are added during user processing, they are added to localAdditions, so that
+      // we can later try to persist them in windmill. When a key is removed during user processing,
+      // we mark removedLocally to be true so that we can later try to delete it from windmill. If
+      // localAdditions is not empty and removedLocally is true, values in localAdditions will be
+      // added to windmill after old values in windmill are removed.
+      List<V> localAdditions;
+      boolean removedLocally;
+
+      KeyState(K originalKey) {
+        this.originalKey = originalKey;
+        existence = KeyExistence.UNKNOWN_EXISTENCE;
+        valuesCached = complete;
+        values = new ConcatIterables<>();
+        valuesSize = 0;
+        localAdditions = Lists.newArrayList();
+        removedLocally = false;
+      }
+    }
+
+    private enum KeyExistence {
+      // this key is known to exist
+      KNOWN_EXIST,
+      // this key is known to be nonexistent
+      KNOWN_NONEXISTENT,
+      // we don't know if this key is in this multimap, this is just to provide a mapping between
+      // the original key and the structural key.
+      UNKNOWN_EXISTENCE
+    }
+
+    private boolean cleared = false;

Review Comment:
   Done



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternals.java:
##########
@@ -1588,7 +1598,648 @@ private Future<Iterable<Map.Entry<ByteString, V>>> getFuture() {
         return reader.valuePrefixFuture(stateKeyPrefix, stateFamily, valueCoder);
       }
     }
-  };
+  }
+
+  private static class WindmillMultimap<K, V> extends SimpleWindmillState
+      implements MultimapState<K, V> {
+
+    private final StateNamespace namespace;
+    private final StateTag<MultimapState<K, V>> address;
+    private final ByteString stateKey;
+    private final String stateFamily;
+    private final Coder<K> keyCoder;
+    private final Coder<V> valueCoder;
+
+    private class KeyState {
+      final K originalKey;
+      KeyExistence existence;
+      // valuesCached can be true if only existence == KNOWN_EXIST and all values of this key is
+      // cached (both values and localAdditions).
+      boolean valuesCached;
+      // Represents the values in windmill. When new values are added during user processing, they
+      // are added to localAdditions but not values. Those new values will be added to values only
+      // after they are persisted into windmill and removed from localAdditions
+      ConcatIterables<V> values;
+      int valuesSize;
+
+      // When new values are added during user processing, they are added to localAdditions, so that
+      // we can later try to persist them in windmill. When a key is removed during user processing,
+      // we mark removedLocally to be true so that we can later try to delete it from windmill. If
+      // localAdditions is not empty and removedLocally is true, values in localAdditions will be
+      // added to windmill after old values in windmill are removed.
+      List<V> localAdditions;
+      boolean removedLocally;
+
+      KeyState(K originalKey) {
+        this.originalKey = originalKey;
+        existence = KeyExistence.UNKNOWN_EXISTENCE;
+        valuesCached = complete;
+        values = new ConcatIterables<>();
+        valuesSize = 0;
+        localAdditions = Lists.newArrayList();
+        removedLocally = false;
+      }
+    }
+
+    private enum KeyExistence {
+      // this key is known to exist
+      KNOWN_EXIST,
+      // this key is known to be nonexistent
+      KNOWN_NONEXISTENT,
+      // we don't know if this key is in this multimap, this is just to provide a mapping between
+      // the original key and the structural key.
+      UNKNOWN_EXISTENCE
+    }
+
+    private boolean cleared = false;
+    // We use the structural value of the keys as the key in keyStateMap, so that different java
+    // Objects with the same content will be treated as the same Multimap key.
+    private Map<Object, KeyState> keyStateMap = Maps.newHashMap();
+    // If true, all keys are cached in keyStateMap with existence == KNOWN_EXIST.
+    private boolean allKeysKnown = false;
+
+    private boolean complete = false;
+    // hasLocalAdditions and hasLocalRemovals track whether there are local changes that needs to be
+    // propagated to windmill.
+    private boolean hasLocalAdditions = false;
+    private boolean hasLocalRemovals = false;
+
+    private WindmillMultimap(
+        StateNamespace namespace,
+        StateTag<MultimapState<K, V>> address,
+        String stateFamily,
+        Coder<K> keyCoder,
+        Coder<V> valueCoder,
+        boolean isNewShardingKey) {
+      this.namespace = namespace;
+      this.address = address;
+      this.stateKey = encodeKey(namespace, address);
+      this.stateFamily = stateFamily;
+      this.keyCoder = keyCoder;
+      this.valueCoder = valueCoder;
+      this.complete = isNewShardingKey;
+      this.allKeysKnown = isNewShardingKey;
+    }
+
+    @Override
+    public void put(K key, V value) {
+      final Object structuralKey = keyCoder.structuralValue(key);
+      hasLocalAdditions = true;
+      keyStateMap.compute(
+          structuralKey,
+          (k, v) -> {
+            if (v == null) v = new KeyState(key);
+            v.existence = KeyExistence.KNOWN_EXIST;
+            v.localAdditions.add(value);
+            return v;
+          });
+    }
+
+    // Initiates a backend state read to fetch all entries if necessary.
+    private Future<Iterable<Map.Entry<ByteString, Iterable<V>>>> getFuture(boolean omitValues) {
+      if (complete) {
+        return Futures.immediateFuture(Collections.emptyList());
+      } else {
+        return reader.multimapFetchAllFuture(omitValues, stateKey, stateFamily, valueCoder);
+      }
+    }
+
+    // Initiates a backend state read to fetch a single entry if necessary.
+    private Future<Iterable<V>> getFutureForKey(K key) {
+      try {
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        keyCoder.encode(key, keyStream, Context.OUTER);
+        return reader.multimapFetchSingleEntryFuture(
+            keyStream.toByteString(), stateKey, stateFamily, valueCoder);
+      } catch (IOException e) {
+        throw new RuntimeException(e);
+      }
+    }
+
+    @Override
+    public ReadableState<Iterable<V>> get(K key) {
+      return new ReadableState<Iterable<V>>() {
+        final Object structuralKey = keyCoder.structuralValue(key);
+
+        @Override
+        public Iterable<V> read() {
+          KeyState keyState = keyStateMap.computeIfAbsent(structuralKey, k -> new KeyState(key));
+          if (keyState.existence == KeyExistence.KNOWN_NONEXISTENT) {
+            return Collections.emptyList();
+          }
+          if (allKeysKnown && keyState.existence == KeyExistence.UNKNOWN_EXISTENCE) {
+            keyStateMap.remove(structuralKey);
+            return Collections.emptyList();
+          }
+          Iterable<V> localNewValues =
+              Iterables.limit(keyState.localAdditions, keyState.localAdditions.size());
+          if (keyState.removedLocally) {
+            // this key has been removed locally but the removal hasn't been sent to windmill,
+            // thus values in windmill(if any) are obsolete, and we only care about local values.
+            return Iterables.unmodifiableIterable(localNewValues);
+          }
+          if (keyState.valuesCached || complete) {
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(
+                    Iterables.limit(keyState.values, keyState.valuesSize), localNewValues));
+          }
+          Future<Iterable<V>> persistedData = getFutureForKey(key);
+          try (Closeable scope = scopedReadState()) {
+            final Iterable<V> persistedValues = persistedData.get();
+            // Iterables.isEmpty() is O(1).
+            if (Iterables.isEmpty(persistedValues)) {
+              if (keyState.localAdditions.isEmpty()) {
+                // empty in both cache and windmill, mark key as KNOWN_NONEXISTENT.
+                keyState.existence = KeyExistence.KNOWN_NONEXISTENT;
+                return Collections.emptyList();
+              }
+              return Iterables.unmodifiableIterable(localNewValues);
+            }
+            keyState.existence = KeyExistence.KNOWN_EXIST;
+            if (persistedValues instanceof Weighted) {
+              keyState.valuesCached = true;
+              ConcatIterables<V> it = new ConcatIterables<>();
+              it.extendWith(persistedValues);
+              keyState.values = it;
+              keyState.valuesSize = Iterables.size(persistedValues);
+            }
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(persistedValues, localNewValues));
+          } catch (InterruptedException | ExecutionException | IOException e) {
+            if (e instanceof InterruptedException) {
+              Thread.currentThread().interrupt();
+            }
+            throw new RuntimeException("Unable to read Multimap state", e);
+          }
+        }
+
+        @Override
+        @SuppressWarnings("FutureReturnValueIgnored")
+        public ReadableState<Iterable<V>> readLater() {
+          WindmillMultimap.this.getFutureForKey(key);
+          return this;
+        }
+      };
+    }
+
+    @Override
+    protected WorkItemCommitRequest persistDirectly(WindmillStateCache.ForKeyAndFamily cache)
+        throws IOException {
+      if (!cleared && !hasLocalAdditions && !hasLocalRemovals) {
+        cache.put(namespace, address, this, 1);
+        return WorkItemCommitRequest.newBuilder().buildPartial();
+      }
+      WorkItemCommitRequest.Builder commitBuilder = WorkItemCommitRequest.newBuilder();
+      Windmill.TagMultimapUpdateRequest.Builder builder = commitBuilder.addMultimapUpdatesBuilder();
+      builder.setTag(stateKey).setStateFamily(stateFamily);
+
+      if (cleared) {
+        builder.setDeleteAll(true);
+      }
+      if (hasLocalRemovals || hasLocalAdditions) {
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        ByteStringOutputStream valueStream = new ByteStringOutputStream();
+        Iterator<Entry<Object, KeyState>> iterator = keyStateMap.entrySet().iterator();
+        while (iterator.hasNext()) {
+          KeyState keyState = iterator.next().getValue();
+          if (!keyState.removedLocally && keyState.localAdditions.isEmpty()) {
+            if (keyState.existence == KeyExistence.KNOWN_NONEXISTENT) iterator.remove();
+            continue;
+          }
+          keyCoder.encode(keyState.originalKey, keyStream, Context.OUTER);
+          ByteString encodedKey = keyStream.toByteStringAndReset();
+          Windmill.TagMultimapEntry.Builder entryBuilder = builder.addUpdatesBuilder();
+          entryBuilder.setEntryName(encodedKey);
+          entryBuilder.setDeleteAll(keyState.removedLocally);
+          keyState.removedLocally = false;
+          for (V value : keyState.localAdditions) {
+            valueCoder.encode(value, valueStream, Context.OUTER);
+            ByteString encodedValue = valueStream.toByteStringAndReset();
+            entryBuilder.addValues(encodedValue);
+          }
+          // Move newly added values from localAdditions to keyState.values as those new values now
+          // are also persisted in Windmill. If a key now has no more values and is not KNOWN_EXIST,
+          // remove it from cache.
+          if (keyState.valuesCached) {
+            keyState.values.extendWith(keyState.localAdditions);
+            keyState.valuesSize += keyState.localAdditions.size();
+          }
+          // Create a new localAdditions so that the cached values are unaffected.
+          keyState.localAdditions = Lists.newArrayList();
+          if (!keyState.valuesCached && keyState.existence != KeyExistence.KNOWN_EXIST) {
+            iterator.remove();
+          }
+        }
+      }
+
+      hasLocalAdditions = false;
+      hasLocalRemovals = false;
+      cleared = false;
+
+      cache.put(namespace, address, this, 1);
+      return commitBuilder.buildPartial();
+    }
+
+    @Override
+    public void remove(K key) {
+      final Object structuralKey = keyCoder.structuralValue(key);
+      KeyState keyState = keyStateMap.computeIfAbsent(structuralKey, k -> new KeyState(key));
+      if (keyState.existence == KeyExistence.KNOWN_NONEXISTENT
+          || (allKeysKnown && keyState.existence == KeyExistence.UNKNOWN_EXISTENCE)) {
+        return;
+      }
+      if (!keyState.valuesCached || (keyState.valuesCached && keyState.valuesSize > 0)) {

Review Comment:
   Done



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternals.java:
##########
@@ -1588,7 +1598,648 @@ private Future<Iterable<Map.Entry<ByteString, V>>> getFuture() {
         return reader.valuePrefixFuture(stateKeyPrefix, stateFamily, valueCoder);
       }
     }
-  };
+  }
+
+  private static class WindmillMultimap<K, V> extends SimpleWindmillState
+      implements MultimapState<K, V> {
+
+    private final StateNamespace namespace;
+    private final StateTag<MultimapState<K, V>> address;
+    private final ByteString stateKey;
+    private final String stateFamily;
+    private final Coder<K> keyCoder;
+    private final Coder<V> valueCoder;
+
+    private class KeyState {
+      final K originalKey;
+      KeyExistence existence;
+      // valuesCached can be true if only existence == KNOWN_EXIST and all values of this key is
+      // cached (both values and localAdditions).
+      boolean valuesCached;
+      // Represents the values in windmill. When new values are added during user processing, they
+      // are added to localAdditions but not values. Those new values will be added to values only
+      // after they are persisted into windmill and removed from localAdditions
+      ConcatIterables<V> values;
+      int valuesSize;
+
+      // When new values are added during user processing, they are added to localAdditions, so that
+      // we can later try to persist them in windmill. When a key is removed during user processing,
+      // we mark removedLocally to be true so that we can later try to delete it from windmill. If
+      // localAdditions is not empty and removedLocally is true, values in localAdditions will be
+      // added to windmill after old values in windmill are removed.
+      List<V> localAdditions;
+      boolean removedLocally;
+
+      KeyState(K originalKey) {
+        this.originalKey = originalKey;
+        existence = KeyExistence.UNKNOWN_EXISTENCE;
+        valuesCached = complete;
+        values = new ConcatIterables<>();
+        valuesSize = 0;
+        localAdditions = Lists.newArrayList();
+        removedLocally = false;
+      }
+    }
+
+    private enum KeyExistence {
+      // this key is known to exist
+      KNOWN_EXIST,
+      // this key is known to be nonexistent
+      KNOWN_NONEXISTENT,
+      // we don't know if this key is in this multimap, this is just to provide a mapping between
+      // the original key and the structural key.
+      UNKNOWN_EXISTENCE
+    }
+
+    private boolean cleared = false;
+    // We use the structural value of the keys as the key in keyStateMap, so that different java
+    // Objects with the same content will be treated as the same Multimap key.
+    private Map<Object, KeyState> keyStateMap = Maps.newHashMap();
+    // If true, all keys are cached in keyStateMap with existence == KNOWN_EXIST.
+    private boolean allKeysKnown = false;
+
+    private boolean complete = false;
+    // hasLocalAdditions and hasLocalRemovals track whether there are local changes that needs to be
+    // propagated to windmill.
+    private boolean hasLocalAdditions = false;
+    private boolean hasLocalRemovals = false;
+
+    private WindmillMultimap(
+        StateNamespace namespace,
+        StateTag<MultimapState<K, V>> address,
+        String stateFamily,
+        Coder<K> keyCoder,
+        Coder<V> valueCoder,
+        boolean isNewShardingKey) {
+      this.namespace = namespace;
+      this.address = address;
+      this.stateKey = encodeKey(namespace, address);
+      this.stateFamily = stateFamily;
+      this.keyCoder = keyCoder;
+      this.valueCoder = valueCoder;
+      this.complete = isNewShardingKey;
+      this.allKeysKnown = isNewShardingKey;
+    }
+
+    @Override
+    public void put(K key, V value) {
+      final Object structuralKey = keyCoder.structuralValue(key);
+      hasLocalAdditions = true;
+      keyStateMap.compute(
+          structuralKey,
+          (k, v) -> {
+            if (v == null) v = new KeyState(key);
+            v.existence = KeyExistence.KNOWN_EXIST;
+            v.localAdditions.add(value);
+            return v;
+          });
+    }
+
+    // Initiates a backend state read to fetch all entries if necessary.
+    private Future<Iterable<Map.Entry<ByteString, Iterable<V>>>> getFuture(boolean omitValues) {
+      if (complete) {
+        return Futures.immediateFuture(Collections.emptyList());
+      } else {
+        return reader.multimapFetchAllFuture(omitValues, stateKey, stateFamily, valueCoder);
+      }
+    }
+
+    // Initiates a backend state read to fetch a single entry if necessary.
+    private Future<Iterable<V>> getFutureForKey(K key) {
+      try {
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        keyCoder.encode(key, keyStream, Context.OUTER);
+        return reader.multimapFetchSingleEntryFuture(
+            keyStream.toByteString(), stateKey, stateFamily, valueCoder);
+      } catch (IOException e) {
+        throw new RuntimeException(e);
+      }
+    }
+
+    @Override
+    public ReadableState<Iterable<V>> get(K key) {
+      return new ReadableState<Iterable<V>>() {
+        final Object structuralKey = keyCoder.structuralValue(key);
+
+        @Override
+        public Iterable<V> read() {
+          KeyState keyState = keyStateMap.computeIfAbsent(structuralKey, k -> new KeyState(key));
+          if (keyState.existence == KeyExistence.KNOWN_NONEXISTENT) {
+            return Collections.emptyList();
+          }
+          if (allKeysKnown && keyState.existence == KeyExistence.UNKNOWN_EXISTENCE) {
+            keyStateMap.remove(structuralKey);
+            return Collections.emptyList();
+          }
+          Iterable<V> localNewValues =
+              Iterables.limit(keyState.localAdditions, keyState.localAdditions.size());
+          if (keyState.removedLocally) {
+            // this key has been removed locally but the removal hasn't been sent to windmill,
+            // thus values in windmill(if any) are obsolete, and we only care about local values.
+            return Iterables.unmodifiableIterable(localNewValues);
+          }
+          if (keyState.valuesCached || complete) {
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(
+                    Iterables.limit(keyState.values, keyState.valuesSize), localNewValues));
+          }
+          Future<Iterable<V>> persistedData = getFutureForKey(key);
+          try (Closeable scope = scopedReadState()) {
+            final Iterable<V> persistedValues = persistedData.get();
+            // Iterables.isEmpty() is O(1).
+            if (Iterables.isEmpty(persistedValues)) {
+              if (keyState.localAdditions.isEmpty()) {
+                // empty in both cache and windmill, mark key as KNOWN_NONEXISTENT.
+                keyState.existence = KeyExistence.KNOWN_NONEXISTENT;
+                return Collections.emptyList();
+              }
+              return Iterables.unmodifiableIterable(localNewValues);
+            }
+            keyState.existence = KeyExistence.KNOWN_EXIST;
+            if (persistedValues instanceof Weighted) {
+              keyState.valuesCached = true;
+              ConcatIterables<V> it = new ConcatIterables<>();
+              it.extendWith(persistedValues);
+              keyState.values = it;
+              keyState.valuesSize = Iterables.size(persistedValues);
+            }
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(persistedValues, localNewValues));
+          } catch (InterruptedException | ExecutionException | IOException e) {
+            if (e instanceof InterruptedException) {
+              Thread.currentThread().interrupt();
+            }
+            throw new RuntimeException("Unable to read Multimap state", e);
+          }
+        }
+
+        @Override
+        @SuppressWarnings("FutureReturnValueIgnored")
+        public ReadableState<Iterable<V>> readLater() {
+          WindmillMultimap.this.getFutureForKey(key);
+          return this;
+        }
+      };
+    }
+
+    @Override
+    protected WorkItemCommitRequest persistDirectly(WindmillStateCache.ForKeyAndFamily cache)
+        throws IOException {
+      if (!cleared && !hasLocalAdditions && !hasLocalRemovals) {
+        cache.put(namespace, address, this, 1);
+        return WorkItemCommitRequest.newBuilder().buildPartial();
+      }
+      WorkItemCommitRequest.Builder commitBuilder = WorkItemCommitRequest.newBuilder();
+      Windmill.TagMultimapUpdateRequest.Builder builder = commitBuilder.addMultimapUpdatesBuilder();
+      builder.setTag(stateKey).setStateFamily(stateFamily);
+
+      if (cleared) {
+        builder.setDeleteAll(true);
+      }
+      if (hasLocalRemovals || hasLocalAdditions) {
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        ByteStringOutputStream valueStream = new ByteStringOutputStream();
+        Iterator<Entry<Object, KeyState>> iterator = keyStateMap.entrySet().iterator();
+        while (iterator.hasNext()) {
+          KeyState keyState = iterator.next().getValue();
+          if (!keyState.removedLocally && keyState.localAdditions.isEmpty()) {
+            if (keyState.existence == KeyExistence.KNOWN_NONEXISTENT) iterator.remove();
+            continue;
+          }
+          keyCoder.encode(keyState.originalKey, keyStream, Context.OUTER);
+          ByteString encodedKey = keyStream.toByteStringAndReset();
+          Windmill.TagMultimapEntry.Builder entryBuilder = builder.addUpdatesBuilder();
+          entryBuilder.setEntryName(encodedKey);
+          entryBuilder.setDeleteAll(keyState.removedLocally);
+          keyState.removedLocally = false;
+          for (V value : keyState.localAdditions) {

Review Comment:
   Done



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternals.java:
##########
@@ -1588,7 +1598,648 @@ private Future<Iterable<Map.Entry<ByteString, V>>> getFuture() {
         return reader.valuePrefixFuture(stateKeyPrefix, stateFamily, valueCoder);
       }
     }
-  };
+  }
+
+  private static class WindmillMultimap<K, V> extends SimpleWindmillState
+      implements MultimapState<K, V> {
+
+    private final StateNamespace namespace;
+    private final StateTag<MultimapState<K, V>> address;
+    private final ByteString stateKey;
+    private final String stateFamily;
+    private final Coder<K> keyCoder;
+    private final Coder<V> valueCoder;
+
+    private class KeyState {
+      final K originalKey;
+      KeyExistence existence;
+      // valuesCached can be true if only existence == KNOWN_EXIST and all values of this key is
+      // cached (both values and localAdditions).
+      boolean valuesCached;
+      // Represents the values in windmill. When new values are added during user processing, they
+      // are added to localAdditions but not values. Those new values will be added to values only
+      // after they are persisted into windmill and removed from localAdditions
+      ConcatIterables<V> values;
+      int valuesSize;
+
+      // When new values are added during user processing, they are added to localAdditions, so that
+      // we can later try to persist them in windmill. When a key is removed during user processing,
+      // we mark removedLocally to be true so that we can later try to delete it from windmill. If
+      // localAdditions is not empty and removedLocally is true, values in localAdditions will be
+      // added to windmill after old values in windmill are removed.
+      List<V> localAdditions;
+      boolean removedLocally;
+
+      KeyState(K originalKey) {
+        this.originalKey = originalKey;
+        existence = KeyExistence.UNKNOWN_EXISTENCE;
+        valuesCached = complete;
+        values = new ConcatIterables<>();
+        valuesSize = 0;
+        localAdditions = Lists.newArrayList();
+        removedLocally = false;
+      }
+    }
+
+    private enum KeyExistence {
+      // this key is known to exist
+      KNOWN_EXIST,
+      // this key is known to be nonexistent
+      KNOWN_NONEXISTENT,
+      // we don't know if this key is in this multimap, this is just to provide a mapping between
+      // the original key and the structural key.
+      UNKNOWN_EXISTENCE
+    }
+
+    private boolean cleared = false;
+    // We use the structural value of the keys as the key in keyStateMap, so that different java
+    // Objects with the same content will be treated as the same Multimap key.
+    private Map<Object, KeyState> keyStateMap = Maps.newHashMap();
+    // If true, all keys are cached in keyStateMap with existence == KNOWN_EXIST.
+    private boolean allKeysKnown = false;
+
+    private boolean complete = false;
+    // hasLocalAdditions and hasLocalRemovals track whether there are local changes that needs to be
+    // propagated to windmill.
+    private boolean hasLocalAdditions = false;
+    private boolean hasLocalRemovals = false;
+
+    private WindmillMultimap(
+        StateNamespace namespace,
+        StateTag<MultimapState<K, V>> address,
+        String stateFamily,
+        Coder<K> keyCoder,
+        Coder<V> valueCoder,
+        boolean isNewShardingKey) {
+      this.namespace = namespace;
+      this.address = address;
+      this.stateKey = encodeKey(namespace, address);
+      this.stateFamily = stateFamily;
+      this.keyCoder = keyCoder;
+      this.valueCoder = valueCoder;
+      this.complete = isNewShardingKey;
+      this.allKeysKnown = isNewShardingKey;
+    }
+
+    @Override
+    public void put(K key, V value) {
+      final Object structuralKey = keyCoder.structuralValue(key);
+      hasLocalAdditions = true;
+      keyStateMap.compute(
+          structuralKey,
+          (k, v) -> {
+            if (v == null) v = new KeyState(key);
+            v.existence = KeyExistence.KNOWN_EXIST;
+            v.localAdditions.add(value);
+            return v;
+          });
+    }
+
+    // Initiates a backend state read to fetch all entries if necessary.
+    private Future<Iterable<Map.Entry<ByteString, Iterable<V>>>> getFuture(boolean omitValues) {
+      if (complete) {
+        return Futures.immediateFuture(Collections.emptyList());
+      } else {
+        return reader.multimapFetchAllFuture(omitValues, stateKey, stateFamily, valueCoder);
+      }
+    }
+
+    // Initiates a backend state read to fetch a single entry if necessary.
+    private Future<Iterable<V>> getFutureForKey(K key) {
+      try {
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        keyCoder.encode(key, keyStream, Context.OUTER);
+        return reader.multimapFetchSingleEntryFuture(
+            keyStream.toByteString(), stateKey, stateFamily, valueCoder);
+      } catch (IOException e) {
+        throw new RuntimeException(e);
+      }
+    }
+
+    @Override
+    public ReadableState<Iterable<V>> get(K key) {
+      return new ReadableState<Iterable<V>>() {
+        final Object structuralKey = keyCoder.structuralValue(key);
+
+        @Override
+        public Iterable<V> read() {
+          KeyState keyState = keyStateMap.computeIfAbsent(structuralKey, k -> new KeyState(key));
+          if (keyState.existence == KeyExistence.KNOWN_NONEXISTENT) {
+            return Collections.emptyList();
+          }
+          if (allKeysKnown && keyState.existence == KeyExistence.UNKNOWN_EXISTENCE) {
+            keyStateMap.remove(structuralKey);
+            return Collections.emptyList();
+          }
+          Iterable<V> localNewValues =
+              Iterables.limit(keyState.localAdditions, keyState.localAdditions.size());
+          if (keyState.removedLocally) {
+            // this key has been removed locally but the removal hasn't been sent to windmill,
+            // thus values in windmill(if any) are obsolete, and we only care about local values.
+            return Iterables.unmodifiableIterable(localNewValues);
+          }
+          if (keyState.valuesCached || complete) {
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(
+                    Iterables.limit(keyState.values, keyState.valuesSize), localNewValues));
+          }
+          Future<Iterable<V>> persistedData = getFutureForKey(key);
+          try (Closeable scope = scopedReadState()) {
+            final Iterable<V> persistedValues = persistedData.get();
+            // Iterables.isEmpty() is O(1).
+            if (Iterables.isEmpty(persistedValues)) {
+              if (keyState.localAdditions.isEmpty()) {
+                // empty in both cache and windmill, mark key as KNOWN_NONEXISTENT.
+                keyState.existence = KeyExistence.KNOWN_NONEXISTENT;
+                return Collections.emptyList();
+              }
+              return Iterables.unmodifiableIterable(localNewValues);
+            }
+            keyState.existence = KeyExistence.KNOWN_EXIST;
+            if (persistedValues instanceof Weighted) {
+              keyState.valuesCached = true;
+              ConcatIterables<V> it = new ConcatIterables<>();
+              it.extendWith(persistedValues);
+              keyState.values = it;
+              keyState.valuesSize = Iterables.size(persistedValues);
+            }
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(persistedValues, localNewValues));
+          } catch (InterruptedException | ExecutionException | IOException e) {
+            if (e instanceof InterruptedException) {
+              Thread.currentThread().interrupt();
+            }
+            throw new RuntimeException("Unable to read Multimap state", e);
+          }
+        }
+
+        @Override
+        @SuppressWarnings("FutureReturnValueIgnored")
+        public ReadableState<Iterable<V>> readLater() {
+          WindmillMultimap.this.getFutureForKey(key);
+          return this;
+        }
+      };
+    }
+
+    @Override
+    protected WorkItemCommitRequest persistDirectly(WindmillStateCache.ForKeyAndFamily cache)
+        throws IOException {
+      if (!cleared && !hasLocalAdditions && !hasLocalRemovals) {
+        cache.put(namespace, address, this, 1);
+        return WorkItemCommitRequest.newBuilder().buildPartial();
+      }
+      WorkItemCommitRequest.Builder commitBuilder = WorkItemCommitRequest.newBuilder();
+      Windmill.TagMultimapUpdateRequest.Builder builder = commitBuilder.addMultimapUpdatesBuilder();
+      builder.setTag(stateKey).setStateFamily(stateFamily);
+
+      if (cleared) {
+        builder.setDeleteAll(true);
+      }
+      if (hasLocalRemovals || hasLocalAdditions) {
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        ByteStringOutputStream valueStream = new ByteStringOutputStream();
+        Iterator<Entry<Object, KeyState>> iterator = keyStateMap.entrySet().iterator();
+        while (iterator.hasNext()) {
+          KeyState keyState = iterator.next().getValue();
+          if (!keyState.removedLocally && keyState.localAdditions.isEmpty()) {
+            if (keyState.existence == KeyExistence.KNOWN_NONEXISTENT) iterator.remove();
+            continue;
+          }
+          keyCoder.encode(keyState.originalKey, keyStream, Context.OUTER);
+          ByteString encodedKey = keyStream.toByteStringAndReset();
+          Windmill.TagMultimapEntry.Builder entryBuilder = builder.addUpdatesBuilder();
+          entryBuilder.setEntryName(encodedKey);
+          entryBuilder.setDeleteAll(keyState.removedLocally);
+          keyState.removedLocally = false;
+          for (V value : keyState.localAdditions) {
+            valueCoder.encode(value, valueStream, Context.OUTER);
+            ByteString encodedValue = valueStream.toByteStringAndReset();
+            entryBuilder.addValues(encodedValue);
+          }
+          // Move newly added values from localAdditions to keyState.values as those new values now
+          // are also persisted in Windmill. If a key now has no more values and is not KNOWN_EXIST,
+          // remove it from cache.
+          if (keyState.valuesCached) {
+            keyState.values.extendWith(keyState.localAdditions);
+            keyState.valuesSize += keyState.localAdditions.size();
+          }
+          // Create a new localAdditions so that the cached values are unaffected.
+          keyState.localAdditions = Lists.newArrayList();
+          if (!keyState.valuesCached && keyState.existence != KeyExistence.KNOWN_EXIST) {
+            iterator.remove();
+          }
+        }
+      }
+
+      hasLocalAdditions = false;
+      hasLocalRemovals = false;
+      cleared = false;
+
+      cache.put(namespace, address, this, 1);
+      return commitBuilder.buildPartial();
+    }
+
+    @Override
+    public void remove(K key) {
+      final Object structuralKey = keyCoder.structuralValue(key);
+      KeyState keyState = keyStateMap.computeIfAbsent(structuralKey, k -> new KeyState(key));
+      if (keyState.existence == KeyExistence.KNOWN_NONEXISTENT
+          || (allKeysKnown && keyState.existence == KeyExistence.UNKNOWN_EXISTENCE)) {
+        return;
+      }
+      if (!keyState.valuesCached || (keyState.valuesCached && keyState.valuesSize > 0)) {
+        // there may be data in windmill that need to be removed.
+        hasLocalRemovals = true;
+        keyState.removedLocally = true;
+        keyState.values = new ConcatIterables<>();
+        keyState.valuesSize = 0;
+        keyState.existence = KeyExistence.KNOWN_NONEXISTENT;
+      } else {
+        // no data in windmill, deleting from local cache is sufficient.
+        keyStateMap.remove(structuralKey);
+      }
+      if (!keyState.localAdditions.isEmpty()) {
+        keyState.localAdditions = Lists.newArrayList();
+      }
+      keyState.valuesCached = true;
+    }
+
+    @Override
+    public void clear() {
+      keyStateMap = Maps.newHashMap();
+      cleared = true;
+      complete = true;
+      allKeysKnown = true;
+    }
+
+    @Override
+    public ReadableState<Iterable<K>> keys() {
+      return new ReadableState<Iterable<K>>() {
+
+        private Map<Object, K> cachedExistKeys() {
+          return keyStateMap.entrySet().stream()
+              .filter(entry -> entry.getValue().existence == KeyExistence.KNOWN_EXIST)
+              .collect(Collectors.toMap(Entry::getKey, e -> e.getValue().originalKey));
+        }
+
+        @Override
+        public Iterable<K> read() {
+          if (allKeysKnown) {
+            return Iterables.unmodifiableIterable(cachedExistKeys().values());
+          }
+          Future<Iterable<Entry<ByteString, Iterable<V>>>> persistedData = getFuture(true);
+          try (Closeable scope = scopedReadState()) {
+            Iterable<Entry<ByteString, Iterable<V>>> entries = persistedData.get();
+            if (entries instanceof Weighted) {
+              // This is a known amount of data, cache them all.
+              entries.forEach(
+                  entry -> {
+                    try {
+                      K originalKey = keyCoder.decode(entry.getKey().newInput(), Context.OUTER);
+                      KeyState keyState =
+                          keyStateMap.computeIfAbsent(
+                              keyCoder.structuralValue(originalKey),
+                              stk -> new KeyState(originalKey));
+                      if (keyState.existence == KeyExistence.UNKNOWN_EXISTENCE) {
+                        keyState.existence = KeyExistence.KNOWN_EXIST;
+                      }
+                    } catch (IOException e) {
+                      throw new RuntimeException(e);
+                    }
+                  });
+              allKeysKnown = true;
+              keyStateMap
+                  .values()
+                  .removeIf(keyState -> keyState.existence != KeyExistence.KNOWN_EXIST);

Review Comment:
   Fixed and added unit test.



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternals.java:
##########
@@ -1588,7 +1598,648 @@ private Future<Iterable<Map.Entry<ByteString, V>>> getFuture() {
         return reader.valuePrefixFuture(stateKeyPrefix, stateFamily, valueCoder);
       }
     }
-  };
+  }
+
+  private static class WindmillMultimap<K, V> extends SimpleWindmillState
+      implements MultimapState<K, V> {
+
+    private final StateNamespace namespace;
+    private final StateTag<MultimapState<K, V>> address;
+    private final ByteString stateKey;
+    private final String stateFamily;
+    private final Coder<K> keyCoder;
+    private final Coder<V> valueCoder;
+
+    private class KeyState {
+      final K originalKey;
+      KeyExistence existence;
+      // valuesCached can be true if only existence == KNOWN_EXIST and all values of this key is
+      // cached (both values and localAdditions).
+      boolean valuesCached;
+      // Represents the values in windmill. When new values are added during user processing, they
+      // are added to localAdditions but not values. Those new values will be added to values only
+      // after they are persisted into windmill and removed from localAdditions
+      ConcatIterables<V> values;
+      int valuesSize;
+
+      // When new values are added during user processing, they are added to localAdditions, so that
+      // we can later try to persist them in windmill. When a key is removed during user processing,
+      // we mark removedLocally to be true so that we can later try to delete it from windmill. If
+      // localAdditions is not empty and removedLocally is true, values in localAdditions will be
+      // added to windmill after old values in windmill are removed.
+      List<V> localAdditions;
+      boolean removedLocally;
+
+      KeyState(K originalKey) {
+        this.originalKey = originalKey;
+        existence = KeyExistence.UNKNOWN_EXISTENCE;
+        valuesCached = complete;
+        values = new ConcatIterables<>();
+        valuesSize = 0;
+        localAdditions = Lists.newArrayList();
+        removedLocally = false;
+      }
+    }
+
+    private enum KeyExistence {
+      // this key is known to exist
+      KNOWN_EXIST,
+      // this key is known to be nonexistent
+      KNOWN_NONEXISTENT,
+      // we don't know if this key is in this multimap, this is just to provide a mapping between
+      // the original key and the structural key.
+      UNKNOWN_EXISTENCE
+    }
+
+    private boolean cleared = false;
+    // We use the structural value of the keys as the key in keyStateMap, so that different java
+    // Objects with the same content will be treated as the same Multimap key.
+    private Map<Object, KeyState> keyStateMap = Maps.newHashMap();
+    // If true, all keys are cached in keyStateMap with existence == KNOWN_EXIST.
+    private boolean allKeysKnown = false;
+
+    private boolean complete = false;
+    // hasLocalAdditions and hasLocalRemovals track whether there are local changes that needs to be
+    // propagated to windmill.
+    private boolean hasLocalAdditions = false;
+    private boolean hasLocalRemovals = false;
+
+    private WindmillMultimap(
+        StateNamespace namespace,
+        StateTag<MultimapState<K, V>> address,
+        String stateFamily,
+        Coder<K> keyCoder,
+        Coder<V> valueCoder,
+        boolean isNewShardingKey) {
+      this.namespace = namespace;
+      this.address = address;
+      this.stateKey = encodeKey(namespace, address);
+      this.stateFamily = stateFamily;
+      this.keyCoder = keyCoder;
+      this.valueCoder = valueCoder;
+      this.complete = isNewShardingKey;
+      this.allKeysKnown = isNewShardingKey;
+    }
+
+    @Override
+    public void put(K key, V value) {
+      final Object structuralKey = keyCoder.structuralValue(key);
+      hasLocalAdditions = true;
+      keyStateMap.compute(
+          structuralKey,
+          (k, v) -> {
+            if (v == null) v = new KeyState(key);
+            v.existence = KeyExistence.KNOWN_EXIST;
+            v.localAdditions.add(value);
+            return v;
+          });
+    }
+
+    // Initiates a backend state read to fetch all entries if necessary.
+    private Future<Iterable<Map.Entry<ByteString, Iterable<V>>>> getFuture(boolean omitValues) {
+      if (complete) {
+        return Futures.immediateFuture(Collections.emptyList());
+      } else {
+        return reader.multimapFetchAllFuture(omitValues, stateKey, stateFamily, valueCoder);
+      }
+    }
+
+    // Initiates a backend state read to fetch a single entry if necessary.
+    private Future<Iterable<V>> getFutureForKey(K key) {
+      try {
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        keyCoder.encode(key, keyStream, Context.OUTER);
+        return reader.multimapFetchSingleEntryFuture(
+            keyStream.toByteString(), stateKey, stateFamily, valueCoder);
+      } catch (IOException e) {
+        throw new RuntimeException(e);
+      }
+    }
+
+    @Override
+    public ReadableState<Iterable<V>> get(K key) {
+      return new ReadableState<Iterable<V>>() {
+        final Object structuralKey = keyCoder.structuralValue(key);
+
+        @Override
+        public Iterable<V> read() {
+          KeyState keyState = keyStateMap.computeIfAbsent(structuralKey, k -> new KeyState(key));
+          if (keyState.existence == KeyExistence.KNOWN_NONEXISTENT) {
+            return Collections.emptyList();
+          }
+          if (allKeysKnown && keyState.existence == KeyExistence.UNKNOWN_EXISTENCE) {
+            keyStateMap.remove(structuralKey);
+            return Collections.emptyList();
+          }
+          Iterable<V> localNewValues =
+              Iterables.limit(keyState.localAdditions, keyState.localAdditions.size());
+          if (keyState.removedLocally) {
+            // this key has been removed locally but the removal hasn't been sent to windmill,
+            // thus values in windmill(if any) are obsolete, and we only care about local values.
+            return Iterables.unmodifiableIterable(localNewValues);
+          }
+          if (keyState.valuesCached || complete) {
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(
+                    Iterables.limit(keyState.values, keyState.valuesSize), localNewValues));
+          }
+          Future<Iterable<V>> persistedData = getFutureForKey(key);
+          try (Closeable scope = scopedReadState()) {
+            final Iterable<V> persistedValues = persistedData.get();
+            // Iterables.isEmpty() is O(1).
+            if (Iterables.isEmpty(persistedValues)) {
+              if (keyState.localAdditions.isEmpty()) {
+                // empty in both cache and windmill, mark key as KNOWN_NONEXISTENT.
+                keyState.existence = KeyExistence.KNOWN_NONEXISTENT;
+                return Collections.emptyList();
+              }
+              return Iterables.unmodifiableIterable(localNewValues);
+            }
+            keyState.existence = KeyExistence.KNOWN_EXIST;
+            if (persistedValues instanceof Weighted) {
+              keyState.valuesCached = true;
+              ConcatIterables<V> it = new ConcatIterables<>();
+              it.extendWith(persistedValues);
+              keyState.values = it;
+              keyState.valuesSize = Iterables.size(persistedValues);
+            }
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(persistedValues, localNewValues));
+          } catch (InterruptedException | ExecutionException | IOException e) {
+            if (e instanceof InterruptedException) {
+              Thread.currentThread().interrupt();
+            }
+            throw new RuntimeException("Unable to read Multimap state", e);
+          }
+        }
+
+        @Override
+        @SuppressWarnings("FutureReturnValueIgnored")
+        public ReadableState<Iterable<V>> readLater() {
+          WindmillMultimap.this.getFutureForKey(key);
+          return this;
+        }
+      };
+    }
+
+    @Override
+    protected WorkItemCommitRequest persistDirectly(WindmillStateCache.ForKeyAndFamily cache)
+        throws IOException {
+      if (!cleared && !hasLocalAdditions && !hasLocalRemovals) {
+        cache.put(namespace, address, this, 1);
+        return WorkItemCommitRequest.newBuilder().buildPartial();
+      }
+      WorkItemCommitRequest.Builder commitBuilder = WorkItemCommitRequest.newBuilder();
+      Windmill.TagMultimapUpdateRequest.Builder builder = commitBuilder.addMultimapUpdatesBuilder();
+      builder.setTag(stateKey).setStateFamily(stateFamily);
+
+      if (cleared) {
+        builder.setDeleteAll(true);
+      }
+      if (hasLocalRemovals || hasLocalAdditions) {
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        ByteStringOutputStream valueStream = new ByteStringOutputStream();
+        Iterator<Entry<Object, KeyState>> iterator = keyStateMap.entrySet().iterator();
+        while (iterator.hasNext()) {
+          KeyState keyState = iterator.next().getValue();
+          if (!keyState.removedLocally && keyState.localAdditions.isEmpty()) {
+            if (keyState.existence == KeyExistence.KNOWN_NONEXISTENT) iterator.remove();
+            continue;
+          }
+          keyCoder.encode(keyState.originalKey, keyStream, Context.OUTER);
+          ByteString encodedKey = keyStream.toByteStringAndReset();
+          Windmill.TagMultimapEntry.Builder entryBuilder = builder.addUpdatesBuilder();
+          entryBuilder.setEntryName(encodedKey);
+          entryBuilder.setDeleteAll(keyState.removedLocally);
+          keyState.removedLocally = false;
+          for (V value : keyState.localAdditions) {
+            valueCoder.encode(value, valueStream, Context.OUTER);
+            ByteString encodedValue = valueStream.toByteStringAndReset();
+            entryBuilder.addValues(encodedValue);
+          }
+          // Move newly added values from localAdditions to keyState.values as those new values now
+          // are also persisted in Windmill. If a key now has no more values and is not KNOWN_EXIST,
+          // remove it from cache.
+          if (keyState.valuesCached) {
+            keyState.values.extendWith(keyState.localAdditions);
+            keyState.valuesSize += keyState.localAdditions.size();
+          }
+          // Create a new localAdditions so that the cached values are unaffected.
+          keyState.localAdditions = Lists.newArrayList();
+          if (!keyState.valuesCached && keyState.existence != KeyExistence.KNOWN_EXIST) {
+            iterator.remove();
+          }
+        }
+      }
+
+      hasLocalAdditions = false;
+      hasLocalRemovals = false;
+      cleared = false;
+
+      cache.put(namespace, address, this, 1);
+      return commitBuilder.buildPartial();
+    }
+
+    @Override
+    public void remove(K key) {
+      final Object structuralKey = keyCoder.structuralValue(key);
+      KeyState keyState = keyStateMap.computeIfAbsent(structuralKey, k -> new KeyState(key));
+      if (keyState.existence == KeyExistence.KNOWN_NONEXISTENT
+          || (allKeysKnown && keyState.existence == KeyExistence.UNKNOWN_EXISTENCE)) {
+        return;
+      }
+      if (!keyState.valuesCached || (keyState.valuesCached && keyState.valuesSize > 0)) {
+        // there may be data in windmill that need to be removed.
+        hasLocalRemovals = true;
+        keyState.removedLocally = true;
+        keyState.values = new ConcatIterables<>();
+        keyState.valuesSize = 0;
+        keyState.existence = KeyExistence.KNOWN_NONEXISTENT;
+      } else {
+        // no data in windmill, deleting from local cache is sufficient.
+        keyStateMap.remove(structuralKey);
+      }
+      if (!keyState.localAdditions.isEmpty()) {
+        keyState.localAdditions = Lists.newArrayList();
+      }
+      keyState.valuesCached = true;
+    }
+
+    @Override
+    public void clear() {
+      keyStateMap = Maps.newHashMap();
+      cleared = true;
+      complete = true;
+      allKeysKnown = true;
+    }
+
+    @Override
+    public ReadableState<Iterable<K>> keys() {
+      return new ReadableState<Iterable<K>>() {
+
+        private Map<Object, K> cachedExistKeys() {
+          return keyStateMap.entrySet().stream()
+              .filter(entry -> entry.getValue().existence == KeyExistence.KNOWN_EXIST)
+              .collect(Collectors.toMap(Entry::getKey, e -> e.getValue().originalKey));
+        }
+
+        @Override
+        public Iterable<K> read() {
+          if (allKeysKnown) {
+            return Iterables.unmodifiableIterable(cachedExistKeys().values());
+          }
+          Future<Iterable<Entry<ByteString, Iterable<V>>>> persistedData = getFuture(true);
+          try (Closeable scope = scopedReadState()) {
+            Iterable<Entry<ByteString, Iterable<V>>> entries = persistedData.get();
+            if (entries instanceof Weighted) {
+              // This is a known amount of data, cache them all.
+              entries.forEach(
+                  entry -> {
+                    try {
+                      K originalKey = keyCoder.decode(entry.getKey().newInput(), Context.OUTER);
+                      KeyState keyState =
+                          keyStateMap.computeIfAbsent(
+                              keyCoder.structuralValue(originalKey),
+                              stk -> new KeyState(originalKey));
+                      if (keyState.existence == KeyExistence.UNKNOWN_EXISTENCE) {
+                        keyState.existence = KeyExistence.KNOWN_EXIST;
+                      }
+                    } catch (IOException e) {
+                      throw new RuntimeException(e);
+                    }
+                  });
+              allKeysKnown = true;
+              keyStateMap
+                  .values()
+                  .removeIf(keyState -> keyState.existence != KeyExistence.KNOWN_EXIST);
+              return Iterables.unmodifiableIterable(cachedExistKeys().values());
+            } else {
+              Map<Object, K> cachedExistKeys = Maps.newHashMap();
+              Set<Object> cachedNonExistKeys = Sets.newHashSet();
+              keyStateMap.forEach(
+                  (structuralKey, keyState) -> {
+                    switch (keyState.existence) {
+                      case KNOWN_EXIST:
+                        cachedExistKeys.put(structuralKey, keyState.originalKey);
+                        break;
+                      case KNOWN_NONEXISTENT:
+                        cachedNonExistKeys.add(structuralKey);
+                        break;
+                      default:
+                        break;
+                    }
+                  });
+              // keysOnlyInWindmill is lazily loaded.
+              Iterable<K> keysOnlyInWindmill =
+                  Iterables.filter(
+                      Iterables.transform(
+                          entries,
+                          entry -> {
+                            try {
+                              K originalKey =
+                                  keyCoder.decode(entry.getKey().newInput(), Context.OUTER);
+                              Object structuralKey = keyCoder.structuralValue(originalKey);
+                              if (cachedExistKeys.containsKey(structuralKey)
+                                  || cachedNonExistKeys.contains(structuralKey)) return null;
+                              return originalKey;
+                            } catch (IOException e) {
+                              throw new RuntimeException(e);
+                            }
+                          }),
+                      Objects::nonNull);
+              return Iterables.unmodifiableIterable(
+                  Iterables.concat(cachedExistKeys.values(), keysOnlyInWindmill));
+            }
+          } catch (InterruptedException | ExecutionException | IOException e) {
+            if (e instanceof InterruptedException) {
+              Thread.currentThread().interrupt();
+            }
+            throw new RuntimeException("Unable to read state", e);
+          }
+        }
+
+        @Override
+        @SuppressWarnings("FutureReturnValueIgnored")
+        public ReadableState<Iterable<K>> readLater() {
+          WindmillMultimap.this.getFuture(true);
+          return this;
+        }
+      };
+    }
+
+    @Override
+    public ReadableState<Iterable<Entry<K, V>>> entries() {
+      return new ReadableState<Iterable<Entry<K, V>>>() {
+        @Override
+        public Iterable<Entry<K, V>> read() {
+          if (complete) {
+            return Iterables.unmodifiableIterable(
+                unnestCachedEntries(mergedCachedEntries(null).entrySet()));
+          }
+          Future<Iterable<Entry<ByteString, Iterable<V>>>> persistedData = getFuture(false);
+          try (Closeable scope = scopedReadState()) {
+            Iterable<Entry<ByteString, Iterable<V>>> entries = persistedData.get();
+            if (Iterables.isEmpty(entries)) {
+              complete = true;
+              allKeysKnown = true;
+              return Iterables.unmodifiableIterable(
+                  unnestCachedEntries(mergedCachedEntries(null).entrySet()));
+            }
+            if (!(entries instanceof Weighted)) {
+              return nonWeightedEntries(entries);
+            }
+            // This is a known amount of data, cache them all.
+            entries.forEach(
+                entry -> {
+                  try {
+                    final K originalKey = keyCoder.decode(entry.getKey().newInput(), Context.OUTER);
+                    final Object structuralKey = keyCoder.structuralValue(originalKey);
+                    KeyState keyState =
+                        keyStateMap.computeIfAbsent(structuralKey, k -> new KeyState(originalKey));
+                    // Ignore any key from windmill that has been marked pending deletion or is
+                    // fully cached.
+                    if (keyState.existence == KeyExistence.KNOWN_NONEXISTENT
+                        || (keyState.existence == KeyExistence.KNOWN_EXIST
+                            && keyState.valuesCached)) return;
+                    // Or else cache contents from windmill.
+                    keyState.existence = KeyExistence.KNOWN_EXIST;
+                    keyState.values.extendWith(entry.getValue());
+                    keyState.valuesSize += Iterables.size(entry.getValue());
+                    // We can't set keyState.valuesCached to true here, because there may be more
+                    // paginated values that should not be filtered out in above if statement.
+                    // keyState.valuesCached will be set to true in later call of
+                    // mergedCachedEntries.
+                  } catch (IOException e) {
+                    throw new RuntimeException(e);
+                  }
+                });
+            allKeysKnown = true;
+            complete = true;
+            return Iterables.unmodifiableIterable(
+                unnestCachedEntries(mergedCachedEntries(null).entrySet()));
+          } catch (InterruptedException | ExecutionException | IOException e) {
+            if (e instanceof InterruptedException) {
+              Thread.currentThread().interrupt();
+            }
+            throw new RuntimeException("Unable to read state", e);
+          }
+        }
+
+        @Override
+        @SuppressWarnings("FutureReturnValueIgnored")
+        public ReadableState<Iterable<Entry<K, V>>> readLater() {
+          WindmillMultimap.this.getFuture(false);
+          return this;
+        }
+
+        // Collect all cached entries into a map and all KNOWN_NONEXISTENT keys to
+        // knownNonexistentKeys(if not null). Note that this method is not side-effect-free: it
+        // unloads any key that is not KNOWN_EXIST and not pending deletion from cache; also if
+        // complete it marks the valuesCached of any key that is KNOWN_EXIST to true, entries()
+        // depends on this behavior when the fetched result is weighted to iterate the whole
+        // keyStateMap one less time.
+        private Map<Object, Triple<K, Boolean, ConcatIterables<V>>> mergedCachedEntries(
+            Map<Object, K> knownNonexistentKeys) {
+          Map<Object, Triple<K, Boolean, ConcatIterables<V>>> cachedEntries = Maps.newHashMap();
+          keyStateMap
+              .entrySet()
+              .removeIf(
+                  (entry -> {
+                    Object structuralKey = entry.getKey();
+                    KeyState keyState = entry.getValue();
+                    if (complete && keyState.existence == KeyExistence.KNOWN_EXIST) {
+                      keyState.valuesCached = true;
+                    }
+                    ConcatIterables<V> it = null;
+                    if (!keyState.localAdditions.isEmpty()) {
+                      it = new ConcatIterables<>();
+                      it.extendWith(
+                          Iterables.limit(keyState.localAdditions, keyState.localAdditions.size()));
+                    }
+                    if (keyState.valuesCached) {
+                      if (it == null) it = new ConcatIterables<>();
+                      it.extendWith(Iterables.limit(keyState.values, keyState.valuesSize));
+                    }
+                    if (it != null)
+                      cachedEntries.put(
+                          structuralKey,
+                          Triple.of(keyState.originalKey, keyState.valuesCached, it));
+                    if (knownNonexistentKeys != null
+                        && keyState.existence == KeyExistence.KNOWN_NONEXISTENT)
+                      knownNonexistentKeys.put(structuralKey, keyState.originalKey);
+                    return (keyState.existence == KeyExistence.KNOWN_NONEXISTENT
+                            && !keyState.removedLocally)
+                        || keyState.existence == KeyExistence.UNKNOWN_EXISTENCE;
+                  }));
+          return cachedEntries;
+        }
+
+        private Iterable<Entry<K, V>> unnestCachedEntries(
+            Iterable<Entry<Object, Triple<K, Boolean, ConcatIterables<V>>>> cachedEntries) {
+          return Iterables.unmodifiableIterable(
+              () ->

Review Comment:
   Removed, it’s not necessary indeed.



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternals.java:
##########
@@ -1588,7 +1598,648 @@ private Future<Iterable<Map.Entry<ByteString, V>>> getFuture() {
         return reader.valuePrefixFuture(stateKeyPrefix, stateFamily, valueCoder);
       }
     }
-  };
+  }
+
+  private static class WindmillMultimap<K, V> extends SimpleWindmillState
+      implements MultimapState<K, V> {
+
+    private final StateNamespace namespace;
+    private final StateTag<MultimapState<K, V>> address;
+    private final ByteString stateKey;
+    private final String stateFamily;
+    private final Coder<K> keyCoder;
+    private final Coder<V> valueCoder;
+
+    private class KeyState {
+      final K originalKey;
+      KeyExistence existence;
+      // valuesCached can be true if only existence == KNOWN_EXIST and all values of this key is
+      // cached (both values and localAdditions).
+      boolean valuesCached;
+      // Represents the values in windmill. When new values are added during user processing, they
+      // are added to localAdditions but not values. Those new values will be added to values only
+      // after they are persisted into windmill and removed from localAdditions
+      ConcatIterables<V> values;
+      int valuesSize;
+
+      // When new values are added during user processing, they are added to localAdditions, so that
+      // we can later try to persist them in windmill. When a key is removed during user processing,
+      // we mark removedLocally to be true so that we can later try to delete it from windmill. If
+      // localAdditions is not empty and removedLocally is true, values in localAdditions will be
+      // added to windmill after old values in windmill are removed.
+      List<V> localAdditions;
+      boolean removedLocally;
+
+      KeyState(K originalKey) {
+        this.originalKey = originalKey;
+        existence = KeyExistence.UNKNOWN_EXISTENCE;
+        valuesCached = complete;
+        values = new ConcatIterables<>();
+        valuesSize = 0;
+        localAdditions = Lists.newArrayList();
+        removedLocally = false;
+      }
+    }
+
+    private enum KeyExistence {
+      // this key is known to exist
+      KNOWN_EXIST,
+      // this key is known to be nonexistent
+      KNOWN_NONEXISTENT,
+      // we don't know if this key is in this multimap, this is just to provide a mapping between
+      // the original key and the structural key.
+      UNKNOWN_EXISTENCE
+    }
+
+    private boolean cleared = false;
+    // We use the structural value of the keys as the key in keyStateMap, so that different java
+    // Objects with the same content will be treated as the same Multimap key.
+    private Map<Object, KeyState> keyStateMap = Maps.newHashMap();
+    // If true, all keys are cached in keyStateMap with existence == KNOWN_EXIST.
+    private boolean allKeysKnown = false;
+
+    private boolean complete = false;
+    // hasLocalAdditions and hasLocalRemovals track whether there are local changes that needs to be
+    // propagated to windmill.
+    private boolean hasLocalAdditions = false;
+    private boolean hasLocalRemovals = false;
+
+    private WindmillMultimap(
+        StateNamespace namespace,
+        StateTag<MultimapState<K, V>> address,
+        String stateFamily,
+        Coder<K> keyCoder,
+        Coder<V> valueCoder,
+        boolean isNewShardingKey) {
+      this.namespace = namespace;
+      this.address = address;
+      this.stateKey = encodeKey(namespace, address);
+      this.stateFamily = stateFamily;
+      this.keyCoder = keyCoder;
+      this.valueCoder = valueCoder;
+      this.complete = isNewShardingKey;
+      this.allKeysKnown = isNewShardingKey;
+    }
+
+    @Override
+    public void put(K key, V value) {
+      final Object structuralKey = keyCoder.structuralValue(key);
+      hasLocalAdditions = true;
+      keyStateMap.compute(
+          structuralKey,
+          (k, v) -> {
+            if (v == null) v = new KeyState(key);
+            v.existence = KeyExistence.KNOWN_EXIST;
+            v.localAdditions.add(value);
+            return v;
+          });
+    }
+
+    // Initiates a backend state read to fetch all entries if necessary.
+    private Future<Iterable<Map.Entry<ByteString, Iterable<V>>>> getFuture(boolean omitValues) {
+      if (complete) {
+        return Futures.immediateFuture(Collections.emptyList());
+      } else {
+        return reader.multimapFetchAllFuture(omitValues, stateKey, stateFamily, valueCoder);
+      }
+    }
+
+    // Initiates a backend state read to fetch a single entry if necessary.
+    private Future<Iterable<V>> getFutureForKey(K key) {
+      try {
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        keyCoder.encode(key, keyStream, Context.OUTER);
+        return reader.multimapFetchSingleEntryFuture(
+            keyStream.toByteString(), stateKey, stateFamily, valueCoder);
+      } catch (IOException e) {
+        throw new RuntimeException(e);
+      }
+    }
+
+    @Override
+    public ReadableState<Iterable<V>> get(K key) {
+      return new ReadableState<Iterable<V>>() {
+        final Object structuralKey = keyCoder.structuralValue(key);
+
+        @Override
+        public Iterable<V> read() {
+          KeyState keyState = keyStateMap.computeIfAbsent(structuralKey, k -> new KeyState(key));
+          if (keyState.existence == KeyExistence.KNOWN_NONEXISTENT) {
+            return Collections.emptyList();
+          }
+          if (allKeysKnown && keyState.existence == KeyExistence.UNKNOWN_EXISTENCE) {
+            keyStateMap.remove(structuralKey);
+            return Collections.emptyList();
+          }
+          Iterable<V> localNewValues =
+              Iterables.limit(keyState.localAdditions, keyState.localAdditions.size());
+          if (keyState.removedLocally) {
+            // this key has been removed locally but the removal hasn't been sent to windmill,
+            // thus values in windmill(if any) are obsolete, and we only care about local values.
+            return Iterables.unmodifiableIterable(localNewValues);
+          }
+          if (keyState.valuesCached || complete) {
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(
+                    Iterables.limit(keyState.values, keyState.valuesSize), localNewValues));
+          }
+          Future<Iterable<V>> persistedData = getFutureForKey(key);
+          try (Closeable scope = scopedReadState()) {
+            final Iterable<V> persistedValues = persistedData.get();
+            // Iterables.isEmpty() is O(1).
+            if (Iterables.isEmpty(persistedValues)) {
+              if (keyState.localAdditions.isEmpty()) {
+                // empty in both cache and windmill, mark key as KNOWN_NONEXISTENT.
+                keyState.existence = KeyExistence.KNOWN_NONEXISTENT;
+                return Collections.emptyList();
+              }
+              return Iterables.unmodifiableIterable(localNewValues);
+            }
+            keyState.existence = KeyExistence.KNOWN_EXIST;
+            if (persistedValues instanceof Weighted) {
+              keyState.valuesCached = true;
+              ConcatIterables<V> it = new ConcatIterables<>();
+              it.extendWith(persistedValues);
+              keyState.values = it;
+              keyState.valuesSize = Iterables.size(persistedValues);
+            }
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(persistedValues, localNewValues));
+          } catch (InterruptedException | ExecutionException | IOException e) {
+            if (e instanceof InterruptedException) {
+              Thread.currentThread().interrupt();
+            }
+            throw new RuntimeException("Unable to read Multimap state", e);
+          }
+        }
+
+        @Override
+        @SuppressWarnings("FutureReturnValueIgnored")
+        public ReadableState<Iterable<V>> readLater() {
+          WindmillMultimap.this.getFutureForKey(key);
+          return this;
+        }
+      };
+    }
+
+    @Override
+    protected WorkItemCommitRequest persistDirectly(WindmillStateCache.ForKeyAndFamily cache)
+        throws IOException {
+      if (!cleared && !hasLocalAdditions && !hasLocalRemovals) {
+        cache.put(namespace, address, this, 1);
+        return WorkItemCommitRequest.newBuilder().buildPartial();
+      }
+      WorkItemCommitRequest.Builder commitBuilder = WorkItemCommitRequest.newBuilder();
+      Windmill.TagMultimapUpdateRequest.Builder builder = commitBuilder.addMultimapUpdatesBuilder();
+      builder.setTag(stateKey).setStateFamily(stateFamily);
+
+      if (cleared) {
+        builder.setDeleteAll(true);
+      }
+      if (hasLocalRemovals || hasLocalAdditions) {
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        ByteStringOutputStream valueStream = new ByteStringOutputStream();
+        Iterator<Entry<Object, KeyState>> iterator = keyStateMap.entrySet().iterator();
+        while (iterator.hasNext()) {
+          KeyState keyState = iterator.next().getValue();
+          if (!keyState.removedLocally && keyState.localAdditions.isEmpty()) {
+            if (keyState.existence == KeyExistence.KNOWN_NONEXISTENT) iterator.remove();
+            continue;
+          }
+          keyCoder.encode(keyState.originalKey, keyStream, Context.OUTER);
+          ByteString encodedKey = keyStream.toByteStringAndReset();
+          Windmill.TagMultimapEntry.Builder entryBuilder = builder.addUpdatesBuilder();
+          entryBuilder.setEntryName(encodedKey);
+          entryBuilder.setDeleteAll(keyState.removedLocally);
+          keyState.removedLocally = false;
+          for (V value : keyState.localAdditions) {
+            valueCoder.encode(value, valueStream, Context.OUTER);
+            ByteString encodedValue = valueStream.toByteStringAndReset();
+            entryBuilder.addValues(encodedValue);
+          }
+          // Move newly added values from localAdditions to keyState.values as those new values now
+          // are also persisted in Windmill. If a key now has no more values and is not KNOWN_EXIST,
+          // remove it from cache.
+          if (keyState.valuesCached) {
+            keyState.values.extendWith(keyState.localAdditions);
+            keyState.valuesSize += keyState.localAdditions.size();
+          }
+          // Create a new localAdditions so that the cached values are unaffected.
+          keyState.localAdditions = Lists.newArrayList();
+          if (!keyState.valuesCached && keyState.existence != KeyExistence.KNOWN_EXIST) {
+            iterator.remove();
+          }
+        }
+      }
+
+      hasLocalAdditions = false;
+      hasLocalRemovals = false;
+      cleared = false;
+
+      cache.put(namespace, address, this, 1);
+      return commitBuilder.buildPartial();
+    }
+
+    @Override
+    public void remove(K key) {
+      final Object structuralKey = keyCoder.structuralValue(key);
+      KeyState keyState = keyStateMap.computeIfAbsent(structuralKey, k -> new KeyState(key));
+      if (keyState.existence == KeyExistence.KNOWN_NONEXISTENT
+          || (allKeysKnown && keyState.existence == KeyExistence.UNKNOWN_EXISTENCE)) {
+        return;
+      }
+      if (!keyState.valuesCached || (keyState.valuesCached && keyState.valuesSize > 0)) {
+        // there may be data in windmill that need to be removed.
+        hasLocalRemovals = true;
+        keyState.removedLocally = true;
+        keyState.values = new ConcatIterables<>();
+        keyState.valuesSize = 0;
+        keyState.existence = KeyExistence.KNOWN_NONEXISTENT;
+      } else {
+        // no data in windmill, deleting from local cache is sufficient.
+        keyStateMap.remove(structuralKey);
+      }
+      if (!keyState.localAdditions.isEmpty()) {
+        keyState.localAdditions = Lists.newArrayList();
+      }
+      keyState.valuesCached = true;
+    }
+
+    @Override
+    public void clear() {
+      keyStateMap = Maps.newHashMap();
+      cleared = true;
+      complete = true;
+      allKeysKnown = true;
+    }
+
+    @Override
+    public ReadableState<Iterable<K>> keys() {
+      return new ReadableState<Iterable<K>>() {
+
+        private Map<Object, K> cachedExistKeys() {
+          return keyStateMap.entrySet().stream()
+              .filter(entry -> entry.getValue().existence == KeyExistence.KNOWN_EXIST)
+              .collect(Collectors.toMap(Entry::getKey, e -> e.getValue().originalKey));
+        }
+
+        @Override
+        public Iterable<K> read() {
+          if (allKeysKnown) {
+            return Iterables.unmodifiableIterable(cachedExistKeys().values());
+          }
+          Future<Iterable<Entry<ByteString, Iterable<V>>>> persistedData = getFuture(true);
+          try (Closeable scope = scopedReadState()) {
+            Iterable<Entry<ByteString, Iterable<V>>> entries = persistedData.get();
+            if (entries instanceof Weighted) {
+              // This is a known amount of data, cache them all.
+              entries.forEach(
+                  entry -> {
+                    try {
+                      K originalKey = keyCoder.decode(entry.getKey().newInput(), Context.OUTER);
+                      KeyState keyState =
+                          keyStateMap.computeIfAbsent(
+                              keyCoder.structuralValue(originalKey),
+                              stk -> new KeyState(originalKey));
+                      if (keyState.existence == KeyExistence.UNKNOWN_EXISTENCE) {
+                        keyState.existence = KeyExistence.KNOWN_EXIST;
+                      }
+                    } catch (IOException e) {
+                      throw new RuntimeException(e);
+                    }
+                  });
+              allKeysKnown = true;
+              keyStateMap
+                  .values()
+                  .removeIf(keyState -> keyState.existence != KeyExistence.KNOWN_EXIST);
+              return Iterables.unmodifiableIterable(cachedExistKeys().values());
+            } else {
+              Map<Object, K> cachedExistKeys = Maps.newHashMap();
+              Set<Object> cachedNonExistKeys = Sets.newHashSet();
+              keyStateMap.forEach(
+                  (structuralKey, keyState) -> {
+                    switch (keyState.existence) {
+                      case KNOWN_EXIST:
+                        cachedExistKeys.put(structuralKey, keyState.originalKey);
+                        break;
+                      case KNOWN_NONEXISTENT:
+                        cachedNonExistKeys.add(structuralKey);
+                        break;
+                      default:
+                        break;
+                    }
+                  });
+              // keysOnlyInWindmill is lazily loaded.
+              Iterable<K> keysOnlyInWindmill =
+                  Iterables.filter(
+                      Iterables.transform(
+                          entries,
+                          entry -> {
+                            try {
+                              K originalKey =
+                                  keyCoder.decode(entry.getKey().newInput(), Context.OUTER);
+                              Object structuralKey = keyCoder.structuralValue(originalKey);
+                              if (cachedExistKeys.containsKey(structuralKey)
+                                  || cachedNonExistKeys.contains(structuralKey)) return null;
+                              return originalKey;
+                            } catch (IOException e) {
+                              throw new RuntimeException(e);
+                            }
+                          }),
+                      Objects::nonNull);
+              return Iterables.unmodifiableIterable(
+                  Iterables.concat(cachedExistKeys.values(), keysOnlyInWindmill));
+            }
+          } catch (InterruptedException | ExecutionException | IOException e) {
+            if (e instanceof InterruptedException) {
+              Thread.currentThread().interrupt();
+            }
+            throw new RuntimeException("Unable to read state", e);
+          }
+        }
+
+        @Override
+        @SuppressWarnings("FutureReturnValueIgnored")
+        public ReadableState<Iterable<K>> readLater() {
+          WindmillMultimap.this.getFuture(true);
+          return this;
+        }
+      };
+    }
+
+    @Override
+    public ReadableState<Iterable<Entry<K, V>>> entries() {
+      return new ReadableState<Iterable<Entry<K, V>>>() {
+        @Override
+        public Iterable<Entry<K, V>> read() {
+          if (complete) {
+            return Iterables.unmodifiableIterable(
+                unnestCachedEntries(mergedCachedEntries(null).entrySet()));
+          }
+          Future<Iterable<Entry<ByteString, Iterable<V>>>> persistedData = getFuture(false);
+          try (Closeable scope = scopedReadState()) {
+            Iterable<Entry<ByteString, Iterable<V>>> entries = persistedData.get();
+            if (Iterables.isEmpty(entries)) {
+              complete = true;
+              allKeysKnown = true;
+              return Iterables.unmodifiableIterable(
+                  unnestCachedEntries(mergedCachedEntries(null).entrySet()));
+            }
+            if (!(entries instanceof Weighted)) {
+              return nonWeightedEntries(entries);
+            }
+            // This is a known amount of data, cache them all.
+            entries.forEach(
+                entry -> {
+                  try {
+                    final K originalKey = keyCoder.decode(entry.getKey().newInput(), Context.OUTER);
+                    final Object structuralKey = keyCoder.structuralValue(originalKey);
+                    KeyState keyState =
+                        keyStateMap.computeIfAbsent(structuralKey, k -> new KeyState(originalKey));
+                    // Ignore any key from windmill that has been marked pending deletion or is
+                    // fully cached.
+                    if (keyState.existence == KeyExistence.KNOWN_NONEXISTENT
+                        || (keyState.existence == KeyExistence.KNOWN_EXIST
+                            && keyState.valuesCached)) return;
+                    // Or else cache contents from windmill.
+                    keyState.existence = KeyExistence.KNOWN_EXIST;
+                    keyState.values.extendWith(entry.getValue());
+                    keyState.valuesSize += Iterables.size(entry.getValue());
+                    // We can't set keyState.valuesCached to true here, because there may be more

Review Comment:
   That doesn’t seems to be the case. When [`consumeBag()`](https://github.com/apache/beam/blob/ddae966f3346fbe247486324cbf8a8a532895316/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateReader.java#L781) [`this.bagPageValues()`](https://github.com/apache/beam/blob/ddae966f3346fbe247486324cbf8a8a532895316/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateReader.java#L720) always return a weighted list ignoring `shouldRemove`. So a paginated response could also be weighted.



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternals.java:
##########
@@ -1588,7 +1598,648 @@ private Future<Iterable<Map.Entry<ByteString, V>>> getFuture() {
         return reader.valuePrefixFuture(stateKeyPrefix, stateFamily, valueCoder);
       }
     }
-  };
+  }
+
+  private static class WindmillMultimap<K, V> extends SimpleWindmillState
+      implements MultimapState<K, V> {
+
+    private final StateNamespace namespace;
+    private final StateTag<MultimapState<K, V>> address;
+    private final ByteString stateKey;
+    private final String stateFamily;
+    private final Coder<K> keyCoder;
+    private final Coder<V> valueCoder;
+
+    private class KeyState {
+      final K originalKey;
+      KeyExistence existence;
+      // valuesCached can be true if only existence == KNOWN_EXIST and all values of this key is
+      // cached (both values and localAdditions).
+      boolean valuesCached;
+      // Represents the values in windmill. When new values are added during user processing, they
+      // are added to localAdditions but not values. Those new values will be added to values only
+      // after they are persisted into windmill and removed from localAdditions
+      ConcatIterables<V> values;
+      int valuesSize;
+
+      // When new values are added during user processing, they are added to localAdditions, so that
+      // we can later try to persist them in windmill. When a key is removed during user processing,
+      // we mark removedLocally to be true so that we can later try to delete it from windmill. If
+      // localAdditions is not empty and removedLocally is true, values in localAdditions will be
+      // added to windmill after old values in windmill are removed.
+      List<V> localAdditions;
+      boolean removedLocally;
+
+      KeyState(K originalKey) {
+        this.originalKey = originalKey;
+        existence = KeyExistence.UNKNOWN_EXISTENCE;
+        valuesCached = complete;
+        values = new ConcatIterables<>();
+        valuesSize = 0;
+        localAdditions = Lists.newArrayList();
+        removedLocally = false;
+      }
+    }
+
+    private enum KeyExistence {
+      // this key is known to exist
+      KNOWN_EXIST,
+      // this key is known to be nonexistent
+      KNOWN_NONEXISTENT,
+      // we don't know if this key is in this multimap, this is just to provide a mapping between
+      // the original key and the structural key.
+      UNKNOWN_EXISTENCE
+    }
+
+    private boolean cleared = false;
+    // We use the structural value of the keys as the key in keyStateMap, so that different java
+    // Objects with the same content will be treated as the same Multimap key.
+    private Map<Object, KeyState> keyStateMap = Maps.newHashMap();
+    // If true, all keys are cached in keyStateMap with existence == KNOWN_EXIST.
+    private boolean allKeysKnown = false;
+
+    private boolean complete = false;
+    // hasLocalAdditions and hasLocalRemovals track whether there are local changes that needs to be
+    // propagated to windmill.
+    private boolean hasLocalAdditions = false;
+    private boolean hasLocalRemovals = false;
+
+    private WindmillMultimap(
+        StateNamespace namespace,
+        StateTag<MultimapState<K, V>> address,
+        String stateFamily,
+        Coder<K> keyCoder,
+        Coder<V> valueCoder,
+        boolean isNewShardingKey) {
+      this.namespace = namespace;
+      this.address = address;
+      this.stateKey = encodeKey(namespace, address);
+      this.stateFamily = stateFamily;
+      this.keyCoder = keyCoder;
+      this.valueCoder = valueCoder;
+      this.complete = isNewShardingKey;
+      this.allKeysKnown = isNewShardingKey;
+    }
+
+    @Override
+    public void put(K key, V value) {
+      final Object structuralKey = keyCoder.structuralValue(key);
+      hasLocalAdditions = true;
+      keyStateMap.compute(
+          structuralKey,
+          (k, v) -> {
+            if (v == null) v = new KeyState(key);
+            v.existence = KeyExistence.KNOWN_EXIST;
+            v.localAdditions.add(value);
+            return v;
+          });
+    }
+
+    // Initiates a backend state read to fetch all entries if necessary.
+    private Future<Iterable<Map.Entry<ByteString, Iterable<V>>>> getFuture(boolean omitValues) {
+      if (complete) {
+        return Futures.immediateFuture(Collections.emptyList());
+      } else {
+        return reader.multimapFetchAllFuture(omitValues, stateKey, stateFamily, valueCoder);
+      }
+    }
+
+    // Initiates a backend state read to fetch a single entry if necessary.
+    private Future<Iterable<V>> getFutureForKey(K key) {
+      try {
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        keyCoder.encode(key, keyStream, Context.OUTER);
+        return reader.multimapFetchSingleEntryFuture(
+            keyStream.toByteString(), stateKey, stateFamily, valueCoder);
+      } catch (IOException e) {
+        throw new RuntimeException(e);
+      }
+    }
+
+    @Override
+    public ReadableState<Iterable<V>> get(K key) {
+      return new ReadableState<Iterable<V>>() {
+        final Object structuralKey = keyCoder.structuralValue(key);
+
+        @Override
+        public Iterable<V> read() {
+          KeyState keyState = keyStateMap.computeIfAbsent(structuralKey, k -> new KeyState(key));
+          if (keyState.existence == KeyExistence.KNOWN_NONEXISTENT) {
+            return Collections.emptyList();
+          }
+          if (allKeysKnown && keyState.existence == KeyExistence.UNKNOWN_EXISTENCE) {

Review Comment:
   Done for the second comment.
   
   `a key that has UNKNOWN_EXISTENCE but a pending delete` will be set to `KNOWN_NONEXISTENT` in `remove()`. `testMultimapRemoveAndPersist` covers this.
   
   



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternals.java:
##########
@@ -1588,7 +1598,648 @@ private Future<Iterable<Map.Entry<ByteString, V>>> getFuture() {
         return reader.valuePrefixFuture(stateKeyPrefix, stateFamily, valueCoder);
       }
     }
-  };
+  }
+
+  private static class WindmillMultimap<K, V> extends SimpleWindmillState
+      implements MultimapState<K, V> {
+
+    private final StateNamespace namespace;
+    private final StateTag<MultimapState<K, V>> address;
+    private final ByteString stateKey;
+    private final String stateFamily;
+    private final Coder<K> keyCoder;
+    private final Coder<V> valueCoder;
+
+    private class KeyState {
+      final K originalKey;
+      KeyExistence existence;
+      // valuesCached can be true if only existence == KNOWN_EXIST and all values of this key is
+      // cached (both values and localAdditions).
+      boolean valuesCached;
+      // Represents the values in windmill. When new values are added during user processing, they
+      // are added to localAdditions but not values. Those new values will be added to values only
+      // after they are persisted into windmill and removed from localAdditions
+      ConcatIterables<V> values;
+      int valuesSize;
+
+      // When new values are added during user processing, they are added to localAdditions, so that
+      // we can later try to persist them in windmill. When a key is removed during user processing,
+      // we mark removedLocally to be true so that we can later try to delete it from windmill. If
+      // localAdditions is not empty and removedLocally is true, values in localAdditions will be
+      // added to windmill after old values in windmill are removed.
+      List<V> localAdditions;
+      boolean removedLocally;
+
+      KeyState(K originalKey) {
+        this.originalKey = originalKey;
+        existence = KeyExistence.UNKNOWN_EXISTENCE;
+        valuesCached = complete;
+        values = new ConcatIterables<>();
+        valuesSize = 0;
+        localAdditions = Lists.newArrayList();
+        removedLocally = false;
+      }
+    }
+
+    private enum KeyExistence {
+      // this key is known to exist
+      KNOWN_EXIST,
+      // this key is known to be nonexistent
+      KNOWN_NONEXISTENT,
+      // we don't know if this key is in this multimap, this is just to provide a mapping between
+      // the original key and the structural key.
+      UNKNOWN_EXISTENCE
+    }
+
+    private boolean cleared = false;
+    // We use the structural value of the keys as the key in keyStateMap, so that different java
+    // Objects with the same content will be treated as the same Multimap key.
+    private Map<Object, KeyState> keyStateMap = Maps.newHashMap();
+    // If true, all keys are cached in keyStateMap with existence == KNOWN_EXIST.
+    private boolean allKeysKnown = false;
+
+    private boolean complete = false;

Review Comment:
   Done



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternals.java:
##########
@@ -1588,7 +1598,648 @@ private Future<Iterable<Map.Entry<ByteString, V>>> getFuture() {
         return reader.valuePrefixFuture(stateKeyPrefix, stateFamily, valueCoder);
       }
     }
-  };
+  }
+
+  private static class WindmillMultimap<K, V> extends SimpleWindmillState
+      implements MultimapState<K, V> {
+
+    private final StateNamespace namespace;
+    private final StateTag<MultimapState<K, V>> address;
+    private final ByteString stateKey;
+    private final String stateFamily;
+    private final Coder<K> keyCoder;
+    private final Coder<V> valueCoder;
+
+    private class KeyState {
+      final K originalKey;
+      KeyExistence existence;
+      // valuesCached can be true if only existence == KNOWN_EXIST and all values of this key is
+      // cached (both values and localAdditions).
+      boolean valuesCached;
+      // Represents the values in windmill. When new values are added during user processing, they
+      // are added to localAdditions but not values. Those new values will be added to values only
+      // after they are persisted into windmill and removed from localAdditions
+      ConcatIterables<V> values;
+      int valuesSize;
+
+      // When new values are added during user processing, they are added to localAdditions, so that
+      // we can later try to persist them in windmill. When a key is removed during user processing,
+      // we mark removedLocally to be true so that we can later try to delete it from windmill. If
+      // localAdditions is not empty and removedLocally is true, values in localAdditions will be
+      // added to windmill after old values in windmill are removed.
+      List<V> localAdditions;
+      boolean removedLocally;
+
+      KeyState(K originalKey) {
+        this.originalKey = originalKey;
+        existence = KeyExistence.UNKNOWN_EXISTENCE;
+        valuesCached = complete;
+        values = new ConcatIterables<>();
+        valuesSize = 0;
+        localAdditions = Lists.newArrayList();
+        removedLocally = false;
+      }
+    }
+
+    private enum KeyExistence {
+      // this key is known to exist
+      KNOWN_EXIST,
+      // this key is known to be nonexistent
+      KNOWN_NONEXISTENT,
+      // we don't know if this key is in this multimap, this is just to provide a mapping between
+      // the original key and the structural key.
+      UNKNOWN_EXISTENCE
+    }
+
+    private boolean cleared = false;
+    // We use the structural value of the keys as the key in keyStateMap, so that different java
+    // Objects with the same content will be treated as the same Multimap key.
+    private Map<Object, KeyState> keyStateMap = Maps.newHashMap();
+    // If true, all keys are cached in keyStateMap with existence == KNOWN_EXIST.
+    private boolean allKeysKnown = false;
+
+    private boolean complete = false;
+    // hasLocalAdditions and hasLocalRemovals track whether there are local changes that needs to be
+    // propagated to windmill.
+    private boolean hasLocalAdditions = false;
+    private boolean hasLocalRemovals = false;
+
+    private WindmillMultimap(
+        StateNamespace namespace,
+        StateTag<MultimapState<K, V>> address,
+        String stateFamily,
+        Coder<K> keyCoder,
+        Coder<V> valueCoder,
+        boolean isNewShardingKey) {
+      this.namespace = namespace;
+      this.address = address;
+      this.stateKey = encodeKey(namespace, address);
+      this.stateFamily = stateFamily;
+      this.keyCoder = keyCoder;
+      this.valueCoder = valueCoder;
+      this.complete = isNewShardingKey;
+      this.allKeysKnown = isNewShardingKey;
+    }
+
+    @Override
+    public void put(K key, V value) {
+      final Object structuralKey = keyCoder.structuralValue(key);
+      hasLocalAdditions = true;
+      keyStateMap.compute(
+          structuralKey,
+          (k, v) -> {
+            if (v == null) v = new KeyState(key);
+            v.existence = KeyExistence.KNOWN_EXIST;
+            v.localAdditions.add(value);
+            return v;
+          });
+    }
+
+    // Initiates a backend state read to fetch all entries if necessary.
+    private Future<Iterable<Map.Entry<ByteString, Iterable<V>>>> getFuture(boolean omitValues) {
+      if (complete) {
+        return Futures.immediateFuture(Collections.emptyList());
+      } else {
+        return reader.multimapFetchAllFuture(omitValues, stateKey, stateFamily, valueCoder);
+      }
+    }
+
+    // Initiates a backend state read to fetch a single entry if necessary.
+    private Future<Iterable<V>> getFutureForKey(K key) {
+      try {
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        keyCoder.encode(key, keyStream, Context.OUTER);
+        return reader.multimapFetchSingleEntryFuture(
+            keyStream.toByteString(), stateKey, stateFamily, valueCoder);
+      } catch (IOException e) {
+        throw new RuntimeException(e);
+      }
+    }
+
+    @Override
+    public ReadableState<Iterable<V>> get(K key) {
+      return new ReadableState<Iterable<V>>() {
+        final Object structuralKey = keyCoder.structuralValue(key);
+
+        @Override
+        public Iterable<V> read() {
+          KeyState keyState = keyStateMap.computeIfAbsent(structuralKey, k -> new KeyState(key));
+          if (keyState.existence == KeyExistence.KNOWN_NONEXISTENT) {
+            return Collections.emptyList();
+          }
+          if (allKeysKnown && keyState.existence == KeyExistence.UNKNOWN_EXISTENCE) {
+            keyStateMap.remove(structuralKey);
+            return Collections.emptyList();
+          }
+          Iterable<V> localNewValues =
+              Iterables.limit(keyState.localAdditions, keyState.localAdditions.size());
+          if (keyState.removedLocally) {
+            // this key has been removed locally but the removal hasn't been sent to windmill,
+            // thus values in windmill(if any) are obsolete, and we only care about local values.
+            return Iterables.unmodifiableIterable(localNewValues);
+          }
+          if (keyState.valuesCached || complete) {
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(
+                    Iterables.limit(keyState.values, keyState.valuesSize), localNewValues));
+          }
+          Future<Iterable<V>> persistedData = getFutureForKey(key);
+          try (Closeable scope = scopedReadState()) {
+            final Iterable<V> persistedValues = persistedData.get();
+            // Iterables.isEmpty() is O(1).
+            if (Iterables.isEmpty(persistedValues)) {
+              if (keyState.localAdditions.isEmpty()) {
+                // empty in both cache and windmill, mark key as KNOWN_NONEXISTENT.
+                keyState.existence = KeyExistence.KNOWN_NONEXISTENT;
+                return Collections.emptyList();
+              }
+              return Iterables.unmodifiableIterable(localNewValues);
+            }
+            keyState.existence = KeyExistence.KNOWN_EXIST;
+            if (persistedValues instanceof Weighted) {
+              keyState.valuesCached = true;
+              ConcatIterables<V> it = new ConcatIterables<>();
+              it.extendWith(persistedValues);
+              keyState.values = it;
+              keyState.valuesSize = Iterables.size(persistedValues);
+            }
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(persistedValues, localNewValues));
+          } catch (InterruptedException | ExecutionException | IOException e) {
+            if (e instanceof InterruptedException) {
+              Thread.currentThread().interrupt();
+            }
+            throw new RuntimeException("Unable to read Multimap state", e);
+          }
+        }
+
+        @Override
+        @SuppressWarnings("FutureReturnValueIgnored")
+        public ReadableState<Iterable<V>> readLater() {
+          WindmillMultimap.this.getFutureForKey(key);
+          return this;
+        }
+      };
+    }
+
+    @Override
+    protected WorkItemCommitRequest persistDirectly(WindmillStateCache.ForKeyAndFamily cache)
+        throws IOException {
+      if (!cleared && !hasLocalAdditions && !hasLocalRemovals) {
+        cache.put(namespace, address, this, 1);
+        return WorkItemCommitRequest.newBuilder().buildPartial();
+      }
+      WorkItemCommitRequest.Builder commitBuilder = WorkItemCommitRequest.newBuilder();
+      Windmill.TagMultimapUpdateRequest.Builder builder = commitBuilder.addMultimapUpdatesBuilder();
+      builder.setTag(stateKey).setStateFamily(stateFamily);
+
+      if (cleared) {
+        builder.setDeleteAll(true);
+      }
+      if (hasLocalRemovals || hasLocalAdditions) {
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        ByteStringOutputStream valueStream = new ByteStringOutputStream();
+        Iterator<Entry<Object, KeyState>> iterator = keyStateMap.entrySet().iterator();
+        while (iterator.hasNext()) {
+          KeyState keyState = iterator.next().getValue();
+          if (!keyState.removedLocally && keyState.localAdditions.isEmpty()) {
+            if (keyState.existence == KeyExistence.KNOWN_NONEXISTENT) iterator.remove();
+            continue;
+          }
+          keyCoder.encode(keyState.originalKey, keyStream, Context.OUTER);
+          ByteString encodedKey = keyStream.toByteStringAndReset();
+          Windmill.TagMultimapEntry.Builder entryBuilder = builder.addUpdatesBuilder();
+          entryBuilder.setEntryName(encodedKey);
+          entryBuilder.setDeleteAll(keyState.removedLocally);

Review Comment:
   removedLocally could be false here if this key has only additions.



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternals.java:
##########
@@ -1588,7 +1598,648 @@ private Future<Iterable<Map.Entry<ByteString, V>>> getFuture() {
         return reader.valuePrefixFuture(stateKeyPrefix, stateFamily, valueCoder);
       }
     }
-  };
+  }
+
+  private static class WindmillMultimap<K, V> extends SimpleWindmillState
+      implements MultimapState<K, V> {
+
+    private final StateNamespace namespace;
+    private final StateTag<MultimapState<K, V>> address;
+    private final ByteString stateKey;
+    private final String stateFamily;
+    private final Coder<K> keyCoder;
+    private final Coder<V> valueCoder;
+
+    private class KeyState {
+      final K originalKey;
+      KeyExistence existence;
+      // valuesCached can be true if only existence == KNOWN_EXIST and all values of this key is
+      // cached (both values and localAdditions).
+      boolean valuesCached;
+      // Represents the values in windmill. When new values are added during user processing, they
+      // are added to localAdditions but not values. Those new values will be added to values only
+      // after they are persisted into windmill and removed from localAdditions
+      ConcatIterables<V> values;
+      int valuesSize;
+
+      // When new values are added during user processing, they are added to localAdditions, so that
+      // we can later try to persist them in windmill. When a key is removed during user processing,
+      // we mark removedLocally to be true so that we can later try to delete it from windmill. If
+      // localAdditions is not empty and removedLocally is true, values in localAdditions will be
+      // added to windmill after old values in windmill are removed.
+      List<V> localAdditions;
+      boolean removedLocally;
+
+      KeyState(K originalKey) {
+        this.originalKey = originalKey;
+        existence = KeyExistence.UNKNOWN_EXISTENCE;
+        valuesCached = complete;
+        values = new ConcatIterables<>();
+        valuesSize = 0;
+        localAdditions = Lists.newArrayList();
+        removedLocally = false;
+      }
+    }
+
+    private enum KeyExistence {
+      // this key is known to exist
+      KNOWN_EXIST,
+      // this key is known to be nonexistent
+      KNOWN_NONEXISTENT,
+      // we don't know if this key is in this multimap, this is just to provide a mapping between
+      // the original key and the structural key.
+      UNKNOWN_EXISTENCE
+    }
+
+    private boolean cleared = false;
+    // We use the structural value of the keys as the key in keyStateMap, so that different java
+    // Objects with the same content will be treated as the same Multimap key.
+    private Map<Object, KeyState> keyStateMap = Maps.newHashMap();
+    // If true, all keys are cached in keyStateMap with existence == KNOWN_EXIST.
+    private boolean allKeysKnown = false;
+
+    private boolean complete = false;
+    // hasLocalAdditions and hasLocalRemovals track whether there are local changes that needs to be
+    // propagated to windmill.
+    private boolean hasLocalAdditions = false;
+    private boolean hasLocalRemovals = false;
+
+    private WindmillMultimap(
+        StateNamespace namespace,
+        StateTag<MultimapState<K, V>> address,
+        String stateFamily,
+        Coder<K> keyCoder,
+        Coder<V> valueCoder,
+        boolean isNewShardingKey) {
+      this.namespace = namespace;
+      this.address = address;
+      this.stateKey = encodeKey(namespace, address);
+      this.stateFamily = stateFamily;
+      this.keyCoder = keyCoder;
+      this.valueCoder = valueCoder;
+      this.complete = isNewShardingKey;
+      this.allKeysKnown = isNewShardingKey;
+    }
+
+    @Override
+    public void put(K key, V value) {
+      final Object structuralKey = keyCoder.structuralValue(key);
+      hasLocalAdditions = true;
+      keyStateMap.compute(
+          structuralKey,
+          (k, v) -> {
+            if (v == null) v = new KeyState(key);
+            v.existence = KeyExistence.KNOWN_EXIST;
+            v.localAdditions.add(value);
+            return v;
+          });
+    }
+
+    // Initiates a backend state read to fetch all entries if necessary.
+    private Future<Iterable<Map.Entry<ByteString, Iterable<V>>>> getFuture(boolean omitValues) {
+      if (complete) {
+        return Futures.immediateFuture(Collections.emptyList());
+      } else {
+        return reader.multimapFetchAllFuture(omitValues, stateKey, stateFamily, valueCoder);
+      }
+    }
+
+    // Initiates a backend state read to fetch a single entry if necessary.
+    private Future<Iterable<V>> getFutureForKey(K key) {
+      try {
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        keyCoder.encode(key, keyStream, Context.OUTER);
+        return reader.multimapFetchSingleEntryFuture(
+            keyStream.toByteString(), stateKey, stateFamily, valueCoder);
+      } catch (IOException e) {
+        throw new RuntimeException(e);
+      }
+    }
+
+    @Override
+    public ReadableState<Iterable<V>> get(K key) {
+      return new ReadableState<Iterable<V>>() {
+        final Object structuralKey = keyCoder.structuralValue(key);
+
+        @Override
+        public Iterable<V> read() {
+          KeyState keyState = keyStateMap.computeIfAbsent(structuralKey, k -> new KeyState(key));
+          if (keyState.existence == KeyExistence.KNOWN_NONEXISTENT) {
+            return Collections.emptyList();
+          }
+          if (allKeysKnown && keyState.existence == KeyExistence.UNKNOWN_EXISTENCE) {
+            keyStateMap.remove(structuralKey);
+            return Collections.emptyList();
+          }
+          Iterable<V> localNewValues =
+              Iterables.limit(keyState.localAdditions, keyState.localAdditions.size());
+          if (keyState.removedLocally) {
+            // this key has been removed locally but the removal hasn't been sent to windmill,
+            // thus values in windmill(if any) are obsolete, and we only care about local values.
+            return Iterables.unmodifiableIterable(localNewValues);
+          }
+          if (keyState.valuesCached || complete) {
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(
+                    Iterables.limit(keyState.values, keyState.valuesSize), localNewValues));
+          }
+          Future<Iterable<V>> persistedData = getFutureForKey(key);
+          try (Closeable scope = scopedReadState()) {
+            final Iterable<V> persistedValues = persistedData.get();
+            // Iterables.isEmpty() is O(1).
+            if (Iterables.isEmpty(persistedValues)) {
+              if (keyState.localAdditions.isEmpty()) {
+                // empty in both cache and windmill, mark key as KNOWN_NONEXISTENT.
+                keyState.existence = KeyExistence.KNOWN_NONEXISTENT;
+                return Collections.emptyList();
+              }
+              return Iterables.unmodifiableIterable(localNewValues);
+            }
+            keyState.existence = KeyExistence.KNOWN_EXIST;
+            if (persistedValues instanceof Weighted) {
+              keyState.valuesCached = true;
+              ConcatIterables<V> it = new ConcatIterables<>();
+              it.extendWith(persistedValues);
+              keyState.values = it;
+              keyState.valuesSize = Iterables.size(persistedValues);
+            }
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(persistedValues, localNewValues));
+          } catch (InterruptedException | ExecutionException | IOException e) {
+            if (e instanceof InterruptedException) {
+              Thread.currentThread().interrupt();
+            }
+            throw new RuntimeException("Unable to read Multimap state", e);
+          }
+        }
+
+        @Override
+        @SuppressWarnings("FutureReturnValueIgnored")
+        public ReadableState<Iterable<V>> readLater() {
+          WindmillMultimap.this.getFutureForKey(key);
+          return this;
+        }
+      };
+    }
+
+    @Override
+    protected WorkItemCommitRequest persistDirectly(WindmillStateCache.ForKeyAndFamily cache)
+        throws IOException {
+      if (!cleared && !hasLocalAdditions && !hasLocalRemovals) {
+        cache.put(namespace, address, this, 1);
+        return WorkItemCommitRequest.newBuilder().buildPartial();
+      }
+      WorkItemCommitRequest.Builder commitBuilder = WorkItemCommitRequest.newBuilder();
+      Windmill.TagMultimapUpdateRequest.Builder builder = commitBuilder.addMultimapUpdatesBuilder();
+      builder.setTag(stateKey).setStateFamily(stateFamily);
+
+      if (cleared) {
+        builder.setDeleteAll(true);
+      }
+      if (hasLocalRemovals || hasLocalAdditions) {
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        ByteStringOutputStream valueStream = new ByteStringOutputStream();
+        Iterator<Entry<Object, KeyState>> iterator = keyStateMap.entrySet().iterator();
+        while (iterator.hasNext()) {
+          KeyState keyState = iterator.next().getValue();
+          if (!keyState.removedLocally && keyState.localAdditions.isEmpty()) {
+            if (keyState.existence == KeyExistence.KNOWN_NONEXISTENT) iterator.remove();
+            continue;
+          }
+          keyCoder.encode(keyState.originalKey, keyStream, Context.OUTER);
+          ByteString encodedKey = keyStream.toByteStringAndReset();
+          Windmill.TagMultimapEntry.Builder entryBuilder = builder.addUpdatesBuilder();
+          entryBuilder.setEntryName(encodedKey);
+          entryBuilder.setDeleteAll(keyState.removedLocally);
+          keyState.removedLocally = false;
+          for (V value : keyState.localAdditions) {
+            valueCoder.encode(value, valueStream, Context.OUTER);
+            ByteString encodedValue = valueStream.toByteStringAndReset();
+            entryBuilder.addValues(encodedValue);
+          }
+          // Move newly added values from localAdditions to keyState.values as those new values now
+          // are also persisted in Windmill. If a key now has no more values and is not KNOWN_EXIST,
+          // remove it from cache.
+          if (keyState.valuesCached) {
+            keyState.values.extendWith(keyState.localAdditions);
+            keyState.valuesSize += keyState.localAdditions.size();
+          }
+          // Create a new localAdditions so that the cached values are unaffected.
+          keyState.localAdditions = Lists.newArrayList();
+          if (!keyState.valuesCached && keyState.existence != KeyExistence.KNOWN_EXIST) {
+            iterator.remove();
+          }
+        }
+      }
+
+      hasLocalAdditions = false;
+      hasLocalRemovals = false;
+      cleared = false;
+
+      cache.put(namespace, address, this, 1);
+      return commitBuilder.buildPartial();
+    }
+
+    @Override
+    public void remove(K key) {
+      final Object structuralKey = keyCoder.structuralValue(key);
+      KeyState keyState = keyStateMap.computeIfAbsent(structuralKey, k -> new KeyState(key));
+      if (keyState.existence == KeyExistence.KNOWN_NONEXISTENT
+          || (allKeysKnown && keyState.existence == KeyExistence.UNKNOWN_EXISTENCE)) {
+        return;
+      }
+      if (!keyState.valuesCached || (keyState.valuesCached && keyState.valuesSize > 0)) {
+        // there may be data in windmill that need to be removed.
+        hasLocalRemovals = true;
+        keyState.removedLocally = true;
+        keyState.values = new ConcatIterables<>();
+        keyState.valuesSize = 0;
+        keyState.existence = KeyExistence.KNOWN_NONEXISTENT;
+      } else {
+        // no data in windmill, deleting from local cache is sufficient.
+        keyStateMap.remove(structuralKey);
+      }
+      if (!keyState.localAdditions.isEmpty()) {
+        keyState.localAdditions = Lists.newArrayList();
+      }
+      keyState.valuesCached = true;
+    }
+
+    @Override
+    public void clear() {
+      keyStateMap = Maps.newHashMap();
+      cleared = true;
+      complete = true;
+      allKeysKnown = true;
+    }
+
+    @Override
+    public ReadableState<Iterable<K>> keys() {
+      return new ReadableState<Iterable<K>>() {
+
+        private Map<Object, K> cachedExistKeys() {
+          return keyStateMap.entrySet().stream()
+              .filter(entry -> entry.getValue().existence == KeyExistence.KNOWN_EXIST)
+              .collect(Collectors.toMap(Entry::getKey, e -> e.getValue().originalKey));
+        }
+
+        @Override
+        public Iterable<K> read() {
+          if (allKeysKnown) {
+            return Iterables.unmodifiableIterable(cachedExistKeys().values());
+          }
+          Future<Iterable<Entry<ByteString, Iterable<V>>>> persistedData = getFuture(true);
+          try (Closeable scope = scopedReadState()) {
+            Iterable<Entry<ByteString, Iterable<V>>> entries = persistedData.get();
+            if (entries instanceof Weighted) {
+              // This is a known amount of data, cache them all.
+              entries.forEach(
+                  entry -> {
+                    try {
+                      K originalKey = keyCoder.decode(entry.getKey().newInput(), Context.OUTER);
+                      KeyState keyState =
+                          keyStateMap.computeIfAbsent(
+                              keyCoder.structuralValue(originalKey),
+                              stk -> new KeyState(originalKey));
+                      if (keyState.existence == KeyExistence.UNKNOWN_EXISTENCE) {
+                        keyState.existence = KeyExistence.KNOWN_EXIST;
+                      }
+                    } catch (IOException e) {
+                      throw new RuntimeException(e);
+                    }
+                  });
+              allKeysKnown = true;
+              keyStateMap
+                  .values()
+                  .removeIf(keyState -> keyState.existence != KeyExistence.KNOWN_EXIST);
+              return Iterables.unmodifiableIterable(cachedExistKeys().values());
+            } else {
+              Map<Object, K> cachedExistKeys = Maps.newHashMap();
+              Set<Object> cachedNonExistKeys = Sets.newHashSet();
+              keyStateMap.forEach(
+                  (structuralKey, keyState) -> {
+                    switch (keyState.existence) {
+                      case KNOWN_EXIST:
+                        cachedExistKeys.put(structuralKey, keyState.originalKey);
+                        break;
+                      case KNOWN_NONEXISTENT:
+                        cachedNonExistKeys.add(structuralKey);
+                        break;
+                      default:
+                        break;
+                    }
+                  });
+              // keysOnlyInWindmill is lazily loaded.
+              Iterable<K> keysOnlyInWindmill =
+                  Iterables.filter(
+                      Iterables.transform(
+                          entries,
+                          entry -> {
+                            try {
+                              K originalKey =
+                                  keyCoder.decode(entry.getKey().newInput(), Context.OUTER);
+                              Object structuralKey = keyCoder.structuralValue(originalKey);
+                              if (cachedExistKeys.containsKey(structuralKey)
+                                  || cachedNonExistKeys.contains(structuralKey)) return null;
+                              return originalKey;
+                            } catch (IOException e) {
+                              throw new RuntimeException(e);
+                            }
+                          }),
+                      Objects::nonNull);
+              return Iterables.unmodifiableIterable(
+                  Iterables.concat(cachedExistKeys.values(), keysOnlyInWindmill));
+            }
+          } catch (InterruptedException | ExecutionException | IOException e) {
+            if (e instanceof InterruptedException) {
+              Thread.currentThread().interrupt();
+            }
+            throw new RuntimeException("Unable to read state", e);
+          }
+        }
+
+        @Override
+        @SuppressWarnings("FutureReturnValueIgnored")
+        public ReadableState<Iterable<K>> readLater() {
+          WindmillMultimap.this.getFuture(true);
+          return this;
+        }
+      };
+    }
+
+    @Override
+    public ReadableState<Iterable<Entry<K, V>>> entries() {
+      return new ReadableState<Iterable<Entry<K, V>>>() {
+        @Override
+        public Iterable<Entry<K, V>> read() {
+          if (complete) {
+            return Iterables.unmodifiableIterable(
+                unnestCachedEntries(mergedCachedEntries(null).entrySet()));
+          }
+          Future<Iterable<Entry<ByteString, Iterable<V>>>> persistedData = getFuture(false);
+          try (Closeable scope = scopedReadState()) {
+            Iterable<Entry<ByteString, Iterable<V>>> entries = persistedData.get();
+            if (Iterables.isEmpty(entries)) {
+              complete = true;
+              allKeysKnown = true;
+              return Iterables.unmodifiableIterable(
+                  unnestCachedEntries(mergedCachedEntries(null).entrySet()));
+            }
+            if (!(entries instanceof Weighted)) {
+              return nonWeightedEntries(entries);
+            }
+            // This is a known amount of data, cache them all.
+            entries.forEach(
+                entry -> {
+                  try {
+                    final K originalKey = keyCoder.decode(entry.getKey().newInput(), Context.OUTER);
+                    final Object structuralKey = keyCoder.structuralValue(originalKey);
+                    KeyState keyState =
+                        keyStateMap.computeIfAbsent(structuralKey, k -> new KeyState(originalKey));
+                    // Ignore any key from windmill that has been marked pending deletion or is
+                    // fully cached.
+                    if (keyState.existence == KeyExistence.KNOWN_NONEXISTENT
+                        || (keyState.existence == KeyExistence.KNOWN_EXIST
+                            && keyState.valuesCached)) return;
+                    // Or else cache contents from windmill.
+                    keyState.existence = KeyExistence.KNOWN_EXIST;
+                    keyState.values.extendWith(entry.getValue());
+                    keyState.valuesSize += Iterables.size(entry.getValue());
+                    // We can't set keyState.valuesCached to true here, because there may be more
+                    // paginated values that should not be filtered out in above if statement.
+                    // keyState.valuesCached will be set to true in later call of
+                    // mergedCachedEntries.
+                  } catch (IOException e) {
+                    throw new RuntimeException(e);
+                  }
+                });
+            allKeysKnown = true;
+            complete = true;
+            return Iterables.unmodifiableIterable(
+                unnestCachedEntries(mergedCachedEntries(null).entrySet()));
+          } catch (InterruptedException | ExecutionException | IOException e) {
+            if (e instanceof InterruptedException) {
+              Thread.currentThread().interrupt();
+            }
+            throw new RuntimeException("Unable to read state", e);
+          }
+        }
+
+        @Override
+        @SuppressWarnings("FutureReturnValueIgnored")
+        public ReadableState<Iterable<Entry<K, V>>> readLater() {
+          WindmillMultimap.this.getFuture(false);
+          return this;
+        }
+
+        // Collect all cached entries into a map and all KNOWN_NONEXISTENT keys to
+        // knownNonexistentKeys(if not null). Note that this method is not side-effect-free: it
+        // unloads any key that is not KNOWN_EXIST and not pending deletion from cache; also if
+        // complete it marks the valuesCached of any key that is KNOWN_EXIST to true, entries()
+        // depends on this behavior when the fetched result is weighted to iterate the whole
+        // keyStateMap one less time.
+        private Map<Object, Triple<K, Boolean, ConcatIterables<V>>> mergedCachedEntries(
+            Map<Object, K> knownNonexistentKeys) {
+          Map<Object, Triple<K, Boolean, ConcatIterables<V>>> cachedEntries = Maps.newHashMap();
+          keyStateMap
+              .entrySet()
+              .removeIf(
+                  (entry -> {
+                    Object structuralKey = entry.getKey();
+                    KeyState keyState = entry.getValue();
+                    if (complete && keyState.existence == KeyExistence.KNOWN_EXIST) {
+                      keyState.valuesCached = true;
+                    }
+                    ConcatIterables<V> it = null;
+                    if (!keyState.localAdditions.isEmpty()) {
+                      it = new ConcatIterables<>();
+                      it.extendWith(
+                          Iterables.limit(keyState.localAdditions, keyState.localAdditions.size()));
+                    }
+                    if (keyState.valuesCached) {
+                      if (it == null) it = new ConcatIterables<>();
+                      it.extendWith(Iterables.limit(keyState.values, keyState.valuesSize));
+                    }
+                    if (it != null)
+                      cachedEntries.put(
+                          structuralKey,
+                          Triple.of(keyState.originalKey, keyState.valuesCached, it));
+                    if (knownNonexistentKeys != null
+                        && keyState.existence == KeyExistence.KNOWN_NONEXISTENT)
+                      knownNonexistentKeys.put(structuralKey, keyState.originalKey);
+                    return (keyState.existence == KeyExistence.KNOWN_NONEXISTENT
+                            && !keyState.removedLocally)
+                        || keyState.existence == KeyExistence.UNKNOWN_EXISTENCE;
+                  }));
+          return cachedEntries;
+        }
+
+        private Iterable<Entry<K, V>> unnestCachedEntries(
+            Iterable<Entry<Object, Triple<K, Boolean, ConcatIterables<V>>>> cachedEntries) {
+          return Iterables.unmodifiableIterable(
+              () ->
+                  Iterators.concat(
+                      Iterables.transform(
+                              cachedEntries,
+                              entry ->
+                                  Iterables.transform(
+                                          entry.getValue().getRight(),
+                                          v ->
+                                              new AbstractMap.SimpleEntry<>(
+                                                  entry.getValue().getLeft(), v))
+                                      .iterator())
+                          .iterator()));
+        }
+
+        private Iterable<Entry<K, V>> nonWeightedEntries(
+            Iterable<Entry<ByteString, Iterable<V>>> lazyWindmillEntries) {
+          class ResultIterable implements Iterable<Entry<K, V>> {
+            private final Iterable<Entry<ByteString, Iterable<V>>> lazyWindmillEntries;
+            private final Map<Object, Triple<K, Boolean, ConcatIterables<V>>> cachedEntries;
+            private final Map<Object, K> knownNonexistentKeys;
+
+            ResultIterable(
+                Map<Object, Triple<K, Boolean, ConcatIterables<V>>> cachedEntries,
+                Iterable<Entry<ByteString, Iterable<V>>> lazyWindmillEntries,
+                Map<Object, K> knownNonexistentKeys) {
+              this.cachedEntries = cachedEntries;
+              this.lazyWindmillEntries = lazyWindmillEntries;
+              this.knownNonexistentKeys = knownNonexistentKeys;
+            }
+
+            @Override
+            public Iterator<Entry<K, V>> iterator() {
+              // Each time when the Iterable returned by entries() is iterated, a new Iterator is
+              // created. Every iterator must keep its own copy of seenCachedKeys so that if a key
+              // is paginated into multiple iterables from windmill, the cached values of this key
+              // will only be returned once.
+              Set<Object> seenCachedKeys = Sets.newHashSet();
+              // notFullyCachedEntries returns all entries from windmill that are not fully cached
+              // and combines them with localAdditions. If a key is fully cached, contents of this
+              // key from windmill are ignored.
+              Iterable<Triple<Object, K, Iterable<V>>> notFullyCachedEntries =
+                  Iterables.filter(
+                      Iterables.transform(
+                          lazyWindmillEntries,
+                          entry -> {
+                            try {
+                              final K key =
+                                  keyCoder.decode(entry.getKey().newInput(), Context.OUTER);
+                              final Object structuralKey = keyCoder.structuralValue(key);
+                              // key is deleted in cache thus fully cached.
+                              if (knownNonexistentKeys.containsKey(structuralKey)) return null;
+                              Triple<K, Boolean, ConcatIterables<V>> triple =
+                                  cachedEntries.get(structuralKey);
+                              // no record of key in cache, return content in windmill.
+                              if (triple == null) {
+                                return Triple.of(structuralKey, key, entry.getValue());
+                              }
+                              // key is fully cached in cache.
+                              if (triple.getMiddle()) return null;
+
+                              // key is not fully cached, combine the content with local additions

Review Comment:
   Done



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternals.java:
##########
@@ -1588,7 +1598,648 @@ private Future<Iterable<Map.Entry<ByteString, V>>> getFuture() {
         return reader.valuePrefixFuture(stateKeyPrefix, stateFamily, valueCoder);
       }
     }
-  };
+  }
+
+  private static class WindmillMultimap<K, V> extends SimpleWindmillState
+      implements MultimapState<K, V> {
+
+    private final StateNamespace namespace;
+    private final StateTag<MultimapState<K, V>> address;
+    private final ByteString stateKey;
+    private final String stateFamily;
+    private final Coder<K> keyCoder;
+    private final Coder<V> valueCoder;
+
+    private class KeyState {
+      final K originalKey;
+      KeyExistence existence;
+      // valuesCached can be true if only existence == KNOWN_EXIST and all values of this key is
+      // cached (both values and localAdditions).
+      boolean valuesCached;
+      // Represents the values in windmill. When new values are added during user processing, they
+      // are added to localAdditions but not values. Those new values will be added to values only
+      // after they are persisted into windmill and removed from localAdditions
+      ConcatIterables<V> values;
+      int valuesSize;
+
+      // When new values are added during user processing, they are added to localAdditions, so that
+      // we can later try to persist them in windmill. When a key is removed during user processing,
+      // we mark removedLocally to be true so that we can later try to delete it from windmill. If
+      // localAdditions is not empty and removedLocally is true, values in localAdditions will be
+      // added to windmill after old values in windmill are removed.
+      List<V> localAdditions;
+      boolean removedLocally;
+
+      KeyState(K originalKey) {
+        this.originalKey = originalKey;
+        existence = KeyExistence.UNKNOWN_EXISTENCE;
+        valuesCached = complete;
+        values = new ConcatIterables<>();
+        valuesSize = 0;
+        localAdditions = Lists.newArrayList();
+        removedLocally = false;
+      }
+    }
+
+    private enum KeyExistence {
+      // this key is known to exist
+      KNOWN_EXIST,
+      // this key is known to be nonexistent
+      KNOWN_NONEXISTENT,
+      // we don't know if this key is in this multimap, this is just to provide a mapping between
+      // the original key and the structural key.
+      UNKNOWN_EXISTENCE
+    }
+
+    private boolean cleared = false;
+    // We use the structural value of the keys as the key in keyStateMap, so that different java
+    // Objects with the same content will be treated as the same Multimap key.
+    private Map<Object, KeyState> keyStateMap = Maps.newHashMap();
+    // If true, all keys are cached in keyStateMap with existence == KNOWN_EXIST.
+    private boolean allKeysKnown = false;
+
+    private boolean complete = false;
+    // hasLocalAdditions and hasLocalRemovals track whether there are local changes that needs to be
+    // propagated to windmill.
+    private boolean hasLocalAdditions = false;
+    private boolean hasLocalRemovals = false;
+
+    private WindmillMultimap(
+        StateNamespace namespace,
+        StateTag<MultimapState<K, V>> address,
+        String stateFamily,
+        Coder<K> keyCoder,
+        Coder<V> valueCoder,
+        boolean isNewShardingKey) {
+      this.namespace = namespace;
+      this.address = address;
+      this.stateKey = encodeKey(namespace, address);
+      this.stateFamily = stateFamily;
+      this.keyCoder = keyCoder;
+      this.valueCoder = valueCoder;
+      this.complete = isNewShardingKey;
+      this.allKeysKnown = isNewShardingKey;
+    }
+
+    @Override
+    public void put(K key, V value) {
+      final Object structuralKey = keyCoder.structuralValue(key);
+      hasLocalAdditions = true;
+      keyStateMap.compute(
+          structuralKey,
+          (k, v) -> {
+            if (v == null) v = new KeyState(key);
+            v.existence = KeyExistence.KNOWN_EXIST;
+            v.localAdditions.add(value);
+            return v;
+          });
+    }
+
+    // Initiates a backend state read to fetch all entries if necessary.
+    private Future<Iterable<Map.Entry<ByteString, Iterable<V>>>> getFuture(boolean omitValues) {
+      if (complete) {
+        return Futures.immediateFuture(Collections.emptyList());
+      } else {
+        return reader.multimapFetchAllFuture(omitValues, stateKey, stateFamily, valueCoder);
+      }
+    }
+
+    // Initiates a backend state read to fetch a single entry if necessary.
+    private Future<Iterable<V>> getFutureForKey(K key) {
+      try {
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        keyCoder.encode(key, keyStream, Context.OUTER);
+        return reader.multimapFetchSingleEntryFuture(
+            keyStream.toByteString(), stateKey, stateFamily, valueCoder);
+      } catch (IOException e) {
+        throw new RuntimeException(e);
+      }
+    }
+
+    @Override
+    public ReadableState<Iterable<V>> get(K key) {
+      return new ReadableState<Iterable<V>>() {
+        final Object structuralKey = keyCoder.structuralValue(key);
+
+        @Override
+        public Iterable<V> read() {
+          KeyState keyState = keyStateMap.computeIfAbsent(structuralKey, k -> new KeyState(key));
+          if (keyState.existence == KeyExistence.KNOWN_NONEXISTENT) {
+            return Collections.emptyList();
+          }
+          if (allKeysKnown && keyState.existence == KeyExistence.UNKNOWN_EXISTENCE) {
+            keyStateMap.remove(structuralKey);
+            return Collections.emptyList();
+          }
+          Iterable<V> localNewValues =
+              Iterables.limit(keyState.localAdditions, keyState.localAdditions.size());
+          if (keyState.removedLocally) {
+            // this key has been removed locally but the removal hasn't been sent to windmill,
+            // thus values in windmill(if any) are obsolete, and we only care about local values.
+            return Iterables.unmodifiableIterable(localNewValues);
+          }
+          if (keyState.valuesCached || complete) {
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(
+                    Iterables.limit(keyState.values, keyState.valuesSize), localNewValues));
+          }
+          Future<Iterable<V>> persistedData = getFutureForKey(key);
+          try (Closeable scope = scopedReadState()) {
+            final Iterable<V> persistedValues = persistedData.get();
+            // Iterables.isEmpty() is O(1).
+            if (Iterables.isEmpty(persistedValues)) {
+              if (keyState.localAdditions.isEmpty()) {
+                // empty in both cache and windmill, mark key as KNOWN_NONEXISTENT.
+                keyState.existence = KeyExistence.KNOWN_NONEXISTENT;
+                return Collections.emptyList();
+              }
+              return Iterables.unmodifiableIterable(localNewValues);
+            }
+            keyState.existence = KeyExistence.KNOWN_EXIST;
+            if (persistedValues instanceof Weighted) {
+              keyState.valuesCached = true;
+              ConcatIterables<V> it = new ConcatIterables<>();
+              it.extendWith(persistedValues);
+              keyState.values = it;
+              keyState.valuesSize = Iterables.size(persistedValues);
+            }
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(persistedValues, localNewValues));
+          } catch (InterruptedException | ExecutionException | IOException e) {
+            if (e instanceof InterruptedException) {
+              Thread.currentThread().interrupt();
+            }
+            throw new RuntimeException("Unable to read Multimap state", e);
+          }
+        }
+
+        @Override
+        @SuppressWarnings("FutureReturnValueIgnored")
+        public ReadableState<Iterable<V>> readLater() {
+          WindmillMultimap.this.getFutureForKey(key);
+          return this;
+        }
+      };
+    }
+
+    @Override
+    protected WorkItemCommitRequest persistDirectly(WindmillStateCache.ForKeyAndFamily cache)
+        throws IOException {
+      if (!cleared && !hasLocalAdditions && !hasLocalRemovals) {
+        cache.put(namespace, address, this, 1);
+        return WorkItemCommitRequest.newBuilder().buildPartial();
+      }
+      WorkItemCommitRequest.Builder commitBuilder = WorkItemCommitRequest.newBuilder();
+      Windmill.TagMultimapUpdateRequest.Builder builder = commitBuilder.addMultimapUpdatesBuilder();
+      builder.setTag(stateKey).setStateFamily(stateFamily);
+
+      if (cleared) {
+        builder.setDeleteAll(true);
+      }
+      if (hasLocalRemovals || hasLocalAdditions) {
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        ByteStringOutputStream valueStream = new ByteStringOutputStream();
+        Iterator<Entry<Object, KeyState>> iterator = keyStateMap.entrySet().iterator();
+        while (iterator.hasNext()) {
+          KeyState keyState = iterator.next().getValue();
+          if (!keyState.removedLocally && keyState.localAdditions.isEmpty()) {
+            if (keyState.existence == KeyExistence.KNOWN_NONEXISTENT) iterator.remove();
+            continue;
+          }
+          keyCoder.encode(keyState.originalKey, keyStream, Context.OUTER);
+          ByteString encodedKey = keyStream.toByteStringAndReset();
+          Windmill.TagMultimapEntry.Builder entryBuilder = builder.addUpdatesBuilder();
+          entryBuilder.setEntryName(encodedKey);
+          entryBuilder.setDeleteAll(keyState.removedLocally);
+          keyState.removedLocally = false;
+          for (V value : keyState.localAdditions) {
+            valueCoder.encode(value, valueStream, Context.OUTER);
+            ByteString encodedValue = valueStream.toByteStringAndReset();
+            entryBuilder.addValues(encodedValue);
+          }
+          // Move newly added values from localAdditions to keyState.values as those new values now
+          // are also persisted in Windmill. If a key now has no more values and is not KNOWN_EXIST,
+          // remove it from cache.
+          if (keyState.valuesCached) {
+            keyState.values.extendWith(keyState.localAdditions);
+            keyState.valuesSize += keyState.localAdditions.size();
+          }
+          // Create a new localAdditions so that the cached values are unaffected.
+          keyState.localAdditions = Lists.newArrayList();
+          if (!keyState.valuesCached && keyState.existence != KeyExistence.KNOWN_EXIST) {
+            iterator.remove();
+          }
+        }
+      }
+
+      hasLocalAdditions = false;
+      hasLocalRemovals = false;
+      cleared = false;
+
+      cache.put(namespace, address, this, 1);
+      return commitBuilder.buildPartial();
+    }
+
+    @Override
+    public void remove(K key) {
+      final Object structuralKey = keyCoder.structuralValue(key);
+      KeyState keyState = keyStateMap.computeIfAbsent(structuralKey, k -> new KeyState(key));
+      if (keyState.existence == KeyExistence.KNOWN_NONEXISTENT
+          || (allKeysKnown && keyState.existence == KeyExistence.UNKNOWN_EXISTENCE)) {
+        return;
+      }
+      if (!keyState.valuesCached || (keyState.valuesCached && keyState.valuesSize > 0)) {
+        // there may be data in windmill that need to be removed.
+        hasLocalRemovals = true;
+        keyState.removedLocally = true;
+        keyState.values = new ConcatIterables<>();
+        keyState.valuesSize = 0;
+        keyState.existence = KeyExistence.KNOWN_NONEXISTENT;
+      } else {
+        // no data in windmill, deleting from local cache is sufficient.
+        keyStateMap.remove(structuralKey);
+      }
+      if (!keyState.localAdditions.isEmpty()) {
+        keyState.localAdditions = Lists.newArrayList();
+      }
+      keyState.valuesCached = true;
+    }
+
+    @Override
+    public void clear() {
+      keyStateMap = Maps.newHashMap();
+      cleared = true;
+      complete = true;
+      allKeysKnown = true;
+    }
+
+    @Override
+    public ReadableState<Iterable<K>> keys() {
+      return new ReadableState<Iterable<K>>() {
+
+        private Map<Object, K> cachedExistKeys() {
+          return keyStateMap.entrySet().stream()
+              .filter(entry -> entry.getValue().existence == KeyExistence.KNOWN_EXIST)
+              .collect(Collectors.toMap(Entry::getKey, e -> e.getValue().originalKey));
+        }
+
+        @Override
+        public Iterable<K> read() {
+          if (allKeysKnown) {
+            return Iterables.unmodifiableIterable(cachedExistKeys().values());
+          }
+          Future<Iterable<Entry<ByteString, Iterable<V>>>> persistedData = getFuture(true);
+          try (Closeable scope = scopedReadState()) {
+            Iterable<Entry<ByteString, Iterable<V>>> entries = persistedData.get();
+            if (entries instanceof Weighted) {
+              // This is a known amount of data, cache them all.
+              entries.forEach(
+                  entry -> {
+                    try {
+                      K originalKey = keyCoder.decode(entry.getKey().newInput(), Context.OUTER);
+                      KeyState keyState =
+                          keyStateMap.computeIfAbsent(
+                              keyCoder.structuralValue(originalKey),
+                              stk -> new KeyState(originalKey));
+                      if (keyState.existence == KeyExistence.UNKNOWN_EXISTENCE) {
+                        keyState.existence = KeyExistence.KNOWN_EXIST;
+                      }
+                    } catch (IOException e) {
+                      throw new RuntimeException(e);
+                    }
+                  });
+              allKeysKnown = true;
+              keyStateMap
+                  .values()
+                  .removeIf(keyState -> keyState.existence != KeyExistence.KNOWN_EXIST);
+              return Iterables.unmodifiableIterable(cachedExistKeys().values());
+            } else {
+              Map<Object, K> cachedExistKeys = Maps.newHashMap();
+              Set<Object> cachedNonExistKeys = Sets.newHashSet();
+              keyStateMap.forEach(
+                  (structuralKey, keyState) -> {
+                    switch (keyState.existence) {
+                      case KNOWN_EXIST:
+                        cachedExistKeys.put(structuralKey, keyState.originalKey);
+                        break;
+                      case KNOWN_NONEXISTENT:
+                        cachedNonExistKeys.add(structuralKey);
+                        break;
+                      default:
+                        break;
+                    }
+                  });
+              // keysOnlyInWindmill is lazily loaded.
+              Iterable<K> keysOnlyInWindmill =
+                  Iterables.filter(
+                      Iterables.transform(
+                          entries,
+                          entry -> {
+                            try {
+                              K originalKey =
+                                  keyCoder.decode(entry.getKey().newInput(), Context.OUTER);
+                              Object structuralKey = keyCoder.structuralValue(originalKey);
+                              if (cachedExistKeys.containsKey(structuralKey)
+                                  || cachedNonExistKeys.contains(structuralKey)) return null;
+                              return originalKey;
+                            } catch (IOException e) {
+                              throw new RuntimeException(e);
+                            }
+                          }),
+                      Objects::nonNull);
+              return Iterables.unmodifiableIterable(
+                  Iterables.concat(cachedExistKeys.values(), keysOnlyInWindmill));
+            }
+          } catch (InterruptedException | ExecutionException | IOException e) {
+            if (e instanceof InterruptedException) {
+              Thread.currentThread().interrupt();
+            }
+            throw new RuntimeException("Unable to read state", e);
+          }
+        }
+
+        @Override
+        @SuppressWarnings("FutureReturnValueIgnored")
+        public ReadableState<Iterable<K>> readLater() {
+          WindmillMultimap.this.getFuture(true);
+          return this;
+        }
+      };
+    }
+
+    @Override
+    public ReadableState<Iterable<Entry<K, V>>> entries() {
+      return new ReadableState<Iterable<Entry<K, V>>>() {
+        @Override
+        public Iterable<Entry<K, V>> read() {
+          if (complete) {
+            return Iterables.unmodifiableIterable(
+                unnestCachedEntries(mergedCachedEntries(null).entrySet()));
+          }
+          Future<Iterable<Entry<ByteString, Iterable<V>>>> persistedData = getFuture(false);
+          try (Closeable scope = scopedReadState()) {
+            Iterable<Entry<ByteString, Iterable<V>>> entries = persistedData.get();
+            if (Iterables.isEmpty(entries)) {
+              complete = true;
+              allKeysKnown = true;
+              return Iterables.unmodifiableIterable(
+                  unnestCachedEntries(mergedCachedEntries(null).entrySet()));
+            }
+            if (!(entries instanceof Weighted)) {
+              return nonWeightedEntries(entries);
+            }
+            // This is a known amount of data, cache them all.
+            entries.forEach(
+                entry -> {
+                  try {
+                    final K originalKey = keyCoder.decode(entry.getKey().newInput(), Context.OUTER);
+                    final Object structuralKey = keyCoder.structuralValue(originalKey);
+                    KeyState keyState =
+                        keyStateMap.computeIfAbsent(structuralKey, k -> new KeyState(originalKey));
+                    // Ignore any key from windmill that has been marked pending deletion or is
+                    // fully cached.
+                    if (keyState.existence == KeyExistence.KNOWN_NONEXISTENT
+                        || (keyState.existence == KeyExistence.KNOWN_EXIST
+                            && keyState.valuesCached)) return;
+                    // Or else cache contents from windmill.
+                    keyState.existence = KeyExistence.KNOWN_EXIST;
+                    keyState.values.extendWith(entry.getValue());
+                    keyState.valuesSize += Iterables.size(entry.getValue());
+                    // We can't set keyState.valuesCached to true here, because there may be more
+                    // paginated values that should not be filtered out in above if statement.
+                    // keyState.valuesCached will be set to true in later call of
+                    // mergedCachedEntries.
+                  } catch (IOException e) {
+                    throw new RuntimeException(e);
+                  }
+                });
+            allKeysKnown = true;
+            complete = true;
+            return Iterables.unmodifiableIterable(
+                unnestCachedEntries(mergedCachedEntries(null).entrySet()));
+          } catch (InterruptedException | ExecutionException | IOException e) {
+            if (e instanceof InterruptedException) {
+              Thread.currentThread().interrupt();
+            }
+            throw new RuntimeException("Unable to read state", e);
+          }
+        }
+
+        @Override
+        @SuppressWarnings("FutureReturnValueIgnored")
+        public ReadableState<Iterable<Entry<K, V>>> readLater() {
+          WindmillMultimap.this.getFuture(false);
+          return this;
+        }
+
+        // Collect all cached entries into a map and all KNOWN_NONEXISTENT keys to
+        // knownNonexistentKeys(if not null). Note that this method is not side-effect-free: it
+        // unloads any key that is not KNOWN_EXIST and not pending deletion from cache; also if
+        // complete it marks the valuesCached of any key that is KNOWN_EXIST to true, entries()
+        // depends on this behavior when the fetched result is weighted to iterate the whole
+        // keyStateMap one less time.
+        private Map<Object, Triple<K, Boolean, ConcatIterables<V>>> mergedCachedEntries(
+            Map<Object, K> knownNonexistentKeys) {
+          Map<Object, Triple<K, Boolean, ConcatIterables<V>>> cachedEntries = Maps.newHashMap();
+          keyStateMap
+              .entrySet()
+              .removeIf(
+                  (entry -> {
+                    Object structuralKey = entry.getKey();
+                    KeyState keyState = entry.getValue();
+                    if (complete && keyState.existence == KeyExistence.KNOWN_EXIST) {
+                      keyState.valuesCached = true;
+                    }
+                    ConcatIterables<V> it = null;
+                    if (!keyState.localAdditions.isEmpty()) {
+                      it = new ConcatIterables<>();
+                      it.extendWith(
+                          Iterables.limit(keyState.localAdditions, keyState.localAdditions.size()));
+                    }
+                    if (keyState.valuesCached) {
+                      if (it == null) it = new ConcatIterables<>();
+                      it.extendWith(Iterables.limit(keyState.values, keyState.valuesSize));
+                    }
+                    if (it != null)
+                      cachedEntries.put(
+                          structuralKey,
+                          Triple.of(keyState.originalKey, keyState.valuesCached, it));
+                    if (knownNonexistentKeys != null
+                        && keyState.existence == KeyExistence.KNOWN_NONEXISTENT)
+                      knownNonexistentKeys.put(structuralKey, keyState.originalKey);
+                    return (keyState.existence == KeyExistence.KNOWN_NONEXISTENT
+                            && !keyState.removedLocally)
+                        || keyState.existence == KeyExistence.UNKNOWN_EXISTENCE;
+                  }));
+          return cachedEntries;
+        }
+
+        private Iterable<Entry<K, V>> unnestCachedEntries(
+            Iterable<Entry<Object, Triple<K, Boolean, ConcatIterables<V>>>> cachedEntries) {
+          return Iterables.unmodifiableIterable(
+              () ->
+                  Iterators.concat(
+                      Iterables.transform(
+                              cachedEntries,
+                              entry ->
+                                  Iterables.transform(
+                                          entry.getValue().getRight(),
+                                          v ->
+                                              new AbstractMap.SimpleEntry<>(
+                                                  entry.getValue().getLeft(), v))
+                                      .iterator())
+                          .iterator()));
+        }
+
+        private Iterable<Entry<K, V>> nonWeightedEntries(
+            Iterable<Entry<ByteString, Iterable<V>>> lazyWindmillEntries) {
+          class ResultIterable implements Iterable<Entry<K, V>> {
+            private final Iterable<Entry<ByteString, Iterable<V>>> lazyWindmillEntries;
+            private final Map<Object, Triple<K, Boolean, ConcatIterables<V>>> cachedEntries;
+            private final Map<Object, K> knownNonexistentKeys;
+
+            ResultIterable(
+                Map<Object, Triple<K, Boolean, ConcatIterables<V>>> cachedEntries,
+                Iterable<Entry<ByteString, Iterable<V>>> lazyWindmillEntries,
+                Map<Object, K> knownNonexistentKeys) {
+              this.cachedEntries = cachedEntries;
+              this.lazyWindmillEntries = lazyWindmillEntries;
+              this.knownNonexistentKeys = knownNonexistentKeys;
+            }
+
+            @Override
+            public Iterator<Entry<K, V>> iterator() {
+              // Each time when the Iterable returned by entries() is iterated, a new Iterator is
+              // created. Every iterator must keep its own copy of seenCachedKeys so that if a key
+              // is paginated into multiple iterables from windmill, the cached values of this key
+              // will only be returned once.
+              Set<Object> seenCachedKeys = Sets.newHashSet();
+              // notFullyCachedEntries returns all entries from windmill that are not fully cached
+              // and combines them with localAdditions. If a key is fully cached, contents of this
+              // key from windmill are ignored.
+              Iterable<Triple<Object, K, Iterable<V>>> notFullyCachedEntries =
+                  Iterables.filter(
+                      Iterables.transform(
+                          lazyWindmillEntries,
+                          entry -> {
+                            try {
+                              final K key =
+                                  keyCoder.decode(entry.getKey().newInput(), Context.OUTER);
+                              final Object structuralKey = keyCoder.structuralValue(key);
+                              // key is deleted in cache thus fully cached.
+                              if (knownNonexistentKeys.containsKey(structuralKey)) return null;
+                              Triple<K, Boolean, ConcatIterables<V>> triple =
+                                  cachedEntries.get(structuralKey);
+                              // no record of key in cache, return content in windmill.
+                              if (triple == null) {
+                                return Triple.of(structuralKey, key, entry.getValue());
+                              }
+                              // key is fully cached in cache.
+                              if (triple.getMiddle()) return null;
+
+                              // key is not fully cached, combine the content with local additions
+                              if (seenCachedKeys.contains(structuralKey)) {
+                                return Triple.of(structuralKey, key, entry.getValue());
+                              } else {
+                                seenCachedKeys.add(structuralKey);

Review Comment:
   Done for the first comment. `seenCachedKeys` can’t be removed: the `cachedEntries` object is per Iterable object, but the tracking of `seenCachedKeys` is per Iterator, each time the returned iterable is iterated, a new Iterator is created we need a separate `seenCachedKeys` to track the status and we can't modify `cachedEntries`.
   
   ```java
   Iterable<> entries = multimap.entries().read();
   for (entry : entries) {
     // if we modify cachedEntries here to clear the local additions
   }
   for (entry : entries) {
     // we can't see local additions here.
   }
   ```



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternals.java:
##########
@@ -1588,7 +1598,648 @@ private Future<Iterable<Map.Entry<ByteString, V>>> getFuture() {
         return reader.valuePrefixFuture(stateKeyPrefix, stateFamily, valueCoder);
       }
     }
-  };
+  }
+
+  private static class WindmillMultimap<K, V> extends SimpleWindmillState
+      implements MultimapState<K, V> {
+
+    private final StateNamespace namespace;
+    private final StateTag<MultimapState<K, V>> address;
+    private final ByteString stateKey;
+    private final String stateFamily;
+    private final Coder<K> keyCoder;
+    private final Coder<V> valueCoder;
+
+    private class KeyState {
+      final K originalKey;
+      KeyExistence existence;
+      // valuesCached can be true if only existence == KNOWN_EXIST and all values of this key is
+      // cached (both values and localAdditions).
+      boolean valuesCached;
+      // Represents the values in windmill. When new values are added during user processing, they
+      // are added to localAdditions but not values. Those new values will be added to values only
+      // after they are persisted into windmill and removed from localAdditions
+      ConcatIterables<V> values;
+      int valuesSize;
+
+      // When new values are added during user processing, they are added to localAdditions, so that
+      // we can later try to persist them in windmill. When a key is removed during user processing,
+      // we mark removedLocally to be true so that we can later try to delete it from windmill. If
+      // localAdditions is not empty and removedLocally is true, values in localAdditions will be
+      // added to windmill after old values in windmill are removed.
+      List<V> localAdditions;
+      boolean removedLocally;
+
+      KeyState(K originalKey) {
+        this.originalKey = originalKey;
+        existence = KeyExistence.UNKNOWN_EXISTENCE;
+        valuesCached = complete;
+        values = new ConcatIterables<>();
+        valuesSize = 0;
+        localAdditions = Lists.newArrayList();
+        removedLocally = false;
+      }
+    }
+
+    private enum KeyExistence {
+      // this key is known to exist
+      KNOWN_EXIST,
+      // this key is known to be nonexistent
+      KNOWN_NONEXISTENT,
+      // we don't know if this key is in this multimap, this is just to provide a mapping between
+      // the original key and the structural key.
+      UNKNOWN_EXISTENCE
+    }
+
+    private boolean cleared = false;
+    // We use the structural value of the keys as the key in keyStateMap, so that different java
+    // Objects with the same content will be treated as the same Multimap key.
+    private Map<Object, KeyState> keyStateMap = Maps.newHashMap();
+    // If true, all keys are cached in keyStateMap with existence == KNOWN_EXIST.
+    private boolean allKeysKnown = false;
+
+    private boolean complete = false;
+    // hasLocalAdditions and hasLocalRemovals track whether there are local changes that needs to be
+    // propagated to windmill.
+    private boolean hasLocalAdditions = false;
+    private boolean hasLocalRemovals = false;
+
+    private WindmillMultimap(
+        StateNamespace namespace,
+        StateTag<MultimapState<K, V>> address,
+        String stateFamily,
+        Coder<K> keyCoder,
+        Coder<V> valueCoder,
+        boolean isNewShardingKey) {
+      this.namespace = namespace;
+      this.address = address;
+      this.stateKey = encodeKey(namespace, address);
+      this.stateFamily = stateFamily;
+      this.keyCoder = keyCoder;
+      this.valueCoder = valueCoder;
+      this.complete = isNewShardingKey;
+      this.allKeysKnown = isNewShardingKey;
+    }
+
+    @Override
+    public void put(K key, V value) {
+      final Object structuralKey = keyCoder.structuralValue(key);
+      hasLocalAdditions = true;
+      keyStateMap.compute(
+          structuralKey,
+          (k, v) -> {
+            if (v == null) v = new KeyState(key);
+            v.existence = KeyExistence.KNOWN_EXIST;
+            v.localAdditions.add(value);
+            return v;
+          });
+    }
+
+    // Initiates a backend state read to fetch all entries if necessary.
+    private Future<Iterable<Map.Entry<ByteString, Iterable<V>>>> getFuture(boolean omitValues) {
+      if (complete) {
+        return Futures.immediateFuture(Collections.emptyList());
+      } else {
+        return reader.multimapFetchAllFuture(omitValues, stateKey, stateFamily, valueCoder);
+      }
+    }
+
+    // Initiates a backend state read to fetch a single entry if necessary.
+    private Future<Iterable<V>> getFutureForKey(K key) {
+      try {
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        keyCoder.encode(key, keyStream, Context.OUTER);
+        return reader.multimapFetchSingleEntryFuture(
+            keyStream.toByteString(), stateKey, stateFamily, valueCoder);
+      } catch (IOException e) {
+        throw new RuntimeException(e);
+      }
+    }
+
+    @Override
+    public ReadableState<Iterable<V>> get(K key) {
+      return new ReadableState<Iterable<V>>() {
+        final Object structuralKey = keyCoder.structuralValue(key);
+
+        @Override
+        public Iterable<V> read() {
+          KeyState keyState = keyStateMap.computeIfAbsent(structuralKey, k -> new KeyState(key));
+          if (keyState.existence == KeyExistence.KNOWN_NONEXISTENT) {
+            return Collections.emptyList();
+          }
+          if (allKeysKnown && keyState.existence == KeyExistence.UNKNOWN_EXISTENCE) {
+            keyStateMap.remove(structuralKey);
+            return Collections.emptyList();
+          }
+          Iterable<V> localNewValues =
+              Iterables.limit(keyState.localAdditions, keyState.localAdditions.size());
+          if (keyState.removedLocally) {
+            // this key has been removed locally but the removal hasn't been sent to windmill,
+            // thus values in windmill(if any) are obsolete, and we only care about local values.
+            return Iterables.unmodifiableIterable(localNewValues);
+          }
+          if (keyState.valuesCached || complete) {
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(
+                    Iterables.limit(keyState.values, keyState.valuesSize), localNewValues));
+          }
+          Future<Iterable<V>> persistedData = getFutureForKey(key);
+          try (Closeable scope = scopedReadState()) {
+            final Iterable<V> persistedValues = persistedData.get();
+            // Iterables.isEmpty() is O(1).
+            if (Iterables.isEmpty(persistedValues)) {
+              if (keyState.localAdditions.isEmpty()) {
+                // empty in both cache and windmill, mark key as KNOWN_NONEXISTENT.
+                keyState.existence = KeyExistence.KNOWN_NONEXISTENT;
+                return Collections.emptyList();
+              }
+              return Iterables.unmodifiableIterable(localNewValues);
+            }
+            keyState.existence = KeyExistence.KNOWN_EXIST;
+            if (persistedValues instanceof Weighted) {
+              keyState.valuesCached = true;
+              ConcatIterables<V> it = new ConcatIterables<>();
+              it.extendWith(persistedValues);
+              keyState.values = it;
+              keyState.valuesSize = Iterables.size(persistedValues);
+            }
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(persistedValues, localNewValues));
+          } catch (InterruptedException | ExecutionException | IOException e) {
+            if (e instanceof InterruptedException) {
+              Thread.currentThread().interrupt();
+            }
+            throw new RuntimeException("Unable to read Multimap state", e);
+          }
+        }
+
+        @Override
+        @SuppressWarnings("FutureReturnValueIgnored")
+        public ReadableState<Iterable<V>> readLater() {
+          WindmillMultimap.this.getFutureForKey(key);
+          return this;
+        }
+      };
+    }
+
+    @Override
+    protected WorkItemCommitRequest persistDirectly(WindmillStateCache.ForKeyAndFamily cache)
+        throws IOException {
+      if (!cleared && !hasLocalAdditions && !hasLocalRemovals) {
+        cache.put(namespace, address, this, 1);
+        return WorkItemCommitRequest.newBuilder().buildPartial();
+      }
+      WorkItemCommitRequest.Builder commitBuilder = WorkItemCommitRequest.newBuilder();
+      Windmill.TagMultimapUpdateRequest.Builder builder = commitBuilder.addMultimapUpdatesBuilder();
+      builder.setTag(stateKey).setStateFamily(stateFamily);
+
+      if (cleared) {
+        builder.setDeleteAll(true);
+      }
+      if (hasLocalRemovals || hasLocalAdditions) {
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        ByteStringOutputStream valueStream = new ByteStringOutputStream();
+        Iterator<Entry<Object, KeyState>> iterator = keyStateMap.entrySet().iterator();
+        while (iterator.hasNext()) {
+          KeyState keyState = iterator.next().getValue();
+          if (!keyState.removedLocally && keyState.localAdditions.isEmpty()) {
+            if (keyState.existence == KeyExistence.KNOWN_NONEXISTENT) iterator.remove();
+            continue;
+          }
+          keyCoder.encode(keyState.originalKey, keyStream, Context.OUTER);
+          ByteString encodedKey = keyStream.toByteStringAndReset();
+          Windmill.TagMultimapEntry.Builder entryBuilder = builder.addUpdatesBuilder();
+          entryBuilder.setEntryName(encodedKey);
+          entryBuilder.setDeleteAll(keyState.removedLocally);
+          keyState.removedLocally = false;
+          for (V value : keyState.localAdditions) {
+            valueCoder.encode(value, valueStream, Context.OUTER);
+            ByteString encodedValue = valueStream.toByteStringAndReset();
+            entryBuilder.addValues(encodedValue);
+          }
+          // Move newly added values from localAdditions to keyState.values as those new values now
+          // are also persisted in Windmill. If a key now has no more values and is not KNOWN_EXIST,
+          // remove it from cache.
+          if (keyState.valuesCached) {
+            keyState.values.extendWith(keyState.localAdditions);
+            keyState.valuesSize += keyState.localAdditions.size();
+          }
+          // Create a new localAdditions so that the cached values are unaffected.
+          keyState.localAdditions = Lists.newArrayList();
+          if (!keyState.valuesCached && keyState.existence != KeyExistence.KNOWN_EXIST) {
+            iterator.remove();
+          }
+        }
+      }
+
+      hasLocalAdditions = false;
+      hasLocalRemovals = false;
+      cleared = false;
+
+      cache.put(namespace, address, this, 1);
+      return commitBuilder.buildPartial();
+    }
+
+    @Override
+    public void remove(K key) {
+      final Object structuralKey = keyCoder.structuralValue(key);
+      KeyState keyState = keyStateMap.computeIfAbsent(structuralKey, k -> new KeyState(key));
+      if (keyState.existence == KeyExistence.KNOWN_NONEXISTENT
+          || (allKeysKnown && keyState.existence == KeyExistence.UNKNOWN_EXISTENCE)) {
+        return;
+      }
+      if (!keyState.valuesCached || (keyState.valuesCached && keyState.valuesSize > 0)) {
+        // there may be data in windmill that need to be removed.
+        hasLocalRemovals = true;
+        keyState.removedLocally = true;
+        keyState.values = new ConcatIterables<>();
+        keyState.valuesSize = 0;
+        keyState.existence = KeyExistence.KNOWN_NONEXISTENT;
+      } else {
+        // no data in windmill, deleting from local cache is sufficient.
+        keyStateMap.remove(structuralKey);
+      }
+      if (!keyState.localAdditions.isEmpty()) {
+        keyState.localAdditions = Lists.newArrayList();
+      }
+      keyState.valuesCached = true;
+    }
+
+    @Override
+    public void clear() {
+      keyStateMap = Maps.newHashMap();
+      cleared = true;
+      complete = true;
+      allKeysKnown = true;

Review Comment:
   Done



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternals.java:
##########
@@ -1588,7 +1598,648 @@ private Future<Iterable<Map.Entry<ByteString, V>>> getFuture() {
         return reader.valuePrefixFuture(stateKeyPrefix, stateFamily, valueCoder);
       }
     }
-  };
+  }
+
+  private static class WindmillMultimap<K, V> extends SimpleWindmillState
+      implements MultimapState<K, V> {
+
+    private final StateNamespace namespace;
+    private final StateTag<MultimapState<K, V>> address;
+    private final ByteString stateKey;
+    private final String stateFamily;
+    private final Coder<K> keyCoder;
+    private final Coder<V> valueCoder;
+
+    private class KeyState {
+      final K originalKey;
+      KeyExistence existence;
+      // valuesCached can be true if only existence == KNOWN_EXIST and all values of this key is
+      // cached (both values and localAdditions).
+      boolean valuesCached;
+      // Represents the values in windmill. When new values are added during user processing, they
+      // are added to localAdditions but not values. Those new values will be added to values only
+      // after they are persisted into windmill and removed from localAdditions
+      ConcatIterables<V> values;
+      int valuesSize;
+
+      // When new values are added during user processing, they are added to localAdditions, so that
+      // we can later try to persist them in windmill. When a key is removed during user processing,
+      // we mark removedLocally to be true so that we can later try to delete it from windmill. If
+      // localAdditions is not empty and removedLocally is true, values in localAdditions will be
+      // added to windmill after old values in windmill are removed.
+      List<V> localAdditions;
+      boolean removedLocally;
+
+      KeyState(K originalKey) {
+        this.originalKey = originalKey;
+        existence = KeyExistence.UNKNOWN_EXISTENCE;
+        valuesCached = complete;
+        values = new ConcatIterables<>();
+        valuesSize = 0;
+        localAdditions = Lists.newArrayList();
+        removedLocally = false;
+      }
+    }
+
+    private enum KeyExistence {
+      // this key is known to exist
+      KNOWN_EXIST,
+      // this key is known to be nonexistent
+      KNOWN_NONEXISTENT,
+      // we don't know if this key is in this multimap, this is just to provide a mapping between
+      // the original key and the structural key.
+      UNKNOWN_EXISTENCE
+    }
+
+    private boolean cleared = false;
+    // We use the structural value of the keys as the key in keyStateMap, so that different java
+    // Objects with the same content will be treated as the same Multimap key.
+    private Map<Object, KeyState> keyStateMap = Maps.newHashMap();
+    // If true, all keys are cached in keyStateMap with existence == KNOWN_EXIST.
+    private boolean allKeysKnown = false;
+
+    private boolean complete = false;
+    // hasLocalAdditions and hasLocalRemovals track whether there are local changes that needs to be
+    // propagated to windmill.
+    private boolean hasLocalAdditions = false;
+    private boolean hasLocalRemovals = false;
+
+    private WindmillMultimap(
+        StateNamespace namespace,
+        StateTag<MultimapState<K, V>> address,
+        String stateFamily,
+        Coder<K> keyCoder,
+        Coder<V> valueCoder,
+        boolean isNewShardingKey) {
+      this.namespace = namespace;
+      this.address = address;
+      this.stateKey = encodeKey(namespace, address);
+      this.stateFamily = stateFamily;
+      this.keyCoder = keyCoder;
+      this.valueCoder = valueCoder;
+      this.complete = isNewShardingKey;
+      this.allKeysKnown = isNewShardingKey;
+    }
+
+    @Override
+    public void put(K key, V value) {
+      final Object structuralKey = keyCoder.structuralValue(key);
+      hasLocalAdditions = true;
+      keyStateMap.compute(
+          structuralKey,
+          (k, v) -> {
+            if (v == null) v = new KeyState(key);
+            v.existence = KeyExistence.KNOWN_EXIST;
+            v.localAdditions.add(value);
+            return v;
+          });
+    }
+
+    // Initiates a backend state read to fetch all entries if necessary.
+    private Future<Iterable<Map.Entry<ByteString, Iterable<V>>>> getFuture(boolean omitValues) {
+      if (complete) {
+        return Futures.immediateFuture(Collections.emptyList());
+      } else {
+        return reader.multimapFetchAllFuture(omitValues, stateKey, stateFamily, valueCoder);
+      }
+    }
+
+    // Initiates a backend state read to fetch a single entry if necessary.
+    private Future<Iterable<V>> getFutureForKey(K key) {

Review Comment:
   Done



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternals.java:
##########
@@ -1588,7 +1598,648 @@ private Future<Iterable<Map.Entry<ByteString, V>>> getFuture() {
         return reader.valuePrefixFuture(stateKeyPrefix, stateFamily, valueCoder);
       }
     }
-  };
+  }
+
+  private static class WindmillMultimap<K, V> extends SimpleWindmillState
+      implements MultimapState<K, V> {
+
+    private final StateNamespace namespace;
+    private final StateTag<MultimapState<K, V>> address;
+    private final ByteString stateKey;
+    private final String stateFamily;
+    private final Coder<K> keyCoder;
+    private final Coder<V> valueCoder;
+
+    private class KeyState {
+      final K originalKey;
+      KeyExistence existence;
+      // valuesCached can be true if only existence == KNOWN_EXIST and all values of this key is
+      // cached (both values and localAdditions).
+      boolean valuesCached;
+      // Represents the values in windmill. When new values are added during user processing, they
+      // are added to localAdditions but not values. Those new values will be added to values only
+      // after they are persisted into windmill and removed from localAdditions
+      ConcatIterables<V> values;
+      int valuesSize;
+
+      // When new values are added during user processing, they are added to localAdditions, so that
+      // we can later try to persist them in windmill. When a key is removed during user processing,
+      // we mark removedLocally to be true so that we can later try to delete it from windmill. If
+      // localAdditions is not empty and removedLocally is true, values in localAdditions will be
+      // added to windmill after old values in windmill are removed.
+      List<V> localAdditions;
+      boolean removedLocally;
+
+      KeyState(K originalKey) {
+        this.originalKey = originalKey;
+        existence = KeyExistence.UNKNOWN_EXISTENCE;
+        valuesCached = complete;
+        values = new ConcatIterables<>();
+        valuesSize = 0;
+        localAdditions = Lists.newArrayList();
+        removedLocally = false;
+      }
+    }
+
+    private enum KeyExistence {
+      // this key is known to exist

Review Comment:
   Added comment. It's like a cache of `map.containsKey()`:
   
   * known to exist: this key has at least 1 value combining local and windmill
   * known to not-exist: this key has exact 0 value combining local and windmill
   * unknown: this key has exact 0 value locally; but 0 or any values in windmill



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternals.java:
##########
@@ -1588,7 +1598,648 @@ private Future<Iterable<Map.Entry<ByteString, V>>> getFuture() {
         return reader.valuePrefixFuture(stateKeyPrefix, stateFamily, valueCoder);
       }
     }
-  };
+  }
+
+  private static class WindmillMultimap<K, V> extends SimpleWindmillState
+      implements MultimapState<K, V> {
+
+    private final StateNamespace namespace;
+    private final StateTag<MultimapState<K, V>> address;
+    private final ByteString stateKey;
+    private final String stateFamily;
+    private final Coder<K> keyCoder;
+    private final Coder<V> valueCoder;
+
+    private class KeyState {
+      final K originalKey;
+      KeyExistence existence;
+      // valuesCached can be true if only existence == KNOWN_EXIST and all values of this key is
+      // cached (both values and localAdditions).
+      boolean valuesCached;
+      // Represents the values in windmill. When new values are added during user processing, they
+      // are added to localAdditions but not values. Those new values will be added to values only
+      // after they are persisted into windmill and removed from localAdditions
+      ConcatIterables<V> values;
+      int valuesSize;
+
+      // When new values are added during user processing, they are added to localAdditions, so that
+      // we can later try to persist them in windmill. When a key is removed during user processing,
+      // we mark removedLocally to be true so that we can later try to delete it from windmill. If
+      // localAdditions is not empty and removedLocally is true, values in localAdditions will be
+      // added to windmill after old values in windmill are removed.
+      List<V> localAdditions;
+      boolean removedLocally;
+
+      KeyState(K originalKey) {
+        this.originalKey = originalKey;
+        existence = KeyExistence.UNKNOWN_EXISTENCE;
+        valuesCached = complete;
+        values = new ConcatIterables<>();
+        valuesSize = 0;
+        localAdditions = Lists.newArrayList();
+        removedLocally = false;
+      }
+    }
+
+    private enum KeyExistence {
+      // this key is known to exist
+      KNOWN_EXIST,
+      // this key is known to be nonexistent
+      KNOWN_NONEXISTENT,
+      // we don't know if this key is in this multimap, this is just to provide a mapping between
+      // the original key and the structural key.
+      UNKNOWN_EXISTENCE
+    }
+
+    private boolean cleared = false;
+    // We use the structural value of the keys as the key in keyStateMap, so that different java
+    // Objects with the same content will be treated as the same Multimap key.
+    private Map<Object, KeyState> keyStateMap = Maps.newHashMap();
+    // If true, all keys are cached in keyStateMap with existence == KNOWN_EXIST.
+    private boolean allKeysKnown = false;
+
+    private boolean complete = false;
+    // hasLocalAdditions and hasLocalRemovals track whether there are local changes that needs to be
+    // propagated to windmill.
+    private boolean hasLocalAdditions = false;
+    private boolean hasLocalRemovals = false;
+
+    private WindmillMultimap(
+        StateNamespace namespace,
+        StateTag<MultimapState<K, V>> address,
+        String stateFamily,
+        Coder<K> keyCoder,
+        Coder<V> valueCoder,
+        boolean isNewShardingKey) {
+      this.namespace = namespace;
+      this.address = address;
+      this.stateKey = encodeKey(namespace, address);
+      this.stateFamily = stateFamily;
+      this.keyCoder = keyCoder;
+      this.valueCoder = valueCoder;
+      this.complete = isNewShardingKey;
+      this.allKeysKnown = isNewShardingKey;
+    }
+
+    @Override
+    public void put(K key, V value) {
+      final Object structuralKey = keyCoder.structuralValue(key);
+      hasLocalAdditions = true;
+      keyStateMap.compute(
+          structuralKey,
+          (k, v) -> {
+            if (v == null) v = new KeyState(key);
+            v.existence = KeyExistence.KNOWN_EXIST;
+            v.localAdditions.add(value);
+            return v;
+          });
+    }
+
+    // Initiates a backend state read to fetch all entries if necessary.
+    private Future<Iterable<Map.Entry<ByteString, Iterable<V>>>> getFuture(boolean omitValues) {
+      if (complete) {
+        return Futures.immediateFuture(Collections.emptyList());
+      } else {
+        return reader.multimapFetchAllFuture(omitValues, stateKey, stateFamily, valueCoder);
+      }
+    }
+
+    // Initiates a backend state read to fetch a single entry if necessary.
+    private Future<Iterable<V>> getFutureForKey(K key) {
+      try {
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        keyCoder.encode(key, keyStream, Context.OUTER);
+        return reader.multimapFetchSingleEntryFuture(
+            keyStream.toByteString(), stateKey, stateFamily, valueCoder);
+      } catch (IOException e) {
+        throw new RuntimeException(e);
+      }
+    }
+
+    @Override
+    public ReadableState<Iterable<V>> get(K key) {
+      return new ReadableState<Iterable<V>>() {
+        final Object structuralKey = keyCoder.structuralValue(key);
+
+        @Override
+        public Iterable<V> read() {
+          KeyState keyState = keyStateMap.computeIfAbsent(structuralKey, k -> new KeyState(key));
+          if (keyState.existence == KeyExistence.KNOWN_NONEXISTENT) {
+            return Collections.emptyList();
+          }
+          if (allKeysKnown && keyState.existence == KeyExistence.UNKNOWN_EXISTENCE) {
+            keyStateMap.remove(structuralKey);
+            return Collections.emptyList();
+          }
+          Iterable<V> localNewValues =
+              Iterables.limit(keyState.localAdditions, keyState.localAdditions.size());
+          if (keyState.removedLocally) {
+            // this key has been removed locally but the removal hasn't been sent to windmill,
+            // thus values in windmill(if any) are obsolete, and we only care about local values.
+            return Iterables.unmodifiableIterable(localNewValues);
+          }
+          if (keyState.valuesCached || complete) {
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(
+                    Iterables.limit(keyState.values, keyState.valuesSize), localNewValues));
+          }
+          Future<Iterable<V>> persistedData = getFutureForKey(key);
+          try (Closeable scope = scopedReadState()) {
+            final Iterable<V> persistedValues = persistedData.get();
+            // Iterables.isEmpty() is O(1).
+            if (Iterables.isEmpty(persistedValues)) {
+              if (keyState.localAdditions.isEmpty()) {
+                // empty in both cache and windmill, mark key as KNOWN_NONEXISTENT.
+                keyState.existence = KeyExistence.KNOWN_NONEXISTENT;
+                return Collections.emptyList();
+              }
+              return Iterables.unmodifiableIterable(localNewValues);
+            }
+            keyState.existence = KeyExistence.KNOWN_EXIST;
+            if (persistedValues instanceof Weighted) {
+              keyState.valuesCached = true;
+              ConcatIterables<V> it = new ConcatIterables<>();
+              it.extendWith(persistedValues);
+              keyState.values = it;
+              keyState.valuesSize = Iterables.size(persistedValues);
+            }
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(persistedValues, localNewValues));
+          } catch (InterruptedException | ExecutionException | IOException e) {
+            if (e instanceof InterruptedException) {
+              Thread.currentThread().interrupt();
+            }
+            throw new RuntimeException("Unable to read Multimap state", e);
+          }
+        }
+
+        @Override
+        @SuppressWarnings("FutureReturnValueIgnored")
+        public ReadableState<Iterable<V>> readLater() {
+          WindmillMultimap.this.getFutureForKey(key);
+          return this;
+        }
+      };
+    }
+
+    @Override
+    protected WorkItemCommitRequest persistDirectly(WindmillStateCache.ForKeyAndFamily cache)
+        throws IOException {
+      if (!cleared && !hasLocalAdditions && !hasLocalRemovals) {
+        cache.put(namespace, address, this, 1);
+        return WorkItemCommitRequest.newBuilder().buildPartial();
+      }
+      WorkItemCommitRequest.Builder commitBuilder = WorkItemCommitRequest.newBuilder();
+      Windmill.TagMultimapUpdateRequest.Builder builder = commitBuilder.addMultimapUpdatesBuilder();
+      builder.setTag(stateKey).setStateFamily(stateFamily);
+
+      if (cleared) {
+        builder.setDeleteAll(true);
+      }
+      if (hasLocalRemovals || hasLocalAdditions) {
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        ByteStringOutputStream valueStream = new ByteStringOutputStream();
+        Iterator<Entry<Object, KeyState>> iterator = keyStateMap.entrySet().iterator();
+        while (iterator.hasNext()) {
+          KeyState keyState = iterator.next().getValue();
+          if (!keyState.removedLocally && keyState.localAdditions.isEmpty()) {
+            if (keyState.existence == KeyExistence.KNOWN_NONEXISTENT) iterator.remove();
+            continue;
+          }
+          keyCoder.encode(keyState.originalKey, keyStream, Context.OUTER);
+          ByteString encodedKey = keyStream.toByteStringAndReset();
+          Windmill.TagMultimapEntry.Builder entryBuilder = builder.addUpdatesBuilder();
+          entryBuilder.setEntryName(encodedKey);
+          entryBuilder.setDeleteAll(keyState.removedLocally);
+          keyState.removedLocally = false;
+          for (V value : keyState.localAdditions) {
+            valueCoder.encode(value, valueStream, Context.OUTER);
+            ByteString encodedValue = valueStream.toByteStringAndReset();
+            entryBuilder.addValues(encodedValue);
+          }
+          // Move newly added values from localAdditions to keyState.values as those new values now
+          // are also persisted in Windmill. If a key now has no more values and is not KNOWN_EXIST,
+          // remove it from cache.
+          if (keyState.valuesCached) {
+            keyState.values.extendWith(keyState.localAdditions);
+            keyState.valuesSize += keyState.localAdditions.size();
+          }
+          // Create a new localAdditions so that the cached values are unaffected.
+          keyState.localAdditions = Lists.newArrayList();
+          if (!keyState.valuesCached && keyState.existence != KeyExistence.KNOWN_EXIST) {
+            iterator.remove();
+          }
+        }
+      }
+
+      hasLocalAdditions = false;
+      hasLocalRemovals = false;
+      cleared = false;
+
+      cache.put(namespace, address, this, 1);
+      return commitBuilder.buildPartial();
+    }
+
+    @Override
+    public void remove(K key) {
+      final Object structuralKey = keyCoder.structuralValue(key);
+      KeyState keyState = keyStateMap.computeIfAbsent(structuralKey, k -> new KeyState(key));
+      if (keyState.existence == KeyExistence.KNOWN_NONEXISTENT
+          || (allKeysKnown && keyState.existence == KeyExistence.UNKNOWN_EXISTENCE)) {
+        return;
+      }
+      if (!keyState.valuesCached || (keyState.valuesCached && keyState.valuesSize > 0)) {
+        // there may be data in windmill that need to be removed.
+        hasLocalRemovals = true;
+        keyState.removedLocally = true;
+        keyState.values = new ConcatIterables<>();
+        keyState.valuesSize = 0;
+        keyState.existence = KeyExistence.KNOWN_NONEXISTENT;
+      } else {
+        // no data in windmill, deleting from local cache is sufficient.

Review Comment:
   With the if branch simplified to `if (!keyState.valuesCached || keyState.valuesSize > 0)`, else branch is equivalent to `else (keyState.valuesCached && keyState.valueSize <= 0) {}`, so else branch is `cached and known empty in windmill`



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternals.java:
##########
@@ -1588,7 +1598,648 @@ private Future<Iterable<Map.Entry<ByteString, V>>> getFuture() {
         return reader.valuePrefixFuture(stateKeyPrefix, stateFamily, valueCoder);
       }
     }
-  };
+  }
+
+  private static class WindmillMultimap<K, V> extends SimpleWindmillState
+      implements MultimapState<K, V> {
+
+    private final StateNamespace namespace;
+    private final StateTag<MultimapState<K, V>> address;
+    private final ByteString stateKey;
+    private final String stateFamily;
+    private final Coder<K> keyCoder;
+    private final Coder<V> valueCoder;
+
+    private class KeyState {
+      final K originalKey;
+      KeyExistence existence;
+      // valuesCached can be true if only existence == KNOWN_EXIST and all values of this key is
+      // cached (both values and localAdditions).
+      boolean valuesCached;
+      // Represents the values in windmill. When new values are added during user processing, they
+      // are added to localAdditions but not values. Those new values will be added to values only
+      // after they are persisted into windmill and removed from localAdditions
+      ConcatIterables<V> values;
+      int valuesSize;
+
+      // When new values are added during user processing, they are added to localAdditions, so that
+      // we can later try to persist them in windmill. When a key is removed during user processing,
+      // we mark removedLocally to be true so that we can later try to delete it from windmill. If
+      // localAdditions is not empty and removedLocally is true, values in localAdditions will be
+      // added to windmill after old values in windmill are removed.
+      List<V> localAdditions;
+      boolean removedLocally;
+
+      KeyState(K originalKey) {
+        this.originalKey = originalKey;
+        existence = KeyExistence.UNKNOWN_EXISTENCE;
+        valuesCached = complete;
+        values = new ConcatIterables<>();
+        valuesSize = 0;
+        localAdditions = Lists.newArrayList();
+        removedLocally = false;
+      }
+    }
+
+    private enum KeyExistence {
+      // this key is known to exist
+      KNOWN_EXIST,
+      // this key is known to be nonexistent
+      KNOWN_NONEXISTENT,
+      // we don't know if this key is in this multimap, this is just to provide a mapping between
+      // the original key and the structural key.
+      UNKNOWN_EXISTENCE
+    }
+
+    private boolean cleared = false;
+    // We use the structural value of the keys as the key in keyStateMap, so that different java
+    // Objects with the same content will be treated as the same Multimap key.
+    private Map<Object, KeyState> keyStateMap = Maps.newHashMap();
+    // If true, all keys are cached in keyStateMap with existence == KNOWN_EXIST.
+    private boolean allKeysKnown = false;
+
+    private boolean complete = false;
+    // hasLocalAdditions and hasLocalRemovals track whether there are local changes that needs to be
+    // propagated to windmill.
+    private boolean hasLocalAdditions = false;
+    private boolean hasLocalRemovals = false;
+
+    private WindmillMultimap(
+        StateNamespace namespace,
+        StateTag<MultimapState<K, V>> address,
+        String stateFamily,
+        Coder<K> keyCoder,
+        Coder<V> valueCoder,
+        boolean isNewShardingKey) {
+      this.namespace = namespace;
+      this.address = address;
+      this.stateKey = encodeKey(namespace, address);
+      this.stateFamily = stateFamily;
+      this.keyCoder = keyCoder;
+      this.valueCoder = valueCoder;
+      this.complete = isNewShardingKey;
+      this.allKeysKnown = isNewShardingKey;
+    }
+
+    @Override
+    public void put(K key, V value) {
+      final Object structuralKey = keyCoder.structuralValue(key);
+      hasLocalAdditions = true;
+      keyStateMap.compute(
+          structuralKey,
+          (k, v) -> {
+            if (v == null) v = new KeyState(key);
+            v.existence = KeyExistence.KNOWN_EXIST;
+            v.localAdditions.add(value);
+            return v;
+          });
+    }
+
+    // Initiates a backend state read to fetch all entries if necessary.
+    private Future<Iterable<Map.Entry<ByteString, Iterable<V>>>> getFuture(boolean omitValues) {
+      if (complete) {
+        return Futures.immediateFuture(Collections.emptyList());
+      } else {
+        return reader.multimapFetchAllFuture(omitValues, stateKey, stateFamily, valueCoder);
+      }
+    }
+
+    // Initiates a backend state read to fetch a single entry if necessary.
+    private Future<Iterable<V>> getFutureForKey(K key) {
+      try {
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        keyCoder.encode(key, keyStream, Context.OUTER);
+        return reader.multimapFetchSingleEntryFuture(
+            keyStream.toByteString(), stateKey, stateFamily, valueCoder);
+      } catch (IOException e) {
+        throw new RuntimeException(e);
+      }
+    }
+
+    @Override
+    public ReadableState<Iterable<V>> get(K key) {
+      return new ReadableState<Iterable<V>>() {
+        final Object structuralKey = keyCoder.structuralValue(key);
+
+        @Override
+        public Iterable<V> read() {
+          KeyState keyState = keyStateMap.computeIfAbsent(structuralKey, k -> new KeyState(key));
+          if (keyState.existence == KeyExistence.KNOWN_NONEXISTENT) {
+            return Collections.emptyList();
+          }
+          if (allKeysKnown && keyState.existence == KeyExistence.UNKNOWN_EXISTENCE) {
+            keyStateMap.remove(structuralKey);
+            return Collections.emptyList();
+          }
+          Iterable<V> localNewValues =
+              Iterables.limit(keyState.localAdditions, keyState.localAdditions.size());
+          if (keyState.removedLocally) {
+            // this key has been removed locally but the removal hasn't been sent to windmill,
+            // thus values in windmill(if any) are obsolete, and we only care about local values.
+            return Iterables.unmodifiableIterable(localNewValues);
+          }
+          if (keyState.valuesCached || complete) {
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(
+                    Iterables.limit(keyState.values, keyState.valuesSize), localNewValues));
+          }
+          Future<Iterable<V>> persistedData = getFutureForKey(key);
+          try (Closeable scope = scopedReadState()) {
+            final Iterable<V> persistedValues = persistedData.get();
+            // Iterables.isEmpty() is O(1).
+            if (Iterables.isEmpty(persistedValues)) {
+              if (keyState.localAdditions.isEmpty()) {
+                // empty in both cache and windmill, mark key as KNOWN_NONEXISTENT.
+                keyState.existence = KeyExistence.KNOWN_NONEXISTENT;
+                return Collections.emptyList();
+              }
+              return Iterables.unmodifiableIterable(localNewValues);
+            }
+            keyState.existence = KeyExistence.KNOWN_EXIST;
+            if (persistedValues instanceof Weighted) {
+              keyState.valuesCached = true;
+              ConcatIterables<V> it = new ConcatIterables<>();
+              it.extendWith(persistedValues);
+              keyState.values = it;
+              keyState.valuesSize = Iterables.size(persistedValues);
+            }
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(persistedValues, localNewValues));
+          } catch (InterruptedException | ExecutionException | IOException e) {
+            if (e instanceof InterruptedException) {
+              Thread.currentThread().interrupt();
+            }
+            throw new RuntimeException("Unable to read Multimap state", e);
+          }
+        }
+
+        @Override
+        @SuppressWarnings("FutureReturnValueIgnored")
+        public ReadableState<Iterable<V>> readLater() {
+          WindmillMultimap.this.getFutureForKey(key);
+          return this;
+        }
+      };
+    }
+
+    @Override
+    protected WorkItemCommitRequest persistDirectly(WindmillStateCache.ForKeyAndFamily cache)
+        throws IOException {
+      if (!cleared && !hasLocalAdditions && !hasLocalRemovals) {
+        cache.put(namespace, address, this, 1);
+        return WorkItemCommitRequest.newBuilder().buildPartial();
+      }
+      WorkItemCommitRequest.Builder commitBuilder = WorkItemCommitRequest.newBuilder();
+      Windmill.TagMultimapUpdateRequest.Builder builder = commitBuilder.addMultimapUpdatesBuilder();
+      builder.setTag(stateKey).setStateFamily(stateFamily);
+
+      if (cleared) {
+        builder.setDeleteAll(true);
+      }
+      if (hasLocalRemovals || hasLocalAdditions) {
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        ByteStringOutputStream valueStream = new ByteStringOutputStream();
+        Iterator<Entry<Object, KeyState>> iterator = keyStateMap.entrySet().iterator();
+        while (iterator.hasNext()) {
+          KeyState keyState = iterator.next().getValue();
+          if (!keyState.removedLocally && keyState.localAdditions.isEmpty()) {
+            if (keyState.existence == KeyExistence.KNOWN_NONEXISTENT) iterator.remove();
+            continue;
+          }
+          keyCoder.encode(keyState.originalKey, keyStream, Context.OUTER);
+          ByteString encodedKey = keyStream.toByteStringAndReset();
+          Windmill.TagMultimapEntry.Builder entryBuilder = builder.addUpdatesBuilder();
+          entryBuilder.setEntryName(encodedKey);
+          entryBuilder.setDeleteAll(keyState.removedLocally);
+          keyState.removedLocally = false;
+          for (V value : keyState.localAdditions) {
+            valueCoder.encode(value, valueStream, Context.OUTER);
+            ByteString encodedValue = valueStream.toByteStringAndReset();
+            entryBuilder.addValues(encodedValue);
+          }
+          // Move newly added values from localAdditions to keyState.values as those new values now
+          // are also persisted in Windmill. If a key now has no more values and is not KNOWN_EXIST,
+          // remove it from cache.
+          if (keyState.valuesCached) {
+            keyState.values.extendWith(keyState.localAdditions);
+            keyState.valuesSize += keyState.localAdditions.size();
+          }
+          // Create a new localAdditions so that the cached values are unaffected.
+          keyState.localAdditions = Lists.newArrayList();
+          if (!keyState.valuesCached && keyState.existence != KeyExistence.KNOWN_EXIST) {
+            iterator.remove();
+          }
+        }
+      }
+
+      hasLocalAdditions = false;
+      hasLocalRemovals = false;
+      cleared = false;
+
+      cache.put(namespace, address, this, 1);
+      return commitBuilder.buildPartial();
+    }
+
+    @Override
+    public void remove(K key) {
+      final Object structuralKey = keyCoder.structuralValue(key);
+      KeyState keyState = keyStateMap.computeIfAbsent(structuralKey, k -> new KeyState(key));
+      if (keyState.existence == KeyExistence.KNOWN_NONEXISTENT
+          || (allKeysKnown && keyState.existence == KeyExistence.UNKNOWN_EXISTENCE)) {

Review Comment:
   Done



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternals.java:
##########
@@ -1588,7 +1598,648 @@ private Future<Iterable<Map.Entry<ByteString, V>>> getFuture() {
         return reader.valuePrefixFuture(stateKeyPrefix, stateFamily, valueCoder);
       }
     }
-  };
+  }
+
+  private static class WindmillMultimap<K, V> extends SimpleWindmillState
+      implements MultimapState<K, V> {
+
+    private final StateNamespace namespace;
+    private final StateTag<MultimapState<K, V>> address;
+    private final ByteString stateKey;
+    private final String stateFamily;
+    private final Coder<K> keyCoder;
+    private final Coder<V> valueCoder;
+
+    private class KeyState {
+      final K originalKey;
+      KeyExistence existence;
+      // valuesCached can be true if only existence == KNOWN_EXIST and all values of this key is
+      // cached (both values and localAdditions).
+      boolean valuesCached;
+      // Represents the values in windmill. When new values are added during user processing, they
+      // are added to localAdditions but not values. Those new values will be added to values only
+      // after they are persisted into windmill and removed from localAdditions
+      ConcatIterables<V> values;
+      int valuesSize;
+
+      // When new values are added during user processing, they are added to localAdditions, so that
+      // we can later try to persist them in windmill. When a key is removed during user processing,
+      // we mark removedLocally to be true so that we can later try to delete it from windmill. If
+      // localAdditions is not empty and removedLocally is true, values in localAdditions will be
+      // added to windmill after old values in windmill are removed.
+      List<V> localAdditions;
+      boolean removedLocally;
+
+      KeyState(K originalKey) {
+        this.originalKey = originalKey;
+        existence = KeyExistence.UNKNOWN_EXISTENCE;
+        valuesCached = complete;
+        values = new ConcatIterables<>();
+        valuesSize = 0;
+        localAdditions = Lists.newArrayList();
+        removedLocally = false;
+      }
+    }
+
+    private enum KeyExistence {
+      // this key is known to exist
+      KNOWN_EXIST,
+      // this key is known to be nonexistent
+      KNOWN_NONEXISTENT,
+      // we don't know if this key is in this multimap, this is just to provide a mapping between
+      // the original key and the structural key.
+      UNKNOWN_EXISTENCE
+    }
+
+    private boolean cleared = false;
+    // We use the structural value of the keys as the key in keyStateMap, so that different java
+    // Objects with the same content will be treated as the same Multimap key.
+    private Map<Object, KeyState> keyStateMap = Maps.newHashMap();
+    // If true, all keys are cached in keyStateMap with existence == KNOWN_EXIST.
+    private boolean allKeysKnown = false;
+
+    private boolean complete = false;
+    // hasLocalAdditions and hasLocalRemovals track whether there are local changes that needs to be
+    // propagated to windmill.
+    private boolean hasLocalAdditions = false;
+    private boolean hasLocalRemovals = false;
+
+    private WindmillMultimap(
+        StateNamespace namespace,
+        StateTag<MultimapState<K, V>> address,
+        String stateFamily,
+        Coder<K> keyCoder,
+        Coder<V> valueCoder,
+        boolean isNewShardingKey) {
+      this.namespace = namespace;
+      this.address = address;
+      this.stateKey = encodeKey(namespace, address);
+      this.stateFamily = stateFamily;
+      this.keyCoder = keyCoder;
+      this.valueCoder = valueCoder;
+      this.complete = isNewShardingKey;
+      this.allKeysKnown = isNewShardingKey;
+    }
+
+    @Override
+    public void put(K key, V value) {
+      final Object structuralKey = keyCoder.structuralValue(key);
+      hasLocalAdditions = true;
+      keyStateMap.compute(
+          structuralKey,
+          (k, v) -> {
+            if (v == null) v = new KeyState(key);
+            v.existence = KeyExistence.KNOWN_EXIST;
+            v.localAdditions.add(value);
+            return v;
+          });
+    }
+
+    // Initiates a backend state read to fetch all entries if necessary.
+    private Future<Iterable<Map.Entry<ByteString, Iterable<V>>>> getFuture(boolean omitValues) {
+      if (complete) {
+        return Futures.immediateFuture(Collections.emptyList());
+      } else {
+        return reader.multimapFetchAllFuture(omitValues, stateKey, stateFamily, valueCoder);
+      }
+    }
+
+    // Initiates a backend state read to fetch a single entry if necessary.
+    private Future<Iterable<V>> getFutureForKey(K key) {
+      try {
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        keyCoder.encode(key, keyStream, Context.OUTER);
+        return reader.multimapFetchSingleEntryFuture(
+            keyStream.toByteString(), stateKey, stateFamily, valueCoder);
+      } catch (IOException e) {
+        throw new RuntimeException(e);
+      }
+    }
+
+    @Override
+    public ReadableState<Iterable<V>> get(K key) {
+      return new ReadableState<Iterable<V>>() {
+        final Object structuralKey = keyCoder.structuralValue(key);
+
+        @Override
+        public Iterable<V> read() {
+          KeyState keyState = keyStateMap.computeIfAbsent(structuralKey, k -> new KeyState(key));
+          if (keyState.existence == KeyExistence.KNOWN_NONEXISTENT) {
+            return Collections.emptyList();
+          }
+          if (allKeysKnown && keyState.existence == KeyExistence.UNKNOWN_EXISTENCE) {
+            keyStateMap.remove(structuralKey);
+            return Collections.emptyList();
+          }
+          Iterable<V> localNewValues =
+              Iterables.limit(keyState.localAdditions, keyState.localAdditions.size());
+          if (keyState.removedLocally) {
+            // this key has been removed locally but the removal hasn't been sent to windmill,
+            // thus values in windmill(if any) are obsolete, and we only care about local values.
+            return Iterables.unmodifiableIterable(localNewValues);
+          }
+          if (keyState.valuesCached || complete) {
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(
+                    Iterables.limit(keyState.values, keyState.valuesSize), localNewValues));
+          }
+          Future<Iterable<V>> persistedData = getFutureForKey(key);
+          try (Closeable scope = scopedReadState()) {
+            final Iterable<V> persistedValues = persistedData.get();
+            // Iterables.isEmpty() is O(1).
+            if (Iterables.isEmpty(persistedValues)) {
+              if (keyState.localAdditions.isEmpty()) {
+                // empty in both cache and windmill, mark key as KNOWN_NONEXISTENT.
+                keyState.existence = KeyExistence.KNOWN_NONEXISTENT;
+                return Collections.emptyList();
+              }
+              return Iterables.unmodifiableIterable(localNewValues);
+            }
+            keyState.existence = KeyExistence.KNOWN_EXIST;
+            if (persistedValues instanceof Weighted) {
+              keyState.valuesCached = true;
+              ConcatIterables<V> it = new ConcatIterables<>();
+              it.extendWith(persistedValues);
+              keyState.values = it;
+              keyState.valuesSize = Iterables.size(persistedValues);
+            }
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(persistedValues, localNewValues));
+          } catch (InterruptedException | ExecutionException | IOException e) {
+            if (e instanceof InterruptedException) {
+              Thread.currentThread().interrupt();
+            }
+            throw new RuntimeException("Unable to read Multimap state", e);
+          }
+        }
+
+        @Override
+        @SuppressWarnings("FutureReturnValueIgnored")
+        public ReadableState<Iterable<V>> readLater() {
+          WindmillMultimap.this.getFutureForKey(key);
+          return this;
+        }
+      };
+    }
+
+    @Override
+    protected WorkItemCommitRequest persistDirectly(WindmillStateCache.ForKeyAndFamily cache)
+        throws IOException {
+      if (!cleared && !hasLocalAdditions && !hasLocalRemovals) {
+        cache.put(namespace, address, this, 1);
+        return WorkItemCommitRequest.newBuilder().buildPartial();
+      }
+      WorkItemCommitRequest.Builder commitBuilder = WorkItemCommitRequest.newBuilder();
+      Windmill.TagMultimapUpdateRequest.Builder builder = commitBuilder.addMultimapUpdatesBuilder();
+      builder.setTag(stateKey).setStateFamily(stateFamily);
+
+      if (cleared) {
+        builder.setDeleteAll(true);
+      }
+      if (hasLocalRemovals || hasLocalAdditions) {
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        ByteStringOutputStream valueStream = new ByteStringOutputStream();
+        Iterator<Entry<Object, KeyState>> iterator = keyStateMap.entrySet().iterator();
+        while (iterator.hasNext()) {
+          KeyState keyState = iterator.next().getValue();
+          if (!keyState.removedLocally && keyState.localAdditions.isEmpty()) {
+            if (keyState.existence == KeyExistence.KNOWN_NONEXISTENT) iterator.remove();
+            continue;
+          }
+          keyCoder.encode(keyState.originalKey, keyStream, Context.OUTER);
+          ByteString encodedKey = keyStream.toByteStringAndReset();
+          Windmill.TagMultimapEntry.Builder entryBuilder = builder.addUpdatesBuilder();
+          entryBuilder.setEntryName(encodedKey);
+          entryBuilder.setDeleteAll(keyState.removedLocally);
+          keyState.removedLocally = false;
+          for (V value : keyState.localAdditions) {
+            valueCoder.encode(value, valueStream, Context.OUTER);
+            ByteString encodedValue = valueStream.toByteStringAndReset();
+            entryBuilder.addValues(encodedValue);
+          }
+          // Move newly added values from localAdditions to keyState.values as those new values now
+          // are also persisted in Windmill. If a key now has no more values and is not KNOWN_EXIST,
+          // remove it from cache.
+          if (keyState.valuesCached) {
+            keyState.values.extendWith(keyState.localAdditions);
+            keyState.valuesSize += keyState.localAdditions.size();
+          }
+          // Create a new localAdditions so that the cached values are unaffected.
+          keyState.localAdditions = Lists.newArrayList();
+          if (!keyState.valuesCached && keyState.existence != KeyExistence.KNOWN_EXIST) {
+            iterator.remove();
+          }
+        }
+      }
+
+      hasLocalAdditions = false;
+      hasLocalRemovals = false;
+      cleared = false;
+
+      cache.put(namespace, address, this, 1);
+      return commitBuilder.buildPartial();
+    }
+
+    @Override
+    public void remove(K key) {
+      final Object structuralKey = keyCoder.structuralValue(key);
+      KeyState keyState = keyStateMap.computeIfAbsent(structuralKey, k -> new KeyState(key));
+      if (keyState.existence == KeyExistence.KNOWN_NONEXISTENT
+          || (allKeysKnown && keyState.existence == KeyExistence.UNKNOWN_EXISTENCE)) {
+        return;
+      }
+      if (!keyState.valuesCached || (keyState.valuesCached && keyState.valuesSize > 0)) {
+        // there may be data in windmill that need to be removed.
+        hasLocalRemovals = true;
+        keyState.removedLocally = true;
+        keyState.values = new ConcatIterables<>();
+        keyState.valuesSize = 0;
+        keyState.existence = KeyExistence.KNOWN_NONEXISTENT;
+      } else {
+        // no data in windmill, deleting from local cache is sufficient.
+        keyStateMap.remove(structuralKey);
+      }
+      if (!keyState.localAdditions.isEmpty()) {
+        keyState.localAdditions = Lists.newArrayList();
+      }
+      keyState.valuesCached = true;
+    }
+
+    @Override
+    public void clear() {
+      keyStateMap = Maps.newHashMap();
+      cleared = true;
+      complete = true;
+      allKeysKnown = true;
+    }
+
+    @Override
+    public ReadableState<Iterable<K>> keys() {
+      return new ReadableState<Iterable<K>>() {
+
+        private Map<Object, K> cachedExistKeys() {
+          return keyStateMap.entrySet().stream()
+              .filter(entry -> entry.getValue().existence == KeyExistence.KNOWN_EXIST)
+              .collect(Collectors.toMap(Entry::getKey, e -> e.getValue().originalKey));
+        }
+
+        @Override
+        public Iterable<K> read() {
+          if (allKeysKnown) {
+            return Iterables.unmodifiableIterable(cachedExistKeys().values());
+          }
+          Future<Iterable<Entry<ByteString, Iterable<V>>>> persistedData = getFuture(true);
+          try (Closeable scope = scopedReadState()) {
+            Iterable<Entry<ByteString, Iterable<V>>> entries = persistedData.get();
+            if (entries instanceof Weighted) {
+              // This is a known amount of data, cache them all.
+              entries.forEach(
+                  entry -> {
+                    try {
+                      K originalKey = keyCoder.decode(entry.getKey().newInput(), Context.OUTER);
+                      KeyState keyState =
+                          keyStateMap.computeIfAbsent(
+                              keyCoder.structuralValue(originalKey),
+                              stk -> new KeyState(originalKey));
+                      if (keyState.existence == KeyExistence.UNKNOWN_EXISTENCE) {
+                        keyState.existence = KeyExistence.KNOWN_EXIST;
+                      }
+                    } catch (IOException e) {
+                      throw new RuntimeException(e);
+                    }
+                  });
+              allKeysKnown = true;
+              keyStateMap
+                  .values()
+                  .removeIf(keyState -> keyState.existence != KeyExistence.KNOWN_EXIST);
+              return Iterables.unmodifiableIterable(cachedExistKeys().values());
+            } else {
+              Map<Object, K> cachedExistKeys = Maps.newHashMap();
+              Set<Object> cachedNonExistKeys = Sets.newHashSet();
+              keyStateMap.forEach(
+                  (structuralKey, keyState) -> {
+                    switch (keyState.existence) {
+                      case KNOWN_EXIST:
+                        cachedExistKeys.put(structuralKey, keyState.originalKey);
+                        break;
+                      case KNOWN_NONEXISTENT:
+                        cachedNonExistKeys.add(structuralKey);
+                        break;
+                      default:
+                        break;
+                    }
+                  });
+              // keysOnlyInWindmill is lazily loaded.
+              Iterable<K> keysOnlyInWindmill =
+                  Iterables.filter(
+                      Iterables.transform(
+                          entries,
+                          entry -> {
+                            try {
+                              K originalKey =
+                                  keyCoder.decode(entry.getKey().newInput(), Context.OUTER);
+                              Object structuralKey = keyCoder.structuralValue(originalKey);
+                              if (cachedExistKeys.containsKey(structuralKey)
+                                  || cachedNonExistKeys.contains(structuralKey)) return null;
+                              return originalKey;
+                            } catch (IOException e) {
+                              throw new RuntimeException(e);
+                            }
+                          }),
+                      Objects::nonNull);
+              return Iterables.unmodifiableIterable(
+                  Iterables.concat(cachedExistKeys.values(), keysOnlyInWindmill));
+            }
+          } catch (InterruptedException | ExecutionException | IOException e) {
+            if (e instanceof InterruptedException) {
+              Thread.currentThread().interrupt();
+            }
+            throw new RuntimeException("Unable to read state", e);
+          }
+        }
+
+        @Override
+        @SuppressWarnings("FutureReturnValueIgnored")
+        public ReadableState<Iterable<K>> readLater() {
+          WindmillMultimap.this.getFuture(true);
+          return this;
+        }
+      };
+    }
+
+    @Override
+    public ReadableState<Iterable<Entry<K, V>>> entries() {
+      return new ReadableState<Iterable<Entry<K, V>>>() {
+        @Override
+        public Iterable<Entry<K, V>> read() {
+          if (complete) {
+            return Iterables.unmodifiableIterable(
+                unnestCachedEntries(mergedCachedEntries(null).entrySet()));
+          }
+          Future<Iterable<Entry<ByteString, Iterable<V>>>> persistedData = getFuture(false);
+          try (Closeable scope = scopedReadState()) {
+            Iterable<Entry<ByteString, Iterable<V>>> entries = persistedData.get();
+            if (Iterables.isEmpty(entries)) {
+              complete = true;
+              allKeysKnown = true;
+              return Iterables.unmodifiableIterable(
+                  unnestCachedEntries(mergedCachedEntries(null).entrySet()));
+            }
+            if (!(entries instanceof Weighted)) {
+              return nonWeightedEntries(entries);
+            }
+            // This is a known amount of data, cache them all.
+            entries.forEach(
+                entry -> {
+                  try {
+                    final K originalKey = keyCoder.decode(entry.getKey().newInput(), Context.OUTER);
+                    final Object structuralKey = keyCoder.structuralValue(originalKey);
+                    KeyState keyState =
+                        keyStateMap.computeIfAbsent(structuralKey, k -> new KeyState(originalKey));
+                    // Ignore any key from windmill that has been marked pending deletion or is
+                    // fully cached.
+                    if (keyState.existence == KeyExistence.KNOWN_NONEXISTENT
+                        || (keyState.existence == KeyExistence.KNOWN_EXIST
+                            && keyState.valuesCached)) return;
+                    // Or else cache contents from windmill.
+                    keyState.existence = KeyExistence.KNOWN_EXIST;
+                    keyState.values.extendWith(entry.getValue());
+                    keyState.valuesSize += Iterables.size(entry.getValue());
+                    // We can't set keyState.valuesCached to true here, because there may be more
+                    // paginated values that should not be filtered out in above if statement.
+                    // keyState.valuesCached will be set to true in later call of
+                    // mergedCachedEntries.
+                  } catch (IOException e) {
+                    throw new RuntimeException(e);
+                  }
+                });
+            allKeysKnown = true;
+            complete = true;
+            return Iterables.unmodifiableIterable(
+                unnestCachedEntries(mergedCachedEntries(null).entrySet()));
+          } catch (InterruptedException | ExecutionException | IOException e) {
+            if (e instanceof InterruptedException) {
+              Thread.currentThread().interrupt();
+            }
+            throw new RuntimeException("Unable to read state", e);
+          }
+        }
+
+        @Override
+        @SuppressWarnings("FutureReturnValueIgnored")
+        public ReadableState<Iterable<Entry<K, V>>> readLater() {
+          WindmillMultimap.this.getFuture(false);
+          return this;
+        }
+
+        // Collect all cached entries into a map and all KNOWN_NONEXISTENT keys to
+        // knownNonexistentKeys(if not null). Note that this method is not side-effect-free: it
+        // unloads any key that is not KNOWN_EXIST and not pending deletion from cache; also if
+        // complete it marks the valuesCached of any key that is KNOWN_EXIST to true, entries()
+        // depends on this behavior when the fetched result is weighted to iterate the whole
+        // keyStateMap one less time.
+        private Map<Object, Triple<K, Boolean, ConcatIterables<V>>> mergedCachedEntries(
+            Map<Object, K> knownNonexistentKeys) {
+          Map<Object, Triple<K, Boolean, ConcatIterables<V>>> cachedEntries = Maps.newHashMap();
+          keyStateMap
+              .entrySet()
+              .removeIf(
+                  (entry -> {
+                    Object structuralKey = entry.getKey();
+                    KeyState keyState = entry.getValue();
+                    if (complete && keyState.existence == KeyExistence.KNOWN_EXIST) {
+                      keyState.valuesCached = true;
+                    }
+                    ConcatIterables<V> it = null;
+                    if (!keyState.localAdditions.isEmpty()) {
+                      it = new ConcatIterables<>();
+                      it.extendWith(
+                          Iterables.limit(keyState.localAdditions, keyState.localAdditions.size()));
+                    }
+                    if (keyState.valuesCached) {
+                      if (it == null) it = new ConcatIterables<>();
+                      it.extendWith(Iterables.limit(keyState.values, keyState.valuesSize));
+                    }
+                    if (it != null)
+                      cachedEntries.put(
+                          structuralKey,
+                          Triple.of(keyState.originalKey, keyState.valuesCached, it));
+                    if (knownNonexistentKeys != null
+                        && keyState.existence == KeyExistence.KNOWN_NONEXISTENT)
+                      knownNonexistentKeys.put(structuralKey, keyState.originalKey);

Review Comment:
   Yes, done



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternals.java:
##########
@@ -1588,7 +1598,648 @@ private Future<Iterable<Map.Entry<ByteString, V>>> getFuture() {
         return reader.valuePrefixFuture(stateKeyPrefix, stateFamily, valueCoder);
       }
     }
-  };
+  }
+
+  private static class WindmillMultimap<K, V> extends SimpleWindmillState
+      implements MultimapState<K, V> {
+
+    private final StateNamespace namespace;
+    private final StateTag<MultimapState<K, V>> address;
+    private final ByteString stateKey;
+    private final String stateFamily;
+    private final Coder<K> keyCoder;
+    private final Coder<V> valueCoder;
+
+    private class KeyState {
+      final K originalKey;
+      KeyExistence existence;
+      // valuesCached can be true if only existence == KNOWN_EXIST and all values of this key is
+      // cached (both values and localAdditions).
+      boolean valuesCached;
+      // Represents the values in windmill. When new values are added during user processing, they
+      // are added to localAdditions but not values. Those new values will be added to values only
+      // after they are persisted into windmill and removed from localAdditions
+      ConcatIterables<V> values;
+      int valuesSize;
+
+      // When new values are added during user processing, they are added to localAdditions, so that
+      // we can later try to persist them in windmill. When a key is removed during user processing,
+      // we mark removedLocally to be true so that we can later try to delete it from windmill. If
+      // localAdditions is not empty and removedLocally is true, values in localAdditions will be
+      // added to windmill after old values in windmill are removed.
+      List<V> localAdditions;
+      boolean removedLocally;
+
+      KeyState(K originalKey) {
+        this.originalKey = originalKey;
+        existence = KeyExistence.UNKNOWN_EXISTENCE;
+        valuesCached = complete;
+        values = new ConcatIterables<>();
+        valuesSize = 0;
+        localAdditions = Lists.newArrayList();
+        removedLocally = false;
+      }
+    }
+
+    private enum KeyExistence {
+      // this key is known to exist
+      KNOWN_EXIST,
+      // this key is known to be nonexistent
+      KNOWN_NONEXISTENT,
+      // we don't know if this key is in this multimap, this is just to provide a mapping between
+      // the original key and the structural key.
+      UNKNOWN_EXISTENCE
+    }
+
+    private boolean cleared = false;
+    // We use the structural value of the keys as the key in keyStateMap, so that different java
+    // Objects with the same content will be treated as the same Multimap key.
+    private Map<Object, KeyState> keyStateMap = Maps.newHashMap();
+    // If true, all keys are cached in keyStateMap with existence == KNOWN_EXIST.
+    private boolean allKeysKnown = false;
+
+    private boolean complete = false;
+    // hasLocalAdditions and hasLocalRemovals track whether there are local changes that needs to be
+    // propagated to windmill.
+    private boolean hasLocalAdditions = false;
+    private boolean hasLocalRemovals = false;
+
+    private WindmillMultimap(
+        StateNamespace namespace,
+        StateTag<MultimapState<K, V>> address,
+        String stateFamily,
+        Coder<K> keyCoder,
+        Coder<V> valueCoder,
+        boolean isNewShardingKey) {
+      this.namespace = namespace;
+      this.address = address;
+      this.stateKey = encodeKey(namespace, address);
+      this.stateFamily = stateFamily;
+      this.keyCoder = keyCoder;
+      this.valueCoder = valueCoder;
+      this.complete = isNewShardingKey;
+      this.allKeysKnown = isNewShardingKey;
+    }
+
+    @Override
+    public void put(K key, V value) {
+      final Object structuralKey = keyCoder.structuralValue(key);
+      hasLocalAdditions = true;
+      keyStateMap.compute(
+          structuralKey,
+          (k, v) -> {
+            if (v == null) v = new KeyState(key);
+            v.existence = KeyExistence.KNOWN_EXIST;
+            v.localAdditions.add(value);
+            return v;
+          });
+    }
+
+    // Initiates a backend state read to fetch all entries if necessary.
+    private Future<Iterable<Map.Entry<ByteString, Iterable<V>>>> getFuture(boolean omitValues) {
+      if (complete) {
+        return Futures.immediateFuture(Collections.emptyList());
+      } else {
+        return reader.multimapFetchAllFuture(omitValues, stateKey, stateFamily, valueCoder);
+      }
+    }
+
+    // Initiates a backend state read to fetch a single entry if necessary.
+    private Future<Iterable<V>> getFutureForKey(K key) {
+      try {
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        keyCoder.encode(key, keyStream, Context.OUTER);
+        return reader.multimapFetchSingleEntryFuture(
+            keyStream.toByteString(), stateKey, stateFamily, valueCoder);
+      } catch (IOException e) {
+        throw new RuntimeException(e);
+      }
+    }
+
+    @Override
+    public ReadableState<Iterable<V>> get(K key) {
+      return new ReadableState<Iterable<V>>() {
+        final Object structuralKey = keyCoder.structuralValue(key);
+
+        @Override
+        public Iterable<V> read() {
+          KeyState keyState = keyStateMap.computeIfAbsent(structuralKey, k -> new KeyState(key));
+          if (keyState.existence == KeyExistence.KNOWN_NONEXISTENT) {
+            return Collections.emptyList();
+          }
+          if (allKeysKnown && keyState.existence == KeyExistence.UNKNOWN_EXISTENCE) {
+            keyStateMap.remove(structuralKey);
+            return Collections.emptyList();
+          }
+          Iterable<V> localNewValues =
+              Iterables.limit(keyState.localAdditions, keyState.localAdditions.size());
+          if (keyState.removedLocally) {
+            // this key has been removed locally but the removal hasn't been sent to windmill,
+            // thus values in windmill(if any) are obsolete, and we only care about local values.
+            return Iterables.unmodifiableIterable(localNewValues);
+          }
+          if (keyState.valuesCached || complete) {
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(
+                    Iterables.limit(keyState.values, keyState.valuesSize), localNewValues));
+          }
+          Future<Iterable<V>> persistedData = getFutureForKey(key);
+          try (Closeable scope = scopedReadState()) {
+            final Iterable<V> persistedValues = persistedData.get();
+            // Iterables.isEmpty() is O(1).
+            if (Iterables.isEmpty(persistedValues)) {
+              if (keyState.localAdditions.isEmpty()) {
+                // empty in both cache and windmill, mark key as KNOWN_NONEXISTENT.
+                keyState.existence = KeyExistence.KNOWN_NONEXISTENT;
+                return Collections.emptyList();
+              }
+              return Iterables.unmodifiableIterable(localNewValues);
+            }
+            keyState.existence = KeyExistence.KNOWN_EXIST;
+            if (persistedValues instanceof Weighted) {
+              keyState.valuesCached = true;
+              ConcatIterables<V> it = new ConcatIterables<>();
+              it.extendWith(persistedValues);
+              keyState.values = it;
+              keyState.valuesSize = Iterables.size(persistedValues);
+            }
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(persistedValues, localNewValues));
+          } catch (InterruptedException | ExecutionException | IOException e) {
+            if (e instanceof InterruptedException) {
+              Thread.currentThread().interrupt();
+            }
+            throw new RuntimeException("Unable to read Multimap state", e);
+          }
+        }
+
+        @Override
+        @SuppressWarnings("FutureReturnValueIgnored")
+        public ReadableState<Iterable<V>> readLater() {
+          WindmillMultimap.this.getFutureForKey(key);
+          return this;
+        }
+      };
+    }
+
+    @Override
+    protected WorkItemCommitRequest persistDirectly(WindmillStateCache.ForKeyAndFamily cache)
+        throws IOException {
+      if (!cleared && !hasLocalAdditions && !hasLocalRemovals) {
+        cache.put(namespace, address, this, 1);
+        return WorkItemCommitRequest.newBuilder().buildPartial();
+      }
+      WorkItemCommitRequest.Builder commitBuilder = WorkItemCommitRequest.newBuilder();
+      Windmill.TagMultimapUpdateRequest.Builder builder = commitBuilder.addMultimapUpdatesBuilder();
+      builder.setTag(stateKey).setStateFamily(stateFamily);
+
+      if (cleared) {
+        builder.setDeleteAll(true);
+      }
+      if (hasLocalRemovals || hasLocalAdditions) {
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        ByteStringOutputStream valueStream = new ByteStringOutputStream();
+        Iterator<Entry<Object, KeyState>> iterator = keyStateMap.entrySet().iterator();
+        while (iterator.hasNext()) {
+          KeyState keyState = iterator.next().getValue();
+          if (!keyState.removedLocally && keyState.localAdditions.isEmpty()) {
+            if (keyState.existence == KeyExistence.KNOWN_NONEXISTENT) iterator.remove();
+            continue;
+          }
+          keyCoder.encode(keyState.originalKey, keyStream, Context.OUTER);
+          ByteString encodedKey = keyStream.toByteStringAndReset();
+          Windmill.TagMultimapEntry.Builder entryBuilder = builder.addUpdatesBuilder();
+          entryBuilder.setEntryName(encodedKey);
+          entryBuilder.setDeleteAll(keyState.removedLocally);
+          keyState.removedLocally = false;
+          for (V value : keyState.localAdditions) {
+            valueCoder.encode(value, valueStream, Context.OUTER);
+            ByteString encodedValue = valueStream.toByteStringAndReset();
+            entryBuilder.addValues(encodedValue);
+          }
+          // Move newly added values from localAdditions to keyState.values as those new values now
+          // are also persisted in Windmill. If a key now has no more values and is not KNOWN_EXIST,
+          // remove it from cache.
+          if (keyState.valuesCached) {
+            keyState.values.extendWith(keyState.localAdditions);
+            keyState.valuesSize += keyState.localAdditions.size();
+          }
+          // Create a new localAdditions so that the cached values are unaffected.
+          keyState.localAdditions = Lists.newArrayList();
+          if (!keyState.valuesCached && keyState.existence != KeyExistence.KNOWN_EXIST) {
+            iterator.remove();
+          }
+        }
+      }
+
+      hasLocalAdditions = false;
+      hasLocalRemovals = false;
+      cleared = false;
+
+      cache.put(namespace, address, this, 1);
+      return commitBuilder.buildPartial();
+    }
+
+    @Override
+    public void remove(K key) {
+      final Object structuralKey = keyCoder.structuralValue(key);
+      KeyState keyState = keyStateMap.computeIfAbsent(structuralKey, k -> new KeyState(key));
+      if (keyState.existence == KeyExistence.KNOWN_NONEXISTENT
+          || (allKeysKnown && keyState.existence == KeyExistence.UNKNOWN_EXISTENCE)) {
+        return;
+      }
+      if (!keyState.valuesCached || (keyState.valuesCached && keyState.valuesSize > 0)) {
+        // there may be data in windmill that need to be removed.
+        hasLocalRemovals = true;
+        keyState.removedLocally = true;
+        keyState.values = new ConcatIterables<>();
+        keyState.valuesSize = 0;
+        keyState.existence = KeyExistence.KNOWN_NONEXISTENT;
+      } else {
+        // no data in windmill, deleting from local cache is sufficient.
+        keyStateMap.remove(structuralKey);
+      }
+      if (!keyState.localAdditions.isEmpty()) {
+        keyState.localAdditions = Lists.newArrayList();
+      }
+      keyState.valuesCached = true;
+    }
+
+    @Override
+    public void clear() {
+      keyStateMap = Maps.newHashMap();
+      cleared = true;
+      complete = true;
+      allKeysKnown = true;
+    }
+
+    @Override
+    public ReadableState<Iterable<K>> keys() {
+      return new ReadableState<Iterable<K>>() {
+
+        private Map<Object, K> cachedExistKeys() {
+          return keyStateMap.entrySet().stream()
+              .filter(entry -> entry.getValue().existence == KeyExistence.KNOWN_EXIST)
+              .collect(Collectors.toMap(Entry::getKey, e -> e.getValue().originalKey));
+        }
+
+        @Override
+        public Iterable<K> read() {
+          if (allKeysKnown) {
+            return Iterables.unmodifiableIterable(cachedExistKeys().values());
+          }
+          Future<Iterable<Entry<ByteString, Iterable<V>>>> persistedData = getFuture(true);
+          try (Closeable scope = scopedReadState()) {
+            Iterable<Entry<ByteString, Iterable<V>>> entries = persistedData.get();
+            if (entries instanceof Weighted) {
+              // This is a known amount of data, cache them all.
+              entries.forEach(
+                  entry -> {
+                    try {
+                      K originalKey = keyCoder.decode(entry.getKey().newInput(), Context.OUTER);
+                      KeyState keyState =
+                          keyStateMap.computeIfAbsent(
+                              keyCoder.structuralValue(originalKey),
+                              stk -> new KeyState(originalKey));
+                      if (keyState.existence == KeyExistence.UNKNOWN_EXISTENCE) {
+                        keyState.existence = KeyExistence.KNOWN_EXIST;
+                      }
+                    } catch (IOException e) {
+                      throw new RuntimeException(e);
+                    }
+                  });
+              allKeysKnown = true;
+              keyStateMap
+                  .values()
+                  .removeIf(keyState -> keyState.existence != KeyExistence.KNOWN_EXIST);
+              return Iterables.unmodifiableIterable(cachedExistKeys().values());
+            } else {
+              Map<Object, K> cachedExistKeys = Maps.newHashMap();
+              Set<Object> cachedNonExistKeys = Sets.newHashSet();
+              keyStateMap.forEach(
+                  (structuralKey, keyState) -> {
+                    switch (keyState.existence) {
+                      case KNOWN_EXIST:
+                        cachedExistKeys.put(structuralKey, keyState.originalKey);
+                        break;
+                      case KNOWN_NONEXISTENT:
+                        cachedNonExistKeys.add(structuralKey);
+                        break;
+                      default:
+                        break;
+                    }
+                  });
+              // keysOnlyInWindmill is lazily loaded.
+              Iterable<K> keysOnlyInWindmill =
+                  Iterables.filter(
+                      Iterables.transform(
+                          entries,
+                          entry -> {
+                            try {
+                              K originalKey =
+                                  keyCoder.decode(entry.getKey().newInput(), Context.OUTER);
+                              Object structuralKey = keyCoder.structuralValue(originalKey);
+                              if (cachedExistKeys.containsKey(structuralKey)
+                                  || cachedNonExistKeys.contains(structuralKey)) return null;
+                              return originalKey;
+                            } catch (IOException e) {
+                              throw new RuntimeException(e);
+                            }
+                          }),
+                      Objects::nonNull);
+              return Iterables.unmodifiableIterable(
+                  Iterables.concat(cachedExistKeys.values(), keysOnlyInWindmill));
+            }
+          } catch (InterruptedException | ExecutionException | IOException e) {
+            if (e instanceof InterruptedException) {
+              Thread.currentThread().interrupt();
+            }
+            throw new RuntimeException("Unable to read state", e);
+          }
+        }
+
+        @Override
+        @SuppressWarnings("FutureReturnValueIgnored")
+        public ReadableState<Iterable<K>> readLater() {
+          WindmillMultimap.this.getFuture(true);
+          return this;
+        }
+      };
+    }
+
+    @Override
+    public ReadableState<Iterable<Entry<K, V>>> entries() {
+      return new ReadableState<Iterable<Entry<K, V>>>() {
+        @Override
+        public Iterable<Entry<K, V>> read() {
+          if (complete) {
+            return Iterables.unmodifiableIterable(
+                unnestCachedEntries(mergedCachedEntries(null).entrySet()));
+          }
+          Future<Iterable<Entry<ByteString, Iterable<V>>>> persistedData = getFuture(false);
+          try (Closeable scope = scopedReadState()) {
+            Iterable<Entry<ByteString, Iterable<V>>> entries = persistedData.get();
+            if (Iterables.isEmpty(entries)) {
+              complete = true;
+              allKeysKnown = true;
+              return Iterables.unmodifiableIterable(
+                  unnestCachedEntries(mergedCachedEntries(null).entrySet()));
+            }
+            if (!(entries instanceof Weighted)) {
+              return nonWeightedEntries(entries);
+            }
+            // This is a known amount of data, cache them all.
+            entries.forEach(
+                entry -> {
+                  try {
+                    final K originalKey = keyCoder.decode(entry.getKey().newInput(), Context.OUTER);
+                    final Object structuralKey = keyCoder.structuralValue(originalKey);
+                    KeyState keyState =
+                        keyStateMap.computeIfAbsent(structuralKey, k -> new KeyState(originalKey));
+                    // Ignore any key from windmill that has been marked pending deletion or is
+                    // fully cached.
+                    if (keyState.existence == KeyExistence.KNOWN_NONEXISTENT
+                        || (keyState.existence == KeyExistence.KNOWN_EXIST
+                            && keyState.valuesCached)) return;
+                    // Or else cache contents from windmill.
+                    keyState.existence = KeyExistence.KNOWN_EXIST;
+                    keyState.values.extendWith(entry.getValue());
+                    keyState.valuesSize += Iterables.size(entry.getValue());
+                    // We can't set keyState.valuesCached to true here, because there may be more
+                    // paginated values that should not be filtered out in above if statement.
+                    // keyState.valuesCached will be set to true in later call of
+                    // mergedCachedEntries.
+                  } catch (IOException e) {
+                    throw new RuntimeException(e);
+                  }
+                });
+            allKeysKnown = true;
+            complete = true;
+            return Iterables.unmodifiableIterable(
+                unnestCachedEntries(mergedCachedEntries(null).entrySet()));
+          } catch (InterruptedException | ExecutionException | IOException e) {
+            if (e instanceof InterruptedException) {
+              Thread.currentThread().interrupt();
+            }
+            throw new RuntimeException("Unable to read state", e);
+          }
+        }
+
+        @Override
+        @SuppressWarnings("FutureReturnValueIgnored")
+        public ReadableState<Iterable<Entry<K, V>>> readLater() {
+          WindmillMultimap.this.getFuture(false);
+          return this;
+        }
+
+        // Collect all cached entries into a map and all KNOWN_NONEXISTENT keys to
+        // knownNonexistentKeys(if not null). Note that this method is not side-effect-free: it
+        // unloads any key that is not KNOWN_EXIST and not pending deletion from cache; also if
+        // complete it marks the valuesCached of any key that is KNOWN_EXIST to true, entries()
+        // depends on this behavior when the fetched result is weighted to iterate the whole
+        // keyStateMap one less time.

Review Comment:
   Done



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternals.java:
##########
@@ -1588,7 +1598,648 @@ private Future<Iterable<Map.Entry<ByteString, V>>> getFuture() {
         return reader.valuePrefixFuture(stateKeyPrefix, stateFamily, valueCoder);
       }
     }
-  };
+  }
+
+  private static class WindmillMultimap<K, V> extends SimpleWindmillState
+      implements MultimapState<K, V> {
+
+    private final StateNamespace namespace;
+    private final StateTag<MultimapState<K, V>> address;
+    private final ByteString stateKey;
+    private final String stateFamily;
+    private final Coder<K> keyCoder;
+    private final Coder<V> valueCoder;
+
+    private class KeyState {
+      final K originalKey;
+      KeyExistence existence;
+      // valuesCached can be true if only existence == KNOWN_EXIST and all values of this key is
+      // cached (both values and localAdditions).
+      boolean valuesCached;
+      // Represents the values in windmill. When new values are added during user processing, they
+      // are added to localAdditions but not values. Those new values will be added to values only
+      // after they are persisted into windmill and removed from localAdditions
+      ConcatIterables<V> values;
+      int valuesSize;
+
+      // When new values are added during user processing, they are added to localAdditions, so that
+      // we can later try to persist them in windmill. When a key is removed during user processing,
+      // we mark removedLocally to be true so that we can later try to delete it from windmill. If
+      // localAdditions is not empty and removedLocally is true, values in localAdditions will be
+      // added to windmill after old values in windmill are removed.
+      List<V> localAdditions;
+      boolean removedLocally;
+
+      KeyState(K originalKey) {
+        this.originalKey = originalKey;
+        existence = KeyExistence.UNKNOWN_EXISTENCE;
+        valuesCached = complete;
+        values = new ConcatIterables<>();
+        valuesSize = 0;
+        localAdditions = Lists.newArrayList();
+        removedLocally = false;
+      }
+    }
+
+    private enum KeyExistence {
+      // this key is known to exist
+      KNOWN_EXIST,
+      // this key is known to be nonexistent
+      KNOWN_NONEXISTENT,
+      // we don't know if this key is in this multimap, this is just to provide a mapping between
+      // the original key and the structural key.
+      UNKNOWN_EXISTENCE
+    }
+
+    private boolean cleared = false;
+    // We use the structural value of the keys as the key in keyStateMap, so that different java
+    // Objects with the same content will be treated as the same Multimap key.
+    private Map<Object, KeyState> keyStateMap = Maps.newHashMap();
+    // If true, all keys are cached in keyStateMap with existence == KNOWN_EXIST.
+    private boolean allKeysKnown = false;
+
+    private boolean complete = false;
+    // hasLocalAdditions and hasLocalRemovals track whether there are local changes that needs to be
+    // propagated to windmill.
+    private boolean hasLocalAdditions = false;
+    private boolean hasLocalRemovals = false;
+
+    private WindmillMultimap(
+        StateNamespace namespace,
+        StateTag<MultimapState<K, V>> address,
+        String stateFamily,
+        Coder<K> keyCoder,
+        Coder<V> valueCoder,
+        boolean isNewShardingKey) {
+      this.namespace = namespace;
+      this.address = address;
+      this.stateKey = encodeKey(namespace, address);
+      this.stateFamily = stateFamily;
+      this.keyCoder = keyCoder;
+      this.valueCoder = valueCoder;
+      this.complete = isNewShardingKey;
+      this.allKeysKnown = isNewShardingKey;
+    }
+
+    @Override
+    public void put(K key, V value) {
+      final Object structuralKey = keyCoder.structuralValue(key);
+      hasLocalAdditions = true;
+      keyStateMap.compute(
+          structuralKey,
+          (k, v) -> {
+            if (v == null) v = new KeyState(key);
+            v.existence = KeyExistence.KNOWN_EXIST;
+            v.localAdditions.add(value);
+            return v;
+          });
+    }
+
+    // Initiates a backend state read to fetch all entries if necessary.
+    private Future<Iterable<Map.Entry<ByteString, Iterable<V>>>> getFuture(boolean omitValues) {
+      if (complete) {
+        return Futures.immediateFuture(Collections.emptyList());
+      } else {
+        return reader.multimapFetchAllFuture(omitValues, stateKey, stateFamily, valueCoder);
+      }
+    }
+
+    // Initiates a backend state read to fetch a single entry if necessary.
+    private Future<Iterable<V>> getFutureForKey(K key) {
+      try {
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        keyCoder.encode(key, keyStream, Context.OUTER);
+        return reader.multimapFetchSingleEntryFuture(
+            keyStream.toByteString(), stateKey, stateFamily, valueCoder);
+      } catch (IOException e) {
+        throw new RuntimeException(e);
+      }
+    }
+
+    @Override
+    public ReadableState<Iterable<V>> get(K key) {
+      return new ReadableState<Iterable<V>>() {
+        final Object structuralKey = keyCoder.structuralValue(key);
+
+        @Override
+        public Iterable<V> read() {
+          KeyState keyState = keyStateMap.computeIfAbsent(structuralKey, k -> new KeyState(key));
+          if (keyState.existence == KeyExistence.KNOWN_NONEXISTENT) {
+            return Collections.emptyList();
+          }
+          if (allKeysKnown && keyState.existence == KeyExistence.UNKNOWN_EXISTENCE) {
+            keyStateMap.remove(structuralKey);
+            return Collections.emptyList();
+          }
+          Iterable<V> localNewValues =
+              Iterables.limit(keyState.localAdditions, keyState.localAdditions.size());
+          if (keyState.removedLocally) {
+            // this key has been removed locally but the removal hasn't been sent to windmill,
+            // thus values in windmill(if any) are obsolete, and we only care about local values.
+            return Iterables.unmodifiableIterable(localNewValues);
+          }
+          if (keyState.valuesCached || complete) {
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(
+                    Iterables.limit(keyState.values, keyState.valuesSize), localNewValues));
+          }
+          Future<Iterable<V>> persistedData = getFutureForKey(key);
+          try (Closeable scope = scopedReadState()) {
+            final Iterable<V> persistedValues = persistedData.get();
+            // Iterables.isEmpty() is O(1).
+            if (Iterables.isEmpty(persistedValues)) {
+              if (keyState.localAdditions.isEmpty()) {
+                // empty in both cache and windmill, mark key as KNOWN_NONEXISTENT.
+                keyState.existence = KeyExistence.KNOWN_NONEXISTENT;
+                return Collections.emptyList();
+              }
+              return Iterables.unmodifiableIterable(localNewValues);
+            }
+            keyState.existence = KeyExistence.KNOWN_EXIST;
+            if (persistedValues instanceof Weighted) {
+              keyState.valuesCached = true;
+              ConcatIterables<V> it = new ConcatIterables<>();
+              it.extendWith(persistedValues);
+              keyState.values = it;
+              keyState.valuesSize = Iterables.size(persistedValues);
+            }
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(persistedValues, localNewValues));
+          } catch (InterruptedException | ExecutionException | IOException e) {
+            if (e instanceof InterruptedException) {
+              Thread.currentThread().interrupt();
+            }
+            throw new RuntimeException("Unable to read Multimap state", e);
+          }
+        }
+
+        @Override
+        @SuppressWarnings("FutureReturnValueIgnored")
+        public ReadableState<Iterable<V>> readLater() {
+          WindmillMultimap.this.getFutureForKey(key);
+          return this;
+        }
+      };
+    }
+
+    @Override
+    protected WorkItemCommitRequest persistDirectly(WindmillStateCache.ForKeyAndFamily cache)
+        throws IOException {
+      if (!cleared && !hasLocalAdditions && !hasLocalRemovals) {
+        cache.put(namespace, address, this, 1);
+        return WorkItemCommitRequest.newBuilder().buildPartial();
+      }
+      WorkItemCommitRequest.Builder commitBuilder = WorkItemCommitRequest.newBuilder();
+      Windmill.TagMultimapUpdateRequest.Builder builder = commitBuilder.addMultimapUpdatesBuilder();
+      builder.setTag(stateKey).setStateFamily(stateFamily);
+
+      if (cleared) {
+        builder.setDeleteAll(true);
+      }
+      if (hasLocalRemovals || hasLocalAdditions) {
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        ByteStringOutputStream valueStream = new ByteStringOutputStream();
+        Iterator<Entry<Object, KeyState>> iterator = keyStateMap.entrySet().iterator();
+        while (iterator.hasNext()) {
+          KeyState keyState = iterator.next().getValue();
+          if (!keyState.removedLocally && keyState.localAdditions.isEmpty()) {
+            if (keyState.existence == KeyExistence.KNOWN_NONEXISTENT) iterator.remove();
+            continue;
+          }
+          keyCoder.encode(keyState.originalKey, keyStream, Context.OUTER);
+          ByteString encodedKey = keyStream.toByteStringAndReset();
+          Windmill.TagMultimapEntry.Builder entryBuilder = builder.addUpdatesBuilder();
+          entryBuilder.setEntryName(encodedKey);
+          entryBuilder.setDeleteAll(keyState.removedLocally);
+          keyState.removedLocally = false;
+          for (V value : keyState.localAdditions) {
+            valueCoder.encode(value, valueStream, Context.OUTER);
+            ByteString encodedValue = valueStream.toByteStringAndReset();
+            entryBuilder.addValues(encodedValue);
+          }
+          // Move newly added values from localAdditions to keyState.values as those new values now
+          // are also persisted in Windmill. If a key now has no more values and is not KNOWN_EXIST,
+          // remove it from cache.
+          if (keyState.valuesCached) {
+            keyState.values.extendWith(keyState.localAdditions);
+            keyState.valuesSize += keyState.localAdditions.size();
+          }
+          // Create a new localAdditions so that the cached values are unaffected.
+          keyState.localAdditions = Lists.newArrayList();
+          if (!keyState.valuesCached && keyState.existence != KeyExistence.KNOWN_EXIST) {
+            iterator.remove();
+          }
+        }
+      }
+
+      hasLocalAdditions = false;
+      hasLocalRemovals = false;
+      cleared = false;
+
+      cache.put(namespace, address, this, 1);
+      return commitBuilder.buildPartial();
+    }
+
+    @Override
+    public void remove(K key) {
+      final Object structuralKey = keyCoder.structuralValue(key);
+      KeyState keyState = keyStateMap.computeIfAbsent(structuralKey, k -> new KeyState(key));
+      if (keyState.existence == KeyExistence.KNOWN_NONEXISTENT
+          || (allKeysKnown && keyState.existence == KeyExistence.UNKNOWN_EXISTENCE)) {
+        return;
+      }
+      if (!keyState.valuesCached || (keyState.valuesCached && keyState.valuesSize > 0)) {
+        // there may be data in windmill that need to be removed.
+        hasLocalRemovals = true;
+        keyState.removedLocally = true;
+        keyState.values = new ConcatIterables<>();
+        keyState.valuesSize = 0;
+        keyState.existence = KeyExistence.KNOWN_NONEXISTENT;
+      } else {
+        // no data in windmill, deleting from local cache is sufficient.
+        keyStateMap.remove(structuralKey);
+      }
+      if (!keyState.localAdditions.isEmpty()) {
+        keyState.localAdditions = Lists.newArrayList();
+      }
+      keyState.valuesCached = true;
+    }
+
+    @Override
+    public void clear() {
+      keyStateMap = Maps.newHashMap();
+      cleared = true;
+      complete = true;
+      allKeysKnown = true;
+    }
+
+    @Override
+    public ReadableState<Iterable<K>> keys() {
+      return new ReadableState<Iterable<K>>() {
+
+        private Map<Object, K> cachedExistKeys() {
+          return keyStateMap.entrySet().stream()
+              .filter(entry -> entry.getValue().existence == KeyExistence.KNOWN_EXIST)
+              .collect(Collectors.toMap(Entry::getKey, e -> e.getValue().originalKey));
+        }
+
+        @Override
+        public Iterable<K> read() {
+          if (allKeysKnown) {
+            return Iterables.unmodifiableIterable(cachedExistKeys().values());
+          }
+          Future<Iterable<Entry<ByteString, Iterable<V>>>> persistedData = getFuture(true);
+          try (Closeable scope = scopedReadState()) {
+            Iterable<Entry<ByteString, Iterable<V>>> entries = persistedData.get();
+            if (entries instanceof Weighted) {
+              // This is a known amount of data, cache them all.
+              entries.forEach(
+                  entry -> {
+                    try {
+                      K originalKey = keyCoder.decode(entry.getKey().newInput(), Context.OUTER);
+                      KeyState keyState =
+                          keyStateMap.computeIfAbsent(
+                              keyCoder.structuralValue(originalKey),
+                              stk -> new KeyState(originalKey));
+                      if (keyState.existence == KeyExistence.UNKNOWN_EXISTENCE) {
+                        keyState.existence = KeyExistence.KNOWN_EXIST;
+                      }
+                    } catch (IOException e) {
+                      throw new RuntimeException(e);
+                    }
+                  });
+              allKeysKnown = true;
+              keyStateMap
+                  .values()
+                  .removeIf(keyState -> keyState.existence != KeyExistence.KNOWN_EXIST);
+              return Iterables.unmodifiableIterable(cachedExistKeys().values());
+            } else {
+              Map<Object, K> cachedExistKeys = Maps.newHashMap();
+              Set<Object> cachedNonExistKeys = Sets.newHashSet();
+              keyStateMap.forEach(
+                  (structuralKey, keyState) -> {
+                    switch (keyState.existence) {
+                      case KNOWN_EXIST:
+                        cachedExistKeys.put(structuralKey, keyState.originalKey);
+                        break;
+                      case KNOWN_NONEXISTENT:
+                        cachedNonExistKeys.add(structuralKey);
+                        break;
+                      default:
+                        break;
+                    }
+                  });
+              // keysOnlyInWindmill is lazily loaded.
+              Iterable<K> keysOnlyInWindmill =
+                  Iterables.filter(
+                      Iterables.transform(
+                          entries,
+                          entry -> {
+                            try {
+                              K originalKey =
+                                  keyCoder.decode(entry.getKey().newInput(), Context.OUTER);
+                              Object structuralKey = keyCoder.structuralValue(originalKey);
+                              if (cachedExistKeys.containsKey(structuralKey)
+                                  || cachedNonExistKeys.contains(structuralKey)) return null;
+                              return originalKey;
+                            } catch (IOException e) {
+                              throw new RuntimeException(e);
+                            }
+                          }),
+                      Objects::nonNull);
+              return Iterables.unmodifiableIterable(
+                  Iterables.concat(cachedExistKeys.values(), keysOnlyInWindmill));
+            }
+          } catch (InterruptedException | ExecutionException | IOException e) {
+            if (e instanceof InterruptedException) {
+              Thread.currentThread().interrupt();
+            }
+            throw new RuntimeException("Unable to read state", e);
+          }
+        }
+
+        @Override
+        @SuppressWarnings("FutureReturnValueIgnored")
+        public ReadableState<Iterable<K>> readLater() {
+          WindmillMultimap.this.getFuture(true);
+          return this;
+        }
+      };
+    }
+
+    @Override
+    public ReadableState<Iterable<Entry<K, V>>> entries() {
+      return new ReadableState<Iterable<Entry<K, V>>>() {
+        @Override
+        public Iterable<Entry<K, V>> read() {
+          if (complete) {
+            return Iterables.unmodifiableIterable(
+                unnestCachedEntries(mergedCachedEntries(null).entrySet()));
+          }
+          Future<Iterable<Entry<ByteString, Iterable<V>>>> persistedData = getFuture(false);
+          try (Closeable scope = scopedReadState()) {
+            Iterable<Entry<ByteString, Iterable<V>>> entries = persistedData.get();
+            if (Iterables.isEmpty(entries)) {
+              complete = true;
+              allKeysKnown = true;
+              return Iterables.unmodifiableIterable(
+                  unnestCachedEntries(mergedCachedEntries(null).entrySet()));
+            }
+            if (!(entries instanceof Weighted)) {
+              return nonWeightedEntries(entries);
+            }
+            // This is a known amount of data, cache them all.
+            entries.forEach(
+                entry -> {
+                  try {
+                    final K originalKey = keyCoder.decode(entry.getKey().newInput(), Context.OUTER);
+                    final Object structuralKey = keyCoder.structuralValue(originalKey);
+                    KeyState keyState =
+                        keyStateMap.computeIfAbsent(structuralKey, k -> new KeyState(originalKey));
+                    // Ignore any key from windmill that has been marked pending deletion or is
+                    // fully cached.
+                    if (keyState.existence == KeyExistence.KNOWN_NONEXISTENT
+                        || (keyState.existence == KeyExistence.KNOWN_EXIST
+                            && keyState.valuesCached)) return;
+                    // Or else cache contents from windmill.
+                    keyState.existence = KeyExistence.KNOWN_EXIST;
+                    keyState.values.extendWith(entry.getValue());
+                    keyState.valuesSize += Iterables.size(entry.getValue());
+                    // We can't set keyState.valuesCached to true here, because there may be more
+                    // paginated values that should not be filtered out in above if statement.
+                    // keyState.valuesCached will be set to true in later call of
+                    // mergedCachedEntries.
+                  } catch (IOException e) {
+                    throw new RuntimeException(e);
+                  }
+                });
+            allKeysKnown = true;
+            complete = true;
+            return Iterables.unmodifiableIterable(
+                unnestCachedEntries(mergedCachedEntries(null).entrySet()));
+          } catch (InterruptedException | ExecutionException | IOException e) {
+            if (e instanceof InterruptedException) {
+              Thread.currentThread().interrupt();
+            }
+            throw new RuntimeException("Unable to read state", e);
+          }
+        }
+
+        @Override
+        @SuppressWarnings("FutureReturnValueIgnored")
+        public ReadableState<Iterable<Entry<K, V>>> readLater() {
+          WindmillMultimap.this.getFuture(false);
+          return this;
+        }
+
+        // Collect all cached entries into a map and all KNOWN_NONEXISTENT keys to
+        // knownNonexistentKeys(if not null). Note that this method is not side-effect-free: it
+        // unloads any key that is not KNOWN_EXIST and not pending deletion from cache; also if
+        // complete it marks the valuesCached of any key that is KNOWN_EXIST to true, entries()
+        // depends on this behavior when the fetched result is weighted to iterate the whole
+        // keyStateMap one less time.
+        private Map<Object, Triple<K, Boolean, ConcatIterables<V>>> mergedCachedEntries(
+            Map<Object, K> knownNonexistentKeys) {
+          Map<Object, Triple<K, Boolean, ConcatIterables<V>>> cachedEntries = Maps.newHashMap();
+          keyStateMap
+              .entrySet()
+              .removeIf(
+                  (entry -> {
+                    Object structuralKey = entry.getKey();
+                    KeyState keyState = entry.getValue();
+                    if (complete && keyState.existence == KeyExistence.KNOWN_EXIST) {
+                      keyState.valuesCached = true;
+                    }
+                    ConcatIterables<V> it = null;
+                    if (!keyState.localAdditions.isEmpty()) {
+                      it = new ConcatIterables<>();
+                      it.extendWith(
+                          Iterables.limit(keyState.localAdditions, keyState.localAdditions.size()));
+                    }
+                    if (keyState.valuesCached) {
+                      if (it == null) it = new ConcatIterables<>();
+                      it.extendWith(Iterables.limit(keyState.values, keyState.valuesSize));
+                    }
+                    if (it != null)

Review Comment:
   Done



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@beam.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org