You are viewing a plain text version of this content. The canonical link for it is here.
Posted to github@beam.apache.org by GitBox <gi...@apache.org> on 2023/01/06 16:58:52 UTC

[GitHub] [beam] reuvenlax commented on a diff in pull request #23492: Add Windmill support for MultimapState

reuvenlax commented on code in PR #23492:
URL: https://github.com/apache/beam/pull/23492#discussion_r1062946434


##########
runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java:
##########
@@ -2437,6 +2437,13 @@ static void verifyDoFnSupported(
 
     boolean streamingEngine = useStreamingEngine(options);
     boolean isUnifiedWorker = useUnifiedWorker(options);
+
+    if (DoFnSignatures.usesMultimapState(fn) && isUnifiedWorker) {
+      throw new UnsupportedOperationException(
+          String.format(
+              "%s does not currently support %s running using streaming on unified worker",
+              DataflowRunner.class.getSimpleName(), MultimapState.class.getSimpleName()));

Review Comment:
   What are the plans to support this on the unified worker?



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternals.java:
##########
@@ -1588,7 +1599,518 @@ private Future<Iterable<Map.Entry<ByteString, V>>> getFuture() {
         return reader.valuePrefixFuture(stateKeyPrefix, stateFamily, valueCoder);
       }
     }
-  };
+  }
+
+  private static class WindmillMultimap<K, V> extends SimpleWindmillState
+      implements MultimapState<K, V> {
+
+    private final StateNamespace namespace;
+    private final StateTag<MultimapState<K, V>> address;
+    private final ByteString stateKey;
+    private final String stateFamily;
+    private final Coder<K> keyCoder;
+    private final Coder<V> valueCoder;
+
+    private class KeyState {
+      final K originalKey;
+      KeyExistence existence;
+      // valuesCached can be true if only existence == KNOWN_EXIST
+      boolean valuesCached;
+      // represents the values in windmill. When new values are added, they are added to
+      // localAdditions but not KeyState#values. New values will be added to KeyState#values only
+      // after they are persisted into windmill and removed from localAdditions.
+      ConcatIterables<V> values;
+
+      KeyState(K originalKey) {
+        this.originalKey = originalKey;
+        existence = KeyExistence.UNKNOWN_EXISTENCE;
+        valuesCached = false;
+        values = new ConcatIterables<>();
+      }
+    }
+
+    private enum KeyExistence {
+      // this key is known to exist
+      KNOWN_EXIST,
+      // this key is known to be nonexistent
+      KNOWN_NONEXISTENT,
+      // we don't know if this key is in this multimap, this is just to provide a mapping between
+      // the original key and the structural key.
+      UNKNOWN_EXISTENCE
+    }
+
+    private boolean cleared = false;
+    // We use the structural value of the keys as the key in keyStateMap, so that different java
+    // Objects with the same content will be treated as the same Multimap key.
+    private Map<Object, KeyState> keyStateMap = Maps.newHashMap();
+    // If true, all keys are cached in keyStateMap with existence == KNOWN_EXIST.
+    private boolean allKeysKnown = false;
+
+    private boolean complete = false;
+    // All keys that have new values pending write to windmill.
+    private Multimap<Object, V> localAdditions = ArrayListMultimap.create();
+    // All keys that are pending delete. If a key exist in both localRemovals and localAdditions:
+    // new values in localAdditions will be added after old values are removed.
+    private Set<Object> localRemovals = Sets.newHashSet();
+
+    private WindmillMultimap(
+        StateNamespace namespace,
+        StateTag<MultimapState<K, V>> address,
+        String stateFamily,
+        Coder<K> keyCoder,
+        Coder<V> valueCoder,
+        boolean isNewShardingKey) {
+      this.namespace = namespace;
+      this.address = address;
+      this.stateKey = encodeKey(namespace, address);
+      this.stateFamily = stateFamily;
+      this.keyCoder = keyCoder;
+      this.valueCoder = valueCoder;
+      this.complete = isNewShardingKey;
+      this.allKeysKnown = isNewShardingKey;
+    }
+
+    @Override
+    public void put(K key, V value) {
+      final Object structuralKey = keyCoder.structuralValue(key);
+      localAdditions.put(structuralKey, value);
+      keyStateMap.compute(
+          structuralKey,
+          (k, v) -> {
+            if (v == null) v = new KeyState(key);
+            v.existence = KeyExistence.KNOWN_EXIST;
+            return v;
+          });
+    }
+
+    // Initiates a backend state read to fetch all entries if necessary.
+    private Future<Iterable<Map.Entry<ByteString, Iterable<V>>>> getFuture(boolean omitValues) {
+      if (complete) {
+        return Futures.immediateFuture(Collections.emptyList());
+      } else {
+        return reader.multimapFetchAllFuture(omitValues, stateKey, stateFamily, valueCoder);
+      }
+    }
+
+    // Initiates a backend state read to fetch a single entry if necessary.
+    private Future<Iterable<V>> getFutureForKey(K key) {
+      try {
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        keyCoder.encode(key, keyStream);
+        return reader.multimapFetchSingleEntryFuture(
+            keyStream.toByteString(), stateKey, stateFamily, valueCoder);
+      } catch (IOException e) {
+        throw new RuntimeException(e);
+      }
+    }
+
+    @Override
+    public ReadableState<Iterable<V>> get(K key) {
+      return new ReadableState<Iterable<V>>() {
+        final Object structuralKey = keyCoder.structuralValue(key);
+
+        @Override
+        public Iterable<V> read() {
+          KeyState keyState = keyStateMap.computeIfAbsent(structuralKey, k -> new KeyState(key));
+          if ((keyState.existence == KeyExistence.KNOWN_NONEXISTENT)
+              || (allKeysKnown && keyState.existence == KeyExistence.UNKNOWN_EXISTENCE)) {
+            return Collections.emptyList();
+          }
+          if (localRemovals.contains(structuralKey)) {
+            // this key has been removed locally but the removal hasn't been sent to windmill,
+            // thus values in windmill(if any) are obsolete, and we only care about local values.
+            if (localAdditions.containsKey(structuralKey)) {
+              return Iterables.unmodifiableIterable(localAdditions.get(structuralKey));
+            } else {
+              return Collections.emptyList();
+            }
+          }
+          if (keyState.valuesCached || complete) {
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(keyState.values, localAdditions.get(structuralKey)));
+          }
+          Future<Iterable<V>> persistedData = getFutureForKey(key);
+          try (Closeable scope = scopedReadState()) {
+            final Iterable<V> persistedValues = persistedData.get();
+            if (Iterables.isEmpty(persistedValues)) {
+              Collection<V> local = localAdditions.get(structuralKey);
+              if (local.isEmpty()) {
+                keyState.existence = KeyExistence.KNOWN_NONEXISTENT;
+                return Collections.emptyList();
+              }
+              return Iterables.unmodifiableIterable(local);
+            }
+            if (persistedValues instanceof Weighted) {
+              keyState.existence = KeyExistence.KNOWN_EXIST;
+              keyState.valuesCached = true;
+              ConcatIterables<V> it = new ConcatIterables<>();
+              it.extendWith(persistedValues);
+              keyState.values = it;
+            }
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(persistedValues, localAdditions.get(structuralKey)));
+          } catch (InterruptedException | ExecutionException | IOException e) {
+            if (e instanceof InterruptedException) {
+              Thread.currentThread().interrupt();
+            }
+            throw new RuntimeException("Unable to read Multimap state", e);
+          }
+        }
+
+        @Override
+        @SuppressWarnings("FutureReturnValueIgnored")
+        public ReadableState<Iterable<V>> readLater() {
+          WindmillMultimap.this.getFutureForKey(key);
+          return this;
+        }
+      };
+    }
+
+    @Override
+    protected WorkItemCommitRequest persistDirectly(WindmillStateCache.ForKeyAndFamily cache)
+        throws IOException {
+      if (!cleared && localAdditions.isEmpty() && localRemovals.isEmpty()) {
+        return WorkItemCommitRequest.newBuilder().buildPartial();
+      }
+      WorkItemCommitRequest.Builder commitBuilder = WorkItemCommitRequest.newBuilder();
+      Windmill.TagMultimapUpdateRequest.Builder builder = null;
+      if (cleared) {
+        builder = commitBuilder.addMultimapUpdatesBuilder();
+        builder.setDeleteAll(true);
+        cleared = false;
+      }
+      Set<Object> keysWithUpdates = Sets.newHashSet();
+      keysWithUpdates.addAll(localRemovals);
+      keysWithUpdates.addAll(localAdditions.keySet());
+      if (!keysWithUpdates.isEmpty() && builder == null) {
+        builder = commitBuilder.addMultimapUpdatesBuilder();
+      }
+      for (Object structuralKey : keysWithUpdates) {
+        KeyState keyState = keyStateMap.get(structuralKey);
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        keyCoder.encode(keyState.originalKey, keyStream);
+        ByteString encodedKey = keyStream.toByteString();
+        Windmill.TagMultimapEntry.Builder entryBuilder = builder.addUpdatesBuilder();
+        entryBuilder.setEntryName(encodedKey);
+        entryBuilder.setDeleteAll(localRemovals.contains(structuralKey));
+        for (V value : localAdditions.get(structuralKey)) {
+          ByteStringOutputStream valueStream = new ByteStringOutputStream();
+          valueCoder.encode(value, valueStream);
+          ByteString encodedValue = valueStream.toByteString();
+          entryBuilder.addValues(encodedValue);
+        }
+        // Move newly added values from localAdditions to cachedEntries as those new values are also
+        // persisted in Windmill.
+        if (keyState.valuesCached) {
+          keyState.values.extendWith(localAdditions.get(structuralKey));
+        }
+      }
+
+      if (builder != null) {
+        builder.setTag(stateKey).setStateFamily(stateFamily);
+      }
+      for (Object removedKey : localRemovals) {

Review Comment:
   Why iterate over localRemovals again? You just iterated over it in the previous block



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternals.java:
##########
@@ -1588,7 +1599,518 @@ private Future<Iterable<Map.Entry<ByteString, V>>> getFuture() {
         return reader.valuePrefixFuture(stateKeyPrefix, stateFamily, valueCoder);
       }
     }
-  };
+  }
+
+  private static class WindmillMultimap<K, V> extends SimpleWindmillState
+      implements MultimapState<K, V> {
+
+    private final StateNamespace namespace;
+    private final StateTag<MultimapState<K, V>> address;
+    private final ByteString stateKey;
+    private final String stateFamily;
+    private final Coder<K> keyCoder;
+    private final Coder<V> valueCoder;
+
+    private class KeyState {
+      final K originalKey;
+      KeyExistence existence;
+      // valuesCached can be true if only existence == KNOWN_EXIST
+      boolean valuesCached;
+      // represents the values in windmill. When new values are added, they are added to
+      // localAdditions but not KeyState#values. New values will be added to KeyState#values only
+      // after they are persisted into windmill and removed from localAdditions.
+      ConcatIterables<V> values;
+
+      KeyState(K originalKey) {
+        this.originalKey = originalKey;
+        existence = KeyExistence.UNKNOWN_EXISTENCE;
+        valuesCached = false;
+        values = new ConcatIterables<>();
+      }
+    }
+
+    private enum KeyExistence {
+      // this key is known to exist
+      KNOWN_EXIST,
+      // this key is known to be nonexistent
+      KNOWN_NONEXISTENT,
+      // we don't know if this key is in this multimap, this is just to provide a mapping between
+      // the original key and the structural key.
+      UNKNOWN_EXISTENCE
+    }
+
+    private boolean cleared = false;
+    // We use the structural value of the keys as the key in keyStateMap, so that different java
+    // Objects with the same content will be treated as the same Multimap key.
+    private Map<Object, KeyState> keyStateMap = Maps.newHashMap();
+    // If true, all keys are cached in keyStateMap with existence == KNOWN_EXIST.
+    private boolean allKeysKnown = false;
+
+    private boolean complete = false;
+    // All keys that have new values pending write to windmill.
+    private Multimap<Object, V> localAdditions = ArrayListMultimap.create();
+    // All keys that are pending delete. If a key exist in both localRemovals and localAdditions:
+    // new values in localAdditions will be added after old values are removed.
+    private Set<Object> localRemovals = Sets.newHashSet();
+
+    private WindmillMultimap(
+        StateNamespace namespace,
+        StateTag<MultimapState<K, V>> address,
+        String stateFamily,
+        Coder<K> keyCoder,
+        Coder<V> valueCoder,
+        boolean isNewShardingKey) {
+      this.namespace = namespace;
+      this.address = address;
+      this.stateKey = encodeKey(namespace, address);
+      this.stateFamily = stateFamily;
+      this.keyCoder = keyCoder;
+      this.valueCoder = valueCoder;
+      this.complete = isNewShardingKey;
+      this.allKeysKnown = isNewShardingKey;
+    }
+
+    @Override
+    public void put(K key, V value) {
+      final Object structuralKey = keyCoder.structuralValue(key);
+      localAdditions.put(structuralKey, value);
+      keyStateMap.compute(
+          structuralKey,
+          (k, v) -> {
+            if (v == null) v = new KeyState(key);
+            v.existence = KeyExistence.KNOWN_EXIST;
+            return v;
+          });
+    }
+
+    // Initiates a backend state read to fetch all entries if necessary.
+    private Future<Iterable<Map.Entry<ByteString, Iterable<V>>>> getFuture(boolean omitValues) {
+      if (complete) {
+        return Futures.immediateFuture(Collections.emptyList());
+      } else {
+        return reader.multimapFetchAllFuture(omitValues, stateKey, stateFamily, valueCoder);
+      }
+    }
+
+    // Initiates a backend state read to fetch a single entry if necessary.
+    private Future<Iterable<V>> getFutureForKey(K key) {
+      try {
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        keyCoder.encode(key, keyStream);
+        return reader.multimapFetchSingleEntryFuture(
+            keyStream.toByteString(), stateKey, stateFamily, valueCoder);
+      } catch (IOException e) {
+        throw new RuntimeException(e);
+      }
+    }
+
+    @Override
+    public ReadableState<Iterable<V>> get(K key) {
+      return new ReadableState<Iterable<V>>() {
+        final Object structuralKey = keyCoder.structuralValue(key);
+
+        @Override
+        public Iterable<V> read() {
+          KeyState keyState = keyStateMap.computeIfAbsent(structuralKey, k -> new KeyState(key));
+          if ((keyState.existence == KeyExistence.KNOWN_NONEXISTENT)
+              || (allKeysKnown && keyState.existence == KeyExistence.UNKNOWN_EXISTENCE)) {
+            return Collections.emptyList();
+          }
+          if (localRemovals.contains(structuralKey)) {
+            // this key has been removed locally but the removal hasn't been sent to windmill,
+            // thus values in windmill(if any) are obsolete, and we only care about local values.
+            if (localAdditions.containsKey(structuralKey)) {

Review Comment:
   Note that localAdditions.get(structuralKey) should return an empty collection if !containsKey(structuralKey)



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternals.java:
##########
@@ -1588,7 +1599,518 @@ private Future<Iterable<Map.Entry<ByteString, V>>> getFuture() {
         return reader.valuePrefixFuture(stateKeyPrefix, stateFamily, valueCoder);
       }
     }
-  };
+  }
+
+  private static class WindmillMultimap<K, V> extends SimpleWindmillState
+      implements MultimapState<K, V> {
+
+    private final StateNamespace namespace;
+    private final StateTag<MultimapState<K, V>> address;
+    private final ByteString stateKey;
+    private final String stateFamily;
+    private final Coder<K> keyCoder;
+    private final Coder<V> valueCoder;
+
+    private class KeyState {
+      final K originalKey;
+      KeyExistence existence;
+      // valuesCached can be true if only existence == KNOWN_EXIST
+      boolean valuesCached;
+      // represents the values in windmill. When new values are added, they are added to
+      // localAdditions but not KeyState#values. New values will be added to KeyState#values only
+      // after they are persisted into windmill and removed from localAdditions.
+      ConcatIterables<V> values;
+
+      KeyState(K originalKey) {
+        this.originalKey = originalKey;
+        existence = KeyExistence.UNKNOWN_EXISTENCE;
+        valuesCached = false;
+        values = new ConcatIterables<>();
+      }
+    }
+
+    private enum KeyExistence {
+      // this key is known to exist
+      KNOWN_EXIST,
+      // this key is known to be nonexistent
+      KNOWN_NONEXISTENT,
+      // we don't know if this key is in this multimap, this is just to provide a mapping between
+      // the original key and the structural key.
+      UNKNOWN_EXISTENCE
+    }
+
+    private boolean cleared = false;
+    // We use the structural value of the keys as the key in keyStateMap, so that different java
+    // Objects with the same content will be treated as the same Multimap key.
+    private Map<Object, KeyState> keyStateMap = Maps.newHashMap();
+    // If true, all keys are cached in keyStateMap with existence == KNOWN_EXIST.
+    private boolean allKeysKnown = false;
+
+    private boolean complete = false;
+    // All keys that have new values pending write to windmill.
+    private Multimap<Object, V> localAdditions = ArrayListMultimap.create();
+    // All keys that are pending delete. If a key exist in both localRemovals and localAdditions:
+    // new values in localAdditions will be added after old values are removed.
+    private Set<Object> localRemovals = Sets.newHashSet();
+
+    private WindmillMultimap(
+        StateNamespace namespace,
+        StateTag<MultimapState<K, V>> address,
+        String stateFamily,
+        Coder<K> keyCoder,
+        Coder<V> valueCoder,
+        boolean isNewShardingKey) {
+      this.namespace = namespace;
+      this.address = address;
+      this.stateKey = encodeKey(namespace, address);
+      this.stateFamily = stateFamily;
+      this.keyCoder = keyCoder;
+      this.valueCoder = valueCoder;
+      this.complete = isNewShardingKey;
+      this.allKeysKnown = isNewShardingKey;
+    }
+
+    @Override
+    public void put(K key, V value) {
+      final Object structuralKey = keyCoder.structuralValue(key);
+      localAdditions.put(structuralKey, value);
+      keyStateMap.compute(
+          structuralKey,
+          (k, v) -> {
+            if (v == null) v = new KeyState(key);
+            v.existence = KeyExistence.KNOWN_EXIST;
+            return v;
+          });
+    }
+
+    // Initiates a backend state read to fetch all entries if necessary.
+    private Future<Iterable<Map.Entry<ByteString, Iterable<V>>>> getFuture(boolean omitValues) {
+      if (complete) {
+        return Futures.immediateFuture(Collections.emptyList());
+      } else {
+        return reader.multimapFetchAllFuture(omitValues, stateKey, stateFamily, valueCoder);
+      }
+    }
+
+    // Initiates a backend state read to fetch a single entry if necessary.
+    private Future<Iterable<V>> getFutureForKey(K key) {
+      try {
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        keyCoder.encode(key, keyStream);
+        return reader.multimapFetchSingleEntryFuture(
+            keyStream.toByteString(), stateKey, stateFamily, valueCoder);
+      } catch (IOException e) {
+        throw new RuntimeException(e);
+      }
+    }
+
+    @Override
+    public ReadableState<Iterable<V>> get(K key) {
+      return new ReadableState<Iterable<V>>() {
+        final Object structuralKey = keyCoder.structuralValue(key);
+
+        @Override
+        public Iterable<V> read() {
+          KeyState keyState = keyStateMap.computeIfAbsent(structuralKey, k -> new KeyState(key));
+          if ((keyState.existence == KeyExistence.KNOWN_NONEXISTENT)
+              || (allKeysKnown && keyState.existence == KeyExistence.UNKNOWN_EXISTENCE)) {
+            return Collections.emptyList();
+          }
+          if (localRemovals.contains(structuralKey)) {
+            // this key has been removed locally but the removal hasn't been sent to windmill,
+            // thus values in windmill(if any) are obsolete, and we only care about local values.
+            if (localAdditions.containsKey(structuralKey)) {

Review Comment:
   this is a double lookup in the map. Consider replacing with a single lookup



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternals.java:
##########
@@ -1588,7 +1599,518 @@ private Future<Iterable<Map.Entry<ByteString, V>>> getFuture() {
         return reader.valuePrefixFuture(stateKeyPrefix, stateFamily, valueCoder);
       }
     }
-  };
+  }
+
+  private static class WindmillMultimap<K, V> extends SimpleWindmillState
+      implements MultimapState<K, V> {
+
+    private final StateNamespace namespace;
+    private final StateTag<MultimapState<K, V>> address;
+    private final ByteString stateKey;
+    private final String stateFamily;
+    private final Coder<K> keyCoder;
+    private final Coder<V> valueCoder;
+
+    private class KeyState {
+      final K originalKey;
+      KeyExistence existence;
+      // valuesCached can be true if only existence == KNOWN_EXIST
+      boolean valuesCached;
+      // represents the values in windmill. When new values are added, they are added to
+      // localAdditions but not KeyState#values. New values will be added to KeyState#values only
+      // after they are persisted into windmill and removed from localAdditions.
+      ConcatIterables<V> values;
+
+      KeyState(K originalKey) {
+        this.originalKey = originalKey;
+        existence = KeyExistence.UNKNOWN_EXISTENCE;
+        valuesCached = false;
+        values = new ConcatIterables<>();
+      }
+    }
+
+    private enum KeyExistence {
+      // this key is known to exist
+      KNOWN_EXIST,
+      // this key is known to be nonexistent
+      KNOWN_NONEXISTENT,
+      // we don't know if this key is in this multimap, this is just to provide a mapping between
+      // the original key and the structural key.
+      UNKNOWN_EXISTENCE
+    }
+
+    private boolean cleared = false;
+    // We use the structural value of the keys as the key in keyStateMap, so that different java
+    // Objects with the same content will be treated as the same Multimap key.
+    private Map<Object, KeyState> keyStateMap = Maps.newHashMap();
+    // If true, all keys are cached in keyStateMap with existence == KNOWN_EXIST.
+    private boolean allKeysKnown = false;
+
+    private boolean complete = false;
+    // All keys that have new values pending write to windmill.
+    private Multimap<Object, V> localAdditions = ArrayListMultimap.create();
+    // All keys that are pending delete. If a key exist in both localRemovals and localAdditions:
+    // new values in localAdditions will be added after old values are removed.
+    private Set<Object> localRemovals = Sets.newHashSet();
+
+    private WindmillMultimap(
+        StateNamespace namespace,
+        StateTag<MultimapState<K, V>> address,
+        String stateFamily,
+        Coder<K> keyCoder,
+        Coder<V> valueCoder,
+        boolean isNewShardingKey) {
+      this.namespace = namespace;
+      this.address = address;
+      this.stateKey = encodeKey(namespace, address);
+      this.stateFamily = stateFamily;
+      this.keyCoder = keyCoder;
+      this.valueCoder = valueCoder;
+      this.complete = isNewShardingKey;
+      this.allKeysKnown = isNewShardingKey;
+    }
+
+    @Override
+    public void put(K key, V value) {
+      final Object structuralKey = keyCoder.structuralValue(key);
+      localAdditions.put(structuralKey, value);
+      keyStateMap.compute(
+          structuralKey,
+          (k, v) -> {
+            if (v == null) v = new KeyState(key);
+            v.existence = KeyExistence.KNOWN_EXIST;
+            return v;
+          });
+    }
+
+    // Initiates a backend state read to fetch all entries if necessary.
+    private Future<Iterable<Map.Entry<ByteString, Iterable<V>>>> getFuture(boolean omitValues) {
+      if (complete) {
+        return Futures.immediateFuture(Collections.emptyList());
+      } else {
+        return reader.multimapFetchAllFuture(omitValues, stateKey, stateFamily, valueCoder);
+      }
+    }
+
+    // Initiates a backend state read to fetch a single entry if necessary.
+    private Future<Iterable<V>> getFutureForKey(K key) {
+      try {
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        keyCoder.encode(key, keyStream);
+        return reader.multimapFetchSingleEntryFuture(
+            keyStream.toByteString(), stateKey, stateFamily, valueCoder);
+      } catch (IOException e) {
+        throw new RuntimeException(e);
+      }
+    }
+
+    @Override
+    public ReadableState<Iterable<V>> get(K key) {
+      return new ReadableState<Iterable<V>>() {
+        final Object structuralKey = keyCoder.structuralValue(key);
+
+        @Override
+        public Iterable<V> read() {
+          KeyState keyState = keyStateMap.computeIfAbsent(structuralKey, k -> new KeyState(key));
+          if ((keyState.existence == KeyExistence.KNOWN_NONEXISTENT)
+              || (allKeysKnown && keyState.existence == KeyExistence.UNKNOWN_EXISTENCE)) {
+            return Collections.emptyList();
+          }
+          if (localRemovals.contains(structuralKey)) {
+            // this key has been removed locally but the removal hasn't been sent to windmill,
+            // thus values in windmill(if any) are obsolete, and we only care about local values.
+            if (localAdditions.containsKey(structuralKey)) {
+              return Iterables.unmodifiableIterable(localAdditions.get(structuralKey));
+            } else {
+              return Collections.emptyList();
+            }
+          }
+          if (keyState.valuesCached || complete) {
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(keyState.values, localAdditions.get(structuralKey)));
+          }
+          Future<Iterable<V>> persistedData = getFutureForKey(key);
+          try (Closeable scope = scopedReadState()) {
+            final Iterable<V> persistedValues = persistedData.get();
+            if (Iterables.isEmpty(persistedValues)) {
+              Collection<V> local = localAdditions.get(structuralKey);
+              if (local.isEmpty()) {
+                keyState.existence = KeyExistence.KNOWN_NONEXISTENT;
+                return Collections.emptyList();
+              }
+              return Iterables.unmodifiableIterable(local);
+            }
+            if (persistedValues instanceof Weighted) {

Review Comment:
   how is Weighted being used here?



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternals.java:
##########
@@ -1588,7 +1599,518 @@ private Future<Iterable<Map.Entry<ByteString, V>>> getFuture() {
         return reader.valuePrefixFuture(stateKeyPrefix, stateFamily, valueCoder);
       }
     }
-  };
+  }
+
+  private static class WindmillMultimap<K, V> extends SimpleWindmillState
+      implements MultimapState<K, V> {
+
+    private final StateNamespace namespace;
+    private final StateTag<MultimapState<K, V>> address;
+    private final ByteString stateKey;
+    private final String stateFamily;
+    private final Coder<K> keyCoder;
+    private final Coder<V> valueCoder;
+
+    private class KeyState {
+      final K originalKey;
+      KeyExistence existence;
+      // valuesCached can be true if only existence == KNOWN_EXIST
+      boolean valuesCached;
+      // represents the values in windmill. When new values are added, they are added to
+      // localAdditions but not KeyState#values. New values will be added to KeyState#values only
+      // after they are persisted into windmill and removed from localAdditions.
+      ConcatIterables<V> values;
+
+      KeyState(K originalKey) {
+        this.originalKey = originalKey;
+        existence = KeyExistence.UNKNOWN_EXISTENCE;
+        valuesCached = false;
+        values = new ConcatIterables<>();
+      }
+    }
+
+    private enum KeyExistence {
+      // this key is known to exist
+      KNOWN_EXIST,
+      // this key is known to be nonexistent
+      KNOWN_NONEXISTENT,
+      // we don't know if this key is in this multimap, this is just to provide a mapping between
+      // the original key and the structural key.
+      UNKNOWN_EXISTENCE
+    }
+
+    private boolean cleared = false;
+    // We use the structural value of the keys as the key in keyStateMap, so that different java
+    // Objects with the same content will be treated as the same Multimap key.
+    private Map<Object, KeyState> keyStateMap = Maps.newHashMap();
+    // If true, all keys are cached in keyStateMap with existence == KNOWN_EXIST.
+    private boolean allKeysKnown = false;
+
+    private boolean complete = false;
+    // All keys that have new values pending write to windmill.
+    private Multimap<Object, V> localAdditions = ArrayListMultimap.create();
+    // All keys that are pending delete. If a key exist in both localRemovals and localAdditions:
+    // new values in localAdditions will be added after old values are removed.
+    private Set<Object> localRemovals = Sets.newHashSet();
+
+    private WindmillMultimap(
+        StateNamespace namespace,
+        StateTag<MultimapState<K, V>> address,
+        String stateFamily,
+        Coder<K> keyCoder,
+        Coder<V> valueCoder,
+        boolean isNewShardingKey) {
+      this.namespace = namespace;
+      this.address = address;
+      this.stateKey = encodeKey(namespace, address);
+      this.stateFamily = stateFamily;
+      this.keyCoder = keyCoder;
+      this.valueCoder = valueCoder;
+      this.complete = isNewShardingKey;
+      this.allKeysKnown = isNewShardingKey;
+    }
+
+    @Override
+    public void put(K key, V value) {
+      final Object structuralKey = keyCoder.structuralValue(key);
+      localAdditions.put(structuralKey, value);
+      keyStateMap.compute(
+          structuralKey,
+          (k, v) -> {
+            if (v == null) v = new KeyState(key);
+            v.existence = KeyExistence.KNOWN_EXIST;
+            return v;
+          });
+    }
+
+    // Initiates a backend state read to fetch all entries if necessary.
+    private Future<Iterable<Map.Entry<ByteString, Iterable<V>>>> getFuture(boolean omitValues) {
+      if (complete) {
+        return Futures.immediateFuture(Collections.emptyList());
+      } else {
+        return reader.multimapFetchAllFuture(omitValues, stateKey, stateFamily, valueCoder);
+      }
+    }
+
+    // Initiates a backend state read to fetch a single entry if necessary.
+    private Future<Iterable<V>> getFutureForKey(K key) {
+      try {
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        keyCoder.encode(key, keyStream);
+        return reader.multimapFetchSingleEntryFuture(
+            keyStream.toByteString(), stateKey, stateFamily, valueCoder);
+      } catch (IOException e) {
+        throw new RuntimeException(e);
+      }
+    }
+
+    @Override
+    public ReadableState<Iterable<V>> get(K key) {
+      return new ReadableState<Iterable<V>>() {
+        final Object structuralKey = keyCoder.structuralValue(key);
+
+        @Override
+        public Iterable<V> read() {
+          KeyState keyState = keyStateMap.computeIfAbsent(structuralKey, k -> new KeyState(key));
+          if ((keyState.existence == KeyExistence.KNOWN_NONEXISTENT)
+              || (allKeysKnown && keyState.existence == KeyExistence.UNKNOWN_EXISTENCE)) {
+            return Collections.emptyList();
+          }
+          if (localRemovals.contains(structuralKey)) {
+            // this key has been removed locally but the removal hasn't been sent to windmill,
+            // thus values in windmill(if any) are obsolete, and we only care about local values.
+            if (localAdditions.containsKey(structuralKey)) {
+              return Iterables.unmodifiableIterable(localAdditions.get(structuralKey));
+            } else {
+              return Collections.emptyList();
+            }
+          }
+          if (keyState.valuesCached || complete) {
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(keyState.values, localAdditions.get(structuralKey)));
+          }
+          Future<Iterable<V>> persistedData = getFutureForKey(key);
+          try (Closeable scope = scopedReadState()) {
+            final Iterable<V> persistedValues = persistedData.get();
+            if (Iterables.isEmpty(persistedValues)) {
+              Collection<V> local = localAdditions.get(structuralKey);
+              if (local.isEmpty()) {
+                keyState.existence = KeyExistence.KNOWN_NONEXISTENT;
+                return Collections.emptyList();
+              }
+              return Iterables.unmodifiableIterable(local);
+            }
+            if (persistedValues instanceof Weighted) {
+              keyState.existence = KeyExistence.KNOWN_EXIST;
+              keyState.valuesCached = true;
+              ConcatIterables<V> it = new ConcatIterables<>();
+              it.extendWith(persistedValues);
+              keyState.values = it;
+            }
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(persistedValues, localAdditions.get(structuralKey)));
+          } catch (InterruptedException | ExecutionException | IOException e) {
+            if (e instanceof InterruptedException) {
+              Thread.currentThread().interrupt();
+            }
+            throw new RuntimeException("Unable to read Multimap state", e);
+          }
+        }
+
+        @Override
+        @SuppressWarnings("FutureReturnValueIgnored")
+        public ReadableState<Iterable<V>> readLater() {
+          WindmillMultimap.this.getFutureForKey(key);
+          return this;
+        }
+      };
+    }
+
+    @Override
+    protected WorkItemCommitRequest persistDirectly(WindmillStateCache.ForKeyAndFamily cache)
+        throws IOException {
+      if (!cleared && localAdditions.isEmpty() && localRemovals.isEmpty()) {
+        return WorkItemCommitRequest.newBuilder().buildPartial();
+      }
+      WorkItemCommitRequest.Builder commitBuilder = WorkItemCommitRequest.newBuilder();
+      Windmill.TagMultimapUpdateRequest.Builder builder = null;
+      if (cleared) {
+        builder = commitBuilder.addMultimapUpdatesBuilder();
+        builder.setDeleteAll(true);
+        cleared = false;
+      }
+      Set<Object> keysWithUpdates = Sets.newHashSet();

Review Comment:
   Instead of copying all keys, maybe iterate over Iterables.concat(localRemovals, localAdditions.keySet())



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternals.java:
##########
@@ -1588,7 +1599,518 @@ private Future<Iterable<Map.Entry<ByteString, V>>> getFuture() {
         return reader.valuePrefixFuture(stateKeyPrefix, stateFamily, valueCoder);
       }
     }
-  };
+  }
+
+  private static class WindmillMultimap<K, V> extends SimpleWindmillState
+      implements MultimapState<K, V> {
+
+    private final StateNamespace namespace;
+    private final StateTag<MultimapState<K, V>> address;
+    private final ByteString stateKey;
+    private final String stateFamily;
+    private final Coder<K> keyCoder;
+    private final Coder<V> valueCoder;
+
+    private class KeyState {
+      final K originalKey;
+      KeyExistence existence;
+      // valuesCached can be true if only existence == KNOWN_EXIST
+      boolean valuesCached;
+      // represents the values in windmill. When new values are added, they are added to
+      // localAdditions but not KeyState#values. New values will be added to KeyState#values only
+      // after they are persisted into windmill and removed from localAdditions.
+      ConcatIterables<V> values;
+
+      KeyState(K originalKey) {
+        this.originalKey = originalKey;
+        existence = KeyExistence.UNKNOWN_EXISTENCE;
+        valuesCached = false;
+        values = new ConcatIterables<>();
+      }
+    }
+
+    private enum KeyExistence {
+      // this key is known to exist
+      KNOWN_EXIST,
+      // this key is known to be nonexistent
+      KNOWN_NONEXISTENT,
+      // we don't know if this key is in this multimap, this is just to provide a mapping between
+      // the original key and the structural key.
+      UNKNOWN_EXISTENCE
+    }
+
+    private boolean cleared = false;
+    // We use the structural value of the keys as the key in keyStateMap, so that different java
+    // Objects with the same content will be treated as the same Multimap key.
+    private Map<Object, KeyState> keyStateMap = Maps.newHashMap();
+    // If true, all keys are cached in keyStateMap with existence == KNOWN_EXIST.
+    private boolean allKeysKnown = false;
+
+    private boolean complete = false;
+    // All keys that have new values pending write to windmill.
+    private Multimap<Object, V> localAdditions = ArrayListMultimap.create();
+    // All keys that are pending delete. If a key exist in both localRemovals and localAdditions:
+    // new values in localAdditions will be added after old values are removed.
+    private Set<Object> localRemovals = Sets.newHashSet();
+
+    private WindmillMultimap(
+        StateNamespace namespace,
+        StateTag<MultimapState<K, V>> address,
+        String stateFamily,
+        Coder<K> keyCoder,
+        Coder<V> valueCoder,
+        boolean isNewShardingKey) {
+      this.namespace = namespace;
+      this.address = address;
+      this.stateKey = encodeKey(namespace, address);
+      this.stateFamily = stateFamily;
+      this.keyCoder = keyCoder;
+      this.valueCoder = valueCoder;
+      this.complete = isNewShardingKey;
+      this.allKeysKnown = isNewShardingKey;
+    }
+
+    @Override
+    public void put(K key, V value) {
+      final Object structuralKey = keyCoder.structuralValue(key);
+      localAdditions.put(structuralKey, value);
+      keyStateMap.compute(
+          structuralKey,
+          (k, v) -> {
+            if (v == null) v = new KeyState(key);
+            v.existence = KeyExistence.KNOWN_EXIST;
+            return v;
+          });
+    }
+
+    // Initiates a backend state read to fetch all entries if necessary.
+    private Future<Iterable<Map.Entry<ByteString, Iterable<V>>>> getFuture(boolean omitValues) {
+      if (complete) {
+        return Futures.immediateFuture(Collections.emptyList());
+      } else {
+        return reader.multimapFetchAllFuture(omitValues, stateKey, stateFamily, valueCoder);
+      }
+    }
+
+    // Initiates a backend state read to fetch a single entry if necessary.
+    private Future<Iterable<V>> getFutureForKey(K key) {
+      try {
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        keyCoder.encode(key, keyStream);
+        return reader.multimapFetchSingleEntryFuture(
+            keyStream.toByteString(), stateKey, stateFamily, valueCoder);
+      } catch (IOException e) {
+        throw new RuntimeException(e);
+      }
+    }
+
+    @Override
+    public ReadableState<Iterable<V>> get(K key) {
+      return new ReadableState<Iterable<V>>() {
+        final Object structuralKey = keyCoder.structuralValue(key);
+
+        @Override
+        public Iterable<V> read() {
+          KeyState keyState = keyStateMap.computeIfAbsent(structuralKey, k -> new KeyState(key));
+          if ((keyState.existence == KeyExistence.KNOWN_NONEXISTENT)
+              || (allKeysKnown && keyState.existence == KeyExistence.UNKNOWN_EXISTENCE)) {
+            return Collections.emptyList();
+          }
+          if (localRemovals.contains(structuralKey)) {
+            // this key has been removed locally but the removal hasn't been sent to windmill,
+            // thus values in windmill(if any) are obsolete, and we only care about local values.
+            if (localAdditions.containsKey(structuralKey)) {
+              return Iterables.unmodifiableIterable(localAdditions.get(structuralKey));
+            } else {
+              return Collections.emptyList();
+            }
+          }
+          if (keyState.valuesCached || complete) {
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(keyState.values, localAdditions.get(structuralKey)));
+          }
+          Future<Iterable<V>> persistedData = getFutureForKey(key);
+          try (Closeable scope = scopedReadState()) {
+            final Iterable<V> persistedValues = persistedData.get();
+            if (Iterables.isEmpty(persistedValues)) {
+              Collection<V> local = localAdditions.get(structuralKey);
+              if (local.isEmpty()) {
+                keyState.existence = KeyExistence.KNOWN_NONEXISTENT;
+                return Collections.emptyList();
+              }
+              return Iterables.unmodifiableIterable(local);
+            }
+            if (persistedValues instanceof Weighted) {
+              keyState.existence = KeyExistence.KNOWN_EXIST;
+              keyState.valuesCached = true;
+              ConcatIterables<V> it = new ConcatIterables<>();
+              it.extendWith(persistedValues);
+              keyState.values = it;
+            }
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(persistedValues, localAdditions.get(structuralKey)));
+          } catch (InterruptedException | ExecutionException | IOException e) {
+            if (e instanceof InterruptedException) {
+              Thread.currentThread().interrupt();
+            }
+            throw new RuntimeException("Unable to read Multimap state", e);
+          }
+        }
+
+        @Override
+        @SuppressWarnings("FutureReturnValueIgnored")
+        public ReadableState<Iterable<V>> readLater() {
+          WindmillMultimap.this.getFutureForKey(key);
+          return this;
+        }
+      };
+    }
+
+    @Override
+    protected WorkItemCommitRequest persistDirectly(WindmillStateCache.ForKeyAndFamily cache)
+        throws IOException {
+      if (!cleared && localAdditions.isEmpty() && localRemovals.isEmpty()) {
+        return WorkItemCommitRequest.newBuilder().buildPartial();
+      }
+      WorkItemCommitRequest.Builder commitBuilder = WorkItemCommitRequest.newBuilder();
+      Windmill.TagMultimapUpdateRequest.Builder builder = null;
+      if (cleared) {
+        builder = commitBuilder.addMultimapUpdatesBuilder();
+        builder.setDeleteAll(true);
+        cleared = false;
+      }
+      Set<Object> keysWithUpdates = Sets.newHashSet();
+      keysWithUpdates.addAll(localRemovals);
+      keysWithUpdates.addAll(localAdditions.keySet());
+      if (!keysWithUpdates.isEmpty() && builder == null) {
+        builder = commitBuilder.addMultimapUpdatesBuilder();
+      }
+      for (Object structuralKey : keysWithUpdates) {
+        KeyState keyState = keyStateMap.get(structuralKey);
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        keyCoder.encode(keyState.originalKey, keyStream);
+        ByteString encodedKey = keyStream.toByteString();
+        Windmill.TagMultimapEntry.Builder entryBuilder = builder.addUpdatesBuilder();
+        entryBuilder.setEntryName(encodedKey);
+        entryBuilder.setDeleteAll(localRemovals.contains(structuralKey));
+        for (V value : localAdditions.get(structuralKey)) {
+          ByteStringOutputStream valueStream = new ByteStringOutputStream();
+          valueCoder.encode(value, valueStream);
+          ByteString encodedValue = valueStream.toByteString();
+          entryBuilder.addValues(encodedValue);
+        }
+        // Move newly added values from localAdditions to cachedEntries as those new values are also
+        // persisted in Windmill.
+        if (keyState.valuesCached) {
+          keyState.values.extendWith(localAdditions.get(structuralKey));
+        }
+      }
+
+      if (builder != null) {
+        builder.setTag(stateKey).setStateFamily(stateFamily);
+      }
+      for (Object removedKey : localRemovals) {
+        KeyState keyState = keyStateMap.get(removedKey);
+        if (keyState.existence != KeyExistence.KNOWN_NONEXISTENT) {
+          keyStateMap.remove(removedKey);
+        }
+      }
+      localRemovals = Sets.newHashSet();
+      localAdditions = ArrayListMultimap.create();
+
+      cache.put(namespace, address, this, 1);
+
+      return commitBuilder.buildPartial();
+    }
+
+    @Override
+    public void remove(K key) {
+      final Object structuralKey = keyCoder.structuralValue(key);
+      KeyState keyState = keyStateMap.computeIfAbsent(structuralKey, k -> new KeyState(key));

Review Comment:
   I worry about this pattern, because we're leaving an unbounded number of tombstones in place for every key deleted, so memory might grow without bound. Is there any way of bounding this?



##########
runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill.proto:
##########
@@ -146,6 +146,79 @@ message TagBag {
   optional int64 fetch_max_bytes = 6 [default = 0x7fffffffffffffff];
 }
 
+// For a given sharding key and state family, a TagMultimap is a collection of
+// `entry_name, bag of values` pairs, each pair is a TagMultimapEntry.
+//
+// The request_position, continuation_position and fetch_max_bytes fields in
+// TagMultimapEntry are used for the pagination and byte limiting of individual
+// entry fetch requests get(entry_name); while those fields in
+// TagMultimapFetchRequest and TagMultimapFetchResponse are used for full
+// multimap fetch requests entry_names() and entries().
+// Do not set both in a TagMultimapFetchRequest at the same time.
+message TagMultimapEntry {
+  optional bytes entry_name = 1;
+  // In update request: if true all values associated with this entry_name will
+  // be deleted. If new values are present they will be written.
+  optional bool delete_all = 2;
+  // In update request: The given values will be added to the collection and
+  // associated with entry_name.
+  // In fetch response: Values that are associated with this entry_name in the
+  // multimap.
+  repeated bytes values = 3;
+  // In fetch request: A previously returned continuation_position from an
+  // earlier read response. Indicates we wish to fetch the next page of values.
+  // If this is the first request, set to empty.
+  // In fetch response: copied from request.
+  optional int64 request_position = 4;
+  // In fetch response: Set when there are values after those returned above,
+  // but they were suppressed to respect the fetch_max_bytes limit. Subsequent
+  // requests should copy this to request_position to retrieve the next page of
+  // values.
+  optional int64 continuation_position = 5;
+  // In fetch request: Limits the size of the fetched values to this byte limit.
+  // A lower limit may be imposed by the service.
+  optional int64 fetch_max_bytes = 6 [default = 0x7fffffffffffffff];
+}
+
+message TagMultimapFetchRequest {
+  optional bytes tag = 1;
+  optional string state_family = 2;
+  // If true, values will be omitted in the response.
+  optional bool fetch_entry_names_only = 3;
+  // Limits the size of the fetched entries to this byte limit. A lower limit
+  // may be imposed by the service.
+  optional int64 fetch_max_bytes = 4 [default = 0x7fffffffffffffff];
+  // A previously returned continuation_position from an earlier fetch response.
+  // Indicates we wish to fetch the next page of entries. If this is the first
+  // request, set to empty.
+  optional bytes request_position = 5;
+  // Fetch the requested subset of entries only. Will fetch all entries if left
+  // empty. Entries in entries_to_fetch should only have the entry_name,
+  // request_position and fetch_max_bytes set.
+  repeated TagMultimapEntry entries_to_fetch = 6;

Review Comment:
   what happens if fetch_max_bytes is set on the FetchRequest and the individual entries?



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternals.java:
##########
@@ -1588,7 +1599,518 @@ private Future<Iterable<Map.Entry<ByteString, V>>> getFuture() {
         return reader.valuePrefixFuture(stateKeyPrefix, stateFamily, valueCoder);
       }
     }
-  };
+  }
+
+  private static class WindmillMultimap<K, V> extends SimpleWindmillState
+      implements MultimapState<K, V> {
+
+    private final StateNamespace namespace;
+    private final StateTag<MultimapState<K, V>> address;
+    private final ByteString stateKey;
+    private final String stateFamily;
+    private final Coder<K> keyCoder;
+    private final Coder<V> valueCoder;
+
+    private class KeyState {
+      final K originalKey;
+      KeyExistence existence;
+      // valuesCached can be true if only existence == KNOWN_EXIST
+      boolean valuesCached;
+      // represents the values in windmill. When new values are added, they are added to
+      // localAdditions but not KeyState#values. New values will be added to KeyState#values only
+      // after they are persisted into windmill and removed from localAdditions.
+      ConcatIterables<V> values;
+
+      KeyState(K originalKey) {
+        this.originalKey = originalKey;
+        existence = KeyExistence.UNKNOWN_EXISTENCE;
+        valuesCached = false;
+        values = new ConcatIterables<>();
+      }
+    }
+
+    private enum KeyExistence {
+      // this key is known to exist
+      KNOWN_EXIST,
+      // this key is known to be nonexistent
+      KNOWN_NONEXISTENT,
+      // we don't know if this key is in this multimap, this is just to provide a mapping between
+      // the original key and the structural key.
+      UNKNOWN_EXISTENCE
+    }
+
+    private boolean cleared = false;
+    // We use the structural value of the keys as the key in keyStateMap, so that different java
+    // Objects with the same content will be treated as the same Multimap key.
+    private Map<Object, KeyState> keyStateMap = Maps.newHashMap();
+    // If true, all keys are cached in keyStateMap with existence == KNOWN_EXIST.
+    private boolean allKeysKnown = false;
+
+    private boolean complete = false;
+    // All keys that have new values pending write to windmill.
+    private Multimap<Object, V> localAdditions = ArrayListMultimap.create();
+    // All keys that are pending delete. If a key exist in both localRemovals and localAdditions:
+    // new values in localAdditions will be added after old values are removed.
+    private Set<Object> localRemovals = Sets.newHashSet();
+
+    private WindmillMultimap(
+        StateNamespace namespace,
+        StateTag<MultimapState<K, V>> address,
+        String stateFamily,
+        Coder<K> keyCoder,
+        Coder<V> valueCoder,
+        boolean isNewShardingKey) {
+      this.namespace = namespace;
+      this.address = address;
+      this.stateKey = encodeKey(namespace, address);
+      this.stateFamily = stateFamily;
+      this.keyCoder = keyCoder;
+      this.valueCoder = valueCoder;
+      this.complete = isNewShardingKey;
+      this.allKeysKnown = isNewShardingKey;
+    }
+
+    @Override
+    public void put(K key, V value) {
+      final Object structuralKey = keyCoder.structuralValue(key);
+      localAdditions.put(structuralKey, value);
+      keyStateMap.compute(
+          structuralKey,
+          (k, v) -> {
+            if (v == null) v = new KeyState(key);
+            v.existence = KeyExistence.KNOWN_EXIST;
+            return v;
+          });
+    }
+
+    // Initiates a backend state read to fetch all entries if necessary.
+    private Future<Iterable<Map.Entry<ByteString, Iterable<V>>>> getFuture(boolean omitValues) {
+      if (complete) {
+        return Futures.immediateFuture(Collections.emptyList());
+      } else {
+        return reader.multimapFetchAllFuture(omitValues, stateKey, stateFamily, valueCoder);
+      }
+    }
+
+    // Initiates a backend state read to fetch a single entry if necessary.
+    private Future<Iterable<V>> getFutureForKey(K key) {
+      try {
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        keyCoder.encode(key, keyStream);
+        return reader.multimapFetchSingleEntryFuture(
+            keyStream.toByteString(), stateKey, stateFamily, valueCoder);
+      } catch (IOException e) {
+        throw new RuntimeException(e);
+      }
+    }
+
+    @Override
+    public ReadableState<Iterable<V>> get(K key) {
+      return new ReadableState<Iterable<V>>() {
+        final Object structuralKey = keyCoder.structuralValue(key);
+
+        @Override
+        public Iterable<V> read() {
+          KeyState keyState = keyStateMap.computeIfAbsent(structuralKey, k -> new KeyState(key));
+          if ((keyState.existence == KeyExistence.KNOWN_NONEXISTENT)
+              || (allKeysKnown && keyState.existence == KeyExistence.UNKNOWN_EXISTENCE)) {
+            return Collections.emptyList();
+          }
+          if (localRemovals.contains(structuralKey)) {
+            // this key has been removed locally but the removal hasn't been sent to windmill,
+            // thus values in windmill(if any) are obsolete, and we only care about local values.
+            if (localAdditions.containsKey(structuralKey)) {
+              return Iterables.unmodifiableIterable(localAdditions.get(structuralKey));
+            } else {
+              return Collections.emptyList();
+            }
+          }
+          if (keyState.valuesCached || complete) {
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(keyState.values, localAdditions.get(structuralKey)));
+          }
+          Future<Iterable<V>> persistedData = getFutureForKey(key);
+          try (Closeable scope = scopedReadState()) {
+            final Iterable<V> persistedValues = persistedData.get();
+            if (Iterables.isEmpty(persistedValues)) {
+              Collection<V> local = localAdditions.get(structuralKey);
+              if (local.isEmpty()) {
+                keyState.existence = KeyExistence.KNOWN_NONEXISTENT;
+                return Collections.emptyList();
+              }
+              return Iterables.unmodifiableIterable(local);
+            }
+            if (persistedValues instanceof Weighted) {
+              keyState.existence = KeyExistence.KNOWN_EXIST;
+              keyState.valuesCached = true;
+              ConcatIterables<V> it = new ConcatIterables<>();
+              it.extendWith(persistedValues);
+              keyState.values = it;
+            }
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(persistedValues, localAdditions.get(structuralKey)));
+          } catch (InterruptedException | ExecutionException | IOException e) {
+            if (e instanceof InterruptedException) {
+              Thread.currentThread().interrupt();
+            }
+            throw new RuntimeException("Unable to read Multimap state", e);
+          }
+        }
+
+        @Override
+        @SuppressWarnings("FutureReturnValueIgnored")
+        public ReadableState<Iterable<V>> readLater() {
+          WindmillMultimap.this.getFutureForKey(key);
+          return this;
+        }
+      };
+    }
+
+    @Override
+    protected WorkItemCommitRequest persistDirectly(WindmillStateCache.ForKeyAndFamily cache)
+        throws IOException {
+      if (!cleared && localAdditions.isEmpty() && localRemovals.isEmpty()) {
+        return WorkItemCommitRequest.newBuilder().buildPartial();
+      }
+      WorkItemCommitRequest.Builder commitBuilder = WorkItemCommitRequest.newBuilder();
+      Windmill.TagMultimapUpdateRequest.Builder builder = null;
+      if (cleared) {
+        builder = commitBuilder.addMultimapUpdatesBuilder();
+        builder.setDeleteAll(true);
+        cleared = false;
+      }
+      Set<Object> keysWithUpdates = Sets.newHashSet();
+      keysWithUpdates.addAll(localRemovals);
+      keysWithUpdates.addAll(localAdditions.keySet());
+      if (!keysWithUpdates.isEmpty() && builder == null) {
+        builder = commitBuilder.addMultimapUpdatesBuilder();
+      }
+      for (Object structuralKey : keysWithUpdates) {
+        KeyState keyState = keyStateMap.get(structuralKey);
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        keyCoder.encode(keyState.originalKey, keyStream);
+        ByteString encodedKey = keyStream.toByteString();
+        Windmill.TagMultimapEntry.Builder entryBuilder = builder.addUpdatesBuilder();
+        entryBuilder.setEntryName(encodedKey);
+        entryBuilder.setDeleteAll(localRemovals.contains(structuralKey));
+        for (V value : localAdditions.get(structuralKey)) {
+          ByteStringOutputStream valueStream = new ByteStringOutputStream();
+          valueCoder.encode(value, valueStream);
+          ByteString encodedValue = valueStream.toByteString();
+          entryBuilder.addValues(encodedValue);
+        }
+        // Move newly added values from localAdditions to cachedEntries as those new values are also
+        // persisted in Windmill.
+        if (keyState.valuesCached) {
+          keyState.values.extendWith(localAdditions.get(structuralKey));
+        }
+      }
+
+      if (builder != null) {
+        builder.setTag(stateKey).setStateFamily(stateFamily);
+      }
+      for (Object removedKey : localRemovals) {
+        KeyState keyState = keyStateMap.get(removedKey);
+        if (keyState.existence != KeyExistence.KNOWN_NONEXISTENT) {
+          keyStateMap.remove(removedKey);
+        }
+      }
+      localRemovals = Sets.newHashSet();
+      localAdditions = ArrayListMultimap.create();
+
+      cache.put(namespace, address, this, 1);
+
+      return commitBuilder.buildPartial();
+    }
+
+    @Override
+    public void remove(K key) {
+      final Object structuralKey = keyCoder.structuralValue(key);
+      KeyState keyState = keyStateMap.computeIfAbsent(structuralKey, k -> new KeyState(key));
+      if (keyState.existence == KeyExistence.KNOWN_NONEXISTENT
+          || (allKeysKnown && keyState.existence == KeyExistence.UNKNOWN_EXISTENCE)) {
+        return;
+      }
+      if (keyState.valuesCached || !complete) {
+        // there may be data in windmill that need to be removed.
+        localRemovals.add(structuralKey);
+        keyState.values = new ConcatIterables<>();
+        keyState.valuesCached = false;
+      } // else: no data in windmill, deleting from local cache is sufficient.
+      localAdditions.removeAll(structuralKey);
+      keyState.existence = KeyExistence.KNOWN_NONEXISTENT;
+    }
+
+    @Override
+    public void clear() {
+      keyStateMap = Maps.newHashMap();
+      localAdditions = ArrayListMultimap.create();
+      localRemovals = Sets.newHashSet();
+      cleared = true;
+      complete = true;
+      allKeysKnown = true;
+    }
+
+    @Override
+    public ReadableState<Iterable<K>> keys() {
+      return new ReadableState<Iterable<K>>() {
+        @Override
+        public Iterable<K> read() {
+          if (allKeysKnown) {
+            return Iterables.unmodifiableIterable(
+                Iterables.transform(
+                    Iterables.filter(
+                        keyStateMap.values(),
+                        keyState -> keyState.existence == KeyExistence.KNOWN_EXIST),
+                    keyState -> keyState.originalKey));
+          }
+          Future<Iterable<Entry<ByteString, Iterable<V>>>> persistedData = getFuture(true);
+          try (Closeable scope = scopedReadState()) {
+            Iterable<Entry<ByteString, Iterable<V>>> entries = persistedData.get();
+            Iterable<K> keys =
+                Iterables.transform(
+                    entries,
+                    entry -> {
+                      try {
+                        return keyCoder.decode(entry.getKey().newInput());
+                      } catch (IOException e) {
+                        throw new RuntimeException(e);
+                      }
+                    });
+            keys =
+                Iterables.filter(
+                    keys,
+                    key -> {
+                      KeyState keyState =
+                          keyStateMap.getOrDefault(keyCoder.structuralValue(key), null);
+                      if (keyState == null) return true;

Review Comment:
   We keep items for which keyState == null?



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternals.java:
##########
@@ -1588,7 +1599,518 @@ private Future<Iterable<Map.Entry<ByteString, V>>> getFuture() {
         return reader.valuePrefixFuture(stateKeyPrefix, stateFamily, valueCoder);
       }
     }
-  };
+  }
+
+  private static class WindmillMultimap<K, V> extends SimpleWindmillState
+      implements MultimapState<K, V> {
+
+    private final StateNamespace namespace;
+    private final StateTag<MultimapState<K, V>> address;
+    private final ByteString stateKey;
+    private final String stateFamily;
+    private final Coder<K> keyCoder;
+    private final Coder<V> valueCoder;
+
+    private class KeyState {
+      final K originalKey;
+      KeyExistence existence;
+      // valuesCached can be true if only existence == KNOWN_EXIST
+      boolean valuesCached;
+      // represents the values in windmill. When new values are added, they are added to
+      // localAdditions but not KeyState#values. New values will be added to KeyState#values only
+      // after they are persisted into windmill and removed from localAdditions.
+      ConcatIterables<V> values;
+
+      KeyState(K originalKey) {
+        this.originalKey = originalKey;
+        existence = KeyExistence.UNKNOWN_EXISTENCE;
+        valuesCached = false;
+        values = new ConcatIterables<>();
+      }
+    }
+
+    private enum KeyExistence {
+      // this key is known to exist
+      KNOWN_EXIST,
+      // this key is known to be nonexistent
+      KNOWN_NONEXISTENT,
+      // we don't know if this key is in this multimap, this is just to provide a mapping between
+      // the original key and the structural key.
+      UNKNOWN_EXISTENCE
+    }
+
+    private boolean cleared = false;
+    // We use the structural value of the keys as the key in keyStateMap, so that different java
+    // Objects with the same content will be treated as the same Multimap key.
+    private Map<Object, KeyState> keyStateMap = Maps.newHashMap();
+    // If true, all keys are cached in keyStateMap with existence == KNOWN_EXIST.
+    private boolean allKeysKnown = false;
+
+    private boolean complete = false;
+    // All keys that have new values pending write to windmill.
+    private Multimap<Object, V> localAdditions = ArrayListMultimap.create();
+    // All keys that are pending delete. If a key exist in both localRemovals and localAdditions:
+    // new values in localAdditions will be added after old values are removed.
+    private Set<Object> localRemovals = Sets.newHashSet();
+
+    private WindmillMultimap(
+        StateNamespace namespace,
+        StateTag<MultimapState<K, V>> address,
+        String stateFamily,
+        Coder<K> keyCoder,
+        Coder<V> valueCoder,
+        boolean isNewShardingKey) {
+      this.namespace = namespace;
+      this.address = address;
+      this.stateKey = encodeKey(namespace, address);
+      this.stateFamily = stateFamily;
+      this.keyCoder = keyCoder;
+      this.valueCoder = valueCoder;
+      this.complete = isNewShardingKey;
+      this.allKeysKnown = isNewShardingKey;
+    }
+
+    @Override
+    public void put(K key, V value) {
+      final Object structuralKey = keyCoder.structuralValue(key);
+      localAdditions.put(structuralKey, value);
+      keyStateMap.compute(
+          structuralKey,
+          (k, v) -> {
+            if (v == null) v = new KeyState(key);
+            v.existence = KeyExistence.KNOWN_EXIST;
+            return v;
+          });
+    }
+
+    // Initiates a backend state read to fetch all entries if necessary.
+    private Future<Iterable<Map.Entry<ByteString, Iterable<V>>>> getFuture(boolean omitValues) {
+      if (complete) {
+        return Futures.immediateFuture(Collections.emptyList());
+      } else {
+        return reader.multimapFetchAllFuture(omitValues, stateKey, stateFamily, valueCoder);
+      }
+    }
+
+    // Initiates a backend state read to fetch a single entry if necessary.
+    private Future<Iterable<V>> getFutureForKey(K key) {
+      try {
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        keyCoder.encode(key, keyStream);
+        return reader.multimapFetchSingleEntryFuture(
+            keyStream.toByteString(), stateKey, stateFamily, valueCoder);
+      } catch (IOException e) {
+        throw new RuntimeException(e);
+      }
+    }
+
+    @Override
+    public ReadableState<Iterable<V>> get(K key) {
+      return new ReadableState<Iterable<V>>() {
+        final Object structuralKey = keyCoder.structuralValue(key);
+
+        @Override
+        public Iterable<V> read() {
+          KeyState keyState = keyStateMap.computeIfAbsent(structuralKey, k -> new KeyState(key));
+          if ((keyState.existence == KeyExistence.KNOWN_NONEXISTENT)
+              || (allKeysKnown && keyState.existence == KeyExistence.UNKNOWN_EXISTENCE)) {
+            return Collections.emptyList();
+          }
+          if (localRemovals.contains(structuralKey)) {
+            // this key has been removed locally but the removal hasn't been sent to windmill,
+            // thus values in windmill(if any) are obsolete, and we only care about local values.
+            if (localAdditions.containsKey(structuralKey)) {
+              return Iterables.unmodifiableIterable(localAdditions.get(structuralKey));
+            } else {
+              return Collections.emptyList();
+            }
+          }
+          if (keyState.valuesCached || complete) {
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(keyState.values, localAdditions.get(structuralKey)));
+          }
+          Future<Iterable<V>> persistedData = getFutureForKey(key);
+          try (Closeable scope = scopedReadState()) {
+            final Iterable<V> persistedValues = persistedData.get();
+            if (Iterables.isEmpty(persistedValues)) {
+              Collection<V> local = localAdditions.get(structuralKey);
+              if (local.isEmpty()) {
+                keyState.existence = KeyExistence.KNOWN_NONEXISTENT;
+                return Collections.emptyList();
+              }
+              return Iterables.unmodifiableIterable(local);
+            }
+            if (persistedValues instanceof Weighted) {
+              keyState.existence = KeyExistence.KNOWN_EXIST;
+              keyState.valuesCached = true;
+              ConcatIterables<V> it = new ConcatIterables<>();
+              it.extendWith(persistedValues);
+              keyState.values = it;
+            }
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(persistedValues, localAdditions.get(structuralKey)));
+          } catch (InterruptedException | ExecutionException | IOException e) {
+            if (e instanceof InterruptedException) {
+              Thread.currentThread().interrupt();
+            }
+            throw new RuntimeException("Unable to read Multimap state", e);
+          }
+        }
+
+        @Override
+        @SuppressWarnings("FutureReturnValueIgnored")
+        public ReadableState<Iterable<V>> readLater() {
+          WindmillMultimap.this.getFutureForKey(key);
+          return this;
+        }
+      };
+    }
+
+    @Override
+    protected WorkItemCommitRequest persistDirectly(WindmillStateCache.ForKeyAndFamily cache)
+        throws IOException {
+      if (!cleared && localAdditions.isEmpty() && localRemovals.isEmpty()) {
+        return WorkItemCommitRequest.newBuilder().buildPartial();
+      }
+      WorkItemCommitRequest.Builder commitBuilder = WorkItemCommitRequest.newBuilder();
+      Windmill.TagMultimapUpdateRequest.Builder builder = null;
+      if (cleared) {
+        builder = commitBuilder.addMultimapUpdatesBuilder();
+        builder.setDeleteAll(true);
+        cleared = false;
+      }
+      Set<Object> keysWithUpdates = Sets.newHashSet();
+      keysWithUpdates.addAll(localRemovals);
+      keysWithUpdates.addAll(localAdditions.keySet());
+      if (!keysWithUpdates.isEmpty() && builder == null) {
+        builder = commitBuilder.addMultimapUpdatesBuilder();
+      }
+      for (Object structuralKey : keysWithUpdates) {
+        KeyState keyState = keyStateMap.get(structuralKey);
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        keyCoder.encode(keyState.originalKey, keyStream);
+        ByteString encodedKey = keyStream.toByteString();
+        Windmill.TagMultimapEntry.Builder entryBuilder = builder.addUpdatesBuilder();
+        entryBuilder.setEntryName(encodedKey);
+        entryBuilder.setDeleteAll(localRemovals.contains(structuralKey));
+        for (V value : localAdditions.get(structuralKey)) {
+          ByteStringOutputStream valueStream = new ByteStringOutputStream();
+          valueCoder.encode(value, valueStream);
+          ByteString encodedValue = valueStream.toByteString();
+          entryBuilder.addValues(encodedValue);
+        }
+        // Move newly added values from localAdditions to cachedEntries as those new values are also
+        // persisted in Windmill.
+        if (keyState.valuesCached) {
+          keyState.values.extendWith(localAdditions.get(structuralKey));
+        }
+      }
+
+      if (builder != null) {
+        builder.setTag(stateKey).setStateFamily(stateFamily);
+      }
+      for (Object removedKey : localRemovals) {
+        KeyState keyState = keyStateMap.get(removedKey);
+        if (keyState.existence != KeyExistence.KNOWN_NONEXISTENT) {
+          keyStateMap.remove(removedKey);
+        }
+      }
+      localRemovals = Sets.newHashSet();
+      localAdditions = ArrayListMultimap.create();
+
+      cache.put(namespace, address, this, 1);
+
+      return commitBuilder.buildPartial();
+    }
+
+    @Override
+    public void remove(K key) {
+      final Object structuralKey = keyCoder.structuralValue(key);
+      KeyState keyState = keyStateMap.computeIfAbsent(structuralKey, k -> new KeyState(key));
+      if (keyState.existence == KeyExistence.KNOWN_NONEXISTENT
+          || (allKeysKnown && keyState.existence == KeyExistence.UNKNOWN_EXISTENCE)) {
+        return;
+      }
+      if (keyState.valuesCached || !complete) {
+        // there may be data in windmill that need to be removed.
+        localRemovals.add(structuralKey);
+        keyState.values = new ConcatIterables<>();

Review Comment:
   why this instead of Collections.emptySet?



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateReader.java:
##########
@@ -466,20 +518,69 @@ private <ResultT, ContinuationT> Future<Iterable<ResultT>> valuesToPagingIterabl
     return Futures.lazyTransform(future, toIterable);
   }
 
+  private void delayUnbatchableMultimapFetches(
+      List<StateTag<?>> multimapTags, HashSet<StateTag<?>> toFetch) {
+    // Each KeyedGetDataRequest can have at most 1 TagMultimapFetchRequest per <tag, state_family>
+    // pair, thus we need to delay unbatchable multimap requests of the same stateFamily and tag

Review Comment:
   why do we have this limit in windmill?



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternals.java:
##########
@@ -1588,7 +1599,518 @@ private Future<Iterable<Map.Entry<ByteString, V>>> getFuture() {
         return reader.valuePrefixFuture(stateKeyPrefix, stateFamily, valueCoder);
       }
     }
-  };
+  }
+
+  private static class WindmillMultimap<K, V> extends SimpleWindmillState
+      implements MultimapState<K, V> {
+
+    private final StateNamespace namespace;
+    private final StateTag<MultimapState<K, V>> address;
+    private final ByteString stateKey;
+    private final String stateFamily;
+    private final Coder<K> keyCoder;
+    private final Coder<V> valueCoder;
+
+    private class KeyState {
+      final K originalKey;
+      KeyExistence existence;
+      // valuesCached can be true if only existence == KNOWN_EXIST
+      boolean valuesCached;
+      // represents the values in windmill. When new values are added, they are added to
+      // localAdditions but not KeyState#values. New values will be added to KeyState#values only
+      // after they are persisted into windmill and removed from localAdditions.
+      ConcatIterables<V> values;
+
+      KeyState(K originalKey) {
+        this.originalKey = originalKey;
+        existence = KeyExistence.UNKNOWN_EXISTENCE;
+        valuesCached = false;
+        values = new ConcatIterables<>();
+      }
+    }
+
+    private enum KeyExistence {
+      // this key is known to exist
+      KNOWN_EXIST,
+      // this key is known to be nonexistent
+      KNOWN_NONEXISTENT,
+      // we don't know if this key is in this multimap, this is just to provide a mapping between
+      // the original key and the structural key.
+      UNKNOWN_EXISTENCE
+    }
+
+    private boolean cleared = false;
+    // We use the structural value of the keys as the key in keyStateMap, so that different java
+    // Objects with the same content will be treated as the same Multimap key.
+    private Map<Object, KeyState> keyStateMap = Maps.newHashMap();
+    // If true, all keys are cached in keyStateMap with existence == KNOWN_EXIST.
+    private boolean allKeysKnown = false;
+
+    private boolean complete = false;
+    // All keys that have new values pending write to windmill.
+    private Multimap<Object, V> localAdditions = ArrayListMultimap.create();
+    // All keys that are pending delete. If a key exist in both localRemovals and localAdditions:
+    // new values in localAdditions will be added after old values are removed.
+    private Set<Object> localRemovals = Sets.newHashSet();
+
+    private WindmillMultimap(
+        StateNamespace namespace,
+        StateTag<MultimapState<K, V>> address,
+        String stateFamily,
+        Coder<K> keyCoder,
+        Coder<V> valueCoder,
+        boolean isNewShardingKey) {
+      this.namespace = namespace;
+      this.address = address;
+      this.stateKey = encodeKey(namespace, address);
+      this.stateFamily = stateFamily;
+      this.keyCoder = keyCoder;
+      this.valueCoder = valueCoder;
+      this.complete = isNewShardingKey;
+      this.allKeysKnown = isNewShardingKey;
+    }
+
+    @Override
+    public void put(K key, V value) {
+      final Object structuralKey = keyCoder.structuralValue(key);
+      localAdditions.put(structuralKey, value);
+      keyStateMap.compute(
+          structuralKey,
+          (k, v) -> {
+            if (v == null) v = new KeyState(key);
+            v.existence = KeyExistence.KNOWN_EXIST;
+            return v;
+          });
+    }
+
+    // Initiates a backend state read to fetch all entries if necessary.
+    private Future<Iterable<Map.Entry<ByteString, Iterable<V>>>> getFuture(boolean omitValues) {
+      if (complete) {
+        return Futures.immediateFuture(Collections.emptyList());
+      } else {
+        return reader.multimapFetchAllFuture(omitValues, stateKey, stateFamily, valueCoder);
+      }
+    }
+
+    // Initiates a backend state read to fetch a single entry if necessary.
+    private Future<Iterable<V>> getFutureForKey(K key) {
+      try {
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        keyCoder.encode(key, keyStream);
+        return reader.multimapFetchSingleEntryFuture(
+            keyStream.toByteString(), stateKey, stateFamily, valueCoder);
+      } catch (IOException e) {
+        throw new RuntimeException(e);
+      }
+    }
+
+    @Override
+    public ReadableState<Iterable<V>> get(K key) {
+      return new ReadableState<Iterable<V>>() {
+        final Object structuralKey = keyCoder.structuralValue(key);
+
+        @Override
+        public Iterable<V> read() {
+          KeyState keyState = keyStateMap.computeIfAbsent(structuralKey, k -> new KeyState(key));
+          if ((keyState.existence == KeyExistence.KNOWN_NONEXISTENT)
+              || (allKeysKnown && keyState.existence == KeyExistence.UNKNOWN_EXISTENCE)) {
+            return Collections.emptyList();
+          }
+          if (localRemovals.contains(structuralKey)) {
+            // this key has been removed locally but the removal hasn't been sent to windmill,
+            // thus values in windmill(if any) are obsolete, and we only care about local values.
+            if (localAdditions.containsKey(structuralKey)) {
+              return Iterables.unmodifiableIterable(localAdditions.get(structuralKey));
+            } else {
+              return Collections.emptyList();
+            }
+          }
+          if (keyState.valuesCached || complete) {
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(keyState.values, localAdditions.get(structuralKey)));
+          }
+          Future<Iterable<V>> persistedData = getFutureForKey(key);
+          try (Closeable scope = scopedReadState()) {
+            final Iterable<V> persistedValues = persistedData.get();
+            if (Iterables.isEmpty(persistedValues)) {
+              Collection<V> local = localAdditions.get(structuralKey);
+              if (local.isEmpty()) {
+                keyState.existence = KeyExistence.KNOWN_NONEXISTENT;
+                return Collections.emptyList();
+              }
+              return Iterables.unmodifiableIterable(local);
+            }
+            if (persistedValues instanceof Weighted) {
+              keyState.existence = KeyExistence.KNOWN_EXIST;
+              keyState.valuesCached = true;
+              ConcatIterables<V> it = new ConcatIterables<>();
+              it.extendWith(persistedValues);
+              keyState.values = it;
+            }
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(persistedValues, localAdditions.get(structuralKey)));
+          } catch (InterruptedException | ExecutionException | IOException e) {
+            if (e instanceof InterruptedException) {
+              Thread.currentThread().interrupt();
+            }
+            throw new RuntimeException("Unable to read Multimap state", e);
+          }
+        }
+
+        @Override
+        @SuppressWarnings("FutureReturnValueIgnored")
+        public ReadableState<Iterable<V>> readLater() {
+          WindmillMultimap.this.getFutureForKey(key);
+          return this;
+        }
+      };
+    }
+
+    @Override
+    protected WorkItemCommitRequest persistDirectly(WindmillStateCache.ForKeyAndFamily cache)
+        throws IOException {
+      if (!cleared && localAdditions.isEmpty() && localRemovals.isEmpty()) {
+        return WorkItemCommitRequest.newBuilder().buildPartial();
+      }
+      WorkItemCommitRequest.Builder commitBuilder = WorkItemCommitRequest.newBuilder();
+      Windmill.TagMultimapUpdateRequest.Builder builder = null;
+      if (cleared) {
+        builder = commitBuilder.addMultimapUpdatesBuilder();
+        builder.setDeleteAll(true);
+        cleared = false;
+      }
+      Set<Object> keysWithUpdates = Sets.newHashSet();
+      keysWithUpdates.addAll(localRemovals);
+      keysWithUpdates.addAll(localAdditions.keySet());
+      if (!keysWithUpdates.isEmpty() && builder == null) {
+        builder = commitBuilder.addMultimapUpdatesBuilder();
+      }
+      for (Object structuralKey : keysWithUpdates) {
+        KeyState keyState = keyStateMap.get(structuralKey);
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        keyCoder.encode(keyState.originalKey, keyStream);
+        ByteString encodedKey = keyStream.toByteString();
+        Windmill.TagMultimapEntry.Builder entryBuilder = builder.addUpdatesBuilder();
+        entryBuilder.setEntryName(encodedKey);
+        entryBuilder.setDeleteAll(localRemovals.contains(structuralKey));
+        for (V value : localAdditions.get(structuralKey)) {
+          ByteStringOutputStream valueStream = new ByteStringOutputStream();
+          valueCoder.encode(value, valueStream);
+          ByteString encodedValue = valueStream.toByteString();
+          entryBuilder.addValues(encodedValue);
+        }
+        // Move newly added values from localAdditions to cachedEntries as those new values are also
+        // persisted in Windmill.
+        if (keyState.valuesCached) {
+          keyState.values.extendWith(localAdditions.get(structuralKey));
+        }
+      }
+
+      if (builder != null) {
+        builder.setTag(stateKey).setStateFamily(stateFamily);
+      }
+      for (Object removedKey : localRemovals) {
+        KeyState keyState = keyStateMap.get(removedKey);
+        if (keyState.existence != KeyExistence.KNOWN_NONEXISTENT) {
+          keyStateMap.remove(removedKey);
+        }
+      }
+      localRemovals = Sets.newHashSet();
+      localAdditions = ArrayListMultimap.create();
+
+      cache.put(namespace, address, this, 1);
+
+      return commitBuilder.buildPartial();
+    }
+
+    @Override
+    public void remove(K key) {
+      final Object structuralKey = keyCoder.structuralValue(key);
+      KeyState keyState = keyStateMap.computeIfAbsent(structuralKey, k -> new KeyState(key));
+      if (keyState.existence == KeyExistence.KNOWN_NONEXISTENT
+          || (allKeysKnown && keyState.existence == KeyExistence.UNKNOWN_EXISTENCE)) {
+        return;
+      }
+      if (keyState.valuesCached || !complete) {
+        // there may be data in windmill that need to be removed.
+        localRemovals.add(structuralKey);
+        keyState.values = new ConcatIterables<>();
+        keyState.valuesCached = false;
+      } // else: no data in windmill, deleting from local cache is sufficient.
+      localAdditions.removeAll(structuralKey);
+      keyState.existence = KeyExistence.KNOWN_NONEXISTENT;
+    }
+
+    @Override
+    public void clear() {
+      keyStateMap = Maps.newHashMap();
+      localAdditions = ArrayListMultimap.create();
+      localRemovals = Sets.newHashSet();
+      cleared = true;
+      complete = true;
+      allKeysKnown = true;
+    }
+
+    @Override
+    public ReadableState<Iterable<K>> keys() {
+      return new ReadableState<Iterable<K>>() {
+        @Override
+        public Iterable<K> read() {
+          if (allKeysKnown) {
+            return Iterables.unmodifiableIterable(
+                Iterables.transform(
+                    Iterables.filter(
+                        keyStateMap.values(),
+                        keyState -> keyState.existence == KeyExistence.KNOWN_EXIST),

Review Comment:
   are there no paths where you might have unknown_exist in the map when allKeysKnown?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@beam.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org