You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@druid.apache.org by cw...@apache.org on 2021/05/12 08:22:17 UTC

[druid] branch master updated: add estimated byte size limit enforcement for heap based expression aggregator (#11236)

This is an automated email from the ASF dual-hosted git repository.

cwylie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/druid.git


The following commit(s) were added to refs/heads/master by this push:
     new 790262e  add estimated byte size limit enforcement for heap based expression aggregator (#11236)
790262e is described below

commit 790262e5d0a8b537531291580895bc2e84946721
Author: Clint Wylie <cw...@apache.org>
AuthorDate: Wed May 12 01:21:50 2021 -0700

    add estimated byte size limit enforcement for heap based expression aggregator (#11236)
---
 .../java/org/apache/druid/math/expr/ExprEval.java  | 66 ++++++++++++++-
 .../org/apache/druid/math/expr/ExprEvalTest.java   | 99 ++++++++++++++++++----
 .../aggregation/ExpressionLambdaAggregator.java    |  9 +-
 .../ExpressionLambdaAggregatorFactory.java         |  3 +-
 .../timeseries/TimeseriesQueryRunnerTest.java      | 38 +++++++++
 5 files changed, 194 insertions(+), 21 deletions(-)

diff --git a/core/src/main/java/org/apache/druid/math/expr/ExprEval.java b/core/src/main/java/org/apache/druid/math/expr/ExprEval.java
index a2ef91f..8c23487 100644
--- a/core/src/main/java/org/apache/druid/math/expr/ExprEval.java
+++ b/core/src/main/java/org/apache/druid/math/expr/ExprEval.java
@@ -31,6 +31,7 @@ import javax.annotation.Nullable;
 import java.nio.ByteBuffer;
 import java.util.Arrays;
 import java.util.List;
+import java.util.Objects;
 
 /**
  * Generic result holder for evaluated {@link Expr} containing the value and {@link ExprType} of the value to allow
@@ -279,7 +280,7 @@ public abstract class ExprEval<T>
     }
   }
 
-  private static void checkMaxBytes(ExprType type, int sizeBytes, int maxSizeBytes)
+  public static void checkMaxBytes(ExprType type, int sizeBytes, int maxSizeBytes)
   {
     if (sizeBytes > maxSizeBytes) {
       throw new ISE("Unable to serialize [%s], size [%s] is larger than max [%s]", type, sizeBytes, maxSizeBytes);
@@ -287,6 +288,65 @@ public abstract class ExprEval<T>
   }
 
   /**
+   * Used to estimate the size in bytes to {@link #serialize} the {@link ExprEval} value, checking against a maximum
+   * size and failing with an {@link ISE} if the estimate is over the maximum.
+   */
+  public static void estimateAndCheckMaxBytes(ExprEval eval, int maxSizeBytes)
+  {
+    final int estimated;
+    switch (eval.type()) {
+      case STRING:
+        String stringValue = eval.asString();
+        estimated = 1 + Integer.BYTES + (stringValue == null ? 0 : StringUtils.estimatedBinaryLengthAsUTF8(stringValue));
+        break;
+      case LONG:
+      case DOUBLE:
+        estimated = 1 + (NullHandling.sqlCompatible() ? 1 + Long.BYTES : Long.BYTES);
+        break;
+      case STRING_ARRAY:
+        String[] stringArray = eval.asStringArray();
+        if (stringArray == null) {
+          estimated = 1 + Integer.BYTES;
+        } else {
+          final int elementsSize = Arrays.stream(stringArray)
+                                         .filter(Objects::nonNull)
+                                         .mapToInt(StringUtils::estimatedBinaryLengthAsUTF8)
+                                         .sum();
+          // since each value is variably sized, there is an integer per element
+          estimated = 1 + Integer.BYTES + (Integer.BYTES * stringArray.length) + elementsSize;
+        }
+        break;
+      case LONG_ARRAY:
+        Long[] longArray = eval.asLongArray();
+        if (longArray == null) {
+          estimated = 1 + Integer.BYTES;
+        } else {
+          final int elementsSize = Arrays.stream(longArray)
+                                         .filter(Objects::nonNull)
+                                         .mapToInt(x -> Long.BYTES)
+                                         .sum();
+          estimated = 1 + Integer.BYTES + (NullHandling.sqlCompatible() ? longArray.length : 0) + elementsSize;
+        }
+        break;
+      case DOUBLE_ARRAY:
+        Double[] doubleArray = eval.asDoubleArray();
+        if (doubleArray == null) {
+          estimated = 1 + Integer.BYTES;
+        } else {
+          final int elementsSize = Arrays.stream(doubleArray)
+                                         .filter(Objects::nonNull)
+                                         .mapToInt(x -> Long.BYTES)
+                                         .sum();
+          estimated = 1 + Integer.BYTES + (NullHandling.sqlCompatible() ? doubleArray.length : 0) + elementsSize;
+        }
+        break;
+      default:
+        throw new IllegalStateException("impossible");
+    }
+    checkMaxBytes(eval.type(), estimated, maxSizeBytes);
+  }
+
+  /**
    * Converts a List to an appropriate array type, optionally doing some conversion to make multi-valued strings
    * consistent across selector types, which are not consistent in treatment of null, [], and [null].
    *
@@ -1120,7 +1180,9 @@ public abstract class ExprEval<T>
     @Override
     public String[] asStringArray()
     {
-      return value == null ? null : Arrays.stream(value).map(x -> x != null ? x.toString() : null).toArray(String[]::new);
+      return value == null
+             ? null
+             : Arrays.stream(value).map(x -> x != null ? x.toString() : null).toArray(String[]::new);
     }
 
     @Nullable
diff --git a/core/src/test/java/org/apache/druid/math/expr/ExprEvalTest.java b/core/src/test/java/org/apache/druid/math/expr/ExprEvalTest.java
index b15f321..ae8b5e3 100644
--- a/core/src/test/java/org/apache/druid/math/expr/ExprEvalTest.java
+++ b/core/src/test/java/org/apache/druid/math/expr/ExprEvalTest.java
@@ -20,6 +20,7 @@
 package org.apache.druid.math.expr;
 
 import com.google.common.collect.ImmutableList;
+import org.apache.druid.common.config.NullHandling;
 import org.apache.druid.java.util.common.ISE;
 import org.apache.druid.java.util.common.StringUtils;
 import org.apache.druid.testing.InitializedNullHandlingTest;
@@ -53,7 +54,12 @@ public class ExprEvalTest extends InitializedNullHandlingTest
   public void testStringSerdeTooBig()
   {
     expectedException.expect(ISE.class);
-    expectedException.expectMessage(StringUtils.format("Unable to serialize [%s], size [%s] is larger than max [%s]", ExprType.STRING, 16, 10));
+    expectedException.expectMessage(StringUtils.format(
+        "Unable to serialize [%s], size [%s] is larger than max [%s]",
+        ExprType.STRING,
+        16,
+        10
+    ));
     assertExpr(0, ExprEval.of("hello world"), 10);
   }
 
@@ -77,49 +83,104 @@ public class ExprEvalTest extends InitializedNullHandlingTest
   @Test
   public void testStringArraySerde()
   {
-    assertExpr(0, new String[] {"hello", "hi", "hey"});
-    assertExpr(1024, new String[] {"hello", null, "hi", "hey"});
-    assertExpr(2048, new String[] {});
+    assertExpr(0, new String[]{"hello", "hi", "hey"});
+    assertExpr(1024, new String[]{"hello", null, "hi", "hey"});
+    assertExpr(2048, new String[]{});
   }
 
   @Test
   public void testStringArraySerdeToBig()
   {
     expectedException.expect(ISE.class);
-    expectedException.expectMessage(StringUtils.format("Unable to serialize [%s], size [%s] is larger than max [%s]", ExprType.STRING_ARRAY, 14, 10));
-    assertExpr(0, ExprEval.ofStringArray(new String[] {"hello", "hi", "hey"}), 10);
+    expectedException.expectMessage(StringUtils.format(
+        "Unable to serialize [%s], size [%s] is larger than max [%s]",
+        ExprType.STRING_ARRAY,
+        14,
+        10
+    ));
+    assertExpr(0, ExprEval.ofStringArray(new String[]{"hello", "hi", "hey"}), 10);
+  }
+
+  @Test
+  public void testStringArrayEvalToBig()
+  {
+    expectedException.expect(ISE.class);
+    // this has a different failure size than string serde because it doesn't check incrementally
+    expectedException.expectMessage(StringUtils.format(
+        "Unable to serialize [%s], size [%s] is larger than max [%s]",
+        ExprType.STRING_ARRAY,
+        27,
+        10
+    ));
+    assertEstimatedBytes(ExprEval.ofStringArray(new String[]{"hello", "hi", "hey"}), 10);
   }
 
   @Test
   public void testLongArraySerde()
   {
-    assertExpr(0, new Long[] {1L, 2L, 3L});
-    assertExpr(1234, new Long[] {1L, 2L, null, 3L});
-    assertExpr(1234, new Long[] {});
+    assertExpr(0, new Long[]{1L, 2L, 3L});
+    assertExpr(1234, new Long[]{1L, 2L, null, 3L});
+    assertExpr(1234, new Long[]{});
   }
 
   @Test
   public void testLongArraySerdeTooBig()
   {
     expectedException.expect(ISE.class);
-    expectedException.expectMessage(StringUtils.format("Unable to serialize [%s], size [%s] is larger than max [%s]", ExprType.LONG_ARRAY, 29, 10));
-    assertExpr(0, ExprEval.ofLongArray(new Long[] {1L, 2L, 3L}), 10);
+    expectedException.expectMessage(StringUtils.format(
+        "Unable to serialize [%s], size [%s] is larger than max [%s]",
+        ExprType.LONG_ARRAY,
+        29,
+        10
+    ));
+    assertExpr(0, ExprEval.ofLongArray(new Long[]{1L, 2L, 3L}), 10);
+  }
+
+  @Test
+  public void testLongArrayEvalTooBig()
+  {
+    expectedException.expect(ISE.class);
+    expectedException.expectMessage(StringUtils.format(
+        "Unable to serialize [%s], size [%s] is larger than max [%s]",
+        ExprType.LONG_ARRAY,
+        NullHandling.sqlCompatible() ? 32 : 29,
+        10
+    ));
+    assertEstimatedBytes(ExprEval.ofLongArray(new Long[]{1L, 2L, 3L}), 10);
   }
 
   @Test
   public void testDoubleArraySerde()
   {
-    assertExpr(0, new Double[] {1.1, 2.2, 3.3});
-    assertExpr(1234, new Double[] {1.1, 2.2, null, 3.3});
-    assertExpr(1234, new Double[] {});
+    assertExpr(0, new Double[]{1.1, 2.2, 3.3});
+    assertExpr(1234, new Double[]{1.1, 2.2, null, 3.3});
+    assertExpr(1234, new Double[]{});
   }
 
   @Test
   public void testDoubleArraySerdeTooBig()
   {
     expectedException.expect(ISE.class);
-    expectedException.expectMessage(StringUtils.format("Unable to serialize [%s], size [%s] is larger than max [%s]", ExprType.DOUBLE_ARRAY, 29, 10));
-    assertExpr(0, ExprEval.ofDoubleArray(new Double[] {1.1, 2.2, 3.3}), 10);
+    expectedException.expectMessage(StringUtils.format(
+        "Unable to serialize [%s], size [%s] is larger than max [%s]",
+        ExprType.DOUBLE_ARRAY,
+        29,
+        10
+    ));
+    assertExpr(0, ExprEval.ofDoubleArray(new Double[]{1.1, 2.2, 3.3}), 10);
+  }
+
+  @Test
+  public void testDoubleArrayEvalTooBig()
+  {
+    expectedException.expect(ISE.class);
+    expectedException.expectMessage(StringUtils.format(
+        "Unable to serialize [%s], size [%s] is larger than max [%s]",
+        ExprType.DOUBLE_ARRAY,
+        NullHandling.sqlCompatible() ? 32 : 29,
+        10
+    ));
+    assertEstimatedBytes(ExprEval.ofDoubleArray(new Double[]{1.1, 2.2, 3.3}), 10);
   }
 
   @Test
@@ -216,5 +277,11 @@ public class ExprEvalTest extends InitializedNullHandlingTest
     } else {
       Assert.assertEquals(expected.value(), ExprEval.deserialize(buffer, position).value());
     }
+    assertEstimatedBytes(expected, maxSizeBytes);
+  }
+
+  private void assertEstimatedBytes(ExprEval eval, int maxSizeBytes)
+  {
+    ExprEval.estimateAndCheckMaxBytes(eval, maxSizeBytes);
   }
 }
diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregator.java b/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregator.java
index 0305c8a..59bd6f2 100644
--- a/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregator.java
+++ b/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregator.java
@@ -20,6 +20,7 @@
 package org.apache.druid.query.aggregation;
 
 import org.apache.druid.math.expr.Expr;
+import org.apache.druid.math.expr.ExprEval;
 
 import javax.annotation.Nullable;
 
@@ -27,17 +28,21 @@ public class ExpressionLambdaAggregator implements Aggregator
 {
   private final Expr lambda;
   private final ExpressionLambdaAggregatorInputBindings bindings;
+  private final int maxSizeBytes;
 
-  public ExpressionLambdaAggregator(Expr lambda, ExpressionLambdaAggregatorInputBindings bindings)
+  public ExpressionLambdaAggregator(Expr lambda, ExpressionLambdaAggregatorInputBindings bindings, int maxSizeBytes)
   {
     this.lambda = lambda;
     this.bindings = bindings;
+    this.maxSizeBytes = maxSizeBytes;
   }
 
   @Override
   public void aggregate()
   {
-    bindings.accumulate(lambda.eval(bindings));
+    final ExprEval<?> eval = lambda.eval(bindings);
+    ExprEval.estimateAndCheckMaxBytes(eval, maxSizeBytes);
+    bindings.accumulate(eval);
   }
 
   @Nullable
diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregatorFactory.java b/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregatorFactory.java
index be8b100..e40000d 100644
--- a/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregatorFactory.java
+++ b/processing/src/main/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregatorFactory.java
@@ -248,7 +248,8 @@ public class ExpressionLambdaAggregatorFactory extends AggregatorFactory
     FactorizePlan thePlan = new FactorizePlan(metricFactory);
     return new ExpressionLambdaAggregator(
         thePlan.getExpression(),
-        thePlan.getBindings()
+        thePlan.getBindings(),
+        maxSizeBytes.getBytesInInt()
     );
   }
 
diff --git a/processing/src/test/java/org/apache/druid/query/timeseries/TimeseriesQueryRunnerTest.java b/processing/src/test/java/org/apache/druid/query/timeseries/TimeseriesQueryRunnerTest.java
index bc4173a..806d5d8 100644
--- a/processing/src/test/java/org/apache/druid/query/timeseries/TimeseriesQueryRunnerTest.java
+++ b/processing/src/test/java/org/apache/druid/query/timeseries/TimeseriesQueryRunnerTest.java
@@ -27,6 +27,7 @@ import com.google.common.collect.Lists;
 import com.google.common.primitives.Doubles;
 import org.apache.druid.common.config.NullHandling;
 import org.apache.druid.java.util.common.DateTimes;
+import org.apache.druid.java.util.common.HumanReadableBytes;
 import org.apache.druid.java.util.common.Intervals;
 import org.apache.druid.java.util.common.StringUtils;
 import org.apache.druid.java.util.common.granularity.Granularities;
@@ -3039,6 +3040,43 @@ public class TimeseriesQueryRunnerTest extends InitializedNullHandlingTest
   }
 
   @Test
+  public void testTimeseriesWithExpressionAggregatorTooBig()
+  {
+    // expression agg cannot vectorize
+    cannotVectorize();
+    if (!vectorize) {
+      // size bytes when it overshoots varies slightly between algorithms
+      expectedException.expectMessage("Unable to serialize [STRING_ARRAY]");
+    }
+    TimeseriesQuery query = Druids.newTimeseriesQueryBuilder()
+                                  .dataSource(QueryRunnerTestHelper.DATA_SOURCE)
+                                  .granularity(Granularities.DAY)
+                                  .intervals(QueryRunnerTestHelper.FIRST_TO_THIRD)
+                                  .aggregators(
+                                      Collections.singletonList(
+                                          new ExpressionLambdaAggregatorFactory(
+                                              "array_agg_distinct",
+                                              ImmutableSet.of(QueryRunnerTestHelper.MARKET_DIMENSION),
+                                              "acc",
+                                              "[]",
+                                              null,
+                                              "array_set_add(acc, market)",
+                                              "array_set_add_all(acc, array_agg_distinct)",
+                                              null,
+                                              null,
+                                              HumanReadableBytes.valueOf(10),
+                                              TestExprMacroTable.INSTANCE
+                                          )
+                                      )
+                                  )
+                                  .descending(descending)
+                                  .context(makeContext())
+                                  .build();
+
+    runner.run(QueryPlus.wrap(query)).toList();
+  }
+
+  @Test
   public void testTimeseriesCardinalityAggOnMultiStringExpression()
   {
     TimeseriesQuery query = Druids.newTimeseriesQueryBuilder()

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@druid.apache.org
For additional commands, e-mail: commits-help@druid.apache.org