You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@calcite.apache.org by jh...@apache.org on 2019/04/12 20:28:02 UTC

[calcite] branch master updated: [CALCITE-896] Remove Aggregate if grouping columns are unique and all functions are splittable

This is an automated email from the ASF dual-hosted git repository.

jhyde pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/calcite.git


The following commit(s) were added to refs/heads/master by this push:
     new 3ed7637  [CALCITE-896] Remove Aggregate if grouping columns are unique and all functions are splittable
3ed7637 is described below

commit 3ed76375fd05b29db17d6117aa0487ccb85f45ba
Author: Haisheng Yuan <h....@alibaba-inc.com>
AuthorDate: Thu Feb 28 17:10:34 2019 -0600

    [CALCITE-896] Remove Aggregate if grouping columns are unique and all functions are splittable
    
    Close apache/calcite#1078
---
 .../calcite/rel/rules/AggregateRemoveRule.java     |  77 +++++++++----
 .../calcite/sql/SqlSplittableAggFunction.java      |  22 +++-
 .../org/apache/calcite/test/RelOptRulesTest.java   |  87 +++++++++++++++
 .../org/apache/calcite/test/RelOptRulesTest.xml    | 121 +++++++++++++++++++++
 core/src/test/resources/sql/sub-query.iq           |   5 +-
 5 files changed, 284 insertions(+), 28 deletions(-)

diff --git a/core/src/main/java/org/apache/calcite/rel/rules/AggregateRemoveRule.java b/core/src/main/java/org/apache/calcite/rel/rules/AggregateRemoveRule.java
index 11a6d1c..1cfd621 100644
--- a/core/src/main/java/org/apache/calcite/rel/rules/AggregateRemoveRule.java
+++ b/core/src/main/java/org/apache/calcite/rel/rules/AggregateRemoveRule.java
@@ -20,18 +20,29 @@ import org.apache.calcite.plan.RelOptRule;
 import org.apache.calcite.plan.RelOptRuleCall;
 import org.apache.calcite.rel.RelNode;
 import org.apache.calcite.rel.core.Aggregate;
+import org.apache.calcite.rel.core.AggregateCall;
 import org.apache.calcite.rel.core.RelFactories;
 import org.apache.calcite.rel.logical.LogicalAggregate;
 import org.apache.calcite.rel.metadata.RelMetadataQuery;
+import org.apache.calcite.rex.RexBuilder;
+import org.apache.calcite.rex.RexNode;
 import org.apache.calcite.runtime.SqlFunctions;
+import org.apache.calcite.sql.SqlAggFunction;
+import org.apache.calcite.sql.SqlKind;
+import org.apache.calcite.sql.SqlSplittableAggFunction;
 import org.apache.calcite.tools.RelBuilder;
 import org.apache.calcite.tools.RelBuilderFactory;
 
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Objects;
+
 /**
  * Planner rule that removes
  * a {@link org.apache.calcite.rel.core.Aggregate}
  * if it computes no aggregate functions
- * (that is, it is implementing {@code SELECT DISTINCT})
+ * (that is, it is implementing {@code SELECT DISTINCT}),
+ * or all the aggregate functions are splittable,
  * and the underlying relational expression is already distinct.
  */
 public class AggregateRemoveRule extends RelOptRule {
@@ -51,41 +62,65 @@ public class AggregateRemoveRule extends RelOptRule {
    */
   public AggregateRemoveRule(Class<? extends Aggregate> aggregateClass,
       RelBuilderFactory relBuilderFactory) {
-    // REVIEW jvs 14-Mar-2006: We have to explicitly mention the child here
-    // to make sure the rule re-fires after the child changes (e.g. via
-    // ProjectRemoveRule), since that may change our information
-    // about whether the child is distinct.  If we clean up the inference of
-    // distinct to make it correct up-front, we can get rid of the reference
-    // to the child here.
     super(
-        operand(aggregateClass,
-            operand(RelNode.class, any())),
-        relBuilderFactory, null);
+        operandJ(aggregateClass, null, agg -> isAggregateSupported(agg),
+            any()), relBuilderFactory, null);
+  }
+
+  private static boolean isAggregateSupported(Aggregate aggregate) {
+    if (aggregate.getGroupType() != Aggregate.Group.SIMPLE
+        || aggregate.getGroupCount() == 0) {
+      return false;
+    }
+    // If any aggregate functions do not support splitting, bail out.
+    for (AggregateCall aggregateCall : aggregate.getAggCallList()) {
+      if (aggregateCall.filterArg >= 0
+          || aggregateCall.getAggregation()
+              .unwrap(SqlSplittableAggFunction.class) == null) {
+        return false;
+      }
+    }
+    return true;
   }
 
   //~ Methods ----------------------------------------------------------------
 
   public void onMatch(RelOptRuleCall call) {
     final Aggregate aggregate = call.rel(0);
-    final RelNode input = call.rel(1);
-    if (!aggregate.getAggCallList().isEmpty() || aggregate.indicator) {
-      return;
-    }
+    final RelNode input = aggregate.getInput();
     final RelMetadataQuery mq = call.getMetadataQuery();
     if (!SqlFunctions.isTrue(mq.areColumnsUnique(input, aggregate.getGroupSet()))) {
       return;
     }
-    // Distinct is "GROUP BY c1, c2" (where c1, c2 are a set of columns on
-    // which the input is unique, i.e. contain a key) and has no aggregate
-    // functions. It can be removed.
-    final RelNode newInput = convert(input, aggregate.getTraitSet().simplify());
 
-    // If aggregate was projecting a subset of columns, add a project for the
-    // same effect.
     final RelBuilder relBuilder = call.builder();
+    final RexBuilder rexBuilder = relBuilder.getRexBuilder();
+    final List<RexNode> projects = new ArrayList<>();
+    for (AggregateCall aggCall : aggregate.getAggCallList()) {
+      final SqlAggFunction aggregation = aggCall.getAggregation();
+      if (aggregation.getKind() == SqlKind.SUM0) {
+        // Bail out for SUM0 to avoid potential infinite rule matching,
+        // because it may be generated by transforming SUM aggregate
+        // function to SUM0 and COUNT.
+        return;
+      }
+      final SqlSplittableAggFunction splitter =
+          Objects.requireNonNull(
+              aggregation.unwrap(SqlSplittableAggFunction.class));
+      final RexNode singleton = splitter.singleton(
+          rexBuilder, input.getRowType(), aggCall);
+      projects.add(singleton);
+    }
+
+    final RelNode newInput = convert(input, aggregate.getTraitSet().simplify());
     relBuilder.push(newInput);
-    if (newInput.getRowType().getFieldCount()
+    if (!projects.isEmpty()) {
+      projects.addAll(0, relBuilder.fields(aggregate.getGroupSet().asList()));
+      relBuilder.project(projects);
+    } else if (newInput.getRowType().getFieldCount()
         > aggregate.getRowType().getFieldCount()) {
+      // If aggregate was projecting a subset of columns, and there were no
+      // aggregate functions, add a project for the same effect.
       relBuilder.project(relBuilder.fields(aggregate.getGroupSet().asList()));
     }
     call.transformTo(relBuilder.build());
diff --git a/core/src/main/java/org/apache/calcite/sql/SqlSplittableAggFunction.java b/core/src/main/java/org/apache/calcite/sql/SqlSplittableAggFunction.java
index 43457e2..015e257 100644
--- a/core/src/main/java/org/apache/calcite/sql/SqlSplittableAggFunction.java
+++ b/core/src/main/java/org/apache/calcite/sql/SqlSplittableAggFunction.java
@@ -186,12 +186,13 @@ public interface SqlSplittableAggFunction {
       }
       final RexNode predicate =
           RexUtil.composeConjunction(rexBuilder, predicates, true);
+      final RexNode rexOne = rexBuilder.makeExactLiteral(
+          BigDecimal.ONE, aggregateCall.getType());
       if (predicate == null) {
-        return rexBuilder.makeExactLiteral(BigDecimal.ONE);
+        return rexOne;
       } else {
-        return rexBuilder.makeCall(SqlStdOperatorTable.CASE, predicate,
-            rexBuilder.makeExactLiteral(BigDecimal.ONE),
-            rexBuilder.makeExactLiteral(BigDecimal.ZERO));
+        return rexBuilder.makeCall(SqlStdOperatorTable.CASE, predicate, rexOne,
+            rexBuilder.makeExactLiteral(BigDecimal.ZERO, aggregateCall.getType()));
       }
     }
 
@@ -340,6 +341,19 @@ public interface SqlSplittableAggFunction {
     @Override public SqlAggFunction getMergeAggFunctionOfTopSplit() {
       return SqlStdOperatorTable.SUM0;
     }
+
+    @Override public RexNode singleton(RexBuilder rexBuilder,
+        RelDataType inputRowType, AggregateCall aggregateCall) {
+      final int arg = aggregateCall.getArgList().get(0);
+      final RelDataType type = inputRowType.getFieldList().get(arg).getType();
+      final RexNode inputRef = rexBuilder.makeInputRef(type, arg);
+      if (type.isNullable()) {
+        return rexBuilder.makeCall(SqlStdOperatorTable.COALESCE, inputRef,
+            rexBuilder.makeExactLiteral(BigDecimal.ZERO, type));
+      } else {
+        return inputRef;
+      }
+    }
   }
 }
 
diff --git a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
index 5e87fc8..b7f21f3 100644
--- a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
+++ b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
@@ -61,6 +61,7 @@ import org.apache.calcite.rel.rules.AggregateMergeRule;
 import org.apache.calcite.rel.rules.AggregateProjectMergeRule;
 import org.apache.calcite.rel.rules.AggregateProjectPullUpConstantsRule;
 import org.apache.calcite.rel.rules.AggregateReduceFunctionsRule;
+import org.apache.calcite.rel.rules.AggregateRemoveRule;
 import org.apache.calcite.rel.rules.AggregateUnionAggregateRule;
 import org.apache.calcite.rel.rules.AggregateUnionTransposeRule;
 import org.apache.calcite.rel.rules.AggregateValuesRule;
@@ -3815,6 +3816,92 @@ public class RelOptRulesTest extends RelOptTestBase {
     checkPlanning(tester, preProgram, new HepPlanner(program), sql);
   }
 
+  /**
+   * Test case for AggregateRemoveRule, should remove aggregates since
+   * empno is unique and all aggregate functions are splittable.
+   */
+  @Test public void testAggregateRemove1() {
+    final HepProgram program = new HepProgramBuilder()
+        .addRuleInstance(AggregateRemoveRule.INSTANCE)
+        .addRuleInstance(ProjectMergeRule.INSTANCE)
+        .build();
+    final String sql = "select empno, sum(sal), min(sal), max(sal), "
+        + "bit_and(distinct sal), bit_or(sal), count(distinct sal) "
+        + "from sales.emp group by empno, deptno\n";
+    checkPlanning(program, sql);
+  }
+
+  /**
+   * Test case for AggregateRemoveRule, should remove aggregates since
+   * empno is unique and there are no aggregate functions.
+   */
+  @Test public void testAggregateRemove2() {
+    final HepProgram program = new HepProgramBuilder()
+        .addRuleInstance(AggregateRemoveRule.INSTANCE)
+        .addRuleInstance(ProjectMergeRule.INSTANCE)
+        .build();
+    final String sql = "select distinct empno, deptno from sales.emp\n";
+    checkPlanning(program, sql);
+  }
+
+  /**
+   * Test case for AggregateRemoveRule, should remove aggregates since
+   * empno is unique and all aggregate functions are splittable. Count
+   * aggregate function should be transformed to CASE function call
+   * because mgr is nullable.
+   */
+  @Test public void testAggregateRemove3() {
+    final HepProgram program = new HepProgramBuilder()
+        .addRuleInstance(AggregateRemoveRule.INSTANCE)
+        .addRuleInstance(ProjectMergeRule.INSTANCE)
+        .build();
+    final String sql = "select empno, count(mgr) "
+        + "from sales.emp group by empno, deptno\n";
+    checkPlanning(program, sql);
+  }
+
+  /**
+   * Negative test case for AggregateRemoveRule, should not
+   * remove aggregate because avg is not splittable.
+   */
+  @Test public void testAggregateRemove4() {
+    final HepProgram program = new HepProgramBuilder()
+        .addRuleInstance(AggregateRemoveRule.INSTANCE)
+        .addRuleInstance(ProjectMergeRule.INSTANCE)
+        .build();
+    final String sql = "select empno, max(sal), avg(sal) "
+        + "from sales.emp group by empno, deptno\n";
+    checkPlanUnchanged(new HepPlanner(program), sql);
+  }
+
+  /**
+   * Negative test case for AggregateRemoveRule, should not
+   * remove non-simple aggregates.
+   */
+  @Test public void testAggregateRemove5() {
+    final HepProgram program = new HepProgramBuilder()
+        .addRuleInstance(AggregateRemoveRule.INSTANCE)
+        .addRuleInstance(ProjectMergeRule.INSTANCE)
+        .build();
+    final String sql = "select empno, deptno, sum(sal) "
+        + "from sales.emp group by cube(empno, deptno)\n";
+    checkPlanUnchanged(new HepPlanner(program), sql);
+  }
+
+  /**
+   * Negative test case for AggregateRemoveRule, should not
+   * remove aggregate because deptno is not unique.
+   */
+  @Test public void testAggregateRemove6() {
+    final HepProgram program = new HepProgramBuilder()
+        .addRuleInstance(AggregateRemoveRule.INSTANCE)
+        .addRuleInstance(ProjectMergeRule.INSTANCE)
+        .build();
+    final String sql = "select deptno, max(sal) "
+        + "from sales.emp group by deptno\n";
+    checkPlanUnchanged(new HepPlanner(program), sql);
+  }
+
   @Test public void testSwapOuterJoin() {
     final HepProgram program = new HepProgramBuilder()
         .addMatchLimit(1)
diff --git a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
index 582ab00..ef28033 100644
--- a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
+++ b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
@@ -7087,6 +7087,127 @@ LogicalAggregate(group=[{}], X=[SUM($5)], Z=[MIN($5)])
 ]]>
         </Resource>
     </TestCase>
+    <TestCase name="testAggregateRemove1">
+        <Resource name="sql">
+            <![CDATA[select empno, sum(sal), min(sal), max(sal),
+  bit_and(distinct sal), bit_or(sal), count(distinct sal)
+  from sales.emp group by empno, ename]]>
+        </Resource>
+        <Resource name="planBefore">
+            <![CDATA[
+LogicalProject(EMPNO=[$0], EXPR$1=[$2], EXPR$2=[$3], EXPR$3=[$4], EXPR$4=[$5], EXPR$5=[$6], EXPR$6=[$7])
+  LogicalAggregate(group=[{0, 1}], EXPR$1=[SUM($2)], EXPR$2=[MIN($2)], EXPR$3=[MAX($2)], EXPR$4=[BIT_AND(DISTINCT $2)], EXPR$5=[BIT_OR($2)], EXPR$6=[COUNT(DISTINCT $2)])
+    LogicalProject(EMPNO=[$0], DEPTNO=[$7], SAL=[$5])
+      LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+        </Resource>
+        <Resource name="planAfter">
+            <![CDATA[
+LogicalProject(EMPNO=[$0], EXPR$1=[$5], EXPR$2=[$5], EXPR$3=[$5], EXPR$4=[$5], EXPR$5=[$5], EXPR$6=[1:BIGINT])
+  LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+        </Resource>
+    </TestCase>
+    <TestCase name="testAggregateRemove2">
+        <Resource name="sql">
+            <![CDATA[select distinct empno, deptno from sales.emp]]>
+        </Resource>
+        <Resource name="planBefore">
+            <![CDATA[
+LogicalAggregate(group=[{0, 1}])
+  LogicalProject(EMPNO=[$0], DEPTNO=[$7])
+    LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+        </Resource>
+        <Resource name="planAfter">
+            <![CDATA[
+LogicalProject(EMPNO=[$0], DEPTNO=[$7])
+  LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+        </Resource>
+    </TestCase>
+    <TestCase name="testAggregateRemove3">
+        <Resource name="sql">
+            <![CDATA[select empno, count(mgr)
+            from sales.emp group by empno, deptno]]>
+        </Resource>
+        <Resource name="planBefore">
+            <![CDATA[
+LogicalProject(EMPNO=[$0], EXPR$1=[$2])
+  LogicalAggregate(group=[{0, 1}], EXPR$1=[COUNT($2)])
+    LogicalProject(EMPNO=[$0], DEPTNO=[$7], MGR=[$3])
+      LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+        </Resource>
+        <Resource name="planAfter">
+            <![CDATA[
+LogicalProject(EMPNO=[$0], EXPR$1=[CASE(IS NOT NULL($3), 1:BIGINT, 0:BIGINT)])
+  LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+        </Resource>
+    </TestCase>
+    <TestCase name="testAggregateRemove4">
+        <Resource name="sql">
+            <![CDATA[select empno, max(sal), avg(sal)
+            from sales.emp group by empno, deptno]]>
+        </Resource>
+        <Resource name="planBefore">
+            <![CDATA[
+LogicalProject(EMPNO=[$0], EXPR$1=[$2], EXPR$2=[$3])
+  LogicalAggregate(group=[{0, 1}], EXPR$1=[MAX($2)], EXPR$2=[AVG($2)])
+    LogicalProject(EMPNO=[$0], DEPTNO=[$7], SAL=[$5])
+      LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+        </Resource>
+        <Resource name="planAfter">
+            <![CDATA[
+LogicalProject(EMPNO=[$0], EXPR$1=[$2], EXPR$2=[$3])
+  LogicalAggregate(group=[{0, 1}], EXPR$1=[MAX($2)], EXPR$2=[AVG($2)])
+    LogicalProject(EMPNO=[$0], DEPTNO=[$7], SAL=[$5])
+      LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+        </Resource>
+    </TestCase>
+    <TestCase name="testAggregateRemove5">
+        <Resource name="sql">
+            <![CDATA[select empno, deptno, sum(sal)
+            from sales.emp group by cube(empno, deptno)]]>
+        </Resource>
+        <Resource name="planBefore">
+            <![CDATA[
+LogicalAggregate(group=[{0, 1}], groups=[[{0, 1}, {0}, {1}, {}]], EXPR$2=[SUM($2)])
+  LogicalProject(EMPNO=[$0], DEPTNO=[$7], SAL=[$5])
+    LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+        </Resource>
+        <Resource name="planAfter">
+            <![CDATA[
+LogicalAggregate(group=[{0, 1}], groups=[[{0, 1}, {0}, {1}, {}]], EXPR$2=[SUM($2)])
+  LogicalProject(EMPNO=[$0], DEPTNO=[$7], SAL=[$5])
+    LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+        </Resource>
+    </TestCase>
+    <TestCase name="testAggregateRemove6">
+        <Resource name="sql">
+            <![CDATA[select deptno, max(sal)
+            from sales.emp group by deptno]]>
+        </Resource>
+        <Resource name="planBefore">
+            <![CDATA[
+LogicalAggregate(group=[{0}], EXPR$1=[MAX($1)])
+  LogicalProject(DEPTNO=[$7], SAL=[$5])
+    LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+        </Resource>
+        <Resource name="planAfter">
+            <![CDATA[
+LogicalAggregate(group=[{0}], EXPR$1=[MAX($1)])
+  LogicalProject(DEPTNO=[$7], SAL=[$5])
+    LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+        </Resource>
+    </TestCase>
     <TestCase name="testReduceNullableCase2">
         <Resource name="sql">
             <![CDATA[SELECT deptno, ename, CASE WHEN 1=2 THEN substring(ename, 1, cast(2 as int)) ELSE NULL end from emp group by deptno, ename, case when 1=2 then substring(ename,1, cast(2 as int))  else null end]]>
diff --git a/core/src/test/resources/sql/sub-query.iq b/core/src/test/resources/sql/sub-query.iq
index 77492d3..d639ae9 100644
--- a/core/src/test/resources/sql/sub-query.iq
+++ b/core/src/test/resources/sql/sub-query.iq
@@ -2052,9 +2052,8 @@ EnumerableAggregate(group=[{}], C=[COUNT()])
       EnumerableJoin(condition=[=($1, $3)], joinType=[left])
         EnumerableCalc(expr#0..7=[{inputs}], proj#0..1=[{exprs}], SAL=[$t5])
           EnumerableTableScan(table=[[scott, EMP]])
-        EnumerableAggregate(group=[{1}], c=[COUNT()], ck=[COUNT($0)])
-          EnumerableCalc(expr#0..2=[{inputs}], expr#3=[IS NOT NULL($t1)], proj#0..2=[{exprs}], $condition=[$t3])
-            EnumerableTableScan(table=[[scott, DEPT]])
+        EnumerableCalc(expr#0..2=[{inputs}], expr#3=[1:BIGINT], expr#4=[IS NOT NULL($t1)], DNAME=[$t1], $f1=[$t3], $f2=[$t3], $condition=[$t4])
+          EnumerableTableScan(table=[[scott, DEPT]])
       EnumerableCalc(expr#0..4=[{inputs}], DEPTNO=[$t2], i=[$t3], DNAME=[$t4], SAL=[$t0])
         EnumerableJoin(condition=[=($1, $2)], joinType=[inner])
           EnumerableCalc(expr#0=[{inputs}], expr#1=[100], expr#2=[+($t0, $t1)], SAL=[$t0], $f1=[$t2])