You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by lc...@apache.org on 2022/05/16 14:08:56 UTC

[beam] branch master updated: [BEAM-13015] Update the SDK harness grouping table to be memory bounded based upon the amount of assigned cache memory and to use an LRU eviction policy. (#17327)

This is an automated email from the ASF dual-hosted git repository.

lcwik pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 5b81d140636 [BEAM-13015] Update the SDK harness grouping table to be memory bounded based upon the amount of assigned cache memory and to use an LRU eviction policy. (#17327)
5b81d140636 is described below

commit 5b81d140636e3fa774610aeb8a8896d02696b707
Author: Luke Cwik <lc...@google.com>
AuthorDate: Mon May 16 07:08:44 2022 -0700

    [BEAM-13015] Update the SDK harness grouping table to be memory bounded based upon the amount of assigned cache memory and to use an LRU eviction policy. (#17327)
    
    * [BEAM-13015] Update the grouping table to be memory bounded based upon amount of assigned cache memory and also use an LRU policy for evicting entries from the table.
    
    * fixup! checkstyle
    
    * fixup! Address PR comments.
---
 .../beam/runners/core/NullSideInputReader.java     |   5 +-
 .../apache/beam/sdk/options/SdkHarnessOptions.java |  18 +-
 .../java/org/apache/beam/fn/harness/Cache.java     |   6 +
 .../java/org/apache/beam/fn/harness/Caches.java    |   5 +-
 .../org/apache/beam/fn/harness/CombineRunners.java |  44 +-
 .../org/apache/beam/fn/harness/GroupingTable.java  |  34 --
 .../beam/fn/harness/PrecombineGroupingTable.java   | 624 ++++++++++-----------
 .../fn/harness/PrecombineGroupingTableTest.java    | 306 +++++++---
 8 files changed, 576 insertions(+), 466 deletions(-)

diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/NullSideInputReader.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/NullSideInputReader.java
index 0cd0f6fc0bd..186a207ebf0 100644
--- a/runners/core-java/src/main/java/org/apache/beam/runners/core/NullSideInputReader.java
+++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/NullSideInputReader.java
@@ -29,10 +29,13 @@ import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Sets;
  */
 public class NullSideInputReader implements SideInputReader {
 
+  /** The default empty instance. */
+  private static final NullSideInputReader EMPTY_INSTANCE = of(Collections.emptySet());
+
   private Set<PCollectionView<?>> views;
 
   public static NullSideInputReader empty() {
-    return new NullSideInputReader(Collections.emptySet());
+    return EMPTY_INSTANCE;
   }
 
   public static NullSideInputReader of(Iterable<? extends PCollectionView<?>> views) {
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/options/SdkHarnessOptions.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/options/SdkHarnessOptions.java
index 8c63106676e..c701bc38c00 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/options/SdkHarnessOptions.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/options/SdkHarnessOptions.java
@@ -87,14 +87,26 @@ public interface SdkHarnessOptions extends PipelineOptions {
   void setSdkHarnessLogLevelOverrides(SdkHarnessLogLevelOverrides value);
 
   /**
-   * Size (in MB) of each grouping table used to pre-combine elements. If unset, defaults to 100 MB.
+   * Size (in MB) of each grouping table used to pre-combine elements. Larger values may reduce the
+   * amount of data shuffled. If unset, defaults to 100 MB.
    *
    * <p>CAUTION: If set too large, workers may run into OOM conditions more easily, each worker may
    * have many grouping tables in-memory concurrently.
+   *
+   * <p>CAUTION: This option does not apply to portable runners such as Dataflow Prime. See {@link
+   * #setMaxCacheMemoryUsageMb}, {@link #setMaxCacheMemoryUsagePercent}, or {@link
+   * #setMaxCacheMemoryUsageMbClass} to configure memory thresholds that apply to the grouping table
+   * and other cached objects.
    */
   @Description(
-      "The size (in MB) of the grouping tables used to pre-combine elements before "
-          + "shuffling.  Larger values may reduce the amount of data shuffled.")
+      "The size (in MB) of the grouping tables used to pre-combine elements before shuffling. If "
+          + "unset, defaults to 100 MB. Larger values may reduce the amount of data shuffled. "
+          + "CAUTION: If set too large, workers may run into OOM conditions more easily, each "
+          + "worker may have many grouping tables in-memory concurrently. CAUTION: This option "
+          + "does not apply to portable runners such as Dataflow Prime. See "
+          + "--maxCacheMemoryUsageMb, --maxCacheMemoryUsagePercent, or "
+          + "--maxCacheMemoryUsageMbClass to configure memory thresholds that apply to the "
+          + "grouping table and other cached objects.")
   @Default.Integer(100)
   int getGroupingTableMaxSizeMb();
 
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/Cache.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/Cache.java
index 3164fb241be..b4f379c5123 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/Cache.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/Cache.java
@@ -38,7 +38,13 @@ public interface Cache<K, V> {
    *
    * <p>Types should consider implementing {@link org.apache.beam.sdk.util.Weighted} to not invoke
    * the overhead of using the {@link Caches#weigh default weigher} multiple times.
+   *
+   * <p>This interface may be invoked from any other thread that manipulates the cache causing this
+   * value to be shrunk. Implementers must ensure thread safety with respect to any side effects
+   * caused.
    */
+  @ThreadSafe
+  @FunctionalInterface
   interface Shrinkable<V> {
     /**
      * Returns a new object that is smaller than the object being evicted.
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/Caches.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/Caches.java
index 9fa658e7b6a..514b21575b2 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/Caches.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/Caches.java
@@ -56,9 +56,12 @@ public final class Caches {
   private static final MemoryMeter MEMORY_METER =
       MemoryMeter.builder().withGuessing(Guess.BEST).build();
 
+  /** The size of a reference. */
+  public static final long REFERENCE_SIZE = 8;
+
   public static long weigh(Object o) {
     if (o == null) {
-      return 8;
+      return REFERENCE_SIZE;
     }
     try {
       return MEMORY_METER.measureDeep(o);
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/CombineRunners.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/CombineRunners.java
index bb5ab2f3278..1f8de86f1fa 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/CombineRunners.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/CombineRunners.java
@@ -20,6 +20,7 @@ package org.apache.beam.fn.harness;
 import com.google.auto.service.AutoService;
 import java.io.IOException;
 import java.util.Map;
+import java.util.function.Supplier;
 import org.apache.beam.model.pipeline.v1.RunnerApi;
 import org.apache.beam.model.pipeline.v1.RunnerApi.CombinePayload;
 import org.apache.beam.model.pipeline.v1.RunnerApi.PTransform;
@@ -68,40 +69,46 @@ public class CombineRunners {
   }
 
   private static class PrecombineRunner<KeyT, InputT, AccumT> {
-    private PipelineOptions options;
-    private CombineFn<InputT, AccumT, ?> combineFn;
-    private FnDataReceiver<WindowedValue<KV<KeyT, AccumT>>> output;
-    private Coder<KeyT> keyCoder;
-    private GroupingTable<WindowedValue<KeyT>, InputT, AccumT> groupingTable;
-    private Coder<AccumT> accumCoder;
+    private final PipelineOptions options;
+    private final String ptransformId;
+    private final Supplier<Cache<?, ?>> bundleCache;
+    private final CombineFn<InputT, AccumT, ?> combineFn;
+    private final FnDataReceiver<WindowedValue<KV<KeyT, AccumT>>> output;
+    private final Coder<KeyT> keyCoder;
+    private PrecombineGroupingTable<KeyT, InputT, AccumT> groupingTable;
 
     PrecombineRunner(
         PipelineOptions options,
+        String ptransformId,
+        Supplier<Cache<?, ?>> bundleCache,
         CombineFn<InputT, AccumT, ?> combineFn,
         FnDataReceiver<WindowedValue<KV<KeyT, AccumT>>> output,
-        Coder<KeyT> keyCoder,
-        Coder<AccumT> accumCoder) {
+        Coder<KeyT> keyCoder) {
       this.options = options;
+      this.ptransformId = ptransformId;
+      this.bundleCache = bundleCache;
       this.combineFn = combineFn;
       this.output = output;
       this.keyCoder = keyCoder;
-      this.accumCoder = accumCoder;
     }
 
     void startBundle() {
       groupingTable =
           PrecombineGroupingTable.combiningAndSampling(
-              options, combineFn, keyCoder, accumCoder, 0.001 /*sizeEstimatorSampleRate*/);
+              options,
+              Caches.subCache(bundleCache.get(), ptransformId),
+              combineFn,
+              keyCoder,
+              0.001 /*sizeEstimatorSampleRate*/);
     }
 
     void processElement(WindowedValue<KV<KeyT, InputT>> elem) throws Exception {
-      groupingTable.put(
-          elem, (Object outputElem) -> output.accept((WindowedValue<KV<KeyT, AccumT>>) outputElem));
+      groupingTable.put(elem, output::accept);
     }
 
     void finishBundle() throws Exception {
-      groupingTable.flush(
-          (Object outputElem) -> output.accept((WindowedValue<KV<KeyT, AccumT>>) outputElem));
+      groupingTable.flush(output::accept);
+      groupingTable = null;
     }
   }
 
@@ -144,8 +151,6 @@ public class CombineRunners {
           (CombineFn)
               SerializableUtils.deserializeFromByteArray(
                   combinePayload.getCombineFn().getPayload().toByteArray(), "CombineFn");
-      Coder<AccumT> accumCoder =
-          (Coder<AccumT>) rehydratedComponents.getCoder(combinePayload.getAccumulatorCoderId());
 
       FnDataReceiver<WindowedValue<KV<KeyT, AccumT>>> consumer =
           (FnDataReceiver)
@@ -154,7 +159,12 @@ public class CombineRunners {
 
       PrecombineRunner<KeyT, InputT, AccumT> runner =
           new PrecombineRunner<>(
-              context.getPipelineOptions(), combineFn, consumer, keyCoder, accumCoder);
+              context.getPipelineOptions(),
+              context.getPTransformId(),
+              context.getBundleCacheSupplier(),
+              combineFn,
+              consumer,
+              keyCoder);
 
       // Register the appropriate handlers.
       context.addStartBundleFunction(runner::startBundle);
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/GroupingTable.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/GroupingTable.java
deleted file mode 100644
index 0c55655f666..00000000000
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/GroupingTable.java
+++ /dev/null
@@ -1,34 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.beam.fn.harness;
-
-/** An interface that groups inputs to an accumulator and flushes the output. */
-public interface GroupingTable<K, InputT, AccumT> {
-
-  /** Abstract interface of things that accept inputs one at a time via process(). */
-  interface Receiver {
-    /** Processes the element. */
-    void process(Object outputElem) throws Exception;
-  }
-
-  /** Adds a pair to this table, possibly flushing some entries to output if the table is full. */
-  void put(Object pair, Receiver receiver) throws Exception;
-
-  /** Flushes all entries in this table to output. */
-  void flush(Receiver output) throws Exception;
-}
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/PrecombineGroupingTable.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/PrecombineGroupingTable.java
index dda245b7f3c..d7251bed4fa 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/PrecombineGroupingTable.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/PrecombineGroupingTable.java
@@ -17,392 +17,365 @@
  */
 package org.apache.beam.fn.harness;
 
-import java.util.HashMap;
+import java.util.Collection;
 import java.util.Iterator;
-import java.util.Map;
+import java.util.LinkedHashMap;
+import java.util.Objects;
 import java.util.Random;
+import java.util.concurrent.atomic.AtomicLong;
+import javax.annotation.Nullable;
+import javax.annotation.concurrent.NotThreadSafe;
+import org.apache.beam.fn.harness.Cache.Shrinkable;
 import org.apache.beam.runners.core.GlobalCombineFnRunner;
 import org.apache.beam.runners.core.GlobalCombineFnRunners;
 import org.apache.beam.runners.core.NullSideInputReader;
-import org.apache.beam.runners.core.SideInputReader;
 import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.fn.data.FnDataReceiver;
 import org.apache.beam.sdk.options.PipelineOptions;
-import org.apache.beam.sdk.options.SdkHarnessOptions;
 import org.apache.beam.sdk.transforms.Combine.CombineFn;
 import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.transforms.windowing.PaneInfo;
+import org.apache.beam.sdk.util.Weighted;
 import org.apache.beam.sdk.util.WindowedValue;
-import org.apache.beam.sdk.util.common.ElementByteSizeObserver;
 import org.apache.beam.sdk.values.KV;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.io.ByteStreams;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.io.CountingOutputStream;
 import org.joda.time.Instant;
 
-/** Static utility methods that provide {@link GroupingTable} implementations. */
+/**
+ * Static utility methods that provide a grouping table implementation.
+ *
+ * <p>{@link NotThreadSafe} because the caller must use the bundle processing thread when invoking
+ * {@link #put} and {@link #flush}. {@link #shrink} may be called from any thread.
+ */
 @SuppressWarnings({
   "nullness" // TODO(https://issues.apache.org/jira/browse/BEAM-10402)
 })
+@NotThreadSafe
 public class PrecombineGroupingTable<K, InputT, AccumT>
-    implements GroupingTable<K, InputT, AccumT> {
-  private static long getGroupingTableSizeBytes(PipelineOptions options) {
-    return options.as(SdkHarnessOptions.class).getGroupingTableMaxSizeMb() * 1024L * 1024L;
-  }
+    implements Shrinkable<PrecombineGroupingTable<K, InputT, AccumT>>, Weighted {
+
+  private static final Instant IGNORED = BoundedWindow.TIMESTAMP_MIN_VALUE;
 
-  /** Returns a {@link GroupingTable} that combines inputs into a accumulator. */
-  public static <K, InputT, AccumT> GroupingTable<WindowedValue<K>, InputT, AccumT> combining(
+  /**
+   * Returns a grouping table that combines inputs into an accumulator. The grouping table uses the
+   * cache to defer flushing output until the cache evicts the table.
+   */
+  public static <K, InputT, AccumT> PrecombineGroupingTable<K, InputT, AccumT> combining(
       PipelineOptions options,
+      Cache<Object, Object> cache,
       CombineFn<InputT, AccumT, ?> combineFn,
-      Coder<K> keyCoder,
-      Coder<? super AccumT> accumulatorCoder) {
-    Combiner<WindowedValue<K>, InputT, AccumT, ?> valueCombiner =
-        new ValueCombiner<>(
-            GlobalCombineFnRunners.create(combineFn), NullSideInputReader.empty(), options);
+      Coder<K> keyCoder) {
     return new PrecombineGroupingTable<>(
-        getGroupingTableSizeBytes(options),
-        new WindowingCoderGroupingKeyCreator<>(keyCoder),
-        WindowedPairInfo.create(),
-        valueCombiner,
-        new CoderSizeEstimator<>(WindowedValue.getValueOnlyCoder(keyCoder)),
-        new CoderSizeEstimator<>(accumulatorCoder));
+        options,
+        cache,
+        keyCoder,
+        GlobalCombineFnRunners.create(combineFn),
+        Caches::weigh,
+        Caches::weigh);
   }
 
   /**
-   * Returns a {@link GroupingTable} that combines inputs into a accumulator with sampling {@link
-   * SizeEstimator SizeEstimators}.
+   * Returns a grouping table that combines inputs into an accumulator with sampling {@link
+   * SizeEstimator SizeEstimators}. The grouping table uses the cache to defer flushing output until
+   * the cache evicts the table.
    */
-  public static <K, InputT, AccumT>
-      GroupingTable<WindowedValue<K>, InputT, AccumT> combiningAndSampling(
-          PipelineOptions options,
-          CombineFn<InputT, AccumT, ?> combineFn,
-          Coder<K> keyCoder,
-          Coder<? super AccumT> accumulatorCoder,
-          double sizeEstimatorSampleRate) {
-    Combiner<WindowedValue<K>, InputT, AccumT, ?> valueCombiner =
-        new ValueCombiner<>(
-            GlobalCombineFnRunners.create(combineFn), NullSideInputReader.empty(), options);
+  public static <K, InputT, AccumT> PrecombineGroupingTable<K, InputT, AccumT> combiningAndSampling(
+      PipelineOptions options,
+      Cache<Object, Object> cache,
+      CombineFn<InputT, AccumT, ?> combineFn,
+      Coder<K> keyCoder,
+      double sizeEstimatorSampleRate) {
     return new PrecombineGroupingTable<>(
-        getGroupingTableSizeBytes(options),
-        new WindowingCoderGroupingKeyCreator<>(keyCoder),
-        WindowedPairInfo.create(),
-        valueCombiner,
-        new SamplingSizeEstimator<>(
-            new CoderSizeEstimator<>(WindowedValue.getValueOnlyCoder(keyCoder)),
-            sizeEstimatorSampleRate,
-            1.0),
-        new SamplingSizeEstimator<>(
-            new CoderSizeEstimator<>(accumulatorCoder), sizeEstimatorSampleRate, 1.0));
+        options,
+        cache,
+        keyCoder,
+        GlobalCombineFnRunners.create(combineFn),
+        new SamplingSizeEstimator<>(Caches::weigh, sizeEstimatorSampleRate, 1.0),
+        new SamplingSizeEstimator<>(Caches::weigh, sizeEstimatorSampleRate, 1.0));
   }
 
-  /** Provides client-specific operations for grouping keys. */
-  public interface GroupingKeyCreator<K> {
-    Object createGroupingKey(K key) throws Exception;
+  @Nullable
+  @Override
+  public PrecombineGroupingTable<K, InputT, AccumT> shrink() {
+    long currentWeight = maxWeight.updateAndGet(operand -> operand >> 1);
+    // It is possible that we are shrunk multiple times until the requested max weight is too small.
+    // In this case we want to effectively stop shrinking since we can't effectively cache much
+    // at this time and the next insertion will likely evict all records.
+    if (currentWeight <= 100L) {
+      return null;
+    }
+    return this;
   }
 
-  /** Implements Precombine GroupingKeyCreator via Coder. */
-  public static class WindowingCoderGroupingKeyCreator<K>
-      implements GroupingKeyCreator<WindowedValue<K>> {
+  @Override
+  public long getWeight() {
+    return maxWeight.get();
+  }
 
-    private static final Instant ignored = BoundedWindow.TIMESTAMP_MIN_VALUE;
+  /** Provides client-specific operations for size estimates. */
+  @FunctionalInterface
+  public interface SizeEstimator<T> {
+    long estimateSize(T element);
+  }
 
-    private final Coder<K> coder;
+  private final Coder<K> keyCoder;
+  private final GlobalCombineFnRunner<InputT, AccumT, ?> combineFn;
+  private final PipelineOptions options;
+  private final SizeEstimator<K> keySizer;
+  private final SizeEstimator<AccumT> accumulatorSizer;
+  private final Cache<Key, PrecombineGroupingTable<K, InputT, AccumT>> cache;
+  private final LinkedHashMap<GroupingTableKey, GroupingTableEntry> lruMap;
+  private final AtomicLong maxWeight;
+  private long weight;
 
-    WindowingCoderGroupingKeyCreator(Coder<K> coder) {
-      this.coder = coder;
-    }
+  private static final class Key implements Weighted {
+    private static final Key INSTANCE = new Key();
 
     @Override
-    public Object createGroupingKey(WindowedValue<K> key) {
-      // Ignore timestamp for grouping purposes.
-      // The Precombine output will inherit the timestamp of one of its inputs.
-      return WindowedValue.of(
-          coder.structuralValue(key.getValue()), ignored, key.getWindows(), key.getPane());
+    public long getWeight() {
+      // Ignore the actual size of this singleton because it is trivial and because
+      // the weight reported here will be counted many times as it is present in
+      // many different state subcaches.
+      return 0;
     }
   }
 
-  /** Provides client-specific operations for size estimates. */
-  public interface SizeEstimator<T> {
-    long estimateSize(T element) throws Exception;
+  PrecombineGroupingTable(
+      PipelineOptions options,
+      Cache<?, ?> cache,
+      Coder<K> keyCoder,
+      GlobalCombineFnRunner<InputT, AccumT, ?> combineFn,
+      SizeEstimator<K> keySizer,
+      SizeEstimator<AccumT> accumulatorSizer) {
+    this.options = options;
+    this.cache = (Cache<Key, PrecombineGroupingTable<K, InputT, AccumT>>) cache;
+    this.keyCoder = keyCoder;
+    this.combineFn = combineFn;
+    this.keySizer = keySizer;
+    this.accumulatorSizer = accumulatorSizer;
+    this.lruMap = new LinkedHashMap<>(16, 0.75f, true);
+    this.maxWeight = new AtomicLong();
+    this.weight = 0L;
+    this.cache.put(Key.INSTANCE, this);
   }
 
-  /** Implements SizeEstimator via Coder. */
-  public static class CoderSizeEstimator<T> implements SizeEstimator<T> {
-    /** Basic implementation of {@link ElementByteSizeObserver} for use in size estimation. */
-    private static class Observer extends ElementByteSizeObserver {
-      private long observedSize = 0;
-
-      @Override
-      protected void reportElementSize(long elementSize) {
-        observedSize += elementSize;
+  private static class GroupingTableKey implements Weighted {
+    private final Object structuralKey;
+    private final Collection<? extends BoundedWindow> windows;
+    private final PaneInfo paneInfo;
+    private final long weight;
+
+    <K> GroupingTableKey(
+        K key,
+        Collection<? extends BoundedWindow> windows,
+        PaneInfo paneInfo,
+        Coder<K> keyCoder,
+        SizeEstimator<K> keySizer) {
+      this.structuralKey = keyCoder.structuralValue(key);
+      this.windows = windows;
+      this.paneInfo = paneInfo;
+      // We account for the weight of the key using the keySizer if the coder's structural value
+      // is the same as its value.
+      if (structuralKey == key) {
+        weight = keySizer.estimateSize(key) + Caches.weigh(windows) + Caches.weigh(paneInfo);
+      } else {
+        weight = Caches.weigh(this);
       }
     }
 
-    final Coder<T> coder;
-
-    CoderSizeEstimator(Coder<T> coder) {
-      this.coder = coder;
+    public Object getStructuralKey() {
+      return structuralKey;
     }
 
-    @Override
-    public long estimateSize(T value) throws Exception {
-      // First try using byte size observer
-      CoderSizeEstimator.Observer observer = new CoderSizeEstimator.Observer();
-      coder.registerByteSizeObserver(value, observer);
-
-      if (!observer.getIsLazy()) {
-        observer.advance();
-        return observer.observedSize;
-      } else {
-        // Coder byte size observation is lazy (requires iteration for observation) so fall back to
-        // counting output stream
-        CountingOutputStream os = new CountingOutputStream(ByteStreams.nullOutputStream());
-        coder.encode(value, os);
-        return os.getCount();
-      }
+    public Collection<? extends BoundedWindow> getWindows() {
+      return windows;
     }
-  }
 
-  /**
-   * Provides client-specific operations for working with elements that are key/value or key/values
-   * pairs.
-   */
-  public interface PairInfo {
-    Object getKeyFromInputPair(Object pair);
-
-    Object getValueFromInputPair(Object pair);
-
-    Object makeOutputPair(Object key, Object value);
-  }
-
-  /** Implements Precombine PairInfo via KVs. */
-  public static class WindowedPairInfo implements PairInfo {
-    private static WindowedPairInfo theInstance = new WindowedPairInfo();
-
-    public static WindowedPairInfo create() {
-      return theInstance;
+    public PaneInfo getPaneInfo() {
+      return paneInfo;
     }
 
-    private WindowedPairInfo() {}
+    @Override
+    public long getWeight() {
+      return weight;
+    }
 
     @Override
-    public Object getKeyFromInputPair(Object pair) {
-      @SuppressWarnings("unchecked")
-      WindowedValue<KV<?, ?>> windowedKv = (WindowedValue<KV<?, ?>>) pair;
-      return windowedKv.withValue(windowedKv.getValue().getKey());
+    public boolean equals(Object o) {
+      if (this == o) {
+        return true;
+      }
+      if (!(o instanceof GroupingTableKey)) {
+        return false;
+      }
+      GroupingTableKey that = (GroupingTableKey) o;
+      return Objects.equals(structuralKey, that.structuralKey)
+          && windows.equals(that.windows)
+          && paneInfo.equals(that.paneInfo);
     }
 
     @Override
-    public Object getValueFromInputPair(Object pair) {
-      @SuppressWarnings("unchecked")
-      WindowedValue<KV<?, ?>> windowedKv = (WindowedValue<KV<?, ?>>) pair;
-      return windowedKv.getValue().getValue();
+    public int hashCode() {
+      return Objects.hash(structuralKey, windows, paneInfo);
     }
 
     @Override
-    public Object makeOutputPair(Object key, Object values) {
-      WindowedValue<?> windowedKey = (WindowedValue<?>) key;
-      return windowedKey.withValue(KV.of(windowedKey.getValue(), values));
+    public String toString() {
+      return "GroupingTableKey{"
+          + "structuralKey="
+          + structuralKey
+          + ", windows="
+          + windows
+          + ", paneInfo="
+          + paneInfo
+          + ", weight="
+          + weight
+          + '}';
     }
   }
 
-  /** Provides client-specific operations for combining values. */
-  public interface Combiner<K, InputT, AccumT, OutputT> {
-    AccumT createAccumulator(K key);
-
-    AccumT add(K key, AccumT accumulator, InputT value);
-
-    AccumT merge(K key, Iterable<AccumT> accumulators);
-
-    AccumT compact(K key, AccumT accumulator);
+  private class GroupingTableEntry implements Weighted {
+    private final GroupingTableKey groupingKey;
+    private final K userKey;
+    private final long keySize;
+    private long accumulatorSize;
+    private AccumT accumulator;
+    private boolean dirty;
+
+    private GroupingTableEntry(GroupingTableKey groupingKey, K userKey, InputT initialInputValue) {
+      this.groupingKey = groupingKey;
+      this.userKey = userKey;
+      if (groupingKey.getStructuralKey() == userKey) {
+        // This object is only storing references to the same objects that are being stored
+        // by the cache so the accounting of the size of the key is occurring already.
+        this.keySize = Caches.REFERENCE_SIZE * 2;
+      } else {
+        this.keySize = Caches.REFERENCE_SIZE + keySizer.estimateSize(userKey);
+      }
+      this.accumulator =
+          combineFn.createAccumulator(
+              options, NullSideInputReader.empty(), groupingKey.getWindows());
+      add(initialInputValue);
+      this.accumulatorSize = accumulatorSizer.estimateSize(accumulator);
+    }
 
-    OutputT extract(K key, AccumT accumulator);
-  }
+    public GroupingTableKey getGroupingKey() {
+      return groupingKey;
+    }
 
-  /** Implements Precombine Combiner via Combine.KeyedCombineFn. */
-  public static class ValueCombiner<K, InputT, AccumT, OutputT>
-      implements Combiner<WindowedValue<K>, InputT, AccumT, OutputT> {
-    private final GlobalCombineFnRunner<InputT, AccumT, OutputT> combineFn;
-    private final SideInputReader sideInputReader;
-    private final PipelineOptions options;
-
-    private ValueCombiner(
-        GlobalCombineFnRunner<InputT, AccumT, OutputT> combineFn,
-        SideInputReader sideInputReader,
-        PipelineOptions options) {
-      this.combineFn = combineFn;
-      this.sideInputReader = sideInputReader;
-      this.options = options;
+    public K getKey() {
+      return userKey;
     }
 
-    @Override
-    public AccumT createAccumulator(WindowedValue<K> windowedKey) {
-      return this.combineFn.createAccumulator(options, sideInputReader, windowedKey.getWindows());
+    public AccumT getValue() {
+      return accumulator;
     }
 
     @Override
-    public AccumT add(WindowedValue<K> windowedKey, AccumT accumulator, InputT value) {
-      return this.combineFn.addInput(
-          accumulator, value, options, sideInputReader, windowedKey.getWindows());
+    public long getWeight() {
+      return keySize + accumulatorSize;
     }
 
-    @Override
-    public AccumT merge(WindowedValue<K> windowedKey, Iterable<AccumT> accumulators) {
-      return this.combineFn.mergeAccumulators(
-          accumulators, options, sideInputReader, windowedKey.getWindows());
+    public void compact() {
+      if (dirty) {
+        accumulator =
+            combineFn.compact(
+                accumulator, options, NullSideInputReader.empty(), groupingKey.getWindows());
+        accumulatorSize = accumulatorSizer.estimateSize(accumulator);
+        dirty = false;
+      }
     }
 
-    @Override
-    public AccumT compact(WindowedValue<K> windowedKey, AccumT accumulator) {
-      return this.combineFn.compact(
-          accumulator, options, sideInputReader, windowedKey.getWindows());
+    public void add(InputT value) {
+      dirty = true;
+      accumulator =
+          combineFn.addInput(
+              accumulator, value, options, NullSideInputReader.empty(), groupingKey.getWindows());
+      accumulatorSize = accumulatorSizer.estimateSize(accumulator);
     }
 
     @Override
-    public OutputT extract(WindowedValue<K> windowedKey, AccumT accumulator) {
-      return this.combineFn.extractOutput(
-          accumulator, options, sideInputReader, windowedKey.getWindows());
+    public String toString() {
+      return "GroupingTableEntry{"
+          + "groupingKey="
+          + groupingKey
+          + ", userKey="
+          + userKey
+          + ", keySize="
+          + keySize
+          + ", accumulatorSize="
+          + accumulatorSize
+          + ", accumulator="
+          + accumulator
+          + ", dirty="
+          + dirty
+          + '}';
     }
   }
 
-  // How many bytes a word in the JVM has.
-  private static final int BYTES_PER_JVM_WORD = getBytesPerJvmWord();
-  /**
-   * The number of bytes of overhead to store an entry in the grouping table (a {@code
-   * HashMap<StructuralByteArray, KeyAndValues>}), ignoring the actual number of bytes in the keys
-   * and values:
-   *
-   * <ul>
-   *   <li>an array element (1 word),
-   *   <li>a HashMap.Entry (4 words),
-   *   <li>a StructuralByteArray (1 words),
-   *   <li>a backing array (guessed at 1 word for the length),
-   *   <li>a KeyAndValues (2 words),
-   *   <li>an ArrayList (2 words),
-   *   <li>a backing array (1 word),
-   *   <li>per-object overhead (JVM-specific, guessed at 2 words * 6 objects).
-   * </ul>
-   */
-  private static final int PER_KEY_OVERHEAD = 24 * BYTES_PER_JVM_WORD;
-
-  /** A {@link GroupingTable} that uses the given combiner to combine values in place. */
-  // Keep the table relatively full to increase the chance of collisions.
-  private static final double TARGET_LOAD = 0.9;
-
-  private long maxSize;
-  private final GroupingKeyCreator<? super K> groupingKeyCreator;
-  private final PairInfo pairInfo;
-  private final Combiner<? super K, InputT, AccumT, ?> combiner;
-  private final SizeEstimator<? super K> keySizer;
-  private final SizeEstimator<? super AccumT> accumulatorSizer;
-
-  private long size = 0;
-  private Map<Object, GroupingTableEntry<K, InputT, AccumT>> table;
-
-  PrecombineGroupingTable(
-      long maxSize,
-      GroupingKeyCreator<? super K> groupingKeyCreator,
-      PairInfo pairInfo,
-      Combiner<? super K, InputT, AccumT, ?> combineFn,
-      SizeEstimator<? super K> keySizer,
-      SizeEstimator<? super AccumT> accumulatorSizer) {
-    this.maxSize = maxSize;
-    this.groupingKeyCreator = groupingKeyCreator;
-    this.pairInfo = pairInfo;
-    this.combiner = combineFn;
-    this.keySizer = keySizer;
-    this.accumulatorSizer = accumulatorSizer;
-    this.table = new HashMap<>();
-  }
-
-  interface GroupingTableEntry<K, InputT, AccumT> {
-    K getKey();
-
-    AccumT getValue();
-
-    void add(InputT value) throws Exception;
-
-    long getSize();
-
-    void compact() throws Exception;
-  }
-
-  private GroupingTableEntry<K, InputT, AccumT> createTableEntry(final K key) throws Exception {
-    return new GroupingTableEntry<K, InputT, AccumT>() {
-      final long keySize = keySizer.estimateSize(key);
-      AccumT accumulator = combiner.createAccumulator(key);
-      long accumulatorSize = 0; // never used before a value is added...
-
-      @Override
-      public K getKey() {
-        return key;
-      }
-
-      @Override
-      public AccumT getValue() {
-        return accumulator;
-      }
-
-      @Override
-      public long getSize() {
-        return keySize + accumulatorSize;
-      }
-
-      @Override
-      public void compact() throws Exception {
-        AccumT newAccumulator = combiner.compact(key, accumulator);
-        if (newAccumulator != accumulator) {
-          accumulator = newAccumulator;
-          accumulatorSize = accumulatorSizer.estimateSize(newAccumulator);
-        }
-      }
-
-      @Override
-      public void add(InputT value) throws Exception {
-        accumulator = combiner.add(key, accumulator, value);
-        accumulatorSize = accumulatorSizer.estimateSize(accumulator);
-      }
-    };
-  }
-
-  /** Adds a pair to this table, possibly flushing some entries to output if the table is full. */
-  @SuppressWarnings("unchecked")
-  @Override
-  public void put(Object pair, Receiver receiver) throws Exception {
-    put(
-        (K) pairInfo.getKeyFromInputPair(pair),
-        (InputT) pairInfo.getValueFromInputPair(pair),
-        receiver);
-  }
-
   /**
    * Adds the key and value to this table, possibly flushing some entries to output if the table is
    * full.
    */
-  public void put(K key, InputT value, Receiver receiver) throws Exception {
-    Object groupingKey = groupingKeyCreator.createGroupingKey(key);
-    GroupingTableEntry<K, InputT, AccumT> entry = table.get(groupingKey);
-    if (entry == null) {
-      entry = createTableEntry(key);
-      table.put(groupingKey, entry);
-      size += PER_KEY_OVERHEAD;
-    } else {
-      size -= entry.getSize();
-    }
-    entry.add(value);
-    size += entry.getSize();
-
-    if (size >= maxSize) {
-      long targetSize = (long) (TARGET_LOAD * maxSize);
-      Iterator<GroupingTableEntry<K, InputT, AccumT>> entries = table.values().iterator();
-      while (size >= targetSize) {
-        if (!entries.hasNext()) {
-          // Should never happen, but sizes may be estimates...
-          size = 0;
-          break;
+  @VisibleForTesting
+  public void put(
+      WindowedValue<KV<K, InputT>> value, FnDataReceiver<WindowedValue<KV<K, AccumT>>> receiver)
+      throws Exception {
+    // Ignore timestamp for grouping purposes.
+    // The Pre-combine output will inherit the timestamp of one of its inputs.
+    GroupingTableKey groupingKey =
+        new GroupingTableKey(
+            value.getValue().getKey(), value.getWindows(), value.getPane(), keyCoder, keySizer);
+
+    lruMap.compute(
+        groupingKey,
+        (key, tableEntry) -> {
+          if (tableEntry == null) {
+            weight += groupingKey.getWeight();
+            tableEntry =
+                new GroupingTableEntry(
+                    groupingKey, value.getValue().getKey(), value.getValue().getValue());
+          } else {
+            weight -= tableEntry.getWeight();
+            tableEntry.add(value.getValue().getValue());
+          }
+          weight += tableEntry.getWeight();
+          return tableEntry;
+        });
+
+    // Increase the maximum only if we require it
+    maxWeight.accumulateAndGet(weight, (current, update) -> current < update ? update : current);
+
+    // Update the cache to ensure that LRU is handled appropriately and for the cache to have an
+    // opportunity to shrink the maxWeight if necessary.
+    cache.put(Key.INSTANCE, this);
+
+    // Get the updated weight now that the cache may have been shrunk and respect it
+    long currentMax = maxWeight.get();
+
+    // Only compact and output from the bundle processing thread that is inserting elements into the
+    // grouping table. This ensures that we honor the guarantee that transforms for a single bundle
+    // execute using the same thread.
+    if (weight > currentMax) {
+      // Try to compact as many the values as possible and only flush values if compaction wasn't
+      // enough.
+      for (GroupingTableEntry valueToCompact : lruMap.values()) {
+        long currentWeight = valueToCompact.getWeight();
+        valueToCompact.compact();
+        weight += valueToCompact.getWeight() - currentWeight;
+      }
+
+      if (weight > currentMax) {
+        Iterator<GroupingTableEntry> iterator = lruMap.values().iterator();
+        while (iterator.hasNext()) {
+          GroupingTableEntry valueToFlush = iterator.next();
+          weight -= valueToFlush.getWeight() + valueToFlush.getGroupingKey().getWeight();
+          iterator.remove();
+          output(valueToFlush, receiver);
+          if (weight <= currentMax) {
+            break;
+          }
         }
-        GroupingTableEntry<K, InputT, AccumT> toFlush = entries.next();
-        entries.remove();
-        size -= toFlush.getSize() + PER_KEY_OVERHEAD;
-        output(toFlush, receiver);
       }
     }
   }
@@ -410,41 +383,26 @@ public class PrecombineGroupingTable<K, InputT, AccumT>
   /**
    * Output the given entry. Does not actually remove it from the table or update this table's size.
    */
-  private void output(GroupingTableEntry<K, InputT, AccumT> entry, Receiver receiver)
+  private void output(
+      GroupingTableEntry entry, FnDataReceiver<WindowedValue<KV<K, AccumT>>> receiver)
       throws Exception {
     entry.compact();
-    receiver.process(pairInfo.makeOutputPair(entry.getKey(), entry.getValue()));
+    receiver.accept(
+        WindowedValue.of(
+            KV.of(entry.getKey(), entry.getValue()),
+            IGNORED,
+            entry.getGroupingKey().getWindows(),
+            entry.getGroupingKey().getPaneInfo()));
   }
 
   /** Flushes all entries in this table to output. */
-  @Override
-  public void flush(Receiver output) throws Exception {
-    for (GroupingTableEntry<K, InputT, AccumT> entry : table.values()) {
-      output(entry, output);
-    }
-    table.clear();
-    size = 0;
-  }
-
-  @VisibleForTesting
-  public void setMaxSize(long maxSize) {
-    this.maxSize = maxSize;
-  }
-
-  @VisibleForTesting
-  public long size() {
-    return size;
-  }
-
-  /** Returns the number of bytes in a JVM word. In case we failed to find the answer, returns 8. */
-  private static int getBytesPerJvmWord() {
-    String wordSizeInBits = System.getProperty("sun.arch.data.model");
-    try {
-      return Integer.parseInt(wordSizeInBits) / 8;
-    } catch (NumberFormatException e) {
-      // The JVM word size is unknown.  Assume 64-bit.
-      return 8;
+  public void flush(FnDataReceiver<WindowedValue<KV<K, AccumT>>> receiver) throws Exception {
+    cache.remove(Key.INSTANCE);
+    for (GroupingTableEntry valueToFlush : lruMap.values()) {
+      output(valueToFlush, receiver);
     }
+    lruMap.clear();
+    weight = 0;
   }
 
   ////////////////////////////////////////////////////////////////////////////
@@ -510,7 +468,7 @@ public class PrecombineGroupingTable<K, InputT, AccumT>
     }
 
     @Override
-    public long estimateSize(T element) throws Exception {
+    public long estimateSize(T element) {
       if (sampleNow()) {
         return recordSample(underlying.estimateSize(element));
       } else {
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/PrecombineGroupingTableTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/PrecombineGroupingTableTest.java
index 3d3e2d31cd0..56cb12b0dad 100644
--- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/PrecombineGroupingTableTest.java
+++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/PrecombineGroupingTableTest.java
@@ -17,10 +17,12 @@
  */
 package org.apache.beam.fn.harness;
 
+import static org.apache.beam.sdk.util.WindowedValue.valueInGlobalWindow;
 import static org.hamcrest.MatcherAssert.assertThat;
 import static org.hamcrest.Matchers.anyOf;
+import static org.hamcrest.Matchers.containsInAnyOrder;
 import static org.hamcrest.Matchers.empty;
-import static org.hamcrest.Matchers.hasItem;
+import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.in;
 import static org.hamcrest.core.Is.is;
 import static org.junit.Assert.assertEquals;
@@ -28,16 +30,28 @@ import static org.junit.Assert.assertEquals;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.List;
+import java.util.Map;
 import java.util.Random;
-import org.apache.beam.fn.harness.GroupingTable.Receiver;
-import org.apache.beam.fn.harness.PrecombineGroupingTable.Combiner;
-import org.apache.beam.fn.harness.PrecombineGroupingTable.GroupingKeyCreator;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
 import org.apache.beam.fn.harness.PrecombineGroupingTable.SamplingSizeEstimator;
 import org.apache.beam.fn.harness.PrecombineGroupingTable.SizeEstimator;
+import org.apache.beam.runners.core.GlobalCombineFnRunner;
+import org.apache.beam.runners.core.GlobalCombineFnRunners;
+import org.apache.beam.sdk.coders.StringUtf8Coder;
+import org.apache.beam.sdk.coders.VarLongCoder;
+import org.apache.beam.sdk.fn.data.FnDataReceiver;
+import org.apache.beam.sdk.fn.test.TestExecutors;
+import org.apache.beam.sdk.fn.test.TestExecutors.TestExecutorService;
+import org.apache.beam.sdk.options.PipelineOptions;
+import org.apache.beam.sdk.options.PipelineOptionsFactory;
+import org.apache.beam.sdk.transforms.Combine.CombineFn;
+import org.apache.beam.sdk.util.WindowedValue;
 import org.apache.beam.sdk.values.KV;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ArrayListMultimap;
 import org.hamcrest.Description;
 import org.hamcrest.TypeSafeDiagnosingMatcher;
-import org.hamcrest.collection.IsIterableContainingInAnyOrder;
+import org.junit.Rule;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.JUnit4;
@@ -46,79 +60,243 @@ import org.junit.runners.JUnit4;
 @RunWith(JUnit4.class)
 public class PrecombineGroupingTableTest {
 
-  private static class TestOutputReceiver implements Receiver {
-    final List<Object> outputElems = new ArrayList<>();
+  @Rule
+  public TestExecutorService executorService = TestExecutors.from(Executors.newCachedThreadPool());
+
+  private static class TestOutputReceiver<T> implements FnDataReceiver<T> {
+    final List<T> outputElems = new ArrayList<>();
 
     @Override
-    public void process(Object elem) {
+    public void accept(T elem) {
       outputElems.add(elem);
     }
   }
 
-  @Test
-  public void testCombiningGroupingTable() throws Exception {
-    Combiner<Object, Integer, Long, Long> summingCombineFn =
-        new Combiner<Object, Integer, Long, Long>() {
+  private static final CombineFn<Integer, Long, Long> COMBINE_FN =
+      new CombineFn<Integer, Long, Long>() {
 
-          @Override
-          public Long createAccumulator(Object key) {
-            return 0L;
-          }
+        @Override
+        public Long createAccumulator() {
+          return 0L;
+        }
 
-          @Override
-          public Long add(Object key, Long accumulator, Integer value) {
-            return accumulator + value;
-          }
+        @Override
+        public Long addInput(Long accumulator, Integer value) {
+          return accumulator + value;
+        }
 
-          @Override
-          public Long merge(Object key, Iterable<Long> accumulators) {
-            long sum = 0;
-            for (Long part : accumulators) {
-              sum += part;
-            }
-            return sum;
+        @Override
+        public Long mergeAccumulators(Iterable<Long> accumulators) {
+          long sum = 0;
+          for (Long part : accumulators) {
+            sum += part;
           }
+          return sum;
+        }
 
-          @Override
-          public Long compact(Object key, Long accumulator) {
-            return accumulator;
+        @Override
+        public Long compact(Long accumulator) {
+          if (accumulator % 2 == 0) {
+            return accumulator / 4;
           }
+          return accumulator;
+        }
 
-          @Override
-          public Long extract(Object key, Long accumulator) {
-            return accumulator;
-          }
-        };
+        @Override
+        public Long extractOutput(Long accumulator) {
+          return accumulator;
+        }
+      };
+
+  @Test
+  public void testCombiningGroupingTableHonorsKeyWeights() throws Exception {
+    PrecombineGroupingTable<String, Integer, Long> table =
+        new PrecombineGroupingTable<>(
+            PipelineOptionsFactory.create(),
+            Caches.forMaximumBytes(2500L),
+            StringUtf8Coder.of(),
+            GlobalCombineFnRunners.create(COMBINE_FN),
+            new StringPowerSizeEstimator(),
+            new IdentitySizeEstimator());
+
+    TestOutputReceiver<WindowedValue<KV<String, Long>>> receiver = new TestOutputReceiver<>();
+
+    // Putting the same 1000 weight key in should not cause any eviction.
+    table.put(valueInGlobalWindow(KV.of("AAA", 1)), receiver);
+    table.put(valueInGlobalWindow(KV.of("AAA", 2)), receiver);
+    table.put(valueInGlobalWindow(KV.of("AAA", 4)), receiver);
+    assertThat(receiver.outputElems, empty());
+
+    // Putting in other large keys should cause eviction.
+    table.put(valueInGlobalWindow(KV.of("BBB", 9)), receiver);
+    table.put(valueInGlobalWindow(KV.of("CCC", 11)), receiver);
+    assertThat(
+        receiver.outputElems,
+        containsInAnyOrder(
+            valueInGlobalWindow(KV.of("AAA", 1L + 2 + 4)), valueInGlobalWindow(KV.of("BBB", 9L))));
+
+    table.flush(receiver);
+    assertThat(
+        receiver.outputElems,
+        containsInAnyOrder(
+            valueInGlobalWindow(KV.of("AAA", 1L + 2 + 4)),
+            valueInGlobalWindow(KV.of("BBB", 9L)),
+            valueInGlobalWindow(KV.of("CCC", 11L))));
+  }
 
+  @Test
+  public void testCombiningGroupingTableEvictsAllOnLargeEntry() throws Exception {
     PrecombineGroupingTable<String, Integer, Long> table =
         new PrecombineGroupingTable<>(
-            100_000_000L,
-            new IdentityGroupingKeyCreator(),
-            new KvPairInfo(),
-            summingCombineFn,
+            PipelineOptionsFactory.create(),
+            Caches.forMaximumBytes(2500L),
+            StringUtf8Coder.of(),
+            GlobalCombineFnRunners.create(COMBINE_FN),
             new StringPowerSizeEstimator(),
             new IdentitySizeEstimator());
-    table.setMaxSize(1000);
 
-    TestOutputReceiver receiver = new TestOutputReceiver();
+    TestOutputReceiver<WindowedValue<KV<String, Long>>> receiver = new TestOutputReceiver<>();
 
-    table.put("A", 1, receiver);
-    table.put("B", 2, receiver);
-    table.put("B", 3, receiver);
-    table.put("C", 4, receiver);
+    table.put(valueInGlobalWindow(KV.of("A", 1)), receiver);
+    table.put(valueInGlobalWindow(KV.of("B", 3)), receiver);
+    table.put(valueInGlobalWindow(KV.of("B", 6)), receiver);
+    table.put(valueInGlobalWindow(KV.of("C", 7)), receiver);
     assertThat(receiver.outputElems, empty());
 
-    table.put("C", 5000, receiver);
-    assertThat(receiver.outputElems, hasItem((Object) KV.of("C", 5004L)));
+    // Add beyond the size which causes compaction which still leads to evicting all since the
+    // largest is most recent.
+    table.put(valueInGlobalWindow(KV.of("C", 9999)), receiver);
+    assertThat(
+        receiver.outputElems,
+        containsInAnyOrder(
+            valueInGlobalWindow(KV.of("A", 1L)),
+            valueInGlobalWindow(KV.of("B", 9L)),
+            valueInGlobalWindow(KV.of("C", (9999L + 7) / 4))));
 
-    table.put("DDDD", 6, receiver);
-    assertThat(receiver.outputElems, hasItem((Object) KV.of("DDDD", 6L)));
+    table.flush(receiver);
+    assertThat(
+        receiver.outputElems,
+        containsInAnyOrder(
+            valueInGlobalWindow(KV.of("A", 1L)),
+            valueInGlobalWindow(KV.of("B", 3L + 6)),
+            valueInGlobalWindow(KV.of("C", (9999L + 7) / 4))));
+  }
+
+  @Test
+  public void testCombiningGroupingTableCompactionSaves() throws Exception {
+    PrecombineGroupingTable<String, Integer, Long> table =
+        new PrecombineGroupingTable<>(
+            PipelineOptionsFactory.create(),
+            Caches.forMaximumBytes(2500L),
+            StringUtf8Coder.of(),
+            GlobalCombineFnRunners.create(COMBINE_FN),
+            new StringPowerSizeEstimator(),
+            new IdentitySizeEstimator());
+
+    TestOutputReceiver<WindowedValue<KV<String, Long>>> receiver = new TestOutputReceiver<>();
+
+    // Insert three compactable values which shouldn't lead to eviction even though we are over
+    // the maximum size.
+    table.put(valueInGlobalWindow(KV.of("A", 1004)), receiver);
+    table.put(valueInGlobalWindow(KV.of("B", 1004)), receiver);
+    table.put(valueInGlobalWindow(KV.of("C", 1004)), receiver);
+    assertThat(receiver.outputElems, empty());
 
     table.flush(receiver);
     assertThat(
         receiver.outputElems,
-        IsIterableContainingInAnyOrder.containsInAnyOrder(
-            KV.of("A", 1L), KV.of("B", 2L + 3), KV.of("C", 5000L + 4), KV.of("DDDD", 6L)));
+        containsInAnyOrder(
+            valueInGlobalWindow(KV.of("A", 1004L / 4)),
+            valueInGlobalWindow(KV.of("B", 1004L / 4)),
+            valueInGlobalWindow(KV.of("C", 1004L / 4))));
+  }
+
+  @Test
+  public void testCombiningGroupingTablePartialEviction() throws Exception {
+    PrecombineGroupingTable<String, Integer, Long> table =
+        new PrecombineGroupingTable<>(
+            PipelineOptionsFactory.create(),
+            Caches.forMaximumBytes(2500L),
+            StringUtf8Coder.of(),
+            GlobalCombineFnRunners.create(COMBINE_FN),
+            new StringPowerSizeEstimator(),
+            new IdentitySizeEstimator());
+
+    TestOutputReceiver<WindowedValue<KV<String, Long>>> receiver = new TestOutputReceiver<>();
+
+    // Insert three values which even with compaction isn't enough so we evict A & B to get
+    // under the max weight.
+    table.put(valueInGlobalWindow(KV.of("A", 1001)), receiver);
+    table.put(valueInGlobalWindow(KV.of("B", 1001)), receiver);
+    table.put(valueInGlobalWindow(KV.of("C", 1001)), receiver);
+    assertThat(
+        receiver.outputElems,
+        containsInAnyOrder(
+            valueInGlobalWindow(KV.of("A", 1001L)), valueInGlobalWindow(KV.of("B", 1001L))));
+
+    table.flush(receiver);
+    assertThat(
+        receiver.outputElems,
+        containsInAnyOrder(
+            valueInGlobalWindow(KV.of("A", 1001L)),
+            valueInGlobalWindow(KV.of("B", 1001L)),
+            valueInGlobalWindow(KV.of("C", 1001L))));
+  }
+
+  @Test
+  public void testCombiningGroupingTableEmitsCorrectValuesUnderHighCacheContention()
+      throws Exception {
+    Long[] expectedKeys = new Long[1000];
+    for (int j = 1; j <= 1000; ++j) {
+      expectedKeys[j - 1] = (long) j;
+    }
+
+    int numThreads = 1000;
+    List<Future<?>> futures = new ArrayList<>(numThreads);
+    PipelineOptions options = PipelineOptionsFactory.create();
+    GlobalCombineFnRunner<Integer, Long, Long> combineFnRunner =
+        GlobalCombineFnRunners.create(COMBINE_FN);
+    Cache<Object, Object> cache = Caches.forMaximumBytes(numThreads * 50000);
+    for (int i = 0; i < numThreads; ++i) {
+      final int currentI = i;
+      futures.add(
+          executorService.submit(
+              () -> {
+                ArrayListMultimap<Long, Long> values = ArrayListMultimap.create();
+                PrecombineGroupingTable<Long, Integer, Long> table =
+                    new PrecombineGroupingTable<>(
+                        options,
+                        Caches.subCache(cache, currentI),
+                        VarLongCoder.of(),
+                        combineFnRunner,
+                        new IdentitySizeEstimator(),
+                        new IdentitySizeEstimator());
+                for (int j = 1; j <= 1000; ++j) {
+                  table.put(
+                      valueInGlobalWindow(KV.of((long) j, j)),
+                      (input) ->
+                          values.put(input.getValue().getKey(), input.getValue().getValue()));
+                }
+                for (int j = 1; j <= 1000; ++j) {
+                  table.flush(
+                      (input) ->
+                          values.put(input.getValue().getKey(), input.getValue().getValue()));
+                }
+
+                assertThat(values.keySet(), containsInAnyOrder(expectedKeys));
+                for (Map.Entry<Long, Long> value : values.entries()) {
+                  if (value.getKey() % 2 == 0) {
+                    assertThat(value.getValue(), equalTo(value.getKey() / 4));
+                  } else {
+                    assertThat(value.getValue(), equalTo(value.getKey()));
+                  }
+                }
+                return null;
+              }));
+    }
+    for (Future<?> future : futures) {
+      future.get();
+    }
   }
 
   ////////////////////////////////////////////////////////////////////////////
@@ -258,14 +436,6 @@ public class PrecombineGroupingTableTest {
     };
   }
 
-  /** Return the key as its grouping key. */
-  private static class IdentityGroupingKeyCreator implements GroupingKeyCreator<Object> {
-    @Override
-    public Object createGroupingKey(Object key) {
-      return key;
-    }
-  }
-
   /** "Estimate" the size of longs by looking at their value. */
   private static class IdentitySizeEstimator implements SizeEstimator<Long> {
     int calls = 0;
@@ -284,22 +454,4 @@ public class PrecombineGroupingTableTest {
       return (long) Math.pow(10, element.length());
     }
   }
-
-  private static class KvPairInfo implements PrecombineGroupingTable.PairInfo {
-    @SuppressWarnings("unchecked")
-    @Override
-    public Object getKeyFromInputPair(Object pair) {
-      return ((KV<Object, ?>) pair).getKey();
-    }
-
-    @Override
-    public Object getValueFromInputPair(Object pair) {
-      return ((KV<?, ?>) pair).getValue();
-    }
-
-    @Override
-    public Object makeOutputPair(Object key, Object value) {
-      return KV.of(key, value);
-    }
-  }
 }