You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@druid.apache.org by ab...@apache.org on 2022/10/03 06:38:32 UTC

[druid] branch master updated: Update ClusterByStatisticsCollectorImpl to use bytes instead of keys (#12998)

This is an automated email from the ASF dual-hosted git repository.

abhishek pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/druid.git


The following commit(s) were added to refs/heads/master by this push:
     new 92d2633ae6 Update ClusterByStatisticsCollectorImpl to use bytes instead of keys (#12998)
92d2633ae6 is described below

commit 92d2633ae6808116c52d5b813403fd1f5b84309c
Author: Adarsh Sanjeev <ad...@gmail.com>
AuthorDate: Mon Oct 3 12:08:23 2022 +0530

    Update ClusterByStatisticsCollectorImpl to use bytes instead of keys (#12998)
    
    * Update clusterByStatistics to use bytes instead of keys
    
    * Address review comments
    
    * Resolve checkstyle
    
    * Increase test coverage
    
    * Update test
    
    * Update thresholds
    
    * Update retained keys function
    
    * Update docs
    
    * Fix spelling
---
 docs/multi-stage-query/concepts.md                 |  3 +
 .../apache/druid/msq/kernel/StageDefinition.java   |  4 +-
 .../ClusterByStatisticsCollectorImpl.java          | 78 +++++++++++-----------
 .../msq/statistics/DelegateOrMinKeyCollector.java  | 10 +++
 .../druid/msq/statistics/DistinctKeyCollector.java | 32 ++++++---
 .../apache/druid/msq/statistics/KeyCollector.java  |  6 ++
 .../statistics/QuantilesSketchKeyCollector.java    | 38 +++++++++--
 .../QuantilesSketchKeyCollectorFactory.java        | 12 ++--
 .../QuantilesSketchKeyCollectorSnapshot.java       | 20 ++++--
 .../ClusterByStatisticsCollectorImplTest.java      |  4 +-
 .../statistics/DelegateOrMinKeyCollectorTest.java  | 28 +++++---
 .../msq/statistics/DistinctKeyCollectorTest.java   | 14 ++--
 .../QuantilesSketchKeyCollectorSnapshotTest.java   | 39 +++++++++++
 .../QuantilesSketchKeyCollectorTest.java           | 45 ++++++++++++-
 .../java/org/apache/druid/frame/key/RowKey.java    |  5 ++
 .../org/apache/druid/frame/key/RowKeyTest.java     | 13 ++++
 16 files changed, 261 insertions(+), 90 deletions(-)

diff --git a/docs/multi-stage-query/concepts.md b/docs/multi-stage-query/concepts.md
index edf2d9111f..5955b5fc14 100644
--- a/docs/multi-stage-query/concepts.md
+++ b/docs/multi-stage-query/concepts.md
@@ -252,6 +252,9 @@ Worker tasks use both JVM heap memory and off-heap ("direct") memory.
 On Peons launched by Middle Managers, the bulk of the JVM heap (75%) is split up into two bundles of equal size: one
 processor bundle and one worker bundle. Each one comprises 37.5% of the available JVM heap.
 
+Depending on the type of query, each worker and controller task can use a sketch for generating partition boundaries.
+Each sketch uses at most approximately 300 MB.
+
 The processor memory bundle is used for query processing and segment generation. Each processor bundle must also
 provides space to buffer I/O between stages. Specifically, each downstream stage requires 1 MB of buffer space for each
 upstream worker. For example, if you have 100 workers running in stage 0, and stage 1 reads from stage 0, then each
diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/StageDefinition.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/StageDefinition.java
index 21fc56bd34..c01506054d 100644
--- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/StageDefinition.java
+++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/StageDefinition.java
@@ -74,7 +74,7 @@ import java.util.function.Supplier;
  */
 public class StageDefinition
 {
-  private static final int PARTITION_STATS_MAX_KEYS = 2 << 15; // Avoid immediate downsample of single-bucket collectors
+  private static final int PARTITION_STATS_MAX_BYTES = 300_000_000; // Avoid immediate downsample of single-bucket collectors
   private static final int PARTITION_STATS_MAX_BUCKETS = 5_000; // Limit for TooManyBuckets
   private static final int MAX_PARTITIONS = 25_000; // Limit for TooManyPartitions
 
@@ -289,7 +289,7 @@ public class StageDefinition
     return ClusterByStatisticsCollectorImpl.create(
         shuffleSpec.getClusterBy(),
         signature,
-        PARTITION_STATS_MAX_KEYS,
+        PARTITION_STATS_MAX_BYTES,
         PARTITION_STATS_MAX_BUCKETS,
         shuffleSpec.doesAggregateByClusterKey(),
         shuffleCheckHasMultipleValues
diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/statistics/ClusterByStatisticsCollectorImpl.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/statistics/ClusterByStatisticsCollectorImpl.java
index 02d1036cc1..9e033c8749 100644
--- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/statistics/ClusterByStatisticsCollectorImpl.java
+++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/statistics/ClusterByStatisticsCollectorImpl.java
@@ -56,17 +56,15 @@ public class ClusterByStatisticsCollectorImpl implements ClusterByStatisticsColl
 
   private final boolean[] hasMultipleValues;
 
-  // This can be reworked to accommodate maxSize instead of maxRetainedKeys to account for the skewness in the size of hte
-  // keys depending on the datasource
-  private final int maxRetainedKeys;
+  private final int maxRetainedBytes;
   private final int maxBuckets;
-  private int totalRetainedKeys;
+  private double totalRetainedBytes;
 
   private ClusterByStatisticsCollectorImpl(
       final ClusterBy clusterBy,
       final RowKeyReader keyReader,
       final KeyCollectorFactory<?, ?> keyCollectorFactory,
-      final int maxRetainedKeys,
+      final int maxRetainedBytes,
       final int maxBuckets,
       final boolean checkHasMultipleValues
   )
@@ -74,21 +72,21 @@ public class ClusterByStatisticsCollectorImpl implements ClusterByStatisticsColl
     this.clusterBy = clusterBy;
     this.keyReader = keyReader;
     this.keyCollectorFactory = keyCollectorFactory;
-    this.maxRetainedKeys = maxRetainedKeys;
+    this.maxRetainedBytes = maxRetainedBytes;
     this.buckets = new TreeMap<>(clusterBy.bucketComparator());
     this.maxBuckets = maxBuckets;
     this.checkHasMultipleValues = checkHasMultipleValues;
     this.hasMultipleValues = checkHasMultipleValues ? new boolean[clusterBy.getColumns().size()] : null;
 
-    if (maxBuckets > maxRetainedKeys) {
-      throw new IAE("maxBuckets[%s] cannot be larger than maxRetainedKeys[%s]", maxBuckets, maxRetainedKeys);
+    if (maxBuckets > maxRetainedBytes) {
+      throw new IAE("maxBuckets[%s] cannot be larger than maxRetainedBytes[%s]", maxBuckets, maxRetainedBytes);
     }
   }
 
   public static ClusterByStatisticsCollector create(
       final ClusterBy clusterBy,
       final RowSignature signature,
-      final int maxRetainedKeys,
+      final int maxRetainedBytes,
       final int maxBuckets,
       final boolean aggregate,
       final boolean checkHasMultipleValues
@@ -101,7 +99,7 @@ public class ClusterByStatisticsCollectorImpl implements ClusterByStatisticsColl
         clusterBy,
         keyReader,
         keyCollectorFactory,
-        maxRetainedKeys,
+        maxRetainedBytes,
         maxBuckets,
         checkHasMultipleValues
     );
@@ -126,8 +124,8 @@ public class ClusterByStatisticsCollectorImpl implements ClusterByStatisticsColl
 
     bucketHolder.keyCollector.add(key, weight);
 
-    totalRetainedKeys += bucketHolder.updateRetainedKeys();
-    if (totalRetainedKeys > maxRetainedKeys) {
+    totalRetainedBytes += bucketHolder.updateRetainedBytes();
+    if (totalRetainedBytes > maxRetainedBytes) {
       downSample();
     }
 
@@ -147,15 +145,15 @@ public class ClusterByStatisticsCollectorImpl implements ClusterByStatisticsColl
         //noinspection rawtypes, unchecked
         ((KeyCollector) bucketHolder.keyCollector).addAll(otherBucketEntry.getValue().keyCollector);
 
-        totalRetainedKeys += bucketHolder.updateRetainedKeys();
-        if (totalRetainedKeys > maxRetainedKeys) {
+        totalRetainedBytes += bucketHolder.updateRetainedBytes();
+        if (totalRetainedBytes > maxRetainedBytes) {
           downSample();
         }
       }
 
       if (checkHasMultipleValues) {
         for (int i = 0; i < clusterBy.getColumns().size(); i++) {
-          hasMultipleValues[i] |= that.hasMultipleValues[i];
+          hasMultipleValues[i] = hasMultipleValues[i] || that.hasMultipleValues[i];
         }
       }
     } else {
@@ -178,8 +176,8 @@ public class ClusterByStatisticsCollectorImpl implements ClusterByStatisticsColl
       //noinspection rawtypes, unchecked
       ((KeyCollector) bucketHolder.keyCollector).addAll(otherKeyCollector);
 
-      totalRetainedKeys += bucketHolder.updateRetainedKeys();
-      if (totalRetainedKeys > maxRetainedKeys) {
+      totalRetainedBytes += bucketHolder.updateRetainedBytes();
+      if (totalRetainedBytes > maxRetainedBytes) {
         downSample();
       }
     }
@@ -221,7 +219,7 @@ public class ClusterByStatisticsCollectorImpl implements ClusterByStatisticsColl
   public ClusterByStatisticsCollector clear()
   {
     buckets.clear();
-    totalRetainedKeys = 0;
+    totalRetainedBytes = 0;
     return this;
   }
 
@@ -232,7 +230,7 @@ public class ClusterByStatisticsCollectorImpl implements ClusterByStatisticsColl
       throw new IAE("Target weight must be positive");
     }
 
-    assertRetainedKeyCountsAreTrackedCorrectly();
+    assertRetainedByteCountsAreTrackedCorrectly();
 
     if (buckets.isEmpty()) {
       return ClusterByPartitions.oneUniversalPartition();
@@ -315,7 +313,7 @@ public class ClusterByStatisticsCollectorImpl implements ClusterByStatisticsColl
   @Override
   public ClusterByStatisticsSnapshot snapshot()
   {
-    assertRetainedKeyCountsAreTrackedCorrectly();
+    assertRetainedByteCountsAreTrackedCorrectly();
 
     final List<ClusterByStatisticsSnapshot.Bucket> bucketSnapshots = new ArrayList<>();
 
@@ -365,20 +363,20 @@ public class ClusterByStatisticsCollectorImpl implements ClusterByStatisticsColl
   }
 
   /**
-   * Reduce the number of retained keys by about half, if possible. May reduce by less than that, or keep the
+   * Reduce the number of retained bytes by about half, if possible. May reduce by less than that, or keep the
    * number the same, if downsampling is not possible. (For example: downsampling is not possible if all buckets
    * have been downsampled all the way to one key each.)
    */
   private void downSample()
   {
-    int newTotalRetainedKeys = totalRetainedKeys;
-    final int targetTotalRetainedKeys = totalRetainedKeys / 2;
+    double newTotalRetainedBytes = totalRetainedBytes;
+    final double targetTotalRetainedBytes = totalRetainedBytes / 2;
 
     final List<BucketHolder> sortedHolders = new ArrayList<>(buckets.size());
 
     // Only consider holders with more than one retained key. Holders with a single retained key cannot be downsampled.
     for (final BucketHolder holder : buckets.values()) {
-      if (holder.retainedKeys > 1) {
+      if (holder.keyCollector.estimatedRetainedKeys() > 1) {
         sortedHolders.add(holder);
       }
     }
@@ -386,54 +384,54 @@ public class ClusterByStatisticsCollectorImpl implements ClusterByStatisticsColl
     // Downsample least-dense buckets first. (They're less likely to need high resolution.)
     sortedHolders.sort(
         Comparator.comparing((BucketHolder holder) ->
-                                 (double) holder.keyCollector.estimatedTotalWeight() / holder.retainedKeys)
+                                 (double) holder.keyCollector.estimatedTotalWeight() / holder.keyCollector.estimatedRetainedKeys())
     );
 
     int i = 0;
-    while (i < sortedHolders.size() && newTotalRetainedKeys > targetTotalRetainedKeys) {
+    while (i < sortedHolders.size() && newTotalRetainedBytes > targetTotalRetainedBytes) {
       final BucketHolder bucketHolder = sortedHolders.get(i);
 
       // Ignore false return, because we wrap all collectors in DelegateOrMinKeyCollector and can be assured that
       // it will downsample all the way to one if needed. Can't do better than that.
       bucketHolder.keyCollector.downSample();
-      newTotalRetainedKeys += bucketHolder.updateRetainedKeys();
+      newTotalRetainedBytes += bucketHolder.updateRetainedBytes();
 
-      if (i == sortedHolders.size() - 1 || sortedHolders.get(i + 1).retainedKeys > bucketHolder.retainedKeys) {
+      if (i == sortedHolders.size() - 1 || sortedHolders.get(i + 1).retainedBytes > bucketHolder.retainedBytes) {
         i++;
       }
     }
 
-    totalRetainedKeys = newTotalRetainedKeys;
+    totalRetainedBytes = newTotalRetainedBytes;
   }
 
-  private void assertRetainedKeyCountsAreTrackedCorrectly()
+  private void assertRetainedByteCountsAreTrackedCorrectly()
   {
     // Check cached value of retainedKeys in each holder.
     assert buckets.values()
                   .stream()
-                  .allMatch(holder -> holder.retainedKeys == holder.keyCollector.estimatedRetainedKeys());
+                  .allMatch(holder -> holder.retainedBytes == holder.keyCollector.estimatedRetainedBytes());
 
-    // Check cached value of totalRetainedKeys.
-    assert totalRetainedKeys ==
-           buckets.values().stream().mapToInt(holder -> holder.keyCollector.estimatedRetainedKeys()).sum();
+    // Check cached value of totalRetainedBytes.
+    assert totalRetainedBytes ==
+           buckets.values().stream().mapToDouble(holder -> holder.keyCollector.estimatedRetainedBytes()).sum();
   }
 
   private static class BucketHolder
   {
     private final KeyCollector<?> keyCollector;
-    private int retainedKeys;
+    private double retainedBytes;
 
     public BucketHolder(final KeyCollector<?> keyCollector)
     {
       this.keyCollector = keyCollector;
-      this.retainedKeys = keyCollector.estimatedRetainedKeys();
+      this.retainedBytes = keyCollector.estimatedRetainedBytes();
     }
 
-    public int updateRetainedKeys()
+    public double updateRetainedBytes()
     {
-      final int newRetainedKeys = keyCollector.estimatedRetainedKeys();
-      final int difference = newRetainedKeys - retainedKeys;
-      retainedKeys = newRetainedKeys;
+      final double newRetainedBytes = keyCollector.estimatedRetainedBytes();
+      final double difference = newRetainedBytes - retainedBytes;
+      retainedBytes = newRetainedBytes;
       return difference;
     }
   }
diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/statistics/DelegateOrMinKeyCollector.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/statistics/DelegateOrMinKeyCollector.java
index 32936e41c2..179c2bc3ae 100644
--- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/statistics/DelegateOrMinKeyCollector.java
+++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/statistics/DelegateOrMinKeyCollector.java
@@ -127,6 +127,16 @@ public class DelegateOrMinKeyCollector<TDelegate extends KeyCollector<TDelegate>
     }
   }
 
+  @Override
+  public double estimatedRetainedBytes()
+  {
+    if (delegate != null) {
+      return delegate.estimatedRetainedBytes();
+    } else {
+      return minKey != null ? minKey.getNumberOfBytes() : 0;
+    }
+  }
+
   @Override
   public boolean downSample()
   {
diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/statistics/DistinctKeyCollector.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/statistics/DistinctKeyCollector.java
index c27bef375f..6868597667 100644
--- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/statistics/DistinctKeyCollector.java
+++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/statistics/DistinctKeyCollector.java
@@ -43,8 +43,8 @@ import java.util.Map;
  */
 public class DistinctKeyCollector implements KeyCollector<DistinctKeyCollector>
 {
-  static final int INITIAL_MAX_KEYS = 2 << 15 /* 65,536 */;
-  static final int SMALLEST_MAX_KEYS = 16;
+  static final int INITIAL_MAX_BYTES = 134_217_728;
+  static final int SMALLEST_MAX_BYTES = 5000;
   private static final int MISSING_KEY_WEIGHT = 0;
 
   private final Comparator<RowKey> comparator;
@@ -71,7 +71,8 @@ public class DistinctKeyCollector implements KeyCollector<DistinctKeyCollector>
    * collector type, which is based on a more solid statistical foundation.
    */
   private final Object2LongSortedMap<RowKey> retainedKeys;
-  private int maxKeys;
+  private int maxBytes;
+  private int retainedBytes;
 
   /**
    * Each key is retained with probability 2^(-spaceReductionFactor). This value is incremented on calls to
@@ -92,7 +93,7 @@ public class DistinctKeyCollector implements KeyCollector<DistinctKeyCollector>
     this.comparator = Preconditions.checkNotNull(comparator, "comparator");
     this.retainedKeys = Preconditions.checkNotNull(retainedKeys, "retainedKeys");
     this.retainedKeys.defaultReturnValue(MISSING_KEY_WEIGHT);
-    this.maxKeys = INITIAL_MAX_KEYS;
+    this.maxBytes = INITIAL_MAX_BYTES;
     this.spaceReductionFactor = spaceReductionFactor;
     this.totalWeightUnadjusted = 0;
 
@@ -120,14 +121,16 @@ public class DistinctKeyCollector implements KeyCollector<DistinctKeyCollector>
       if (isNewMin && !retainedKeys.isEmpty() && !isKeySelected(retainedKeys.firstKey())) {
         // Old min should be kicked out.
         totalWeightUnadjusted -= retainedKeys.removeLong(retainedKeys.firstKey());
+        retainedBytes -= retainedKeys.firstKey().getNumberOfBytes();
       }
 
       if (retainedKeys.putIfAbsent(key, weight) == MISSING_KEY_WEIGHT) {
         // We did add this key. (Previous value was zero, meaning absent.)
         totalWeightUnadjusted += weight;
+        retainedBytes += key.getNumberOfBytes();
       }
 
-      while (retainedKeys.size() >= maxKeys) {
+      while (retainedBytes >= maxBytes) {
         increaseSpaceReductionFactorIfPossible();
       }
     }
@@ -168,6 +171,12 @@ public class DistinctKeyCollector implements KeyCollector<DistinctKeyCollector>
     return retainedKeys.size();
   }
 
+  @Override
+  public double estimatedRetainedBytes()
+  {
+    return retainedBytes;
+  }
+
   @Override
   public RowKey minKey()
   {
@@ -182,13 +191,13 @@ public class DistinctKeyCollector implements KeyCollector<DistinctKeyCollector>
       return true;
     }
 
-    if (maxKeys == SMALLEST_MAX_KEYS) {
+    if (maxBytes <= SMALLEST_MAX_BYTES) {
       return false;
     }
 
-    maxKeys /= 2;
+    maxBytes /= 2;
 
-    while (retainedKeys.size() >= maxKeys) {
+    while (retainedBytes >= maxBytes) {
       if (!increaseSpaceReductionFactorIfPossible()) {
         return false;
       }
@@ -242,10 +251,10 @@ public class DistinctKeyCollector implements KeyCollector<DistinctKeyCollector>
     return retainedKeys;
   }
 
-  @JsonProperty("maxKeys")
-  int getMaxKeys()
+  @JsonProperty("maxBytes")
+  int getMaxBytes()
   {
-    return maxKeys;
+    return maxBytes;
   }
 
   @JsonProperty("spaceReductionFactor")
@@ -296,6 +305,7 @@ public class DistinctKeyCollector implements KeyCollector<DistinctKeyCollector>
 
       if (!isKeySelected(key)) {
         totalWeightUnadjusted -= entry.getLongValue();
+        retainedBytes -= entry.getKey().getNumberOfBytes();
         iterator.remove();
       }
     }
diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/statistics/KeyCollector.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/statistics/KeyCollector.java
index 1aada32a21..48287e74be 100644
--- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/statistics/KeyCollector.java
+++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/statistics/KeyCollector.java
@@ -53,6 +53,12 @@ public interface KeyCollector<CollectorType extends KeyCollector<CollectorType>>
    */
   int estimatedRetainedKeys();
 
+  /**
+   * Returns an estimate of the number of bytes currently retained by this collector. This may change over time as
+   * more keys are added.
+   */
+  double estimatedRetainedBytes();
+
   /**
    * Downsample this collector, dropping about half of the keys that are currently retained. Returns true if
    * the collector was downsampled, or if it is already retaining zero or one keys. Returns false if the collector is
diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/statistics/QuantilesSketchKeyCollector.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/statistics/QuantilesSketchKeyCollector.java
index 99fb8a23e8..950f9419af 100644
--- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/statistics/QuantilesSketchKeyCollector.java
+++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/statistics/QuantilesSketchKeyCollector.java
@@ -37,28 +37,39 @@ import java.util.NoSuchElementException;
 
 /**
  * A key collector that is used when not aggregating. It uses a quantiles sketch to track keys.
+ *
+ * The collector maintains the averageKeyLength for all keys added through {@link #add(RowKey, long)} or
+ * {@link #addAll(QuantilesSketchKeyCollector)}. The average is calculated as a running average and accounts for
+ * weight of the key added. The averageKeyLength is assumed to be unaffected by {@link #downSample()}.
  */
 public class QuantilesSketchKeyCollector implements KeyCollector<QuantilesSketchKeyCollector>
 {
   private final Comparator<RowKey> comparator;
   private ItemsSketch<RowKey> sketch;
+  private double averageKeyLength;
 
   QuantilesSketchKeyCollector(
       final Comparator<RowKey> comparator,
-      @Nullable final ItemsSketch<RowKey> sketch
+      @Nullable final ItemsSketch<RowKey> sketch,
+      double averageKeyLength
   )
   {
     this.comparator = comparator;
     this.sketch = sketch;
+    this.averageKeyLength = averageKeyLength;
   }
 
   @Override
   public void add(RowKey key, long weight)
   {
+    double estimatedTotalSketchSizeInBytes = averageKeyLength * sketch.getN();
+    // The key is added "weight" times to the sketch, we can update the total weight directly.
+    estimatedTotalSketchSizeInBytes += key.getNumberOfBytes() * weight;
     for (int i = 0; i < weight; i++) {
       // Add the same key multiple times to make it "heavier".
       sketch.update(key);
     }
+    averageKeyLength = (estimatedTotalSketchSizeInBytes / sketch.getN());
   }
 
   @Override
@@ -69,6 +80,10 @@ public class QuantilesSketchKeyCollector implements KeyCollector<QuantilesSketch
         comparator
     );
 
+    double sketchBytesCount = averageKeyLength * sketch.getN();
+    double otherBytesCount = other.averageKeyLength * other.getSketch().getN();
+    averageKeyLength = ((sketchBytesCount + otherBytesCount) / (sketch.getN() + other.sketch.getN()));
+
     union.update(sketch);
     union.update(other.sketch);
     sketch = union.getResultAndReset();
@@ -87,14 +102,15 @@ public class QuantilesSketchKeyCollector implements KeyCollector<QuantilesSketch
   }
 
   @Override
-  public int estimatedRetainedKeys()
+  public double estimatedRetainedBytes()
   {
-    // Rough estimation of retained keys for a given K for ~billions of total items, based on the table from
-    // https://datasketches.apache.org/docs/Quantiles/OrigQuantilesSketch.html.
-    final int estimatedMaxRetainedKeys = 11 * sketch.getK();
+    return averageKeyLength * estimatedRetainedKeys();
+  }
 
-    // Cast to int is safe because estimatedMaxRetainedKeys is always within int range.
-    return (int) Math.min(sketch.getN(), estimatedMaxRetainedKeys);
+  @Override
+  public int estimatedRetainedKeys()
+  {
+    return sketch.getRetainedItems();
   }
 
   @Override
@@ -165,4 +181,12 @@ public class QuantilesSketchKeyCollector implements KeyCollector<QuantilesSketch
   {
     return sketch;
   }
+
+  /**
+   * Retrieves the average key length. Exists for usage by {@link QuantilesSketchKeyCollectorFactory}.
+   */
+  double getAverageKeyLength()
+  {
+    return averageKeyLength;
+  }
 }
diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/statistics/QuantilesSketchKeyCollectorFactory.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/statistics/QuantilesSketchKeyCollectorFactory.java
index 613a7dc497..cfc2bd9a54 100644
--- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/statistics/QuantilesSketchKeyCollectorFactory.java
+++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/statistics/QuantilesSketchKeyCollectorFactory.java
@@ -38,9 +38,9 @@ import java.util.Comparator;
 public class QuantilesSketchKeyCollectorFactory
     implements KeyCollectorFactory<QuantilesSketchKeyCollector, QuantilesSketchKeyCollectorSnapshot>
 {
-  // smallest value with normalized rank error < 0.1%; retain up to ~86k elements
+  // Maximum value of K possible.
   @VisibleForTesting
-  static final int SKETCH_INITIAL_K = 1 << 12;
+  static final int SKETCH_INITIAL_K = 1 << 15;
 
   private final Comparator<RowKey> comparator;
 
@@ -57,7 +57,7 @@ public class QuantilesSketchKeyCollectorFactory
   @Override
   public QuantilesSketchKeyCollector newKeyCollector()
   {
-    return new QuantilesSketchKeyCollector(comparator, ItemsSketch.getInstance(SKETCH_INITIAL_K, comparator));
+    return new QuantilesSketchKeyCollector(comparator, ItemsSketch.getInstance(SKETCH_INITIAL_K, comparator), 0);
   }
 
   @Override
@@ -79,7 +79,7 @@ public class QuantilesSketchKeyCollectorFactory
   {
     final String encodedSketch =
         StringUtils.encodeBase64String(collector.getSketch().toByteArray(RowKeySerde.INSTANCE));
-    return new QuantilesSketchKeyCollectorSnapshot(encodedSketch);
+    return new QuantilesSketchKeyCollectorSnapshot(encodedSketch, collector.getAverageKeyLength());
   }
 
   @Override
@@ -89,7 +89,7 @@ public class QuantilesSketchKeyCollectorFactory
     final byte[] bytes = StringUtils.decodeBase64String(encodedSketch);
     final ItemsSketch<RowKey> sketch =
         ItemsSketch.getInstance(Memory.wrap(bytes), comparator, RowKeySerde.INSTANCE);
-    return new QuantilesSketchKeyCollector(comparator, sketch);
+    return new QuantilesSketchKeyCollector(comparator, sketch, snapshot.getAverageKeyLength());
   }
 
   private static class RowKeySerde extends ArrayOfItemsSerDe<RowKey>
@@ -106,7 +106,7 @@ public class QuantilesSketchKeyCollectorFactory
       int serializedSize = Integer.BYTES * items.length;
 
       for (final RowKey key : items) {
-        serializedSize += key.array().length;
+        serializedSize += key.getNumberOfBytes();
       }
 
       final byte[] serializedBytes = new byte[serializedSize];
diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/statistics/QuantilesSketchKeyCollectorSnapshot.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/statistics/QuantilesSketchKeyCollectorSnapshot.java
index 4e9fce437f..1b555ac3f9 100644
--- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/statistics/QuantilesSketchKeyCollectorSnapshot.java
+++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/statistics/QuantilesSketchKeyCollectorSnapshot.java
@@ -20,7 +20,7 @@
 package org.apache.druid.msq.statistics;
 
 import com.fasterxml.jackson.annotation.JsonCreator;
-import com.fasterxml.jackson.annotation.JsonValue;
+import com.fasterxml.jackson.annotation.JsonProperty;
 
 import java.util.Objects;
 
@@ -28,18 +28,27 @@ public class QuantilesSketchKeyCollectorSnapshot implements KeyCollectorSnapshot
 {
   private final String encodedSketch;
 
+  private final double averageKeyLength;
+
   @JsonCreator
-  public QuantilesSketchKeyCollectorSnapshot(String encodedSketch)
+  public QuantilesSketchKeyCollectorSnapshot(@JsonProperty("encodedSketch") String encodedSketch, @JsonProperty("averageKeyLength") double averageKeyLength)
   {
     this.encodedSketch = encodedSketch;
+    this.averageKeyLength = averageKeyLength;
   }
 
-  @JsonValue
+  @JsonProperty("encodedSketch")
   public String getEncodedSketch()
   {
     return encodedSketch;
   }
 
+  @JsonProperty("averageKeyLength")
+  public double getAverageKeyLength()
+  {
+    return averageKeyLength;
+  }
+
   @Override
   public boolean equals(Object o)
   {
@@ -50,12 +59,13 @@ public class QuantilesSketchKeyCollectorSnapshot implements KeyCollectorSnapshot
       return false;
     }
     QuantilesSketchKeyCollectorSnapshot that = (QuantilesSketchKeyCollectorSnapshot) o;
-    return Objects.equals(encodedSketch, that.encodedSketch);
+    return Objects.equals(encodedSketch, that.encodedSketch)
+           && Double.compare(that.averageKeyLength, averageKeyLength) == 0;
   }
 
   @Override
   public int hashCode()
   {
-    return Objects.hash(encodedSketch);
+    return Objects.hash(encodedSketch, averageKeyLength);
   }
 }
diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/statistics/ClusterByStatisticsCollectorImplTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/statistics/ClusterByStatisticsCollectorImplTest.java
index 6976aa687f..17aa0f204d 100644
--- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/statistics/ClusterByStatisticsCollectorImplTest.java
+++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/statistics/ClusterByStatisticsCollectorImplTest.java
@@ -80,7 +80,7 @@ public class ClusterByStatisticsCollectorImplTest extends InitializedNullHandlin
   );
 
   // These numbers are roughly 10x lower than authentic production numbers. (See StageDefinition.)
-  private static final int MAX_KEYS = 5000;
+  private static final int MAX_BYTES = 1_000_000;
   private static final int MAX_BUCKETS = 1000;
 
   @Test
@@ -598,7 +598,7 @@ public class ClusterByStatisticsCollectorImplTest extends InitializedNullHandlin
   private ClusterByStatisticsCollectorImpl makeCollector(final ClusterBy clusterBy, final boolean aggregate)
   {
     return (ClusterByStatisticsCollectorImpl)
-        ClusterByStatisticsCollectorImpl.create(clusterBy, SIGNATURE, MAX_KEYS, MAX_BUCKETS, aggregate, false);
+        ClusterByStatisticsCollectorImpl.create(clusterBy, SIGNATURE, MAX_BYTES, MAX_BUCKETS, aggregate, false);
   }
 
   private static void verifyPartitions(
diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/statistics/DelegateOrMinKeyCollectorTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/statistics/DelegateOrMinKeyCollectorTest.java
index 09b52a3715..e054dcf98b 100644
--- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/statistics/DelegateOrMinKeyCollectorTest.java
+++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/statistics/DelegateOrMinKeyCollectorTest.java
@@ -58,7 +58,7 @@ public class DelegateOrMinKeyCollectorTest
     Assert.assertTrue(collector.getDelegate().isPresent());
     Assert.assertTrue(collector.isEmpty());
     Assert.assertThrows(NoSuchElementException.class, collector::minKey);
-    Assert.assertEquals(0, collector.estimatedRetainedKeys());
+    Assert.assertEquals(0, collector.estimatedRetainedBytes(), 0);
     Assert.assertEquals(0, collector.estimatedTotalWeight());
     MatcherAssert.assertThat(collector.getDelegate().get(), CoreMatchers.instanceOf(QuantilesSketchKeyCollector.class));
   }
@@ -83,12 +83,13 @@ public class DelegateOrMinKeyCollectorTest
             QuantilesSketchKeyCollectorFactory.create(clusterBy)
         ).newKeyCollector();
 
-    collector.add(createKey(1L), 1);
+    RowKey key = createKey(1L);
+    collector.add(key, 1);
 
     Assert.assertTrue(collector.getDelegate().isPresent());
     Assert.assertFalse(collector.isEmpty());
-    Assert.assertEquals(createKey(1L), collector.minKey());
-    Assert.assertEquals(1, collector.estimatedRetainedKeys());
+    Assert.assertEquals(key, collector.minKey());
+    Assert.assertEquals(key.getNumberOfBytes(), collector.estimatedRetainedBytes(), 0);
     Assert.assertEquals(1, collector.estimatedTotalWeight());
   }
 
@@ -101,13 +102,15 @@ public class DelegateOrMinKeyCollectorTest
             QuantilesSketchKeyCollectorFactory.create(clusterBy)
         ).newKeyCollector();
 
-    collector.add(createKey(1L), 1);
+    RowKey key = createKey(1L);
+
+    collector.add(key, 1);
     Assert.assertTrue(collector.downSample());
 
     Assert.assertTrue(collector.getDelegate().isPresent());
     Assert.assertFalse(collector.isEmpty());
-    Assert.assertEquals(createKey(1L), collector.minKey());
-    Assert.assertEquals(1, collector.estimatedRetainedKeys());
+    Assert.assertEquals(key, collector.minKey());
+    Assert.assertEquals(key.getNumberOfBytes(), collector.estimatedRetainedBytes(), 0);
     Assert.assertEquals(1, collector.estimatedTotalWeight());
 
     // Should not have actually downsampled, because the quantiles-based collector does nothing when
@@ -127,23 +130,26 @@ public class DelegateOrMinKeyCollectorTest
             QuantilesSketchKeyCollectorFactory.create(clusterBy)
         ).newKeyCollector();
 
-    collector.add(createKey(1L), 1);
-    collector.add(createKey(1L), 1);
+    RowKey key = createKey(1L);
+    collector.add(key, 1);
+    collector.add(key, 1);
+    int expectedRetainedBytes = 2 * key.getNumberOfBytes();
 
     Assert.assertTrue(collector.getDelegate().isPresent());
     Assert.assertFalse(collector.isEmpty());
     Assert.assertEquals(createKey(1L), collector.minKey());
-    Assert.assertEquals(2, collector.estimatedRetainedKeys());
+    Assert.assertEquals(expectedRetainedBytes, collector.estimatedRetainedBytes(), 0);
     Assert.assertEquals(2, collector.estimatedTotalWeight());
 
     while (collector.getDelegate().isPresent()) {
       Assert.assertTrue(collector.downSample());
     }
+    expectedRetainedBytes = key.getNumberOfBytes();
 
     Assert.assertFalse(collector.getDelegate().isPresent());
     Assert.assertFalse(collector.isEmpty());
     Assert.assertEquals(createKey(1L), collector.minKey());
-    Assert.assertEquals(1, collector.estimatedRetainedKeys());
+    Assert.assertEquals(expectedRetainedBytes, collector.estimatedRetainedBytes(), 0);
     Assert.assertEquals(1, collector.estimatedTotalWeight());
   }
 
diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/statistics/DistinctKeyCollectorTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/statistics/DistinctKeyCollectorTest.java
index d853dc994f..6d3622612d 100644
--- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/statistics/DistinctKeyCollectorTest.java
+++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/statistics/DistinctKeyCollectorTest.java
@@ -20,6 +20,7 @@
 package org.apache.druid.msq.statistics;
 
 import com.google.common.collect.ImmutableList;
+import org.apache.druid.common.config.NullHandling;
 import org.apache.druid.frame.key.ClusterBy;
 import org.apache.druid.frame.key.ClusterByPartition;
 import org.apache.druid.frame.key.ClusterByPartitions;
@@ -43,6 +44,10 @@ public class DistinctKeyCollectorTest
   private final Comparator<RowKey> comparator = clusterBy.keyComparator();
   private final int numKeys = 500_000;
 
+  static {
+    NullHandling.initializeForTests();
+  }
+
   @Test
   public void test_empty()
   {
@@ -127,11 +132,11 @@ public class DistinctKeyCollectorTest
             // Intentionally empty loop body.
           }
 
-          Assert.assertEquals(DistinctKeyCollector.SMALLEST_MAX_KEYS, collector.getMaxKeys());
+          Assert.assertTrue(DistinctKeyCollector.SMALLEST_MAX_BYTES >= collector.getMaxBytes());
           MatcherAssert.assertThat(
               testName,
-              collector.estimatedRetainedKeys(),
-              Matchers.lessThanOrEqualTo(DistinctKeyCollector.SMALLEST_MAX_KEYS)
+              (int) collector.estimatedRetainedBytes(),
+              Matchers.lessThanOrEqualTo(DistinctKeyCollector.SMALLEST_MAX_BYTES)
           );
 
           // Don't use verifyCollector, since this collector is downsampled so aggressively that it can't possibly
@@ -230,8 +235,7 @@ public class DistinctKeyCollectorTest
       final NavigableMap<RowKey, List<Integer>> sortedKeyWeights
   )
   {
-    Assert.assertEquals(collector.getRetainedKeys().size(), collector.estimatedRetainedKeys());
-    MatcherAssert.assertThat(collector.getRetainedKeys().size(), Matchers.lessThan(collector.getMaxKeys()));
+    MatcherAssert.assertThat((int) collector.estimatedRetainedBytes(), Matchers.lessThan(collector.getMaxBytes()));
 
     KeyCollectorTestUtils.verifyCollector(
         collector,
diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/statistics/QuantilesSketchKeyCollectorSnapshotTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/statistics/QuantilesSketchKeyCollectorSnapshotTest.java
new file mode 100644
index 0000000000..e05f295188
--- /dev/null
+++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/statistics/QuantilesSketchKeyCollectorSnapshotTest.java
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.msq.statistics;
+
+import com.fasterxml.jackson.core.JsonProcessingException;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import org.apache.druid.jackson.DefaultObjectMapper;
+import org.junit.Assert;
+import org.junit.Test;
+
+public class QuantilesSketchKeyCollectorSnapshotTest
+{
+  private final ObjectMapper jsonMapper = new DefaultObjectMapper();
+
+  @Test
+  public void testSnapshotSerde() throws JsonProcessingException
+  {
+    QuantilesSketchKeyCollectorSnapshot snapshot = new QuantilesSketchKeyCollectorSnapshot("sketchString", 100);
+    String jsonStr = jsonMapper.writeValueAsString(snapshot);
+    Assert.assertEquals(snapshot, jsonMapper.readValue(jsonStr, QuantilesSketchKeyCollectorSnapshot.class));
+  }
+}
diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/statistics/QuantilesSketchKeyCollectorTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/statistics/QuantilesSketchKeyCollectorTest.java
index 974f79c7bf..0f8147eb92 100644
--- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/statistics/QuantilesSketchKeyCollectorTest.java
+++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/statistics/QuantilesSketchKeyCollectorTest.java
@@ -24,9 +24,12 @@ import org.apache.druid.common.config.NullHandling;
 import org.apache.druid.frame.key.ClusterBy;
 import org.apache.druid.frame.key.ClusterByPartition;
 import org.apache.druid.frame.key.ClusterByPartitions;
+import org.apache.druid.frame.key.KeyTestUtils;
 import org.apache.druid.frame.key.RowKey;
 import org.apache.druid.frame.key.SortColumn;
 import org.apache.druid.java.util.common.Pair;
+import org.apache.druid.segment.column.ColumnType;
+import org.apache.druid.segment.column.RowSignature;
 import org.junit.Assert;
 import org.junit.Test;
 
@@ -119,7 +122,7 @@ public class QuantilesSketchKeyCollectorTest
           }
 
           Assert.assertEquals(testName, 2, collector.getSketch().getK());
-          Assert.assertEquals(testName, 22, collector.estimatedRetainedKeys());
+          Assert.assertEquals(testName, 14, collector.estimatedRetainedKeys());
 
           // Don't use verifyCollector, since this collector is downsampled so aggressively that it can't possibly
           // hope to pass those tests. Grade on a curve.
@@ -161,6 +164,46 @@ public class QuantilesSketchKeyCollectorTest
     );
   }
 
+  @Test
+  public void testAverageKeyLength()
+  {
+    final QuantilesSketchKeyCollector collector = QuantilesSketchKeyCollectorFactory.create(clusterBy).newKeyCollector();
+
+    final QuantilesSketchKeyCollector other = QuantilesSketchKeyCollectorFactory.create(clusterBy).newKeyCollector();
+
+    RowSignature smallKeySignature = KeyTestUtils.createKeySignature(
+        new ClusterBy(ImmutableList.of(new SortColumn("x", false)), 0).getColumns(),
+        RowSignature.builder().add("x", ColumnType.LONG).build()
+    );
+    RowKey smallKey = KeyTestUtils.createKey(smallKeySignature, 1L);
+
+    RowSignature largeKeySignature = KeyTestUtils.createKeySignature(
+        new ClusterBy(
+            ImmutableList.of(
+                new SortColumn("x", false),
+                new SortColumn("y", false),
+                new SortColumn("z", false)
+            ),
+            0).getColumns(),
+        RowSignature.builder()
+                    .add("x", ColumnType.LONG)
+                    .add("y", ColumnType.LONG)
+                    .add("z", ColumnType.LONG)
+                    .build()
+    );
+    RowKey largeKey = KeyTestUtils.createKey(largeKeySignature, 1L, 2L, 3L);
+
+
+    collector.add(smallKey, 3);
+    Assert.assertEquals(smallKey.getNumberOfBytes(), collector.getAverageKeyLength(), 0);
+
+    other.add(largeKey, 5);
+    Assert.assertEquals(largeKey.getNumberOfBytes(), other.getAverageKeyLength(), 0);
+
+    collector.addAll(other);
+    Assert.assertEquals((smallKey.getNumberOfBytes() * 3 + largeKey.getNumberOfBytes() * 5) / 8.0, collector.getAverageKeyLength(), 0);
+  }
+
   @Test
   public void test_uniformRandomKeys_inverseBarbellWeighted()
   {
diff --git a/processing/src/main/java/org/apache/druid/frame/key/RowKey.java b/processing/src/main/java/org/apache/druid/frame/key/RowKey.java
index 498a23a46d..aa3701ba90 100644
--- a/processing/src/main/java/org/apache/druid/frame/key/RowKey.java
+++ b/processing/src/main/java/org/apache/druid/frame/key/RowKey.java
@@ -108,4 +108,9 @@ public class RowKey
   {
     return Arrays.toString(key);
   }
+
+  public int getNumberOfBytes()
+  {
+    return array().length;
+  }
 }
diff --git a/processing/src/test/java/org/apache/druid/frame/key/RowKeyTest.java b/processing/src/test/java/org/apache/druid/frame/key/RowKeyTest.java
index 0aa6a87a98..20e8fb981a 100644
--- a/processing/src/test/java/org/apache/druid/frame/key/RowKeyTest.java
+++ b/processing/src/test/java/org/apache/druid/frame/key/RowKeyTest.java
@@ -91,4 +91,17 @@ public class RowKeyTest extends InitializedNullHandlingTest
         KeyTestUtils.createKey(signatureLongString, 1L, "def").hashCode()
     );
   }
+
+  @Test
+  public void testGetNumberOfBytes()
+  {
+    final RowSignature signatureLong = RowSignature.builder().add("1", ColumnType.LONG).build();
+    final RowKey longKey = KeyTestUtils.createKey(signatureLong, 1L, "abc");
+    Assert.assertEquals(longKey.array().length, longKey.getNumberOfBytes());
+
+    final RowSignature signatureLongString =
+        RowSignature.builder().add("1", ColumnType.LONG).add("2", ColumnType.STRING).build();
+    final RowKey longStringKey = KeyTestUtils.createKey(signatureLongString, 1L, "abc");
+    Assert.assertEquals(longStringKey.array().length, longStringKey.getNumberOfBytes());
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@druid.apache.org
For additional commands, e-mail: commits-help@druid.apache.org