You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@calcite.apache.org by jh...@apache.org on 2019/06/10 07:27:07 UTC
[calcite] 01/02: [CALCITE-3123] In RelBuilder,
eliminate duplicate aggregate calls
This is an automated email from the ASF dual-hosted git repository.
jhyde pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/calcite.git
commit e01ba5ab6e7c57348f9f7be2babf00ae007204b5
Author: Julian Hyde <jh...@apache.org>
AuthorDate: Fri Jun 7 15:56:13 2019 -0700
[CALCITE-3123] In RelBuilder, eliminate duplicate aggregate calls
---
.../java/org/apache/calcite/tools/RelBuilder.java | 44 ++++++++++++++++++++--
.../org/apache/calcite/test/RelBuilderTest.java | 19 ++++++++++
2 files changed, 60 insertions(+), 3 deletions(-)
diff --git a/core/src/main/java/org/apache/calcite/tools/RelBuilder.java b/core/src/main/java/org/apache/calcite/tools/RelBuilder.java
index c139f58..cdf71c4 100644
--- a/core/src/main/java/org/apache/calcite/tools/RelBuilder.java
+++ b/core/src/main/java/org/apache/calcite/tools/RelBuilder.java
@@ -1602,7 +1602,45 @@ public class RelBuilder {
for (ImmutableBitSet set : groupSets) {
assert groupSet.contains(set);
}
- RelNode aggregate = aggregateFactory.createAggregate(r,
+
+ if (Util.isDistinct(aggregateCalls)) {
+ return aggregate_(groupSet, groupSets, r, aggregateCalls,
+ registrar.extraNodes, frame.fields);
+ } else {
+ // There are duplicate aggregate calls.
+ final Set<AggregateCall> callSet = new HashSet<>();
+ final List<Integer> projects =
+ new ArrayList<>(Util.range(groupSet.cardinality()));
+ final List<AggregateCall> distinctAggregateCalls = new ArrayList<>();
+ for (AggregateCall aggregateCall : aggregateCalls) {
+ final int i;
+ if (callSet.add(aggregateCall)) {
+ i = distinctAggregateCalls.size();
+ distinctAggregateCalls.add(aggregateCall);
+ } else {
+ i = distinctAggregateCalls.indexOf(aggregateCall);
+ assert i >= 0;
+ }
+ projects.add(i);
+ }
+ aggregate_(groupSet, groupSets, r, distinctAggregateCalls,
+ registrar.extraNodes, frame.fields);
+ final List<RexNode> fields =
+ new ArrayList<>(fields(Util.range(groupSet.cardinality())));
+ for (Ord<Integer> project : Ord.zip(projects)) {
+ fields.add(alias(field(project.e), aggregateCalls.get(project.i).name));
+ }
+ return project(fields);
+ }
+ }
+
+ /** Finishes the implementation of {@link #aggregate} by creating an
+ * {@link Aggregate} and pushing it onto the stack. */
+ private RelBuilder aggregate_(ImmutableBitSet groupSet,
+ ImmutableList<ImmutableBitSet> groupSets, RelNode input,
+ List<AggregateCall> aggregateCalls, List<RexNode> extraNodes,
+ ImmutableList<Field> inFields) {
+ final RelNode aggregate = aggregateFactory.createAggregate(input,
groupSet, groupSets, aggregateCalls);
// build field list
@@ -1612,11 +1650,11 @@ public class RelBuilder {
int i = 0;
// first, group fields
for (Integer groupField : groupSet.asList()) {
- RexNode node = registrar.extraNodes.get(groupField);
+ RexNode node = extraNodes.get(groupField);
final SqlKind kind = node.getKind();
switch (kind) {
case INPUT_REF:
- fields.add(frame.fields.get(((RexInputRef) node).getIndex()));
+ fields.add(inFields.get(((RexInputRef) node).getIndex()));
break;
default:
String name = aggregateFields.get(i).getName();
diff --git a/core/src/test/java/org/apache/calcite/test/RelBuilderTest.java b/core/src/test/java/org/apache/calcite/test/RelBuilderTest.java
index a4f28af..c06bdeb 100644
--- a/core/src/test/java/org/apache/calcite/test/RelBuilderTest.java
+++ b/core/src/test/java/org/apache/calcite/test/RelBuilderTest.java
@@ -927,6 +927,25 @@ public class RelBuilderTest {
assertThat(root, hasTree(expected));
}
+ /** Tests that {@link RelBuilder#aggregate} eliminates duplicate aggregate
+ * calls and creates a {@code Project} to compensate. */
+ @Test public void testAggregateEliminatesDuplicateCalls() {
+ final RelBuilder builder = RelBuilder.create(config().build());
+ RelNode root =
+ builder.scan("EMP")
+ .aggregate(builder.groupKey(),
+ builder.sum(builder.field(1)).as("S1"),
+ builder.count().as("C"),
+ builder.sum(builder.field(2)).as("S2"),
+ builder.sum(builder.field(1)).as("S1b"))
+ .build();
+ final String expected = ""
+ + "LogicalProject(S1=[$0], C=[$1], S2=[$2], S1b=[$0])\n"
+ + " LogicalAggregate(group=[{}], S1=[SUM($1)], C=[COUNT()], S2=[SUM($2)])\n"
+ + " LogicalTableScan(table=[[scott, EMP]])\n";
+ assertThat(root, hasTree(expected));
+ }
+
@Test public void testAggregateFilter() {
// Equivalent SQL:
// SELECT deptno, COUNT(*) FILTER (WHERE empno > 100) AS c