You are viewing a plain text version of this content. The canonical link for it is here.
Posted to github@beam.apache.org by GitBox <gi...@apache.org> on 2022/04/09 05:23:35 UTC

[GitHub] [beam] lukecwik opened a new pull request, #17327: [BEAM-13015] Update the SDK harness grouping table to be memory bounded based upon the amount of assigned cache memory and to use an LRU eviction policy.

lukecwik opened a new pull request, #17327:
URL: https://github.com/apache/beam/pull/17327

   
   ------------------------
   
   Thank you for your contribution! Follow this checklist to help us incorporate your contribution quickly and easily:
   
    - [ ] [**Choose reviewer(s)**](https://beam.apache.org/contribute/#make-your-change) and mention them in a comment (`R: @username`).
    - [ ] Format the pull request title like `[BEAM-XXX] Fixes bug in ApproximateQuantiles`, where you replace `BEAM-XXX` with the appropriate JIRA issue, if applicable. This will automatically link the pull request to the issue.
    - [ ] Update `CHANGES.md` with noteworthy changes.
    - [ ] If this contribution is large, please file an Apache [Individual Contributor License Agreement](https://www.apache.org/licenses/icla.pdf).
   
   See the [Contributor Guide](https://beam.apache.org/contribute) for more tips on [how to make review process smoother](https://beam.apache.org/contribute/#make-reviewers-job-easier).
   
   To check the build health, please visit [https://github.com/apache/beam/blob/master/.test-infra/BUILD_STATUS.md](https://github.com/apache/beam/blob/master/.test-infra/BUILD_STATUS.md)
   
   GitHub Actions Tests Status (on master branch)
   ------------------------------------------------------------------------------------------------
   [![Build python source distribution and wheels](https://github.com/apache/beam/workflows/Build%20python%20source%20distribution%20and%20wheels/badge.svg?branch=master&event=schedule)](https://github.com/apache/beam/actions?query=workflow%3A%22Build+python+source+distribution+and+wheels%22+branch%3Amaster+event%3Aschedule)
   [![Python tests](https://github.com/apache/beam/workflows/Python%20tests/badge.svg?branch=master&event=schedule)](https://github.com/apache/beam/actions?query=workflow%3A%22Python+Tests%22+branch%3Amaster+event%3Aschedule)
   [![Java tests](https://github.com/apache/beam/workflows/Java%20Tests/badge.svg?branch=master&event=schedule)](https://github.com/apache/beam/actions?query=workflow%3A%22Java+Tests%22+branch%3Amaster+event%3Aschedule)
   
   See [CI.md](https://github.com/apache/beam/blob/master/CI.md) for more information about GitHub Actions CI.
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@beam.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [beam] lukecwik commented on pull request #17327: [BEAM-13015] Update the SDK harness grouping table to be memory bounded based upon the amount of assigned cache memory and to use an LRU eviction policy.

Posted by GitBox <gi...@apache.org>.
lukecwik commented on PR #17327:
URL: https://github.com/apache/beam/pull/17327#issuecomment-1112411820

   CC: @Abacn 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@beam.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [beam] y1chi commented on a diff in pull request #17327: [BEAM-13015] Update the SDK harness grouping table to be memory bounded based upon the amount of assigned cache memory and to use an LRU eviction policy.

Posted by GitBox <gi...@apache.org>.
y1chi commented on code in PR #17327:
URL: https://github.com/apache/beam/pull/17327#discussion_r872882294


##########
sdks/java/harness/src/main/java/org/apache/beam/fn/harness/PrecombineGroupingTable.java:
##########
@@ -223,27 +323,26 @@ public void put(
       throws Exception {
     // Ignore timestamp for grouping purposes.
     // The Pre-combine output will inherit the timestamp of one of its inputs.
-    WindowedValue<Object> groupingKey =
-        WindowedValue.of(
-            keyCoder.structuralValue(value.getValue().getKey()),
-            IGNORED,
-            value.getWindows(),
-            value.getPane());
-
-    GroupingTableEntry entry =
-        lruMap.compute(
-            groupingKey,
-            (key, tableEntry) -> {
-              if (tableEntry == null) {
-                tableEntry =
-                    new GroupingTableEntry(
-                        groupingKey, value.getValue().getKey(), value.getValue().getValue());
-              } else {
-                tableEntry.add(value.getValue().getValue());
-              }
-              return tableEntry;
-            });
-    weight += entry.getWeight();
+    GroupingTableKey groupingKey =
+        new GroupingTableKey(
+            value.getValue().getKey(), value.getWindows(), value.getPane(), keyCoder, keySizer);
+
+    lruMap.compute(
+        groupingKey,
+        (key, tableEntry) -> {
+          if (tableEntry == null) {
+            weight += groupingKey.getWeight();

Review Comment:
   remove this?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@beam.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [beam] robertwb commented on pull request #17327: [BEAM-13015] Update the SDK harness grouping table to be memory bounded based upon the amount of assigned cache memory and to use an LRU eviction policy.

Posted by GitBox <gi...@apache.org>.
robertwb commented on PR #17327:
URL: https://github.com/apache/beam/pull/17327#issuecomment-1136578666

   I happened to do some benchmarking for a separate change (#17641) and noticed that this PR seems to reduce the performance significantly. Before (https://github.com/robertwb/incubator-beam/tree/java-combine-key-old) I was getting stats
   
   ```
     33,102 ±(99.9%) 1,173 ops/s [Average]
     (min, avg, max) = (32,761, 33,102, 33,492), stdev = 0,305
     CI (99.9%): [31,929, 34,275] (assumes normal distribution)
   
     24,809 ±(99.9%) 0,861 ops/s [Average]
     (min, avg, max) = (24,521, 24,809, 25,083), stdev = 0,224
     CI (99.9%): [23,948, 25,670] (assumes normal distribution)
   ```
   
   (two benchmarks here: globally windowed and not) but after merging this change I'm seeing
   
   ```
   Result "org.apache.beam.fn.harness.jmh.CombinerTableBenchmark.uniformDistribution":
     4,949 ±(99.9%) 0,349 ops/s [Average]
     (min, avg, max) = (4,832, 4,949, 5,059), stdev = 0,091
     CI (99.9%): [4,601, 5,298] (assumes normal distribution)
   
   Result "org.apache.beam.fn.harness.jmh.CombinerTableBenchmark.uniformDistribution":
     3,855 ±(99.9%) 0,304 ops/s [Average]
     (min, avg, max) = (3,735, 3,855, 3,930), stdev = 0,079
     CI (99.9%): [3,551, 4,159] (assumes normal distribution)
   ```
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@beam.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [beam] aaltay commented on pull request #17327: [BEAM-13015] Update the SDK harness grouping table to be memory bounded based upon the amount of assigned cache memory and to use an LRU eviction policy.

Posted by GitBox <gi...@apache.org>.
aaltay commented on PR #17327:
URL: https://github.com/apache/beam/pull/17327#issuecomment-1125447394

   @Abacn - could you please review this change?


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@beam.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [beam] lukecwik commented on a diff in pull request #17327: [BEAM-13015] Update the SDK harness grouping table to be memory bounded based upon the amount of assigned cache memory and to use an LRU eviction policy.

Posted by GitBox <gi...@apache.org>.
lukecwik commented on code in PR #17327:
URL: https://github.com/apache/beam/pull/17327#discussion_r872889905


##########
sdks/java/harness/src/main/java/org/apache/beam/fn/harness/PrecombineGroupingTable.java:
##########
@@ -223,27 +323,26 @@ public void put(
       throws Exception {
     // Ignore timestamp for grouping purposes.
     // The Pre-combine output will inherit the timestamp of one of its inputs.
-    WindowedValue<Object> groupingKey =
-        WindowedValue.of(
-            keyCoder.structuralValue(value.getValue().getKey()),
-            IGNORED,
-            value.getWindows(),
-            value.getPane());
-
-    GroupingTableEntry entry =
-        lruMap.compute(
-            groupingKey,
-            (key, tableEntry) -> {
-              if (tableEntry == null) {
-                tableEntry =
-                    new GroupingTableEntry(
-                        groupingKey, value.getValue().getKey(), value.getValue().getValue());
-              } else {
-                tableEntry.add(value.getValue().getValue());
-              }
-              return tableEntry;
-            });
-    weight += entry.getWeight();
+    GroupingTableKey groupingKey =
+        new GroupingTableKey(
+            value.getValue().getKey(), value.getWindows(), value.getPane(), keyCoder, keySizer);
+
+    lruMap.compute(
+        groupingKey,
+        (key, tableEntry) -> {
+          if (tableEntry == null) {
+            weight += groupingKey.getWeight();

Review Comment:
   There are two cases. 
   * key == structural key, then:
     * GroupingTableKey weight = key weight + windows weight + pane info weight
     * GroupingTableEntry weight = reference weight + accumulator weight
   * key != structural key, then:
     * GroupingTableKey weight = structural key weight + windows weight + pane info weight
     * GroupingTableEntry weight = key weight + accumulator weight



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@beam.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [beam] lukecwik commented on a diff in pull request #17327: [BEAM-13015] Update the SDK harness grouping table to be memory bounded based upon the amount of assigned cache memory and to use an LRU eviction policy.

Posted by GitBox <gi...@apache.org>.
lukecwik commented on code in PR #17327:
URL: https://github.com/apache/beam/pull/17327#discussion_r872872640


##########
sdks/java/harness/src/test/java/org/apache/beam/fn/harness/PrecombineGroupingTableTest.java:
##########
@@ -46,79 +60,207 @@
 @RunWith(JUnit4.class)
 public class PrecombineGroupingTableTest {
 
-  private static class TestOutputReceiver implements Receiver {
-    final List<Object> outputElems = new ArrayList<>();
+  @Rule
+  public TestExecutorService executorService = TestExecutors.from(Executors.newCachedThreadPool());
+
+  private static class TestOutputReceiver<T> implements FnDataReceiver<T> {
+    final List<T> outputElems = new ArrayList<>();
 
     @Override
-    public void process(Object elem) {
+    public void accept(T elem) {
       outputElems.add(elem);
     }
   }
 
-  @Test
-  public void testCombiningGroupingTable() throws Exception {
-    Combiner<Object, Integer, Long, Long> summingCombineFn =
-        new Combiner<Object, Integer, Long, Long>() {
+  private static final CombineFn<Integer, Long, Long> COMBINE_FN =
+      new CombineFn<Integer, Long, Long>() {
 
-          @Override
-          public Long createAccumulator(Object key) {
-            return 0L;
-          }
+        @Override
+        public Long createAccumulator() {
+          return 0L;
+        }
 
-          @Override
-          public Long add(Object key, Long accumulator, Integer value) {
-            return accumulator + value;
-          }
+        @Override
+        public Long addInput(Long accumulator, Integer value) {
+          return accumulator + value;
+        }
 
-          @Override
-          public Long merge(Object key, Iterable<Long> accumulators) {
-            long sum = 0;
-            for (Long part : accumulators) {
-              sum += part;
-            }
-            return sum;
+        @Override
+        public Long mergeAccumulators(Iterable<Long> accumulators) {
+          long sum = 0;
+          for (Long part : accumulators) {
+            sum += part;
           }
+          return sum;
+        }
 
-          @Override
-          public Long compact(Object key, Long accumulator) {
-            return accumulator;
+        @Override
+        public Long compact(Long accumulator) {
+          if (accumulator % 2 == 0) {
+            return accumulator / 4;
           }
+          return accumulator;
+        }
 
-          @Override
-          public Long extract(Object key, Long accumulator) {
-            return accumulator;
-          }
-        };
+        @Override
+        public Long extractOutput(Long accumulator) {
+          return accumulator;
+        }
+      };
 
+  @Test
+  public void testCombiningGroupingTableEvictsAllOnLargeEntry() throws Exception {
     PrecombineGroupingTable<String, Integer, Long> table =
         new PrecombineGroupingTable<>(
-            100_000_000L,
-            new IdentityGroupingKeyCreator(),
-            new KvPairInfo(),
-            summingCombineFn,
+            PipelineOptionsFactory.create(),
+            Caches.forMaximumBytes(2500L),
+            StringUtf8Coder.of(),
+            GlobalCombineFnRunners.create(COMBINE_FN),
             new StringPowerSizeEstimator(),
             new IdentitySizeEstimator());
-    table.setMaxSize(1000);
 
-    TestOutputReceiver receiver = new TestOutputReceiver();
+    TestOutputReceiver<WindowedValue<KV<String, Long>>> receiver = new TestOutputReceiver<>();
 
-    table.put("A", 1, receiver);
-    table.put("B", 2, receiver);
-    table.put("B", 3, receiver);
-    table.put("C", 4, receiver);
+    table.put(valueInGlobalWindow(KV.of("A", 1)), receiver);
+    table.put(valueInGlobalWindow(KV.of("B", 3)), receiver);
+    table.put(valueInGlobalWindow(KV.of("B", 6)), receiver);
+    table.put(valueInGlobalWindow(KV.of("C", 7)), receiver);
     assertThat(receiver.outputElems, empty());
 
-    table.put("C", 5000, receiver);
-    assertThat(receiver.outputElems, hasItem((Object) KV.of("C", 5004L)));
+    // Add beyond the size which causes compaction which still leads to evicting all since the
+    // largest is most recent.
+    table.put(valueInGlobalWindow(KV.of("C", 9999)), receiver);
+    assertThat(
+        receiver.outputElems,
+        containsInAnyOrder(
+            valueInGlobalWindow(KV.of("A", 1L)),
+            valueInGlobalWindow(KV.of("B", 9L)),
+            valueInGlobalWindow(KV.of("C", (9999L + 7) / 4))));
+
+    table.flush(receiver);
+    assertThat(
+        receiver.outputElems,
+        containsInAnyOrder(
+            valueInGlobalWindow(KV.of("A", 1L)),
+            valueInGlobalWindow(KV.of("B", 3L + 6)),
+            valueInGlobalWindow(KV.of("C", (9999L + 7) / 4))));
+  }
+
+  @Test
+  public void testCombiningGroupingTableCompactionSaves() throws Exception {
+    PrecombineGroupingTable<String, Integer, Long> table =
+        new PrecombineGroupingTable<>(
+            PipelineOptionsFactory.create(),
+            Caches.forMaximumBytes(2500L),
+            StringUtf8Coder.of(),
+            GlobalCombineFnRunners.create(COMBINE_FN),
+            new StringPowerSizeEstimator(),
+            new IdentitySizeEstimator());
+
+    TestOutputReceiver<WindowedValue<KV<String, Long>>> receiver = new TestOutputReceiver<>();
+
+    // Insert three compactable values which shouldn't lead to eviction even though we are over
+    // the maximum size.
+    table.put(valueInGlobalWindow(KV.of("A", 1004)), receiver);
+    table.put(valueInGlobalWindow(KV.of("B", 1004)), receiver);
+    table.put(valueInGlobalWindow(KV.of("C", 1004)), receiver);
+    assertThat(receiver.outputElems, empty());
+
+    table.flush(receiver);
+    assertThat(
+        receiver.outputElems,
+        containsInAnyOrder(
+            valueInGlobalWindow(KV.of("A", 1004L / 4)),
+            valueInGlobalWindow(KV.of("B", 1004L / 4)),
+            valueInGlobalWindow(KV.of("C", 1004L / 4))));
+  }
+
+  @Test
+  public void testCombiningGroupingTablePartialEviction() throws Exception {
+    PrecombineGroupingTable<String, Integer, Long> table =
+        new PrecombineGroupingTable<>(
+            PipelineOptionsFactory.create(),
+            Caches.forMaximumBytes(2500L),
+            StringUtf8Coder.of(),
+            GlobalCombineFnRunners.create(COMBINE_FN),
+            new StringPowerSizeEstimator(),
+            new IdentitySizeEstimator());
+
+    TestOutputReceiver<WindowedValue<KV<String, Long>>> receiver = new TestOutputReceiver<>();
 
-    table.put("DDDD", 6, receiver);
-    assertThat(receiver.outputElems, hasItem((Object) KV.of("DDDD", 6L)));
+    // Insert three values which even with compaction isn't enough so we evict D & E to get

Review Comment:
   Done



##########
sdks/java/harness/src/main/java/org/apache/beam/fn/harness/PrecombineGroupingTable.java:
##########
@@ -17,434 +17,283 @@
  */
 package org.apache.beam.fn.harness;
 
-import java.util.HashMap;
 import java.util.Iterator;
-import java.util.Map;
+import java.util.LinkedHashMap;
 import java.util.Random;
+import java.util.concurrent.atomic.AtomicLong;
+import javax.annotation.Nullable;
+import javax.annotation.concurrent.NotThreadSafe;
+import org.apache.beam.fn.harness.Cache.Shrinkable;
 import org.apache.beam.runners.core.GlobalCombineFnRunner;
 import org.apache.beam.runners.core.GlobalCombineFnRunners;
 import org.apache.beam.runners.core.NullSideInputReader;
-import org.apache.beam.runners.core.SideInputReader;
 import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.fn.data.FnDataReceiver;
 import org.apache.beam.sdk.options.PipelineOptions;
-import org.apache.beam.sdk.options.SdkHarnessOptions;
 import org.apache.beam.sdk.transforms.Combine.CombineFn;
 import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.util.Weighted;
 import org.apache.beam.sdk.util.WindowedValue;
-import org.apache.beam.sdk.util.common.ElementByteSizeObserver;
 import org.apache.beam.sdk.values.KV;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.io.ByteStreams;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.io.CountingOutputStream;
 import org.joda.time.Instant;
 
-/** Static utility methods that provide {@link GroupingTable} implementations. */
+/** Static utility methods that provide a grouping table implementation. */
 @SuppressWarnings({
   "nullness" // TODO(https://issues.apache.org/jira/browse/BEAM-10402)
 })
+@NotThreadSafe

Review Comment:
   Documented that `put` and `flush` must be called from the bundle processing thread. `shrink` can be called from any thread. 



##########
sdks/java/harness/src/main/java/org/apache/beam/fn/harness/PrecombineGroupingTable.java:
##########
@@ -17,434 +17,283 @@
  */
 package org.apache.beam.fn.harness;
 
-import java.util.HashMap;
 import java.util.Iterator;
-import java.util.Map;
+import java.util.LinkedHashMap;
 import java.util.Random;
+import java.util.concurrent.atomic.AtomicLong;
+import javax.annotation.Nullable;
+import javax.annotation.concurrent.NotThreadSafe;
+import org.apache.beam.fn.harness.Cache.Shrinkable;
 import org.apache.beam.runners.core.GlobalCombineFnRunner;
 import org.apache.beam.runners.core.GlobalCombineFnRunners;
 import org.apache.beam.runners.core.NullSideInputReader;
-import org.apache.beam.runners.core.SideInputReader;
 import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.fn.data.FnDataReceiver;
 import org.apache.beam.sdk.options.PipelineOptions;
-import org.apache.beam.sdk.options.SdkHarnessOptions;
 import org.apache.beam.sdk.transforms.Combine.CombineFn;
 import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.util.Weighted;
 import org.apache.beam.sdk.util.WindowedValue;
-import org.apache.beam.sdk.util.common.ElementByteSizeObserver;
 import org.apache.beam.sdk.values.KV;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.io.ByteStreams;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.io.CountingOutputStream;
 import org.joda.time.Instant;
 
-/** Static utility methods that provide {@link GroupingTable} implementations. */
+/** Static utility methods that provide a grouping table implementation. */
 @SuppressWarnings({
   "nullness" // TODO(https://issues.apache.org/jira/browse/BEAM-10402)
 })
+@NotThreadSafe
 public class PrecombineGroupingTable<K, InputT, AccumT>
-    implements GroupingTable<K, InputT, AccumT> {
-  private static long getGroupingTableSizeBytes(PipelineOptions options) {
-    return options.as(SdkHarnessOptions.class).getGroupingTableMaxSizeMb() * 1024L * 1024L;
-  }
+    implements Shrinkable<PrecombineGroupingTable<K, InputT, AccumT>>, Weighted {
+
+  private static final Instant IGNORED = BoundedWindow.TIMESTAMP_MIN_VALUE;
 
-  /** Returns a {@link GroupingTable} that combines inputs into a accumulator. */
-  public static <K, InputT, AccumT> GroupingTable<WindowedValue<K>, InputT, AccumT> combining(
+  /**
+   * Returns a grouping table that combines inputs into an accumulator. The grouping table uses the
+   * cache to defer flushing output until the cache evicts the table.
+   */
+  public static <K, InputT, AccumT> PrecombineGroupingTable<K, InputT, AccumT> combining(
       PipelineOptions options,
+      Cache<Object, Object> cache,
       CombineFn<InputT, AccumT, ?> combineFn,
-      Coder<K> keyCoder,
-      Coder<? super AccumT> accumulatorCoder) {
-    Combiner<WindowedValue<K>, InputT, AccumT, ?> valueCombiner =
-        new ValueCombiner<>(
-            GlobalCombineFnRunners.create(combineFn), NullSideInputReader.empty(), options);
+      Coder<K> keyCoder) {
     return new PrecombineGroupingTable<>(
-        getGroupingTableSizeBytes(options),
-        new WindowingCoderGroupingKeyCreator<>(keyCoder),
-        WindowedPairInfo.create(),
-        valueCombiner,
-        new CoderSizeEstimator<>(WindowedValue.getValueOnlyCoder(keyCoder)),
-        new CoderSizeEstimator<>(accumulatorCoder));
+        options,
+        cache,
+        keyCoder,
+        GlobalCombineFnRunners.create(combineFn),
+        Caches::weigh,
+        Caches::weigh);
   }
 
   /**
-   * Returns a {@link GroupingTable} that combines inputs into a accumulator with sampling {@link
-   * SizeEstimator SizeEstimators}.
+   * Returns a grouping table that combines inputs into an accumulator with sampling {@link
+   * SizeEstimator SizeEstimators}. The grouping table uses the cache to defer flushing output until
+   * the cache evicts the table.
    */
-  public static <K, InputT, AccumT>
-      GroupingTable<WindowedValue<K>, InputT, AccumT> combiningAndSampling(
-          PipelineOptions options,
-          CombineFn<InputT, AccumT, ?> combineFn,
-          Coder<K> keyCoder,
-          Coder<? super AccumT> accumulatorCoder,
-          double sizeEstimatorSampleRate) {
-    Combiner<WindowedValue<K>, InputT, AccumT, ?> valueCombiner =
-        new ValueCombiner<>(
-            GlobalCombineFnRunners.create(combineFn), NullSideInputReader.empty(), options);
+  public static <K, InputT, AccumT> PrecombineGroupingTable<K, InputT, AccumT> combiningAndSampling(
+      PipelineOptions options,
+      Cache<Object, Object> cache,
+      CombineFn<InputT, AccumT, ?> combineFn,
+      Coder<K> keyCoder,
+      double sizeEstimatorSampleRate) {
     return new PrecombineGroupingTable<>(
-        getGroupingTableSizeBytes(options),
-        new WindowingCoderGroupingKeyCreator<>(keyCoder),
-        WindowedPairInfo.create(),
-        valueCombiner,
-        new SamplingSizeEstimator<>(
-            new CoderSizeEstimator<>(WindowedValue.getValueOnlyCoder(keyCoder)),
-            sizeEstimatorSampleRate,
-            1.0),
-        new SamplingSizeEstimator<>(
-            new CoderSizeEstimator<>(accumulatorCoder), sizeEstimatorSampleRate, 1.0));
-  }
-
-  /** Provides client-specific operations for grouping keys. */
-  public interface GroupingKeyCreator<K> {
-    Object createGroupingKey(K key) throws Exception;
+        options,
+        cache,
+        keyCoder,
+        GlobalCombineFnRunners.create(combineFn),
+        new SamplingSizeEstimator<>(Caches::weigh, sizeEstimatorSampleRate, 1.0),
+        new SamplingSizeEstimator<>(Caches::weigh, sizeEstimatorSampleRate, 1.0));
   }
 
-  /** Implements Precombine GroupingKeyCreator via Coder. */
-  public static class WindowingCoderGroupingKeyCreator<K>
-      implements GroupingKeyCreator<WindowedValue<K>> {
-
-    private static final Instant ignored = BoundedWindow.TIMESTAMP_MIN_VALUE;
-
-    private final Coder<K> coder;
-
-    WindowingCoderGroupingKeyCreator(Coder<K> coder) {
-      this.coder = coder;
+  @Nullable
+  @Override
+  public PrecombineGroupingTable<K, InputT, AccumT> shrink() {
+    long currentWeight = maxWeight.updateAndGet(operand -> operand >> 1);
+    // It is possible that we are shrunk multiple times until the requested max weight is too small.
+    // In this case we want to effectively stop shrinking since we can't effectively cache much
+    // at this time and the next insertion will likely evict all records.
+    if (currentWeight <= 100L) {
+      return null;
     }
+    return this;
+  }
 
-    @Override
-    public Object createGroupingKey(WindowedValue<K> key) {
-      // Ignore timestamp for grouping purposes.
-      // The Precombine output will inherit the timestamp of one of its inputs.
-      return WindowedValue.of(
-          coder.structuralValue(key.getValue()), ignored, key.getWindows(), key.getPane());
-    }
+  @Override
+  public long getWeight() {
+    return maxWeight.get();
   }
 
   /** Provides client-specific operations for size estimates. */
+  @FunctionalInterface
   public interface SizeEstimator<T> {
-    long estimateSize(T element) throws Exception;
+    long estimateSize(T element);
   }
 
-  /** Implements SizeEstimator via Coder. */
-  public static class CoderSizeEstimator<T> implements SizeEstimator<T> {
-    /** Basic implementation of {@link ElementByteSizeObserver} for use in size estimation. */
-    private static class Observer extends ElementByteSizeObserver {
-      private long observedSize = 0;
-
-      @Override
-      protected void reportElementSize(long elementSize) {
-        observedSize += elementSize;
-      }
-    }
-
-    final Coder<T> coder;
+  private final Coder<K> keyCoder;
+  private final GlobalCombineFnRunner<InputT, AccumT, ?> combineFn;
+  private final PipelineOptions options;
+  private final SizeEstimator<K> keySizer;
+  private final SizeEstimator<AccumT> accumulatorSizer;
+  private final Cache<Key, PrecombineGroupingTable<K, InputT, AccumT>> cache;
+  private final LinkedHashMap<WindowedValue<Object>, GroupingTableEntry> lruMap;
+  private final AtomicLong maxWeight;
+  private long weight;
 
-    CoderSizeEstimator(Coder<T> coder) {
-      this.coder = coder;
-    }
+  private static final class Key implements Weighted {
+    private static final Key INSTANCE = new Key();
 
     @Override
-    public long estimateSize(T value) throws Exception {
-      // First try using byte size observer
-      CoderSizeEstimator.Observer observer = new CoderSizeEstimator.Observer();
-      coder.registerByteSizeObserver(value, observer);
-
-      if (!observer.getIsLazy()) {
-        observer.advance();
-        return observer.observedSize;
-      } else {
-        // Coder byte size observation is lazy (requires iteration for observation) so fall back to
-        // counting output stream
-        CountingOutputStream os = new CountingOutputStream(ByteStreams.nullOutputStream());
-        coder.encode(value, os);
-        return os.getCount();
-      }
+    public long getWeight() {
+      // Ignore the actual size of this singleton because it is trivial and because
+      // the weight reported here will be counted many times as it is present in
+      // many different state subcaches.
+      return 0;
     }
   }
 
-  /**
-   * Provides client-specific operations for working with elements that are key/value or key/values
-   * pairs.
-   */
-  public interface PairInfo {
-    Object getKeyFromInputPair(Object pair);
-
-    Object getValueFromInputPair(Object pair);
-
-    Object makeOutputPair(Object key, Object value);
-  }
-
-  /** Implements Precombine PairInfo via KVs. */
-  public static class WindowedPairInfo implements PairInfo {
-    private static WindowedPairInfo theInstance = new WindowedPairInfo();
-
-    public static WindowedPairInfo create() {
-      return theInstance;
-    }
-
-    private WindowedPairInfo() {}
-
-    @Override
-    public Object getKeyFromInputPair(Object pair) {
-      @SuppressWarnings("unchecked")
-      WindowedValue<KV<?, ?>> windowedKv = (WindowedValue<KV<?, ?>>) pair;
-      return windowedKv.withValue(windowedKv.getValue().getKey());
-    }
-
-    @Override
-    public Object getValueFromInputPair(Object pair) {
-      @SuppressWarnings("unchecked")
-      WindowedValue<KV<?, ?>> windowedKv = (WindowedValue<KV<?, ?>>) pair;
-      return windowedKv.getValue().getValue();
-    }
-
-    @Override
-    public Object makeOutputPair(Object key, Object values) {
-      WindowedValue<?> windowedKey = (WindowedValue<?>) key;
-      return windowedKey.withValue(KV.of(windowedKey.getValue(), values));
-    }
-  }
-
-  /** Provides client-specific operations for combining values. */
-  public interface Combiner<K, InputT, AccumT, OutputT> {
-    AccumT createAccumulator(K key);
-
-    AccumT add(K key, AccumT accumulator, InputT value);
-
-    AccumT merge(K key, Iterable<AccumT> accumulators);
-
-    AccumT compact(K key, AccumT accumulator);
-
-    OutputT extract(K key, AccumT accumulator);
+  PrecombineGroupingTable(
+      PipelineOptions options,
+      Cache<?, ?> cache,
+      Coder<K> keyCoder,
+      GlobalCombineFnRunner<InputT, AccumT, ?> combineFn,
+      SizeEstimator<K> keySizer,
+      SizeEstimator<AccumT> accumulatorSizer) {
+    this.options = options;
+    this.cache = (Cache<Key, PrecombineGroupingTable<K, InputT, AccumT>>) cache;
+    this.keyCoder = keyCoder;
+    this.combineFn = combineFn;
+    this.keySizer = keySizer;
+    this.accumulatorSizer = accumulatorSizer;
+    this.lruMap = new LinkedHashMap<>(16, 0.75f, true);
+    this.maxWeight = new AtomicLong();
+    this.weight = 0L;
+    this.cache.put(Key.INSTANCE, this);
   }
 
-  /** Implements Precombine Combiner via Combine.KeyedCombineFn. */
-  public static class ValueCombiner<K, InputT, AccumT, OutputT>
-      implements Combiner<WindowedValue<K>, InputT, AccumT, OutputT> {
-    private final GlobalCombineFnRunner<InputT, AccumT, OutputT> combineFn;
-    private final SideInputReader sideInputReader;
-    private final PipelineOptions options;
-
-    private ValueCombiner(
-        GlobalCombineFnRunner<InputT, AccumT, OutputT> combineFn,
-        SideInputReader sideInputReader,
-        PipelineOptions options) {
-      this.combineFn = combineFn;
-      this.sideInputReader = sideInputReader;
-      this.options = options;
-    }
-
-    @Override
-    public AccumT createAccumulator(WindowedValue<K> windowedKey) {
-      return this.combineFn.createAccumulator(options, sideInputReader, windowedKey.getWindows());
+  private class GroupingTableEntry implements Weighted {
+    private final WindowedValue<Object> groupingKey;
+    private final K userKey;
+    private final long keySize;
+    private long accumulatorSize;
+    private AccumT accumulator;
+    private boolean dirty;
+
+    private GroupingTableEntry(
+        WindowedValue<Object> groupingKey, K userKey, InputT initialInputValue) {
+      this.groupingKey = groupingKey;
+      this.userKey = userKey;
+      if (groupingKey.getValue() == userKey) {
+        // This object is only storing references to the same objects that are being stored
+        // by the cache so the accounting of the size of the key is occurring already.
+        this.keySize = Caches.REFERENCE_SIZE * 2;
+      } else {
+        this.keySize = Caches.REFERENCE_SIZE + keySizer.estimateSize(userKey);
+      }
+      this.accumulator =
+          combineFn.createAccumulator(
+              options, NullSideInputReader.empty(), groupingKey.getWindows());
+      add(initialInputValue);
+      this.accumulatorSize = accumulatorSizer.estimateSize(accumulator);
     }
 
-    @Override
-    public AccumT add(WindowedValue<K> windowedKey, AccumT accumulator, InputT value) {
-      return this.combineFn.addInput(
-          accumulator, value, options, sideInputReader, windowedKey.getWindows());
+    public WindowedValue<Object> getGroupingKey() {
+      return groupingKey;
     }
 
-    @Override
-    public AccumT merge(WindowedValue<K> windowedKey, Iterable<AccumT> accumulators) {
-      return this.combineFn.mergeAccumulators(
-          accumulators, options, sideInputReader, windowedKey.getWindows());
+    public K getKey() {
+      return userKey;
     }
 
-    @Override
-    public AccumT compact(WindowedValue<K> windowedKey, AccumT accumulator) {
-      return this.combineFn.compact(
-          accumulator, options, sideInputReader, windowedKey.getWindows());
+    public AccumT getValue() {
+      return accumulator;
     }
 
     @Override
-    public OutputT extract(WindowedValue<K> windowedKey, AccumT accumulator) {
-      return this.combineFn.extractOutput(
-          accumulator, options, sideInputReader, windowedKey.getWindows());
+    public long getWeight() {
+      return keySize + accumulatorSize;
     }
-  }
-
-  // How many bytes a word in the JVM has.
-  private static final int BYTES_PER_JVM_WORD = getBytesPerJvmWord();
-  /**
-   * The number of bytes of overhead to store an entry in the grouping table (a {@code
-   * HashMap<StructuralByteArray, KeyAndValues>}), ignoring the actual number of bytes in the keys
-   * and values:
-   *
-   * <ul>
-   *   <li>an array element (1 word),
-   *   <li>a HashMap.Entry (4 words),
-   *   <li>a StructuralByteArray (1 words),
-   *   <li>a backing array (guessed at 1 word for the length),
-   *   <li>a KeyAndValues (2 words),
-   *   <li>an ArrayList (2 words),
-   *   <li>a backing array (1 word),
-   *   <li>per-object overhead (JVM-specific, guessed at 2 words * 6 objects).
-   * </ul>
-   */
-  private static final int PER_KEY_OVERHEAD = 24 * BYTES_PER_JVM_WORD;
-
-  /** A {@link GroupingTable} that uses the given combiner to combine values in place. */
-  // Keep the table relatively full to increase the chance of collisions.
-  private static final double TARGET_LOAD = 0.9;
-
-  private long maxSize;
-  private final GroupingKeyCreator<? super K> groupingKeyCreator;
-  private final PairInfo pairInfo;
-  private final Combiner<? super K, InputT, AccumT, ?> combiner;
-  private final SizeEstimator<? super K> keySizer;
-  private final SizeEstimator<? super AccumT> accumulatorSizer;
-
-  private long size = 0;
-  private Map<Object, GroupingTableEntry<K, InputT, AccumT>> table;
-
-  PrecombineGroupingTable(
-      long maxSize,
-      GroupingKeyCreator<? super K> groupingKeyCreator,
-      PairInfo pairInfo,
-      Combiner<? super K, InputT, AccumT, ?> combineFn,
-      SizeEstimator<? super K> keySizer,
-      SizeEstimator<? super AccumT> accumulatorSizer) {
-    this.maxSize = maxSize;
-    this.groupingKeyCreator = groupingKeyCreator;
-    this.pairInfo = pairInfo;
-    this.combiner = combineFn;
-    this.keySizer = keySizer;
-    this.accumulatorSizer = accumulatorSizer;
-    this.table = new HashMap<>();
-  }
-
-  interface GroupingTableEntry<K, InputT, AccumT> {
-    K getKey();
-
-    AccumT getValue();
-
-    void add(InputT value) throws Exception;
-
-    long getSize();
-
-    void compact() throws Exception;
-  }
-
-  private GroupingTableEntry<K, InputT, AccumT> createTableEntry(final K key) throws Exception {
-    return new GroupingTableEntry<K, InputT, AccumT>() {
-      final long keySize = keySizer.estimateSize(key);
-      AccumT accumulator = combiner.createAccumulator(key);
-      long accumulatorSize = 0; // never used before a value is added...
-
-      @Override
-      public K getKey() {
-        return key;
-      }
-
-      @Override
-      public AccumT getValue() {
-        return accumulator;
-      }
-
-      @Override
-      public long getSize() {
-        return keySize + accumulatorSize;
-      }
-
-      @Override
-      public void compact() throws Exception {
-        AccumT newAccumulator = combiner.compact(key, accumulator);
-        if (newAccumulator != accumulator) {
-          accumulator = newAccumulator;
-          accumulatorSize = accumulatorSizer.estimateSize(newAccumulator);
-        }
-      }
 
-      @Override
-      public void add(InputT value) throws Exception {
-        accumulator = combiner.add(key, accumulator, value);
+    public void compact() {
+      if (dirty) {
+        accumulator =
+            combineFn.compact(
+                accumulator, options, NullSideInputReader.empty(), groupingKey.getWindows());
         accumulatorSize = accumulatorSizer.estimateSize(accumulator);
+        dirty = false;
       }
-    };
-  }
+    }
 
-  /** Adds a pair to this table, possibly flushing some entries to output if the table is full. */
-  @SuppressWarnings("unchecked")
-  @Override
-  public void put(Object pair, Receiver receiver) throws Exception {
-    put(
-        (K) pairInfo.getKeyFromInputPair(pair),
-        (InputT) pairInfo.getValueFromInputPair(pair),
-        receiver);
+    public void add(InputT value) {
+      dirty = true;
+      accumulator =
+          combineFn.addInput(
+              accumulator, value, options, NullSideInputReader.empty(), groupingKey.getWindows());
+      accumulatorSize = accumulatorSizer.estimateSize(accumulator);
+    }
   }
 
   /**
    * Adds the key and value to this table, possibly flushing some entries to output if the table is
    * full.
    */
-  public void put(K key, InputT value, Receiver receiver) throws Exception {
-    Object groupingKey = groupingKeyCreator.createGroupingKey(key);
-    GroupingTableEntry<K, InputT, AccumT> entry = table.get(groupingKey);
-    if (entry == null) {
-      entry = createTableEntry(key);
-      table.put(groupingKey, entry);
-      size += PER_KEY_OVERHEAD;
-    } else {
-      size -= entry.getSize();
-    }
-    entry.add(value);
-    size += entry.getSize();
-
-    if (size >= maxSize) {
-      long targetSize = (long) (TARGET_LOAD * maxSize);
-      Iterator<GroupingTableEntry<K, InputT, AccumT>> entries = table.values().iterator();
-      while (size >= targetSize) {
-        if (!entries.hasNext()) {
-          // Should never happen, but sizes may be estimates...
-          size = 0;
-          break;
+  @VisibleForTesting
+  public void put(
+      WindowedValue<KV<K, InputT>> value, FnDataReceiver<WindowedValue<KV<K, AccumT>>> receiver)
+      throws Exception {
+    // Ignore timestamp for grouping purposes.
+    // The Pre-combine output will inherit the timestamp of one of its inputs.
+    WindowedValue<Object> groupingKey =
+        WindowedValue.of(
+            keyCoder.structuralValue(value.getValue().getKey()),
+            IGNORED,
+            value.getWindows(),
+            value.getPane());
+
+    GroupingTableEntry entry =
+        lruMap.compute(
+            groupingKey,
+            (key, tableEntry) -> {
+              if (tableEntry == null) {
+                tableEntry =
+                    new GroupingTableEntry(
+                        groupingKey, value.getValue().getKey(), value.getValue().getValue());
+              } else {
+                tableEntry.add(value.getValue().getValue());
+              }
+              return tableEntry;
+            });
+    weight += entry.getWeight();
+    // Increase the maximum only if we require it
+    maxWeight.accumulateAndGet(weight, (current, update) -> current < update ? update : current);
+
+    // Update the cache to ensure that LRU is handled appropriately and for the cache to have an
+    // opportunity to shrink the maxWeight if necessary.
+    cache.put(Key.INSTANCE, this);
+
+    // Get the updated weight now that the cache may have been shrunk and respect it
+    long currentMax = maxWeight.get();
+    if (weight > currentMax) {

Review Comment:
   Because we want to make sure that we only produce output from the bundle processing thread and not from an arbitrary thread that caused the shrinking to happen. Added a comment to reflect.



##########
sdks/java/harness/src/main/java/org/apache/beam/fn/harness/PrecombineGroupingTable.java:
##########
@@ -17,434 +17,283 @@
  */
 package org.apache.beam.fn.harness;
 
-import java.util.HashMap;
 import java.util.Iterator;
-import java.util.Map;
+import java.util.LinkedHashMap;
 import java.util.Random;
+import java.util.concurrent.atomic.AtomicLong;
+import javax.annotation.Nullable;
+import javax.annotation.concurrent.NotThreadSafe;
+import org.apache.beam.fn.harness.Cache.Shrinkable;
 import org.apache.beam.runners.core.GlobalCombineFnRunner;
 import org.apache.beam.runners.core.GlobalCombineFnRunners;
 import org.apache.beam.runners.core.NullSideInputReader;
-import org.apache.beam.runners.core.SideInputReader;
 import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.fn.data.FnDataReceiver;
 import org.apache.beam.sdk.options.PipelineOptions;
-import org.apache.beam.sdk.options.SdkHarnessOptions;
 import org.apache.beam.sdk.transforms.Combine.CombineFn;
 import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.util.Weighted;
 import org.apache.beam.sdk.util.WindowedValue;
-import org.apache.beam.sdk.util.common.ElementByteSizeObserver;
 import org.apache.beam.sdk.values.KV;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.io.ByteStreams;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.io.CountingOutputStream;
 import org.joda.time.Instant;
 
-/** Static utility methods that provide {@link GroupingTable} implementations. */
+/** Static utility methods that provide a grouping table implementation. */
 @SuppressWarnings({
   "nullness" // TODO(https://issues.apache.org/jira/browse/BEAM-10402)
 })
+@NotThreadSafe
 public class PrecombineGroupingTable<K, InputT, AccumT>
-    implements GroupingTable<K, InputT, AccumT> {
-  private static long getGroupingTableSizeBytes(PipelineOptions options) {
-    return options.as(SdkHarnessOptions.class).getGroupingTableMaxSizeMb() * 1024L * 1024L;
-  }
+    implements Shrinkable<PrecombineGroupingTable<K, InputT, AccumT>>, Weighted {
+
+  private static final Instant IGNORED = BoundedWindow.TIMESTAMP_MIN_VALUE;
 
-  /** Returns a {@link GroupingTable} that combines inputs into a accumulator. */
-  public static <K, InputT, AccumT> GroupingTable<WindowedValue<K>, InputT, AccumT> combining(
+  /**
+   * Returns a grouping table that combines inputs into an accumulator. The grouping table uses the
+   * cache to defer flushing output until the cache evicts the table.
+   */
+  public static <K, InputT, AccumT> PrecombineGroupingTable<K, InputT, AccumT> combining(
       PipelineOptions options,
+      Cache<Object, Object> cache,
       CombineFn<InputT, AccumT, ?> combineFn,
-      Coder<K> keyCoder,
-      Coder<? super AccumT> accumulatorCoder) {
-    Combiner<WindowedValue<K>, InputT, AccumT, ?> valueCombiner =
-        new ValueCombiner<>(
-            GlobalCombineFnRunners.create(combineFn), NullSideInputReader.empty(), options);
+      Coder<K> keyCoder) {
     return new PrecombineGroupingTable<>(
-        getGroupingTableSizeBytes(options),
-        new WindowingCoderGroupingKeyCreator<>(keyCoder),
-        WindowedPairInfo.create(),
-        valueCombiner,
-        new CoderSizeEstimator<>(WindowedValue.getValueOnlyCoder(keyCoder)),
-        new CoderSizeEstimator<>(accumulatorCoder));
+        options,
+        cache,
+        keyCoder,
+        GlobalCombineFnRunners.create(combineFn),
+        Caches::weigh,
+        Caches::weigh);
   }
 
   /**
-   * Returns a {@link GroupingTable} that combines inputs into a accumulator with sampling {@link
-   * SizeEstimator SizeEstimators}.
+   * Returns a grouping table that combines inputs into an accumulator with sampling {@link
+   * SizeEstimator SizeEstimators}. The grouping table uses the cache to defer flushing output until
+   * the cache evicts the table.
    */
-  public static <K, InputT, AccumT>
-      GroupingTable<WindowedValue<K>, InputT, AccumT> combiningAndSampling(
-          PipelineOptions options,
-          CombineFn<InputT, AccumT, ?> combineFn,
-          Coder<K> keyCoder,
-          Coder<? super AccumT> accumulatorCoder,
-          double sizeEstimatorSampleRate) {
-    Combiner<WindowedValue<K>, InputT, AccumT, ?> valueCombiner =
-        new ValueCombiner<>(
-            GlobalCombineFnRunners.create(combineFn), NullSideInputReader.empty(), options);
+  public static <K, InputT, AccumT> PrecombineGroupingTable<K, InputT, AccumT> combiningAndSampling(
+      PipelineOptions options,
+      Cache<Object, Object> cache,
+      CombineFn<InputT, AccumT, ?> combineFn,
+      Coder<K> keyCoder,
+      double sizeEstimatorSampleRate) {
     return new PrecombineGroupingTable<>(
-        getGroupingTableSizeBytes(options),
-        new WindowingCoderGroupingKeyCreator<>(keyCoder),
-        WindowedPairInfo.create(),
-        valueCombiner,
-        new SamplingSizeEstimator<>(
-            new CoderSizeEstimator<>(WindowedValue.getValueOnlyCoder(keyCoder)),
-            sizeEstimatorSampleRate,
-            1.0),
-        new SamplingSizeEstimator<>(
-            new CoderSizeEstimator<>(accumulatorCoder), sizeEstimatorSampleRate, 1.0));
-  }
-
-  /** Provides client-specific operations for grouping keys. */
-  public interface GroupingKeyCreator<K> {
-    Object createGroupingKey(K key) throws Exception;
+        options,
+        cache,
+        keyCoder,
+        GlobalCombineFnRunners.create(combineFn),
+        new SamplingSizeEstimator<>(Caches::weigh, sizeEstimatorSampleRate, 1.0),
+        new SamplingSizeEstimator<>(Caches::weigh, sizeEstimatorSampleRate, 1.0));
   }
 
-  /** Implements Precombine GroupingKeyCreator via Coder. */
-  public static class WindowingCoderGroupingKeyCreator<K>
-      implements GroupingKeyCreator<WindowedValue<K>> {
-
-    private static final Instant ignored = BoundedWindow.TIMESTAMP_MIN_VALUE;
-
-    private final Coder<K> coder;
-
-    WindowingCoderGroupingKeyCreator(Coder<K> coder) {
-      this.coder = coder;
+  @Nullable
+  @Override
+  public PrecombineGroupingTable<K, InputT, AccumT> shrink() {
+    long currentWeight = maxWeight.updateAndGet(operand -> operand >> 1);
+    // It is possible that we are shrunk multiple times until the requested max weight is too small.
+    // In this case we want to effectively stop shrinking since we can't effectively cache much
+    // at this time and the next insertion will likely evict all records.
+    if (currentWeight <= 100L) {
+      return null;
     }
+    return this;
+  }
 
-    @Override
-    public Object createGroupingKey(WindowedValue<K> key) {
-      // Ignore timestamp for grouping purposes.
-      // The Precombine output will inherit the timestamp of one of its inputs.
-      return WindowedValue.of(
-          coder.structuralValue(key.getValue()), ignored, key.getWindows(), key.getPane());
-    }
+  @Override
+  public long getWeight() {
+    return maxWeight.get();
   }
 
   /** Provides client-specific operations for size estimates. */
+  @FunctionalInterface
   public interface SizeEstimator<T> {
-    long estimateSize(T element) throws Exception;
+    long estimateSize(T element);
   }
 
-  /** Implements SizeEstimator via Coder. */
-  public static class CoderSizeEstimator<T> implements SizeEstimator<T> {
-    /** Basic implementation of {@link ElementByteSizeObserver} for use in size estimation. */
-    private static class Observer extends ElementByteSizeObserver {
-      private long observedSize = 0;
-
-      @Override
-      protected void reportElementSize(long elementSize) {
-        observedSize += elementSize;
-      }
-    }
-
-    final Coder<T> coder;
+  private final Coder<K> keyCoder;
+  private final GlobalCombineFnRunner<InputT, AccumT, ?> combineFn;
+  private final PipelineOptions options;
+  private final SizeEstimator<K> keySizer;
+  private final SizeEstimator<AccumT> accumulatorSizer;
+  private final Cache<Key, PrecombineGroupingTable<K, InputT, AccumT>> cache;
+  private final LinkedHashMap<WindowedValue<Object>, GroupingTableEntry> lruMap;
+  private final AtomicLong maxWeight;
+  private long weight;
 
-    CoderSizeEstimator(Coder<T> coder) {
-      this.coder = coder;
-    }
+  private static final class Key implements Weighted {
+    private static final Key INSTANCE = new Key();
 
     @Override
-    public long estimateSize(T value) throws Exception {
-      // First try using byte size observer
-      CoderSizeEstimator.Observer observer = new CoderSizeEstimator.Observer();
-      coder.registerByteSizeObserver(value, observer);
-
-      if (!observer.getIsLazy()) {
-        observer.advance();
-        return observer.observedSize;
-      } else {
-        // Coder byte size observation is lazy (requires iteration for observation) so fall back to
-        // counting output stream
-        CountingOutputStream os = new CountingOutputStream(ByteStreams.nullOutputStream());
-        coder.encode(value, os);
-        return os.getCount();
-      }
+    public long getWeight() {
+      // Ignore the actual size of this singleton because it is trivial and because
+      // the weight reported here will be counted many times as it is present in
+      // many different state subcaches.
+      return 0;
     }
   }
 
-  /**
-   * Provides client-specific operations for working with elements that are key/value or key/values
-   * pairs.
-   */
-  public interface PairInfo {
-    Object getKeyFromInputPair(Object pair);
-
-    Object getValueFromInputPair(Object pair);
-
-    Object makeOutputPair(Object key, Object value);
-  }
-
-  /** Implements Precombine PairInfo via KVs. */
-  public static class WindowedPairInfo implements PairInfo {
-    private static WindowedPairInfo theInstance = new WindowedPairInfo();
-
-    public static WindowedPairInfo create() {
-      return theInstance;
-    }
-
-    private WindowedPairInfo() {}
-
-    @Override
-    public Object getKeyFromInputPair(Object pair) {
-      @SuppressWarnings("unchecked")
-      WindowedValue<KV<?, ?>> windowedKv = (WindowedValue<KV<?, ?>>) pair;
-      return windowedKv.withValue(windowedKv.getValue().getKey());
-    }
-
-    @Override
-    public Object getValueFromInputPair(Object pair) {
-      @SuppressWarnings("unchecked")
-      WindowedValue<KV<?, ?>> windowedKv = (WindowedValue<KV<?, ?>>) pair;
-      return windowedKv.getValue().getValue();
-    }
-
-    @Override
-    public Object makeOutputPair(Object key, Object values) {
-      WindowedValue<?> windowedKey = (WindowedValue<?>) key;
-      return windowedKey.withValue(KV.of(windowedKey.getValue(), values));
-    }
-  }
-
-  /** Provides client-specific operations for combining values. */
-  public interface Combiner<K, InputT, AccumT, OutputT> {
-    AccumT createAccumulator(K key);
-
-    AccumT add(K key, AccumT accumulator, InputT value);
-
-    AccumT merge(K key, Iterable<AccumT> accumulators);
-
-    AccumT compact(K key, AccumT accumulator);
-
-    OutputT extract(K key, AccumT accumulator);
+  PrecombineGroupingTable(
+      PipelineOptions options,
+      Cache<?, ?> cache,
+      Coder<K> keyCoder,
+      GlobalCombineFnRunner<InputT, AccumT, ?> combineFn,
+      SizeEstimator<K> keySizer,
+      SizeEstimator<AccumT> accumulatorSizer) {
+    this.options = options;
+    this.cache = (Cache<Key, PrecombineGroupingTable<K, InputT, AccumT>>) cache;
+    this.keyCoder = keyCoder;
+    this.combineFn = combineFn;
+    this.keySizer = keySizer;
+    this.accumulatorSizer = accumulatorSizer;
+    this.lruMap = new LinkedHashMap<>(16, 0.75f, true);
+    this.maxWeight = new AtomicLong();
+    this.weight = 0L;
+    this.cache.put(Key.INSTANCE, this);
   }
 
-  /** Implements Precombine Combiner via Combine.KeyedCombineFn. */
-  public static class ValueCombiner<K, InputT, AccumT, OutputT>
-      implements Combiner<WindowedValue<K>, InputT, AccumT, OutputT> {
-    private final GlobalCombineFnRunner<InputT, AccumT, OutputT> combineFn;
-    private final SideInputReader sideInputReader;
-    private final PipelineOptions options;
-
-    private ValueCombiner(
-        GlobalCombineFnRunner<InputT, AccumT, OutputT> combineFn,
-        SideInputReader sideInputReader,
-        PipelineOptions options) {
-      this.combineFn = combineFn;
-      this.sideInputReader = sideInputReader;
-      this.options = options;
-    }
-
-    @Override
-    public AccumT createAccumulator(WindowedValue<K> windowedKey) {
-      return this.combineFn.createAccumulator(options, sideInputReader, windowedKey.getWindows());
+  private class GroupingTableEntry implements Weighted {
+    private final WindowedValue<Object> groupingKey;
+    private final K userKey;
+    private final long keySize;
+    private long accumulatorSize;
+    private AccumT accumulator;
+    private boolean dirty;
+
+    private GroupingTableEntry(
+        WindowedValue<Object> groupingKey, K userKey, InputT initialInputValue) {
+      this.groupingKey = groupingKey;
+      this.userKey = userKey;
+      if (groupingKey.getValue() == userKey) {
+        // This object is only storing references to the same objects that are being stored
+        // by the cache so the accounting of the size of the key is occurring already.
+        this.keySize = Caches.REFERENCE_SIZE * 2;
+      } else {
+        this.keySize = Caches.REFERENCE_SIZE + keySizer.estimateSize(userKey);
+      }
+      this.accumulator =
+          combineFn.createAccumulator(
+              options, NullSideInputReader.empty(), groupingKey.getWindows());
+      add(initialInputValue);
+      this.accumulatorSize = accumulatorSizer.estimateSize(accumulator);
     }
 
-    @Override
-    public AccumT add(WindowedValue<K> windowedKey, AccumT accumulator, InputT value) {
-      return this.combineFn.addInput(
-          accumulator, value, options, sideInputReader, windowedKey.getWindows());
+    public WindowedValue<Object> getGroupingKey() {
+      return groupingKey;
     }
 
-    @Override
-    public AccumT merge(WindowedValue<K> windowedKey, Iterable<AccumT> accumulators) {
-      return this.combineFn.mergeAccumulators(
-          accumulators, options, sideInputReader, windowedKey.getWindows());
+    public K getKey() {
+      return userKey;
     }
 
-    @Override
-    public AccumT compact(WindowedValue<K> windowedKey, AccumT accumulator) {
-      return this.combineFn.compact(
-          accumulator, options, sideInputReader, windowedKey.getWindows());
+    public AccumT getValue() {
+      return accumulator;
     }
 
     @Override
-    public OutputT extract(WindowedValue<K> windowedKey, AccumT accumulator) {
-      return this.combineFn.extractOutput(
-          accumulator, options, sideInputReader, windowedKey.getWindows());
+    public long getWeight() {
+      return keySize + accumulatorSize;
     }
-  }
-
-  // How many bytes a word in the JVM has.
-  private static final int BYTES_PER_JVM_WORD = getBytesPerJvmWord();
-  /**
-   * The number of bytes of overhead to store an entry in the grouping table (a {@code
-   * HashMap<StructuralByteArray, KeyAndValues>}), ignoring the actual number of bytes in the keys
-   * and values:
-   *
-   * <ul>
-   *   <li>an array element (1 word),
-   *   <li>a HashMap.Entry (4 words),
-   *   <li>a StructuralByteArray (1 words),
-   *   <li>a backing array (guessed at 1 word for the length),
-   *   <li>a KeyAndValues (2 words),
-   *   <li>an ArrayList (2 words),
-   *   <li>a backing array (1 word),
-   *   <li>per-object overhead (JVM-specific, guessed at 2 words * 6 objects).
-   * </ul>
-   */
-  private static final int PER_KEY_OVERHEAD = 24 * BYTES_PER_JVM_WORD;
-
-  /** A {@link GroupingTable} that uses the given combiner to combine values in place. */
-  // Keep the table relatively full to increase the chance of collisions.
-  private static final double TARGET_LOAD = 0.9;
-
-  private long maxSize;
-  private final GroupingKeyCreator<? super K> groupingKeyCreator;
-  private final PairInfo pairInfo;
-  private final Combiner<? super K, InputT, AccumT, ?> combiner;
-  private final SizeEstimator<? super K> keySizer;
-  private final SizeEstimator<? super AccumT> accumulatorSizer;
-
-  private long size = 0;
-  private Map<Object, GroupingTableEntry<K, InputT, AccumT>> table;
-
-  PrecombineGroupingTable(
-      long maxSize,
-      GroupingKeyCreator<? super K> groupingKeyCreator,
-      PairInfo pairInfo,
-      Combiner<? super K, InputT, AccumT, ?> combineFn,
-      SizeEstimator<? super K> keySizer,
-      SizeEstimator<? super AccumT> accumulatorSizer) {
-    this.maxSize = maxSize;
-    this.groupingKeyCreator = groupingKeyCreator;
-    this.pairInfo = pairInfo;
-    this.combiner = combineFn;
-    this.keySizer = keySizer;
-    this.accumulatorSizer = accumulatorSizer;
-    this.table = new HashMap<>();
-  }
-
-  interface GroupingTableEntry<K, InputT, AccumT> {
-    K getKey();
-
-    AccumT getValue();
-
-    void add(InputT value) throws Exception;
-
-    long getSize();
-
-    void compact() throws Exception;
-  }
-
-  private GroupingTableEntry<K, InputT, AccumT> createTableEntry(final K key) throws Exception {
-    return new GroupingTableEntry<K, InputT, AccumT>() {
-      final long keySize = keySizer.estimateSize(key);
-      AccumT accumulator = combiner.createAccumulator(key);
-      long accumulatorSize = 0; // never used before a value is added...
-
-      @Override
-      public K getKey() {
-        return key;
-      }
-
-      @Override
-      public AccumT getValue() {
-        return accumulator;
-      }
-
-      @Override
-      public long getSize() {
-        return keySize + accumulatorSize;
-      }
-
-      @Override
-      public void compact() throws Exception {
-        AccumT newAccumulator = combiner.compact(key, accumulator);
-        if (newAccumulator != accumulator) {
-          accumulator = newAccumulator;
-          accumulatorSize = accumulatorSizer.estimateSize(newAccumulator);
-        }
-      }
 
-      @Override
-      public void add(InputT value) throws Exception {
-        accumulator = combiner.add(key, accumulator, value);
+    public void compact() {
+      if (dirty) {
+        accumulator =
+            combineFn.compact(
+                accumulator, options, NullSideInputReader.empty(), groupingKey.getWindows());
         accumulatorSize = accumulatorSizer.estimateSize(accumulator);
+        dirty = false;
       }
-    };
-  }
+    }
 
-  /** Adds a pair to this table, possibly flushing some entries to output if the table is full. */
-  @SuppressWarnings("unchecked")
-  @Override
-  public void put(Object pair, Receiver receiver) throws Exception {
-    put(
-        (K) pairInfo.getKeyFromInputPair(pair),
-        (InputT) pairInfo.getValueFromInputPair(pair),
-        receiver);
+    public void add(InputT value) {
+      dirty = true;
+      accumulator =
+          combineFn.addInput(
+              accumulator, value, options, NullSideInputReader.empty(), groupingKey.getWindows());
+      accumulatorSize = accumulatorSizer.estimateSize(accumulator);
+    }
   }
 
   /**
    * Adds the key and value to this table, possibly flushing some entries to output if the table is
    * full.
    */
-  public void put(K key, InputT value, Receiver receiver) throws Exception {
-    Object groupingKey = groupingKeyCreator.createGroupingKey(key);
-    GroupingTableEntry<K, InputT, AccumT> entry = table.get(groupingKey);
-    if (entry == null) {
-      entry = createTableEntry(key);
-      table.put(groupingKey, entry);
-      size += PER_KEY_OVERHEAD;
-    } else {
-      size -= entry.getSize();
-    }
-    entry.add(value);
-    size += entry.getSize();
-
-    if (size >= maxSize) {
-      long targetSize = (long) (TARGET_LOAD * maxSize);
-      Iterator<GroupingTableEntry<K, InputT, AccumT>> entries = table.values().iterator();
-      while (size >= targetSize) {
-        if (!entries.hasNext()) {
-          // Should never happen, but sizes may be estimates...
-          size = 0;
-          break;
+  @VisibleForTesting
+  public void put(
+      WindowedValue<KV<K, InputT>> value, FnDataReceiver<WindowedValue<KV<K, AccumT>>> receiver)
+      throws Exception {
+    // Ignore timestamp for grouping purposes.
+    // The Pre-combine output will inherit the timestamp of one of its inputs.
+    WindowedValue<Object> groupingKey =
+        WindowedValue.of(
+            keyCoder.structuralValue(value.getValue().getKey()),
+            IGNORED,
+            value.getWindows(),
+            value.getPane());
+
+    GroupingTableEntry entry =
+        lruMap.compute(
+            groupingKey,
+            (key, tableEntry) -> {
+              if (tableEntry == null) {
+                tableEntry =
+                    new GroupingTableEntry(
+                        groupingKey, value.getValue().getKey(), value.getValue().getValue());
+              } else {
+                tableEntry.add(value.getValue().getValue());
+              }
+              return tableEntry;
+            });
+    weight += entry.getWeight();

Review Comment:
   Fixed and updated tests since it turned out we weren't accounting for the grouping table key.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@beam.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [beam] lukecwik commented on pull request #17327: [BEAM-13015] Update the SDK harness grouping table to be memory bounded based upon the amount of assigned cache memory and to use an LRU eviction policy.

Posted by GitBox <gi...@apache.org>.
lukecwik commented on PR #17327:
URL: https://github.com/apache/beam/pull/17327#issuecomment-1098315009

   Run Java PreCommit


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@beam.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [beam] lukecwik commented on pull request #17327: [BEAM-13015] Update the SDK harness grouping table to be memory bounded based upon the amount of assigned cache memory and to use an LRU eviction policy.

Posted by GitBox <gi...@apache.org>.
lukecwik commented on PR #17327:
URL: https://github.com/apache/beam/pull/17327#issuecomment-1110348325

   Ping @youngoli 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@beam.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [beam] lukecwik commented on pull request #17327: [BEAM-13015] Update the SDK harness grouping table to be memory bounded based upon the amount of assigned cache memory and to use an LRU eviction policy.

Posted by GitBox <gi...@apache.org>.
lukecwik commented on PR #17327:
URL: https://github.com/apache/beam/pull/17327#issuecomment-1127031951

   Run Java PreCommit


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@beam.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [beam] lukecwik commented on pull request #17327: [BEAM-13015] Update the SDK harness grouping table to be memory bounded based upon the amount of assigned cache memory and to use an LRU eviction policy.

Posted by GitBox <gi...@apache.org>.
lukecwik commented on PR #17327:
URL: https://github.com/apache/beam/pull/17327#issuecomment-1093705714

   R: @youngoli 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@beam.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [beam] lukecwik commented on pull request #17327: [BEAM-13015] Update the SDK harness grouping table to be memory bounded based upon the amount of assigned cache memory and to use an LRU eviction policy.

Posted by GitBox <gi...@apache.org>.
lukecwik commented on PR #17327:
URL: https://github.com/apache/beam/pull/17327#issuecomment-1097359892

   Run Python_PVR_Flink PreCommit


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@beam.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [beam] lukecwik commented on pull request #17327: [BEAM-13015] Update the SDK harness grouping table to be memory bounded based upon the amount of assigned cache memory and to use an LRU eviction policy.

Posted by GitBox <gi...@apache.org>.
lukecwik commented on PR #17327:
URL: https://github.com/apache/beam/pull/17327#issuecomment-1097359751

   Run Java PreCommit


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@beam.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [beam] robertwb commented on pull request #17327: [BEAM-13015] Update the SDK harness grouping table to be memory bounded based upon the amount of assigned cache memory and to use an LRU eviction policy.

Posted by GitBox <gi...@apache.org>.
robertwb commented on PR #17327:
URL: https://github.com/apache/beam/pull/17327#issuecomment-1137863948

   I should note that before either change I was getting on the order of 15k ops/sec. 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@beam.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [beam] lukecwik commented on a diff in pull request #17327: [BEAM-13015] Update the SDK harness grouping table to be memory bounded based upon the amount of assigned cache memory and to use an LRU eviction policy.

Posted by GitBox <gi...@apache.org>.
lukecwik commented on code in PR #17327:
URL: https://github.com/apache/beam/pull/17327#discussion_r872882433


##########
sdks/java/harness/src/main/java/org/apache/beam/fn/harness/PrecombineGroupingTable.java:
##########
@@ -223,27 +323,26 @@ public void put(
       throws Exception {
     // Ignore timestamp for grouping purposes.
     // The Pre-combine output will inherit the timestamp of one of its inputs.
-    WindowedValue<Object> groupingKey =
-        WindowedValue.of(
-            keyCoder.structuralValue(value.getValue().getKey()),
-            IGNORED,
-            value.getWindows(),
-            value.getPane());
-
-    GroupingTableEntry entry =
-        lruMap.compute(
-            groupingKey,
-            (key, tableEntry) -> {
-              if (tableEntry == null) {
-                tableEntry =
-                    new GroupingTableEntry(
-                        groupingKey, value.getValue().getKey(), value.getValue().getValue());
-              } else {
-                tableEntry.add(value.getValue().getValue());
-              }
-              return tableEntry;
-            });
-    weight += entry.getWeight();
+    GroupingTableKey groupingKey =
+        new GroupingTableKey(
+            value.getValue().getKey(), value.getWindows(), value.getPane(), keyCoder, keySizer);
+
+    lruMap.compute(
+        groupingKey,
+        (key, tableEntry) -> {
+          if (tableEntry == null) {
+            weight += groupingKey.getWeight();

Review Comment:
   this adds the weight of the key, and not the value



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@beam.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [beam] y1chi commented on a diff in pull request #17327: [BEAM-13015] Update the SDK harness grouping table to be memory bounded based upon the amount of assigned cache memory and to use an LRU eviction policy.

Posted by GitBox <gi...@apache.org>.
y1chi commented on code in PR #17327:
URL: https://github.com/apache/beam/pull/17327#discussion_r871975236


##########
sdks/java/harness/src/main/java/org/apache/beam/fn/harness/PrecombineGroupingTable.java:
##########
@@ -17,434 +17,283 @@
  */
 package org.apache.beam.fn.harness;
 
-import java.util.HashMap;
 import java.util.Iterator;
-import java.util.Map;
+import java.util.LinkedHashMap;
 import java.util.Random;
+import java.util.concurrent.atomic.AtomicLong;
+import javax.annotation.Nullable;
+import javax.annotation.concurrent.NotThreadSafe;
+import org.apache.beam.fn.harness.Cache.Shrinkable;
 import org.apache.beam.runners.core.GlobalCombineFnRunner;
 import org.apache.beam.runners.core.GlobalCombineFnRunners;
 import org.apache.beam.runners.core.NullSideInputReader;
-import org.apache.beam.runners.core.SideInputReader;
 import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.fn.data.FnDataReceiver;
 import org.apache.beam.sdk.options.PipelineOptions;
-import org.apache.beam.sdk.options.SdkHarnessOptions;
 import org.apache.beam.sdk.transforms.Combine.CombineFn;
 import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.util.Weighted;
 import org.apache.beam.sdk.util.WindowedValue;
-import org.apache.beam.sdk.util.common.ElementByteSizeObserver;
 import org.apache.beam.sdk.values.KV;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.io.ByteStreams;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.io.CountingOutputStream;
 import org.joda.time.Instant;
 
-/** Static utility methods that provide {@link GroupingTable} implementations. */
+/** Static utility methods that provide a grouping table implementation. */
 @SuppressWarnings({
   "nullness" // TODO(https://issues.apache.org/jira/browse/BEAM-10402)
 })
+@NotThreadSafe

Review Comment:
   Document why? Also seems to contradict the requirement of Shrinkable?



##########
sdks/java/harness/src/main/java/org/apache/beam/fn/harness/PrecombineGroupingTable.java:
##########
@@ -17,434 +17,283 @@
  */
 package org.apache.beam.fn.harness;
 
-import java.util.HashMap;
 import java.util.Iterator;
-import java.util.Map;
+import java.util.LinkedHashMap;
 import java.util.Random;
+import java.util.concurrent.atomic.AtomicLong;
+import javax.annotation.Nullable;
+import javax.annotation.concurrent.NotThreadSafe;
+import org.apache.beam.fn.harness.Cache.Shrinkable;
 import org.apache.beam.runners.core.GlobalCombineFnRunner;
 import org.apache.beam.runners.core.GlobalCombineFnRunners;
 import org.apache.beam.runners.core.NullSideInputReader;
-import org.apache.beam.runners.core.SideInputReader;
 import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.fn.data.FnDataReceiver;
 import org.apache.beam.sdk.options.PipelineOptions;
-import org.apache.beam.sdk.options.SdkHarnessOptions;
 import org.apache.beam.sdk.transforms.Combine.CombineFn;
 import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.util.Weighted;
 import org.apache.beam.sdk.util.WindowedValue;
-import org.apache.beam.sdk.util.common.ElementByteSizeObserver;
 import org.apache.beam.sdk.values.KV;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.io.ByteStreams;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.io.CountingOutputStream;
 import org.joda.time.Instant;
 
-/** Static utility methods that provide {@link GroupingTable} implementations. */
+/** Static utility methods that provide a grouping table implementation. */
 @SuppressWarnings({
   "nullness" // TODO(https://issues.apache.org/jira/browse/BEAM-10402)
 })
+@NotThreadSafe
 public class PrecombineGroupingTable<K, InputT, AccumT>
-    implements GroupingTable<K, InputT, AccumT> {
-  private static long getGroupingTableSizeBytes(PipelineOptions options) {
-    return options.as(SdkHarnessOptions.class).getGroupingTableMaxSizeMb() * 1024L * 1024L;
-  }
+    implements Shrinkable<PrecombineGroupingTable<K, InputT, AccumT>>, Weighted {
+
+  private static final Instant IGNORED = BoundedWindow.TIMESTAMP_MIN_VALUE;
 
-  /** Returns a {@link GroupingTable} that combines inputs into a accumulator. */
-  public static <K, InputT, AccumT> GroupingTable<WindowedValue<K>, InputT, AccumT> combining(
+  /**
+   * Returns a grouping table that combines inputs into an accumulator. The grouping table uses the
+   * cache to defer flushing output until the cache evicts the table.
+   */
+  public static <K, InputT, AccumT> PrecombineGroupingTable<K, InputT, AccumT> combining(
       PipelineOptions options,
+      Cache<Object, Object> cache,
       CombineFn<InputT, AccumT, ?> combineFn,
-      Coder<K> keyCoder,
-      Coder<? super AccumT> accumulatorCoder) {
-    Combiner<WindowedValue<K>, InputT, AccumT, ?> valueCombiner =
-        new ValueCombiner<>(
-            GlobalCombineFnRunners.create(combineFn), NullSideInputReader.empty(), options);
+      Coder<K> keyCoder) {
     return new PrecombineGroupingTable<>(
-        getGroupingTableSizeBytes(options),
-        new WindowingCoderGroupingKeyCreator<>(keyCoder),
-        WindowedPairInfo.create(),
-        valueCombiner,
-        new CoderSizeEstimator<>(WindowedValue.getValueOnlyCoder(keyCoder)),
-        new CoderSizeEstimator<>(accumulatorCoder));
+        options,
+        cache,
+        keyCoder,
+        GlobalCombineFnRunners.create(combineFn),
+        Caches::weigh,
+        Caches::weigh);
   }
 
   /**
-   * Returns a {@link GroupingTable} that combines inputs into a accumulator with sampling {@link
-   * SizeEstimator SizeEstimators}.
+   * Returns a grouping table that combines inputs into an accumulator with sampling {@link
+   * SizeEstimator SizeEstimators}. The grouping table uses the cache to defer flushing output until
+   * the cache evicts the table.
    */
-  public static <K, InputT, AccumT>
-      GroupingTable<WindowedValue<K>, InputT, AccumT> combiningAndSampling(
-          PipelineOptions options,
-          CombineFn<InputT, AccumT, ?> combineFn,
-          Coder<K> keyCoder,
-          Coder<? super AccumT> accumulatorCoder,
-          double sizeEstimatorSampleRate) {
-    Combiner<WindowedValue<K>, InputT, AccumT, ?> valueCombiner =
-        new ValueCombiner<>(
-            GlobalCombineFnRunners.create(combineFn), NullSideInputReader.empty(), options);
+  public static <K, InputT, AccumT> PrecombineGroupingTable<K, InputT, AccumT> combiningAndSampling(
+      PipelineOptions options,
+      Cache<Object, Object> cache,
+      CombineFn<InputT, AccumT, ?> combineFn,
+      Coder<K> keyCoder,
+      double sizeEstimatorSampleRate) {
     return new PrecombineGroupingTable<>(
-        getGroupingTableSizeBytes(options),
-        new WindowingCoderGroupingKeyCreator<>(keyCoder),
-        WindowedPairInfo.create(),
-        valueCombiner,
-        new SamplingSizeEstimator<>(
-            new CoderSizeEstimator<>(WindowedValue.getValueOnlyCoder(keyCoder)),
-            sizeEstimatorSampleRate,
-            1.0),
-        new SamplingSizeEstimator<>(
-            new CoderSizeEstimator<>(accumulatorCoder), sizeEstimatorSampleRate, 1.0));
-  }
-
-  /** Provides client-specific operations for grouping keys. */
-  public interface GroupingKeyCreator<K> {
-    Object createGroupingKey(K key) throws Exception;
+        options,
+        cache,
+        keyCoder,
+        GlobalCombineFnRunners.create(combineFn),
+        new SamplingSizeEstimator<>(Caches::weigh, sizeEstimatorSampleRate, 1.0),
+        new SamplingSizeEstimator<>(Caches::weigh, sizeEstimatorSampleRate, 1.0));
   }
 
-  /** Implements Precombine GroupingKeyCreator via Coder. */
-  public static class WindowingCoderGroupingKeyCreator<K>
-      implements GroupingKeyCreator<WindowedValue<K>> {
-
-    private static final Instant ignored = BoundedWindow.TIMESTAMP_MIN_VALUE;
-
-    private final Coder<K> coder;
-
-    WindowingCoderGroupingKeyCreator(Coder<K> coder) {
-      this.coder = coder;
+  @Nullable
+  @Override
+  public PrecombineGroupingTable<K, InputT, AccumT> shrink() {
+    long currentWeight = maxWeight.updateAndGet(operand -> operand >> 1);
+    // It is possible that we are shrunk multiple times until the requested max weight is too small.
+    // In this case we want to effectively stop shrinking since we can't effectively cache much
+    // at this time and the next insertion will likely evict all records.
+    if (currentWeight <= 100L) {
+      return null;
     }
+    return this;
+  }
 
-    @Override
-    public Object createGroupingKey(WindowedValue<K> key) {
-      // Ignore timestamp for grouping purposes.
-      // The Precombine output will inherit the timestamp of one of its inputs.
-      return WindowedValue.of(
-          coder.structuralValue(key.getValue()), ignored, key.getWindows(), key.getPane());
-    }
+  @Override
+  public long getWeight() {
+    return maxWeight.get();
   }
 
   /** Provides client-specific operations for size estimates. */
+  @FunctionalInterface
   public interface SizeEstimator<T> {
-    long estimateSize(T element) throws Exception;
+    long estimateSize(T element);
   }
 
-  /** Implements SizeEstimator via Coder. */
-  public static class CoderSizeEstimator<T> implements SizeEstimator<T> {
-    /** Basic implementation of {@link ElementByteSizeObserver} for use in size estimation. */
-    private static class Observer extends ElementByteSizeObserver {
-      private long observedSize = 0;
-
-      @Override
-      protected void reportElementSize(long elementSize) {
-        observedSize += elementSize;
-      }
-    }
-
-    final Coder<T> coder;
+  private final Coder<K> keyCoder;
+  private final GlobalCombineFnRunner<InputT, AccumT, ?> combineFn;
+  private final PipelineOptions options;
+  private final SizeEstimator<K> keySizer;
+  private final SizeEstimator<AccumT> accumulatorSizer;
+  private final Cache<Key, PrecombineGroupingTable<K, InputT, AccumT>> cache;
+  private final LinkedHashMap<WindowedValue<Object>, GroupingTableEntry> lruMap;
+  private final AtomicLong maxWeight;
+  private long weight;
 
-    CoderSizeEstimator(Coder<T> coder) {
-      this.coder = coder;
-    }
+  private static final class Key implements Weighted {
+    private static final Key INSTANCE = new Key();
 
     @Override
-    public long estimateSize(T value) throws Exception {
-      // First try using byte size observer
-      CoderSizeEstimator.Observer observer = new CoderSizeEstimator.Observer();
-      coder.registerByteSizeObserver(value, observer);
-
-      if (!observer.getIsLazy()) {
-        observer.advance();
-        return observer.observedSize;
-      } else {
-        // Coder byte size observation is lazy (requires iteration for observation) so fall back to
-        // counting output stream
-        CountingOutputStream os = new CountingOutputStream(ByteStreams.nullOutputStream());
-        coder.encode(value, os);
-        return os.getCount();
-      }
+    public long getWeight() {
+      // Ignore the actual size of this singleton because it is trivial and because
+      // the weight reported here will be counted many times as it is present in
+      // many different state subcaches.
+      return 0;
     }
   }
 
-  /**
-   * Provides client-specific operations for working with elements that are key/value or key/values
-   * pairs.
-   */
-  public interface PairInfo {
-    Object getKeyFromInputPair(Object pair);
-
-    Object getValueFromInputPair(Object pair);
-
-    Object makeOutputPair(Object key, Object value);
-  }
-
-  /** Implements Precombine PairInfo via KVs. */
-  public static class WindowedPairInfo implements PairInfo {
-    private static WindowedPairInfo theInstance = new WindowedPairInfo();
-
-    public static WindowedPairInfo create() {
-      return theInstance;
-    }
-
-    private WindowedPairInfo() {}
-
-    @Override
-    public Object getKeyFromInputPair(Object pair) {
-      @SuppressWarnings("unchecked")
-      WindowedValue<KV<?, ?>> windowedKv = (WindowedValue<KV<?, ?>>) pair;
-      return windowedKv.withValue(windowedKv.getValue().getKey());
-    }
-
-    @Override
-    public Object getValueFromInputPair(Object pair) {
-      @SuppressWarnings("unchecked")
-      WindowedValue<KV<?, ?>> windowedKv = (WindowedValue<KV<?, ?>>) pair;
-      return windowedKv.getValue().getValue();
-    }
-
-    @Override
-    public Object makeOutputPair(Object key, Object values) {
-      WindowedValue<?> windowedKey = (WindowedValue<?>) key;
-      return windowedKey.withValue(KV.of(windowedKey.getValue(), values));
-    }
-  }
-
-  /** Provides client-specific operations for combining values. */
-  public interface Combiner<K, InputT, AccumT, OutputT> {
-    AccumT createAccumulator(K key);
-
-    AccumT add(K key, AccumT accumulator, InputT value);
-
-    AccumT merge(K key, Iterable<AccumT> accumulators);
-
-    AccumT compact(K key, AccumT accumulator);
-
-    OutputT extract(K key, AccumT accumulator);
+  PrecombineGroupingTable(
+      PipelineOptions options,
+      Cache<?, ?> cache,
+      Coder<K> keyCoder,
+      GlobalCombineFnRunner<InputT, AccumT, ?> combineFn,
+      SizeEstimator<K> keySizer,
+      SizeEstimator<AccumT> accumulatorSizer) {
+    this.options = options;
+    this.cache = (Cache<Key, PrecombineGroupingTable<K, InputT, AccumT>>) cache;
+    this.keyCoder = keyCoder;
+    this.combineFn = combineFn;
+    this.keySizer = keySizer;
+    this.accumulatorSizer = accumulatorSizer;
+    this.lruMap = new LinkedHashMap<>(16, 0.75f, true);
+    this.maxWeight = new AtomicLong();
+    this.weight = 0L;
+    this.cache.put(Key.INSTANCE, this);
   }
 
-  /** Implements Precombine Combiner via Combine.KeyedCombineFn. */
-  public static class ValueCombiner<K, InputT, AccumT, OutputT>
-      implements Combiner<WindowedValue<K>, InputT, AccumT, OutputT> {
-    private final GlobalCombineFnRunner<InputT, AccumT, OutputT> combineFn;
-    private final SideInputReader sideInputReader;
-    private final PipelineOptions options;
-
-    private ValueCombiner(
-        GlobalCombineFnRunner<InputT, AccumT, OutputT> combineFn,
-        SideInputReader sideInputReader,
-        PipelineOptions options) {
-      this.combineFn = combineFn;
-      this.sideInputReader = sideInputReader;
-      this.options = options;
-    }
-
-    @Override
-    public AccumT createAccumulator(WindowedValue<K> windowedKey) {
-      return this.combineFn.createAccumulator(options, sideInputReader, windowedKey.getWindows());
+  private class GroupingTableEntry implements Weighted {
+    private final WindowedValue<Object> groupingKey;
+    private final K userKey;
+    private final long keySize;
+    private long accumulatorSize;
+    private AccumT accumulator;
+    private boolean dirty;
+
+    private GroupingTableEntry(
+        WindowedValue<Object> groupingKey, K userKey, InputT initialInputValue) {
+      this.groupingKey = groupingKey;
+      this.userKey = userKey;
+      if (groupingKey.getValue() == userKey) {
+        // This object is only storing references to the same objects that are being stored
+        // by the cache so the accounting of the size of the key is occurring already.
+        this.keySize = Caches.REFERENCE_SIZE * 2;
+      } else {
+        this.keySize = Caches.REFERENCE_SIZE + keySizer.estimateSize(userKey);
+      }
+      this.accumulator =
+          combineFn.createAccumulator(
+              options, NullSideInputReader.empty(), groupingKey.getWindows());
+      add(initialInputValue);
+      this.accumulatorSize = accumulatorSizer.estimateSize(accumulator);
     }
 
-    @Override
-    public AccumT add(WindowedValue<K> windowedKey, AccumT accumulator, InputT value) {
-      return this.combineFn.addInput(
-          accumulator, value, options, sideInputReader, windowedKey.getWindows());
+    public WindowedValue<Object> getGroupingKey() {
+      return groupingKey;
     }
 
-    @Override
-    public AccumT merge(WindowedValue<K> windowedKey, Iterable<AccumT> accumulators) {
-      return this.combineFn.mergeAccumulators(
-          accumulators, options, sideInputReader, windowedKey.getWindows());
+    public K getKey() {
+      return userKey;
     }
 
-    @Override
-    public AccumT compact(WindowedValue<K> windowedKey, AccumT accumulator) {
-      return this.combineFn.compact(
-          accumulator, options, sideInputReader, windowedKey.getWindows());
+    public AccumT getValue() {
+      return accumulator;
     }
 
     @Override
-    public OutputT extract(WindowedValue<K> windowedKey, AccumT accumulator) {
-      return this.combineFn.extractOutput(
-          accumulator, options, sideInputReader, windowedKey.getWindows());
+    public long getWeight() {
+      return keySize + accumulatorSize;
     }
-  }
-
-  // How many bytes a word in the JVM has.
-  private static final int BYTES_PER_JVM_WORD = getBytesPerJvmWord();
-  /**
-   * The number of bytes of overhead to store an entry in the grouping table (a {@code
-   * HashMap<StructuralByteArray, KeyAndValues>}), ignoring the actual number of bytes in the keys
-   * and values:
-   *
-   * <ul>
-   *   <li>an array element (1 word),
-   *   <li>a HashMap.Entry (4 words),
-   *   <li>a StructuralByteArray (1 words),
-   *   <li>a backing array (guessed at 1 word for the length),
-   *   <li>a KeyAndValues (2 words),
-   *   <li>an ArrayList (2 words),
-   *   <li>a backing array (1 word),
-   *   <li>per-object overhead (JVM-specific, guessed at 2 words * 6 objects).
-   * </ul>
-   */
-  private static final int PER_KEY_OVERHEAD = 24 * BYTES_PER_JVM_WORD;
-
-  /** A {@link GroupingTable} that uses the given combiner to combine values in place. */
-  // Keep the table relatively full to increase the chance of collisions.
-  private static final double TARGET_LOAD = 0.9;
-
-  private long maxSize;
-  private final GroupingKeyCreator<? super K> groupingKeyCreator;
-  private final PairInfo pairInfo;
-  private final Combiner<? super K, InputT, AccumT, ?> combiner;
-  private final SizeEstimator<? super K> keySizer;
-  private final SizeEstimator<? super AccumT> accumulatorSizer;
-
-  private long size = 0;
-  private Map<Object, GroupingTableEntry<K, InputT, AccumT>> table;
-
-  PrecombineGroupingTable(
-      long maxSize,
-      GroupingKeyCreator<? super K> groupingKeyCreator,
-      PairInfo pairInfo,
-      Combiner<? super K, InputT, AccumT, ?> combineFn,
-      SizeEstimator<? super K> keySizer,
-      SizeEstimator<? super AccumT> accumulatorSizer) {
-    this.maxSize = maxSize;
-    this.groupingKeyCreator = groupingKeyCreator;
-    this.pairInfo = pairInfo;
-    this.combiner = combineFn;
-    this.keySizer = keySizer;
-    this.accumulatorSizer = accumulatorSizer;
-    this.table = new HashMap<>();
-  }
-
-  interface GroupingTableEntry<K, InputT, AccumT> {
-    K getKey();
-
-    AccumT getValue();
-
-    void add(InputT value) throws Exception;
-
-    long getSize();
-
-    void compact() throws Exception;
-  }
-
-  private GroupingTableEntry<K, InputT, AccumT> createTableEntry(final K key) throws Exception {
-    return new GroupingTableEntry<K, InputT, AccumT>() {
-      final long keySize = keySizer.estimateSize(key);
-      AccumT accumulator = combiner.createAccumulator(key);
-      long accumulatorSize = 0; // never used before a value is added...
-
-      @Override
-      public K getKey() {
-        return key;
-      }
-
-      @Override
-      public AccumT getValue() {
-        return accumulator;
-      }
-
-      @Override
-      public long getSize() {
-        return keySize + accumulatorSize;
-      }
-
-      @Override
-      public void compact() throws Exception {
-        AccumT newAccumulator = combiner.compact(key, accumulator);
-        if (newAccumulator != accumulator) {
-          accumulator = newAccumulator;
-          accumulatorSize = accumulatorSizer.estimateSize(newAccumulator);
-        }
-      }
 
-      @Override
-      public void add(InputT value) throws Exception {
-        accumulator = combiner.add(key, accumulator, value);
+    public void compact() {
+      if (dirty) {
+        accumulator =
+            combineFn.compact(
+                accumulator, options, NullSideInputReader.empty(), groupingKey.getWindows());
         accumulatorSize = accumulatorSizer.estimateSize(accumulator);
+        dirty = false;
       }
-    };
-  }
+    }
 
-  /** Adds a pair to this table, possibly flushing some entries to output if the table is full. */
-  @SuppressWarnings("unchecked")
-  @Override
-  public void put(Object pair, Receiver receiver) throws Exception {
-    put(
-        (K) pairInfo.getKeyFromInputPair(pair),
-        (InputT) pairInfo.getValueFromInputPair(pair),
-        receiver);
+    public void add(InputT value) {
+      dirty = true;
+      accumulator =
+          combineFn.addInput(
+              accumulator, value, options, NullSideInputReader.empty(), groupingKey.getWindows());
+      accumulatorSize = accumulatorSizer.estimateSize(accumulator);
+    }
   }
 
   /**
    * Adds the key and value to this table, possibly flushing some entries to output if the table is
    * full.
    */
-  public void put(K key, InputT value, Receiver receiver) throws Exception {
-    Object groupingKey = groupingKeyCreator.createGroupingKey(key);
-    GroupingTableEntry<K, InputT, AccumT> entry = table.get(groupingKey);
-    if (entry == null) {
-      entry = createTableEntry(key);
-      table.put(groupingKey, entry);
-      size += PER_KEY_OVERHEAD;
-    } else {
-      size -= entry.getSize();
-    }
-    entry.add(value);
-    size += entry.getSize();
-
-    if (size >= maxSize) {
-      long targetSize = (long) (TARGET_LOAD * maxSize);
-      Iterator<GroupingTableEntry<K, InputT, AccumT>> entries = table.values().iterator();
-      while (size >= targetSize) {
-        if (!entries.hasNext()) {
-          // Should never happen, but sizes may be estimates...
-          size = 0;
-          break;
+  @VisibleForTesting
+  public void put(
+      WindowedValue<KV<K, InputT>> value, FnDataReceiver<WindowedValue<KV<K, AccumT>>> receiver)
+      throws Exception {
+    // Ignore timestamp for grouping purposes.
+    // The Pre-combine output will inherit the timestamp of one of its inputs.
+    WindowedValue<Object> groupingKey =
+        WindowedValue.of(
+            keyCoder.structuralValue(value.getValue().getKey()),
+            IGNORED,
+            value.getWindows(),
+            value.getPane());
+
+    GroupingTableEntry entry =
+        lruMap.compute(
+            groupingKey,
+            (key, tableEntry) -> {
+              if (tableEntry == null) {
+                tableEntry =
+                    new GroupingTableEntry(
+                        groupingKey, value.getValue().getKey(), value.getValue().getValue());
+              } else {
+                tableEntry.add(value.getValue().getValue());
+              }
+              return tableEntry;
+            });
+    weight += entry.getWeight();

Review Comment:
   is this accurate if entry is not new?



##########
sdks/java/harness/src/test/java/org/apache/beam/fn/harness/PrecombineGroupingTableTest.java:
##########
@@ -46,79 +60,207 @@
 @RunWith(JUnit4.class)
 public class PrecombineGroupingTableTest {
 
-  private static class TestOutputReceiver implements Receiver {
-    final List<Object> outputElems = new ArrayList<>();
+  @Rule
+  public TestExecutorService executorService = TestExecutors.from(Executors.newCachedThreadPool());
+
+  private static class TestOutputReceiver<T> implements FnDataReceiver<T> {
+    final List<T> outputElems = new ArrayList<>();
 
     @Override
-    public void process(Object elem) {
+    public void accept(T elem) {
       outputElems.add(elem);
     }
   }
 
-  @Test
-  public void testCombiningGroupingTable() throws Exception {
-    Combiner<Object, Integer, Long, Long> summingCombineFn =
-        new Combiner<Object, Integer, Long, Long>() {
+  private static final CombineFn<Integer, Long, Long> COMBINE_FN =
+      new CombineFn<Integer, Long, Long>() {
 
-          @Override
-          public Long createAccumulator(Object key) {
-            return 0L;
-          }
+        @Override
+        public Long createAccumulator() {
+          return 0L;
+        }
 
-          @Override
-          public Long add(Object key, Long accumulator, Integer value) {
-            return accumulator + value;
-          }
+        @Override
+        public Long addInput(Long accumulator, Integer value) {
+          return accumulator + value;
+        }
 
-          @Override
-          public Long merge(Object key, Iterable<Long> accumulators) {
-            long sum = 0;
-            for (Long part : accumulators) {
-              sum += part;
-            }
-            return sum;
+        @Override
+        public Long mergeAccumulators(Iterable<Long> accumulators) {
+          long sum = 0;
+          for (Long part : accumulators) {
+            sum += part;
           }
+          return sum;
+        }
 
-          @Override
-          public Long compact(Object key, Long accumulator) {
-            return accumulator;
+        @Override
+        public Long compact(Long accumulator) {
+          if (accumulator % 2 == 0) {
+            return accumulator / 4;
           }
+          return accumulator;
+        }
 
-          @Override
-          public Long extract(Object key, Long accumulator) {
-            return accumulator;
-          }
-        };
+        @Override
+        public Long extractOutput(Long accumulator) {
+          return accumulator;
+        }
+      };
 
+  @Test
+  public void testCombiningGroupingTableEvictsAllOnLargeEntry() throws Exception {
     PrecombineGroupingTable<String, Integer, Long> table =
         new PrecombineGroupingTable<>(
-            100_000_000L,
-            new IdentityGroupingKeyCreator(),
-            new KvPairInfo(),
-            summingCombineFn,
+            PipelineOptionsFactory.create(),
+            Caches.forMaximumBytes(2500L),
+            StringUtf8Coder.of(),
+            GlobalCombineFnRunners.create(COMBINE_FN),
             new StringPowerSizeEstimator(),
             new IdentitySizeEstimator());
-    table.setMaxSize(1000);
 
-    TestOutputReceiver receiver = new TestOutputReceiver();
+    TestOutputReceiver<WindowedValue<KV<String, Long>>> receiver = new TestOutputReceiver<>();
 
-    table.put("A", 1, receiver);
-    table.put("B", 2, receiver);
-    table.put("B", 3, receiver);
-    table.put("C", 4, receiver);
+    table.put(valueInGlobalWindow(KV.of("A", 1)), receiver);
+    table.put(valueInGlobalWindow(KV.of("B", 3)), receiver);
+    table.put(valueInGlobalWindow(KV.of("B", 6)), receiver);
+    table.put(valueInGlobalWindow(KV.of("C", 7)), receiver);
     assertThat(receiver.outputElems, empty());
 
-    table.put("C", 5000, receiver);
-    assertThat(receiver.outputElems, hasItem((Object) KV.of("C", 5004L)));
+    // Add beyond the size which causes compaction which still leads to evicting all since the
+    // largest is most recent.
+    table.put(valueInGlobalWindow(KV.of("C", 9999)), receiver);
+    assertThat(
+        receiver.outputElems,
+        containsInAnyOrder(
+            valueInGlobalWindow(KV.of("A", 1L)),
+            valueInGlobalWindow(KV.of("B", 9L)),
+            valueInGlobalWindow(KV.of("C", (9999L + 7) / 4))));
+
+    table.flush(receiver);
+    assertThat(
+        receiver.outputElems,
+        containsInAnyOrder(
+            valueInGlobalWindow(KV.of("A", 1L)),
+            valueInGlobalWindow(KV.of("B", 3L + 6)),
+            valueInGlobalWindow(KV.of("C", (9999L + 7) / 4))));
+  }
+
+  @Test
+  public void testCombiningGroupingTableCompactionSaves() throws Exception {
+    PrecombineGroupingTable<String, Integer, Long> table =
+        new PrecombineGroupingTable<>(
+            PipelineOptionsFactory.create(),
+            Caches.forMaximumBytes(2500L),
+            StringUtf8Coder.of(),
+            GlobalCombineFnRunners.create(COMBINE_FN),
+            new StringPowerSizeEstimator(),
+            new IdentitySizeEstimator());
+
+    TestOutputReceiver<WindowedValue<KV<String, Long>>> receiver = new TestOutputReceiver<>();
+
+    // Insert three compactable values which shouldn't lead to eviction even though we are over
+    // the maximum size.
+    table.put(valueInGlobalWindow(KV.of("A", 1004)), receiver);
+    table.put(valueInGlobalWindow(KV.of("B", 1004)), receiver);
+    table.put(valueInGlobalWindow(KV.of("C", 1004)), receiver);
+    assertThat(receiver.outputElems, empty());
+
+    table.flush(receiver);
+    assertThat(
+        receiver.outputElems,
+        containsInAnyOrder(
+            valueInGlobalWindow(KV.of("A", 1004L / 4)),
+            valueInGlobalWindow(KV.of("B", 1004L / 4)),
+            valueInGlobalWindow(KV.of("C", 1004L / 4))));
+  }
+
+  @Test
+  public void testCombiningGroupingTablePartialEviction() throws Exception {
+    PrecombineGroupingTable<String, Integer, Long> table =
+        new PrecombineGroupingTable<>(
+            PipelineOptionsFactory.create(),
+            Caches.forMaximumBytes(2500L),
+            StringUtf8Coder.of(),
+            GlobalCombineFnRunners.create(COMBINE_FN),
+            new StringPowerSizeEstimator(),
+            new IdentitySizeEstimator());
+
+    TestOutputReceiver<WindowedValue<KV<String, Long>>> receiver = new TestOutputReceiver<>();
 
-    table.put("DDDD", 6, receiver);
-    assertThat(receiver.outputElems, hasItem((Object) KV.of("DDDD", 6L)));
+    // Insert three values which even with compaction isn't enough so we evict D & E to get

Review Comment:
   s/D & E/A & B/



##########
sdks/java/harness/src/main/java/org/apache/beam/fn/harness/PrecombineGroupingTable.java:
##########
@@ -17,434 +17,283 @@
  */
 package org.apache.beam.fn.harness;
 
-import java.util.HashMap;
 import java.util.Iterator;
-import java.util.Map;
+import java.util.LinkedHashMap;
 import java.util.Random;
+import java.util.concurrent.atomic.AtomicLong;
+import javax.annotation.Nullable;
+import javax.annotation.concurrent.NotThreadSafe;
+import org.apache.beam.fn.harness.Cache.Shrinkable;
 import org.apache.beam.runners.core.GlobalCombineFnRunner;
 import org.apache.beam.runners.core.GlobalCombineFnRunners;
 import org.apache.beam.runners.core.NullSideInputReader;
-import org.apache.beam.runners.core.SideInputReader;
 import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.fn.data.FnDataReceiver;
 import org.apache.beam.sdk.options.PipelineOptions;
-import org.apache.beam.sdk.options.SdkHarnessOptions;
 import org.apache.beam.sdk.transforms.Combine.CombineFn;
 import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.util.Weighted;
 import org.apache.beam.sdk.util.WindowedValue;
-import org.apache.beam.sdk.util.common.ElementByteSizeObserver;
 import org.apache.beam.sdk.values.KV;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.io.ByteStreams;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.io.CountingOutputStream;
 import org.joda.time.Instant;
 
-/** Static utility methods that provide {@link GroupingTable} implementations. */
+/** Static utility methods that provide a grouping table implementation. */
 @SuppressWarnings({
   "nullness" // TODO(https://issues.apache.org/jira/browse/BEAM-10402)
 })
+@NotThreadSafe
 public class PrecombineGroupingTable<K, InputT, AccumT>
-    implements GroupingTable<K, InputT, AccumT> {
-  private static long getGroupingTableSizeBytes(PipelineOptions options) {
-    return options.as(SdkHarnessOptions.class).getGroupingTableMaxSizeMb() * 1024L * 1024L;
-  }
+    implements Shrinkable<PrecombineGroupingTable<K, InputT, AccumT>>, Weighted {
+
+  private static final Instant IGNORED = BoundedWindow.TIMESTAMP_MIN_VALUE;
 
-  /** Returns a {@link GroupingTable} that combines inputs into a accumulator. */
-  public static <K, InputT, AccumT> GroupingTable<WindowedValue<K>, InputT, AccumT> combining(
+  /**
+   * Returns a grouping table that combines inputs into an accumulator. The grouping table uses the
+   * cache to defer flushing output until the cache evicts the table.
+   */
+  public static <K, InputT, AccumT> PrecombineGroupingTable<K, InputT, AccumT> combining(
       PipelineOptions options,
+      Cache<Object, Object> cache,
       CombineFn<InputT, AccumT, ?> combineFn,
-      Coder<K> keyCoder,
-      Coder<? super AccumT> accumulatorCoder) {
-    Combiner<WindowedValue<K>, InputT, AccumT, ?> valueCombiner =
-        new ValueCombiner<>(
-            GlobalCombineFnRunners.create(combineFn), NullSideInputReader.empty(), options);
+      Coder<K> keyCoder) {
     return new PrecombineGroupingTable<>(
-        getGroupingTableSizeBytes(options),
-        new WindowingCoderGroupingKeyCreator<>(keyCoder),
-        WindowedPairInfo.create(),
-        valueCombiner,
-        new CoderSizeEstimator<>(WindowedValue.getValueOnlyCoder(keyCoder)),
-        new CoderSizeEstimator<>(accumulatorCoder));
+        options,
+        cache,
+        keyCoder,
+        GlobalCombineFnRunners.create(combineFn),
+        Caches::weigh,
+        Caches::weigh);
   }
 
   /**
-   * Returns a {@link GroupingTable} that combines inputs into a accumulator with sampling {@link
-   * SizeEstimator SizeEstimators}.
+   * Returns a grouping table that combines inputs into an accumulator with sampling {@link
+   * SizeEstimator SizeEstimators}. The grouping table uses the cache to defer flushing output until
+   * the cache evicts the table.
    */
-  public static <K, InputT, AccumT>
-      GroupingTable<WindowedValue<K>, InputT, AccumT> combiningAndSampling(
-          PipelineOptions options,
-          CombineFn<InputT, AccumT, ?> combineFn,
-          Coder<K> keyCoder,
-          Coder<? super AccumT> accumulatorCoder,
-          double sizeEstimatorSampleRate) {
-    Combiner<WindowedValue<K>, InputT, AccumT, ?> valueCombiner =
-        new ValueCombiner<>(
-            GlobalCombineFnRunners.create(combineFn), NullSideInputReader.empty(), options);
+  public static <K, InputT, AccumT> PrecombineGroupingTable<K, InputT, AccumT> combiningAndSampling(
+      PipelineOptions options,
+      Cache<Object, Object> cache,
+      CombineFn<InputT, AccumT, ?> combineFn,
+      Coder<K> keyCoder,
+      double sizeEstimatorSampleRate) {
     return new PrecombineGroupingTable<>(
-        getGroupingTableSizeBytes(options),
-        new WindowingCoderGroupingKeyCreator<>(keyCoder),
-        WindowedPairInfo.create(),
-        valueCombiner,
-        new SamplingSizeEstimator<>(
-            new CoderSizeEstimator<>(WindowedValue.getValueOnlyCoder(keyCoder)),
-            sizeEstimatorSampleRate,
-            1.0),
-        new SamplingSizeEstimator<>(
-            new CoderSizeEstimator<>(accumulatorCoder), sizeEstimatorSampleRate, 1.0));
-  }
-
-  /** Provides client-specific operations for grouping keys. */
-  public interface GroupingKeyCreator<K> {
-    Object createGroupingKey(K key) throws Exception;
+        options,
+        cache,
+        keyCoder,
+        GlobalCombineFnRunners.create(combineFn),
+        new SamplingSizeEstimator<>(Caches::weigh, sizeEstimatorSampleRate, 1.0),
+        new SamplingSizeEstimator<>(Caches::weigh, sizeEstimatorSampleRate, 1.0));
   }
 
-  /** Implements Precombine GroupingKeyCreator via Coder. */
-  public static class WindowingCoderGroupingKeyCreator<K>
-      implements GroupingKeyCreator<WindowedValue<K>> {
-
-    private static final Instant ignored = BoundedWindow.TIMESTAMP_MIN_VALUE;
-
-    private final Coder<K> coder;
-
-    WindowingCoderGroupingKeyCreator(Coder<K> coder) {
-      this.coder = coder;
+  @Nullable
+  @Override
+  public PrecombineGroupingTable<K, InputT, AccumT> shrink() {
+    long currentWeight = maxWeight.updateAndGet(operand -> operand >> 1);
+    // It is possible that we are shrunk multiple times until the requested max weight is too small.
+    // In this case we want to effectively stop shrinking since we can't effectively cache much
+    // at this time and the next insertion will likely evict all records.
+    if (currentWeight <= 100L) {
+      return null;
     }
+    return this;
+  }
 
-    @Override
-    public Object createGroupingKey(WindowedValue<K> key) {
-      // Ignore timestamp for grouping purposes.
-      // The Precombine output will inherit the timestamp of one of its inputs.
-      return WindowedValue.of(
-          coder.structuralValue(key.getValue()), ignored, key.getWindows(), key.getPane());
-    }
+  @Override
+  public long getWeight() {
+    return maxWeight.get();
   }
 
   /** Provides client-specific operations for size estimates. */
+  @FunctionalInterface
   public interface SizeEstimator<T> {
-    long estimateSize(T element) throws Exception;
+    long estimateSize(T element);
   }
 
-  /** Implements SizeEstimator via Coder. */
-  public static class CoderSizeEstimator<T> implements SizeEstimator<T> {
-    /** Basic implementation of {@link ElementByteSizeObserver} for use in size estimation. */
-    private static class Observer extends ElementByteSizeObserver {
-      private long observedSize = 0;
-
-      @Override
-      protected void reportElementSize(long elementSize) {
-        observedSize += elementSize;
-      }
-    }
-
-    final Coder<T> coder;
+  private final Coder<K> keyCoder;
+  private final GlobalCombineFnRunner<InputT, AccumT, ?> combineFn;
+  private final PipelineOptions options;
+  private final SizeEstimator<K> keySizer;
+  private final SizeEstimator<AccumT> accumulatorSizer;
+  private final Cache<Key, PrecombineGroupingTable<K, InputT, AccumT>> cache;
+  private final LinkedHashMap<WindowedValue<Object>, GroupingTableEntry> lruMap;
+  private final AtomicLong maxWeight;
+  private long weight;
 
-    CoderSizeEstimator(Coder<T> coder) {
-      this.coder = coder;
-    }
+  private static final class Key implements Weighted {
+    private static final Key INSTANCE = new Key();
 
     @Override
-    public long estimateSize(T value) throws Exception {
-      // First try using byte size observer
-      CoderSizeEstimator.Observer observer = new CoderSizeEstimator.Observer();
-      coder.registerByteSizeObserver(value, observer);
-
-      if (!observer.getIsLazy()) {
-        observer.advance();
-        return observer.observedSize;
-      } else {
-        // Coder byte size observation is lazy (requires iteration for observation) so fall back to
-        // counting output stream
-        CountingOutputStream os = new CountingOutputStream(ByteStreams.nullOutputStream());
-        coder.encode(value, os);
-        return os.getCount();
-      }
+    public long getWeight() {
+      // Ignore the actual size of this singleton because it is trivial and because
+      // the weight reported here will be counted many times as it is present in
+      // many different state subcaches.
+      return 0;
     }
   }
 
-  /**
-   * Provides client-specific operations for working with elements that are key/value or key/values
-   * pairs.
-   */
-  public interface PairInfo {
-    Object getKeyFromInputPair(Object pair);
-
-    Object getValueFromInputPair(Object pair);
-
-    Object makeOutputPair(Object key, Object value);
-  }
-
-  /** Implements Precombine PairInfo via KVs. */
-  public static class WindowedPairInfo implements PairInfo {
-    private static WindowedPairInfo theInstance = new WindowedPairInfo();
-
-    public static WindowedPairInfo create() {
-      return theInstance;
-    }
-
-    private WindowedPairInfo() {}
-
-    @Override
-    public Object getKeyFromInputPair(Object pair) {
-      @SuppressWarnings("unchecked")
-      WindowedValue<KV<?, ?>> windowedKv = (WindowedValue<KV<?, ?>>) pair;
-      return windowedKv.withValue(windowedKv.getValue().getKey());
-    }
-
-    @Override
-    public Object getValueFromInputPair(Object pair) {
-      @SuppressWarnings("unchecked")
-      WindowedValue<KV<?, ?>> windowedKv = (WindowedValue<KV<?, ?>>) pair;
-      return windowedKv.getValue().getValue();
-    }
-
-    @Override
-    public Object makeOutputPair(Object key, Object values) {
-      WindowedValue<?> windowedKey = (WindowedValue<?>) key;
-      return windowedKey.withValue(KV.of(windowedKey.getValue(), values));
-    }
-  }
-
-  /** Provides client-specific operations for combining values. */
-  public interface Combiner<K, InputT, AccumT, OutputT> {
-    AccumT createAccumulator(K key);
-
-    AccumT add(K key, AccumT accumulator, InputT value);
-
-    AccumT merge(K key, Iterable<AccumT> accumulators);
-
-    AccumT compact(K key, AccumT accumulator);
-
-    OutputT extract(K key, AccumT accumulator);
+  PrecombineGroupingTable(
+      PipelineOptions options,
+      Cache<?, ?> cache,
+      Coder<K> keyCoder,
+      GlobalCombineFnRunner<InputT, AccumT, ?> combineFn,
+      SizeEstimator<K> keySizer,
+      SizeEstimator<AccumT> accumulatorSizer) {
+    this.options = options;
+    this.cache = (Cache<Key, PrecombineGroupingTable<K, InputT, AccumT>>) cache;
+    this.keyCoder = keyCoder;
+    this.combineFn = combineFn;
+    this.keySizer = keySizer;
+    this.accumulatorSizer = accumulatorSizer;
+    this.lruMap = new LinkedHashMap<>(16, 0.75f, true);
+    this.maxWeight = new AtomicLong();
+    this.weight = 0L;
+    this.cache.put(Key.INSTANCE, this);
   }
 
-  /** Implements Precombine Combiner via Combine.KeyedCombineFn. */
-  public static class ValueCombiner<K, InputT, AccumT, OutputT>
-      implements Combiner<WindowedValue<K>, InputT, AccumT, OutputT> {
-    private final GlobalCombineFnRunner<InputT, AccumT, OutputT> combineFn;
-    private final SideInputReader sideInputReader;
-    private final PipelineOptions options;
-
-    private ValueCombiner(
-        GlobalCombineFnRunner<InputT, AccumT, OutputT> combineFn,
-        SideInputReader sideInputReader,
-        PipelineOptions options) {
-      this.combineFn = combineFn;
-      this.sideInputReader = sideInputReader;
-      this.options = options;
-    }
-
-    @Override
-    public AccumT createAccumulator(WindowedValue<K> windowedKey) {
-      return this.combineFn.createAccumulator(options, sideInputReader, windowedKey.getWindows());
+  private class GroupingTableEntry implements Weighted {
+    private final WindowedValue<Object> groupingKey;
+    private final K userKey;
+    private final long keySize;
+    private long accumulatorSize;
+    private AccumT accumulator;
+    private boolean dirty;
+
+    private GroupingTableEntry(
+        WindowedValue<Object> groupingKey, K userKey, InputT initialInputValue) {
+      this.groupingKey = groupingKey;
+      this.userKey = userKey;
+      if (groupingKey.getValue() == userKey) {
+        // This object is only storing references to the same objects that are being stored
+        // by the cache so the accounting of the size of the key is occurring already.
+        this.keySize = Caches.REFERENCE_SIZE * 2;
+      } else {
+        this.keySize = Caches.REFERENCE_SIZE + keySizer.estimateSize(userKey);
+      }
+      this.accumulator =
+          combineFn.createAccumulator(
+              options, NullSideInputReader.empty(), groupingKey.getWindows());
+      add(initialInputValue);
+      this.accumulatorSize = accumulatorSizer.estimateSize(accumulator);
     }
 
-    @Override
-    public AccumT add(WindowedValue<K> windowedKey, AccumT accumulator, InputT value) {
-      return this.combineFn.addInput(
-          accumulator, value, options, sideInputReader, windowedKey.getWindows());
+    public WindowedValue<Object> getGroupingKey() {
+      return groupingKey;
     }
 
-    @Override
-    public AccumT merge(WindowedValue<K> windowedKey, Iterable<AccumT> accumulators) {
-      return this.combineFn.mergeAccumulators(
-          accumulators, options, sideInputReader, windowedKey.getWindows());
+    public K getKey() {
+      return userKey;
     }
 
-    @Override
-    public AccumT compact(WindowedValue<K> windowedKey, AccumT accumulator) {
-      return this.combineFn.compact(
-          accumulator, options, sideInputReader, windowedKey.getWindows());
+    public AccumT getValue() {
+      return accumulator;
     }
 
     @Override
-    public OutputT extract(WindowedValue<K> windowedKey, AccumT accumulator) {
-      return this.combineFn.extractOutput(
-          accumulator, options, sideInputReader, windowedKey.getWindows());
+    public long getWeight() {
+      return keySize + accumulatorSize;
     }
-  }
-
-  // How many bytes a word in the JVM has.
-  private static final int BYTES_PER_JVM_WORD = getBytesPerJvmWord();
-  /**
-   * The number of bytes of overhead to store an entry in the grouping table (a {@code
-   * HashMap<StructuralByteArray, KeyAndValues>}), ignoring the actual number of bytes in the keys
-   * and values:
-   *
-   * <ul>
-   *   <li>an array element (1 word),
-   *   <li>a HashMap.Entry (4 words),
-   *   <li>a StructuralByteArray (1 words),
-   *   <li>a backing array (guessed at 1 word for the length),
-   *   <li>a KeyAndValues (2 words),
-   *   <li>an ArrayList (2 words),
-   *   <li>a backing array (1 word),
-   *   <li>per-object overhead (JVM-specific, guessed at 2 words * 6 objects).
-   * </ul>
-   */
-  private static final int PER_KEY_OVERHEAD = 24 * BYTES_PER_JVM_WORD;
-
-  /** A {@link GroupingTable} that uses the given combiner to combine values in place. */
-  // Keep the table relatively full to increase the chance of collisions.
-  private static final double TARGET_LOAD = 0.9;
-
-  private long maxSize;
-  private final GroupingKeyCreator<? super K> groupingKeyCreator;
-  private final PairInfo pairInfo;
-  private final Combiner<? super K, InputT, AccumT, ?> combiner;
-  private final SizeEstimator<? super K> keySizer;
-  private final SizeEstimator<? super AccumT> accumulatorSizer;
-
-  private long size = 0;
-  private Map<Object, GroupingTableEntry<K, InputT, AccumT>> table;
-
-  PrecombineGroupingTable(
-      long maxSize,
-      GroupingKeyCreator<? super K> groupingKeyCreator,
-      PairInfo pairInfo,
-      Combiner<? super K, InputT, AccumT, ?> combineFn,
-      SizeEstimator<? super K> keySizer,
-      SizeEstimator<? super AccumT> accumulatorSizer) {
-    this.maxSize = maxSize;
-    this.groupingKeyCreator = groupingKeyCreator;
-    this.pairInfo = pairInfo;
-    this.combiner = combineFn;
-    this.keySizer = keySizer;
-    this.accumulatorSizer = accumulatorSizer;
-    this.table = new HashMap<>();
-  }
-
-  interface GroupingTableEntry<K, InputT, AccumT> {
-    K getKey();
-
-    AccumT getValue();
-
-    void add(InputT value) throws Exception;
-
-    long getSize();
-
-    void compact() throws Exception;
-  }
-
-  private GroupingTableEntry<K, InputT, AccumT> createTableEntry(final K key) throws Exception {
-    return new GroupingTableEntry<K, InputT, AccumT>() {
-      final long keySize = keySizer.estimateSize(key);
-      AccumT accumulator = combiner.createAccumulator(key);
-      long accumulatorSize = 0; // never used before a value is added...
-
-      @Override
-      public K getKey() {
-        return key;
-      }
-
-      @Override
-      public AccumT getValue() {
-        return accumulator;
-      }
-
-      @Override
-      public long getSize() {
-        return keySize + accumulatorSize;
-      }
-
-      @Override
-      public void compact() throws Exception {
-        AccumT newAccumulator = combiner.compact(key, accumulator);
-        if (newAccumulator != accumulator) {
-          accumulator = newAccumulator;
-          accumulatorSize = accumulatorSizer.estimateSize(newAccumulator);
-        }
-      }
 
-      @Override
-      public void add(InputT value) throws Exception {
-        accumulator = combiner.add(key, accumulator, value);
+    public void compact() {
+      if (dirty) {
+        accumulator =
+            combineFn.compact(
+                accumulator, options, NullSideInputReader.empty(), groupingKey.getWindows());
         accumulatorSize = accumulatorSizer.estimateSize(accumulator);
+        dirty = false;
       }
-    };
-  }
+    }
 
-  /** Adds a pair to this table, possibly flushing some entries to output if the table is full. */
-  @SuppressWarnings("unchecked")
-  @Override
-  public void put(Object pair, Receiver receiver) throws Exception {
-    put(
-        (K) pairInfo.getKeyFromInputPair(pair),
-        (InputT) pairInfo.getValueFromInputPair(pair),
-        receiver);
+    public void add(InputT value) {
+      dirty = true;
+      accumulator =
+          combineFn.addInput(
+              accumulator, value, options, NullSideInputReader.empty(), groupingKey.getWindows());
+      accumulatorSize = accumulatorSizer.estimateSize(accumulator);
+    }
   }
 
   /**
    * Adds the key and value to this table, possibly flushing some entries to output if the table is
    * full.
    */
-  public void put(K key, InputT value, Receiver receiver) throws Exception {
-    Object groupingKey = groupingKeyCreator.createGroupingKey(key);
-    GroupingTableEntry<K, InputT, AccumT> entry = table.get(groupingKey);
-    if (entry == null) {
-      entry = createTableEntry(key);
-      table.put(groupingKey, entry);
-      size += PER_KEY_OVERHEAD;
-    } else {
-      size -= entry.getSize();
-    }
-    entry.add(value);
-    size += entry.getSize();
-
-    if (size >= maxSize) {
-      long targetSize = (long) (TARGET_LOAD * maxSize);
-      Iterator<GroupingTableEntry<K, InputT, AccumT>> entries = table.values().iterator();
-      while (size >= targetSize) {
-        if (!entries.hasNext()) {
-          // Should never happen, but sizes may be estimates...
-          size = 0;
-          break;
+  @VisibleForTesting
+  public void put(
+      WindowedValue<KV<K, InputT>> value, FnDataReceiver<WindowedValue<KV<K, AccumT>>> receiver)
+      throws Exception {
+    // Ignore timestamp for grouping purposes.
+    // The Pre-combine output will inherit the timestamp of one of its inputs.
+    WindowedValue<Object> groupingKey =
+        WindowedValue.of(
+            keyCoder.structuralValue(value.getValue().getKey()),
+            IGNORED,
+            value.getWindows(),
+            value.getPane());
+
+    GroupingTableEntry entry =
+        lruMap.compute(
+            groupingKey,
+            (key, tableEntry) -> {
+              if (tableEntry == null) {
+                tableEntry =
+                    new GroupingTableEntry(
+                        groupingKey, value.getValue().getKey(), value.getValue().getValue());
+              } else {
+                tableEntry.add(value.getValue().getValue());
+              }
+              return tableEntry;
+            });
+    weight += entry.getWeight();
+    // Increase the maximum only if we require it
+    maxWeight.accumulateAndGet(weight, (current, update) -> current < update ? update : current);
+
+    // Update the cache to ensure that LRU is handled appropriately and for the cache to have an
+    // opportunity to shrink the maxWeight if necessary.
+    cache.put(Key.INSTANCE, this);
+
+    // Get the updated weight now that the cache may have been shrunk and respect it
+    long currentMax = maxWeight.get();
+    if (weight > currentMax) {

Review Comment:
   If this is triggered by shrink() why not do it in shrink but instead rely on new input?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@beam.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [beam] y1chi commented on a diff in pull request #17327: [BEAM-13015] Update the SDK harness grouping table to be memory bounded based upon the amount of assigned cache memory and to use an LRU eviction policy.

Posted by GitBox <gi...@apache.org>.
y1chi commented on code in PR #17327:
URL: https://github.com/apache/beam/pull/17327#discussion_r872883803


##########
sdks/java/harness/src/main/java/org/apache/beam/fn/harness/PrecombineGroupingTable.java:
##########
@@ -223,27 +323,26 @@ public void put(
       throws Exception {
     // Ignore timestamp for grouping purposes.
     // The Pre-combine output will inherit the timestamp of one of its inputs.
-    WindowedValue<Object> groupingKey =
-        WindowedValue.of(
-            keyCoder.structuralValue(value.getValue().getKey()),
-            IGNORED,
-            value.getWindows(),
-            value.getPane());
-
-    GroupingTableEntry entry =
-        lruMap.compute(
-            groupingKey,
-            (key, tableEntry) -> {
-              if (tableEntry == null) {
-                tableEntry =
-                    new GroupingTableEntry(
-                        groupingKey, value.getValue().getKey(), value.getValue().getValue());
-              } else {
-                tableEntry.add(value.getValue().getValue());
-              }
-              return tableEntry;
-            });
-    weight += entry.getWeight();
+    GroupingTableKey groupingKey =
+        new GroupingTableKey(
+            value.getValue().getKey(), value.getWindows(), value.getPane(), keyCoder, keySizer);
+
+    lruMap.compute(
+        groupingKey,
+        (key, tableEntry) -> {
+          if (tableEntry == null) {
+            weight += groupingKey.getWeight();

Review Comment:
   isn't entry.getWeight() = key.getWeight() + accumulator.getWeight()?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@beam.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [beam] youngoli commented on a diff in pull request #17327: [BEAM-13015] Update the SDK harness grouping table to be memory bounded based upon the amount of assigned cache memory and to use an LRU eviction policy.

Posted by GitBox <gi...@apache.org>.
youngoli commented on code in PR #17327:
URL: https://github.com/apache/beam/pull/17327#discussion_r872913431


##########
sdks/java/harness/src/main/java/org/apache/beam/fn/harness/PrecombineGroupingTable.java:
##########
@@ -17,434 +17,392 @@
  */
 package org.apache.beam.fn.harness;
 
-import java.util.HashMap;
+import java.util.Collection;
 import java.util.Iterator;
-import java.util.Map;
+import java.util.LinkedHashMap;
+import java.util.Objects;
 import java.util.Random;
+import java.util.concurrent.atomic.AtomicLong;
+import javax.annotation.Nullable;
+import javax.annotation.concurrent.NotThreadSafe;
+import org.apache.beam.fn.harness.Cache.Shrinkable;
 import org.apache.beam.runners.core.GlobalCombineFnRunner;
 import org.apache.beam.runners.core.GlobalCombineFnRunners;
 import org.apache.beam.runners.core.NullSideInputReader;
-import org.apache.beam.runners.core.SideInputReader;
 import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.fn.data.FnDataReceiver;
 import org.apache.beam.sdk.options.PipelineOptions;
-import org.apache.beam.sdk.options.SdkHarnessOptions;
 import org.apache.beam.sdk.transforms.Combine.CombineFn;
 import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.transforms.windowing.PaneInfo;
+import org.apache.beam.sdk.util.Weighted;
 import org.apache.beam.sdk.util.WindowedValue;
-import org.apache.beam.sdk.util.common.ElementByteSizeObserver;
 import org.apache.beam.sdk.values.KV;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.io.ByteStreams;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.io.CountingOutputStream;
 import org.joda.time.Instant;
 
-/** Static utility methods that provide {@link GroupingTable} implementations. */
+/**
+ * Static utility methods that provide a grouping table implementation.
+ *
+ * <p>{@link NotThreadSafe} because the caller must use the bundle processing thread when invoking
+ * {@link #put} and {@link #flush}. {@link #shrink} may be called from any thread.
+ */
 @SuppressWarnings({
   "nullness" // TODO(https://issues.apache.org/jira/browse/BEAM-10402)
 })
+@NotThreadSafe
 public class PrecombineGroupingTable<K, InputT, AccumT>
-    implements GroupingTable<K, InputT, AccumT> {
-  private static long getGroupingTableSizeBytes(PipelineOptions options) {
-    return options.as(SdkHarnessOptions.class).getGroupingTableMaxSizeMb() * 1024L * 1024L;
-  }
+    implements Shrinkable<PrecombineGroupingTable<K, InputT, AccumT>>, Weighted {
+
+  private static final Instant IGNORED = BoundedWindow.TIMESTAMP_MIN_VALUE;
 
-  /** Returns a {@link GroupingTable} that combines inputs into a accumulator. */
-  public static <K, InputT, AccumT> GroupingTable<WindowedValue<K>, InputT, AccumT> combining(
+  /**
+   * Returns a grouping table that combines inputs into an accumulator. The grouping table uses the
+   * cache to defer flushing output until the cache evicts the table.
+   */
+  public static <K, InputT, AccumT> PrecombineGroupingTable<K, InputT, AccumT> combining(
       PipelineOptions options,
+      Cache<Object, Object> cache,
       CombineFn<InputT, AccumT, ?> combineFn,
-      Coder<K> keyCoder,
-      Coder<? super AccumT> accumulatorCoder) {
-    Combiner<WindowedValue<K>, InputT, AccumT, ?> valueCombiner =
-        new ValueCombiner<>(
-            GlobalCombineFnRunners.create(combineFn), NullSideInputReader.empty(), options);
+      Coder<K> keyCoder) {
     return new PrecombineGroupingTable<>(
-        getGroupingTableSizeBytes(options),
-        new WindowingCoderGroupingKeyCreator<>(keyCoder),
-        WindowedPairInfo.create(),
-        valueCombiner,
-        new CoderSizeEstimator<>(WindowedValue.getValueOnlyCoder(keyCoder)),
-        new CoderSizeEstimator<>(accumulatorCoder));
+        options,
+        cache,
+        keyCoder,
+        GlobalCombineFnRunners.create(combineFn),
+        Caches::weigh,
+        Caches::weigh);
   }
 
   /**
-   * Returns a {@link GroupingTable} that combines inputs into a accumulator with sampling {@link
-   * SizeEstimator SizeEstimators}.
+   * Returns a grouping table that combines inputs into an accumulator with sampling {@link
+   * SizeEstimator SizeEstimators}. The grouping table uses the cache to defer flushing output until
+   * the cache evicts the table.
    */
-  public static <K, InputT, AccumT>
-      GroupingTable<WindowedValue<K>, InputT, AccumT> combiningAndSampling(
-          PipelineOptions options,
-          CombineFn<InputT, AccumT, ?> combineFn,
-          Coder<K> keyCoder,
-          Coder<? super AccumT> accumulatorCoder,
-          double sizeEstimatorSampleRate) {
-    Combiner<WindowedValue<K>, InputT, AccumT, ?> valueCombiner =
-        new ValueCombiner<>(
-            GlobalCombineFnRunners.create(combineFn), NullSideInputReader.empty(), options);
+  public static <K, InputT, AccumT> PrecombineGroupingTable<K, InputT, AccumT> combiningAndSampling(
+      PipelineOptions options,
+      Cache<Object, Object> cache,
+      CombineFn<InputT, AccumT, ?> combineFn,
+      Coder<K> keyCoder,
+      double sizeEstimatorSampleRate) {
     return new PrecombineGroupingTable<>(
-        getGroupingTableSizeBytes(options),
-        new WindowingCoderGroupingKeyCreator<>(keyCoder),
-        WindowedPairInfo.create(),
-        valueCombiner,
-        new SamplingSizeEstimator<>(
-            new CoderSizeEstimator<>(WindowedValue.getValueOnlyCoder(keyCoder)),
-            sizeEstimatorSampleRate,
-            1.0),
-        new SamplingSizeEstimator<>(
-            new CoderSizeEstimator<>(accumulatorCoder), sizeEstimatorSampleRate, 1.0));
+        options,
+        cache,
+        keyCoder,
+        GlobalCombineFnRunners.create(combineFn),
+        new SamplingSizeEstimator<>(Caches::weigh, sizeEstimatorSampleRate, 1.0),
+        new SamplingSizeEstimator<>(Caches::weigh, sizeEstimatorSampleRate, 1.0));
   }
 
-  /** Provides client-specific operations for grouping keys. */
-  public interface GroupingKeyCreator<K> {
-    Object createGroupingKey(K key) throws Exception;
+  @Nullable
+  @Override
+  public PrecombineGroupingTable<K, InputT, AccumT> shrink() {
+    long currentWeight = maxWeight.updateAndGet(operand -> operand >> 1);
+    // It is possible that we are shrunk multiple times until the requested max weight is too small.
+    // In this case we want to effectively stop shrinking since we can't effectively cache much
+    // at this time and the next insertion will likely evict all records.
+    if (currentWeight <= 100L) {
+      return null;
+    }
+    return this;
   }
 
-  /** Implements Precombine GroupingKeyCreator via Coder. */
-  public static class WindowingCoderGroupingKeyCreator<K>
-      implements GroupingKeyCreator<WindowedValue<K>> {
+  @Override
+  public long getWeight() {
+    return maxWeight.get();
+  }
 
-    private static final Instant ignored = BoundedWindow.TIMESTAMP_MIN_VALUE;
+  /** Provides client-specific operations for size estimates. */
+  @FunctionalInterface
+  public interface SizeEstimator<T> {
+    long estimateSize(T element);
+  }
 
-    private final Coder<K> coder;
+  private final Coder<K> keyCoder;
+  private final GlobalCombineFnRunner<InputT, AccumT, ?> combineFn;
+  private final PipelineOptions options;
+  private final SizeEstimator<K> keySizer;
+  private final SizeEstimator<AccumT> accumulatorSizer;
+  private final Cache<Key, PrecombineGroupingTable<K, InputT, AccumT>> cache;
+  private final LinkedHashMap<GroupingTableKey, GroupingTableEntry> lruMap;
+  private final AtomicLong maxWeight;
+  private long weight;
 
-    WindowingCoderGroupingKeyCreator(Coder<K> coder) {
-      this.coder = coder;
-    }
+  private static final class Key implements Weighted {
+    private static final Key INSTANCE = new Key();
 
     @Override
-    public Object createGroupingKey(WindowedValue<K> key) {
-      // Ignore timestamp for grouping purposes.
-      // The Precombine output will inherit the timestamp of one of its inputs.
-      return WindowedValue.of(
-          coder.structuralValue(key.getValue()), ignored, key.getWindows(), key.getPane());
+    public long getWeight() {
+      // Ignore the actual size of this singleton because it is trivial and because
+      // the weight reported here will be counted many times as it is present in
+      // many different state subcaches.
+      return 0;
     }
   }
 
-  /** Provides client-specific operations for size estimates. */
-  public interface SizeEstimator<T> {
-    long estimateSize(T element) throws Exception;
+  PrecombineGroupingTable(
+      PipelineOptions options,
+      Cache<?, ?> cache,
+      Coder<K> keyCoder,
+      GlobalCombineFnRunner<InputT, AccumT, ?> combineFn,
+      SizeEstimator<K> keySizer,
+      SizeEstimator<AccumT> accumulatorSizer) {
+    this.options = options;
+    this.cache = (Cache<Key, PrecombineGroupingTable<K, InputT, AccumT>>) cache;
+    this.keyCoder = keyCoder;
+    this.combineFn = combineFn;
+    this.keySizer = keySizer;
+    this.accumulatorSizer = accumulatorSizer;
+    this.lruMap = new LinkedHashMap<>(16, 0.75f, true);
+    this.maxWeight = new AtomicLong();
+    this.weight = 0L;
+    this.cache.put(Key.INSTANCE, this);
   }
 
-  /** Implements SizeEstimator via Coder. */
-  public static class CoderSizeEstimator<T> implements SizeEstimator<T> {
-    /** Basic implementation of {@link ElementByteSizeObserver} for use in size estimation. */
-    private static class Observer extends ElementByteSizeObserver {
-      private long observedSize = 0;
-
-      @Override
-      protected void reportElementSize(long elementSize) {
-        observedSize += elementSize;
+  private static class GroupingTableKey implements Weighted {
+    private final Object structuralKey;
+    private final Collection<? extends BoundedWindow> windows;
+    private final PaneInfo paneInfo;
+    private final long weight;
+
+    <K> GroupingTableKey(
+        K key,
+        Collection<? extends BoundedWindow> windows,
+        PaneInfo paneInfo,
+        Coder<K> keyCoder,
+        SizeEstimator<K> keySizer) {
+      this.structuralKey = keyCoder.structuralValue(key);
+      this.windows = windows;
+      this.paneInfo = paneInfo;
+      // We account for the weight of the key using the keySizer if the coder's structural value
+      // is the same as its value.
+      if (structuralKey == key) {
+        weight = keySizer.estimateSize(key) + Caches.weigh(windows) + Caches.weigh(paneInfo);
+      } else {
+        weight = Caches.weigh(this);
       }
     }
 
-    final Coder<T> coder;
-
-    CoderSizeEstimator(Coder<T> coder) {
-      this.coder = coder;
+    public Object getStructuralKey() {
+      return structuralKey;
     }
 
-    @Override
-    public long estimateSize(T value) throws Exception {
-      // First try using byte size observer
-      CoderSizeEstimator.Observer observer = new CoderSizeEstimator.Observer();
-      coder.registerByteSizeObserver(value, observer);
-
-      if (!observer.getIsLazy()) {
-        observer.advance();
-        return observer.observedSize;
-      } else {
-        // Coder byte size observation is lazy (requires iteration for observation) so fall back to
-        // counting output stream
-        CountingOutputStream os = new CountingOutputStream(ByteStreams.nullOutputStream());
-        coder.encode(value, os);
-        return os.getCount();
-      }
+    public Collection<? extends BoundedWindow> getWindows() {
+      return windows;
     }
-  }
 
-  /**
-   * Provides client-specific operations for working with elements that are key/value or key/values
-   * pairs.
-   */
-  public interface PairInfo {
-    Object getKeyFromInputPair(Object pair);
-
-    Object getValueFromInputPair(Object pair);
-
-    Object makeOutputPair(Object key, Object value);
-  }
-
-  /** Implements Precombine PairInfo via KVs. */
-  public static class WindowedPairInfo implements PairInfo {
-    private static WindowedPairInfo theInstance = new WindowedPairInfo();
-
-    public static WindowedPairInfo create() {
-      return theInstance;
+    public PaneInfo getPaneInfo() {
+      return paneInfo;
     }
 
-    private WindowedPairInfo() {}
+    @Override
+    public long getWeight() {
+      return weight;
+    }
 
     @Override
-    public Object getKeyFromInputPair(Object pair) {
-      @SuppressWarnings("unchecked")
-      WindowedValue<KV<?, ?>> windowedKv = (WindowedValue<KV<?, ?>>) pair;
-      return windowedKv.withValue(windowedKv.getValue().getKey());
+    public boolean equals(Object o) {
+      if (this == o) {
+        return true;
+      }
+      if (!(o instanceof GroupingTableKey)) {
+        return false;
+      }
+      GroupingTableKey that = (GroupingTableKey) o;
+      return Objects.equals(structuralKey, that.structuralKey)
+          && windows.equals(that.windows)
+          && paneInfo.equals(that.paneInfo);
     }
 
     @Override
-    public Object getValueFromInputPair(Object pair) {
-      @SuppressWarnings("unchecked")
-      WindowedValue<KV<?, ?>> windowedKv = (WindowedValue<KV<?, ?>>) pair;
-      return windowedKv.getValue().getValue();
+    public int hashCode() {
+      return Objects.hash(structuralKey, windows, paneInfo);
     }
 
     @Override
-    public Object makeOutputPair(Object key, Object values) {
-      WindowedValue<?> windowedKey = (WindowedValue<?>) key;
-      return windowedKey.withValue(KV.of(windowedKey.getValue(), values));
+    public String toString() {
+      return "GroupingTableKey{"
+          + "structuralKey="
+          + structuralKey
+          + ", windows="
+          + windows
+          + ", paneInfo="
+          + paneInfo
+          + ", weight="
+          + weight
+          + '}';
     }
   }
 
-  /** Provides client-specific operations for combining values. */
-  public interface Combiner<K, InputT, AccumT, OutputT> {
-    AccumT createAccumulator(K key);
-
-    AccumT add(K key, AccumT accumulator, InputT value);
-
-    AccumT merge(K key, Iterable<AccumT> accumulators);
-
-    AccumT compact(K key, AccumT accumulator);
+  private class GroupingTableEntry implements Weighted {
+    private final GroupingTableKey groupingKey;
+    private final K userKey;
+    private final long keySize;
+    private long accumulatorSize;
+    private AccumT accumulator;
+    private boolean dirty;
+
+    private GroupingTableEntry(GroupingTableKey groupingKey, K userKey, InputT initialInputValue) {
+      this.groupingKey = groupingKey;
+      this.userKey = userKey;
+      if (groupingKey.getStructuralKey() == userKey) {
+        // This object is only storing references to the same objects that are being stored
+        // by the cache so the accounting of the size of the key is occurring already.
+        this.keySize = Caches.REFERENCE_SIZE * 2;
+      } else {
+        this.keySize = Caches.REFERENCE_SIZE + keySizer.estimateSize(userKey);
+      }
+      this.accumulator =
+          combineFn.createAccumulator(
+              options, NullSideInputReader.empty(), groupingKey.getWindows());
+      add(initialInputValue);
+      this.accumulatorSize = accumulatorSizer.estimateSize(accumulator);
+    }
 
-    OutputT extract(K key, AccumT accumulator);
-  }
+    public GroupingTableKey getGroupingKey() {
+      return groupingKey;
+    }
 
-  /** Implements Precombine Combiner via Combine.KeyedCombineFn. */
-  public static class ValueCombiner<K, InputT, AccumT, OutputT>
-      implements Combiner<WindowedValue<K>, InputT, AccumT, OutputT> {
-    private final GlobalCombineFnRunner<InputT, AccumT, OutputT> combineFn;
-    private final SideInputReader sideInputReader;
-    private final PipelineOptions options;
-
-    private ValueCombiner(
-        GlobalCombineFnRunner<InputT, AccumT, OutputT> combineFn,
-        SideInputReader sideInputReader,
-        PipelineOptions options) {
-      this.combineFn = combineFn;
-      this.sideInputReader = sideInputReader;
-      this.options = options;
+    public K getKey() {
+      return userKey;
     }
 
-    @Override
-    public AccumT createAccumulator(WindowedValue<K> windowedKey) {
-      return this.combineFn.createAccumulator(options, sideInputReader, windowedKey.getWindows());
+    public AccumT getValue() {
+      return accumulator;
     }
 
     @Override
-    public AccumT add(WindowedValue<K> windowedKey, AccumT accumulator, InputT value) {
-      return this.combineFn.addInput(
-          accumulator, value, options, sideInputReader, windowedKey.getWindows());
+    public long getWeight() {
+      return keySize + accumulatorSize;
     }
 
-    @Override
-    public AccumT merge(WindowedValue<K> windowedKey, Iterable<AccumT> accumulators) {
-      return this.combineFn.mergeAccumulators(
-          accumulators, options, sideInputReader, windowedKey.getWindows());
+    public void compact() {
+      if (dirty) {
+        accumulator =
+            combineFn.compact(
+                accumulator, options, NullSideInputReader.empty(), groupingKey.getWindows());
+        accumulatorSize = accumulatorSizer.estimateSize(accumulator);
+        dirty = false;
+      }
     }
 
-    @Override
-    public AccumT compact(WindowedValue<K> windowedKey, AccumT accumulator) {
-      return this.combineFn.compact(
-          accumulator, options, sideInputReader, windowedKey.getWindows());
+    public void add(InputT value) {
+      dirty = true;
+      accumulator =
+          combineFn.addInput(
+              accumulator, value, options, NullSideInputReader.empty(), groupingKey.getWindows());
+      accumulatorSize = accumulatorSizer.estimateSize(accumulator);
     }
 
     @Override
-    public OutputT extract(WindowedValue<K> windowedKey, AccumT accumulator) {
-      return this.combineFn.extractOutput(
-          accumulator, options, sideInputReader, windowedKey.getWindows());
+    public String toString() {
+      return "GroupingTableEntry{"
+          + "groupingKey="
+          + groupingKey
+          + ", userKey="
+          + userKey
+          + ", keySize="
+          + keySize
+          + ", accumulatorSize="
+          + accumulatorSize
+          + ", accumulator="
+          + accumulator
+          + ", dirty="
+          + dirty
+          + '}';
     }
   }
 
-  // How many bytes a word in the JVM has.
-  private static final int BYTES_PER_JVM_WORD = getBytesPerJvmWord();
-  /**
-   * The number of bytes of overhead to store an entry in the grouping table (a {@code
-   * HashMap<StructuralByteArray, KeyAndValues>}), ignoring the actual number of bytes in the keys
-   * and values:
-   *
-   * <ul>
-   *   <li>an array element (1 word),
-   *   <li>a HashMap.Entry (4 words),
-   *   <li>a StructuralByteArray (1 words),
-   *   <li>a backing array (guessed at 1 word for the length),
-   *   <li>a KeyAndValues (2 words),
-   *   <li>an ArrayList (2 words),
-   *   <li>a backing array (1 word),
-   *   <li>per-object overhead (JVM-specific, guessed at 2 words * 6 objects).
-   * </ul>
-   */
-  private static final int PER_KEY_OVERHEAD = 24 * BYTES_PER_JVM_WORD;
-
-  /** A {@link GroupingTable} that uses the given combiner to combine values in place. */
-  // Keep the table relatively full to increase the chance of collisions.
-  private static final double TARGET_LOAD = 0.9;
-
-  private long maxSize;
-  private final GroupingKeyCreator<? super K> groupingKeyCreator;
-  private final PairInfo pairInfo;
-  private final Combiner<? super K, InputT, AccumT, ?> combiner;
-  private final SizeEstimator<? super K> keySizer;
-  private final SizeEstimator<? super AccumT> accumulatorSizer;
-
-  private long size = 0;
-  private Map<Object, GroupingTableEntry<K, InputT, AccumT>> table;
-
-  PrecombineGroupingTable(
-      long maxSize,
-      GroupingKeyCreator<? super K> groupingKeyCreator,
-      PairInfo pairInfo,
-      Combiner<? super K, InputT, AccumT, ?> combineFn,
-      SizeEstimator<? super K> keySizer,
-      SizeEstimator<? super AccumT> accumulatorSizer) {
-    this.maxSize = maxSize;
-    this.groupingKeyCreator = groupingKeyCreator;
-    this.pairInfo = pairInfo;
-    this.combiner = combineFn;
-    this.keySizer = keySizer;
-    this.accumulatorSizer = accumulatorSizer;
-    this.table = new HashMap<>();
-  }
-
-  interface GroupingTableEntry<K, InputT, AccumT> {
-    K getKey();
-
-    AccumT getValue();
-
-    void add(InputT value) throws Exception;
-
-    long getSize();
-
-    void compact() throws Exception;
-  }
-
-  private GroupingTableEntry<K, InputT, AccumT> createTableEntry(final K key) throws Exception {
-    return new GroupingTableEntry<K, InputT, AccumT>() {
-      final long keySize = keySizer.estimateSize(key);
-      AccumT accumulator = combiner.createAccumulator(key);
-      long accumulatorSize = 0; // never used before a value is added...
-
-      @Override
-      public K getKey() {
-        return key;
-      }
-
-      @Override
-      public AccumT getValue() {
-        return accumulator;
-      }
-
-      @Override
-      public long getSize() {
-        return keySize + accumulatorSize;
-      }
-
-      @Override
-      public void compact() throws Exception {
-        AccumT newAccumulator = combiner.compact(key, accumulator);
-        if (newAccumulator != accumulator) {
-          accumulator = newAccumulator;
-          accumulatorSize = accumulatorSizer.estimateSize(newAccumulator);
-        }
-      }
-
-      @Override
-      public void add(InputT value) throws Exception {
-        accumulator = combiner.add(key, accumulator, value);
-        accumulatorSize = accumulatorSizer.estimateSize(accumulator);
-      }
-    };
-  }
-
-  /** Adds a pair to this table, possibly flushing some entries to output if the table is full. */
-  @SuppressWarnings("unchecked")
-  @Override
-  public void put(Object pair, Receiver receiver) throws Exception {
-    put(
-        (K) pairInfo.getKeyFromInputPair(pair),
-        (InputT) pairInfo.getValueFromInputPair(pair),
-        receiver);
-  }
-
   /**
    * Adds the key and value to this table, possibly flushing some entries to output if the table is
    * full.
    */
-  public void put(K key, InputT value, Receiver receiver) throws Exception {
-    Object groupingKey = groupingKeyCreator.createGroupingKey(key);
-    GroupingTableEntry<K, InputT, AccumT> entry = table.get(groupingKey);
-    if (entry == null) {
-      entry = createTableEntry(key);
-      table.put(groupingKey, entry);
-      size += PER_KEY_OVERHEAD;
-    } else {
-      size -= entry.getSize();
-    }
-    entry.add(value);
-    size += entry.getSize();
-
-    if (size >= maxSize) {
-      long targetSize = (long) (TARGET_LOAD * maxSize);
-      Iterator<GroupingTableEntry<K, InputT, AccumT>> entries = table.values().iterator();
-      while (size >= targetSize) {
-        if (!entries.hasNext()) {
-          // Should never happen, but sizes may be estimates...
-          size = 0;
-          break;
+  @VisibleForTesting
+  public void put(
+      WindowedValue<KV<K, InputT>> value, FnDataReceiver<WindowedValue<KV<K, AccumT>>> receiver)
+      throws Exception {
+    // Ignore timestamp for grouping purposes.
+    // The Pre-combine output will inherit the timestamp of one of its inputs.
+    GroupingTableKey groupingKey =
+        new GroupingTableKey(
+            value.getValue().getKey(), value.getWindows(), value.getPane(), keyCoder, keySizer);
+
+    lruMap.compute(
+        groupingKey,
+        (key, tableEntry) -> {
+          if (tableEntry == null) {
+            weight += groupingKey.getWeight();
+            tableEntry =
+                new GroupingTableEntry(
+                    groupingKey, value.getValue().getKey(), value.getValue().getValue());
+          } else {
+            weight -= tableEntry.getWeight();
+            tableEntry.add(value.getValue().getValue());
+          }
+          weight += tableEntry.getWeight();
+          return tableEntry;
+        });
+
+    // Increase the maximum only if we require it
+    maxWeight.accumulateAndGet(weight, (current, update) -> current < update ? update : current);
+
+    // Update the cache to ensure that LRU is handled appropriately and for the cache to have an
+    // opportunity to shrink the maxWeight if necessary.
+    cache.put(Key.INSTANCE, this);
+
+    // Get the updated weight now that the cache may have been shrunk and respect it
+    long currentMax = maxWeight.get();
+
+    // Only compact and output from the bundle processing thread that is inserting elements into the
+    // grouping table. This ensures that we honor the guarantee that transforms for a single bundle
+    // execute using the same thread.
+    if (weight > currentMax) {
+      // Try to compact as many the values as possible and only flush values if compaction wasn't
+      // enough.
+      for (GroupingTableEntry valueToCompact : lruMap.values()) {
+        long currentWeight = valueToCompact.getWeight();
+        valueToCompact.compact();
+        weight += valueToCompact.getWeight() - currentWeight;
+      }
+
+      if (weight > currentMax) {
+        Iterator<GroupingTableEntry> iterator = lruMap.values().iterator();
+        while (iterator.hasNext()) {
+          GroupingTableEntry valueToFlush = iterator.next();
+          weight -= valueToFlush.getWeight() + valueToFlush.getGroupingKey().getWeight();

Review Comment:
   I'm having some trouble following all the different weights, and my first instinct is that since valueToFlush contains the GroupingKey, that this would count the weight of the grouping key twice (and presumably this would be bad because it wasn't counted twice when being originally added to the max weight).



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@beam.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [beam] lukecwik commented on pull request #17327: [BEAM-13015] Update the SDK harness grouping table to be memory bounded based upon the amount of assigned cache memory and to use an LRU eviction policy.

Posted by GitBox <gi...@apache.org>.
lukecwik commented on PR #17327:
URL: https://github.com/apache/beam/pull/17327#issuecomment-1126551772

   @y1chi PTAL


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@beam.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [beam] lukecwik merged pull request #17327: [BEAM-13015] Update the SDK harness grouping table to be memory bounded based upon the amount of assigned cache memory and to use an LRU eviction policy.

Posted by GitBox <gi...@apache.org>.
lukecwik merged PR #17327:
URL: https://github.com/apache/beam/pull/17327


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@beam.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org