You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@calcite.apache.org by jc...@apache.org on 2019/03/12 22:54:04 UTC

[calcite] 01/01: [CALCITE-2912]

This is an automated email from the ASF dual-hosted git repository.

jcamacho pushed a commit to branch CALCITE-2912
in repository https://gitbox.apache.org/repos/asf/calcite.git

commit 1e81f956a4736b70c258b5594850236a4100e2a6
Author: Jesus Camacho Rodriguez <jc...@apache.org>
AuthorDate: Tue Mar 12 15:53:52 2019 -0700

    [CALCITE-2912]
---
 .../rules/AggregateProjectPullUpConstantsRule.java |  35 ++++--
 .../calcite/rel/rules/AggregateReduceRule.java     | 127 +++++++++++++++++++++
 .../apache/calcite/test/MaterializationTest.java   |  16 +++
 .../org/apache/calcite/test/RelOptRulesTest.java   |  10 +-
 .../org/apache/calcite/test/RelOptRulesTest.xml    |  31 ++---
 5 files changed, 192 insertions(+), 27 deletions(-)

diff --git a/core/src/main/java/org/apache/calcite/rel/rules/AggregateProjectPullUpConstantsRule.java b/core/src/main/java/org/apache/calcite/rel/rules/AggregateProjectPullUpConstantsRule.java
index a12e6d1..9a65488 100644
--- a/core/src/main/java/org/apache/calcite/rel/rules/AggregateProjectPullUpConstantsRule.java
+++ b/core/src/main/java/org/apache/calcite/rel/rules/AggregateProjectPullUpConstantsRule.java
@@ -31,11 +31,13 @@ import org.apache.calcite.rel.type.RelDataTypeField;
 import org.apache.calcite.rex.RexBuilder;
 import org.apache.calcite.rex.RexInputRef;
 import org.apache.calcite.rex.RexNode;
+import org.apache.calcite.sql.fun.SqlStdOperatorTable;
 import org.apache.calcite.tools.RelBuilder;
 import org.apache.calcite.tools.RelBuilderFactory;
 import org.apache.calcite.util.ImmutableBitSet;
 import org.apache.calcite.util.Pair;
 
+import java.math.BigDecimal;
 import java.util.ArrayList;
 import java.util.List;
 import java.util.NavigableMap;
@@ -100,7 +102,7 @@ public class AggregateProjectPullUpConstantsRule extends RelOptRule {
 
     assert !aggregate.indicator : "predicate ensured no grouping sets";
     final int groupCount = aggregate.getGroupCount();
-    if (groupCount == 1) {
+    if (groupCount < 1) {
       // No room for optimization since we cannot convert from non-empty
       // GROUP BY list to the empty one.
       return;
@@ -127,14 +129,7 @@ public class AggregateProjectPullUpConstantsRule extends RelOptRule {
       return;
     }
 
-    if (groupCount == map.size()) {
-      // At least a single item in group by is required.
-      // Otherwise "GROUP BY 1, 2" might be altered to "GROUP BY ()".
-      // Removing of the first element is not optimal here,
-      // however it will allow us to use fast path below (just trim
-      // groupCount).
-      map.remove(map.navigableKeySet().first());
-    }
+    final boolean empty = groupCount == map.size();
 
     ImmutableBitSet newGroupSet = aggregate.getGroupSet();
     for (int key : map.keySet()) {
@@ -154,7 +149,25 @@ public class AggregateProjectPullUpConstantsRule extends RelOptRule {
           aggCall.adaptTo(input, aggCall.getArgList(), aggCall.filterArg,
               groupCount, newGroupCount));
     }
-    relBuilder.aggregate(relBuilder.groupKey(newGroupSet), newAggCalls);
+
+    // Create aggregate operator.
+    if (empty) {
+      // If empty, create an additional count(*) field
+      Aggregate tmpAggregate = (Aggregate) relBuilder
+          .aggregate(relBuilder.groupKey(), relBuilder.countStar(null))
+          .build();
+      newAggCalls.add(tmpAggregate.getAggCallList().get(0));
+      // Reset stack and create new aggregate call
+      relBuilder.push(tmpAggregate.getInput());
+      relBuilder.aggregate(relBuilder.groupKey(), newAggCalls);
+      // Add a filter on the new count(*) != 0
+      relBuilder.filter(
+          rexBuilder.makeCall(SqlStdOperatorTable.NOT_EQUALS,
+              relBuilder.field(relBuilder.peek().getRowType().getFieldCount() - 1),
+              rexBuilder.makeBigintLiteral(BigDecimal.ZERO)));
+    } else {
+      relBuilder.aggregate(relBuilder.groupKey(newGroupSet), newAggCalls);
+    }
 
     // Create a projection back again.
     List<Pair<RexNode, String>> projects = new ArrayList<>();
@@ -186,6 +199,8 @@ public class AggregateProjectPullUpConstantsRule extends RelOptRule {
       projects.add(Pair.of(expr, field.getName()));
     }
     relBuilder.project(Pair.left(projects), Pair.right(projects)); // inverse
+    // Create top Project fixing nullability of fields
+    relBuilder.convert(aggregate.getRowType(), false);
     call.transformTo(relBuilder.build());
   }
 
diff --git a/core/src/main/java/org/apache/calcite/rel/rules/AggregateReduceRule.java b/core/src/main/java/org/apache/calcite/rel/rules/AggregateReduceRule.java
new file mode 100644
index 0000000..4d237de
--- /dev/null
+++ b/core/src/main/java/org/apache/calcite/rel/rules/AggregateReduceRule.java
@@ -0,0 +1,127 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to you under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.calcite.rel.rules;
+
+import org.apache.calcite.plan.RelOptRule;
+import org.apache.calcite.plan.RelOptRuleCall;
+import org.apache.calcite.rel.core.Aggregate;
+import org.apache.calcite.rel.core.AggregateCall;
+import org.apache.calcite.rel.core.RelFactories;
+import org.apache.calcite.rel.logical.LogicalAggregate;
+import org.apache.calcite.rex.RexBuilder;
+import org.apache.calcite.rex.RexNode;
+import org.apache.calcite.sql.SqlKind;
+import org.apache.calcite.tools.RelBuilder;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * Planner rule that reduces aggregate functions in
+ * {@link org.apache.calcite.rel.core.Aggregate}s to simpler forms.
+ *
+ * <p>Rewrites:
+ * <ul>
+ *
+ * <li>COUNT(x) &rarr; COUNT(*) if x is not nullable
+ * </ul>
+ *
+ * It also removes duplicate aggregate calls.
+ */
+public class AggregateReduceRule extends RelOptRule {
+
+  /** The singleton. */
+  public static final AggregateReduceRule INSTANCE =
+      new AggregateReduceRule();
+
+  /** Private constructor. */
+  private AggregateReduceRule() {
+    super(operand(LogicalAggregate.class, any()),
+        RelFactories.LOGICAL_BUILDER, null);
+  }
+
+  @Override public void onMatch(RelOptRuleCall call) {
+    final RelBuilder relBuilder = call.builder();
+    final Aggregate aggRel = call.rel(0);
+    final RexBuilder rexBuilder = aggRel.getCluster().getRexBuilder();
+
+    // We try to rewrite COUNT(x) into COUNT(*) if x is not nullable.
+    // We remove duplicate aggregate calls as well.
+    boolean rewrite = false;
+    boolean identity = true;
+    final Map<AggregateCall, Integer> mapping = new HashMap<>();
+    final List<Integer> indexes = new ArrayList<>();
+    final List<AggregateCall> aggCalls = aggRel.getAggCallList();
+    final List<AggregateCall> newAggCalls = new ArrayList<>(aggCalls.size());
+    int nextIdx = aggRel.getGroupCount() + aggRel.getIndicatorCount();
+    for (int i = 0; i < aggCalls.size(); i++) {
+      AggregateCall aggCall = aggCalls.get(i);
+      if (aggCall.getAggregation().getKind() == SqlKind.COUNT && !aggCall.isDistinct()) {
+        final List<Integer> args = aggCall.getArgList();
+        final List<Integer> nullableArgs = new ArrayList<>(args.size());
+        for (int arg : args) {
+          if (aggRel.getInput().getRowType().getFieldList().get(arg).getType().isNullable()) {
+            nullableArgs.add(arg);
+          }
+        }
+        if (nullableArgs.size() != args.size()) {
+          aggCall = aggCall.copy(nullableArgs, aggCall.filterArg, aggCall.collation);
+          rewrite = true;
+        }
+      }
+      Integer idx = mapping.get(aggCall);
+      if (idx == null) {
+        newAggCalls.add(aggCall);
+        idx = nextIdx++;
+        mapping.put(aggCall, idx);
+      } else {
+        rewrite = true;
+        identity = false;
+      }
+      indexes.add(idx);
+    }
+
+    if (rewrite) {
+      // We trigger the transform
+      final Aggregate newAggregate = aggRel.copy(aggRel.getTraitSet(), aggRel.getInput(),
+          aggRel.indicator, aggRel.getGroupSet(), aggRel.getGroupSets(),
+          newAggCalls);
+      if (identity) {
+        call.transformTo(newAggregate);
+      } else {
+        final int offset = aggRel.getGroupCount() + aggRel.getIndicatorCount();
+        final List<RexNode> projList = new ArrayList<>();
+        for (int i = 0; i < offset; ++i) {
+          projList.add(
+              rexBuilder.makeInputRef(
+                  aggRel.getRowType().getFieldList().get(i).getType(), i));
+        }
+        for (int i = offset; i < aggRel.getRowType().getFieldCount(); ++i) {
+          projList.add(
+              rexBuilder.makeInputRef(
+                  aggRel.getRowType().getFieldList().get(i).getType(), indexes.get(i - offset)));
+        }
+        call.transformTo(relBuilder.push(newAggregate).project(projList).build());
+      }
+    }
+  }
+
+}
+
+// End AggregateReduceRule.java
diff --git a/core/src/test/java/org/apache/calcite/test/MaterializationTest.java b/core/src/test/java/org/apache/calcite/test/MaterializationTest.java
index 3126db2..5192bbe 100644
--- a/core/src/test/java/org/apache/calcite/test/MaterializationTest.java
+++ b/core/src/test/java/org/apache/calcite/test/MaterializationTest.java
@@ -2193,6 +2193,22 @@ public class MaterializationTest {
                 + "    EnumerableTableScan(table=[[hr, m0]]"));
   }
 
+//  @Test public void testAggregateMaterializationWithConstantFilter() {
+//    checkMaterialize(
+//        "select \"deptno\", \"name\", count(*) as c\n"
+//            + "from \"emps\" group by \"deptno\", \"name\"",
+//        "select \"name\", count(*) as c\n"
+//            + "from \"emps\" where \"name\" = 'a_name' group by \"name\"");
+//  }
+//
+//  @Test public void testAggregateMaterializationWithConstantFilter2() {
+//    checkMaterialize(
+//        "select \"deptno\", \"name\", \"salary\", count(*) as c\n"
+//            + "from \"emps\" group by \"deptno\", \"name\", \"salary\"",
+//        "select \"deptno\", \"name\", count(*) as c\n"
+//            + "from \"emps\" where \"name\" = 'a_name' group by \"deptno\", \"name\"");
+//  }
+
   @Test public void testMaterializationSubstitution() {
     String q = "select *\n"
         + "from (select * from \"emps\" where \"empid\" < 300)\n"
diff --git a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
index ef2b705..a8ec4fd 100644
--- a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
+++ b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
@@ -56,6 +56,7 @@ import org.apache.calcite.rel.rules.AggregateJoinTransposeRule;
 import org.apache.calcite.rel.rules.AggregateProjectMergeRule;
 import org.apache.calcite.rel.rules.AggregateProjectPullUpConstantsRule;
 import org.apache.calcite.rel.rules.AggregateReduceFunctionsRule;
+import org.apache.calcite.rel.rules.AggregateReduceRule;
 import org.apache.calcite.rel.rules.AggregateUnionAggregateRule;
 import org.apache.calcite.rel.rules.AggregateUnionTransposeRule;
 import org.apache.calcite.rel.rules.AggregateValuesRule;
@@ -3822,17 +3823,20 @@ public class RelOptRulesTest extends RelOptTestBase {
     checkPlanning(new HepPlanner(program), sql);
   }
 
-  /** Tests {@link AggregateProjectPullUpConstantsRule} where reduction is not
-   * possible because "deptno" is the only key. */
+  /** Tests {@link AggregateProjectPullUpConstantsRule} where all columns can be
+   * reduced. */
   @Test public void testAggregateConstantKeyRule2() {
     final HepProgram program = new HepProgramBuilder()
         .addRuleInstance(AggregateProjectPullUpConstantsRule.INSTANCE2)
+        .addRuleInstance(AggregateReduceRule.INSTANCE)
+        .addRuleInstance(FilterProjectTransposeRule.INSTANCE)
+        .addRuleInstance(ProjectMergeRule.INSTANCE)
         .build();
     final String sql = "select count(*) as c\n"
         + "from sales.emp\n"
         + "where deptno = 10\n"
         + "group by deptno";
-    checkPlanUnchanged(new HepPlanner(program), sql);
+    checkPlanning(new HepPlanner(program), sql);
   }
 
   /** Tests {@link AggregateProjectPullUpConstantsRule} where both keys are
diff --git a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
index c94842c..fb03df2 100644
--- a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
+++ b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
@@ -4047,10 +4047,11 @@ LogicalAggregate(group=[{0, 1}], EXPR$2=[MAX($2)])
         </Resource>
         <Resource name="planAfter">
             <![CDATA[
-LogicalProject(EXPR$0=[$0], EXPR$1=[+(2, 3)], EXPR$2=[$1])
-  LogicalAggregate(group=[{0}], EXPR$2=[MAX($2)])
-    LogicalProject(EXPR$0=[4], EXPR$1=[+(2, 3)], MGR=[$3])
-      LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+LogicalProject(EXPR$0=[4], EXPR$1=[+(2, 3)], EXPR$2=[$0])
+  LogicalFilter(condition=[<>($1, 0)])
+    LogicalAggregate(group=[{}], EXPR$2=[MAX($2)], agg#1=[COUNT()])
+      LogicalProject(EXPR$0=[4], EXPR$1=[+(2, 3)], MGR=[$3])
+        LogicalTableScan(table=[[CATALOG, SALES, EMP]])
 ]]>
         </Resource>
     </TestCase>
@@ -4069,10 +4070,11 @@ LogicalAggregate(group=[{0, 1}], EXPR$2=[MAX($2)])
         </Resource>
         <Resource name="planAfter">
             <![CDATA[
-LogicalProject(EXPR$0=[$0], EXPR$1=[+(2, 3)], EXPR$2=[$1])
-  LogicalAggregate(group=[{0}], EXPR$2=[MAX($2)])
-    LogicalProject(EXPR$0=[4], EXPR$1=[+(2, 3)], FIVE=[5])
-      LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+LogicalProject(EXPR$0=[4], EXPR$1=[+(2, 3)], EXPR$2=[CAST($0):INTEGER NOT NULL])
+  LogicalFilter(condition=[<>($1, 0)])
+    LogicalAggregate(group=[{}], EXPR$2=[MAX($2)], agg#1=[COUNT()])
+      LogicalProject(EXPR$0=[4], EXPR$1=[+(2, 3)], FIVE=[5])
+        LogicalTableScan(table=[[CATALOG, SALES, EMP]])
 ]]>
         </Resource>
     </TestCase>
@@ -4091,10 +4093,11 @@ LogicalAggregate(group=[{0, 1}], EXPR$2=[MAX($2)])
         </Resource>
         <Resource name="planAfter">
             <![CDATA[
-LogicalProject(EXPR$0=[$0], EXPR$1=[+(2, 3)], EXPR$2=[$1])
-  LogicalAggregate(group=[{0}], EXPR$2=[MAX($2)])
-    LogicalProject(EXPR$0=[4], EXPR$1=[+(2, 3)], $f2=[5])
-      LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+LogicalProject(EXPR$0=[4], EXPR$1=[+(2, 3)], EXPR$2=[CAST($0):INTEGER NOT NULL])
+  LogicalFilter(condition=[<>($1, 0)])
+    LogicalAggregate(group=[{}], EXPR$2=[MAX($2)], agg#1=[COUNT()])
+      LogicalProject(EXPR$0=[4], EXPR$1=[+(2, 3)], $f2=[5])
+        LogicalTableScan(table=[[CATALOG, SALES, EMP]])
 ]]>
         </Resource>
     </TestCase>
@@ -7977,8 +7980,8 @@ LogicalProject(C=[$1])
         </Resource>
         <Resource name="planAfter">
             <![CDATA[
-LogicalProject(C=[$1])
-  LogicalAggregate(group=[{0}], C=[COUNT()])
+LogicalFilter(condition=[<>($0, 0)])
+  LogicalAggregate(group=[{}], C=[COUNT()])
     LogicalProject(DEPTNO=[$7])
       LogicalFilter(condition=[=($7, 10)])
         LogicalTableScan(table=[[CATALOG, SALES, EMP]])