You are viewing a plain text version of this content. The canonical link for it is here.
Posted to github@beam.apache.org by GitBox <gi...@apache.org> on 2021/02/12 06:34:39 UTC

[GitHub] [beam] reuvenlax commented on a change in pull request #13862: [BEAM-11707] Change WindmillStateCache cache invalidation to be based…

reuvenlax commented on a change in pull request #13862:
URL: https://github.com/apache/beam/pull/13862#discussion_r574972621



##########
File path: runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateCache.java
##########
@@ -50,55 +50,68 @@
  * StreamingDataflowWorker} ensures that a single computation * processing key is executing on one
  * thread at a time, so this is safe.
  */
-@SuppressWarnings({
-  "nullness" // TODO(https://issues.apache.org/jira/browse/BEAM-10402)
-})
 public class WindmillStateCache implements StatusDataProvider {
   // Convert Megabytes to bytes
   private static final long MEGABYTES = 1024 * 1024;
   // Estimate of overhead per StateId.
-  private static final int PER_STATE_ID_OVERHEAD = 20;
+  private static final long PER_STATE_ID_OVERHEAD = 28;
   // Initial size of hash tables per entry.
   private static final int INITIAL_HASH_MAP_CAPACITY = 4;
   // Overhead of each hash map entry.
   private static final int HASH_MAP_ENTRY_OVERHEAD = 16;
-  // Overhead of each cache entry.  Three longs, plus a hash table.
+  // Overhead of each StateCacheEntry.  One long, plus a hash table.
   private static final int PER_CACHE_ENTRY_OVERHEAD =
-      24 + HASH_MAP_ENTRY_OVERHEAD * INITIAL_HASH_MAP_CAPACITY;
+      8 + HASH_MAP_ENTRY_OVERHEAD * INITIAL_HASH_MAP_CAPACITY;
 
   private Cache<StateId, StateCacheEntry> stateCache;
-  private HashMultimap<WindmillComputationKey, StateId> keyIndex =
-      HashMultimap.<WindmillComputationKey, StateId>create();
-  private long displayedWeight = 0; // Only used for status pages and unit tests.
+  // Contains the current valid ForKey object. Entries in the cache are keyed by ForKey with pointer
+  // equality so entries may be invalidated by creating a new key object, rendering the previous
+  // entries inaccessible. They will be evicted through normal cache operation.
+  private ConcurrentMap<WindmillComputationKey, ForKey> keyIndex =
+      new MapMaker().weakValues().concurrencyLevel(4).makeMap();
   private long workerCacheBytes; // Copy workerCacheMb and convert to bytes.
 
-  public WindmillStateCache(Integer workerCacheMb) {
+  public WindmillStateCache(long workerCacheMb) {
     final Weigher<Weighted, Weighted> weigher = Weighers.weightedKeysAndValues();
     workerCacheBytes = workerCacheMb * MEGABYTES;
     stateCache =
         CacheBuilder.newBuilder()
             .maximumWeight(workerCacheBytes)
             .recordStats()
             .weigher(weigher)
-            .removalListener(
-                removal -> {
-                  if (removal.getCause() != RemovalCause.REPLACED) {
-                    synchronized (this) {
-                      StateId id = (StateId) removal.getKey();
-                      if (removal.getCause() != RemovalCause.EXPLICIT) {
-                        // When we invalidate a key explicitly, we'll also update the keyIndex, so
-                        // no need to do it here.
-                        keyIndex.remove(id.getWindmillComputationKey(), id);
-                      }
-                      displayedWeight -= weigher.weigh(id, removal.getValue());
-                    }
-                  }
-                })
+            .concurrencyLevel(4)
             .build();
   }
 
+  private static class EntryStats {
+    long entries;
+    long idWeight;
+    long entryWeight;
+    long entryValues;
+    long maxEntryValues;
+  }
+
+  private EntryStats calculateEntryStats() {
+    class CacheConsumer implements BiConsumer<StateId, StateCacheEntry> {
+      public EntryStats stats = new EntryStats();
+
+      @Override
+      public void accept(StateId stateId, StateCacheEntry stateCacheEntry) {
+        stats.entries++;
+        stats.idWeight += stateId.getWeight();
+        stats.entryWeight += stateCacheEntry.getWeight();
+        stats.entryValues += stateCacheEntry.values.size();
+        stats.maxEntryValues = Math.max(stats.maxEntryValues, stateCacheEntry.values.size());
+      }
+    }
+    CacheConsumer consumer = new CacheConsumer();

Review comment:
       Note that you could also do the following, since BiConsumer is a function interface:
   
   public EntryStats stats = new EntryStats();
   BiConsumer<StateId, StateCacheEntry> consumer = (stateId, stateCacheEntry) -> {
      ...
   };

##########
File path: runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateCache.java
##########
@@ -116,222 +129,198 @@ private ForComputation(String computation) {
 
     /** Invalidate all cache entries for this computation and {@code processingKey}. */
     public void invalidate(ByteString processingKey, long shardingKey) {
-      synchronized (this) {
-        WindmillComputationKey key =
-            WindmillComputationKey.create(computation, processingKey, shardingKey);
-        for (StateId id : keyIndex.removeAll(key)) {
-          stateCache.invalidate(id);
-        }
-      }
+      WindmillComputationKey key =
+          WindmillComputationKey.create(computation, processingKey, shardingKey);
+      // By removing the ForKey object, all state for the key is orphaned in the cache and will
+      // be removed by normal cache cleanup.
+      keyIndex.remove(key);
     }
 
-    /** Returns a per-computation, per-key view of the state cache. */
-    public ForKey forKey(
+    /** Returns a per-computation, per-key, per-state-family view of the state cache. */
+    public ForKeyAndFamily forKey(
         WindmillComputationKey computationKey,
         String stateFamily,
         long cacheToken,
         long workToken) {
-      return new ForKey(computationKey, stateFamily, cacheToken, workToken);
+      ForKey forKey = keyIndex.get(computationKey);

Review comment:
       This could be racy if there are multiple accesses to the cache at the same time. I suspect it's ok because access is already synchronized by key (one keyed worked item active). However. please add a comment explaining the thread safety.

##########
File path: runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateCache.java
##########
@@ -116,222 +129,198 @@ private ForComputation(String computation) {
 
     /** Invalidate all cache entries for this computation and {@code processingKey}. */
     public void invalidate(ByteString processingKey, long shardingKey) {
-      synchronized (this) {
-        WindmillComputationKey key =
-            WindmillComputationKey.create(computation, processingKey, shardingKey);
-        for (StateId id : keyIndex.removeAll(key)) {
-          stateCache.invalidate(id);
-        }
-      }
+      WindmillComputationKey key =
+          WindmillComputationKey.create(computation, processingKey, shardingKey);
+      // By removing the ForKey object, all state for the key is orphaned in the cache and will
+      // be removed by normal cache cleanup.
+      keyIndex.remove(key);
     }
 
-    /** Returns a per-computation, per-key view of the state cache. */
-    public ForKey forKey(
+    /** Returns a per-computation, per-key, per-state-family view of the state cache. */
+    public ForKeyAndFamily forKey(
         WindmillComputationKey computationKey,
         String stateFamily,
         long cacheToken,
         long workToken) {
-      return new ForKey(computationKey, stateFamily, cacheToken, workToken);
+      ForKey forKey = keyIndex.get(computationKey);
+      if (forKey == null || !forKey.updateTokens(cacheToken, workToken)) {
+        forKey = new ForKey(computationKey, cacheToken, workToken);
+        // We prefer this implementation to using compute because that is implemented similarly for
+        // ConcurrentHashMap with the downside of it performing inserts for unchanged existing
+        // values as well.
+        keyIndex.put(computationKey, forKey);
+      }
+      return new ForKeyAndFamily(forKey, stateFamily);
     }
   }
 
   /** Per-computation, per-key view of the state cache. */
-  public class ForKey {
-
+  // Note that we utilize the default equality and hashCode for this class based upon the instance
+  // (instead of the fields) to optimize cache invalidation.
+  private static class ForKey {
     private final WindmillComputationKey computationKey;
-    private final String stateFamily;
     // Cache token must be consistent for the key for the cache to be valid.
     private final long cacheToken;
 
     // The work token for processing must be greater than the last work token.  As work items are
     // increasing for a key, a less-than or equal to work token indicates that the current token is
-    // for stale processing. We don't use the cache so that fetches performed will fail with a no
-    // longer valid work token.
-    private final long workToken;
+    // for stale processing.
+    private long workToken;
 
-    private ForKey(
-        WindmillComputationKey computationKey,
-        String stateFamily,
-        long cacheToken,
-        long workToken) {
+    private ForKey(WindmillComputationKey computationKey, long cacheToken, long workToken) {
       this.computationKey = computationKey;
-      this.stateFamily = stateFamily;
       this.cacheToken = cacheToken;
       this.workToken = workToken;
     }
 
-    public <T extends State> T get(StateNamespace namespace, StateTag<T> address) {
-      return WindmillStateCache.this.get(
-          computationKey, stateFamily, cacheToken, workToken, namespace, address);
-    }
-
-    // Note that once a value has been put for a given workToken, get calls with that same workToken
-    // will fail. This is ok as we only put entries when we are building the commit and will no
-    // longer be performing gets for the same work token.
-    public <T extends State> void put(
-        StateNamespace namespace, StateTag<T> address, T value, long weight) {
-      WindmillStateCache.this.put(
-          computationKey, stateFamily, cacheToken, workToken, namespace, address, value, weight);
+    private boolean updateTokens(long cacheToken, long workToken) {
+      if (this.cacheToken != cacheToken || workToken <= this.workToken) {
+        return false;
+      }
+      this.workToken = workToken;
+      return true;
     }
   }
 
-  /** Returns a per-computation view of the state cache. */
-  public ForComputation forComputation(String computation) {
-    return new ForComputation(computation);
-  }
+  /**
+   * Per-computation, per-key, per-family view of the state cache. Modifications are cached locally
+   * and must be flushed to the cache by calling persist.
+   */
+  public class ForKeyAndFamily {
+    final ForKey forKey;
+    final String stateFamily;
+    private HashMap<StateId, StateCacheEntry> localCache;
 
-  private <T extends State> T get(
-      WindmillComputationKey computationKey,
-      String stateFamily,
-      long cacheToken,
-      long workToken,
-      StateNamespace namespace,
-      StateTag<T> address) {
-    StateId id = new StateId(computationKey, stateFamily, namespace);
-    StateCacheEntry entry = stateCache.getIfPresent(id);
-    if (entry == null) {
-      return null;
+    private ForKeyAndFamily(ForKey forKey, String stateFamily) {
+      this.forKey = forKey;
+      this.stateFamily = stateFamily;
+      localCache = new HashMap<>();
     }
-    if (entry.getCacheToken() != cacheToken) {
-      stateCache.invalidate(id);
-      return null;
+
+    public String getStateFamily() {
+      return stateFamily;
     }
-    if (workToken <= entry.getLastWorkToken()) {
-      // We don't used the cached item but we don't invalidate it.
-      return null;
+
+    public <T extends State> @Nullable T get(StateNamespace namespace, StateTag<T> address) {
+      StateId id = new StateId(forKey, stateFamily, namespace);
+      @SuppressWarnings("nullness") // Unsure how to annotate lambda return allowing null.
+      StateCacheEntry entry = localCache.computeIfAbsent(id, key -> stateCache.getIfPresent(key));
+      return entry == null ? null : entry.get(namespace, address);
     }
-    return entry.get(namespace, address);
-  }
 
-  private <T extends State> void put(
-      WindmillComputationKey computationKey,
-      String stateFamily,
-      long cacheToken,
-      long workToken,
-      StateNamespace namespace,
-      StateTag<T> address,
-      T value,
-      long weight) {
-    StateId id = new StateId(computationKey, stateFamily, namespace);
-    StateCacheEntry entry = stateCache.getIfPresent(id);
-    if (entry == null) {
-      synchronized (this) {
-        keyIndex.put(id.getWindmillComputationKey(), id);
+    public <T extends State> void put(
+        StateNamespace namespace, StateTag<T> address, T value, long weight) {
+      StateId id = new StateId(forKey, stateFamily, namespace);
+      @Nullable StateCacheEntry entry = localCache.get(id);
+      if (entry == null) {
+        entry = stateCache.getIfPresent(id);
+        if (entry == null) {
+          entry = new StateCacheEntry();
+        }
+        boolean hadValue = localCache.putIfAbsent(id, entry) != null;
+        assert (!hadValue);
       }
+      entry.put(namespace, address, value, weight);
     }
-    if (entry == null || entry.getCacheToken() != cacheToken) {
-      entry = new StateCacheEntry(cacheToken);
-      this.displayedWeight += id.getWeight();
-      this.displayedWeight += entry.getWeight();
+
+    public void persist() {
+      localCache.forEach((id, entry) -> stateCache.put(id, entry));
     }
-    entry.setLastWorkToken(workToken);
-    this.displayedWeight += entry.put(namespace, address, value, weight);
-    // Always add back to the cache to update the weight.
-    stateCache.put(id, entry);
   }
 
-  /** Struct identifying a cache entry that contains all data for a key and namespace. */
-  private static class StateId implements Weighted {
+  /** Returns a per-computation view of the state cache. */
+  public ForComputation forComputation(String computation) {
+    return new ForComputation(computation);
+  }
 
-    private final WindmillComputationKey computationKey;
+  /**
+   * Struct identifying a cache entry that contains all data for a ForKey instance and namespace.
+   */
+  private static class StateId implements Weighted {

Review comment:
       Wondering whether it's worth using AutoValue for some of these classes. With autovalue you can simply annotate the hashCode method with @Memoize to avoid recomputing it.

##########
File path: runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateCache.java
##########
@@ -116,222 +129,198 @@ private ForComputation(String computation) {
 
     /** Invalidate all cache entries for this computation and {@code processingKey}. */
     public void invalidate(ByteString processingKey, long shardingKey) {
-      synchronized (this) {
-        WindmillComputationKey key =
-            WindmillComputationKey.create(computation, processingKey, shardingKey);
-        for (StateId id : keyIndex.removeAll(key)) {
-          stateCache.invalidate(id);
-        }
-      }
+      WindmillComputationKey key =
+          WindmillComputationKey.create(computation, processingKey, shardingKey);
+      // By removing the ForKey object, all state for the key is orphaned in the cache and will
+      // be removed by normal cache cleanup.
+      keyIndex.remove(key);
     }
 
-    /** Returns a per-computation, per-key view of the state cache. */
-    public ForKey forKey(
+    /** Returns a per-computation, per-key, per-state-family view of the state cache. */
+    public ForKeyAndFamily forKey(
         WindmillComputationKey computationKey,
         String stateFamily,
         long cacheToken,
         long workToken) {
-      return new ForKey(computationKey, stateFamily, cacheToken, workToken);
+      ForKey forKey = keyIndex.get(computationKey);
+      if (forKey == null || !forKey.updateTokens(cacheToken, workToken)) {
+        forKey = new ForKey(computationKey, cacheToken, workToken);
+        // We prefer this implementation to using compute because that is implemented similarly for
+        // ConcurrentHashMap with the downside of it performing inserts for unchanged existing
+        // values as well.
+        keyIndex.put(computationKey, forKey);
+      }
+      return new ForKeyAndFamily(forKey, stateFamily);
     }
   }
 
   /** Per-computation, per-key view of the state cache. */
-  public class ForKey {
-
+  // Note that we utilize the default equality and hashCode for this class based upon the instance
+  // (instead of the fields) to optimize cache invalidation.
+  private static class ForKey {
     private final WindmillComputationKey computationKey;
-    private final String stateFamily;
     // Cache token must be consistent for the key for the cache to be valid.
     private final long cacheToken;
 
     // The work token for processing must be greater than the last work token.  As work items are
     // increasing for a key, a less-than or equal to work token indicates that the current token is
-    // for stale processing. We don't use the cache so that fetches performed will fail with a no
-    // longer valid work token.
-    private final long workToken;
+    // for stale processing.
+    private long workToken;
 
-    private ForKey(
-        WindmillComputationKey computationKey,
-        String stateFamily,
-        long cacheToken,
-        long workToken) {
+    private ForKey(WindmillComputationKey computationKey, long cacheToken, long workToken) {
       this.computationKey = computationKey;
-      this.stateFamily = stateFamily;
       this.cacheToken = cacheToken;
       this.workToken = workToken;
     }
 
-    public <T extends State> T get(StateNamespace namespace, StateTag<T> address) {
-      return WindmillStateCache.this.get(
-          computationKey, stateFamily, cacheToken, workToken, namespace, address);
-    }
-
-    // Note that once a value has been put for a given workToken, get calls with that same workToken
-    // will fail. This is ok as we only put entries when we are building the commit and will no
-    // longer be performing gets for the same work token.
-    public <T extends State> void put(
-        StateNamespace namespace, StateTag<T> address, T value, long weight) {
-      WindmillStateCache.this.put(
-          computationKey, stateFamily, cacheToken, workToken, namespace, address, value, weight);
+    private boolean updateTokens(long cacheToken, long workToken) {
+      if (this.cacheToken != cacheToken || workToken <= this.workToken) {
+        return false;
+      }
+      this.workToken = workToken;
+      return true;
     }
   }
 
-  /** Returns a per-computation view of the state cache. */
-  public ForComputation forComputation(String computation) {
-    return new ForComputation(computation);
-  }
+  /**
+   * Per-computation, per-key, per-family view of the state cache. Modifications are cached locally
+   * and must be flushed to the cache by calling persist.
+   */
+  public class ForKeyAndFamily {
+    final ForKey forKey;
+    final String stateFamily;
+    private HashMap<StateId, StateCacheEntry> localCache;
 
-  private <T extends State> T get(
-      WindmillComputationKey computationKey,
-      String stateFamily,
-      long cacheToken,
-      long workToken,
-      StateNamespace namespace,
-      StateTag<T> address) {
-    StateId id = new StateId(computationKey, stateFamily, namespace);
-    StateCacheEntry entry = stateCache.getIfPresent(id);
-    if (entry == null) {
-      return null;
+    private ForKeyAndFamily(ForKey forKey, String stateFamily) {
+      this.forKey = forKey;
+      this.stateFamily = stateFamily;
+      localCache = new HashMap<>();
     }
-    if (entry.getCacheToken() != cacheToken) {
-      stateCache.invalidate(id);
-      return null;
+
+    public String getStateFamily() {
+      return stateFamily;
     }
-    if (workToken <= entry.getLastWorkToken()) {
-      // We don't used the cached item but we don't invalidate it.
-      return null;
+
+    public <T extends State> @Nullable T get(StateNamespace namespace, StateTag<T> address) {
+      StateId id = new StateId(forKey, stateFamily, namespace);
+      @SuppressWarnings("nullness") // Unsure how to annotate lambda return allowing null.
+      StateCacheEntry entry = localCache.computeIfAbsent(id, key -> stateCache.getIfPresent(key));

Review comment:
       Can you add a cast?
   
   (Function<StateTag<T>, @Nullable StateCacheEntry) (key -> stateCache.getIfPresent(key))
   

##########
File path: runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateCache.java
##########
@@ -116,222 +129,198 @@ private ForComputation(String computation) {
 
     /** Invalidate all cache entries for this computation and {@code processingKey}. */
     public void invalidate(ByteString processingKey, long shardingKey) {
-      synchronized (this) {
-        WindmillComputationKey key =
-            WindmillComputationKey.create(computation, processingKey, shardingKey);
-        for (StateId id : keyIndex.removeAll(key)) {
-          stateCache.invalidate(id);
-        }
-      }
+      WindmillComputationKey key =
+          WindmillComputationKey.create(computation, processingKey, shardingKey);
+      // By removing the ForKey object, all state for the key is orphaned in the cache and will
+      // be removed by normal cache cleanup.
+      keyIndex.remove(key);
     }
 
-    /** Returns a per-computation, per-key view of the state cache. */
-    public ForKey forKey(
+    /** Returns a per-computation, per-key, per-state-family view of the state cache. */
+    public ForKeyAndFamily forKey(
         WindmillComputationKey computationKey,
         String stateFamily,
         long cacheToken,
         long workToken) {
-      return new ForKey(computationKey, stateFamily, cacheToken, workToken);
+      ForKey forKey = keyIndex.get(computationKey);
+      if (forKey == null || !forKey.updateTokens(cacheToken, workToken)) {
+        forKey = new ForKey(computationKey, cacheToken, workToken);
+        // We prefer this implementation to using compute because that is implemented similarly for
+        // ConcurrentHashMap with the downside of it performing inserts for unchanged existing
+        // values as well.
+        keyIndex.put(computationKey, forKey);
+      }
+      return new ForKeyAndFamily(forKey, stateFamily);
     }
   }
 
   /** Per-computation, per-key view of the state cache. */
-  public class ForKey {
-
+  // Note that we utilize the default equality and hashCode for this class based upon the instance
+  // (instead of the fields) to optimize cache invalidation.
+  private static class ForKey {
     private final WindmillComputationKey computationKey;
-    private final String stateFamily;
     // Cache token must be consistent for the key for the cache to be valid.
     private final long cacheToken;
 
     // The work token for processing must be greater than the last work token.  As work items are
     // increasing for a key, a less-than or equal to work token indicates that the current token is
-    // for stale processing. We don't use the cache so that fetches performed will fail with a no
-    // longer valid work token.
-    private final long workToken;
+    // for stale processing.
+    private long workToken;
 
-    private ForKey(
-        WindmillComputationKey computationKey,
-        String stateFamily,
-        long cacheToken,
-        long workToken) {
+    private ForKey(WindmillComputationKey computationKey, long cacheToken, long workToken) {
       this.computationKey = computationKey;
-      this.stateFamily = stateFamily;
       this.cacheToken = cacheToken;
       this.workToken = workToken;
     }
 
-    public <T extends State> T get(StateNamespace namespace, StateTag<T> address) {
-      return WindmillStateCache.this.get(
-          computationKey, stateFamily, cacheToken, workToken, namespace, address);
-    }
-
-    // Note that once a value has been put for a given workToken, get calls with that same workToken
-    // will fail. This is ok as we only put entries when we are building the commit and will no
-    // longer be performing gets for the same work token.
-    public <T extends State> void put(
-        StateNamespace namespace, StateTag<T> address, T value, long weight) {
-      WindmillStateCache.this.put(
-          computationKey, stateFamily, cacheToken, workToken, namespace, address, value, weight);
+    private boolean updateTokens(long cacheToken, long workToken) {
+      if (this.cacheToken != cacheToken || workToken <= this.workToken) {
+        return false;
+      }
+      this.workToken = workToken;
+      return true;
     }
   }
 
-  /** Returns a per-computation view of the state cache. */
-  public ForComputation forComputation(String computation) {
-    return new ForComputation(computation);
-  }
+  /**
+   * Per-computation, per-key, per-family view of the state cache. Modifications are cached locally
+   * and must be flushed to the cache by calling persist.
+   */
+  public class ForKeyAndFamily {
+    final ForKey forKey;
+    final String stateFamily;
+    private HashMap<StateId, StateCacheEntry> localCache;
 
-  private <T extends State> T get(
-      WindmillComputationKey computationKey,
-      String stateFamily,
-      long cacheToken,
-      long workToken,
-      StateNamespace namespace,
-      StateTag<T> address) {
-    StateId id = new StateId(computationKey, stateFamily, namespace);
-    StateCacheEntry entry = stateCache.getIfPresent(id);
-    if (entry == null) {
-      return null;
+    private ForKeyAndFamily(ForKey forKey, String stateFamily) {
+      this.forKey = forKey;
+      this.stateFamily = stateFamily;
+      localCache = new HashMap<>();
     }
-    if (entry.getCacheToken() != cacheToken) {
-      stateCache.invalidate(id);
-      return null;
+
+    public String getStateFamily() {
+      return stateFamily;
     }
-    if (workToken <= entry.getLastWorkToken()) {
-      // We don't used the cached item but we don't invalidate it.
-      return null;
+
+    public <T extends State> @Nullable T get(StateNamespace namespace, StateTag<T> address) {
+      StateId id = new StateId(forKey, stateFamily, namespace);
+      @SuppressWarnings("nullness") // Unsure how to annotate lambda return allowing null.
+      StateCacheEntry entry = localCache.computeIfAbsent(id, key -> stateCache.getIfPresent(key));
+      return entry == null ? null : entry.get(namespace, address);
     }
-    return entry.get(namespace, address);
-  }
 
-  private <T extends State> void put(
-      WindmillComputationKey computationKey,
-      String stateFamily,
-      long cacheToken,
-      long workToken,
-      StateNamespace namespace,
-      StateTag<T> address,
-      T value,
-      long weight) {
-    StateId id = new StateId(computationKey, stateFamily, namespace);
-    StateCacheEntry entry = stateCache.getIfPresent(id);
-    if (entry == null) {
-      synchronized (this) {
-        keyIndex.put(id.getWindmillComputationKey(), id);
+    public <T extends State> void put(
+        StateNamespace namespace, StateTag<T> address, T value, long weight) {
+      StateId id = new StateId(forKey, stateFamily, namespace);
+      @Nullable StateCacheEntry entry = localCache.get(id);
+      if (entry == null) {
+        entry = stateCache.getIfPresent(id);
+        if (entry == null) {
+          entry = new StateCacheEntry();
+        }
+        boolean hadValue = localCache.putIfAbsent(id, entry) != null;
+        assert (!hadValue);

Review comment:
       Use Preconditions class instead of Java assert

##########
File path: runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternals.java
##########
@@ -997,10 +998,14 @@ public Boolean read() {
     }
 
     @Override
-    public WorkItemCommitRequest persistDirectly(ForKey cache) throws IOException {
+    public WorkItemCommitRequest persistDirectly(WindmillStateCache.ForKeyAndFamily cache)
+        throws IOException {
       WorkItemCommitRequest.Builder commitBuilder = WorkItemCommitRequest.newBuilder();
       TagSortedListUpdateRequest.Builder updatesBuilder =
-          commitBuilder.addSortedListUpdatesBuilder().setStateFamily(stateFamily).setTag(stateKey);
+          commitBuilder
+              .addSortedListUpdatesBuilder()
+              .setStateFamily(cache.getStateFamily())
+              .setTag(stateKey);

Review comment:
       make sure to run ./gradlew spotlessApply to get canonical formatting 




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org