You are viewing a plain text version of this content. The canonical link for it is here.
Posted to github@beam.apache.org by GitBox <gi...@apache.org> on 2023/01/11 14:50:49 UTC

[GitHub] [beam] scwhittle commented on a diff in pull request #23492: Add Windmill support for MultimapState

scwhittle commented on code in PR #23492:
URL: https://github.com/apache/beam/pull/23492#discussion_r1067051289


##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternals.java:
##########
@@ -1588,7 +1599,522 @@ private Future<Iterable<Map.Entry<ByteString, V>>> getFuture() {
         return reader.valuePrefixFuture(stateKeyPrefix, stateFamily, valueCoder);
       }
     }
-  };
+  }
+
+  private static class WindmillMultimap<K, V> extends SimpleWindmillState
+      implements MultimapState<K, V> {
+
+    private final StateNamespace namespace;
+    private final StateTag<MultimapState<K, V>> address;
+    private final ByteString stateKey;
+    private final String stateFamily;
+    private final Coder<K> keyCoder;
+    private final Coder<V> valueCoder;
+
+    private class KeyState {
+      final K originalKey;
+      KeyExistence existence;
+      // valuesCached can be true if only existence == KNOWN_EXIST and all values of this key is
+      // cached(both KeyState#values and localAdditions).
+      boolean valuesCached;
+      // represents the values in windmill. When new values are added, they are added to

Review Comment:
   nit: When new values are added during user processing, they ...
   Makes it a little clearer to separate from values added from persistent state read



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternals.java:
##########
@@ -1588,7 +1599,522 @@ private Future<Iterable<Map.Entry<ByteString, V>>> getFuture() {
         return reader.valuePrefixFuture(stateKeyPrefix, stateFamily, valueCoder);
       }
     }
-  };
+  }
+
+  private static class WindmillMultimap<K, V> extends SimpleWindmillState
+      implements MultimapState<K, V> {
+
+    private final StateNamespace namespace;
+    private final StateTag<MultimapState<K, V>> address;
+    private final ByteString stateKey;
+    private final String stateFamily;
+    private final Coder<K> keyCoder;
+    private final Coder<V> valueCoder;
+
+    private class KeyState {
+      final K originalKey;
+      KeyExistence existence;
+      // valuesCached can be true if only existence == KNOWN_EXIST and all values of this key is
+      // cached(both KeyState#values and localAdditions).

Review Comment:
   nit: space after cached



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternals.java:
##########
@@ -1588,7 +1599,522 @@ private Future<Iterable<Map.Entry<ByteString, V>>> getFuture() {
         return reader.valuePrefixFuture(stateKeyPrefix, stateFamily, valueCoder);
       }
     }
-  };
+  }
+
+  private static class WindmillMultimap<K, V> extends SimpleWindmillState
+      implements MultimapState<K, V> {
+
+    private final StateNamespace namespace;
+    private final StateTag<MultimapState<K, V>> address;
+    private final ByteString stateKey;
+    private final String stateFamily;
+    private final Coder<K> keyCoder;
+    private final Coder<V> valueCoder;
+
+    private class KeyState {
+      final K originalKey;
+      KeyExistence existence;
+      // valuesCached can be true if only existence == KNOWN_EXIST and all values of this key is
+      // cached(both KeyState#values and localAdditions).
+      boolean valuesCached;
+      // represents the values in windmill. When new values are added, they are added to
+      // localAdditions but not KeyState#values. New values will be added to KeyState#values only
+      // after they are persisted into windmill and removed from localAdditions.
+      ConcatIterables<V> values;
+
+      KeyState(K originalKey) {
+        this.originalKey = originalKey;
+        existence = KeyExistence.UNKNOWN_EXISTENCE;
+        valuesCached = false;
+        values = new ConcatIterables<>();
+      }
+    }
+
+    private enum KeyExistence {
+      // this key is known to exist
+      KNOWN_EXIST,
+      // this key is known to be nonexistent
+      KNOWN_NONEXISTENT,
+      // we don't know if this key is in this multimap, this is just to provide a mapping between
+      // the original key and the structural key.
+      UNKNOWN_EXISTENCE
+    }
+
+    private boolean cleared = false;
+    // We use the structural value of the keys as the key in keyStateMap, so that different java
+    // Objects with the same content will be treated as the same Multimap key.
+    private Map<Object, KeyState> keyStateMap = Maps.newHashMap();
+    // If true, all keys are cached in keyStateMap with existence == KNOWN_EXIST.
+    private boolean allKeysKnown = false;
+
+    private boolean complete = false;
+    // All keys that have new values pending write to windmill.
+    private Multimap<Object, V> localAdditions = ArrayListMultimap.create();
+    // All keys that are pending delete. If a key exist in both localRemovals and localAdditions:
+    // new values in localAdditions will be added after old values are removed.
+    private Set<Object> localRemovals = Sets.newHashSet();
+
+    private WindmillMultimap(
+        StateNamespace namespace,
+        StateTag<MultimapState<K, V>> address,
+        String stateFamily,
+        Coder<K> keyCoder,
+        Coder<V> valueCoder,
+        boolean isNewShardingKey) {
+      this.namespace = namespace;
+      this.address = address;
+      this.stateKey = encodeKey(namespace, address);
+      this.stateFamily = stateFamily;
+      this.keyCoder = keyCoder;
+      this.valueCoder = valueCoder;
+      this.complete = isNewShardingKey;
+      this.allKeysKnown = isNewShardingKey;
+    }
+
+    @Override
+    public void put(K key, V value) {
+      final Object structuralKey = keyCoder.structuralValue(key);
+      localAdditions.put(structuralKey, value);
+      keyStateMap.compute(
+          structuralKey,
+          (k, v) -> {
+            if (v == null) v = new KeyState(key);
+            v.existence = KeyExistence.KNOWN_EXIST;
+            return v;
+          });
+    }
+
+    // Initiates a backend state read to fetch all entries if necessary.
+    private Future<Iterable<Map.Entry<ByteString, Iterable<V>>>> getFuture(boolean omitValues) {
+      if (complete) {
+        return Futures.immediateFuture(Collections.emptyList());
+      } else {
+        return reader.multimapFetchAllFuture(omitValues, stateKey, stateFamily, valueCoder);
+      }
+    }
+
+    // Initiates a backend state read to fetch a single entry if necessary.
+    private Future<Iterable<V>> getFutureForKey(K key) {
+      try {
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        keyCoder.encode(key, keyStream);
+        return reader.multimapFetchSingleEntryFuture(
+            keyStream.toByteString(), stateKey, stateFamily, valueCoder);
+      } catch (IOException e) {
+        throw new RuntimeException(e);
+      }
+    }
+
+    @Override
+    public ReadableState<Iterable<V>> get(K key) {
+      return new ReadableState<Iterable<V>>() {
+        final Object structuralKey = keyCoder.structuralValue(key);
+
+        @Override
+        public Iterable<V> read() {
+          KeyState keyState = keyStateMap.computeIfAbsent(structuralKey, k -> new KeyState(key));
+          if (keyState.existence == KeyExistence.KNOWN_NONEXISTENT) {
+            return Collections.emptyList();
+          }
+          if (allKeysKnown && keyState.existence == KeyExistence.UNKNOWN_EXISTENCE) {
+            keyStateMap.remove(structuralKey);
+            return Collections.emptyList();
+          }
+
+          if (localRemovals.contains(structuralKey)) {
+            // this key has been removed locally but the removal hasn't been sent to windmill,
+            // thus values in windmill(if any) are obsolete, and we only care about local values.
+            return Iterables.unmodifiableIterable(localAdditions.get(structuralKey));
+          }
+          if (keyState.valuesCached || complete) {
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(keyState.values, localAdditions.get(structuralKey)));
+          }
+          Future<Iterable<V>> persistedData = getFutureForKey(key);
+          try (Closeable scope = scopedReadState()) {
+            final Iterable<V> persistedValues = persistedData.get();
+            if (Iterables.isEmpty(persistedValues)) {
+              Collection<V> local = localAdditions.get(structuralKey);
+              if (local.isEmpty()) {
+                // empty in both cache and windmill, remove key from cache.
+                keyStateMap.remove(structuralKey);
+                return Collections.emptyList();
+              }
+              return Iterables.unmodifiableIterable(local);
+            }
+            if (persistedValues instanceof Weighted) {
+              keyState.existence = KeyExistence.KNOWN_EXIST;
+              keyState.valuesCached = true;
+              ConcatIterables<V> it = new ConcatIterables<>();
+              it.extendWith(persistedValues);
+              keyState.values = it;
+            }
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(persistedValues, localAdditions.get(structuralKey)));
+          } catch (InterruptedException | ExecutionException | IOException e) {
+            if (e instanceof InterruptedException) {
+              Thread.currentThread().interrupt();
+            }
+            throw new RuntimeException("Unable to read Multimap state", e);
+          }
+        }
+
+        @Override
+        @SuppressWarnings("FutureReturnValueIgnored")
+        public ReadableState<Iterable<V>> readLater() {
+          WindmillMultimap.this.getFutureForKey(key);
+          return this;
+        }
+      };
+    }
+
+    @Override
+    protected WorkItemCommitRequest persistDirectly(WindmillStateCache.ForKeyAndFamily cache)
+        throws IOException {
+      if (!cleared && localAdditions.isEmpty() && localRemovals.isEmpty()) {
+        return WorkItemCommitRequest.newBuilder().buildPartial();
+      }
+      WorkItemCommitRequest.Builder commitBuilder = WorkItemCommitRequest.newBuilder();
+      Windmill.TagMultimapUpdateRequest.Builder builder = null;
+      if (cleared) {
+        builder = commitBuilder.addMultimapUpdatesBuilder();
+        builder.setDeleteAll(true);
+        cleared = false;
+      }
+      Set<Object> keysWithUpdates = Sets.newHashSet();
+      keysWithUpdates.addAll(localRemovals);
+      keysWithUpdates.addAll(localAdditions.keySet());
+      if (!keysWithUpdates.isEmpty() && builder == null) {
+        builder = commitBuilder.addMultimapUpdatesBuilder();
+        builder.setTag(stateKey).setStateFamily(stateFamily);
+      }
+      for (Object structuralKey : keysWithUpdates) {
+        KeyState keyState = keyStateMap.get(structuralKey);
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();

Review Comment:
   more efficient to resuse the same stream for everything with toByteStringAndReset



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternals.java:
##########
@@ -1588,7 +1599,522 @@ private Future<Iterable<Map.Entry<ByteString, V>>> getFuture() {
         return reader.valuePrefixFuture(stateKeyPrefix, stateFamily, valueCoder);
       }
     }
-  };
+  }
+
+  private static class WindmillMultimap<K, V> extends SimpleWindmillState
+      implements MultimapState<K, V> {
+
+    private final StateNamespace namespace;
+    private final StateTag<MultimapState<K, V>> address;
+    private final ByteString stateKey;
+    private final String stateFamily;
+    private final Coder<K> keyCoder;
+    private final Coder<V> valueCoder;
+
+    private class KeyState {
+      final K originalKey;
+      KeyExistence existence;
+      // valuesCached can be true if only existence == KNOWN_EXIST and all values of this key is
+      // cached(both KeyState#values and localAdditions).
+      boolean valuesCached;
+      // represents the values in windmill. When new values are added, they are added to
+      // localAdditions but not KeyState#values. New values will be added to KeyState#values only
+      // after they are persisted into windmill and removed from localAdditions.
+      ConcatIterables<V> values;
+
+      KeyState(K originalKey) {
+        this.originalKey = originalKey;
+        existence = KeyExistence.UNKNOWN_EXISTENCE;
+        valuesCached = false;
+        values = new ConcatIterables<>();
+      }
+    }
+
+    private enum KeyExistence {
+      // this key is known to exist
+      KNOWN_EXIST,
+      // this key is known to be nonexistent
+      KNOWN_NONEXISTENT,
+      // we don't know if this key is in this multimap, this is just to provide a mapping between
+      // the original key and the structural key.
+      UNKNOWN_EXISTENCE
+    }
+
+    private boolean cleared = false;
+    // We use the structural value of the keys as the key in keyStateMap, so that different java
+    // Objects with the same content will be treated as the same Multimap key.
+    private Map<Object, KeyState> keyStateMap = Maps.newHashMap();
+    // If true, all keys are cached in keyStateMap with existence == KNOWN_EXIST.
+    private boolean allKeysKnown = false;
+
+    private boolean complete = false;
+    // All keys that have new values pending write to windmill.
+    private Multimap<Object, V> localAdditions = ArrayListMultimap.create();
+    // All keys that are pending delete. If a key exist in both localRemovals and localAdditions:
+    // new values in localAdditions will be added after old values are removed.
+    private Set<Object> localRemovals = Sets.newHashSet();
+
+    private WindmillMultimap(
+        StateNamespace namespace,
+        StateTag<MultimapState<K, V>> address,
+        String stateFamily,
+        Coder<K> keyCoder,
+        Coder<V> valueCoder,
+        boolean isNewShardingKey) {
+      this.namespace = namespace;
+      this.address = address;
+      this.stateKey = encodeKey(namespace, address);
+      this.stateFamily = stateFamily;
+      this.keyCoder = keyCoder;
+      this.valueCoder = valueCoder;
+      this.complete = isNewShardingKey;
+      this.allKeysKnown = isNewShardingKey;
+    }
+
+    @Override
+    public void put(K key, V value) {
+      final Object structuralKey = keyCoder.structuralValue(key);
+      localAdditions.put(structuralKey, value);
+      keyStateMap.compute(
+          structuralKey,
+          (k, v) -> {
+            if (v == null) v = new KeyState(key);
+            v.existence = KeyExistence.KNOWN_EXIST;
+            return v;
+          });
+    }
+
+    // Initiates a backend state read to fetch all entries if necessary.
+    private Future<Iterable<Map.Entry<ByteString, Iterable<V>>>> getFuture(boolean omitValues) {
+      if (complete) {
+        return Futures.immediateFuture(Collections.emptyList());
+      } else {
+        return reader.multimapFetchAllFuture(omitValues, stateKey, stateFamily, valueCoder);
+      }
+    }
+
+    // Initiates a backend state read to fetch a single entry if necessary.
+    private Future<Iterable<V>> getFutureForKey(K key) {
+      try {
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        keyCoder.encode(key, keyStream);
+        return reader.multimapFetchSingleEntryFuture(
+            keyStream.toByteString(), stateKey, stateFamily, valueCoder);
+      } catch (IOException e) {
+        throw new RuntimeException(e);
+      }
+    }
+
+    @Override
+    public ReadableState<Iterable<V>> get(K key) {
+      return new ReadableState<Iterable<V>>() {
+        final Object structuralKey = keyCoder.structuralValue(key);
+
+        @Override
+        public Iterable<V> read() {
+          KeyState keyState = keyStateMap.computeIfAbsent(structuralKey, k -> new KeyState(key));
+          if (keyState.existence == KeyExistence.KNOWN_NONEXISTENT) {
+            return Collections.emptyList();
+          }
+          if (allKeysKnown && keyState.existence == KeyExistence.UNKNOWN_EXISTENCE) {
+            keyStateMap.remove(structuralKey);
+            return Collections.emptyList();
+          }
+
+          if (localRemovals.contains(structuralKey)) {
+            // this key has been removed locally but the removal hasn't been sent to windmill,
+            // thus values in windmill(if any) are obsolete, and we only care about local values.
+            return Iterables.unmodifiableIterable(localAdditions.get(structuralKey));
+          }
+          if (keyState.valuesCached || complete) {
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(keyState.values, localAdditions.get(structuralKey)));
+          }
+          Future<Iterable<V>> persistedData = getFutureForKey(key);
+          try (Closeable scope = scopedReadState()) {
+            final Iterable<V> persistedValues = persistedData.get();
+            if (Iterables.isEmpty(persistedValues)) {
+              Collection<V> local = localAdditions.get(structuralKey);
+              if (local.isEmpty()) {
+                // empty in both cache and windmill, remove key from cache.
+                keyStateMap.remove(structuralKey);
+                return Collections.emptyList();
+              }
+              return Iterables.unmodifiableIterable(local);
+            }
+            if (persistedValues instanceof Weighted) {
+              keyState.existence = KeyExistence.KNOWN_EXIST;
+              keyState.valuesCached = true;
+              ConcatIterables<V> it = new ConcatIterables<>();
+              it.extendWith(persistedValues);
+              keyState.values = it;
+            }
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(persistedValues, localAdditions.get(structuralKey)));
+          } catch (InterruptedException | ExecutionException | IOException e) {
+            if (e instanceof InterruptedException) {
+              Thread.currentThread().interrupt();
+            }
+            throw new RuntimeException("Unable to read Multimap state", e);
+          }
+        }
+
+        @Override
+        @SuppressWarnings("FutureReturnValueIgnored")
+        public ReadableState<Iterable<V>> readLater() {
+          WindmillMultimap.this.getFutureForKey(key);
+          return this;
+        }
+      };
+    }
+
+    @Override
+    protected WorkItemCommitRequest persistDirectly(WindmillStateCache.ForKeyAndFamily cache)
+        throws IOException {
+      if (!cleared && localAdditions.isEmpty() && localRemovals.isEmpty()) {
+        return WorkItemCommitRequest.newBuilder().buildPartial();
+      }
+      WorkItemCommitRequest.Builder commitBuilder = WorkItemCommitRequest.newBuilder();
+      Windmill.TagMultimapUpdateRequest.Builder builder = null;
+      if (cleared) {
+        builder = commitBuilder.addMultimapUpdatesBuilder();
+        builder.setDeleteAll(true);
+        cleared = false;
+      }
+      Set<Object> keysWithUpdates = Sets.newHashSet();
+      keysWithUpdates.addAll(localRemovals);
+      keysWithUpdates.addAll(localAdditions.keySet());
+      if (!keysWithUpdates.isEmpty() && builder == null) {
+        builder = commitBuilder.addMultimapUpdatesBuilder();
+        builder.setTag(stateKey).setStateFamily(stateFamily);
+      }
+      for (Object structuralKey : keysWithUpdates) {
+        KeyState keyState = keyStateMap.get(structuralKey);
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        keyCoder.encode(keyState.originalKey, keyStream);
+        ByteString encodedKey = keyStream.toByteString();
+        Windmill.TagMultimapEntry.Builder entryBuilder = builder.addUpdatesBuilder();
+        entryBuilder.setEntryName(encodedKey);
+        entryBuilder.setDeleteAll(localRemovals.contains(structuralKey));
+        for (V value : localAdditions.get(structuralKey)) {
+          ByteStringOutputStream valueStream = new ByteStringOutputStream();
+          valueCoder.encode(value, valueStream);
+          ByteString encodedValue = valueStream.toByteString();
+          entryBuilder.addValues(encodedValue);
+        }
+        // Move newly added values from localAdditions to cachedEntries as those new values now are
+        // also persisted in Windmill. If a key now has no more values and is not KNOWN_EXIST,
+        // remove it from cache.
+        if (keyState.valuesCached) {
+          keyState.values.extendWith(localAdditions.get(structuralKey));
+        } else {
+          if (keyState.existence != KeyExistence.KNOWN_EXIST) keyStateMap.remove(structuralKey);
+        }
+      }
+
+      localRemovals = Sets.newHashSet();
+      localAdditions = ArrayListMultimap.create();
+
+      cache.put(namespace, address, this, 1);
+
+      return commitBuilder.buildPartial();
+    }
+
+    @Override
+    public void remove(K key) {
+      final Object structuralKey = keyCoder.structuralValue(key);
+      KeyState keyState = keyStateMap.computeIfAbsent(structuralKey, k -> new KeyState(key));
+      if (keyState.existence == KeyExistence.KNOWN_NONEXISTENT
+          || (allKeysKnown && keyState.existence == KeyExistence.UNKNOWN_EXISTENCE)) {
+        return;
+      }
+      if (keyState.valuesCached || !complete) {
+        // there may be data in windmill that need to be removed.
+        localRemovals.add(structuralKey);
+        keyState.values = new ConcatIterables<>();
+        keyState.valuesCached = false;
+        keyState.existence = KeyExistence.KNOWN_NONEXISTENT;
+      } else {
+        // no data in windmill, deleting from local cache is sufficient.
+        keyStateMap.remove(structuralKey);
+      }
+      localAdditions.removeAll(structuralKey);
+    }
+
+    @Override
+    public void clear() {
+      keyStateMap = Maps.newHashMap();
+      localAdditions = ArrayListMultimap.create();
+      localRemovals = Sets.newHashSet();
+      cleared = true;
+      complete = true;
+      allKeysKnown = true;
+    }
+
+    @Override
+    public ReadableState<Iterable<K>> keys() {
+      return new ReadableState<Iterable<K>>() {
+        @Override
+        public Iterable<K> read() {
+          if (allKeysKnown) {
+            return Iterables.unmodifiableIterable(
+                Iterables.transform(
+                    Iterables.filter(
+                        keyStateMap.values(),
+                        keyState -> keyState.existence == KeyExistence.KNOWN_EXIST),
+                    keyState -> keyState.originalKey));
+          }
+          Future<Iterable<Entry<ByteString, Iterable<V>>>> persistedData = getFuture(true);
+          try (Closeable scope = scopedReadState()) {
+            Iterable<Entry<ByteString, Iterable<V>>> entries = persistedData.get();
+            Iterable<Entry<Object, K>> keys =
+                Iterables.transform(
+                    entries,
+                    entry -> {
+                      try {
+                        K originalKey = keyCoder.decode(entry.getKey().newInput());
+                        return new AbstractMap.SimpleEntry<>(
+                            keyCoder.structuralValue(originalKey), originalKey);
+                      } catch (IOException e) {
+                        throw new RuntimeException(e);
+                      }
+                    });
+            keys =
+                Iterables.filter(
+                    keys,
+                    entry -> {
+                      KeyState keyState = keyStateMap.getOrDefault(entry.getKey(), null);
+                      // this is a key that exists in windmill but is not cached.
+                      if (keyState == null) return true;
+                      // if the key is marked as deleted in cache, ignore it.
+                      return keyState.existence != KeyExistence.KNOWN_NONEXISTENT;
+                    });
+            if (entries instanceof Weighted) {
+              // This is a known amount of data, cache them all.
+              keys.forEach(
+                  entry -> {
+                    KeyState keyState =
+                        keyStateMap.computeIfAbsent(
+                            entry.getKey(), stk -> new KeyState(entry.getValue()));
+                    keyState.existence = KeyExistence.KNOWN_EXIST;
+                  });
+              allKeysKnown = true;
+              keyStateMap
+                  .values()
+                  .removeIf(keyState -> keyState.existence != KeyExistence.KNOWN_EXIST);
+              return Iterables.unmodifiableIterable(
+                  Iterables.transform(keyStateMap.values(), keyState -> keyState.originalKey));
+            } else {
+              return Iterables.unmodifiableIterable(
+                  Iterables.concat(
+                      // This is the part of keys that are cached.
+                      Iterables.transform(
+                          Iterables.filter(
+                              keyStateMap.values(),
+                              keyState -> keyState.existence == KeyExistence.KNOWN_EXIST),
+                          keyState -> keyState.originalKey),
+                      // This is the part of the keys returned from Windmill that are not cached.
+                      Iterables.transform(
+                          Iterables.filter(keys, e -> !keyStateMap.containsKey(e.getKey())),
+                          entry -> entry.getValue())));
+            }
+          } catch (InterruptedException | ExecutionException | IOException e) {
+            if (e instanceof InterruptedException) {
+              Thread.currentThread().interrupt();
+            }
+            throw new RuntimeException("Unable to read state", e);
+          }
+        }
+
+        @Override
+        @SuppressWarnings("FutureReturnValueIgnored")
+        public ReadableState<Iterable<K>> readLater() {
+          WindmillMultimap.this.getFuture(true);
+          return this;
+        }
+      };
+    }
+
+    private MultimapIterables<K, V> mergedCachedEntries() {
+      MultimapIterables<K, V> result = new MultimapIterables<>();
+      for (Entry<Object, Collection<V>> entry : localAdditions.asMap().entrySet()) {

Review Comment:
   you could merge the localAdditions/localRemovals/keyStateMap since they are all keyed by the same key.  That would simplify some of this merging and could reduce lookups in places where you add to localAdditions and set the keyState to known.



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternals.java:
##########
@@ -1588,7 +1599,522 @@ private Future<Iterable<Map.Entry<ByteString, V>>> getFuture() {
         return reader.valuePrefixFuture(stateKeyPrefix, stateFamily, valueCoder);
       }
     }
-  };
+  }
+
+  private static class WindmillMultimap<K, V> extends SimpleWindmillState
+      implements MultimapState<K, V> {
+
+    private final StateNamespace namespace;
+    private final StateTag<MultimapState<K, V>> address;
+    private final ByteString stateKey;
+    private final String stateFamily;
+    private final Coder<K> keyCoder;
+    private final Coder<V> valueCoder;
+
+    private class KeyState {
+      final K originalKey;
+      KeyExistence existence;
+      // valuesCached can be true if only existence == KNOWN_EXIST and all values of this key is
+      // cached(both KeyState#values and localAdditions).
+      boolean valuesCached;
+      // represents the values in windmill. When new values are added, they are added to
+      // localAdditions but not KeyState#values. New values will be added to KeyState#values only
+      // after they are persisted into windmill and removed from localAdditions.
+      ConcatIterables<V> values;
+
+      KeyState(K originalKey) {
+        this.originalKey = originalKey;
+        existence = KeyExistence.UNKNOWN_EXISTENCE;
+        valuesCached = false;
+        values = new ConcatIterables<>();
+      }
+    }
+
+    private enum KeyExistence {
+      // this key is known to exist
+      KNOWN_EXIST,
+      // this key is known to be nonexistent
+      KNOWN_NONEXISTENT,
+      // we don't know if this key is in this multimap, this is just to provide a mapping between
+      // the original key and the structural key.
+      UNKNOWN_EXISTENCE
+    }
+
+    private boolean cleared = false;
+    // We use the structural value of the keys as the key in keyStateMap, so that different java
+    // Objects with the same content will be treated as the same Multimap key.
+    private Map<Object, KeyState> keyStateMap = Maps.newHashMap();
+    // If true, all keys are cached in keyStateMap with existence == KNOWN_EXIST.
+    private boolean allKeysKnown = false;
+
+    private boolean complete = false;
+    // All keys that have new values pending write to windmill.
+    private Multimap<Object, V> localAdditions = ArrayListMultimap.create();
+    // All keys that are pending delete. If a key exist in both localRemovals and localAdditions:
+    // new values in localAdditions will be added after old values are removed.
+    private Set<Object> localRemovals = Sets.newHashSet();
+
+    private WindmillMultimap(
+        StateNamespace namespace,
+        StateTag<MultimapState<K, V>> address,
+        String stateFamily,
+        Coder<K> keyCoder,
+        Coder<V> valueCoder,
+        boolean isNewShardingKey) {
+      this.namespace = namespace;
+      this.address = address;
+      this.stateKey = encodeKey(namespace, address);
+      this.stateFamily = stateFamily;
+      this.keyCoder = keyCoder;
+      this.valueCoder = valueCoder;
+      this.complete = isNewShardingKey;
+      this.allKeysKnown = isNewShardingKey;
+    }
+
+    @Override
+    public void put(K key, V value) {
+      final Object structuralKey = keyCoder.structuralValue(key);
+      localAdditions.put(structuralKey, value);
+      keyStateMap.compute(
+          structuralKey,
+          (k, v) -> {
+            if (v == null) v = new KeyState(key);
+            v.existence = KeyExistence.KNOWN_EXIST;
+            return v;
+          });
+    }
+
+    // Initiates a backend state read to fetch all entries if necessary.
+    private Future<Iterable<Map.Entry<ByteString, Iterable<V>>>> getFuture(boolean omitValues) {
+      if (complete) {
+        return Futures.immediateFuture(Collections.emptyList());
+      } else {
+        return reader.multimapFetchAllFuture(omitValues, stateKey, stateFamily, valueCoder);
+      }
+    }
+
+    // Initiates a backend state read to fetch a single entry if necessary.
+    private Future<Iterable<V>> getFutureForKey(K key) {
+      try {
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        keyCoder.encode(key, keyStream);
+        return reader.multimapFetchSingleEntryFuture(
+            keyStream.toByteString(), stateKey, stateFamily, valueCoder);
+      } catch (IOException e) {
+        throw new RuntimeException(e);
+      }
+    }
+
+    @Override
+    public ReadableState<Iterable<V>> get(K key) {
+      return new ReadableState<Iterable<V>>() {
+        final Object structuralKey = keyCoder.structuralValue(key);
+
+        @Override
+        public Iterable<V> read() {
+          KeyState keyState = keyStateMap.computeIfAbsent(structuralKey, k -> new KeyState(key));
+          if (keyState.existence == KeyExistence.KNOWN_NONEXISTENT) {
+            return Collections.emptyList();
+          }
+          if (allKeysKnown && keyState.existence == KeyExistence.UNKNOWN_EXISTENCE) {
+            keyStateMap.remove(structuralKey);
+            return Collections.emptyList();
+          }
+
+          if (localRemovals.contains(structuralKey)) {
+            // this key has been removed locally but the removal hasn't been sent to windmill,
+            // thus values in windmill(if any) are obsolete, and we only care about local values.
+            return Iterables.unmodifiableIterable(localAdditions.get(structuralKey));
+          }
+          if (keyState.valuesCached || complete) {
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(keyState.values, localAdditions.get(structuralKey)));
+          }
+          Future<Iterable<V>> persistedData = getFutureForKey(key);
+          try (Closeable scope = scopedReadState()) {
+            final Iterable<V> persistedValues = persistedData.get();
+            if (Iterables.isEmpty(persistedValues)) {
+              Collection<V> local = localAdditions.get(structuralKey);
+              if (local.isEmpty()) {
+                // empty in both cache and windmill, remove key from cache.
+                keyStateMap.remove(structuralKey);
+                return Collections.emptyList();
+              }
+              return Iterables.unmodifiableIterable(local);
+            }
+            if (persistedValues instanceof Weighted) {
+              keyState.existence = KeyExistence.KNOWN_EXIST;
+              keyState.valuesCached = true;
+              ConcatIterables<V> it = new ConcatIterables<>();
+              it.extendWith(persistedValues);
+              keyState.values = it;
+            }
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(persistedValues, localAdditions.get(structuralKey)));
+          } catch (InterruptedException | ExecutionException | IOException e) {
+            if (e instanceof InterruptedException) {
+              Thread.currentThread().interrupt();
+            }
+            throw new RuntimeException("Unable to read Multimap state", e);
+          }
+        }
+
+        @Override
+        @SuppressWarnings("FutureReturnValueIgnored")
+        public ReadableState<Iterable<V>> readLater() {
+          WindmillMultimap.this.getFutureForKey(key);
+          return this;
+        }
+      };
+    }
+
+    @Override
+    protected WorkItemCommitRequest persistDirectly(WindmillStateCache.ForKeyAndFamily cache)
+        throws IOException {
+      if (!cleared && localAdditions.isEmpty() && localRemovals.isEmpty()) {
+        return WorkItemCommitRequest.newBuilder().buildPartial();
+      }
+      WorkItemCommitRequest.Builder commitBuilder = WorkItemCommitRequest.newBuilder();
+      Windmill.TagMultimapUpdateRequest.Builder builder = null;
+      if (cleared) {
+        builder = commitBuilder.addMultimapUpdatesBuilder();
+        builder.setDeleteAll(true);
+        cleared = false;
+      }
+      Set<Object> keysWithUpdates = Sets.newHashSet();
+      keysWithUpdates.addAll(localRemovals);
+      keysWithUpdates.addAll(localAdditions.keySet());
+      if (!keysWithUpdates.isEmpty() && builder == null) {
+        builder = commitBuilder.addMultimapUpdatesBuilder();
+        builder.setTag(stateKey).setStateFamily(stateFamily);
+      }
+      for (Object structuralKey : keysWithUpdates) {
+        KeyState keyState = keyStateMap.get(structuralKey);
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        keyCoder.encode(keyState.originalKey, keyStream);
+        ByteString encodedKey = keyStream.toByteString();
+        Windmill.TagMultimapEntry.Builder entryBuilder = builder.addUpdatesBuilder();
+        entryBuilder.setEntryName(encodedKey);
+        entryBuilder.setDeleteAll(localRemovals.contains(structuralKey));
+        for (V value : localAdditions.get(structuralKey)) {
+          ByteStringOutputStream valueStream = new ByteStringOutputStream();
+          valueCoder.encode(value, valueStream);
+          ByteString encodedValue = valueStream.toByteString();
+          entryBuilder.addValues(encodedValue);
+        }
+        // Move newly added values from localAdditions to cachedEntries as those new values now are
+        // also persisted in Windmill. If a key now has no more values and is not KNOWN_EXIST,
+        // remove it from cache.
+        if (keyState.valuesCached) {
+          keyState.values.extendWith(localAdditions.get(structuralKey));
+        } else {
+          if (keyState.existence != KeyExistence.KNOWN_EXIST) keyStateMap.remove(structuralKey);
+        }
+      }
+
+      localRemovals = Sets.newHashSet();
+      localAdditions = ArrayListMultimap.create();
+
+      cache.put(namespace, address, this, 1);
+
+      return commitBuilder.buildPartial();
+    }
+
+    @Override
+    public void remove(K key) {
+      final Object structuralKey = keyCoder.structuralValue(key);
+      KeyState keyState = keyStateMap.computeIfAbsent(structuralKey, k -> new KeyState(key));
+      if (keyState.existence == KeyExistence.KNOWN_NONEXISTENT
+          || (allKeysKnown && keyState.existence == KeyExistence.UNKNOWN_EXISTENCE)) {
+        return;
+      }
+      if (keyState.valuesCached || !complete) {
+        // there may be data in windmill that need to be removed.
+        localRemovals.add(structuralKey);
+        keyState.values = new ConcatIterables<>();
+        keyState.valuesCached = false;
+        keyState.existence = KeyExistence.KNOWN_NONEXISTENT;
+      } else {
+        // no data in windmill, deleting from local cache is sufficient.
+        keyStateMap.remove(structuralKey);
+      }
+      localAdditions.removeAll(structuralKey);
+    }
+
+    @Override
+    public void clear() {
+      keyStateMap = Maps.newHashMap();
+      localAdditions = ArrayListMultimap.create();
+      localRemovals = Sets.newHashSet();
+      cleared = true;
+      complete = true;
+      allKeysKnown = true;
+    }
+
+    @Override
+    public ReadableState<Iterable<K>> keys() {
+      return new ReadableState<Iterable<K>>() {
+        @Override
+        public Iterable<K> read() {
+          if (allKeysKnown) {
+            return Iterables.unmodifiableIterable(
+                Iterables.transform(
+                    Iterables.filter(
+                        keyStateMap.values(),
+                        keyState -> keyState.existence == KeyExistence.KNOWN_EXIST),
+                    keyState -> keyState.originalKey));
+          }
+          Future<Iterable<Entry<ByteString, Iterable<V>>>> persistedData = getFuture(true);
+          try (Closeable scope = scopedReadState()) {
+            Iterable<Entry<ByteString, Iterable<V>>> entries = persistedData.get();
+            Iterable<Entry<Object, K>> keys =
+                Iterables.transform(
+                    entries,
+                    entry -> {
+                      try {
+                        K originalKey = keyCoder.decode(entry.getKey().newInput());
+                        return new AbstractMap.SimpleEntry<>(
+                            keyCoder.structuralValue(originalKey), originalKey);
+                      } catch (IOException e) {
+                        throw new RuntimeException(e);
+                      }
+                    });
+            keys =
+                Iterables.filter(
+                    keys,
+                    entry -> {
+                      KeyState keyState = keyStateMap.getOrDefault(entry.getKey(), null);
+                      // this is a key that exists in windmill but is not cached.
+                      if (keyState == null) return true;
+                      // if the key is marked as deleted in cache, ignore it.
+                      return keyState.existence != KeyExistence.KNOWN_NONEXISTENT;
+                    });
+            if (entries instanceof Weighted) {
+              // This is a known amount of data, cache them all.
+              keys.forEach(
+                  entry -> {
+                    KeyState keyState =
+                        keyStateMap.computeIfAbsent(
+                            entry.getKey(), stk -> new KeyState(entry.getValue()));
+                    keyState.existence = KeyExistence.KNOWN_EXIST;
+                  });
+              allKeysKnown = true;
+              keyStateMap
+                  .values()
+                  .removeIf(keyState -> keyState.existence != KeyExistence.KNOWN_EXIST);
+              return Iterables.unmodifiableIterable(
+                  Iterables.transform(keyStateMap.values(), keyState -> keyState.originalKey));
+            } else {
+              return Iterables.unmodifiableIterable(
+                  Iterables.concat(
+                      // This is the part of keys that are cached.
+                      Iterables.transform(
+                          Iterables.filter(
+                              keyStateMap.values(),
+                              keyState -> keyState.existence == KeyExistence.KNOWN_EXIST),
+                          keyState -> keyState.originalKey),
+                      // This is the part of the keys returned from Windmill that are not cached.
+                      Iterables.transform(
+                          Iterables.filter(keys, e -> !keyStateMap.containsKey(e.getKey())),
+                          entry -> entry.getValue())));
+            }
+          } catch (InterruptedException | ExecutionException | IOException e) {
+            if (e instanceof InterruptedException) {
+              Thread.currentThread().interrupt();
+            }
+            throw new RuntimeException("Unable to read state", e);
+          }
+        }
+
+        @Override
+        @SuppressWarnings("FutureReturnValueIgnored")
+        public ReadableState<Iterable<K>> readLater() {
+          WindmillMultimap.this.getFuture(true);
+          return this;
+        }
+      };
+    }
+
+    private MultimapIterables<K, V> mergedCachedEntries() {
+      MultimapIterables<K, V> result = new MultimapIterables<>();
+      for (Entry<Object, Collection<V>> entry : localAdditions.asMap().entrySet()) {
+        K key = keyStateMap.get(entry.getKey()).originalKey;
+        result.extendWith(key, entry.getValue());
+      }
+      for (Entry<Object, KeyState> entry : keyStateMap.entrySet()) {
+        if (entry.getValue().valuesCached) {
+          result.extendWith(entry.getValue().originalKey, entry.getValue().values);
+        }
+      }
+      return result;
+    }
+
+    private static class MultimapIterables<K, V> implements Iterable<Entry<K, V>> {
+      Map<K, ConcatIterables<V>> map;
+
+      public MultimapIterables() {
+        this.map = new HashMap<>();
+      }
+
+      public void extendWith(K key, Iterable<V> iterable) {
+        map.compute(
+            key,
+            (k, v) -> {
+              if (v == null) v = new ConcatIterables<>();
+              v.extendWith(iterable);
+              return v;
+            });
+      }
+
+      @Override
+      public Iterator<Entry<K, V>> iterator() {
+        return Iterators.concat(
+            Iterables.transform(
+                    map.entrySet(),
+                    entry ->
+                        Iterables.transform(
+                                entry.getValue(),
+                                v -> new AbstractMap.SimpleEntry<>(entry.getKey(), v))
+                            .iterator())
+                .iterator());
+      }
+    }
+
+    @Override
+    public ReadableState<Iterable<Entry<K, V>>> entries() {
+      return new ReadableState<Iterable<Entry<K, V>>>() {
+        @Override
+        public Iterable<Entry<K, V>> read() {
+          if (complete) {
+            return Iterables.unmodifiableIterable(mergedCachedEntries());
+          }
+          Future<Iterable<Entry<ByteString, Iterable<V>>>> persistedData = getFuture(false);
+          try (Closeable scope = scopedReadState()) {
+            Iterable<Entry<ByteString, Iterable<V>>> entries = persistedData.get();
+            Map<Object, ConcatIterables<V>> entryMap = Maps.newHashMap();
+            entries.forEach(
+                entry -> {
+                  try {
+                    final K key = keyCoder.decode(entry.getKey().newInput());
+                    final Object structuralKey = keyCoder.structuralValue(key);
+                    KeyState keyState =
+                        keyStateMap.computeIfAbsent(structuralKey, k -> new KeyState(key));
+                    if (keyState.existence == KeyExistence.KNOWN_NONEXISTENT) return;
+                    entryMap.compute(
+                        structuralKey,
+                        (k, v) -> {
+                          if (v == null) v = new ConcatIterables<>();

Review Comment:
   do we expect the same structure key to be returned by persistent state?



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternals.java:
##########
@@ -1588,7 +1599,522 @@ private Future<Iterable<Map.Entry<ByteString, V>>> getFuture() {
         return reader.valuePrefixFuture(stateKeyPrefix, stateFamily, valueCoder);
       }
     }
-  };
+  }
+
+  private static class WindmillMultimap<K, V> extends SimpleWindmillState
+      implements MultimapState<K, V> {
+
+    private final StateNamespace namespace;
+    private final StateTag<MultimapState<K, V>> address;
+    private final ByteString stateKey;
+    private final String stateFamily;
+    private final Coder<K> keyCoder;
+    private final Coder<V> valueCoder;
+
+    private class KeyState {
+      final K originalKey;
+      KeyExistence existence;
+      // valuesCached can be true if only existence == KNOWN_EXIST and all values of this key is
+      // cached(both KeyState#values and localAdditions).
+      boolean valuesCached;
+      // represents the values in windmill. When new values are added, they are added to
+      // localAdditions but not KeyState#values. New values will be added to KeyState#values only
+      // after they are persisted into windmill and removed from localAdditions.
+      ConcatIterables<V> values;
+
+      KeyState(K originalKey) {
+        this.originalKey = originalKey;
+        existence = KeyExistence.UNKNOWN_EXISTENCE;
+        valuesCached = false;
+        values = new ConcatIterables<>();
+      }
+    }
+
+    private enum KeyExistence {
+      // this key is known to exist
+      KNOWN_EXIST,
+      // this key is known to be nonexistent
+      KNOWN_NONEXISTENT,
+      // we don't know if this key is in this multimap, this is just to provide a mapping between
+      // the original key and the structural key.
+      UNKNOWN_EXISTENCE
+    }
+
+    private boolean cleared = false;
+    // We use the structural value of the keys as the key in keyStateMap, so that different java
+    // Objects with the same content will be treated as the same Multimap key.
+    private Map<Object, KeyState> keyStateMap = Maps.newHashMap();
+    // If true, all keys are cached in keyStateMap with existence == KNOWN_EXIST.
+    private boolean allKeysKnown = false;
+
+    private boolean complete = false;
+    // All keys that have new values pending write to windmill.
+    private Multimap<Object, V> localAdditions = ArrayListMultimap.create();
+    // All keys that are pending delete. If a key exist in both localRemovals and localAdditions:
+    // new values in localAdditions will be added after old values are removed.
+    private Set<Object> localRemovals = Sets.newHashSet();
+
+    private WindmillMultimap(
+        StateNamespace namespace,
+        StateTag<MultimapState<K, V>> address,
+        String stateFamily,
+        Coder<K> keyCoder,
+        Coder<V> valueCoder,
+        boolean isNewShardingKey) {
+      this.namespace = namespace;
+      this.address = address;
+      this.stateKey = encodeKey(namespace, address);
+      this.stateFamily = stateFamily;
+      this.keyCoder = keyCoder;
+      this.valueCoder = valueCoder;
+      this.complete = isNewShardingKey;
+      this.allKeysKnown = isNewShardingKey;
+    }
+
+    @Override
+    public void put(K key, V value) {
+      final Object structuralKey = keyCoder.structuralValue(key);
+      localAdditions.put(structuralKey, value);
+      keyStateMap.compute(
+          structuralKey,
+          (k, v) -> {
+            if (v == null) v = new KeyState(key);
+            v.existence = KeyExistence.KNOWN_EXIST;
+            return v;
+          });
+    }
+
+    // Initiates a backend state read to fetch all entries if necessary.
+    private Future<Iterable<Map.Entry<ByteString, Iterable<V>>>> getFuture(boolean omitValues) {
+      if (complete) {
+        return Futures.immediateFuture(Collections.emptyList());
+      } else {
+        return reader.multimapFetchAllFuture(omitValues, stateKey, stateFamily, valueCoder);
+      }
+    }
+
+    // Initiates a backend state read to fetch a single entry if necessary.
+    private Future<Iterable<V>> getFutureForKey(K key) {
+      try {
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        keyCoder.encode(key, keyStream);
+        return reader.multimapFetchSingleEntryFuture(
+            keyStream.toByteString(), stateKey, stateFamily, valueCoder);
+      } catch (IOException e) {
+        throw new RuntimeException(e);
+      }
+    }
+
+    @Override
+    public ReadableState<Iterable<V>> get(K key) {
+      return new ReadableState<Iterable<V>>() {
+        final Object structuralKey = keyCoder.structuralValue(key);
+
+        @Override
+        public Iterable<V> read() {
+          KeyState keyState = keyStateMap.computeIfAbsent(structuralKey, k -> new KeyState(key));
+          if (keyState.existence == KeyExistence.KNOWN_NONEXISTENT) {
+            return Collections.emptyList();
+          }
+          if (allKeysKnown && keyState.existence == KeyExistence.UNKNOWN_EXISTENCE) {
+            keyStateMap.remove(structuralKey);
+            return Collections.emptyList();
+          }
+
+          if (localRemovals.contains(structuralKey)) {
+            // this key has been removed locally but the removal hasn't been sent to windmill,
+            // thus values in windmill(if any) are obsolete, and we only care about local values.
+            return Iterables.unmodifiableIterable(localAdditions.get(structuralKey));
+          }
+          if (keyState.valuesCached || complete) {
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(keyState.values, localAdditions.get(structuralKey)));
+          }
+          Future<Iterable<V>> persistedData = getFutureForKey(key);
+          try (Closeable scope = scopedReadState()) {
+            final Iterable<V> persistedValues = persistedData.get();
+            if (Iterables.isEmpty(persistedValues)) {
+              Collection<V> local = localAdditions.get(structuralKey);
+              if (local.isEmpty()) {
+                // empty in both cache and windmill, remove key from cache.
+                keyStateMap.remove(structuralKey);
+                return Collections.emptyList();
+              }
+              return Iterables.unmodifiableIterable(local);
+            }
+            if (persistedValues instanceof Weighted) {
+              keyState.existence = KeyExistence.KNOWN_EXIST;
+              keyState.valuesCached = true;
+              ConcatIterables<V> it = new ConcatIterables<>();
+              it.extendWith(persistedValues);
+              keyState.values = it;
+            }
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(persistedValues, localAdditions.get(structuralKey)));
+          } catch (InterruptedException | ExecutionException | IOException e) {
+            if (e instanceof InterruptedException) {
+              Thread.currentThread().interrupt();
+            }
+            throw new RuntimeException("Unable to read Multimap state", e);
+          }
+        }
+
+        @Override
+        @SuppressWarnings("FutureReturnValueIgnored")
+        public ReadableState<Iterable<V>> readLater() {
+          WindmillMultimap.this.getFutureForKey(key);
+          return this;
+        }
+      };
+    }
+
+    @Override
+    protected WorkItemCommitRequest persistDirectly(WindmillStateCache.ForKeyAndFamily cache)
+        throws IOException {
+      if (!cleared && localAdditions.isEmpty() && localRemovals.isEmpty()) {
+        return WorkItemCommitRequest.newBuilder().buildPartial();
+      }
+      WorkItemCommitRequest.Builder commitBuilder = WorkItemCommitRequest.newBuilder();
+      Windmill.TagMultimapUpdateRequest.Builder builder = null;
+      if (cleared) {
+        builder = commitBuilder.addMultimapUpdatesBuilder();
+        builder.setDeleteAll(true);
+        cleared = false;
+      }
+      Set<Object> keysWithUpdates = Sets.newHashSet();
+      keysWithUpdates.addAll(localRemovals);
+      keysWithUpdates.addAll(localAdditions.keySet());
+      if (!keysWithUpdates.isEmpty() && builder == null) {
+        builder = commitBuilder.addMultimapUpdatesBuilder();
+        builder.setTag(stateKey).setStateFamily(stateFamily);
+      }
+      for (Object structuralKey : keysWithUpdates) {
+        KeyState keyState = keyStateMap.get(structuralKey);
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        keyCoder.encode(keyState.originalKey, keyStream);
+        ByteString encodedKey = keyStream.toByteString();
+        Windmill.TagMultimapEntry.Builder entryBuilder = builder.addUpdatesBuilder();
+        entryBuilder.setEntryName(encodedKey);
+        entryBuilder.setDeleteAll(localRemovals.contains(structuralKey));
+        for (V value : localAdditions.get(structuralKey)) {
+          ByteStringOutputStream valueStream = new ByteStringOutputStream();
+          valueCoder.encode(value, valueStream);
+          ByteString encodedValue = valueStream.toByteString();
+          entryBuilder.addValues(encodedValue);
+        }
+        // Move newly added values from localAdditions to cachedEntries as those new values now are
+        // also persisted in Windmill. If a key now has no more values and is not KNOWN_EXIST,
+        // remove it from cache.
+        if (keyState.valuesCached) {
+          keyState.values.extendWith(localAdditions.get(structuralKey));
+        } else {
+          if (keyState.existence != KeyExistence.KNOWN_EXIST) keyStateMap.remove(structuralKey);
+        }
+      }
+
+      localRemovals = Sets.newHashSet();
+      localAdditions = ArrayListMultimap.create();
+
+      cache.put(namespace, address, this, 1);
+
+      return commitBuilder.buildPartial();
+    }
+
+    @Override
+    public void remove(K key) {
+      final Object structuralKey = keyCoder.structuralValue(key);
+      KeyState keyState = keyStateMap.computeIfAbsent(structuralKey, k -> new KeyState(key));
+      if (keyState.existence == KeyExistence.KNOWN_NONEXISTENT
+          || (allKeysKnown && keyState.existence == KeyExistence.UNKNOWN_EXISTENCE)) {
+        return;
+      }
+      if (keyState.valuesCached || !complete) {
+        // there may be data in windmill that need to be removed.
+        localRemovals.add(structuralKey);
+        keyState.values = new ConcatIterables<>();
+        keyState.valuesCached = false;
+        keyState.existence = KeyExistence.KNOWN_NONEXISTENT;
+      } else {
+        // no data in windmill, deleting from local cache is sufficient.
+        keyStateMap.remove(structuralKey);
+      }
+      localAdditions.removeAll(structuralKey);
+    }
+
+    @Override
+    public void clear() {
+      keyStateMap = Maps.newHashMap();
+      localAdditions = ArrayListMultimap.create();
+      localRemovals = Sets.newHashSet();
+      cleared = true;
+      complete = true;
+      allKeysKnown = true;
+    }
+
+    @Override
+    public ReadableState<Iterable<K>> keys() {
+      return new ReadableState<Iterable<K>>() {
+        @Override
+        public Iterable<K> read() {
+          if (allKeysKnown) {
+            return Iterables.unmodifiableIterable(
+                Iterables.transform(
+                    Iterables.filter(
+                        keyStateMap.values(),
+                        keyState -> keyState.existence == KeyExistence.KNOWN_EXIST),
+                    keyState -> keyState.originalKey));
+          }
+          Future<Iterable<Entry<ByteString, Iterable<V>>>> persistedData = getFuture(true);
+          try (Closeable scope = scopedReadState()) {
+            Iterable<Entry<ByteString, Iterable<V>>> entries = persistedData.get();
+            Iterable<Entry<Object, K>> keys =
+                Iterables.transform(
+                    entries,
+                    entry -> {
+                      try {
+                        K originalKey = keyCoder.decode(entry.getKey().newInput());
+                        return new AbstractMap.SimpleEntry<>(
+                            keyCoder.structuralValue(originalKey), originalKey);
+                      } catch (IOException e) {
+                        throw new RuntimeException(e);
+                      }
+                    });
+            keys =
+                Iterables.filter(
+                    keys,
+                    entry -> {
+                      KeyState keyState = keyStateMap.getOrDefault(entry.getKey(), null);
+                      // this is a key that exists in windmill but is not cached.
+                      if (keyState == null) return true;
+                      // if the key is marked as deleted in cache, ignore it.
+                      return keyState.existence != KeyExistence.KNOWN_NONEXISTENT;
+                    });
+            if (entries instanceof Weighted) {
+              // This is a known amount of data, cache them all.
+              keys.forEach(
+                  entry -> {
+                    KeyState keyState =
+                        keyStateMap.computeIfAbsent(
+                            entry.getKey(), stk -> new KeyState(entry.getValue()));
+                    keyState.existence = KeyExistence.KNOWN_EXIST;
+                  });
+              allKeysKnown = true;
+              keyStateMap
+                  .values()
+                  .removeIf(keyState -> keyState.existence != KeyExistence.KNOWN_EXIST);
+              return Iterables.unmodifiableIterable(
+                  Iterables.transform(keyStateMap.values(), keyState -> keyState.originalKey));
+            } else {
+              return Iterables.unmodifiableIterable(
+                  Iterables.concat(
+                      // This is the part of keys that are cached.

Review Comment:
   just checking that we don't need to provide ordering guarantees.  Otherwise we should interleave the cache and fetched results.



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternals.java:
##########
@@ -1588,7 +1599,522 @@ private Future<Iterable<Map.Entry<ByteString, V>>> getFuture() {
         return reader.valuePrefixFuture(stateKeyPrefix, stateFamily, valueCoder);
       }
     }
-  };
+  }
+
+  private static class WindmillMultimap<K, V> extends SimpleWindmillState
+      implements MultimapState<K, V> {
+
+    private final StateNamespace namespace;
+    private final StateTag<MultimapState<K, V>> address;
+    private final ByteString stateKey;
+    private final String stateFamily;
+    private final Coder<K> keyCoder;
+    private final Coder<V> valueCoder;
+
+    private class KeyState {
+      final K originalKey;
+      KeyExistence existence;
+      // valuesCached can be true if only existence == KNOWN_EXIST and all values of this key is
+      // cached(both KeyState#values and localAdditions).
+      boolean valuesCached;
+      // represents the values in windmill. When new values are added, they are added to
+      // localAdditions but not KeyState#values. New values will be added to KeyState#values only
+      // after they are persisted into windmill and removed from localAdditions.
+      ConcatIterables<V> values;
+
+      KeyState(K originalKey) {
+        this.originalKey = originalKey;
+        existence = KeyExistence.UNKNOWN_EXISTENCE;
+        valuesCached = false;
+        values = new ConcatIterables<>();
+      }
+    }
+
+    private enum KeyExistence {
+      // this key is known to exist
+      KNOWN_EXIST,
+      // this key is known to be nonexistent
+      KNOWN_NONEXISTENT,
+      // we don't know if this key is in this multimap, this is just to provide a mapping between
+      // the original key and the structural key.
+      UNKNOWN_EXISTENCE
+    }
+
+    private boolean cleared = false;
+    // We use the structural value of the keys as the key in keyStateMap, so that different java
+    // Objects with the same content will be treated as the same Multimap key.
+    private Map<Object, KeyState> keyStateMap = Maps.newHashMap();
+    // If true, all keys are cached in keyStateMap with existence == KNOWN_EXIST.
+    private boolean allKeysKnown = false;
+
+    private boolean complete = false;
+    // All keys that have new values pending write to windmill.
+    private Multimap<Object, V> localAdditions = ArrayListMultimap.create();
+    // All keys that are pending delete. If a key exist in both localRemovals and localAdditions:
+    // new values in localAdditions will be added after old values are removed.
+    private Set<Object> localRemovals = Sets.newHashSet();
+
+    private WindmillMultimap(
+        StateNamespace namespace,
+        StateTag<MultimapState<K, V>> address,
+        String stateFamily,
+        Coder<K> keyCoder,
+        Coder<V> valueCoder,
+        boolean isNewShardingKey) {
+      this.namespace = namespace;
+      this.address = address;
+      this.stateKey = encodeKey(namespace, address);
+      this.stateFamily = stateFamily;
+      this.keyCoder = keyCoder;
+      this.valueCoder = valueCoder;
+      this.complete = isNewShardingKey;
+      this.allKeysKnown = isNewShardingKey;
+    }
+
+    @Override
+    public void put(K key, V value) {
+      final Object structuralKey = keyCoder.structuralValue(key);
+      localAdditions.put(structuralKey, value);
+      keyStateMap.compute(
+          structuralKey,
+          (k, v) -> {
+            if (v == null) v = new KeyState(key);
+            v.existence = KeyExistence.KNOWN_EXIST;
+            return v;
+          });
+    }
+
+    // Initiates a backend state read to fetch all entries if necessary.
+    private Future<Iterable<Map.Entry<ByteString, Iterable<V>>>> getFuture(boolean omitValues) {
+      if (complete) {
+        return Futures.immediateFuture(Collections.emptyList());
+      } else {
+        return reader.multimapFetchAllFuture(omitValues, stateKey, stateFamily, valueCoder);
+      }
+    }
+
+    // Initiates a backend state read to fetch a single entry if necessary.
+    private Future<Iterable<V>> getFutureForKey(K key) {
+      try {
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        keyCoder.encode(key, keyStream);
+        return reader.multimapFetchSingleEntryFuture(
+            keyStream.toByteString(), stateKey, stateFamily, valueCoder);
+      } catch (IOException e) {
+        throw new RuntimeException(e);
+      }
+    }
+
+    @Override
+    public ReadableState<Iterable<V>> get(K key) {
+      return new ReadableState<Iterable<V>>() {
+        final Object structuralKey = keyCoder.structuralValue(key);
+
+        @Override
+        public Iterable<V> read() {
+          KeyState keyState = keyStateMap.computeIfAbsent(structuralKey, k -> new KeyState(key));
+          if (keyState.existence == KeyExistence.KNOWN_NONEXISTENT) {
+            return Collections.emptyList();
+          }
+          if (allKeysKnown && keyState.existence == KeyExistence.UNKNOWN_EXISTENCE) {
+            keyStateMap.remove(structuralKey);
+            return Collections.emptyList();
+          }
+
+          if (localRemovals.contains(structuralKey)) {
+            // this key has been removed locally but the removal hasn't been sent to windmill,
+            // thus values in windmill(if any) are obsolete, and we only care about local values.
+            return Iterables.unmodifiableIterable(localAdditions.get(structuralKey));
+          }
+          if (keyState.valuesCached || complete) {
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(keyState.values, localAdditions.get(structuralKey)));
+          }
+          Future<Iterable<V>> persistedData = getFutureForKey(key);
+          try (Closeable scope = scopedReadState()) {
+            final Iterable<V> persistedValues = persistedData.get();
+            if (Iterables.isEmpty(persistedValues)) {
+              Collection<V> local = localAdditions.get(structuralKey);
+              if (local.isEmpty()) {
+                // empty in both cache and windmill, remove key from cache.
+                keyStateMap.remove(structuralKey);
+                return Collections.emptyList();
+              }
+              return Iterables.unmodifiableIterable(local);
+            }
+            if (persistedValues instanceof Weighted) {
+              keyState.existence = KeyExistence.KNOWN_EXIST;
+              keyState.valuesCached = true;
+              ConcatIterables<V> it = new ConcatIterables<>();
+              it.extendWith(persistedValues);
+              keyState.values = it;
+            }
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(persistedValues, localAdditions.get(structuralKey)));
+          } catch (InterruptedException | ExecutionException | IOException e) {
+            if (e instanceof InterruptedException) {
+              Thread.currentThread().interrupt();
+            }
+            throw new RuntimeException("Unable to read Multimap state", e);
+          }
+        }
+
+        @Override
+        @SuppressWarnings("FutureReturnValueIgnored")
+        public ReadableState<Iterable<V>> readLater() {
+          WindmillMultimap.this.getFutureForKey(key);
+          return this;
+        }
+      };
+    }
+
+    @Override
+    protected WorkItemCommitRequest persistDirectly(WindmillStateCache.ForKeyAndFamily cache)
+        throws IOException {
+      if (!cleared && localAdditions.isEmpty() && localRemovals.isEmpty()) {
+        return WorkItemCommitRequest.newBuilder().buildPartial();
+      }
+      WorkItemCommitRequest.Builder commitBuilder = WorkItemCommitRequest.newBuilder();
+      Windmill.TagMultimapUpdateRequest.Builder builder = null;
+      if (cleared) {
+        builder = commitBuilder.addMultimapUpdatesBuilder();
+        builder.setDeleteAll(true);
+        cleared = false;
+      }
+      Set<Object> keysWithUpdates = Sets.newHashSet();
+      keysWithUpdates.addAll(localRemovals);
+      keysWithUpdates.addAll(localAdditions.keySet());
+      if (!keysWithUpdates.isEmpty() && builder == null) {
+        builder = commitBuilder.addMultimapUpdatesBuilder();
+        builder.setTag(stateKey).setStateFamily(stateFamily);
+      }
+      for (Object structuralKey : keysWithUpdates) {
+        KeyState keyState = keyStateMap.get(structuralKey);
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        keyCoder.encode(keyState.originalKey, keyStream);
+        ByteString encodedKey = keyStream.toByteString();
+        Windmill.TagMultimapEntry.Builder entryBuilder = builder.addUpdatesBuilder();
+        entryBuilder.setEntryName(encodedKey);
+        entryBuilder.setDeleteAll(localRemovals.contains(structuralKey));
+        for (V value : localAdditions.get(structuralKey)) {
+          ByteStringOutputStream valueStream = new ByteStringOutputStream();
+          valueCoder.encode(value, valueStream);
+          ByteString encodedValue = valueStream.toByteString();
+          entryBuilder.addValues(encodedValue);
+        }
+        // Move newly added values from localAdditions to cachedEntries as those new values now are
+        // also persisted in Windmill. If a key now has no more values and is not KNOWN_EXIST,
+        // remove it from cache.
+        if (keyState.valuesCached) {
+          keyState.values.extendWith(localAdditions.get(structuralKey));
+        } else {
+          if (keyState.existence != KeyExistence.KNOWN_EXIST) keyStateMap.remove(structuralKey);
+        }
+      }
+
+      localRemovals = Sets.newHashSet();
+      localAdditions = ArrayListMultimap.create();
+
+      cache.put(namespace, address, this, 1);
+
+      return commitBuilder.buildPartial();
+    }
+
+    @Override
+    public void remove(K key) {
+      final Object structuralKey = keyCoder.structuralValue(key);
+      KeyState keyState = keyStateMap.computeIfAbsent(structuralKey, k -> new KeyState(key));
+      if (keyState.existence == KeyExistence.KNOWN_NONEXISTENT
+          || (allKeysKnown && keyState.existence == KeyExistence.UNKNOWN_EXISTENCE)) {
+        return;
+      }
+      if (keyState.valuesCached || !complete) {
+        // there may be data in windmill that need to be removed.
+        localRemovals.add(structuralKey);
+        keyState.values = new ConcatIterables<>();
+        keyState.valuesCached = false;
+        keyState.existence = KeyExistence.KNOWN_NONEXISTENT;
+      } else {
+        // no data in windmill, deleting from local cache is sufficient.
+        keyStateMap.remove(structuralKey);
+      }
+      localAdditions.removeAll(structuralKey);
+    }
+
+    @Override
+    public void clear() {
+      keyStateMap = Maps.newHashMap();
+      localAdditions = ArrayListMultimap.create();
+      localRemovals = Sets.newHashSet();
+      cleared = true;
+      complete = true;
+      allKeysKnown = true;
+    }
+
+    @Override
+    public ReadableState<Iterable<K>> keys() {
+      return new ReadableState<Iterable<K>>() {
+        @Override
+        public Iterable<K> read() {
+          if (allKeysKnown) {
+            return Iterables.unmodifiableIterable(
+                Iterables.transform(
+                    Iterables.filter(
+                        keyStateMap.values(),
+                        keyState -> keyState.existence == KeyExistence.KNOWN_EXIST),
+                    keyState -> keyState.originalKey));
+          }
+          Future<Iterable<Entry<ByteString, Iterable<V>>>> persistedData = getFuture(true);
+          try (Closeable scope = scopedReadState()) {
+            Iterable<Entry<ByteString, Iterable<V>>> entries = persistedData.get();
+            Iterable<Entry<Object, K>> keys =
+                Iterables.transform(
+                    entries,
+                    entry -> {
+                      try {
+                        K originalKey = keyCoder.decode(entry.getKey().newInput());
+                        return new AbstractMap.SimpleEntry<>(
+                            keyCoder.structuralValue(originalKey), originalKey);
+                      } catch (IOException e) {
+                        throw new RuntimeException(e);
+                      }
+                    });
+            keys =
+                Iterables.filter(
+                    keys,
+                    entry -> {
+                      KeyState keyState = keyStateMap.getOrDefault(entry.getKey(), null);
+                      // this is a key that exists in windmill but is not cached.
+                      if (keyState == null) return true;
+                      // if the key is marked as deleted in cache, ignore it.
+                      return keyState.existence != KeyExistence.KNOWN_NONEXISTENT;
+                    });
+            if (entries instanceof Weighted) {
+              // This is a known amount of data, cache them all.
+              keys.forEach(
+                  entry -> {
+                    KeyState keyState =
+                        keyStateMap.computeIfAbsent(
+                            entry.getKey(), stk -> new KeyState(entry.getValue()));
+                    keyState.existence = KeyExistence.KNOWN_EXIST;
+                  });
+              allKeysKnown = true;
+              keyStateMap
+                  .values()
+                  .removeIf(keyState -> keyState.existence != KeyExistence.KNOWN_EXIST);
+              return Iterables.unmodifiableIterable(
+                  Iterables.transform(keyStateMap.values(), keyState -> keyState.originalKey));
+            } else {
+              return Iterables.unmodifiableIterable(
+                  Iterables.concat(
+                      // This is the part of keys that are cached.
+                      Iterables.transform(
+                          Iterables.filter(
+                              keyStateMap.values(),
+                              keyState -> keyState.existence == KeyExistence.KNOWN_EXIST),
+                          keyState -> keyState.originalKey),
+                      // This is the part of the keys returned from Windmill that are not cached.
+                      Iterables.transform(
+                          Iterables.filter(keys, e -> !keyStateMap.containsKey(e.getKey())),
+                          entry -> entry.getValue())));
+            }
+          } catch (InterruptedException | ExecutionException | IOException e) {
+            if (e instanceof InterruptedException) {
+              Thread.currentThread().interrupt();
+            }
+            throw new RuntimeException("Unable to read state", e);
+          }
+        }
+
+        @Override
+        @SuppressWarnings("FutureReturnValueIgnored")
+        public ReadableState<Iterable<K>> readLater() {
+          WindmillMultimap.this.getFuture(true);
+          return this;
+        }
+      };
+    }
+
+    private MultimapIterables<K, V> mergedCachedEntries() {
+      MultimapIterables<K, V> result = new MultimapIterables<>();
+      for (Entry<Object, Collection<V>> entry : localAdditions.asMap().entrySet()) {
+        K key = keyStateMap.get(entry.getKey()).originalKey;
+        result.extendWith(key, entry.getValue());
+      }
+      for (Entry<Object, KeyState> entry : keyStateMap.entrySet()) {
+        if (entry.getValue().valuesCached) {
+          result.extendWith(entry.getValue().originalKey, entry.getValue().values);
+        }
+      }
+      return result;
+    }
+
+    private static class MultimapIterables<K, V> implements Iterable<Entry<K, V>> {
+      Map<K, ConcatIterables<V>> map;
+
+      public MultimapIterables() {
+        this.map = new HashMap<>();
+      }
+
+      public void extendWith(K key, Iterable<V> iterable) {
+        map.compute(
+            key,
+            (k, v) -> {
+              if (v == null) v = new ConcatIterables<>();
+              v.extendWith(iterable);
+              return v;
+            });
+      }
+
+      @Override
+      public Iterator<Entry<K, V>> iterator() {
+        return Iterators.concat(
+            Iterables.transform(
+                    map.entrySet(),
+                    entry ->
+                        Iterables.transform(
+                                entry.getValue(),
+                                v -> new AbstractMap.SimpleEntry<>(entry.getKey(), v))
+                            .iterator())
+                .iterator());
+      }
+    }
+
+    @Override
+    public ReadableState<Iterable<Entry<K, V>>> entries() {
+      return new ReadableState<Iterable<Entry<K, V>>>() {
+        @Override
+        public Iterable<Entry<K, V>> read() {
+          if (complete) {
+            return Iterables.unmodifiableIterable(mergedCachedEntries());
+          }
+          Future<Iterable<Entry<ByteString, Iterable<V>>>> persistedData = getFuture(false);
+          try (Closeable scope = scopedReadState()) {
+            Iterable<Entry<ByteString, Iterable<V>>> entries = persistedData.get();
+            Map<Object, ConcatIterables<V>> entryMap = Maps.newHashMap();
+            entries.forEach(
+                entry -> {
+                  try {
+                    final K key = keyCoder.decode(entry.getKey().newInput());
+                    final Object structuralKey = keyCoder.structuralValue(key);
+                    KeyState keyState =
+                        keyStateMap.computeIfAbsent(structuralKey, k -> new KeyState(key));
+                    if (keyState.existence == KeyExistence.KNOWN_NONEXISTENT) return;
+                    entryMap.compute(
+                        structuralKey,
+                        (k, v) -> {
+                          if (v == null) v = new ConcatIterables<>();
+                          v.extendWith(entry.getValue());
+                          keyState.existence = KeyExistence.KNOWN_EXIST;
+                          return v;
+                        });
+                  } catch (IOException e) {
+                    throw new RuntimeException(e);
+                  }
+                });
+            if (entries instanceof Weighted) {
+              // This is a known amount of data, cache them all.
+              entryMap.forEach(
+                  (structuralKey, values) -> {
+                    KeyState keyState = keyStateMap.get(structuralKey);
+                    if (!keyState.valuesCached) {
+                      keyState.values.extendWith(values);
+                      keyState.valuesCached = true;
+                    }
+                  });
+              allKeysKnown = true;
+              complete = true;
+              keyStateMap
+                  .entrySet()
+                  .removeIf(
+                      entry ->
+                          entry.getValue().existence == KeyExistence.KNOWN_NONEXISTENT
+                              && !localRemovals.contains(entry.getKey()));
+              return Iterables.unmodifiableIterable(mergedCachedEntries());
+            } else {
+              MultimapIterables<K, V> local = mergedCachedEntries();
+              entryMap.forEach(
+                  (structuralKey, values) -> {
+                    KeyState keyState = keyStateMap.get(structuralKey);
+                    if (!keyState.valuesCached) {
+                      local.extendWith(keyState.originalKey, values);
+                    }
+                  });
+              return Iterables.unmodifiableIterable(local);
+            }
+          } catch (InterruptedException | ExecutionException | IOException e) {
+            if (e instanceof InterruptedException) {
+              Thread.currentThread().interrupt();
+            }
+            throw new RuntimeException("Unable to read state", e);
+          }
+        }
+
+        @Override
+        @SuppressWarnings("FutureReturnValueIgnored")
+        public ReadableState<Iterable<Entry<K, V>>> readLater() {
+          WindmillMultimap.this.getFuture(false);
+          return this;
+        }
+      };
+    }
+
+    @Override
+    public ReadableState<Boolean> containsKey(K key) {
+      return new ReadableState<Boolean>() {
+        ReadableState<Iterable<V>> values = null;
+        final Object structuralKey = keyCoder.structuralValue(key);
+
+        @Override
+        public Boolean read() {
+          KeyState keyState = keyStateMap.getOrDefault(structuralKey, null);
+          if (keyState != null && keyState.existence != KeyExistence.UNKNOWN_EXISTENCE) {
+            return keyState.existence == KeyExistence.KNOWN_EXIST;
+          }
+          if (values == null) {
+            values = WindmillMultimap.this.get(key);
+          }
+          return !Iterables.isEmpty(values.read());
+        }
+
+        @Override
+        public ReadableState<Boolean> readLater() {
+          if (values == null) {
+            values = WindmillMultimap.this.get(key);
+          }
+          values.readLater();
+          return this;
+        }
+      };
+    }
+
+    @Override
+    public ReadableState<Boolean> isEmpty() {
+      return new ReadableState<Boolean>() {
+        ReadableState<Iterable<K>> keys = null;
+
+        @Override
+        public Boolean read() {
+          for (KeyState keyState : keyStateMap.values()) {
+            if (keyState.existence == KeyExistence.KNOWN_EXIST) return false;
+          }
+          if (keys == null) {
+            keys = WindmillMultimap.this.keys();

Review Comment:
   this seems more expensive than it needs to be to determine isEmpty
   
   could add comment it could be potentially optimized or perhaps it might be beneficial if isEmpty is often followed by iterating over keys. Though might be likely all values iterated over after isEmpty too.
   



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternals.java:
##########
@@ -1588,7 +1599,472 @@ private Future<Iterable<Map.Entry<ByteString, V>>> getFuture() {
         return reader.valuePrefixFuture(stateKeyPrefix, stateFamily, valueCoder);
       }
     }
-  };
+  }
+
+  private static class WindmillMultimap<K, V> extends SimpleWindmillState
+      implements MultimapState<K, V> {
+
+    private final StateNamespace namespace;
+    private final StateTag<MultimapState<K, V>> address;
+    private final ByteString stateKey;
+    private final String stateFamily;
+    private final Coder<K> keyCoder;
+    private final Coder<V> valueCoder;
+
+    private boolean cleared = false;
+    /**
+     * For any given key, if it's contained in {@link #cachedEntries}, then the complete content of
+     * this key is cached: persisted values of this key in backing store are cached in
+     * cachedEntries, newly added values not yet written to backing store are cached in
+     * localAdditions. If a key is not contained in {@link #cachedEntries} then we don't know if
+     * Windmill contains additional values which also maps to this key, we'll need to read them if
+     * the work item actually wants the content.
+     */
+    private Map<Object, Iterable<V>> cachedEntries = Maps.newHashMap();
+    // Any key presents in existKeyCache is known to exist in the multimap.
+    private Set<Object> existKeyCache = Sets.newHashSet();
+    // If true, any key not in existKeyCache is known to be nonexistent.
+    private boolean allKeysKnown = false;
+    // Any key presents in nonexistentKeyCache is known to be nonexistent in the multimap.
+    private Set<Object> nonexistentKeyCache = Sets.newHashSet();
+
+    private boolean complete = false;
+    private Multimap<Object, V> localAdditions = ArrayListMultimap.create();
+    // All keys that are pending delete. If a key exist in both localRemovals and localAdditions:
+    // new values in localAdditions will be added after old values are removed.
+    private Set<Object> localRemovals = Sets.newHashSet();
+    // structuralKeyMapping maps from the structuralKeys to the actual keys. Any key in
+    // cachedEntries, existKeyCache, nonexistentKeyCache, localAdditions and localRemovals should be
+    // included in this mapping.
+    private Map<Object, K> structuralKeyMapping = Maps.newHashMap();
+
+    private WindmillMultimap(
+        StateNamespace namespace,
+        StateTag<MultimapState<K, V>> address,
+        String stateFamily,
+        Coder<K> keyCoder,
+        Coder<V> valueCoder,
+        boolean isNewShardingKey) {
+      this.namespace = namespace;
+      this.address = address;
+      this.stateKey = encodeKey(namespace, address);
+      this.stateFamily = stateFamily;
+      this.keyCoder = keyCoder;
+      this.valueCoder = valueCoder;
+      this.complete = isNewShardingKey;
+      this.allKeysKnown = isNewShardingKey;
+    }
+
+    @Override
+    public void put(K key, V value) {
+      Object structuralKey = keyCoder.structuralValue(key);
+      localAdditions.put(structuralKey, value);
+      existKeyCache.add(structuralKey);
+      nonexistentKeyCache.remove(structuralKey);
+      structuralKeyMapping.put(structuralKey, key);
+    }
+
+    // Initiates a backend state read to fetch all entries if necessary.
+    private Future<Iterable<Map.Entry<ByteString, Iterable<V>>>> getFuture(boolean omitValues) {
+      if (complete) {
+        return Futures.immediateFuture(Collections.emptyList());
+      } else {
+        return reader.multimapFetchAllFuture(omitValues, stateKey, stateFamily, valueCoder);
+      }
+    }
+
+    // Initiates a backend state read to fetch a single entry if necessary.
+    private Future<Iterable<V>> getFutureForKey(K key) {
+      try {
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        keyCoder.encode(key, keyStream);
+        return reader.multimapFetchSingleEntryFuture(
+            keyStream.toByteString(), stateKey, stateFamily, valueCoder);
+      } catch (IOException e) {
+        throw new RuntimeException(e);
+      }
+    }
+
+    @Override
+    public ReadableState<Iterable<V>> get(K key) {
+      return new ReadableState<Iterable<V>>() {
+        Object structuralKey = keyCoder.structuralValue(key);
+
+        @Override
+        public Iterable<V> read() {
+          if (nonexistentKeyCache.contains(structuralKey)
+              || (allKeysKnown && !existKeyCache.contains(structuralKey))) {
+            return Collections.emptyList();
+          }
+          if (localRemovals.contains(structuralKey)) {
+            // this key has been removed locally but the removal hasn't been sent to windmill,
+            // thus values in windmill(if any) are obsolete, and we only care about local values.
+            if (localAdditions.containsKey(structuralKey)) {
+              return Iterables.unmodifiableIterable(localAdditions.get(structuralKey));
+            } else {
+              return Collections.emptyList();
+            }
+          }
+          if (cachedEntries.containsKey(structuralKey) || complete) {
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(
+                    cachedEntries.getOrDefault(structuralKey, Collections.emptyList()),
+                    localAdditions.get(structuralKey)));
+          }
+          Future<Iterable<V>> persistedData = getFutureForKey(key);
+          try (Closeable scope = scopedReadState()) {
+            Iterable<V> persistedValues = persistedData.get();
+            if (Iterables.size(persistedValues) == 0) {
+              if (!localAdditions.containsKey(structuralKey)) {
+                nonexistentKeyCache.add(structuralKey);
+                Preconditions.checkState(
+                    !existKeyCache.contains(structuralKey),
+                    "Key "
+                        + key
+                        + " exists"
+                        + " in existKeyCache but no value in neither windmill nor local additions.");
+              }
+              return Iterables.unmodifiableIterable(localAdditions.get(structuralKey));
+            }
+            if (persistedValues instanceof Weighted) {
+              cachedEntries.put(structuralKey, new ConcatIterables<>());
+              ((ConcatIterables<V>) cachedEntries.get(structuralKey)).extendWith(persistedValues);
+            }
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(persistedValues, localAdditions.get(structuralKey)));
+          } catch (InterruptedException | ExecutionException | IOException e) {
+            if (e instanceof InterruptedException) {
+              Thread.currentThread().interrupt();
+            }
+            throw new RuntimeException("Unable to read Multimap state", e);
+          }
+        }
+
+        @Override
+        @SuppressWarnings("FutureReturnValueIgnored")
+        public ReadableState<Iterable<V>> readLater() {
+          WindmillMultimap.this.getFutureForKey(key);
+          return this;
+        }
+      };
+    }
+
+    @Override
+    protected WorkItemCommitRequest persistDirectly(WindmillStateCache.ForKeyAndFamily cache)
+        throws IOException {
+      if (!cleared && localAdditions.isEmpty() && localRemovals.isEmpty()) {
+        return WorkItemCommitRequest.newBuilder().buildPartial();

Review Comment:
   we're returning early here so we're not calling cache.put below
   
   if this is empty and complete, it could possibly be worth caching if it prevents reads in the future.



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateReader.java:
##########
@@ -447,6 +493,12 @@ public Iterable<ResultT> apply(
           contStateTag =

Review Comment:
   nit: can have contStateTag be a builder and avoid going back and forth from builder to built



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternals.java:
##########
@@ -1588,7 +1599,522 @@ private Future<Iterable<Map.Entry<ByteString, V>>> getFuture() {
         return reader.valuePrefixFuture(stateKeyPrefix, stateFamily, valueCoder);
       }
     }
-  };
+  }
+
+  private static class WindmillMultimap<K, V> extends SimpleWindmillState
+      implements MultimapState<K, V> {
+
+    private final StateNamespace namespace;
+    private final StateTag<MultimapState<K, V>> address;
+    private final ByteString stateKey;
+    private final String stateFamily;
+    private final Coder<K> keyCoder;
+    private final Coder<V> valueCoder;
+
+    private class KeyState {
+      final K originalKey;
+      KeyExistence existence;
+      // valuesCached can be true if only existence == KNOWN_EXIST and all values of this key is
+      // cached(both KeyState#values and localAdditions).
+      boolean valuesCached;
+      // represents the values in windmill. When new values are added, they are added to
+      // localAdditions but not KeyState#values. New values will be added to KeyState#values only
+      // after they are persisted into windmill and removed from localAdditions.
+      ConcatIterables<V> values;
+
+      KeyState(K originalKey) {
+        this.originalKey = originalKey;
+        existence = KeyExistence.UNKNOWN_EXISTENCE;
+        valuesCached = false;
+        values = new ConcatIterables<>();
+      }
+    }
+
+    private enum KeyExistence {
+      // this key is known to exist
+      KNOWN_EXIST,
+      // this key is known to be nonexistent
+      KNOWN_NONEXISTENT,
+      // we don't know if this key is in this multimap, this is just to provide a mapping between
+      // the original key and the structural key.
+      UNKNOWN_EXISTENCE
+    }
+
+    private boolean cleared = false;
+    // We use the structural value of the keys as the key in keyStateMap, so that different java
+    // Objects with the same content will be treated as the same Multimap key.
+    private Map<Object, KeyState> keyStateMap = Maps.newHashMap();
+    // If true, all keys are cached in keyStateMap with existence == KNOWN_EXIST.
+    private boolean allKeysKnown = false;
+
+    private boolean complete = false;
+    // All keys that have new values pending write to windmill.
+    private Multimap<Object, V> localAdditions = ArrayListMultimap.create();
+    // All keys that are pending delete. If a key exist in both localRemovals and localAdditions:
+    // new values in localAdditions will be added after old values are removed.
+    private Set<Object> localRemovals = Sets.newHashSet();
+
+    private WindmillMultimap(
+        StateNamespace namespace,
+        StateTag<MultimapState<K, V>> address,
+        String stateFamily,
+        Coder<K> keyCoder,
+        Coder<V> valueCoder,
+        boolean isNewShardingKey) {
+      this.namespace = namespace;
+      this.address = address;
+      this.stateKey = encodeKey(namespace, address);
+      this.stateFamily = stateFamily;
+      this.keyCoder = keyCoder;
+      this.valueCoder = valueCoder;
+      this.complete = isNewShardingKey;
+      this.allKeysKnown = isNewShardingKey;
+    }
+
+    @Override
+    public void put(K key, V value) {
+      final Object structuralKey = keyCoder.structuralValue(key);
+      localAdditions.put(structuralKey, value);
+      keyStateMap.compute(
+          structuralKey,
+          (k, v) -> {
+            if (v == null) v = new KeyState(key);
+            v.existence = KeyExistence.KNOWN_EXIST;
+            return v;
+          });
+    }
+
+    // Initiates a backend state read to fetch all entries if necessary.
+    private Future<Iterable<Map.Entry<ByteString, Iterable<V>>>> getFuture(boolean omitValues) {
+      if (complete) {
+        return Futures.immediateFuture(Collections.emptyList());
+      } else {
+        return reader.multimapFetchAllFuture(omitValues, stateKey, stateFamily, valueCoder);
+      }
+    }
+
+    // Initiates a backend state read to fetch a single entry if necessary.
+    private Future<Iterable<V>> getFutureForKey(K key) {
+      try {
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        keyCoder.encode(key, keyStream);
+        return reader.multimapFetchSingleEntryFuture(
+            keyStream.toByteString(), stateKey, stateFamily, valueCoder);
+      } catch (IOException e) {
+        throw new RuntimeException(e);
+      }
+    }
+
+    @Override
+    public ReadableState<Iterable<V>> get(K key) {
+      return new ReadableState<Iterable<V>>() {
+        final Object structuralKey = keyCoder.structuralValue(key);
+
+        @Override
+        public Iterable<V> read() {
+          KeyState keyState = keyStateMap.computeIfAbsent(structuralKey, k -> new KeyState(key));
+          if (keyState.existence == KeyExistence.KNOWN_NONEXISTENT) {
+            return Collections.emptyList();
+          }
+          if (allKeysKnown && keyState.existence == KeyExistence.UNKNOWN_EXISTENCE) {
+            keyStateMap.remove(structuralKey);
+            return Collections.emptyList();
+          }
+
+          if (localRemovals.contains(structuralKey)) {
+            // this key has been removed locally but the removal hasn't been sent to windmill,
+            // thus values in windmill(if any) are obsolete, and we only care about local values.
+            return Iterables.unmodifiableIterable(localAdditions.get(structuralKey));
+          }
+          if (keyState.valuesCached || complete) {
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(keyState.values, localAdditions.get(structuralKey)));
+          }
+          Future<Iterable<V>> persistedData = getFutureForKey(key);
+          try (Closeable scope = scopedReadState()) {
+            final Iterable<V> persistedValues = persistedData.get();
+            if (Iterables.isEmpty(persistedValues)) {
+              Collection<V> local = localAdditions.get(structuralKey);
+              if (local.isEmpty()) {
+                // empty in both cache and windmill, remove key from cache.
+                keyStateMap.remove(structuralKey);
+                return Collections.emptyList();
+              }
+              return Iterables.unmodifiableIterable(local);
+            }
+            if (persistedValues instanceof Weighted) {
+              keyState.existence = KeyExistence.KNOWN_EXIST;
+              keyState.valuesCached = true;
+              ConcatIterables<V> it = new ConcatIterables<>();
+              it.extendWith(persistedValues);
+              keyState.values = it;
+            }
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(persistedValues, localAdditions.get(structuralKey)));
+          } catch (InterruptedException | ExecutionException | IOException e) {
+            if (e instanceof InterruptedException) {
+              Thread.currentThread().interrupt();
+            }
+            throw new RuntimeException("Unable to read Multimap state", e);
+          }
+        }
+
+        @Override
+        @SuppressWarnings("FutureReturnValueIgnored")
+        public ReadableState<Iterable<V>> readLater() {
+          WindmillMultimap.this.getFutureForKey(key);
+          return this;
+        }
+      };
+    }
+
+    @Override
+    protected WorkItemCommitRequest persistDirectly(WindmillStateCache.ForKeyAndFamily cache)
+        throws IOException {
+      if (!cleared && localAdditions.isEmpty() && localRemovals.isEmpty()) {
+        return WorkItemCommitRequest.newBuilder().buildPartial();
+      }
+      WorkItemCommitRequest.Builder commitBuilder = WorkItemCommitRequest.newBuilder();
+      Windmill.TagMultimapUpdateRequest.Builder builder = null;
+      if (cleared) {
+        builder = commitBuilder.addMultimapUpdatesBuilder();
+        builder.setDeleteAll(true);
+        cleared = false;
+      }
+      Set<Object> keysWithUpdates = Sets.newHashSet();
+      keysWithUpdates.addAll(localRemovals);
+      keysWithUpdates.addAll(localAdditions.keySet());
+      if (!keysWithUpdates.isEmpty() && builder == null) {
+        builder = commitBuilder.addMultimapUpdatesBuilder();
+        builder.setTag(stateKey).setStateFamily(stateFamily);
+      }
+      for (Object structuralKey : keysWithUpdates) {
+        KeyState keyState = keyStateMap.get(structuralKey);
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        keyCoder.encode(keyState.originalKey, keyStream);
+        ByteString encodedKey = keyStream.toByteString();
+        Windmill.TagMultimapEntry.Builder entryBuilder = builder.addUpdatesBuilder();
+        entryBuilder.setEntryName(encodedKey);
+        entryBuilder.setDeleteAll(localRemovals.contains(structuralKey));
+        for (V value : localAdditions.get(structuralKey)) {
+          ByteStringOutputStream valueStream = new ByteStringOutputStream();
+          valueCoder.encode(value, valueStream);
+          ByteString encodedValue = valueStream.toByteString();
+          entryBuilder.addValues(encodedValue);
+        }
+        // Move newly added values from localAdditions to cachedEntries as those new values now are
+        // also persisted in Windmill. If a key now has no more values and is not KNOWN_EXIST,
+        // remove it from cache.
+        if (keyState.valuesCached) {
+          keyState.values.extendWith(localAdditions.get(structuralKey));
+        } else {
+          if (keyState.existence != KeyExistence.KNOWN_EXIST) keyStateMap.remove(structuralKey);
+        }
+      }
+
+      localRemovals = Sets.newHashSet();
+      localAdditions = ArrayListMultimap.create();
+
+      cache.put(namespace, address, this, 1);
+
+      return commitBuilder.buildPartial();
+    }
+
+    @Override
+    public void remove(K key) {
+      final Object structuralKey = keyCoder.structuralValue(key);
+      KeyState keyState = keyStateMap.computeIfAbsent(structuralKey, k -> new KeyState(key));
+      if (keyState.existence == KeyExistence.KNOWN_NONEXISTENT
+          || (allKeysKnown && keyState.existence == KeyExistence.UNKNOWN_EXISTENCE)) {
+        return;
+      }
+      if (keyState.valuesCached || !complete) {
+        // there may be data in windmill that need to be removed.
+        localRemovals.add(structuralKey);
+        keyState.values = new ConcatIterables<>();
+        keyState.valuesCached = false;
+        keyState.existence = KeyExistence.KNOWN_NONEXISTENT;
+      } else {
+        // no data in windmill, deleting from local cache is sufficient.
+        keyStateMap.remove(structuralKey);
+      }
+      localAdditions.removeAll(structuralKey);
+    }
+
+    @Override
+    public void clear() {
+      keyStateMap = Maps.newHashMap();
+      localAdditions = ArrayListMultimap.create();
+      localRemovals = Sets.newHashSet();
+      cleared = true;
+      complete = true;
+      allKeysKnown = true;
+    }
+
+    @Override
+    public ReadableState<Iterable<K>> keys() {
+      return new ReadableState<Iterable<K>>() {
+        @Override
+        public Iterable<K> read() {
+          if (allKeysKnown) {
+            return Iterables.unmodifiableIterable(
+                Iterables.transform(
+                    Iterables.filter(
+                        keyStateMap.values(),
+                        keyState -> keyState.existence == KeyExistence.KNOWN_EXIST),
+                    keyState -> keyState.originalKey));
+          }
+          Future<Iterable<Entry<ByteString, Iterable<V>>>> persistedData = getFuture(true);
+          try (Closeable scope = scopedReadState()) {
+            Iterable<Entry<ByteString, Iterable<V>>> entries = persistedData.get();
+            Iterable<Entry<Object, K>> keys =
+                Iterables.transform(
+                    entries,
+                    entry -> {
+                      try {
+                        K originalKey = keyCoder.decode(entry.getKey().newInput());
+                        return new AbstractMap.SimpleEntry<>(
+                            keyCoder.structuralValue(originalKey), originalKey);
+                      } catch (IOException e) {
+                        throw new RuntimeException(e);
+                      }
+                    });
+            keys =
+                Iterables.filter(
+                    keys,
+                    entry -> {
+                      KeyState keyState = keyStateMap.getOrDefault(entry.getKey(), null);
+                      // this is a key that exists in windmill but is not cached.
+                      if (keyState == null) return true;
+                      // if the key is marked as deleted in cache, ignore it.
+                      return keyState.existence != KeyExistence.KNOWN_NONEXISTENT;
+                    });
+            if (entries instanceof Weighted) {
+              // This is a known amount of data, cache them all.
+              keys.forEach(
+                  entry -> {
+                    KeyState keyState =
+                        keyStateMap.computeIfAbsent(
+                            entry.getKey(), stk -> new KeyState(entry.getValue()));
+                    keyState.existence = KeyExistence.KNOWN_EXIST;
+                  });
+              allKeysKnown = true;
+              keyStateMap
+                  .values()
+                  .removeIf(keyState -> keyState.existence != KeyExistence.KNOWN_EXIST);
+              return Iterables.unmodifiableIterable(
+                  Iterables.transform(keyStateMap.values(), keyState -> keyState.originalKey));
+            } else {
+              return Iterables.unmodifiableIterable(
+                  Iterables.concat(
+                      // This is the part of keys that are cached.
+                      Iterables.transform(
+                          Iterables.filter(
+                              keyStateMap.values(),
+                              keyState -> keyState.existence == KeyExistence.KNOWN_EXIST),
+                          keyState -> keyState.originalKey),
+                      // This is the part of the keys returned from Windmill that are not cached.
+                      Iterables.transform(
+                          Iterables.filter(keys, e -> !keyStateMap.containsKey(e.getKey())),
+                          entry -> entry.getValue())));
+            }
+          } catch (InterruptedException | ExecutionException | IOException e) {
+            if (e instanceof InterruptedException) {
+              Thread.currentThread().interrupt();
+            }
+            throw new RuntimeException("Unable to read state", e);
+          }
+        }
+
+        @Override
+        @SuppressWarnings("FutureReturnValueIgnored")
+        public ReadableState<Iterable<K>> readLater() {
+          WindmillMultimap.this.getFuture(true);
+          return this;
+        }
+      };
+    }
+
+    private MultimapIterables<K, V> mergedCachedEntries() {
+      MultimapIterables<K, V> result = new MultimapIterables<>();
+      for (Entry<Object, Collection<V>> entry : localAdditions.asMap().entrySet()) {
+        K key = keyStateMap.get(entry.getKey()).originalKey;
+        result.extendWith(key, entry.getValue());
+      }
+      for (Entry<Object, KeyState> entry : keyStateMap.entrySet()) {
+        if (entry.getValue().valuesCached) {
+          result.extendWith(entry.getValue().originalKey, entry.getValue().values);
+        }
+      }
+      return result;
+    }
+
+    private static class MultimapIterables<K, V> implements Iterable<Entry<K, V>> {
+      Map<K, ConcatIterables<V>> map;
+
+      public MultimapIterables() {
+        this.map = new HashMap<>();
+      }
+
+      public void extendWith(K key, Iterable<V> iterable) {
+        map.compute(
+            key,
+            (k, v) -> {
+              if (v == null) v = new ConcatIterables<>();
+              v.extendWith(iterable);
+              return v;
+            });
+      }
+
+      @Override
+      public Iterator<Entry<K, V>> iterator() {
+        return Iterators.concat(
+            Iterables.transform(
+                    map.entrySet(),
+                    entry ->
+                        Iterables.transform(
+                                entry.getValue(),
+                                v -> new AbstractMap.SimpleEntry<>(entry.getKey(), v))
+                            .iterator())
+                .iterator());
+      }
+    }
+
+    @Override
+    public ReadableState<Iterable<Entry<K, V>>> entries() {
+      return new ReadableState<Iterable<Entry<K, V>>>() {
+        @Override
+        public Iterable<Entry<K, V>> read() {
+          if (complete) {
+            return Iterables.unmodifiableIterable(mergedCachedEntries());
+          }
+          Future<Iterable<Entry<ByteString, Iterable<V>>>> persistedData = getFuture(false);
+          try (Closeable scope = scopedReadState()) {
+            Iterable<Entry<ByteString, Iterable<V>>> entries = persistedData.get();
+            Map<Object, ConcatIterables<V>> entryMap = Maps.newHashMap();
+            entries.forEach(
+                entry -> {
+                  try {
+                    final K key = keyCoder.decode(entry.getKey().newInput());
+                    final Object structuralKey = keyCoder.structuralValue(key);
+                    KeyState keyState =
+                        keyStateMap.computeIfAbsent(structuralKey, k -> new KeyState(key));
+                    if (keyState.existence == KeyExistence.KNOWN_NONEXISTENT) return;
+                    entryMap.compute(
+                        structuralKey,
+                        (k, v) -> {
+                          if (v == null) v = new ConcatIterables<>();
+                          v.extendWith(entry.getValue());
+                          keyState.existence = KeyExistence.KNOWN_EXIST;
+                          return v;
+                        });
+                  } catch (IOException e) {
+                    throw new RuntimeException(e);
+                  }
+                });
+            if (entries instanceof Weighted) {
+              // This is a known amount of data, cache them all.
+              entryMap.forEach(
+                  (structuralKey, values) -> {
+                    KeyState keyState = keyStateMap.get(structuralKey);
+                    if (!keyState.valuesCached) {
+                      keyState.values.extendWith(values);
+                      keyState.valuesCached = true;
+                    }
+                  });
+              allKeysKnown = true;
+              complete = true;
+              keyStateMap
+                  .entrySet()
+                  .removeIf(
+                      entry ->
+                          entry.getValue().existence == KeyExistence.KNOWN_NONEXISTENT
+                              && !localRemovals.contains(entry.getKey()));
+              return Iterables.unmodifiableIterable(mergedCachedEntries());
+            } else {
+              MultimapIterables<K, V> local = mergedCachedEntries();
+              entryMap.forEach(
+                  (structuralKey, values) -> {
+                    KeyState keyState = keyStateMap.get(structuralKey);
+                    if (!keyState.valuesCached) {
+                      local.extendWith(keyState.originalKey, values);
+                    }
+                  });
+              return Iterables.unmodifiableIterable(local);
+            }
+          } catch (InterruptedException | ExecutionException | IOException e) {
+            if (e instanceof InterruptedException) {
+              Thread.currentThread().interrupt();
+            }
+            throw new RuntimeException("Unable to read state", e);
+          }
+        }
+
+        @Override
+        @SuppressWarnings("FutureReturnValueIgnored")
+        public ReadableState<Iterable<Entry<K, V>>> readLater() {
+          WindmillMultimap.this.getFuture(false);
+          return this;
+        }
+      };
+    }
+
+    @Override
+    public ReadableState<Boolean> containsKey(K key) {
+      return new ReadableState<Boolean>() {
+        ReadableState<Iterable<V>> values = null;
+        final Object structuralKey = keyCoder.structuralValue(key);
+
+        @Override
+        public Boolean read() {
+          KeyState keyState = keyStateMap.getOrDefault(structuralKey, null);
+          if (keyState != null && keyState.existence != KeyExistence.UNKNOWN_EXISTENCE) {
+            return keyState.existence == KeyExistence.KNOWN_EXIST;
+          }
+          if (values == null) {
+            values = WindmillMultimap.this.get(key);
+          }
+          return !Iterables.isEmpty(values.read());
+        }
+
+        @Override
+        public ReadableState<Boolean> readLater() {
+          if (values == null) {
+            values = WindmillMultimap.this.get(key);
+          }
+          values.readLater();
+          return this;
+        }
+      };
+    }
+
+    @Override
+    public ReadableState<Boolean> isEmpty() {
+      return new ReadableState<Boolean>() {
+        ReadableState<Iterable<K>> keys = null;
+
+        @Override
+        public Boolean read() {
+          for (KeyState keyState : keyStateMap.values()) {
+            if (keyState.existence == KeyExistence.KNOWN_EXIST) return false;
+          }
+          if (keys == null) {
+            keys = WindmillMultimap.this.keys();
+          }
+          return Iterables.isEmpty(keys.read());
+        }
+
+        @Override
+        public ReadableState<Boolean> readLater() {
+          if (keys == null) {
+            keys = WindmillMultimap.this.keys();
+          }
+          keys.readLater();

Review Comment:
   should this be avoided if we know of any keys in keyStateMap?
   readLater is going to batch to be issued on next blocking read() of anything so we should avoid if not needed since keys will be fetched if not complete but we don't compare about complete here.



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternals.java:
##########
@@ -1588,7 +1599,522 @@ private Future<Iterable<Map.Entry<ByteString, V>>> getFuture() {
         return reader.valuePrefixFuture(stateKeyPrefix, stateFamily, valueCoder);
       }
     }
-  };
+  }
+
+  private static class WindmillMultimap<K, V> extends SimpleWindmillState
+      implements MultimapState<K, V> {
+
+    private final StateNamespace namespace;
+    private final StateTag<MultimapState<K, V>> address;
+    private final ByteString stateKey;
+    private final String stateFamily;
+    private final Coder<K> keyCoder;
+    private final Coder<V> valueCoder;
+
+    private class KeyState {
+      final K originalKey;
+      KeyExistence existence;
+      // valuesCached can be true if only existence == KNOWN_EXIST and all values of this key is
+      // cached(both KeyState#values and localAdditions).
+      boolean valuesCached;
+      // represents the values in windmill. When new values are added, they are added to
+      // localAdditions but not KeyState#values. New values will be added to KeyState#values only
+      // after they are persisted into windmill and removed from localAdditions.
+      ConcatIterables<V> values;
+
+      KeyState(K originalKey) {
+        this.originalKey = originalKey;
+        existence = KeyExistence.UNKNOWN_EXISTENCE;
+        valuesCached = false;
+        values = new ConcatIterables<>();
+      }
+    }
+
+    private enum KeyExistence {
+      // this key is known to exist
+      KNOWN_EXIST,
+      // this key is known to be nonexistent
+      KNOWN_NONEXISTENT,
+      // we don't know if this key is in this multimap, this is just to provide a mapping between
+      // the original key and the structural key.
+      UNKNOWN_EXISTENCE
+    }
+
+    private boolean cleared = false;
+    // We use the structural value of the keys as the key in keyStateMap, so that different java
+    // Objects with the same content will be treated as the same Multimap key.
+    private Map<Object, KeyState> keyStateMap = Maps.newHashMap();
+    // If true, all keys are cached in keyStateMap with existence == KNOWN_EXIST.
+    private boolean allKeysKnown = false;
+
+    private boolean complete = false;
+    // All keys that have new values pending write to windmill.
+    private Multimap<Object, V> localAdditions = ArrayListMultimap.create();
+    // All keys that are pending delete. If a key exist in both localRemovals and localAdditions:
+    // new values in localAdditions will be added after old values are removed.
+    private Set<Object> localRemovals = Sets.newHashSet();
+
+    private WindmillMultimap(
+        StateNamespace namespace,
+        StateTag<MultimapState<K, V>> address,
+        String stateFamily,
+        Coder<K> keyCoder,
+        Coder<V> valueCoder,
+        boolean isNewShardingKey) {
+      this.namespace = namespace;
+      this.address = address;
+      this.stateKey = encodeKey(namespace, address);
+      this.stateFamily = stateFamily;
+      this.keyCoder = keyCoder;
+      this.valueCoder = valueCoder;
+      this.complete = isNewShardingKey;
+      this.allKeysKnown = isNewShardingKey;
+    }
+
+    @Override
+    public void put(K key, V value) {
+      final Object structuralKey = keyCoder.structuralValue(key);
+      localAdditions.put(structuralKey, value);
+      keyStateMap.compute(
+          structuralKey,
+          (k, v) -> {
+            if (v == null) v = new KeyState(key);
+            v.existence = KeyExistence.KNOWN_EXIST;
+            return v;
+          });
+    }
+
+    // Initiates a backend state read to fetch all entries if necessary.
+    private Future<Iterable<Map.Entry<ByteString, Iterable<V>>>> getFuture(boolean omitValues) {
+      if (complete) {
+        return Futures.immediateFuture(Collections.emptyList());
+      } else {
+        return reader.multimapFetchAllFuture(omitValues, stateKey, stateFamily, valueCoder);
+      }
+    }
+
+    // Initiates a backend state read to fetch a single entry if necessary.
+    private Future<Iterable<V>> getFutureForKey(K key) {
+      try {
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        keyCoder.encode(key, keyStream);
+        return reader.multimapFetchSingleEntryFuture(
+            keyStream.toByteString(), stateKey, stateFamily, valueCoder);
+      } catch (IOException e) {
+        throw new RuntimeException(e);
+      }
+    }
+
+    @Override
+    public ReadableState<Iterable<V>> get(K key) {
+      return new ReadableState<Iterable<V>>() {
+        final Object structuralKey = keyCoder.structuralValue(key);
+
+        @Override
+        public Iterable<V> read() {
+          KeyState keyState = keyStateMap.computeIfAbsent(structuralKey, k -> new KeyState(key));
+          if (keyState.existence == KeyExistence.KNOWN_NONEXISTENT) {
+            return Collections.emptyList();
+          }
+          if (allKeysKnown && keyState.existence == KeyExistence.UNKNOWN_EXISTENCE) {
+            keyStateMap.remove(structuralKey);
+            return Collections.emptyList();
+          }
+
+          if (localRemovals.contains(structuralKey)) {
+            // this key has been removed locally but the removal hasn't been sent to windmill,
+            // thus values in windmill(if any) are obsolete, and we only care about local values.
+            return Iterables.unmodifiableIterable(localAdditions.get(structuralKey));
+          }
+          if (keyState.valuesCached || complete) {
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(keyState.values, localAdditions.get(structuralKey)));
+          }
+          Future<Iterable<V>> persistedData = getFutureForKey(key);
+          try (Closeable scope = scopedReadState()) {
+            final Iterable<V> persistedValues = persistedData.get();
+            if (Iterables.isEmpty(persistedValues)) {
+              Collection<V> local = localAdditions.get(structuralKey);
+              if (local.isEmpty()) {
+                // empty in both cache and windmill, remove key from cache.
+                keyStateMap.remove(structuralKey);
+                return Collections.emptyList();
+              }
+              return Iterables.unmodifiableIterable(local);
+            }
+            if (persistedValues instanceof Weighted) {
+              keyState.existence = KeyExistence.KNOWN_EXIST;
+              keyState.valuesCached = true;
+              ConcatIterables<V> it = new ConcatIterables<>();
+              it.extendWith(persistedValues);
+              keyState.values = it;
+            }
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(persistedValues, localAdditions.get(structuralKey)));
+          } catch (InterruptedException | ExecutionException | IOException e) {
+            if (e instanceof InterruptedException) {
+              Thread.currentThread().interrupt();
+            }
+            throw new RuntimeException("Unable to read Multimap state", e);
+          }
+        }
+
+        @Override
+        @SuppressWarnings("FutureReturnValueIgnored")
+        public ReadableState<Iterable<V>> readLater() {
+          WindmillMultimap.this.getFutureForKey(key);
+          return this;
+        }
+      };
+    }
+
+    @Override
+    protected WorkItemCommitRequest persistDirectly(WindmillStateCache.ForKeyAndFamily cache)
+        throws IOException {
+      if (!cleared && localAdditions.isEmpty() && localRemovals.isEmpty()) {
+        return WorkItemCommitRequest.newBuilder().buildPartial();
+      }
+      WorkItemCommitRequest.Builder commitBuilder = WorkItemCommitRequest.newBuilder();
+      Windmill.TagMultimapUpdateRequest.Builder builder = null;
+      if (cleared) {
+        builder = commitBuilder.addMultimapUpdatesBuilder();
+        builder.setDeleteAll(true);
+        cleared = false;
+      }
+      Set<Object> keysWithUpdates = Sets.newHashSet();
+      keysWithUpdates.addAll(localRemovals);
+      keysWithUpdates.addAll(localAdditions.keySet());
+      if (!keysWithUpdates.isEmpty() && builder == null) {
+        builder = commitBuilder.addMultimapUpdatesBuilder();
+        builder.setTag(stateKey).setStateFamily(stateFamily);
+      }
+      for (Object structuralKey : keysWithUpdates) {
+        KeyState keyState = keyStateMap.get(structuralKey);
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        keyCoder.encode(keyState.originalKey, keyStream);
+        ByteString encodedKey = keyStream.toByteString();
+        Windmill.TagMultimapEntry.Builder entryBuilder = builder.addUpdatesBuilder();
+        entryBuilder.setEntryName(encodedKey);
+        entryBuilder.setDeleteAll(localRemovals.contains(structuralKey));
+        for (V value : localAdditions.get(structuralKey)) {
+          ByteStringOutputStream valueStream = new ByteStringOutputStream();
+          valueCoder.encode(value, valueStream);
+          ByteString encodedValue = valueStream.toByteString();
+          entryBuilder.addValues(encodedValue);
+        }
+        // Move newly added values from localAdditions to cachedEntries as those new values now are
+        // also persisted in Windmill. If a key now has no more values and is not KNOWN_EXIST,
+        // remove it from cache.
+        if (keyState.valuesCached) {
+          keyState.values.extendWith(localAdditions.get(structuralKey));
+        } else {
+          if (keyState.existence != KeyExistence.KNOWN_EXIST) keyStateMap.remove(structuralKey);
+        }
+      }
+
+      localRemovals = Sets.newHashSet();
+      localAdditions = ArrayListMultimap.create();
+
+      cache.put(namespace, address, this, 1);
+
+      return commitBuilder.buildPartial();
+    }
+
+    @Override
+    public void remove(K key) {
+      final Object structuralKey = keyCoder.structuralValue(key);
+      KeyState keyState = keyStateMap.computeIfAbsent(structuralKey, k -> new KeyState(key));
+      if (keyState.existence == KeyExistence.KNOWN_NONEXISTENT
+          || (allKeysKnown && keyState.existence == KeyExistence.UNKNOWN_EXISTENCE)) {
+        return;
+      }
+      if (keyState.valuesCached || !complete) {
+        // there may be data in windmill that need to be removed.
+        localRemovals.add(structuralKey);
+        keyState.values = new ConcatIterables<>();
+        keyState.valuesCached = false;
+        keyState.existence = KeyExistence.KNOWN_NONEXISTENT;
+      } else {
+        // no data in windmill, deleting from local cache is sufficient.
+        keyStateMap.remove(structuralKey);
+      }
+      localAdditions.removeAll(structuralKey);
+    }
+
+    @Override
+    public void clear() {
+      keyStateMap = Maps.newHashMap();
+      localAdditions = ArrayListMultimap.create();
+      localRemovals = Sets.newHashSet();
+      cleared = true;
+      complete = true;
+      allKeysKnown = true;
+    }
+
+    @Override
+    public ReadableState<Iterable<K>> keys() {
+      return new ReadableState<Iterable<K>>() {
+        @Override
+        public Iterable<K> read() {
+          if (allKeysKnown) {
+            return Iterables.unmodifiableIterable(
+                Iterables.transform(
+                    Iterables.filter(
+                        keyStateMap.values(),
+                        keyState -> keyState.existence == KeyExistence.KNOWN_EXIST),
+                    keyState -> keyState.originalKey));
+          }
+          Future<Iterable<Entry<ByteString, Iterable<V>>>> persistedData = getFuture(true);
+          try (Closeable scope = scopedReadState()) {
+            Iterable<Entry<ByteString, Iterable<V>>> entries = persistedData.get();
+            Iterable<Entry<Object, K>> keys =
+                Iterables.transform(
+                    entries,
+                    entry -> {
+                      try {
+                        K originalKey = keyCoder.decode(entry.getKey().newInput());
+                        return new AbstractMap.SimpleEntry<>(
+                            keyCoder.structuralValue(originalKey), originalKey);
+                      } catch (IOException e) {
+                        throw new RuntimeException(e);
+                      }
+                    });
+            keys =
+                Iterables.filter(
+                    keys,
+                    entry -> {
+                      KeyState keyState = keyStateMap.getOrDefault(entry.getKey(), null);
+                      // this is a key that exists in windmill but is not cached.
+                      if (keyState == null) return true;
+                      // if the key is marked as deleted in cache, ignore it.
+                      return keyState.existence != KeyExistence.KNOWN_NONEXISTENT;
+                    });
+            if (entries instanceof Weighted) {
+              // This is a known amount of data, cache them all.
+              keys.forEach(
+                  entry -> {
+                    KeyState keyState =
+                        keyStateMap.computeIfAbsent(
+                            entry.getKey(), stk -> new KeyState(entry.getValue()));
+                    keyState.existence = KeyExistence.KNOWN_EXIST;
+                  });
+              allKeysKnown = true;
+              keyStateMap
+                  .values()
+                  .removeIf(keyState -> keyState.existence != KeyExistence.KNOWN_EXIST);
+              return Iterables.unmodifiableIterable(
+                  Iterables.transform(keyStateMap.values(), keyState -> keyState.originalKey));
+            } else {
+              return Iterables.unmodifiableIterable(
+                  Iterables.concat(
+                      // This is the part of keys that are cached.
+                      Iterables.transform(
+                          Iterables.filter(
+                              keyStateMap.values(),
+                              keyState -> keyState.existence == KeyExistence.KNOWN_EXIST),
+                          keyState -> keyState.originalKey),
+                      // This is the part of the keys returned from Windmill that are not cached.
+                      Iterables.transform(
+                          Iterables.filter(keys, e -> !keyStateMap.containsKey(e.getKey())),
+                          entry -> entry.getValue())));
+            }
+          } catch (InterruptedException | ExecutionException | IOException e) {
+            if (e instanceof InterruptedException) {
+              Thread.currentThread().interrupt();
+            }
+            throw new RuntimeException("Unable to read state", e);
+          }
+        }
+
+        @Override
+        @SuppressWarnings("FutureReturnValueIgnored")
+        public ReadableState<Iterable<K>> readLater() {
+          WindmillMultimap.this.getFuture(true);
+          return this;
+        }
+      };
+    }
+
+    private MultimapIterables<K, V> mergedCachedEntries() {
+      MultimapIterables<K, V> result = new MultimapIterables<>();
+      for (Entry<Object, Collection<V>> entry : localAdditions.asMap().entrySet()) {
+        K key = keyStateMap.get(entry.getKey()).originalKey;
+        result.extendWith(key, entry.getValue());
+      }
+      for (Entry<Object, KeyState> entry : keyStateMap.entrySet()) {
+        if (entry.getValue().valuesCached) {
+          result.extendWith(entry.getValue().originalKey, entry.getValue().values);
+        }
+      }
+      return result;
+    }
+
+    private static class MultimapIterables<K, V> implements Iterable<Entry<K, V>> {
+      Map<K, ConcatIterables<V>> map;
+
+      public MultimapIterables() {
+        this.map = new HashMap<>();
+      }
+
+      public void extendWith(K key, Iterable<V> iterable) {
+        map.compute(
+            key,
+            (k, v) -> {
+              if (v == null) v = new ConcatIterables<>();
+              v.extendWith(iterable);
+              return v;
+            });
+      }
+
+      @Override
+      public Iterator<Entry<K, V>> iterator() {
+        return Iterators.concat(
+            Iterables.transform(
+                    map.entrySet(),
+                    entry ->
+                        Iterables.transform(
+                                entry.getValue(),
+                                v -> new AbstractMap.SimpleEntry<>(entry.getKey(), v))
+                            .iterator())
+                .iterator());
+      }
+    }
+
+    @Override
+    public ReadableState<Iterable<Entry<K, V>>> entries() {
+      return new ReadableState<Iterable<Entry<K, V>>>() {
+        @Override
+        public Iterable<Entry<K, V>> read() {
+          if (complete) {
+            return Iterables.unmodifiableIterable(mergedCachedEntries());
+          }
+          Future<Iterable<Entry<ByteString, Iterable<V>>>> persistedData = getFuture(false);
+          try (Closeable scope = scopedReadState()) {
+            Iterable<Entry<ByteString, Iterable<V>>> entries = persistedData.get();
+            Map<Object, ConcatIterables<V>> entryMap = Maps.newHashMap();
+            entries.forEach(
+                entry -> {
+                  try {
+                    final K key = keyCoder.decode(entry.getKey().newInput());
+                    final Object structuralKey = keyCoder.structuralValue(key);
+                    KeyState keyState =
+                        keyStateMap.computeIfAbsent(structuralKey, k -> new KeyState(key));
+                    if (keyState.existence == KeyExistence.KNOWN_NONEXISTENT) return;
+                    entryMap.compute(
+                        structuralKey,
+                        (k, v) -> {
+                          if (v == null) v = new ConcatIterables<>();

Review Comment:
   should it be an error?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@beam.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org