You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@calcite.apache.org by jh...@apache.org on 2019/07/15 16:39:02 UTC

[calcite] 06/06: [CALCITE-3145] RelBuilder.aggregate throws IndexOutOfBoundsException if groupKey is non-empty and there are duplicate aggregate functions

This is an automated email from the ASF dual-hosted git repository.

jhyde pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/calcite.git

commit 0cce229903a845a7b8ed36cf86d6078fd82d73d3
Author: Julian Hyde <jh...@apache.org>
AuthorDate: Mon Jun 24 13:01:37 2019 -0700

    [CALCITE-3145] RelBuilder.aggregate throws IndexOutOfBoundsException if groupKey is non-empty and there are duplicate aggregate functions
    
    The cause is that [CALCITE-3123] did not handle the case of non-empty groupKey.
    
    Enable RelBuilder.Config.dedupAggregateCalls by default.
---
 .../java/org/apache/calcite/tools/RelBuilder.java  | 53 +++++++++++-----------
 .../org/apache/calcite/test/RelBuilderTest.java    | 47 +++++++++++++++----
 2 files changed, 64 insertions(+), 36 deletions(-)

diff --git a/core/src/main/java/org/apache/calcite/tools/RelBuilder.java b/core/src/main/java/org/apache/calcite/tools/RelBuilder.java
index 079c3de..f19a510 100644
--- a/core/src/main/java/org/apache/calcite/tools/RelBuilder.java
+++ b/core/src/main/java/org/apache/calcite/tools/RelBuilder.java
@@ -1634,32 +1634,33 @@ public class RelBuilder {
     if (!config.dedupAggregateCalls || Util.isDistinct(aggregateCalls)) {
       return aggregate_(groupSet, groupSets, r, aggregateCalls,
           registrar.extraNodes, frame.fields);
-    } else {
-      // There are duplicate aggregate calls.
-      final Set<AggregateCall> callSet = new HashSet<>();
-      final List<Integer> projects =
-          new ArrayList<>(Util.range(groupSet.cardinality()));
-      final List<AggregateCall> distinctAggregateCalls = new ArrayList<>();
-      for (AggregateCall aggregateCall : aggregateCalls) {
-        final int i;
-        if (callSet.add(aggregateCall)) {
-          i = distinctAggregateCalls.size();
-          distinctAggregateCalls.add(aggregateCall);
-        } else {
-          i = distinctAggregateCalls.indexOf(aggregateCall);
-          assert i >= 0;
-        }
-        projects.add(i);
-      }
-      aggregate_(groupSet, groupSets, r, distinctAggregateCalls,
-          registrar.extraNodes, frame.fields);
-      final List<RexNode> fields =
-          new ArrayList<>(fields(Util.range(groupSet.cardinality())));
-      for (Ord<Integer> project : Ord.zip(projects)) {
-        fields.add(alias(field(project.e), aggregateCalls.get(project.i).name));
+    }
+
+    // There are duplicate aggregate calls. Rebuild the list to eliminate
+    // duplicates, then add a Project.
+    final Set<AggregateCall> callSet = new HashSet<>();
+    final List<Pair<Integer, String>> projects = new ArrayList<>();
+    Util.range(groupSet.cardinality())
+        .forEach(i -> projects.add(Pair.of(i, null)));
+    final List<AggregateCall> distinctAggregateCalls = new ArrayList<>();
+    for (AggregateCall aggregateCall : aggregateCalls) {
+      final int i;
+      if (callSet.add(aggregateCall)) {
+        i = distinctAggregateCalls.size();
+        distinctAggregateCalls.add(aggregateCall);
+      } else {
+        i = distinctAggregateCalls.indexOf(aggregateCall);
+        assert i >= 0;
       }
-      return project(fields);
+      projects.add(Pair.of(groupSet.cardinality() + i, aggregateCall.name));
     }
+    aggregate_(groupSet, groupSets, r, distinctAggregateCalls,
+        registrar.extraNodes, frame.fields);
+    final List<RexNode> fields = projects.stream()
+        .map(p -> p.right == null ? field(p.left)
+            : alias(field(p.left), p.right))
+        .collect(Collectors.toList());
+    return project(fields);
   }
 
   /** Finishes the implementation of {@link #aggregate} by creating an
@@ -2787,10 +2788,10 @@ public class RelBuilder {
   public static class Config {
     /** Default configuration. */
     public static final Config DEFAULT =
-        new Config(false, true);
+        new Config(true, true);
 
     /** Whether {@link RelBuilder#aggregate} should eliminate duplicate
-     * aggregate calls; default true but currently disabled. */
+     * aggregate calls; default true. */
     public final boolean dedupAggregateCalls;
 
     /** Whether to simplify expressions; default true. */
diff --git a/core/src/test/java/org/apache/calcite/test/RelBuilderTest.java b/core/src/test/java/org/apache/calcite/test/RelBuilderTest.java
index 1178a30..1f8f844 100644
--- a/core/src/test/java/org/apache/calcite/test/RelBuilderTest.java
+++ b/core/src/test/java/org/apache/calcite/test/RelBuilderTest.java
@@ -967,14 +967,6 @@ public class RelBuilderTest {
   /** Tests that {@link RelBuilder#aggregate} eliminates duplicate aggregate
    * calls and creates a {@code Project} to compensate. */
   @Test public void testAggregateEliminatesDuplicateCalls() {
-    final Function<RelBuilder, RelNode> fn = builder ->
-        builder.scan("EMP")
-            .aggregate(builder.groupKey(),
-                builder.sum(builder.field(1)).as("S1"),
-                builder.count().as("C"),
-                builder.sum(builder.field(2)).as("S2"),
-                builder.sum(builder.field(1)).as("S1b"))
-            .build();
     final RelBuilder builder =
         createBuilder(configBuilder ->
             configBuilder.withDedupAggregateCalls(true));
@@ -982,7 +974,7 @@ public class RelBuilderTest {
         + "LogicalProject(S1=[$0], C=[$1], S2=[$2], S1b=[$0])\n"
         + "  LogicalAggregate(group=[{}], S1=[SUM($1)], C=[COUNT()], S2=[SUM($2)])\n"
         + "    LogicalTableScan(table=[[scott, EMP]])\n";
-    assertThat(fn.apply(builder), hasTree(expected));
+    assertThat(buildRelWithDuplicateAggregates(builder), hasTree(expected));
 
     // Now, disable the rewrite
     final RelBuilder builder2 =
@@ -991,7 +983,42 @@ public class RelBuilderTest {
     final String expected2 = ""
         + "LogicalAggregate(group=[{}], S1=[SUM($1)], C=[COUNT()], S2=[SUM($2)], S1b=[SUM($1)])\n"
         + "  LogicalTableScan(table=[[scott, EMP]])\n";
-    assertThat(fn.apply(builder2), hasTree(expected2));
+    assertThat(buildRelWithDuplicateAggregates(builder2), hasTree(expected2));
+  }
+
+  /** As {@link #testAggregateEliminatesDuplicateCalls()} but with a
+   * single-column GROUP BY clause. */
+  @Test public void testAggregateEliminatesDuplicateCalls2() {
+    final RelBuilder builder = RelBuilder.create(config().build());
+    RelNode root = buildRelWithDuplicateAggregates(builder, 0);
+    final String expected = ""
+        + "LogicalProject(EMPNO=[$0], S1=[$1], C=[$2], S2=[$3], S1b=[$1])\n"
+        + "  LogicalAggregate(group=[{0}], S1=[SUM($1)], C=[COUNT()], S2=[SUM($2)])\n"
+        + "    LogicalTableScan(table=[[scott, EMP]])\n";
+    assertThat(root, hasTree(expected));
+  }
+
+  /** As {@link #testAggregateEliminatesDuplicateCalls()} but with a
+   * multi-column GROUP BY clause. */
+  @Test public void testAggregateEliminatesDuplicateCalls3() {
+    final RelBuilder builder = RelBuilder.create(config().build());
+    RelNode root = buildRelWithDuplicateAggregates(builder, 2, 0, 4, 3);
+    final String expected = ""
+        + "LogicalProject(EMPNO=[$0], JOB=[$1], MGR=[$2], HIREDATE=[$3], S1=[$4], C=[$5], S2=[$6], S1b=[$4])\n"
+        + "  LogicalAggregate(group=[{0, 2, 3, 4}], S1=[SUM($1)], C=[COUNT()], S2=[SUM($2)])\n"
+        + "    LogicalTableScan(table=[[scott, EMP]])\n";
+    assertThat(root, hasTree(expected));
+  }
+
+  private RelNode buildRelWithDuplicateAggregates(RelBuilder builder,
+      int... groupFieldOrdinals) {
+    return builder.scan("EMP")
+        .aggregate(builder.groupKey(groupFieldOrdinals),
+            builder.sum(builder.field(1)).as("S1"),
+            builder.count().as("C"),
+            builder.sum(builder.field(2)).as("S2"),
+            builder.sum(builder.field(1)).as("S1b"))
+        .build();
   }
 
   @Test public void testAggregateFilter() {