You are viewing a plain text version of this content. The canonical link for it is here.
Posted to github@beam.apache.org by GitBox <gi...@apache.org> on 2023/01/03 07:20:48 UTC

[GitHub] [beam] zhengbuqian commented on a diff in pull request #23492: Add Windmill support for MultimapState

zhengbuqian commented on code in PR #23492:
URL: https://github.com/apache/beam/pull/23492#discussion_r1059892472


##########
runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java:
##########
@@ -2422,12 +2422,6 @@ static boolean useStreamingEngine(DataflowPipelineOptions options) {
 
   static void verifyDoFnSupported(
       DoFn<?, ?> fn, boolean streaming, DataflowPipelineOptions options) {
-    if (DoFnSignatures.usesMultimapState(fn)) {

Review Comment:
   Done



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateReader.java:
##########
@@ -129,9 +140,10 @@ enum Kind {
     abstract String getStateFamily();
 
     /**
-     * For {@link Kind#BAG, Kind#ORDERED_LIST, Kind#VALUE_PREFIX} kinds: A previous
-     * 'continuation_position' returned by Windmill to signal the resulting bag was incomplete.
-     * Sending that position will request the next page of values. Null for first request.
+     * For {@link Kind#BAG, Kind#ORDERED_LIST, Kind#VALUE_PREFIX, KIND#MULTIMAP_SINGLE_ENTRY,
+     * KIND#MULTIMAP_ALL} kinds: A previous 'continuation_position' returned by Windmill to signal
+     * the resulting bag was incomplete.Sending that position will request the next page of values.

Review Comment:
   Done



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateReader.java:
##########
@@ -142,6 +154,14 @@ enum Kind {
     @Nullable
     abstract Range<Long> getSortedListRange();
 
+    /** For {@link Kind#MULTIMAP_SINGLE_ENTRY} kinds: the key in the multimap to fetch or delete. */
+    @Nullable
+    abstract ByteString getMultimapKey();
+
+    /** For {@link Kind#MULTIMAP_ALL} kinds: will return keys only if true. */

Review Comment:
   Done



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternals.java:
##########
@@ -1588,7 +1599,472 @@ private Future<Iterable<Map.Entry<ByteString, V>>> getFuture() {
         return reader.valuePrefixFuture(stateKeyPrefix, stateFamily, valueCoder);
       }
     }
-  };
+  }
+
+  private static class WindmillMultimap<K, V> extends SimpleWindmillState
+      implements MultimapState<K, V> {
+
+    private final StateNamespace namespace;
+    private final StateTag<MultimapState<K, V>> address;
+    private final ByteString stateKey;
+    private final String stateFamily;
+    private final Coder<K> keyCoder;
+    private final Coder<V> valueCoder;
+
+    private boolean cleared = false;
+    /**
+     * For any given key, if it's contained in {@link #cachedEntries}, then the complete content of
+     * this key is cached: persisted values of this key in backing store are cached in
+     * cachedEntries, newly added values not yet written to backing store are cached in
+     * localAdditions. If a key is not contained in {@link #cachedEntries} then we don't know if
+     * Windmill contains additional values which also maps to this key, we'll need to read them if
+     * the work item actually wants the content.
+     */
+    private Map<Object, Iterable<V>> cachedEntries = Maps.newHashMap();

Review Comment:
   Great suggestion! Updated.



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateReader.java:
##########
@@ -679,6 +834,49 @@ private void consumeResponse(Windmill.KeyedGetDataResponse response, Set<StateTa
       consumeSortedList(sorted_list, stateTag);
     }
 
+    for (Windmill.TagMultimapFetchResponse tagMultimap : response.getTagMultimapsList()) {
+      // First check if it's keys()/entries()

Review Comment:
   I think that is tricky and we'd best not rely on that, wdyt?



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternals.java:
##########
@@ -1588,7 +1599,472 @@ private Future<Iterable<Map.Entry<ByteString, V>>> getFuture() {
         return reader.valuePrefixFuture(stateKeyPrefix, stateFamily, valueCoder);
       }
     }
-  };
+  }
+
+  private static class WindmillMultimap<K, V> extends SimpleWindmillState
+      implements MultimapState<K, V> {
+
+    private final StateNamespace namespace;
+    private final StateTag<MultimapState<K, V>> address;
+    private final ByteString stateKey;
+    private final String stateFamily;
+    private final Coder<K> keyCoder;
+    private final Coder<V> valueCoder;
+
+    private boolean cleared = false;
+    /**
+     * For any given key, if it's contained in {@link #cachedEntries}, then the complete content of
+     * this key is cached: persisted values of this key in backing store are cached in
+     * cachedEntries, newly added values not yet written to backing store are cached in
+     * localAdditions. If a key is not contained in {@link #cachedEntries} then we don't know if
+     * Windmill contains additional values which also maps to this key, we'll need to read them if
+     * the work item actually wants the content.
+     */
+    private Map<Object, Iterable<V>> cachedEntries = Maps.newHashMap();
+    // Any key presents in existKeyCache is known to exist in the multimap.
+    private Set<Object> existKeyCache = Sets.newHashSet();
+    // If true, any key not in existKeyCache is known to be nonexistent.
+    private boolean allKeysKnown = false;
+    // Any key presents in nonexistentKeyCache is known to be nonexistent in the multimap.
+    private Set<Object> nonexistentKeyCache = Sets.newHashSet();
+
+    private boolean complete = false;
+    private Multimap<Object, V> localAdditions = ArrayListMultimap.create();
+    // All keys that are pending delete. If a key exist in both localRemovals and localAdditions:
+    // new values in localAdditions will be added after old values are removed.
+    private Set<Object> localRemovals = Sets.newHashSet();
+    // structuralKeyMapping maps from the structuralKeys to the actual keys. Any key in
+    // cachedEntries, existKeyCache, nonexistentKeyCache, localAdditions and localRemovals should be
+    // included in this mapping.
+    private Map<Object, K> structuralKeyMapping = Maps.newHashMap();
+
+    private WindmillMultimap(
+        StateNamespace namespace,
+        StateTag<MultimapState<K, V>> address,
+        String stateFamily,
+        Coder<K> keyCoder,
+        Coder<V> valueCoder,
+        boolean isNewShardingKey) {
+      this.namespace = namespace;
+      this.address = address;
+      this.stateKey = encodeKey(namespace, address);
+      this.stateFamily = stateFamily;
+      this.keyCoder = keyCoder;
+      this.valueCoder = valueCoder;
+      this.complete = isNewShardingKey;
+      this.allKeysKnown = isNewShardingKey;
+    }
+
+    @Override
+    public void put(K key, V value) {
+      Object structuralKey = keyCoder.structuralValue(key);
+      localAdditions.put(structuralKey, value);
+      existKeyCache.add(structuralKey);
+      nonexistentKeyCache.remove(structuralKey);
+      structuralKeyMapping.put(structuralKey, key);
+    }
+
+    // Initiates a backend state read to fetch all entries if necessary.
+    private Future<Iterable<Map.Entry<ByteString, Iterable<V>>>> getFuture(boolean omitValues) {
+      if (complete) {
+        return Futures.immediateFuture(Collections.emptyList());
+      } else {
+        return reader.multimapFetchAllFuture(omitValues, stateKey, stateFamily, valueCoder);
+      }
+    }
+
+    // Initiates a backend state read to fetch a single entry if necessary.
+    private Future<Iterable<V>> getFutureForKey(K key) {
+      try {
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        keyCoder.encode(key, keyStream);
+        return reader.multimapFetchSingleEntryFuture(
+            keyStream.toByteString(), stateKey, stateFamily, valueCoder);
+      } catch (IOException e) {
+        throw new RuntimeException(e);
+      }
+    }
+
+    @Override
+    public ReadableState<Iterable<V>> get(K key) {
+      return new ReadableState<Iterable<V>>() {
+        Object structuralKey = keyCoder.structuralValue(key);

Review Comment:
   Done



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternals.java:
##########
@@ -1588,7 +1599,472 @@ private Future<Iterable<Map.Entry<ByteString, V>>> getFuture() {
         return reader.valuePrefixFuture(stateKeyPrefix, stateFamily, valueCoder);
       }
     }
-  };
+  }
+
+  private static class WindmillMultimap<K, V> extends SimpleWindmillState
+      implements MultimapState<K, V> {
+
+    private final StateNamespace namespace;
+    private final StateTag<MultimapState<K, V>> address;
+    private final ByteString stateKey;
+    private final String stateFamily;
+    private final Coder<K> keyCoder;
+    private final Coder<V> valueCoder;
+
+    private boolean cleared = false;
+    /**
+     * For any given key, if it's contained in {@link #cachedEntries}, then the complete content of
+     * this key is cached: persisted values of this key in backing store are cached in
+     * cachedEntries, newly added values not yet written to backing store are cached in
+     * localAdditions. If a key is not contained in {@link #cachedEntries} then we don't know if
+     * Windmill contains additional values which also maps to this key, we'll need to read them if
+     * the work item actually wants the content.
+     */
+    private Map<Object, Iterable<V>> cachedEntries = Maps.newHashMap();
+    // Any key presents in existKeyCache is known to exist in the multimap.
+    private Set<Object> existKeyCache = Sets.newHashSet();
+    // If true, any key not in existKeyCache is known to be nonexistent.
+    private boolean allKeysKnown = false;
+    // Any key presents in nonexistentKeyCache is known to be nonexistent in the multimap.
+    private Set<Object> nonexistentKeyCache = Sets.newHashSet();
+
+    private boolean complete = false;
+    private Multimap<Object, V> localAdditions = ArrayListMultimap.create();
+    // All keys that are pending delete. If a key exist in both localRemovals and localAdditions:
+    // new values in localAdditions will be added after old values are removed.
+    private Set<Object> localRemovals = Sets.newHashSet();
+    // structuralKeyMapping maps from the structuralKeys to the actual keys. Any key in
+    // cachedEntries, existKeyCache, nonexistentKeyCache, localAdditions and localRemovals should be
+    // included in this mapping.
+    private Map<Object, K> structuralKeyMapping = Maps.newHashMap();
+
+    private WindmillMultimap(
+        StateNamespace namespace,
+        StateTag<MultimapState<K, V>> address,
+        String stateFamily,
+        Coder<K> keyCoder,
+        Coder<V> valueCoder,
+        boolean isNewShardingKey) {
+      this.namespace = namespace;
+      this.address = address;
+      this.stateKey = encodeKey(namespace, address);
+      this.stateFamily = stateFamily;
+      this.keyCoder = keyCoder;
+      this.valueCoder = valueCoder;
+      this.complete = isNewShardingKey;
+      this.allKeysKnown = isNewShardingKey;
+    }
+
+    @Override
+    public void put(K key, V value) {
+      Object structuralKey = keyCoder.structuralValue(key);
+      localAdditions.put(structuralKey, value);
+      existKeyCache.add(structuralKey);
+      nonexistentKeyCache.remove(structuralKey);
+      structuralKeyMapping.put(structuralKey, key);
+    }
+
+    // Initiates a backend state read to fetch all entries if necessary.
+    private Future<Iterable<Map.Entry<ByteString, Iterable<V>>>> getFuture(boolean omitValues) {
+      if (complete) {
+        return Futures.immediateFuture(Collections.emptyList());
+      } else {
+        return reader.multimapFetchAllFuture(omitValues, stateKey, stateFamily, valueCoder);
+      }
+    }
+
+    // Initiates a backend state read to fetch a single entry if necessary.
+    private Future<Iterable<V>> getFutureForKey(K key) {
+      try {
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        keyCoder.encode(key, keyStream);
+        return reader.multimapFetchSingleEntryFuture(
+            keyStream.toByteString(), stateKey, stateFamily, valueCoder);
+      } catch (IOException e) {
+        throw new RuntimeException(e);
+      }
+    }
+
+    @Override
+    public ReadableState<Iterable<V>> get(K key) {
+      return new ReadableState<Iterable<V>>() {
+        Object structuralKey = keyCoder.structuralValue(key);
+
+        @Override
+        public Iterable<V> read() {
+          if (nonexistentKeyCache.contains(structuralKey)
+              || (allKeysKnown && !existKeyCache.contains(structuralKey))) {
+            return Collections.emptyList();
+          }
+          if (localRemovals.contains(structuralKey)) {
+            // this key has been removed locally but the removal hasn't been sent to windmill,
+            // thus values in windmill(if any) are obsolete, and we only care about local values.
+            if (localAdditions.containsKey(structuralKey)) {
+              return Iterables.unmodifiableIterable(localAdditions.get(structuralKey));
+            } else {
+              return Collections.emptyList();
+            }
+          }
+          if (cachedEntries.containsKey(structuralKey) || complete) {
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(
+                    cachedEntries.getOrDefault(structuralKey, Collections.emptyList()),
+                    localAdditions.get(structuralKey)));
+          }
+          Future<Iterable<V>> persistedData = getFutureForKey(key);
+          try (Closeable scope = scopedReadState()) {
+            Iterable<V> persistedValues = persistedData.get();
+            if (Iterables.size(persistedValues) == 0) {
+              if (!localAdditions.containsKey(structuralKey)) {
+                nonexistentKeyCache.add(structuralKey);
+                Preconditions.checkState(
+                    !existKeyCache.contains(structuralKey),
+                    "Key "
+                        + key
+                        + " exists"
+                        + " in existKeyCache but no value in neither windmill nor local additions.");
+              }
+              return Iterables.unmodifiableIterable(localAdditions.get(structuralKey));
+            }
+            if (persistedValues instanceof Weighted) {
+              cachedEntries.put(structuralKey, new ConcatIterables<>());

Review Comment:
   That's what the current code is? 
   
   BTW I updated to code to avoid unnecessary map key look up.



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternals.java:
##########
@@ -1588,7 +1599,472 @@ private Future<Iterable<Map.Entry<ByteString, V>>> getFuture() {
         return reader.valuePrefixFuture(stateKeyPrefix, stateFamily, valueCoder);
       }
     }
-  };
+  }
+
+  private static class WindmillMultimap<K, V> extends SimpleWindmillState
+      implements MultimapState<K, V> {
+
+    private final StateNamespace namespace;
+    private final StateTag<MultimapState<K, V>> address;
+    private final ByteString stateKey;
+    private final String stateFamily;
+    private final Coder<K> keyCoder;
+    private final Coder<V> valueCoder;
+
+    private boolean cleared = false;
+    /**
+     * For any given key, if it's contained in {@link #cachedEntries}, then the complete content of
+     * this key is cached: persisted values of this key in backing store are cached in
+     * cachedEntries, newly added values not yet written to backing store are cached in
+     * localAdditions. If a key is not contained in {@link #cachedEntries} then we don't know if
+     * Windmill contains additional values which also maps to this key, we'll need to read them if
+     * the work item actually wants the content.
+     */
+    private Map<Object, Iterable<V>> cachedEntries = Maps.newHashMap();
+    // Any key presents in existKeyCache is known to exist in the multimap.
+    private Set<Object> existKeyCache = Sets.newHashSet();
+    // If true, any key not in existKeyCache is known to be nonexistent.
+    private boolean allKeysKnown = false;
+    // Any key presents in nonexistentKeyCache is known to be nonexistent in the multimap.
+    private Set<Object> nonexistentKeyCache = Sets.newHashSet();
+
+    private boolean complete = false;
+    private Multimap<Object, V> localAdditions = ArrayListMultimap.create();
+    // All keys that are pending delete. If a key exist in both localRemovals and localAdditions:
+    // new values in localAdditions will be added after old values are removed.
+    private Set<Object> localRemovals = Sets.newHashSet();
+    // structuralKeyMapping maps from the structuralKeys to the actual keys. Any key in
+    // cachedEntries, existKeyCache, nonexistentKeyCache, localAdditions and localRemovals should be
+    // included in this mapping.
+    private Map<Object, K> structuralKeyMapping = Maps.newHashMap();
+
+    private WindmillMultimap(
+        StateNamespace namespace,
+        StateTag<MultimapState<K, V>> address,
+        String stateFamily,
+        Coder<K> keyCoder,
+        Coder<V> valueCoder,
+        boolean isNewShardingKey) {
+      this.namespace = namespace;
+      this.address = address;
+      this.stateKey = encodeKey(namespace, address);
+      this.stateFamily = stateFamily;
+      this.keyCoder = keyCoder;
+      this.valueCoder = valueCoder;
+      this.complete = isNewShardingKey;
+      this.allKeysKnown = isNewShardingKey;
+    }
+
+    @Override
+    public void put(K key, V value) {
+      Object structuralKey = keyCoder.structuralValue(key);
+      localAdditions.put(structuralKey, value);
+      existKeyCache.add(structuralKey);
+      nonexistentKeyCache.remove(structuralKey);
+      structuralKeyMapping.put(structuralKey, key);
+    }
+
+    // Initiates a backend state read to fetch all entries if necessary.
+    private Future<Iterable<Map.Entry<ByteString, Iterable<V>>>> getFuture(boolean omitValues) {
+      if (complete) {
+        return Futures.immediateFuture(Collections.emptyList());
+      } else {
+        return reader.multimapFetchAllFuture(omitValues, stateKey, stateFamily, valueCoder);
+      }
+    }
+
+    // Initiates a backend state read to fetch a single entry if necessary.
+    private Future<Iterable<V>> getFutureForKey(K key) {
+      try {
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        keyCoder.encode(key, keyStream);
+        return reader.multimapFetchSingleEntryFuture(
+            keyStream.toByteString(), stateKey, stateFamily, valueCoder);
+      } catch (IOException e) {
+        throw new RuntimeException(e);
+      }
+    }
+
+    @Override
+    public ReadableState<Iterable<V>> get(K key) {
+      return new ReadableState<Iterable<V>>() {
+        Object structuralKey = keyCoder.structuralValue(key);
+
+        @Override
+        public Iterable<V> read() {
+          if (nonexistentKeyCache.contains(structuralKey)
+              || (allKeysKnown && !existKeyCache.contains(structuralKey))) {
+            return Collections.emptyList();
+          }
+          if (localRemovals.contains(structuralKey)) {
+            // this key has been removed locally but the removal hasn't been sent to windmill,
+            // thus values in windmill(if any) are obsolete, and we only care about local values.
+            if (localAdditions.containsKey(structuralKey)) {
+              return Iterables.unmodifiableIterable(localAdditions.get(structuralKey));
+            } else {
+              return Collections.emptyList();
+            }
+          }
+          if (cachedEntries.containsKey(structuralKey) || complete) {
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(
+                    cachedEntries.getOrDefault(structuralKey, Collections.emptyList()),
+                    localAdditions.get(structuralKey)));
+          }
+          Future<Iterable<V>> persistedData = getFutureForKey(key);
+          try (Closeable scope = scopedReadState()) {
+            Iterable<V> persistedValues = persistedData.get();
+            if (Iterables.size(persistedValues) == 0) {
+              if (!localAdditions.containsKey(structuralKey)) {
+                nonexistentKeyCache.add(structuralKey);
+                Preconditions.checkState(
+                    !existKeyCache.contains(structuralKey),
+                    "Key "
+                        + key
+                        + " exists"
+                        + " in existKeyCache but no value in neither windmill nor local additions.");
+              }
+              return Iterables.unmodifiableIterable(localAdditions.get(structuralKey));
+            }
+            if (persistedValues instanceof Weighted) {
+              cachedEntries.put(structuralKey, new ConcatIterables<>());
+              ((ConcatIterables<V>) cachedEntries.get(structuralKey)).extendWith(persistedValues);
+            }
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(persistedValues, localAdditions.get(structuralKey)));
+          } catch (InterruptedException | ExecutionException | IOException e) {
+            if (e instanceof InterruptedException) {
+              Thread.currentThread().interrupt();
+            }
+            throw new RuntimeException("Unable to read Multimap state", e);
+          }
+        }
+
+        @Override
+        @SuppressWarnings("FutureReturnValueIgnored")
+        public ReadableState<Iterable<V>> readLater() {
+          WindmillMultimap.this.getFutureForKey(key);
+          return this;
+        }
+      };
+    }
+
+    @Override
+    protected WorkItemCommitRequest persistDirectly(WindmillStateCache.ForKeyAndFamily cache)
+        throws IOException {
+      if (!cleared && localAdditions.isEmpty() && localRemovals.isEmpty()) {
+        return WorkItemCommitRequest.newBuilder().buildPartial();
+      }
+      WorkItemCommitRequest.Builder commitBuilder = WorkItemCommitRequest.newBuilder();
+      Windmill.TagMultimapUpdateRequest.Builder builder = null;
+      if (cleared) {
+        builder = commitBuilder.addMultimapUpdatesBuilder();
+        builder.setDeleteAll(true);
+        cleared = false;
+      }
+      Set<Object> keysWithUpdates = Sets.newHashSet();
+      keysWithUpdates.addAll(localRemovals);
+      keysWithUpdates.addAll(localAdditions.keySet());
+      if (!keysWithUpdates.isEmpty() && builder == null) {
+        builder = commitBuilder.addMultimapUpdatesBuilder();
+      }
+      for (Object structuralKey : keysWithUpdates) {
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        keyCoder.encode(structuralKeyMapping.get(structuralKey), keyStream);
+        ByteString encodedKey = keyStream.toByteString();
+        Windmill.TagMultimapEntry.Builder entryBuilder = builder.addUpdatesBuilder();
+        entryBuilder.setEntryName(encodedKey);
+        entryBuilder.setDeleteAll(localRemovals.contains(structuralKey));
+        for (V value : localAdditions.get(structuralKey)) {
+          ByteStringOutputStream valueStream = new ByteStringOutputStream();
+          valueCoder.encode(value, valueStream);
+          ByteString encodedValue = valueStream.toByteString();
+          entryBuilder.addValues(encodedValue);
+        }
+        if (cachedEntries.containsKey(structuralKey)) {
+          // Move newly added values from localAdditions to cachedEntries as those new values are
+          // also persisted in Windmill.
+          ((ConcatIterables<V>) cachedEntries.get(structuralKey))
+              .extendWith(localAdditions.get(structuralKey));
+        }
+      }
+
+      if (builder != null) {
+        builder.setTag(stateKey).setStateFamily(stateFamily);
+      }
+      for (Object removedKey : localRemovals) {
+        if (!nonexistentKeyCache.contains(removedKey)) {
+          structuralKeyMapping.remove(removedKey);
+        }
+      }
+      localRemovals = Sets.newHashSet();
+      localAdditions = ArrayListMultimap.create();
+
+      cache.put(namespace, address, this, 1);
+
+      return commitBuilder.buildPartial();
+    }
+
+    @Override
+    public void remove(K key) {
+      Object structuralKey = keyCoder.structuralValue(key);
+      if (!structuralKeyMapping.containsKey(structuralKey)) {
+        structuralKeyMapping.put(structuralKey, key);
+      }
+      if (nonexistentKeyCache.contains(structuralKey)
+          || (allKeysKnown && !existKeyCache.contains(structuralKey))) {
+        return;
+      }
+      if (cachedEntries.containsKey(structuralKey) || !complete) {
+        // there may be data in windmill that need to be removed.
+        localRemovals.add(structuralKey);
+        cachedEntries.remove(structuralKey);
+      } // else: no data in windmill, deleting from local cache is sufficient.
+      localAdditions.removeAll(structuralKey);
+      existKeyCache.remove(structuralKey);
+      nonexistentKeyCache.add(structuralKey);
+    }
+
+    @Override
+    public void clear() {
+      cachedEntries = Maps.newHashMap();
+      existKeyCache = Sets.newHashSet();
+      nonexistentKeyCache = Sets.newHashSet();
+      localAdditions = ArrayListMultimap.create();
+      localRemovals = Sets.newHashSet();
+      structuralKeyMapping = Maps.newHashMap();
+      cleared = true;
+      complete = true;
+      allKeysKnown = true;
+    }
+
+    @Override
+    public ReadableState<Iterable<K>> keys() {
+      return new ReadableState<Iterable<K>>() {
+        @Override
+        public Iterable<K> read() {
+          if (allKeysKnown) {
+            return Iterables.unmodifiableIterable(
+                Iterables.transform(existKeyCache, structuralKeyMapping::get));
+          }
+          Future<Iterable<Entry<ByteString, Iterable<V>>>> persistedData = getFuture(true);
+          try (Closeable scope = scopedReadState()) {
+            Iterable<Entry<ByteString, Iterable<V>>> entries = persistedData.get();
+            Iterable<K> keys =
+                Iterables.transform(
+                    entries,
+                    entry -> {
+                      try {
+                        return keyCoder.decode(entry.getKey().newInput());
+                      } catch (IOException e) {
+                        throw new RuntimeException(e);
+                      }
+                    });
+            keys =
+                Iterables.filter(
+                    keys, key -> !nonexistentKeyCache.contains(keyCoder.structuralValue(key)));
+            if (entries instanceof Weighted) {
+              // This is a known amount of data, cache them all.
+              keys.forEach(
+                  k -> {
+                    Object structuralKey = keyCoder.structuralValue(k);
+                    existKeyCache.add(structuralKey);
+                    structuralKeyMapping.put(structuralKey, k);
+                  });
+              allKeysKnown = true;
+              nonexistentKeyCache = Sets.newHashSet();
+              return Iterables.unmodifiableIterable(
+                  Iterables.transform(existKeyCache, structuralKeyMapping::get));
+            } else {
+              return Iterables.unmodifiableIterable(
+                  Iterables.concat(
+                      // This is the part of keys that are cached.
+                      Iterables.transform(existKeyCache, structuralKeyMapping::get),
+                      // This is the part of the keys returned from Windmill that are not cached.
+                      Iterables.filter(
+                          keys, e -> !existKeyCache.contains(keyCoder.structuralValue(e)))));
+            }
+          } catch (InterruptedException | ExecutionException | IOException e) {
+            if (e instanceof InterruptedException) {
+              Thread.currentThread().interrupt();
+            }
+            throw new RuntimeException("Unable to read state", e);
+          }
+        }
+
+        @Override
+        @SuppressWarnings("FutureReturnValueIgnored")
+        public ReadableState<Iterable<K>> readLater() {
+          WindmillMultimap.this.getFuture(true);
+          return this;
+        }
+      };
+    }
+
+    private MultimapIterables<K, V> mergedCachedEntries() {
+      MultimapIterables<K, V> result = new MultimapIterables<>();
+      for (Entry<Object, Collection<V>> entry : localAdditions.asMap().entrySet()) {
+        K key = structuralKeyMapping.get(entry.getKey());
+        result.extendWith(key, entry.getValue());
+      }
+      for (Entry<Object, Iterable<V>> entry : cachedEntries.entrySet()) {
+        K key = structuralKeyMapping.get(entry.getKey());
+        result.extendWith(key, entry.getValue());
+      }
+      return result;
+    }
+
+    private static class MultimapIterables<K, V> implements Iterable<Entry<K, V>> {
+      Map<K, ConcatIterables<V>> map;
+
+      public MultimapIterables() {
+        this.map = new HashMap<>();
+      }
+
+      public void extendWith(K key, Iterable<V> iterable) {
+        if (!map.containsKey(key)) map.put(key, new ConcatIterables<>());
+        map.get(key).extendWith(iterable);
+      }
+
+      @Override
+      public Iterator<Entry<K, V>> iterator() {
+        return Iterators.concat(
+            Iterables.transform(
+                    map.keySet(),
+                    k ->
+                        Iterables.transform(map.get(k), v -> new AbstractMap.SimpleEntry<>(k, v))
+                            .iterator())
+                .iterator());
+      }
+    }
+
+    @Override
+    public ReadableState<Iterable<Entry<K, V>>> entries() {
+      return new ReadableState<Iterable<Entry<K, V>>>() {
+        @Override
+        public Iterable<Entry<K, V>> read() {
+          if (complete) {
+            return Iterables.unmodifiableIterable(mergedCachedEntries());
+          }
+          Future<Iterable<Entry<ByteString, Iterable<V>>>> persistedData = getFuture(false);
+          try (Closeable scope = scopedReadState()) {
+            Iterable<Entry<ByteString, Iterable<V>>> entries = persistedData.get();
+            Map<Object, ConcatIterables<V>> entryMap = Maps.newHashMap();
+            entries.forEach(
+                entry -> {
+                  try {
+                    K key = keyCoder.decode(entry.getKey().newInput());
+                    Object structuralKey = keyCoder.structuralValue(key);
+                    structuralKeyMapping.put(structuralKey, key);
+                    if (nonexistentKeyCache.contains(structuralKey)) return;
+                    if (entryMap.containsKey(structuralKey)) {

Review Comment:
   Done



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateReader.java:
##########
@@ -466,20 +518,67 @@ private <ResultT, ContinuationT> Future<Iterable<ResultT>> valuesToPagingIterabl
     return Futures.lazyTransform(future, toIterable);
   }
 
+  private void delayUnbatchableMultimapFetches(
+      List<StateTag<?>> multimapTags, HashSet<StateTag<?>> toFetch) {
+    // Each KeyedGetDataRequest can have at most 1 TagMultimapFetchRequest, thus we need to delay
+    // unbatchable multimap requests of the same stateFamily and tag into later batches. There's no
+    // priority between get()/entries()/keys(), they will be fetched based on the order they occur
+    // in pendingLookups, so that all requests can eventually be fetched and none starves.
+
+    Map<String, Map<ByteString, List<StateTag<?>>>> groupedTags =
+        multimapTags.stream()
+            .collect(
+                Collectors.groupingBy(

Review Comment:
   Do you mean to do something like this `Collectors.groupingBy(makePair(StateTag::getStateFamily, StateTag::getTag))`? So the groupped result is a `Map<Pair<String, ByteString>, List<StateTag<?>>>`? This doesn't seem to be much better.



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateReader.java:
##########
@@ -334,6 +358,28 @@ public <T> Future<Iterable<TimestampedValue<T>>> orderedListFuture(
         valuesToPagingIterableFuture(stateTag, elemCoder, this.stateFuture(stateTag, elemCoder)));
   }
 
+  public <T> Future<Iterable<Map.Entry<ByteString, Iterable<T>>>> multimapFetchAllFuture(
+      boolean omitValues, ByteString encodedTag, String stateFamily, Coder<T> elemCoder) {
+    StateTag<ByteString> stateTag =
+        StateTag.<ByteString>of(Kind.MULTIMAP_ALL, encodedTag, stateFamily)
+            .toBuilder()
+            .setOmitValues(omitValues)
+            .build();
+    return Preconditions.checkNotNull(

Review Comment:
   Removed the Preconditions check.



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateReader.java:
##########
@@ -466,20 +518,67 @@ private <ResultT, ContinuationT> Future<Iterable<ResultT>> valuesToPagingIterabl
     return Futures.lazyTransform(future, toIterable);
   }
 
+  private void delayUnbatchableMultimapFetches(
+      List<StateTag<?>> multimapTags, HashSet<StateTag<?>> toFetch) {
+    // Each KeyedGetDataRequest can have at most 1 TagMultimapFetchRequest, thus we need to delay

Review Comment:
   Yes I meant at most 1 fetch for a tag/state_family, updated the comment.



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateReader.java:
##########
@@ -334,6 +358,28 @@ public <T> Future<Iterable<TimestampedValue<T>>> orderedListFuture(
         valuesToPagingIterableFuture(stateTag, elemCoder, this.stateFuture(stateTag, elemCoder)));
   }
 
+  public <T> Future<Iterable<Map.Entry<ByteString, Iterable<T>>>> multimapFetchAllFuture(
+      boolean omitValues, ByteString encodedTag, String stateFamily, Coder<T> elemCoder) {
+    StateTag<ByteString> stateTag =
+        StateTag.<ByteString>of(Kind.MULTIMAP_ALL, encodedTag, stateFamily)
+            .toBuilder()
+            .setOmitValues(omitValues)
+            .build();
+    return Preconditions.checkNotNull(
+        valuesToPagingIterableFuture(stateTag, elemCoder, this.stateFuture(stateTag, elemCoder)));
+  }
+
+  public <T> Future<Iterable<T>> multimapFetchSingleEntryFuture(
+      ByteString encodedKey, ByteString encodedTag, String stateFamily, Coder<T> elemCoder) {
+    StateTag<ByteString> stateTag =
+        StateTag.<ByteString>of(Kind.MULTIMAP_SINGLE_ENTRY, encodedTag, stateFamily)
+            .toBuilder()
+            .setMultimapKey(encodedKey)
+            .build();
+    return Preconditions.checkNotNull(

Review Comment:
   Done



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternals.java:
##########
@@ -1588,7 +1599,472 @@ private Future<Iterable<Map.Entry<ByteString, V>>> getFuture() {
         return reader.valuePrefixFuture(stateKeyPrefix, stateFamily, valueCoder);
       }
     }
-  };
+  }
+
+  private static class WindmillMultimap<K, V> extends SimpleWindmillState
+      implements MultimapState<K, V> {
+
+    private final StateNamespace namespace;
+    private final StateTag<MultimapState<K, V>> address;
+    private final ByteString stateKey;
+    private final String stateFamily;
+    private final Coder<K> keyCoder;
+    private final Coder<V> valueCoder;
+
+    private boolean cleared = false;
+    /**
+     * For any given key, if it's contained in {@link #cachedEntries}, then the complete content of
+     * this key is cached: persisted values of this key in backing store are cached in
+     * cachedEntries, newly added values not yet written to backing store are cached in
+     * localAdditions. If a key is not contained in {@link #cachedEntries} then we don't know if
+     * Windmill contains additional values which also maps to this key, we'll need to read them if
+     * the work item actually wants the content.
+     */
+    private Map<Object, Iterable<V>> cachedEntries = Maps.newHashMap();
+    // Any key presents in existKeyCache is known to exist in the multimap.
+    private Set<Object> existKeyCache = Sets.newHashSet();
+    // If true, any key not in existKeyCache is known to be nonexistent.
+    private boolean allKeysKnown = false;
+    // Any key presents in nonexistentKeyCache is known to be nonexistent in the multimap.
+    private Set<Object> nonexistentKeyCache = Sets.newHashSet();
+
+    private boolean complete = false;
+    private Multimap<Object, V> localAdditions = ArrayListMultimap.create();
+    // All keys that are pending delete. If a key exist in both localRemovals and localAdditions:
+    // new values in localAdditions will be added after old values are removed.
+    private Set<Object> localRemovals = Sets.newHashSet();
+    // structuralKeyMapping maps from the structuralKeys to the actual keys. Any key in
+    // cachedEntries, existKeyCache, nonexistentKeyCache, localAdditions and localRemovals should be
+    // included in this mapping.
+    private Map<Object, K> structuralKeyMapping = Maps.newHashMap();

Review Comment:
   Done



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateReader.java:
##########
@@ -466,20 +518,67 @@ private <ResultT, ContinuationT> Future<Iterable<ResultT>> valuesToPagingIterabl
     return Futures.lazyTransform(future, toIterable);
   }
 
+  private void delayUnbatchableMultimapFetches(
+      List<StateTag<?>> multimapTags, HashSet<StateTag<?>> toFetch) {
+    // Each KeyedGetDataRequest can have at most 1 TagMultimapFetchRequest, thus we need to delay
+    // unbatchable multimap requests of the same stateFamily and tag into later batches. There's no
+    // priority between get()/entries()/keys(), they will be fetched based on the order they occur
+    // in pendingLookups, so that all requests can eventually be fetched and none starves.
+
+    Map<String, Map<ByteString, List<StateTag<?>>>> groupedTags =
+        multimapTags.stream()
+            .collect(
+                Collectors.groupingBy(
+                    StateTag::getStateFamily, Collectors.groupingBy(StateTag::getTag)));
+
+    for (Map<ByteString, List<StateTag<?>>> familyTags : groupedTags.values()) {

Review Comment:
   Done



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternals.java:
##########
@@ -1588,7 +1599,472 @@ private Future<Iterable<Map.Entry<ByteString, V>>> getFuture() {
         return reader.valuePrefixFuture(stateKeyPrefix, stateFamily, valueCoder);
       }
     }
-  };
+  }
+
+  private static class WindmillMultimap<K, V> extends SimpleWindmillState
+      implements MultimapState<K, V> {
+
+    private final StateNamespace namespace;
+    private final StateTag<MultimapState<K, V>> address;
+    private final ByteString stateKey;
+    private final String stateFamily;
+    private final Coder<K> keyCoder;
+    private final Coder<V> valueCoder;
+
+    private boolean cleared = false;
+    /**
+     * For any given key, if it's contained in {@link #cachedEntries}, then the complete content of
+     * this key is cached: persisted values of this key in backing store are cached in
+     * cachedEntries, newly added values not yet written to backing store are cached in
+     * localAdditions. If a key is not contained in {@link #cachedEntries} then we don't know if
+     * Windmill contains additional values which also maps to this key, we'll need to read them if
+     * the work item actually wants the content.
+     */
+    private Map<Object, Iterable<V>> cachedEntries = Maps.newHashMap();
+    // Any key presents in existKeyCache is known to exist in the multimap.
+    private Set<Object> existKeyCache = Sets.newHashSet();
+    // If true, any key not in existKeyCache is known to be nonexistent.
+    private boolean allKeysKnown = false;
+    // Any key presents in nonexistentKeyCache is known to be nonexistent in the multimap.
+    private Set<Object> nonexistentKeyCache = Sets.newHashSet();
+
+    private boolean complete = false;
+    private Multimap<Object, V> localAdditions = ArrayListMultimap.create();
+    // All keys that are pending delete. If a key exist in both localRemovals and localAdditions:
+    // new values in localAdditions will be added after old values are removed.
+    private Set<Object> localRemovals = Sets.newHashSet();
+    // structuralKeyMapping maps from the structuralKeys to the actual keys. Any key in
+    // cachedEntries, existKeyCache, nonexistentKeyCache, localAdditions and localRemovals should be
+    // included in this mapping.
+    private Map<Object, K> structuralKeyMapping = Maps.newHashMap();
+
+    private WindmillMultimap(
+        StateNamespace namespace,
+        StateTag<MultimapState<K, V>> address,
+        String stateFamily,
+        Coder<K> keyCoder,
+        Coder<V> valueCoder,
+        boolean isNewShardingKey) {
+      this.namespace = namespace;
+      this.address = address;
+      this.stateKey = encodeKey(namespace, address);
+      this.stateFamily = stateFamily;
+      this.keyCoder = keyCoder;
+      this.valueCoder = valueCoder;
+      this.complete = isNewShardingKey;
+      this.allKeysKnown = isNewShardingKey;
+    }
+
+    @Override
+    public void put(K key, V value) {
+      Object structuralKey = keyCoder.structuralValue(key);
+      localAdditions.put(structuralKey, value);
+      existKeyCache.add(structuralKey);
+      nonexistentKeyCache.remove(structuralKey);
+      structuralKeyMapping.put(structuralKey, key);
+    }
+
+    // Initiates a backend state read to fetch all entries if necessary.
+    private Future<Iterable<Map.Entry<ByteString, Iterable<V>>>> getFuture(boolean omitValues) {
+      if (complete) {
+        return Futures.immediateFuture(Collections.emptyList());
+      } else {
+        return reader.multimapFetchAllFuture(omitValues, stateKey, stateFamily, valueCoder);
+      }
+    }
+
+    // Initiates a backend state read to fetch a single entry if necessary.
+    private Future<Iterable<V>> getFutureForKey(K key) {
+      try {
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        keyCoder.encode(key, keyStream);
+        return reader.multimapFetchSingleEntryFuture(
+            keyStream.toByteString(), stateKey, stateFamily, valueCoder);
+      } catch (IOException e) {
+        throw new RuntimeException(e);
+      }
+    }
+
+    @Override
+    public ReadableState<Iterable<V>> get(K key) {
+      return new ReadableState<Iterable<V>>() {
+        Object structuralKey = keyCoder.structuralValue(key);
+
+        @Override
+        public Iterable<V> read() {
+          if (nonexistentKeyCache.contains(structuralKey)
+              || (allKeysKnown && !existKeyCache.contains(structuralKey))) {
+            return Collections.emptyList();
+          }
+          if (localRemovals.contains(structuralKey)) {
+            // this key has been removed locally but the removal hasn't been sent to windmill,
+            // thus values in windmill(if any) are obsolete, and we only care about local values.
+            if (localAdditions.containsKey(structuralKey)) {
+              return Iterables.unmodifiableIterable(localAdditions.get(structuralKey));
+            } else {
+              return Collections.emptyList();
+            }
+          }
+          if (cachedEntries.containsKey(structuralKey) || complete) {
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(
+                    cachedEntries.getOrDefault(structuralKey, Collections.emptyList()),
+                    localAdditions.get(structuralKey)));
+          }
+          Future<Iterable<V>> persistedData = getFutureForKey(key);
+          try (Closeable scope = scopedReadState()) {
+            Iterable<V> persistedValues = persistedData.get();
+            if (Iterables.size(persistedValues) == 0) {

Review Comment:
   Done



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternals.java:
##########
@@ -1588,7 +1599,472 @@ private Future<Iterable<Map.Entry<ByteString, V>>> getFuture() {
         return reader.valuePrefixFuture(stateKeyPrefix, stateFamily, valueCoder);
       }
     }
-  };
+  }
+
+  private static class WindmillMultimap<K, V> extends SimpleWindmillState
+      implements MultimapState<K, V> {
+
+    private final StateNamespace namespace;
+    private final StateTag<MultimapState<K, V>> address;
+    private final ByteString stateKey;
+    private final String stateFamily;
+    private final Coder<K> keyCoder;
+    private final Coder<V> valueCoder;
+
+    private boolean cleared = false;
+    /**
+     * For any given key, if it's contained in {@link #cachedEntries}, then the complete content of
+     * this key is cached: persisted values of this key in backing store are cached in
+     * cachedEntries, newly added values not yet written to backing store are cached in
+     * localAdditions. If a key is not contained in {@link #cachedEntries} then we don't know if
+     * Windmill contains additional values which also maps to this key, we'll need to read them if
+     * the work item actually wants the content.
+     */
+    private Map<Object, Iterable<V>> cachedEntries = Maps.newHashMap();
+    // Any key presents in existKeyCache is known to exist in the multimap.
+    private Set<Object> existKeyCache = Sets.newHashSet();
+    // If true, any key not in existKeyCache is known to be nonexistent.
+    private boolean allKeysKnown = false;
+    // Any key presents in nonexistentKeyCache is known to be nonexistent in the multimap.
+    private Set<Object> nonexistentKeyCache = Sets.newHashSet();
+
+    private boolean complete = false;
+    private Multimap<Object, V> localAdditions = ArrayListMultimap.create();
+    // All keys that are pending delete. If a key exist in both localRemovals and localAdditions:
+    // new values in localAdditions will be added after old values are removed.
+    private Set<Object> localRemovals = Sets.newHashSet();
+    // structuralKeyMapping maps from the structuralKeys to the actual keys. Any key in
+    // cachedEntries, existKeyCache, nonexistentKeyCache, localAdditions and localRemovals should be
+    // included in this mapping.
+    private Map<Object, K> structuralKeyMapping = Maps.newHashMap();
+
+    private WindmillMultimap(
+        StateNamespace namespace,
+        StateTag<MultimapState<K, V>> address,
+        String stateFamily,
+        Coder<K> keyCoder,
+        Coder<V> valueCoder,
+        boolean isNewShardingKey) {
+      this.namespace = namespace;
+      this.address = address;
+      this.stateKey = encodeKey(namespace, address);
+      this.stateFamily = stateFamily;
+      this.keyCoder = keyCoder;
+      this.valueCoder = valueCoder;
+      this.complete = isNewShardingKey;
+      this.allKeysKnown = isNewShardingKey;
+    }
+
+    @Override
+    public void put(K key, V value) {
+      Object structuralKey = keyCoder.structuralValue(key);
+      localAdditions.put(structuralKey, value);
+      existKeyCache.add(structuralKey);
+      nonexistentKeyCache.remove(structuralKey);
+      structuralKeyMapping.put(structuralKey, key);
+    }
+
+    // Initiates a backend state read to fetch all entries if necessary.
+    private Future<Iterable<Map.Entry<ByteString, Iterable<V>>>> getFuture(boolean omitValues) {
+      if (complete) {
+        return Futures.immediateFuture(Collections.emptyList());
+      } else {
+        return reader.multimapFetchAllFuture(omitValues, stateKey, stateFamily, valueCoder);
+      }
+    }
+
+    // Initiates a backend state read to fetch a single entry if necessary.
+    private Future<Iterable<V>> getFutureForKey(K key) {
+      try {
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        keyCoder.encode(key, keyStream);
+        return reader.multimapFetchSingleEntryFuture(
+            keyStream.toByteString(), stateKey, stateFamily, valueCoder);
+      } catch (IOException e) {
+        throw new RuntimeException(e);
+      }
+    }
+
+    @Override
+    public ReadableState<Iterable<V>> get(K key) {
+      return new ReadableState<Iterable<V>>() {
+        Object structuralKey = keyCoder.structuralValue(key);
+
+        @Override
+        public Iterable<V> read() {
+          if (nonexistentKeyCache.contains(structuralKey)
+              || (allKeysKnown && !existKeyCache.contains(structuralKey))) {
+            return Collections.emptyList();
+          }
+          if (localRemovals.contains(structuralKey)) {
+            // this key has been removed locally but the removal hasn't been sent to windmill,
+            // thus values in windmill(if any) are obsolete, and we only care about local values.
+            if (localAdditions.containsKey(structuralKey)) {
+              return Iterables.unmodifiableIterable(localAdditions.get(structuralKey));
+            } else {
+              return Collections.emptyList();
+            }
+          }
+          if (cachedEntries.containsKey(structuralKey) || complete) {
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(
+                    cachedEntries.getOrDefault(structuralKey, Collections.emptyList()),
+                    localAdditions.get(structuralKey)));
+          }
+          Future<Iterable<V>> persistedData = getFutureForKey(key);
+          try (Closeable scope = scopedReadState()) {
+            Iterable<V> persistedValues = persistedData.get();
+            if (Iterables.size(persistedValues) == 0) {
+              if (!localAdditions.containsKey(structuralKey)) {
+                nonexistentKeyCache.add(structuralKey);
+                Preconditions.checkState(
+                    !existKeyCache.contains(structuralKey),
+                    "Key "
+                        + key
+                        + " exists"
+                        + " in existKeyCache but no value in neither windmill nor local additions.");
+              }
+              return Iterables.unmodifiableIterable(localAdditions.get(structuralKey));
+            }
+            if (persistedValues instanceof Weighted) {
+              cachedEntries.put(structuralKey, new ConcatIterables<>());
+              ((ConcatIterables<V>) cachedEntries.get(structuralKey)).extendWith(persistedValues);
+            }
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(persistedValues, localAdditions.get(structuralKey)));
+          } catch (InterruptedException | ExecutionException | IOException e) {
+            if (e instanceof InterruptedException) {
+              Thread.currentThread().interrupt();
+            }
+            throw new RuntimeException("Unable to read Multimap state", e);
+          }
+        }
+
+        @Override
+        @SuppressWarnings("FutureReturnValueIgnored")
+        public ReadableState<Iterable<V>> readLater() {
+          WindmillMultimap.this.getFutureForKey(key);
+          return this;
+        }
+      };
+    }
+
+    @Override
+    protected WorkItemCommitRequest persistDirectly(WindmillStateCache.ForKeyAndFamily cache)
+        throws IOException {
+      if (!cleared && localAdditions.isEmpty() && localRemovals.isEmpty()) {
+        return WorkItemCommitRequest.newBuilder().buildPartial();
+      }
+      WorkItemCommitRequest.Builder commitBuilder = WorkItemCommitRequest.newBuilder();
+      Windmill.TagMultimapUpdateRequest.Builder builder = null;
+      if (cleared) {
+        builder = commitBuilder.addMultimapUpdatesBuilder();
+        builder.setDeleteAll(true);
+        cleared = false;
+      }
+      Set<Object> keysWithUpdates = Sets.newHashSet();
+      keysWithUpdates.addAll(localRemovals);
+      keysWithUpdates.addAll(localAdditions.keySet());
+      if (!keysWithUpdates.isEmpty() && builder == null) {
+        builder = commitBuilder.addMultimapUpdatesBuilder();
+      }
+      for (Object structuralKey : keysWithUpdates) {
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        keyCoder.encode(structuralKeyMapping.get(structuralKey), keyStream);
+        ByteString encodedKey = keyStream.toByteString();
+        Windmill.TagMultimapEntry.Builder entryBuilder = builder.addUpdatesBuilder();
+        entryBuilder.setEntryName(encodedKey);
+        entryBuilder.setDeleteAll(localRemovals.contains(structuralKey));
+        for (V value : localAdditions.get(structuralKey)) {
+          ByteStringOutputStream valueStream = new ByteStringOutputStream();
+          valueCoder.encode(value, valueStream);
+          ByteString encodedValue = valueStream.toByteString();
+          entryBuilder.addValues(encodedValue);
+        }
+        if (cachedEntries.containsKey(structuralKey)) {
+          // Move newly added values from localAdditions to cachedEntries as those new values are
+          // also persisted in Windmill.
+          ((ConcatIterables<V>) cachedEntries.get(structuralKey))

Review Comment:
   Done



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternals.java:
##########
@@ -1588,7 +1599,472 @@ private Future<Iterable<Map.Entry<ByteString, V>>> getFuture() {
         return reader.valuePrefixFuture(stateKeyPrefix, stateFamily, valueCoder);
       }
     }
-  };
+  }
+
+  private static class WindmillMultimap<K, V> extends SimpleWindmillState
+      implements MultimapState<K, V> {
+
+    private final StateNamespace namespace;
+    private final StateTag<MultimapState<K, V>> address;
+    private final ByteString stateKey;
+    private final String stateFamily;
+    private final Coder<K> keyCoder;
+    private final Coder<V> valueCoder;
+
+    private boolean cleared = false;
+    /**
+     * For any given key, if it's contained in {@link #cachedEntries}, then the complete content of
+     * this key is cached: persisted values of this key in backing store are cached in
+     * cachedEntries, newly added values not yet written to backing store are cached in
+     * localAdditions. If a key is not contained in {@link #cachedEntries} then we don't know if
+     * Windmill contains additional values which also maps to this key, we'll need to read them if
+     * the work item actually wants the content.
+     */
+    private Map<Object, Iterable<V>> cachedEntries = Maps.newHashMap();
+    // Any key presents in existKeyCache is known to exist in the multimap.
+    private Set<Object> existKeyCache = Sets.newHashSet();
+    // If true, any key not in existKeyCache is known to be nonexistent.
+    private boolean allKeysKnown = false;
+    // Any key presents in nonexistentKeyCache is known to be nonexistent in the multimap.
+    private Set<Object> nonexistentKeyCache = Sets.newHashSet();
+
+    private boolean complete = false;
+    private Multimap<Object, V> localAdditions = ArrayListMultimap.create();
+    // All keys that are pending delete. If a key exist in both localRemovals and localAdditions:
+    // new values in localAdditions will be added after old values are removed.
+    private Set<Object> localRemovals = Sets.newHashSet();
+    // structuralKeyMapping maps from the structuralKeys to the actual keys. Any key in
+    // cachedEntries, existKeyCache, nonexistentKeyCache, localAdditions and localRemovals should be
+    // included in this mapping.
+    private Map<Object, K> structuralKeyMapping = Maps.newHashMap();
+
+    private WindmillMultimap(
+        StateNamespace namespace,
+        StateTag<MultimapState<K, V>> address,
+        String stateFamily,
+        Coder<K> keyCoder,
+        Coder<V> valueCoder,
+        boolean isNewShardingKey) {
+      this.namespace = namespace;
+      this.address = address;
+      this.stateKey = encodeKey(namespace, address);
+      this.stateFamily = stateFamily;
+      this.keyCoder = keyCoder;
+      this.valueCoder = valueCoder;
+      this.complete = isNewShardingKey;
+      this.allKeysKnown = isNewShardingKey;
+    }
+
+    @Override
+    public void put(K key, V value) {
+      Object structuralKey = keyCoder.structuralValue(key);
+      localAdditions.put(structuralKey, value);
+      existKeyCache.add(structuralKey);
+      nonexistentKeyCache.remove(structuralKey);
+      structuralKeyMapping.put(structuralKey, key);
+    }
+
+    // Initiates a backend state read to fetch all entries if necessary.
+    private Future<Iterable<Map.Entry<ByteString, Iterable<V>>>> getFuture(boolean omitValues) {
+      if (complete) {
+        return Futures.immediateFuture(Collections.emptyList());
+      } else {
+        return reader.multimapFetchAllFuture(omitValues, stateKey, stateFamily, valueCoder);
+      }
+    }
+
+    // Initiates a backend state read to fetch a single entry if necessary.
+    private Future<Iterable<V>> getFutureForKey(K key) {
+      try {
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        keyCoder.encode(key, keyStream);
+        return reader.multimapFetchSingleEntryFuture(
+            keyStream.toByteString(), stateKey, stateFamily, valueCoder);
+      } catch (IOException e) {
+        throw new RuntimeException(e);
+      }
+    }
+
+    @Override
+    public ReadableState<Iterable<V>> get(K key) {
+      return new ReadableState<Iterable<V>>() {
+        Object structuralKey = keyCoder.structuralValue(key);
+
+        @Override
+        public Iterable<V> read() {
+          if (nonexistentKeyCache.contains(structuralKey)
+              || (allKeysKnown && !existKeyCache.contains(structuralKey))) {
+            return Collections.emptyList();
+          }
+          if (localRemovals.contains(structuralKey)) {
+            // this key has been removed locally but the removal hasn't been sent to windmill,
+            // thus values in windmill(if any) are obsolete, and we only care about local values.
+            if (localAdditions.containsKey(structuralKey)) {
+              return Iterables.unmodifiableIterable(localAdditions.get(structuralKey));
+            } else {
+              return Collections.emptyList();
+            }
+          }
+          if (cachedEntries.containsKey(structuralKey) || complete) {
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(
+                    cachedEntries.getOrDefault(structuralKey, Collections.emptyList()),
+                    localAdditions.get(structuralKey)));
+          }
+          Future<Iterable<V>> persistedData = getFutureForKey(key);
+          try (Closeable scope = scopedReadState()) {
+            Iterable<V> persistedValues = persistedData.get();
+            if (Iterables.size(persistedValues) == 0) {
+              if (!localAdditions.containsKey(structuralKey)) {
+                nonexistentKeyCache.add(structuralKey);
+                Preconditions.checkState(
+                    !existKeyCache.contains(structuralKey),
+                    "Key "
+                        + key
+                        + " exists"
+                        + " in existKeyCache but no value in neither windmill nor local additions.");
+              }
+              return Iterables.unmodifiableIterable(localAdditions.get(structuralKey));
+            }
+            if (persistedValues instanceof Weighted) {
+              cachedEntries.put(structuralKey, new ConcatIterables<>());
+              ((ConcatIterables<V>) cachedEntries.get(structuralKey)).extendWith(persistedValues);
+            }
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(persistedValues, localAdditions.get(structuralKey)));
+          } catch (InterruptedException | ExecutionException | IOException e) {
+            if (e instanceof InterruptedException) {
+              Thread.currentThread().interrupt();
+            }
+            throw new RuntimeException("Unable to read Multimap state", e);
+          }
+        }
+
+        @Override
+        @SuppressWarnings("FutureReturnValueIgnored")
+        public ReadableState<Iterable<V>> readLater() {
+          WindmillMultimap.this.getFutureForKey(key);
+          return this;
+        }
+      };
+    }
+
+    @Override
+    protected WorkItemCommitRequest persistDirectly(WindmillStateCache.ForKeyAndFamily cache)
+        throws IOException {
+      if (!cleared && localAdditions.isEmpty() && localRemovals.isEmpty()) {
+        return WorkItemCommitRequest.newBuilder().buildPartial();
+      }
+      WorkItemCommitRequest.Builder commitBuilder = WorkItemCommitRequest.newBuilder();
+      Windmill.TagMultimapUpdateRequest.Builder builder = null;
+      if (cleared) {
+        builder = commitBuilder.addMultimapUpdatesBuilder();
+        builder.setDeleteAll(true);
+        cleared = false;
+      }
+      Set<Object> keysWithUpdates = Sets.newHashSet();
+      keysWithUpdates.addAll(localRemovals);
+      keysWithUpdates.addAll(localAdditions.keySet());
+      if (!keysWithUpdates.isEmpty() && builder == null) {
+        builder = commitBuilder.addMultimapUpdatesBuilder();
+      }
+      for (Object structuralKey : keysWithUpdates) {

Review Comment:
   Done



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternals.java:
##########
@@ -1588,7 +1599,472 @@ private Future<Iterable<Map.Entry<ByteString, V>>> getFuture() {
         return reader.valuePrefixFuture(stateKeyPrefix, stateFamily, valueCoder);
       }
     }
-  };
+  }
+
+  private static class WindmillMultimap<K, V> extends SimpleWindmillState
+      implements MultimapState<K, V> {
+
+    private final StateNamespace namespace;
+    private final StateTag<MultimapState<K, V>> address;
+    private final ByteString stateKey;
+    private final String stateFamily;
+    private final Coder<K> keyCoder;
+    private final Coder<V> valueCoder;
+
+    private boolean cleared = false;
+    /**
+     * For any given key, if it's contained in {@link #cachedEntries}, then the complete content of
+     * this key is cached: persisted values of this key in backing store are cached in
+     * cachedEntries, newly added values not yet written to backing store are cached in
+     * localAdditions. If a key is not contained in {@link #cachedEntries} then we don't know if
+     * Windmill contains additional values which also maps to this key, we'll need to read them if
+     * the work item actually wants the content.
+     */
+    private Map<Object, Iterable<V>> cachedEntries = Maps.newHashMap();
+    // Any key presents in existKeyCache is known to exist in the multimap.
+    private Set<Object> existKeyCache = Sets.newHashSet();
+    // If true, any key not in existKeyCache is known to be nonexistent.
+    private boolean allKeysKnown = false;
+    // Any key presents in nonexistentKeyCache is known to be nonexistent in the multimap.
+    private Set<Object> nonexistentKeyCache = Sets.newHashSet();
+
+    private boolean complete = false;
+    private Multimap<Object, V> localAdditions = ArrayListMultimap.create();
+    // All keys that are pending delete. If a key exist in both localRemovals and localAdditions:
+    // new values in localAdditions will be added after old values are removed.
+    private Set<Object> localRemovals = Sets.newHashSet();
+    // structuralKeyMapping maps from the structuralKeys to the actual keys. Any key in

Review Comment:
   We use structural key so that different java Object with the same content will be treated as the same key.
   
   Example:
   
   byte[] a = ..., b = ...; // initialize a and b to have the same content
   multimap.put(a, "a");
   multimap.get(b); // should return "a"



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternals.java:
##########
@@ -1588,7 +1599,472 @@ private Future<Iterable<Map.Entry<ByteString, V>>> getFuture() {
         return reader.valuePrefixFuture(stateKeyPrefix, stateFamily, valueCoder);
       }
     }
-  };
+  }
+
+  private static class WindmillMultimap<K, V> extends SimpleWindmillState
+      implements MultimapState<K, V> {
+
+    private final StateNamespace namespace;
+    private final StateTag<MultimapState<K, V>> address;
+    private final ByteString stateKey;
+    private final String stateFamily;
+    private final Coder<K> keyCoder;
+    private final Coder<V> valueCoder;
+
+    private boolean cleared = false;
+    /**
+     * For any given key, if it's contained in {@link #cachedEntries}, then the complete content of
+     * this key is cached: persisted values of this key in backing store are cached in
+     * cachedEntries, newly added values not yet written to backing store are cached in
+     * localAdditions. If a key is not contained in {@link #cachedEntries} then we don't know if
+     * Windmill contains additional values which also maps to this key, we'll need to read them if
+     * the work item actually wants the content.
+     */
+    private Map<Object, Iterable<V>> cachedEntries = Maps.newHashMap();
+    // Any key presents in existKeyCache is known to exist in the multimap.
+    private Set<Object> existKeyCache = Sets.newHashSet();
+    // If true, any key not in existKeyCache is known to be nonexistent.
+    private boolean allKeysKnown = false;
+    // Any key presents in nonexistentKeyCache is known to be nonexistent in the multimap.
+    private Set<Object> nonexistentKeyCache = Sets.newHashSet();
+
+    private boolean complete = false;
+    private Multimap<Object, V> localAdditions = ArrayListMultimap.create();
+    // All keys that are pending delete. If a key exist in both localRemovals and localAdditions:
+    // new values in localAdditions will be added after old values are removed.
+    private Set<Object> localRemovals = Sets.newHashSet();
+    // structuralKeyMapping maps from the structuralKeys to the actual keys. Any key in
+    // cachedEntries, existKeyCache, nonexistentKeyCache, localAdditions and localRemovals should be
+    // included in this mapping.
+    private Map<Object, K> structuralKeyMapping = Maps.newHashMap();
+
+    private WindmillMultimap(
+        StateNamespace namespace,
+        StateTag<MultimapState<K, V>> address,
+        String stateFamily,
+        Coder<K> keyCoder,
+        Coder<V> valueCoder,
+        boolean isNewShardingKey) {
+      this.namespace = namespace;
+      this.address = address;
+      this.stateKey = encodeKey(namespace, address);
+      this.stateFamily = stateFamily;
+      this.keyCoder = keyCoder;
+      this.valueCoder = valueCoder;
+      this.complete = isNewShardingKey;
+      this.allKeysKnown = isNewShardingKey;
+    }
+
+    @Override
+    public void put(K key, V value) {
+      Object structuralKey = keyCoder.structuralValue(key);
+      localAdditions.put(structuralKey, value);
+      existKeyCache.add(structuralKey);
+      nonexistentKeyCache.remove(structuralKey);
+      structuralKeyMapping.put(structuralKey, key);
+    }
+
+    // Initiates a backend state read to fetch all entries if necessary.
+    private Future<Iterable<Map.Entry<ByteString, Iterable<V>>>> getFuture(boolean omitValues) {
+      if (complete) {
+        return Futures.immediateFuture(Collections.emptyList());
+      } else {
+        return reader.multimapFetchAllFuture(omitValues, stateKey, stateFamily, valueCoder);
+      }
+    }
+
+    // Initiates a backend state read to fetch a single entry if necessary.
+    private Future<Iterable<V>> getFutureForKey(K key) {
+      try {
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        keyCoder.encode(key, keyStream);
+        return reader.multimapFetchSingleEntryFuture(
+            keyStream.toByteString(), stateKey, stateFamily, valueCoder);
+      } catch (IOException e) {
+        throw new RuntimeException(e);
+      }
+    }
+
+    @Override
+    public ReadableState<Iterable<V>> get(K key) {
+      return new ReadableState<Iterable<V>>() {
+        Object structuralKey = keyCoder.structuralValue(key);
+
+        @Override
+        public Iterable<V> read() {
+          if (nonexistentKeyCache.contains(structuralKey)
+              || (allKeysKnown && !existKeyCache.contains(structuralKey))) {
+            return Collections.emptyList();
+          }
+          if (localRemovals.contains(structuralKey)) {
+            // this key has been removed locally but the removal hasn't been sent to windmill,
+            // thus values in windmill(if any) are obsolete, and we only care about local values.
+            if (localAdditions.containsKey(structuralKey)) {
+              return Iterables.unmodifiableIterable(localAdditions.get(structuralKey));
+            } else {
+              return Collections.emptyList();
+            }
+          }
+          if (cachedEntries.containsKey(structuralKey) || complete) {
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(
+                    cachedEntries.getOrDefault(structuralKey, Collections.emptyList()),
+                    localAdditions.get(structuralKey)));
+          }
+          Future<Iterable<V>> persistedData = getFutureForKey(key);
+          try (Closeable scope = scopedReadState()) {
+            Iterable<V> persistedValues = persistedData.get();
+            if (Iterables.size(persistedValues) == 0) {
+              if (!localAdditions.containsKey(structuralKey)) {
+                nonexistentKeyCache.add(structuralKey);
+                Preconditions.checkState(
+                    !existKeyCache.contains(structuralKey),
+                    "Key "
+                        + key
+                        + " exists"
+                        + " in existKeyCache but no value in neither windmill nor local additions.");
+              }
+              return Iterables.unmodifiableIterable(localAdditions.get(structuralKey));
+            }
+            if (persistedValues instanceof Weighted) {
+              cachedEntries.put(structuralKey, new ConcatIterables<>());
+              ((ConcatIterables<V>) cachedEntries.get(structuralKey)).extendWith(persistedValues);
+            }
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(persistedValues, localAdditions.get(structuralKey)));
+          } catch (InterruptedException | ExecutionException | IOException e) {
+            if (e instanceof InterruptedException) {
+              Thread.currentThread().interrupt();
+            }
+            throw new RuntimeException("Unable to read Multimap state", e);
+          }
+        }
+
+        @Override
+        @SuppressWarnings("FutureReturnValueIgnored")
+        public ReadableState<Iterable<V>> readLater() {
+          WindmillMultimap.this.getFutureForKey(key);
+          return this;
+        }
+      };
+    }
+
+    @Override
+    protected WorkItemCommitRequest persistDirectly(WindmillStateCache.ForKeyAndFamily cache)
+        throws IOException {
+      if (!cleared && localAdditions.isEmpty() && localRemovals.isEmpty()) {
+        return WorkItemCommitRequest.newBuilder().buildPartial();

Review Comment:
   can you clarify? is this comment on the wrong line?



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternals.java:
##########
@@ -1588,7 +1599,472 @@ private Future<Iterable<Map.Entry<ByteString, V>>> getFuture() {
         return reader.valuePrefixFuture(stateKeyPrefix, stateFamily, valueCoder);
       }
     }
-  };
+  }
+
+  private static class WindmillMultimap<K, V> extends SimpleWindmillState
+      implements MultimapState<K, V> {
+
+    private final StateNamespace namespace;
+    private final StateTag<MultimapState<K, V>> address;
+    private final ByteString stateKey;
+    private final String stateFamily;
+    private final Coder<K> keyCoder;
+    private final Coder<V> valueCoder;
+
+    private boolean cleared = false;
+    /**
+     * For any given key, if it's contained in {@link #cachedEntries}, then the complete content of
+     * this key is cached: persisted values of this key in backing store are cached in
+     * cachedEntries, newly added values not yet written to backing store are cached in
+     * localAdditions. If a key is not contained in {@link #cachedEntries} then we don't know if
+     * Windmill contains additional values which also maps to this key, we'll need to read them if
+     * the work item actually wants the content.
+     */
+    private Map<Object, Iterable<V>> cachedEntries = Maps.newHashMap();
+    // Any key presents in existKeyCache is known to exist in the multimap.
+    private Set<Object> existKeyCache = Sets.newHashSet();
+    // If true, any key not in existKeyCache is known to be nonexistent.
+    private boolean allKeysKnown = false;
+    // Any key presents in nonexistentKeyCache is known to be nonexistent in the multimap.
+    private Set<Object> nonexistentKeyCache = Sets.newHashSet();
+
+    private boolean complete = false;
+    private Multimap<Object, V> localAdditions = ArrayListMultimap.create();
+    // All keys that are pending delete. If a key exist in both localRemovals and localAdditions:
+    // new values in localAdditions will be added after old values are removed.
+    private Set<Object> localRemovals = Sets.newHashSet();
+    // structuralKeyMapping maps from the structuralKeys to the actual keys. Any key in
+    // cachedEntries, existKeyCache, nonexistentKeyCache, localAdditions and localRemovals should be
+    // included in this mapping.
+    private Map<Object, K> structuralKeyMapping = Maps.newHashMap();
+
+    private WindmillMultimap(
+        StateNamespace namespace,
+        StateTag<MultimapState<K, V>> address,
+        String stateFamily,
+        Coder<K> keyCoder,
+        Coder<V> valueCoder,
+        boolean isNewShardingKey) {
+      this.namespace = namespace;
+      this.address = address;
+      this.stateKey = encodeKey(namespace, address);
+      this.stateFamily = stateFamily;
+      this.keyCoder = keyCoder;
+      this.valueCoder = valueCoder;
+      this.complete = isNewShardingKey;
+      this.allKeysKnown = isNewShardingKey;
+    }
+
+    @Override
+    public void put(K key, V value) {
+      Object structuralKey = keyCoder.structuralValue(key);
+      localAdditions.put(structuralKey, value);
+      existKeyCache.add(structuralKey);
+      nonexistentKeyCache.remove(structuralKey);
+      structuralKeyMapping.put(structuralKey, key);
+    }
+
+    // Initiates a backend state read to fetch all entries if necessary.
+    private Future<Iterable<Map.Entry<ByteString, Iterable<V>>>> getFuture(boolean omitValues) {
+      if (complete) {
+        return Futures.immediateFuture(Collections.emptyList());
+      } else {
+        return reader.multimapFetchAllFuture(omitValues, stateKey, stateFamily, valueCoder);
+      }
+    }
+
+    // Initiates a backend state read to fetch a single entry if necessary.
+    private Future<Iterable<V>> getFutureForKey(K key) {
+      try {
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        keyCoder.encode(key, keyStream);
+        return reader.multimapFetchSingleEntryFuture(
+            keyStream.toByteString(), stateKey, stateFamily, valueCoder);
+      } catch (IOException e) {
+        throw new RuntimeException(e);
+      }
+    }
+
+    @Override
+    public ReadableState<Iterable<V>> get(K key) {
+      return new ReadableState<Iterable<V>>() {
+        Object structuralKey = keyCoder.structuralValue(key);
+
+        @Override
+        public Iterable<V> read() {
+          if (nonexistentKeyCache.contains(structuralKey)
+              || (allKeysKnown && !existKeyCache.contains(structuralKey))) {
+            return Collections.emptyList();
+          }
+          if (localRemovals.contains(structuralKey)) {
+            // this key has been removed locally but the removal hasn't been sent to windmill,
+            // thus values in windmill(if any) are obsolete, and we only care about local values.
+            if (localAdditions.containsKey(structuralKey)) {
+              return Iterables.unmodifiableIterable(localAdditions.get(structuralKey));
+            } else {
+              return Collections.emptyList();
+            }
+          }
+          if (cachedEntries.containsKey(structuralKey) || complete) {
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(
+                    cachedEntries.getOrDefault(structuralKey, Collections.emptyList()),
+                    localAdditions.get(structuralKey)));
+          }
+          Future<Iterable<V>> persistedData = getFutureForKey(key);
+          try (Closeable scope = scopedReadState()) {
+            Iterable<V> persistedValues = persistedData.get();
+            if (Iterables.size(persistedValues) == 0) {
+              if (!localAdditions.containsKey(structuralKey)) {
+                nonexistentKeyCache.add(structuralKey);
+                Preconditions.checkState(
+                    !existKeyCache.contains(structuralKey),
+                    "Key "
+                        + key
+                        + " exists"
+                        + " in existKeyCache but no value in neither windmill nor local additions.");
+              }
+              return Iterables.unmodifiableIterable(localAdditions.get(structuralKey));
+            }
+            if (persistedValues instanceof Weighted) {
+              cachedEntries.put(structuralKey, new ConcatIterables<>());
+              ((ConcatIterables<V>) cachedEntries.get(structuralKey)).extendWith(persistedValues);
+            }
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(persistedValues, localAdditions.get(structuralKey)));
+          } catch (InterruptedException | ExecutionException | IOException e) {
+            if (e instanceof InterruptedException) {
+              Thread.currentThread().interrupt();
+            }
+            throw new RuntimeException("Unable to read Multimap state", e);
+          }
+        }
+
+        @Override
+        @SuppressWarnings("FutureReturnValueIgnored")
+        public ReadableState<Iterable<V>> readLater() {
+          WindmillMultimap.this.getFutureForKey(key);
+          return this;
+        }
+      };
+    }
+
+    @Override
+    protected WorkItemCommitRequest persistDirectly(WindmillStateCache.ForKeyAndFamily cache)
+        throws IOException {
+      if (!cleared && localAdditions.isEmpty() && localRemovals.isEmpty()) {
+        return WorkItemCommitRequest.newBuilder().buildPartial();
+      }
+      WorkItemCommitRequest.Builder commitBuilder = WorkItemCommitRequest.newBuilder();
+      Windmill.TagMultimapUpdateRequest.Builder builder = null;
+      if (cleared) {
+        builder = commitBuilder.addMultimapUpdatesBuilder();
+        builder.setDeleteAll(true);
+        cleared = false;
+      }
+      Set<Object> keysWithUpdates = Sets.newHashSet();
+      keysWithUpdates.addAll(localRemovals);
+      keysWithUpdates.addAll(localAdditions.keySet());
+      if (!keysWithUpdates.isEmpty() && builder == null) {
+        builder = commitBuilder.addMultimapUpdatesBuilder();
+      }
+      for (Object structuralKey : keysWithUpdates) {
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        keyCoder.encode(structuralKeyMapping.get(structuralKey), keyStream);
+        ByteString encodedKey = keyStream.toByteString();
+        Windmill.TagMultimapEntry.Builder entryBuilder = builder.addUpdatesBuilder();
+        entryBuilder.setEntryName(encodedKey);
+        entryBuilder.setDeleteAll(localRemovals.contains(structuralKey));
+        for (V value : localAdditions.get(structuralKey)) {
+          ByteStringOutputStream valueStream = new ByteStringOutputStream();
+          valueCoder.encode(value, valueStream);
+          ByteString encodedValue = valueStream.toByteString();
+          entryBuilder.addValues(encodedValue);
+        }
+        if (cachedEntries.containsKey(structuralKey)) {
+          // Move newly added values from localAdditions to cachedEntries as those new values are
+          // also persisted in Windmill.
+          ((ConcatIterables<V>) cachedEntries.get(structuralKey))
+              .extendWith(localAdditions.get(structuralKey));
+        }
+      }
+
+      if (builder != null) {
+        builder.setTag(stateKey).setStateFamily(stateFamily);
+      }
+      for (Object removedKey : localRemovals) {
+        if (!nonexistentKeyCache.contains(removedKey)) {
+          structuralKeyMapping.remove(removedKey);
+        }
+      }
+      localRemovals = Sets.newHashSet();
+      localAdditions = ArrayListMultimap.create();
+
+      cache.put(namespace, address, this, 1);
+
+      return commitBuilder.buildPartial();
+    }
+
+    @Override
+    public void remove(K key) {
+      Object structuralKey = keyCoder.structuralValue(key);
+      if (!structuralKeyMapping.containsKey(structuralKey)) {
+        structuralKeyMapping.put(structuralKey, key);
+      }
+      if (nonexistentKeyCache.contains(structuralKey)
+          || (allKeysKnown && !existKeyCache.contains(structuralKey))) {
+        return;
+      }
+      if (cachedEntries.containsKey(structuralKey) || !complete) {
+        // there may be data in windmill that need to be removed.
+        localRemovals.add(structuralKey);
+        cachedEntries.remove(structuralKey);
+      } // else: no data in windmill, deleting from local cache is sufficient.
+      localAdditions.removeAll(structuralKey);
+      existKeyCache.remove(structuralKey);
+      nonexistentKeyCache.add(structuralKey);
+    }
+
+    @Override
+    public void clear() {
+      cachedEntries = Maps.newHashMap();
+      existKeyCache = Sets.newHashSet();
+      nonexistentKeyCache = Sets.newHashSet();
+      localAdditions = ArrayListMultimap.create();
+      localRemovals = Sets.newHashSet();
+      structuralKeyMapping = Maps.newHashMap();
+      cleared = true;
+      complete = true;
+      allKeysKnown = true;
+    }
+
+    @Override
+    public ReadableState<Iterable<K>> keys() {
+      return new ReadableState<Iterable<K>>() {
+        @Override
+        public Iterable<K> read() {
+          if (allKeysKnown) {
+            return Iterables.unmodifiableIterable(
+                Iterables.transform(existKeyCache, structuralKeyMapping::get));
+          }
+          Future<Iterable<Entry<ByteString, Iterable<V>>>> persistedData = getFuture(true);
+          try (Closeable scope = scopedReadState()) {
+            Iterable<Entry<ByteString, Iterable<V>>> entries = persistedData.get();
+            Iterable<K> keys =
+                Iterables.transform(
+                    entries,
+                    entry -> {
+                      try {
+                        return keyCoder.decode(entry.getKey().newInput());
+                      } catch (IOException e) {
+                        throw new RuntimeException(e);
+                      }
+                    });
+            keys =
+                Iterables.filter(
+                    keys, key -> !nonexistentKeyCache.contains(keyCoder.structuralValue(key)));
+            if (entries instanceof Weighted) {
+              // This is a known amount of data, cache them all.
+              keys.forEach(
+                  k -> {
+                    Object structuralKey = keyCoder.structuralValue(k);
+                    existKeyCache.add(structuralKey);
+                    structuralKeyMapping.put(structuralKey, k);
+                  });
+              allKeysKnown = true;
+              nonexistentKeyCache = Sets.newHashSet();
+              return Iterables.unmodifiableIterable(
+                  Iterables.transform(existKeyCache, structuralKeyMapping::get));
+            } else {
+              return Iterables.unmodifiableIterable(
+                  Iterables.concat(
+                      // This is the part of keys that are cached.
+                      Iterables.transform(existKeyCache, structuralKeyMapping::get),
+                      // This is the part of the keys returned from Windmill that are not cached.
+                      Iterables.filter(
+                          keys, e -> !existKeyCache.contains(keyCoder.structuralValue(e)))));
+            }
+          } catch (InterruptedException | ExecutionException | IOException e) {
+            if (e instanceof InterruptedException) {
+              Thread.currentThread().interrupt();
+            }
+            throw new RuntimeException("Unable to read state", e);
+          }
+        }
+
+        @Override
+        @SuppressWarnings("FutureReturnValueIgnored")
+        public ReadableState<Iterable<K>> readLater() {
+          WindmillMultimap.this.getFuture(true);
+          return this;
+        }
+      };
+    }
+
+    private MultimapIterables<K, V> mergedCachedEntries() {
+      MultimapIterables<K, V> result = new MultimapIterables<>();
+      for (Entry<Object, Collection<V>> entry : localAdditions.asMap().entrySet()) {
+        K key = structuralKeyMapping.get(entry.getKey());
+        result.extendWith(key, entry.getValue());
+      }
+      for (Entry<Object, Iterable<V>> entry : cachedEntries.entrySet()) {
+        K key = structuralKeyMapping.get(entry.getKey());
+        result.extendWith(key, entry.getValue());
+      }
+      return result;
+    }
+
+    private static class MultimapIterables<K, V> implements Iterable<Entry<K, V>> {
+      Map<K, ConcatIterables<V>> map;
+
+      public MultimapIterables() {
+        this.map = new HashMap<>();
+      }
+
+      public void extendWith(K key, Iterable<V> iterable) {
+        if (!map.containsKey(key)) map.put(key, new ConcatIterables<>());

Review Comment:
   Done



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternals.java:
##########
@@ -1588,7 +1599,472 @@ private Future<Iterable<Map.Entry<ByteString, V>>> getFuture() {
         return reader.valuePrefixFuture(stateKeyPrefix, stateFamily, valueCoder);
       }
     }
-  };
+  }
+
+  private static class WindmillMultimap<K, V> extends SimpleWindmillState
+      implements MultimapState<K, V> {
+
+    private final StateNamespace namespace;
+    private final StateTag<MultimapState<K, V>> address;
+    private final ByteString stateKey;
+    private final String stateFamily;
+    private final Coder<K> keyCoder;
+    private final Coder<V> valueCoder;
+
+    private boolean cleared = false;
+    /**
+     * For any given key, if it's contained in {@link #cachedEntries}, then the complete content of
+     * this key is cached: persisted values of this key in backing store are cached in
+     * cachedEntries, newly added values not yet written to backing store are cached in
+     * localAdditions. If a key is not contained in {@link #cachedEntries} then we don't know if
+     * Windmill contains additional values which also maps to this key, we'll need to read them if
+     * the work item actually wants the content.
+     */
+    private Map<Object, Iterable<V>> cachedEntries = Maps.newHashMap();
+    // Any key presents in existKeyCache is known to exist in the multimap.
+    private Set<Object> existKeyCache = Sets.newHashSet();
+    // If true, any key not in existKeyCache is known to be nonexistent.
+    private boolean allKeysKnown = false;
+    // Any key presents in nonexistentKeyCache is known to be nonexistent in the multimap.
+    private Set<Object> nonexistentKeyCache = Sets.newHashSet();
+
+    private boolean complete = false;
+    private Multimap<Object, V> localAdditions = ArrayListMultimap.create();
+    // All keys that are pending delete. If a key exist in both localRemovals and localAdditions:
+    // new values in localAdditions will be added after old values are removed.
+    private Set<Object> localRemovals = Sets.newHashSet();
+    // structuralKeyMapping maps from the structuralKeys to the actual keys. Any key in
+    // cachedEntries, existKeyCache, nonexistentKeyCache, localAdditions and localRemovals should be
+    // included in this mapping.
+    private Map<Object, K> structuralKeyMapping = Maps.newHashMap();
+
+    private WindmillMultimap(
+        StateNamespace namespace,
+        StateTag<MultimapState<K, V>> address,
+        String stateFamily,
+        Coder<K> keyCoder,
+        Coder<V> valueCoder,
+        boolean isNewShardingKey) {
+      this.namespace = namespace;
+      this.address = address;
+      this.stateKey = encodeKey(namespace, address);
+      this.stateFamily = stateFamily;
+      this.keyCoder = keyCoder;
+      this.valueCoder = valueCoder;
+      this.complete = isNewShardingKey;
+      this.allKeysKnown = isNewShardingKey;
+    }
+
+    @Override
+    public void put(K key, V value) {
+      Object structuralKey = keyCoder.structuralValue(key);
+      localAdditions.put(structuralKey, value);
+      existKeyCache.add(structuralKey);
+      nonexistentKeyCache.remove(structuralKey);
+      structuralKeyMapping.put(structuralKey, key);
+    }
+
+    // Initiates a backend state read to fetch all entries if necessary.
+    private Future<Iterable<Map.Entry<ByteString, Iterable<V>>>> getFuture(boolean omitValues) {
+      if (complete) {
+        return Futures.immediateFuture(Collections.emptyList());
+      } else {
+        return reader.multimapFetchAllFuture(omitValues, stateKey, stateFamily, valueCoder);
+      }
+    }
+
+    // Initiates a backend state read to fetch a single entry if necessary.
+    private Future<Iterable<V>> getFutureForKey(K key) {
+      try {
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        keyCoder.encode(key, keyStream);
+        return reader.multimapFetchSingleEntryFuture(
+            keyStream.toByteString(), stateKey, stateFamily, valueCoder);
+      } catch (IOException e) {
+        throw new RuntimeException(e);
+      }
+    }
+
+    @Override
+    public ReadableState<Iterable<V>> get(K key) {
+      return new ReadableState<Iterable<V>>() {
+        Object structuralKey = keyCoder.structuralValue(key);
+
+        @Override
+        public Iterable<V> read() {
+          if (nonexistentKeyCache.contains(structuralKey)
+              || (allKeysKnown && !existKeyCache.contains(structuralKey))) {
+            return Collections.emptyList();
+          }
+          if (localRemovals.contains(structuralKey)) {
+            // this key has been removed locally but the removal hasn't been sent to windmill,
+            // thus values in windmill(if any) are obsolete, and we only care about local values.
+            if (localAdditions.containsKey(structuralKey)) {
+              return Iterables.unmodifiableIterable(localAdditions.get(structuralKey));
+            } else {
+              return Collections.emptyList();
+            }
+          }
+          if (cachedEntries.containsKey(structuralKey) || complete) {
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(
+                    cachedEntries.getOrDefault(structuralKey, Collections.emptyList()),
+                    localAdditions.get(structuralKey)));
+          }
+          Future<Iterable<V>> persistedData = getFutureForKey(key);
+          try (Closeable scope = scopedReadState()) {
+            Iterable<V> persistedValues = persistedData.get();
+            if (Iterables.size(persistedValues) == 0) {
+              if (!localAdditions.containsKey(structuralKey)) {
+                nonexistentKeyCache.add(structuralKey);
+                Preconditions.checkState(
+                    !existKeyCache.contains(structuralKey),
+                    "Key "
+                        + key
+                        + " exists"
+                        + " in existKeyCache but no value in neither windmill nor local additions.");
+              }
+              return Iterables.unmodifiableIterable(localAdditions.get(structuralKey));
+            }
+            if (persistedValues instanceof Weighted) {
+              cachedEntries.put(structuralKey, new ConcatIterables<>());
+              ((ConcatIterables<V>) cachedEntries.get(structuralKey)).extendWith(persistedValues);
+            }
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(persistedValues, localAdditions.get(structuralKey)));
+          } catch (InterruptedException | ExecutionException | IOException e) {
+            if (e instanceof InterruptedException) {
+              Thread.currentThread().interrupt();
+            }
+            throw new RuntimeException("Unable to read Multimap state", e);
+          }
+        }
+
+        @Override
+        @SuppressWarnings("FutureReturnValueIgnored")
+        public ReadableState<Iterable<V>> readLater() {
+          WindmillMultimap.this.getFutureForKey(key);
+          return this;
+        }
+      };
+    }
+
+    @Override
+    protected WorkItemCommitRequest persistDirectly(WindmillStateCache.ForKeyAndFamily cache)
+        throws IOException {
+      if (!cleared && localAdditions.isEmpty() && localRemovals.isEmpty()) {
+        return WorkItemCommitRequest.newBuilder().buildPartial();
+      }
+      WorkItemCommitRequest.Builder commitBuilder = WorkItemCommitRequest.newBuilder();
+      Windmill.TagMultimapUpdateRequest.Builder builder = null;
+      if (cleared) {
+        builder = commitBuilder.addMultimapUpdatesBuilder();
+        builder.setDeleteAll(true);
+        cleared = false;
+      }
+      Set<Object> keysWithUpdates = Sets.newHashSet();
+      keysWithUpdates.addAll(localRemovals);
+      keysWithUpdates.addAll(localAdditions.keySet());
+      if (!keysWithUpdates.isEmpty() && builder == null) {
+        builder = commitBuilder.addMultimapUpdatesBuilder();
+      }
+      for (Object structuralKey : keysWithUpdates) {
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        keyCoder.encode(structuralKeyMapping.get(structuralKey), keyStream);
+        ByteString encodedKey = keyStream.toByteString();
+        Windmill.TagMultimapEntry.Builder entryBuilder = builder.addUpdatesBuilder();
+        entryBuilder.setEntryName(encodedKey);
+        entryBuilder.setDeleteAll(localRemovals.contains(structuralKey));
+        for (V value : localAdditions.get(structuralKey)) {
+          ByteStringOutputStream valueStream = new ByteStringOutputStream();
+          valueCoder.encode(value, valueStream);
+          ByteString encodedValue = valueStream.toByteString();
+          entryBuilder.addValues(encodedValue);
+        }
+        if (cachedEntries.containsKey(structuralKey)) {
+          // Move newly added values from localAdditions to cachedEntries as those new values are
+          // also persisted in Windmill.
+          ((ConcatIterables<V>) cachedEntries.get(structuralKey))
+              .extendWith(localAdditions.get(structuralKey));
+        }
+      }
+
+      if (builder != null) {
+        builder.setTag(stateKey).setStateFamily(stateFamily);
+      }
+      for (Object removedKey : localRemovals) {
+        if (!nonexistentKeyCache.contains(removedKey)) {
+          structuralKeyMapping.remove(removedKey);
+        }
+      }
+      localRemovals = Sets.newHashSet();
+      localAdditions = ArrayListMultimap.create();
+
+      cache.put(namespace, address, this, 1);
+
+      return commitBuilder.buildPartial();
+    }
+
+    @Override
+    public void remove(K key) {
+      Object structuralKey = keyCoder.structuralValue(key);
+      if (!structuralKeyMapping.containsKey(structuralKey)) {
+        structuralKeyMapping.put(structuralKey, key);
+      }
+      if (nonexistentKeyCache.contains(structuralKey)
+          || (allKeysKnown && !existKeyCache.contains(structuralKey))) {
+        return;
+      }
+      if (cachedEntries.containsKey(structuralKey) || !complete) {
+        // there may be data in windmill that need to be removed.
+        localRemovals.add(structuralKey);
+        cachedEntries.remove(structuralKey);
+      } // else: no data in windmill, deleting from local cache is sufficient.
+      localAdditions.removeAll(structuralKey);
+      existKeyCache.remove(structuralKey);
+      nonexistentKeyCache.add(structuralKey);
+    }
+
+    @Override
+    public void clear() {
+      cachedEntries = Maps.newHashMap();
+      existKeyCache = Sets.newHashSet();
+      nonexistentKeyCache = Sets.newHashSet();
+      localAdditions = ArrayListMultimap.create();
+      localRemovals = Sets.newHashSet();
+      structuralKeyMapping = Maps.newHashMap();
+      cleared = true;
+      complete = true;
+      allKeysKnown = true;
+    }
+
+    @Override
+    public ReadableState<Iterable<K>> keys() {
+      return new ReadableState<Iterable<K>>() {
+        @Override
+        public Iterable<K> read() {
+          if (allKeysKnown) {
+            return Iterables.unmodifiableIterable(
+                Iterables.transform(existKeyCache, structuralKeyMapping::get));
+          }
+          Future<Iterable<Entry<ByteString, Iterable<V>>>> persistedData = getFuture(true);
+          try (Closeable scope = scopedReadState()) {
+            Iterable<Entry<ByteString, Iterable<V>>> entries = persistedData.get();
+            Iterable<K> keys =
+                Iterables.transform(
+                    entries,
+                    entry -> {
+                      try {
+                        return keyCoder.decode(entry.getKey().newInput());
+                      } catch (IOException e) {
+                        throw new RuntimeException(e);
+                      }
+                    });
+            keys =
+                Iterables.filter(
+                    keys, key -> !nonexistentKeyCache.contains(keyCoder.structuralValue(key)));
+            if (entries instanceof Weighted) {
+              // This is a known amount of data, cache them all.
+              keys.forEach(
+                  k -> {
+                    Object structuralKey = keyCoder.structuralValue(k);
+                    existKeyCache.add(structuralKey);
+                    structuralKeyMapping.put(structuralKey, k);
+                  });
+              allKeysKnown = true;
+              nonexistentKeyCache = Sets.newHashSet();
+              return Iterables.unmodifiableIterable(
+                  Iterables.transform(existKeyCache, structuralKeyMapping::get));
+            } else {
+              return Iterables.unmodifiableIterable(
+                  Iterables.concat(
+                      // This is the part of keys that are cached.
+                      Iterables.transform(existKeyCache, structuralKeyMapping::get),
+                      // This is the part of the keys returned from Windmill that are not cached.
+                      Iterables.filter(
+                          keys, e -> !existKeyCache.contains(keyCoder.structuralValue(e)))));
+            }
+          } catch (InterruptedException | ExecutionException | IOException e) {
+            if (e instanceof InterruptedException) {
+              Thread.currentThread().interrupt();
+            }
+            throw new RuntimeException("Unable to read state", e);
+          }
+        }
+
+        @Override
+        @SuppressWarnings("FutureReturnValueIgnored")
+        public ReadableState<Iterable<K>> readLater() {
+          WindmillMultimap.this.getFuture(true);
+          return this;
+        }
+      };
+    }
+
+    private MultimapIterables<K, V> mergedCachedEntries() {
+      MultimapIterables<K, V> result = new MultimapIterables<>();
+      for (Entry<Object, Collection<V>> entry : localAdditions.asMap().entrySet()) {
+        K key = structuralKeyMapping.get(entry.getKey());
+        result.extendWith(key, entry.getValue());
+      }
+      for (Entry<Object, Iterable<V>> entry : cachedEntries.entrySet()) {
+        K key = structuralKeyMapping.get(entry.getKey());
+        result.extendWith(key, entry.getValue());
+      }
+      return result;
+    }
+
+    private static class MultimapIterables<K, V> implements Iterable<Entry<K, V>> {
+      Map<K, ConcatIterables<V>> map;
+
+      public MultimapIterables() {
+        this.map = new HashMap<>();
+      }
+
+      public void extendWith(K key, Iterable<V> iterable) {
+        if (!map.containsKey(key)) map.put(key, new ConcatIterables<>());
+        map.get(key).extendWith(iterable);
+      }
+
+      @Override
+      public Iterator<Entry<K, V>> iterator() {
+        return Iterators.concat(
+            Iterables.transform(
+                    map.keySet(),

Review Comment:
   Done



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@beam.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org