You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@calcite.apache.org by jh...@apache.org on 2015/09/03 00:16:15 UTC

[27/50] incubator-calcite git commit: [CALCITE-751] Push aggregate with aggregate functions through join

[CALCITE-751] Push aggregate with aggregate functions through join

In this iteration, it is not safe to use the extended rule (that can handle
aggregate functions) in the Volcano planner, only in the Hep planner. The
extended rule requires metadata that can handle cyclic relational expressions
(to be fixed in [CALCITE-794]).


Project: http://git-wip-us.apache.org/repos/asf/incubator-calcite/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-calcite/commit/cf7a7a97
Tree: http://git-wip-us.apache.org/repos/asf/incubator-calcite/tree/cf7a7a97
Diff: http://git-wip-us.apache.org/repos/asf/incubator-calcite/diff/cf7a7a97

Branch: refs/heads/branch-release
Commit: cf7a7a97368d6e72b2a413f9f8857f54c2970e61
Parents: 41541d4
Author: Julian Hyde <jh...@apache.org>
Authored: Thu Jun 4 17:22:28 2015 -0700
Committer: Julian Hyde <jh...@apache.org>
Committed: Thu Jul 23 12:46:58 2015 -0700

----------------------------------------------------------------------
 .../calcite/plan/SubstitutionVisitor.java       |  11 +-
 .../apache/calcite/rel/core/AggregateCall.java  |   5 +-
 .../rel/metadata/RelMdColumnUniqueness.java     |  66 ++++
 .../rel/rules/AggregateJoinTransposeRule.java   | 295 +++++++++++++++--
 .../rel/rules/AggregateReduceFunctionsRule.java |   8 +-
 .../java/org/apache/calcite/rex/RexBuilder.java |  28 +-
 .../org/apache/calcite/sql/SqlAggFunction.java  |   7 +-
 .../calcite/sql/SqlSplittableAggFunction.java   | 267 +++++++++++++++
 .../calcite/sql/fun/SqlCountAggFunction.java    |   8 +
 .../calcite/sql/fun/SqlMinMaxAggFunction.java   |   7 +
 .../calcite/sql/fun/SqlSumAggFunction.java      |   9 +
 .../sql/fun/SqlSumEmptyIsZeroAggFunction.java   |   8 +
 .../apache/calcite/test/RelOptRulesTest.java    | 111 +++++--
 .../org/apache/calcite/test/RelOptRulesTest.xml | 120 ++++++-
 core/src/test/resources/sql/agg.oq              | 323 ++++++++++++++++++-
 core/src/test/resources/sql/join.oq             |  11 +-
 .../apache/calcite/adapter/tpcds/TpcdsTest.java | 103 ++++++
 17 files changed, 1276 insertions(+), 111 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-calcite/blob/cf7a7a97/core/src/main/java/org/apache/calcite/plan/SubstitutionVisitor.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/calcite/plan/SubstitutionVisitor.java b/core/src/main/java/org/apache/calcite/plan/SubstitutionVisitor.java
index 14e836a..28d7b32 100644
--- a/core/src/main/java/org/apache/calcite/plan/SubstitutionVisitor.java
+++ b/core/src/main/java/org/apache/calcite/plan/SubstitutionVisitor.java
@@ -69,7 +69,6 @@ import com.google.common.base.Preconditions;
 import com.google.common.base.Predicate;
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableSet;
-import com.google.common.collect.Iterables;
 import com.google.common.collect.LinkedHashMultimap;
 import com.google.common.collect.Lists;
 import com.google.common.collect.Maps;
@@ -1070,16 +1069,10 @@ public class SubstitutionVisitor {
   }
 
   public static MutableAggregate permute(MutableAggregate aggregate,
-      MutableRel input, final Mapping mapping) {
+      MutableRel input, Mapping mapping) {
     ImmutableBitSet groupSet = Mappings.apply(mapping, aggregate.getGroupSet());
     ImmutableList<ImmutableBitSet> groupSets =
-        ImmutableList.copyOf(
-            Iterables.transform(aggregate.getGroupSets(),
-                new Function<ImmutableBitSet, ImmutableBitSet>() {
-                  public ImmutableBitSet apply(ImmutableBitSet input1) {
-                    return Mappings.apply(mapping, input1);
-                  }
-                }));
+        Mappings.apply2(mapping, aggregate.getGroupSets());
     List<AggregateCall> aggregateCalls =
         apply(mapping, aggregate.getAggCallList());
     return MutableAggregate.of(input, aggregate.indicator, groupSet, groupSets,

http://git-wip-us.apache.org/repos/asf/incubator-calcite/blob/cf7a7a97/core/src/main/java/org/apache/calcite/rel/core/AggregateCall.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/calcite/rel/core/AggregateCall.java b/core/src/main/java/org/apache/calcite/rel/core/AggregateCall.java
index 2b56522..d35f7e1 100644
--- a/core/src/main/java/org/apache/calcite/rel/core/AggregateCall.java
+++ b/core/src/main/java/org/apache/calcite/rel/core/AggregateCall.java
@@ -21,6 +21,7 @@ import org.apache.calcite.rel.type.RelDataType;
 import org.apache.calcite.rel.type.RelDataTypeFactory;
 import org.apache.calcite.sql.SqlAggFunction;
 import org.apache.calcite.sql.type.SqlTypeUtil;
+import org.apache.calcite.util.mapping.Mapping;
 import org.apache.calcite.util.mapping.Mappings;
 
 import com.google.common.base.Objects;
@@ -286,8 +287,8 @@ public class AggregateCall {
   /** Creates a copy of this aggregate call, applying a mapping to its
    * arguments. */
   public AggregateCall transform(Mappings.TargetMapping mapping) {
-    return copy(Mappings.permute(argList, mapping),
-        Mappings.apply(mapping, filterArg));
+    return copy(Mappings.apply2((Mapping) mapping, argList),
+        filterArg < 0 ? -1 : Mappings.apply(mapping, filterArg));
   }
 }
 

http://git-wip-us.apache.org/repos/asf/incubator-calcite/blob/cf7a7a97/core/src/main/java/org/apache/calcite/rel/metadata/RelMdColumnUniqueness.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/calcite/rel/metadata/RelMdColumnUniqueness.java b/core/src/main/java/org/apache/calcite/rel/metadata/RelMdColumnUniqueness.java
index 527ab3f..9e0e0bd 100644
--- a/core/src/main/java/org/apache/calcite/rel/metadata/RelMdColumnUniqueness.java
+++ b/core/src/main/java/org/apache/calcite/rel/metadata/RelMdColumnUniqueness.java
@@ -16,6 +16,8 @@
  */
 package org.apache.calcite.rel.metadata;
 
+import org.apache.calcite.plan.hep.HepRelVertex;
+import org.apache.calcite.plan.volcano.RelSubset;
 import org.apache.calcite.rel.RelNode;
 import org.apache.calcite.rel.core.Aggregate;
 import org.apache.calcite.rel.core.Correlate;
@@ -35,6 +37,8 @@ import org.apache.calcite.sql.fun.SqlStdOperatorTable;
 import org.apache.calcite.util.BuiltInMethod;
 import org.apache.calcite.util.ImmutableBitSet;
 
+import com.google.common.base.Predicate;
+
 import java.util.List;
 
 /**
@@ -258,6 +262,68 @@ public class RelMdColumnUniqueness {
     // no information available
     return null;
   }
+
+  public Boolean areColumnsUnique(
+      boolean dummy, // prevent method from being used
+      HepRelVertex rel,
+      ImmutableBitSet columns,
+      boolean ignoreNulls) {
+    return RelMetadataQuery.areColumnsUnique(
+        rel.getCurrentRel(),
+        columns,
+        ignoreNulls);
+  }
+
+  public Boolean areColumnsUnique(
+      boolean dummy, // prevent method from being used
+      RelSubset rel,
+      ImmutableBitSet columns,
+      boolean ignoreNulls) {
+    int nullCount = 0;
+    for (RelNode rel2 : rel.getRels()) {
+      if (rel2 instanceof Aggregate || simplyProjects(rel2, columns)) {
+        final Boolean unique =
+            RelMetadataQuery.areColumnsUnique(rel2, columns, ignoreNulls);
+        if (unique != null) {
+          if (unique) {
+            return true;
+          }
+        } else {
+          ++nullCount;
+        }
+      }
+    }
+    return nullCount == 0 ? false : null;
+  }
+
+  private boolean simplyProjects(RelNode rel, ImmutableBitSet columns) {
+    if (!(rel instanceof Project)) {
+      return false;
+    }
+    Project project = (Project) rel;
+    final List<RexNode> projects = project.getProjects();
+    for (int column : columns) {
+      if (column >= projects.size()) {
+        return false;
+      }
+      if (!(projects.get(column) instanceof RexInputRef)) {
+        return false;
+      }
+      final RexInputRef ref = (RexInputRef) projects.get(column);
+      if (ref.getIndex() != column) {
+        return false;
+      }
+    }
+    return true;
+  }
+
+  /** Aggregate and Calc are "safe" children of a RelSubset to delve into. */
+  private static final Predicate<RelNode> SAFE_REL =
+      new Predicate<RelNode>() {
+        public boolean apply(RelNode r) {
+          return r instanceof Aggregate || r instanceof Project;
+        }
+      };
 }
 
 // End RelMdColumnUniqueness.java

http://git-wip-us.apache.org/repos/asf/incubator-calcite/blob/cf7a7a97/core/src/main/java/org/apache/calcite/rel/rules/AggregateJoinTransposeRule.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/calcite/rel/rules/AggregateJoinTransposeRule.java b/core/src/main/java/org/apache/calcite/rel/rules/AggregateJoinTransposeRule.java
index 3d310a5..673d579 100644
--- a/core/src/main/java/org/apache/calcite/rel/rules/AggregateJoinTransposeRule.java
+++ b/core/src/main/java/org/apache/calcite/rel/rules/AggregateJoinTransposeRule.java
@@ -16,25 +16,42 @@
  */
 package org.apache.calcite.rel.rules;
 
+import org.apache.calcite.linq4j.Ord;
 import org.apache.calcite.plan.RelOptRule;
 import org.apache.calcite.plan.RelOptRuleCall;
 import org.apache.calcite.plan.RelOptUtil;
 import org.apache.calcite.rel.RelNode;
 import org.apache.calcite.rel.core.Aggregate;
+import org.apache.calcite.rel.core.AggregateCall;
 import org.apache.calcite.rel.core.Join;
 import org.apache.calcite.rel.core.JoinRelType;
 import org.apache.calcite.rel.core.RelFactories;
 import org.apache.calcite.rel.logical.LogicalAggregate;
 import org.apache.calcite.rel.logical.LogicalJoin;
+import org.apache.calcite.rel.metadata.RelMetadataQuery;
+import org.apache.calcite.rex.RexBuilder;
+import org.apache.calcite.rex.RexCall;
+import org.apache.calcite.rex.RexInputRef;
 import org.apache.calcite.rex.RexNode;
 import org.apache.calcite.rex.RexUtil;
+import org.apache.calcite.sql.SqlAggFunction;
+import org.apache.calcite.sql.SqlSplittableAggFunction;
 import org.apache.calcite.util.ImmutableBitSet;
+import org.apache.calcite.util.mapping.Mapping;
 import org.apache.calcite.util.mapping.Mappings;
 
 import com.google.common.base.Function;
+import com.google.common.base.Preconditions;
+import com.google.common.collect.ImmutableList;
 import com.google.common.collect.Lists;
 
+import java.util.ArrayList;
+import java.util.BitSet;
+import java.util.HashMap;
 import java.util.List;
+import java.util.Map;
+import java.util.SortedMap;
+import java.util.TreeMap;
 
 /**
  * Planner rule that pushes an
@@ -48,29 +65,56 @@ public class AggregateJoinTransposeRule extends RelOptRule {
           LogicalJoin.class,
           RelFactories.DEFAULT_JOIN_FACTORY);
 
+  /** Extended instance of the rule that can push down aggregate functions. */
+  public static final AggregateJoinTransposeRule EXTENDED =
+      new AggregateJoinTransposeRule(LogicalAggregate.class,
+          RelFactories.DEFAULT_AGGREGATE_FACTORY,
+          LogicalJoin.class,
+          RelFactories.DEFAULT_JOIN_FACTORY, true);
+
   private final RelFactories.AggregateFactory aggregateFactory;
 
   private final RelFactories.JoinFactory joinFactory;
 
+  private final boolean allowFunctions;
+
   /** Creates an AggregateJoinTransposeRule. */
   public AggregateJoinTransposeRule(Class<? extends Aggregate> aggregateClass,
       RelFactories.AggregateFactory aggregateFactory,
       Class<? extends Join> joinClass,
       RelFactories.JoinFactory joinFactory) {
+    this(aggregateClass, aggregateFactory, joinClass, joinFactory, false);
+  }
+
+  /** Creates an AggregateJoinTransposeRule that may push down functions. */
+  public AggregateJoinTransposeRule(Class<? extends Aggregate> aggregateClass,
+      RelFactories.AggregateFactory aggregateFactory,
+      Class<? extends Join> joinClass,
+      RelFactories.JoinFactory joinFactory,
+      boolean allowFunctions) {
     super(
         operand(aggregateClass, null, Aggregate.IS_SIMPLE,
             operand(joinClass, any())));
     this.aggregateFactory = aggregateFactory;
     this.joinFactory = joinFactory;
+    this.allowFunctions = allowFunctions;
   }
 
   public void onMatch(RelOptRuleCall call) {
-    Aggregate aggregate = call.rel(0);
-    Join join = call.rel(1);
+    final Aggregate aggregate = call.rel(0);
+    final Join join = call.rel(1);
+    final RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder();
 
-    // If aggregate functions are present, we bail out
-    if (!aggregate.getAggCallList().isEmpty()) {
-      return;
+    // If any aggregate functions do not support splitting, bail out
+    // If any aggregate call has a filter, bail out
+    for (AggregateCall aggregateCall : aggregate.getAggCallList()) {
+      if (aggregateCall.getAggregation().unwrap(SqlSplittableAggFunction.class)
+          == null) {
+        return;
+      }
+      if (aggregateCall.filterArg >= 0) {
+        return;
+      }
     }
 
     // If it is not an inner join, we do not push the
@@ -79,14 +123,20 @@ public class AggregateJoinTransposeRule extends RelOptRule {
       return;
     }
 
+    if (!allowFunctions && !aggregate.getAggCallList().isEmpty()) {
+      return;
+    }
+
     // Do the columns used by the join appear in the output of the aggregate?
     final ImmutableBitSet aggregateColumns = aggregate.getGroupSet();
+    final ImmutableBitSet keyColumns = keyColumns(aggregateColumns,
+        RelMetadataQuery.getPulledUpPredicates(join).pulledUpPredicates);
     final ImmutableBitSet joinColumns =
         RelOptUtil.InputFinder.bits(join.getCondition());
-    boolean allColumnsInAggregate = aggregateColumns.contains(joinColumns);
-    if (!allColumnsInAggregate) {
-      return;
-    }
+    final boolean allColumnsInAggregate =
+        keyColumns.contains(joinColumns);
+    final ImmutableBitSet belowAggregateColumns =
+        aggregateColumns.union(joinColumns);
 
     // Split join condition
     final List<Integer> leftKeys = Lists.newArrayList();
@@ -99,32 +149,233 @@ public class AggregateJoinTransposeRule extends RelOptRule {
       return;
     }
 
-    // Create new aggregate operators below join
-    final ImmutableBitSet leftKeysBitSet = ImmutableBitSet.of(leftKeys);
-    RelNode newLeftInput = aggregateFactory.createAggregate(join.getLeft(),
-        false, leftKeysBitSet, null, aggregate.getAggCallList());
-    final ImmutableBitSet rightKeysBitSet = ImmutableBitSet.of(rightKeys);
-    RelNode newRightInput = aggregateFactory.createAggregate(join.getRight(),
-        false, rightKeysBitSet, null, aggregate.getAggCallList());
+    // Push each aggregate function down to each side that contains all of its
+    // arguments. Note that COUNT(*), because it has no arguments, can go to
+    // both sides.
+    final Map<Integer, Integer> map = new HashMap<>();
+    final List<Side> sides = new ArrayList<>();
+    int uniqueCount = 0;
+    int offset = 0;
+    int belowOffset = 0;
+    for (int s = 0; s < 2; s++) {
+      final Side side = new Side();
+      final RelNode joinInput = join.getInput(s);
+      int fieldCount = joinInput.getRowType().getFieldCount();
+      final ImmutableBitSet fieldSet =
+          ImmutableBitSet.range(offset, offset + fieldCount);
+      final ImmutableBitSet belowAggregateKeyNotShifted =
+          belowAggregateColumns.intersect(fieldSet);
+      for (Ord<Integer> c : Ord.zip(belowAggregateKeyNotShifted)) {
+        map.put(c.e, belowOffset + c.i);
+      }
+      final ImmutableBitSet belowAggregateKey =
+          belowAggregateKeyNotShifted.shift(-offset);
+      final boolean unique;
+      if (!allowFunctions) {
+        assert aggregate.getAggCallList().isEmpty();
+        // If there are no functions, it doesn't matter as much whether we
+        // aggregate the inputs before the join, because there will not be
+        // any functions experiencing a cartesian product effect.
+        //
+        // But finding out whether the input is already unique requires a call
+        // to areColumnsUnique that currently (until [CALCITE-794] "Detect
+        // cycles when computing statistics" is fixed) places a heavy load on
+        // the metadata system.
+        //
+        // So we choose to imagine the the input is already unique, which is
+        // untrue but harmless.
+        //
+        unique = true;
+      } else {
+        final Boolean unique0 =
+            RelMetadataQuery.areColumnsUnique(joinInput, belowAggregateKey);
+        unique = unique0 != null && unique0;
+      }
+      if (unique) {
+        ++uniqueCount;
+        side.aggregate = false;
+        side.newInput = joinInput;
+      } else {
+        side.aggregate = true;
+        List<AggregateCall> belowAggCalls = new ArrayList<>();
+        final SqlSplittableAggFunction.Registry<AggregateCall>
+            belowAggCallRegistry = registry(belowAggCalls);
+        final Mappings.TargetMapping mapping =
+            s == 0
+                ? Mappings.createIdentity(fieldCount)
+                : Mappings.createShiftMapping(fieldCount + offset, 0, offset,
+                    fieldCount);
+        for (Ord<AggregateCall> aggCall : Ord.zip(aggregate.getAggCallList())) {
+          final SqlAggFunction aggregation = aggCall.e.getAggregation();
+          final SqlSplittableAggFunction splitter =
+              Preconditions.checkNotNull(
+                  aggregation.unwrap(SqlSplittableAggFunction.class));
+          final AggregateCall call1;
+          if (fieldSet.contains(ImmutableBitSet.of(aggCall.e.getArgList()))) {
+            call1 = splitter.split(aggCall.e, mapping);
+          } else {
+            call1 = splitter.other(rexBuilder.getTypeFactory(), aggCall.e);
+          }
+          if (call1 != null) {
+            side.split.put(aggCall.i,
+                belowAggregateKey.cardinality()
+                    + belowAggCallRegistry.register(call1));
+          }
+        }
+        side.newInput = aggregateFactory.createAggregate(joinInput, false,
+            belowAggregateKey, null, belowAggCalls);
+      }
+      offset += fieldCount;
+      belowOffset += side.newInput.getRowType().getFieldCount();
+      sides.add(side);
+    }
+
+    if (uniqueCount == 2) {
+      // Both inputs to the join are unique. There is nothing to be gained by
+      // this rule. In fact, this aggregate+join may be the result of a previous
+      // invocation of this rule; if we continue we might loop forever.
+      return;
+    }
 
     // Update condition
-    final Mappings.TargetMapping mapping = Mappings.target(
+    final Mapping mapping = (Mapping) Mappings.target(
         new Function<Integer, Integer>() {
           public Integer apply(Integer a0) {
-            return aggregateColumns.indexOf(a0);
+            return map.get(a0);
           }
         },
         join.getRowType().getFieldCount(),
-        aggregateColumns.cardinality());
+        belowOffset);
     final RexNode newCondition =
         RexUtil.apply(mapping, join.getCondition());
 
     // Create new join
-    RelNode newJoin = joinFactory.createJoin(newLeftInput, newRightInput,
-        newCondition, join.getJoinType(),
+    RelNode newJoin = joinFactory.createJoin(sides.get(0).newInput,
+        sides.get(1).newInput, newCondition, join.getJoinType(),
         join.getVariablesStopped(), join.isSemiJoinDone());
 
-    call.transformTo(newJoin);
+    // Aggregate above to sum up the sub-totals
+    final List<AggregateCall> newAggCalls = new ArrayList<>();
+    final int groupIndicatorCount =
+        aggregate.getGroupCount() + aggregate.getIndicatorCount();
+    final int newLeftWidth = sides.get(0).newInput.getRowType().getFieldCount();
+    final List<RexNode> projects =
+        new ArrayList<>(rexBuilder.identityProjects(newJoin.getRowType()));
+    for (Ord<AggregateCall> aggCall : Ord.zip(aggregate.getAggCallList())) {
+      final SqlAggFunction aggregation = aggCall.e.getAggregation();
+      final SqlSplittableAggFunction splitter =
+          Preconditions.checkNotNull(
+              aggregation.unwrap(SqlSplittableAggFunction.class));
+      final Integer leftSubTotal = sides.get(0).split.get(aggCall.i);
+      final Integer rightSubTotal = sides.get(1).split.get(aggCall.i);
+      newAggCalls.add(
+          splitter.topSplit(rexBuilder, registry(projects),
+              groupIndicatorCount, newJoin.getRowType(), aggCall.e,
+              leftSubTotal == null ? -1 : leftSubTotal,
+              rightSubTotal == null ? -1 : rightSubTotal + newLeftWidth));
+    }
+    RelNode r = newJoin;
+  b:
+    if (allColumnsInAggregate && newAggCalls.isEmpty()) {
+      // no need to aggregate
+    } else {
+      r = RelOptUtil.createProject(r, projects, null, true);
+      if (allColumnsInAggregate) {
+        // let's see if we can convert
+        List<RexNode> projects2 = new ArrayList<>();
+        for (int key : Mappings.apply(mapping, aggregate.getGroupSet())) {
+          projects2.add(rexBuilder.makeInputRef(r, key));
+        }
+        for (AggregateCall newAggCall : newAggCalls) {
+          final SqlSplittableAggFunction splitter =
+              newAggCall.getAggregation()
+                  .unwrap(SqlSplittableAggFunction.class);
+          if (splitter != null) {
+            projects2.add(
+                splitter.singleton(rexBuilder, r.getRowType(), newAggCall));
+          }
+        }
+        if (projects2.size()
+            == aggregate.getGroupSet().cardinality() + newAggCalls.size()) {
+          // We successfully converted agg calls into projects.
+          r = RelOptUtil.createProject(r, projects2, null, true);
+          break b;
+        }
+      }
+      r = aggregateFactory.createAggregate(r, aggregate.indicator,
+          Mappings.apply(mapping, aggregate.getGroupSet()),
+          Mappings.apply2(mapping, aggregate.getGroupSets()), newAggCalls);
+    }
+    call.transformTo(r);
+  }
+
+  /** Computes the closure of a set of columns according to a given list of
+   * constraints. Each 'x = y' constraint causes bit y to be set if bit x is
+   * set, and vice versa. */
+  private static ImmutableBitSet keyColumns(ImmutableBitSet aggregateColumns,
+      ImmutableList<RexNode> predicates) {
+    SortedMap<Integer, BitSet> equivalence = new TreeMap<>();
+    for (RexNode pred : predicates) {
+      populateEquivalences(equivalence, pred);
+    }
+    ImmutableBitSet keyColumns = aggregateColumns;
+    for (Integer aggregateColumn : aggregateColumns) {
+      final BitSet bitSet = equivalence.get(aggregateColumn);
+      if (bitSet != null) {
+        keyColumns = keyColumns.union(bitSet);
+      }
+    }
+    return keyColumns;
+  }
+
+  private static void populateEquivalences(Map<Integer, BitSet> equivalence,
+      RexNode predicate) {
+    switch (predicate.getKind()) {
+    case EQUALS:
+      RexCall call = (RexCall) predicate;
+      final List<RexNode> operands = call.getOperands();
+      if (operands.get(0) instanceof RexInputRef) {
+        final RexInputRef ref0 = (RexInputRef) operands.get(0);
+        if (operands.get(1) instanceof RexInputRef) {
+          final RexInputRef ref1 = (RexInputRef) operands.get(1);
+          populateEquivalence(equivalence, ref0.getIndex(), ref1.getIndex());
+          populateEquivalence(equivalence, ref1.getIndex(), ref0.getIndex());
+        }
+      }
+    }
+  }
+
+  private static void populateEquivalence(Map<Integer, BitSet> equivalence,
+      int i0, int i1) {
+    BitSet bitSet = equivalence.get(i0);
+    if (bitSet == null) {
+      bitSet = new BitSet();
+      equivalence.put(i0, bitSet);
+    }
+    bitSet.set(i1);
+  }
+
+  /** Creates a {@link org.apache.calcite.sql.SqlSplittableAggFunction.Registry}
+   * that is a view of a list. */
+  private static <E> SqlSplittableAggFunction.Registry<E>
+  registry(final List<E> list) {
+    return new SqlSplittableAggFunction.Registry<E>() {
+      public int register(E e) {
+        int i = list.indexOf(e);
+        if (i < 0) {
+          i = list.size();
+          list.add(e);
+        }
+        return i;
+      }
+    };
+  }
+
+  /** Work space for an input to a join. */
+  private static class Side {
+    final Map<Integer, Integer> split = new HashMap<>();
+    RelNode newInput;
+    boolean aggregate;
   }
 }
 

http://git-wip-us.apache.org/repos/asf/incubator-calcite/blob/cf7a7a97/core/src/main/java/org/apache/calcite/rel/rules/AggregateReduceFunctionsRule.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/calcite/rel/rules/AggregateReduceFunctionsRule.java b/core/src/main/java/org/apache/calcite/rel/rules/AggregateReduceFunctionsRule.java
index 3a33e16..752c212 100644
--- a/core/src/main/java/org/apache/calcite/rel/rules/AggregateReduceFunctionsRule.java
+++ b/core/src/main/java/org/apache/calcite/rel/rules/AggregateReduceFunctionsRule.java
@@ -153,12 +153,8 @@ public class AggregateReduceFunctionsRule extends RelOptRule {
     // will add an expression to the end, and we will create an extra
     // project.
     RelNode input = oldAggRel.getInput();
-    final List<RexNode> inputExprs = new ArrayList<>();
-    for (RelDataTypeField field : input.getRowType().getFieldList()) {
-      inputExprs.add(
-          rexBuilder.makeInputRef(
-              field.getType(), inputExprs.size()));
-    }
+    final List<RexNode> inputExprs =
+        new ArrayList<>(rexBuilder.identityProjects(input.getRowType()));
 
     // create new agg function calls and rest of project list together
     for (AggregateCall oldCall : oldCalls) {

http://git-wip-us.apache.org/repos/asf/incubator-calcite/blob/cf7a7a97/core/src/main/java/org/apache/calcite/rex/RexBuilder.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/calcite/rex/RexBuilder.java b/core/src/main/java/org/apache/calcite/rex/RexBuilder.java
index 2cb47d5..ce716c8 100644
--- a/core/src/main/java/org/apache/calcite/rex/RexBuilder.java
+++ b/core/src/main/java/org/apache/calcite/rex/RexBuilder.java
@@ -119,7 +119,7 @@ public class RexBuilder {
 
   /** Creates a list of {@link org.apache.calcite.rex.RexInputRef} expressions,
    * projecting the fields of a given record type. */
-  public List<RexInputRef> identityProjects(final RelDataType rowType) {
+  public List<? extends RexNode> identityProjects(final RelDataType rowType) {
     return Lists.transform(rowType.getFieldList(), TO_INPUT_REF);
   }
 
@@ -533,28 +533,23 @@ public class RexBuilder {
   }
 
   private RexNode makeCastExactToBoolean(RelDataType toType, RexNode exp) {
-    return makeCall(
-        toType,
+    return makeCall(toType,
         SqlStdOperatorTable.NOT_EQUALS,
         ImmutableList.of(exp, makeZeroLiteral(exp.getType())));
   }
 
   private RexNode makeCastBooleanToExact(RelDataType toType, RexNode exp) {
-    final RexNode casted = makeCall(
-        SqlStdOperatorTable.CASE,
+    final RexNode casted = makeCall(SqlStdOperatorTable.CASE,
         exp,
         makeExactLiteral(BigDecimal.ONE, toType),
         makeZeroLiteral(toType));
     if (!exp.getType().isNullable()) {
       return casted;
     }
-    return makeCall(
-        toType,
+    return makeCall(toType,
         SqlStdOperatorTable.CASE,
-        ImmutableList.<RexNode>of(
-            makeCall(SqlStdOperatorTable.IS_NOT_NULL, exp),
-            casted,
-            makeNullLiteral(toType.getSqlTypeName())));
+        ImmutableList.of(makeCall(SqlStdOperatorTable.IS_NOT_NULL, exp),
+            casted, makeNullLiteral(toType.getSqlTypeName())));
   }
 
   private RexNode makeCastIntervalToExact(RelDataType toType, RexNode exp) {
@@ -605,14 +600,12 @@ public class RexBuilder {
     BigDecimal multiplier = BigDecimal.valueOf(endUnit.multiplier)
         .divide(BigDecimal.TEN.pow(scale));
     RelDataType decimalType =
-        getTypeFactory().createSqlType(
-            SqlTypeName.DECIMAL,
+        getTypeFactory().createSqlType(SqlTypeName.DECIMAL,
             scale + intervalType.getPrecision(),
             scale);
     RexNode value = decodeIntervalOrDecimal(ensureType(decimalType, exp, true));
     if (multiplier.longValue() != 1) {
-      value = makeCall(
-          SqlStdOperatorTable.MULTIPLY,
+      value = makeCall(SqlStdOperatorTable.MULTIPLY,
           value, makeExactLiteral(multiplier));
     }
     return encodeIntervalOrDecimal(value, toType, false);
@@ -639,8 +632,7 @@ public class RexBuilder {
       RelDataType type,
       boolean checkOverflow) {
     RelDataType bigintType =
-        typeFactory.createSqlType(
-            SqlTypeName.BIGINT);
+        typeFactory.createSqlType(SqlTypeName.BIGINT);
     RexNode cast = ensureType(bigintType, value, true);
     return makeReinterpretCast(type, cast, makeLiteral(checkOverflow));
   }
@@ -771,6 +763,8 @@ public class RexBuilder {
    * @param input Input relational expression
    * @param i    Ordinal of field
    * @return Reference to field
+   *
+   * @see #identityProjects(RelDataType)
    */
   public RexInputRef makeInputRef(RelNode input, int i) {
     return makeInputRef(input.getRowType().getFieldList().get(i).getType(), i);

http://git-wip-us.apache.org/repos/asf/incubator-calcite/blob/cf7a7a97/core/src/main/java/org/apache/calcite/sql/SqlAggFunction.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/calcite/sql/SqlAggFunction.java b/core/src/main/java/org/apache/calcite/sql/SqlAggFunction.java
index 4f51be2..944309a 100644
--- a/core/src/main/java/org/apache/calcite/sql/SqlAggFunction.java
+++ b/core/src/main/java/org/apache/calcite/sql/SqlAggFunction.java
@@ -16,6 +16,7 @@
  */
 package org.apache.calcite.sql;
 
+import org.apache.calcite.plan.Context;
 import org.apache.calcite.sql.type.SqlOperandTypeChecker;
 import org.apache.calcite.sql.type.SqlOperandTypeInference;
 import org.apache.calcite.sql.type.SqlReturnTypeInference;
@@ -26,7 +27,7 @@ import org.apache.calcite.sql.validate.SqlValidatorScope;
  * Abstract base class for the definition of an aggregate function: an operator
  * which aggregates sets of values into a result.
  */
-public abstract class SqlAggFunction extends SqlFunction {
+public abstract class SqlAggFunction extends SqlFunction implements Context {
   //~ Constructors -----------------------------------------------------------
 
   /** Creates a built-in SqlAggFunction. */
@@ -57,6 +58,10 @@ public abstract class SqlAggFunction extends SqlFunction {
 
   //~ Methods ----------------------------------------------------------------
 
+  public <T> T unwrap(Class<T> clazz) {
+    return clazz.isInstance(this) ? clazz.cast(this) : null;
+  }
+
   @Override public boolean isAggregator() {
     return true;
   }

http://git-wip-us.apache.org/repos/asf/incubator-calcite/blob/cf7a7a97/core/src/main/java/org/apache/calcite/sql/SqlSplittableAggFunction.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/calcite/sql/SqlSplittableAggFunction.java b/core/src/main/java/org/apache/calcite/sql/SqlSplittableAggFunction.java
new file mode 100644
index 0000000..465a4d6
--- /dev/null
+++ b/core/src/main/java/org/apache/calcite/sql/SqlSplittableAggFunction.java
@@ -0,0 +1,267 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to you under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.calcite.sql;
+
+import org.apache.calcite.rel.core.AggregateCall;
+import org.apache.calcite.rel.type.RelDataType;
+import org.apache.calcite.rel.type.RelDataTypeFactory;
+import org.apache.calcite.rel.type.RelDataTypeField;
+import org.apache.calcite.rex.RexBuilder;
+import org.apache.calcite.rex.RexInputRef;
+import org.apache.calcite.rex.RexNode;
+import org.apache.calcite.rex.RexUtil;
+import org.apache.calcite.sql.fun.SqlStdOperatorTable;
+import org.apache.calcite.sql.type.SqlTypeName;
+import org.apache.calcite.util.ImmutableIntList;
+import org.apache.calcite.util.mapping.Mappings;
+
+import com.google.common.collect.ImmutableList;
+
+import java.math.BigDecimal;
+import java.util.ArrayList;
+import java.util.List;
+
+/**
+ * Aggregate function that can be split into partial aggregates.
+ *
+ * <p>For example, {@code COUNT(x)} can be split into {@code COUNT(x)} on
+ * subsets followed by {@code SUM} to combine those counts.
+ */
+public interface SqlSplittableAggFunction {
+  AggregateCall split(AggregateCall aggregateCall,
+      Mappings.TargetMapping mapping);
+
+  /** Called to generate an aggregate for the other side of the join
+   * than the side aggregate call's arguments come from. Returns null if
+   * no aggregate is required. */
+  AggregateCall other(RelDataTypeFactory typeFactory, AggregateCall e);
+
+  /** Generates an aggregate call to merge sub-totals.
+   *
+   * <p>Most implementations will add a single aggregate call to
+   * {@code aggCalls}, and return a {@link RexInputRef} that points to it.
+   *
+   * @param rexBuilder Rex builder
+   * @param extra Place to define extra input expressions
+   * @param offset Offset due to grouping columns (and indicator columns if
+   *     applicable)
+   * @param inputRowType Input row type
+   * @param aggregateCall Source aggregate call
+   * @param leftSubTotal Ordinal of the sub-total coming from the left side of
+   *     the join, or -1 if there is no such sub-total
+   * @param rightSubTotal Ordinal of the sub-total coming from the right side
+   *     of the join, or -1 if there is no such sub-total
+   *
+   * @return Aggregate call
+   */
+  AggregateCall topSplit(RexBuilder rexBuilder, Registry<RexNode> extra,
+      int offset, RelDataType inputRowType, AggregateCall aggregateCall,
+      int leftSubTotal, int rightSubTotal);
+
+  /** Generates an expression for the value of the aggregate function when
+   * applied to a single row.
+   *
+   * <p>For example, if there is one row:
+   * <ul>
+   *   <li>{@code SUM(x)} is {@code x}
+   *   <li>{@code MIN(x)} is {@code x}
+   *   <li>{@code MAX(x)} is {@code x}
+   *   <li>{@code COUNT(x)} is {@code CASE WHEN x IS NOT NULL THEN 1 ELSE 0 END 1}
+   *   which can be simplified to {@code 1} if {@code x} is never null
+   *   <li>{@code COUNT(*)} is 1
+   * </ul>
+   *
+   * @param rexBuilder Rex builder
+   * @param inputRowType Input row type
+   * @param aggregateCall Aggregate call
+   *
+   * @return Expression for single row
+   */
+  RexNode singleton(RexBuilder rexBuilder, RelDataType inputRowType,
+      AggregateCall aggregateCall);
+
+  /** Collection in which one can register an element. Registering may return
+   * a reference to an existing element. */
+  interface Registry<E> {
+    int register(E e);
+  }
+
+  /** Splitting strategy for {@code COUNT}.
+   *
+   * <p>COUNT splits into itself followed by SUM. (Actually
+   * SUM0, because the total needs to be 0, not null, if there are 0 rows.)
+   * This rule works for any number of arguments to COUNT, including COUNT(*).
+   */
+  class CountSplitter implements SqlSplittableAggFunction {
+    public static final CountSplitter INSTANCE = new CountSplitter();
+
+    public AggregateCall split(AggregateCall aggregateCall,
+        Mappings.TargetMapping mapping) {
+      return aggregateCall.transform(mapping);
+    }
+
+    public AggregateCall other(RelDataTypeFactory typeFactory, AggregateCall e) {
+      return AggregateCall.create(SqlStdOperatorTable.COUNT, false,
+          ImmutableIntList.of(), -1,
+          typeFactory.createSqlType(SqlTypeName.BIGINT), null);
+    }
+
+    public AggregateCall topSplit(RexBuilder rexBuilder,
+        Registry<RexNode> extra, int offset, RelDataType inputRowType,
+        AggregateCall aggregateCall, int leftSubTotal, int rightSubTotal) {
+      final List<RexNode> merges = new ArrayList<>();
+      if (leftSubTotal >= 0) {
+        merges.add(
+            rexBuilder.makeInputRef(aggregateCall.type, leftSubTotal));
+      }
+      if (rightSubTotal >= 0) {
+        merges.add(
+            rexBuilder.makeInputRef(aggregateCall.type, rightSubTotal));
+      }
+      RexNode node;
+      switch (merges.size()) {
+      case 1:
+        node = merges.get(0);
+        break;
+      case 2:
+        node = rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, merges);
+        break;
+      default:
+        throw new AssertionError("unexpected count " + merges);
+      }
+      int ordinal = extra.register(node);
+      return AggregateCall.create(SqlStdOperatorTable.SUM0, false,
+          ImmutableList.of(ordinal), -1, aggregateCall.type,
+          aggregateCall.name);
+    }
+
+    /**
+     * {@inheritDoc}
+     *
+     * COUNT(*) and COUNT applied to all NOT NULL arguments become {@code 1};
+     * otherwise {@code CASE WHEN arg0 IS NOT NULL THEN 1 ELSE 0 END}.
+     */
+    public RexNode singleton(RexBuilder rexBuilder, RelDataType inputRowType,
+        AggregateCall aggregateCall) {
+      final List<RexNode> predicates = new ArrayList<>();
+      for (Integer arg : aggregateCall.getArgList()) {
+        final RelDataType type = inputRowType.getFieldList().get(arg).getType();
+        if (type.isNullable()) {
+          predicates.add(
+              rexBuilder.makeCall(SqlStdOperatorTable.IS_NOT_NULL,
+                  rexBuilder.makeInputRef(type, arg)));
+        }
+      }
+      final RexNode predicate =
+          RexUtil.composeConjunction(rexBuilder, predicates, true);
+      if (predicate == null) {
+        return rexBuilder.makeExactLiteral(BigDecimal.ONE);
+      } else {
+        return rexBuilder.makeCall(SqlStdOperatorTable.CASE, predicate,
+            rexBuilder.makeExactLiteral(BigDecimal.ONE),
+            rexBuilder.makeExactLiteral(BigDecimal.ZERO));
+      }
+    }
+  }
+
+  /** Aggregate function that splits into two applications of itself.
+   *
+   * <p>Examples are MIN and MAX. */
+  class SelfSplitter implements SqlSplittableAggFunction {
+    public static final SelfSplitter INSTANCE = new SelfSplitter();
+
+    public RexNode singleton(RexBuilder rexBuilder,
+        RelDataType inputRowType, AggregateCall aggregateCall) {
+      final int arg = aggregateCall.getArgList().get(0);
+      final RelDataTypeField field = inputRowType.getFieldList().get(arg);
+      return rexBuilder.makeInputRef(field.getType(), arg);
+    }
+
+    public AggregateCall split(AggregateCall aggregateCall,
+        Mappings.TargetMapping mapping) {
+      return aggregateCall.transform(mapping);
+    }
+
+    public AggregateCall other(RelDataTypeFactory typeFactory, AggregateCall e) {
+      return null; // no aggregate function required on other side
+    }
+
+    public AggregateCall topSplit(RexBuilder rexBuilder,
+        Registry<RexNode> extra, int offset, RelDataType inputRowType,
+        AggregateCall aggregateCall, int leftSubTotal, int rightSubTotal) {
+      assert (leftSubTotal >= 0) != (rightSubTotal >= 0);
+      final int arg = leftSubTotal >= 0 ? leftSubTotal : rightSubTotal;
+      return aggregateCall.copy(ImmutableIntList.of(arg), -1);
+    }
+  }
+
+  /** Splitting strategy for {@code SUM}. */
+  class SumSplitter implements SqlSplittableAggFunction {
+    public static final SumSplitter INSTANCE = new SumSplitter();
+
+    public RexNode singleton(RexBuilder rexBuilder,
+        RelDataType inputRowType, AggregateCall aggregateCall) {
+      final int arg = aggregateCall.getArgList().get(0);
+      final RelDataTypeField field = inputRowType.getFieldList().get(arg);
+      return rexBuilder.makeInputRef(field.getType(), arg);
+    }
+
+    public AggregateCall split(AggregateCall aggregateCall,
+        Mappings.TargetMapping mapping) {
+      return aggregateCall.transform(mapping);
+    }
+
+    public AggregateCall other(RelDataTypeFactory typeFactory, AggregateCall e) {
+      return AggregateCall.create(SqlStdOperatorTable.COUNT, false,
+          ImmutableIntList.of(), -1,
+          typeFactory.createSqlType(SqlTypeName.BIGINT), null);
+    }
+
+    public AggregateCall topSplit(RexBuilder rexBuilder,
+        Registry<RexNode> extra, int offset, RelDataType inputRowType,
+        AggregateCall aggregateCall, int leftSubTotal, int rightSubTotal) {
+      final List<RexNode> merges = new ArrayList<>();
+      final List<RelDataTypeField> fieldList = inputRowType.getFieldList();
+      if (leftSubTotal >= 0) {
+        final RelDataType type = fieldList.get(leftSubTotal).getType();
+        merges.add(rexBuilder.makeInputRef(type, leftSubTotal));
+      }
+      if (rightSubTotal >= 0) {
+        final RelDataType type = fieldList.get(rightSubTotal).getType();
+        merges.add(rexBuilder.makeInputRef(type, rightSubTotal));
+      }
+      RexNode node;
+      switch (merges.size()) {
+      case 1:
+        node = merges.get(0);
+        break;
+      case 2:
+        node = rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, merges);
+        node = rexBuilder.makeAbstractCast(aggregateCall.type, node);
+        break;
+      default:
+        throw new AssertionError("unexpected count " + merges);
+      }
+      int ordinal = extra.register(node);
+      return AggregateCall.create(SqlStdOperatorTable.SUM, false,
+          ImmutableList.of(ordinal), -1, aggregateCall.type,
+          aggregateCall.name);
+    }
+  }
+}
+
+// End SqlSplittableAggFunction.java

http://git-wip-us.apache.org/repos/asf/incubator-calcite/blob/cf7a7a97/core/src/main/java/org/apache/calcite/sql/fun/SqlCountAggFunction.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/calcite/sql/fun/SqlCountAggFunction.java b/core/src/main/java/org/apache/calcite/sql/fun/SqlCountAggFunction.java
index a08a96e..3feefc2 100644
--- a/core/src/main/java/org/apache/calcite/sql/fun/SqlCountAggFunction.java
+++ b/core/src/main/java/org/apache/calcite/sql/fun/SqlCountAggFunction.java
@@ -22,6 +22,7 @@ import org.apache.calcite.sql.SqlAggFunction;
 import org.apache.calcite.sql.SqlCall;
 import org.apache.calcite.sql.SqlFunctionCategory;
 import org.apache.calcite.sql.SqlKind;
+import org.apache.calcite.sql.SqlSplittableAggFunction;
 import org.apache.calcite.sql.SqlSyntax;
 import org.apache.calcite.sql.type.OperandTypes;
 import org.apache.calcite.sql.type.ReturnTypes;
@@ -83,6 +84,13 @@ public class SqlCountAggFunction extends SqlAggFunction {
     }
     return super.deriveType(validator, scope, call);
   }
+
+  @Override public <T> T unwrap(Class<T> clazz) {
+    if (clazz == SqlSplittableAggFunction.class) {
+      return clazz.cast(SqlSplittableAggFunction.CountSplitter.INSTANCE);
+    }
+    return super.unwrap(clazz);
+  }
 }
 
 // End SqlCountAggFunction.java

http://git-wip-us.apache.org/repos/asf/incubator-calcite/blob/cf7a7a97/core/src/main/java/org/apache/calcite/sql/fun/SqlMinMaxAggFunction.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/calcite/sql/fun/SqlMinMaxAggFunction.java b/core/src/main/java/org/apache/calcite/sql/fun/SqlMinMaxAggFunction.java
index 74b3c99..2ce8d6f 100644
--- a/core/src/main/java/org/apache/calcite/sql/fun/SqlMinMaxAggFunction.java
+++ b/core/src/main/java/org/apache/calcite/sql/fun/SqlMinMaxAggFunction.java
@@ -21,6 +21,7 @@ import org.apache.calcite.rel.type.RelDataTypeFactory;
 import org.apache.calcite.sql.SqlAggFunction;
 import org.apache.calcite.sql.SqlFunctionCategory;
 import org.apache.calcite.sql.SqlKind;
+import org.apache.calcite.sql.SqlSplittableAggFunction;
 import org.apache.calcite.sql.type.OperandTypes;
 import org.apache.calcite.sql.type.ReturnTypes;
 import org.apache.calcite.util.Util;
@@ -112,6 +113,12 @@ public class SqlMinMaxAggFunction extends SqlAggFunction {
     }
   }
 
+  @Override public <T> T unwrap(Class<T> clazz) {
+    if (clazz == SqlSplittableAggFunction.class) {
+      return clazz.cast(SqlSplittableAggFunction.SelfSplitter.INSTANCE);
+    }
+    return super.unwrap(clazz);
+  }
 }
 
 // End SqlMinMaxAggFunction.java

http://git-wip-us.apache.org/repos/asf/incubator-calcite/blob/cf7a7a97/core/src/main/java/org/apache/calcite/sql/fun/SqlSumAggFunction.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/calcite/sql/fun/SqlSumAggFunction.java b/core/src/main/java/org/apache/calcite/sql/fun/SqlSumAggFunction.java
index 9b2835d..41b9d1d 100644
--- a/core/src/main/java/org/apache/calcite/sql/fun/SqlSumAggFunction.java
+++ b/core/src/main/java/org/apache/calcite/sql/fun/SqlSumAggFunction.java
@@ -21,6 +21,7 @@ import org.apache.calcite.rel.type.RelDataTypeFactory;
 import org.apache.calcite.sql.SqlAggFunction;
 import org.apache.calcite.sql.SqlFunctionCategory;
 import org.apache.calcite.sql.SqlKind;
+import org.apache.calcite.sql.SqlSplittableAggFunction;
 import org.apache.calcite.sql.type.OperandTypes;
 import org.apache.calcite.sql.type.ReturnTypes;
 
@@ -35,6 +36,7 @@ import java.util.List;
  * is the same type.
  */
 public class SqlSumAggFunction extends SqlAggFunction {
+
   //~ Instance fields --------------------------------------------------------
 
   private final RelDataType type;
@@ -65,6 +67,13 @@ public class SqlSumAggFunction extends SqlAggFunction {
   public RelDataType getReturnType(RelDataTypeFactory typeFactory) {
     return type;
   }
+
+  @Override public <T> T unwrap(Class<T> clazz) {
+    if (clazz == SqlSplittableAggFunction.class) {
+      return clazz.cast(SqlSplittableAggFunction.SumSplitter.INSTANCE);
+    }
+    return super.unwrap(clazz);
+  }
 }
 
 // End SqlSumAggFunction.java

http://git-wip-us.apache.org/repos/asf/incubator-calcite/blob/cf7a7a97/core/src/main/java/org/apache/calcite/sql/fun/SqlSumEmptyIsZeroAggFunction.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/calcite/sql/fun/SqlSumEmptyIsZeroAggFunction.java b/core/src/main/java/org/apache/calcite/sql/fun/SqlSumEmptyIsZeroAggFunction.java
index 50fa487..965ae75 100644
--- a/core/src/main/java/org/apache/calcite/sql/fun/SqlSumEmptyIsZeroAggFunction.java
+++ b/core/src/main/java/org/apache/calcite/sql/fun/SqlSumEmptyIsZeroAggFunction.java
@@ -21,6 +21,7 @@ import org.apache.calcite.rel.type.RelDataTypeFactory;
 import org.apache.calcite.sql.SqlAggFunction;
 import org.apache.calcite.sql.SqlFunctionCategory;
 import org.apache.calcite.sql.SqlKind;
+import org.apache.calcite.sql.SqlSplittableAggFunction;
 import org.apache.calcite.sql.type.OperandTypes;
 import org.apache.calcite.sql.type.ReturnTypes;
 import org.apache.calcite.sql.type.SqlTypeName;
@@ -59,6 +60,13 @@ public class SqlSumEmptyIsZeroAggFunction extends SqlAggFunction {
     return typeFactory.createTypeWithNullability(
         typeFactory.createSqlType(SqlTypeName.ANY), true);
   }
+
+  @Override public <T> T unwrap(Class<T> clazz) {
+    if (clazz == SqlSplittableAggFunction.class) {
+      return clazz.cast(SqlSplittableAggFunction.SumSplitter.INSTANCE);
+    }
+    return super.unwrap(clazz);
+  }
 }
 
 // End SqlSumEmptyIsZeroAggFunction.java

http://git-wip-us.apache.org/repos/asf/incubator-calcite/blob/cf7a7a97/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
index 3739af5..1901f89 100644
--- a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
+++ b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
@@ -1563,12 +1563,13 @@ public class RelOptRulesTest extends RelOptTestBase {
   @Test public void testPushFilterWithRank() throws Exception {
     HepProgram program = new HepProgramBuilder().addRuleInstance(
         FilterProjectTransposeRule.INSTANCE).build();
-    checkPlanning(program, "select e1.ename, r\n"
+    final String sql = "select e1.ename, r\n"
         + "from (\n"
         + "  select ename, "
         + "  rank() over(partition by  deptno order by sal) as r "
         + "  from emp) e1\n"
-        + "where r < 2");
+        + "where r < 2";
+    checkPlanning(program, sql);
   }
 
   @Test public void testPushFilterWithRankExpr() throws Exception {
@@ -1587,14 +1588,13 @@ public class RelOptRulesTest extends RelOptTestBase {
         .addRuleInstance(AggregateProjectMergeRule.INSTANCE)
         .build();
     final HepProgram program = new HepProgramBuilder()
-        .addRuleInstance(AggregateJoinTransposeRule.INSTANCE)
+        .addRuleInstance(AggregateJoinTransposeRule.EXTENDED)
         .build();
-    checkPlanning(tester, preProgram,
-        new HepPlanner(program),
-        "select e.empno,d.deptno \n"
-                + "from (select * from sales.emp where empno = 10) as e "
-                + "join sales.dept as d on e.empno = d.deptno "
-                + "group by e.empno,d.deptno");
+    final String sql = "select e.empno,d.deptno \n"
+        + "from (select * from sales.emp where empno = 10) as e "
+        + "join sales.dept as d on e.empno = d.deptno "
+        + "group by e.empno,d.deptno";
+    checkPlanning(tester, preProgram, new HepPlanner(program), sql);
   }
 
   @Test public void testPushAggregateThroughJoin2() throws Exception {
@@ -1602,15 +1602,14 @@ public class RelOptRulesTest extends RelOptTestBase {
         .addRuleInstance(AggregateProjectMergeRule.INSTANCE)
         .build();
     final HepProgram program = new HepProgramBuilder()
-        .addRuleInstance(AggregateJoinTransposeRule.INSTANCE)
+        .addRuleInstance(AggregateJoinTransposeRule.EXTENDED)
         .build();
-    checkPlanning(tester, preProgram,
-        new HepPlanner(program),
-        "select e.empno,d.deptno \n"
-                + "from (select * from sales.emp where empno = 10) as e "
-                + "join sales.dept as d on e.empno = d.deptno "
-                + "and e.deptno + e.empno = d.deptno + 5 "
-                + "group by e.empno,d.deptno");
+    final String sql = "select e.empno,d.deptno \n"
+        + "from (select * from sales.emp where empno = 10) as e "
+        + "join sales.dept as d on e.empno = d.deptno "
+        + "and e.deptno + e.empno = d.deptno + 5 "
+        + "group by e.empno,d.deptno";
+    checkPlanning(tester, preProgram, new HepPlanner(program), sql);
   }
 
   @Test public void testPushAggregateThroughJoin3() throws Exception {
@@ -1618,29 +1617,77 @@ public class RelOptRulesTest extends RelOptTestBase {
         .addRuleInstance(AggregateProjectMergeRule.INSTANCE)
         .build();
     final HepProgram program = new HepProgramBuilder()
-        .addRuleInstance(AggregateJoinTransposeRule.INSTANCE)
+        .addRuleInstance(AggregateJoinTransposeRule.EXTENDED)
         .build();
-    checkPlanning(tester, preProgram,
-        new HepPlanner(program),
-        "select e.empno,d.deptno \n"
-                + "from (select * from sales.emp where empno = 10) as e "
-                + "join sales.dept as d on e.empno < d.deptno "
-                + "group by e.empno,d.deptno");
+    final String sql = "select e.empno,d.deptno \n"
+        + "from (select * from sales.emp where empno = 10) as e "
+        + "join sales.dept as d on e.empno < d.deptno "
+        + "group by e.empno,d.deptno";
+    checkPlanning(tester, preProgram, new HepPlanner(program), sql);
   }
 
-  @Test public void testPushAggregateThroughJoin4() throws Exception {
+  /** SUM is the easiest aggregate function to split. */
+  @Test public void testPushAggregateSumThroughJoin() throws Exception {
     final HepProgram preProgram = new HepProgramBuilder()
         .addRuleInstance(AggregateProjectMergeRule.INSTANCE)
         .build();
     final HepProgram program = new HepProgramBuilder()
-        .addRuleInstance(AggregateJoinTransposeRule.INSTANCE)
+        .addRuleInstance(AggregateJoinTransposeRule.EXTENDED)
         .build();
-    checkPlanning(tester, preProgram,
-        new HepPlanner(program),
-        "select e.empno,sum(sal) \n"
-                + "from (select * from sales.emp where empno = 10) as e "
-                + "join sales.dept as d on e.empno = d.deptno "
-                + "group by e.empno,d.deptno");
+    final String sql = "select e.empno,sum(sal) \n"
+        + "from (select * from sales.emp where empno = 10) as e "
+        + "join sales.dept as d on e.empno = d.deptno "
+        + "group by e.empno,d.deptno";
+    checkPlanning(tester, preProgram, new HepPlanner(program), sql);
+  }
+
+  /** Push a variety of aggregate functions. */
+  @Test public void testPushAggregateFunctionsThroughJoin() throws Exception {
+    final HepProgram preProgram = new HepProgramBuilder()
+        .addRuleInstance(AggregateProjectMergeRule.INSTANCE)
+        .build();
+    final HepProgram program = new HepProgramBuilder()
+        .addRuleInstance(AggregateJoinTransposeRule.EXTENDED)
+        .build();
+    final String sql = "select e.empno,\n"
+        + "  min(sal) as min_sal, min(e.deptno) as min_deptno,\n"
+        + "  sum(sal) + 1 as sum_sal_plus, max(sal) as max_sal,\n"
+        + "  sum(sal) as sum_sal_2, count(sal) as count_sal\n"
+        + "from sales.emp as e\n"
+        + "join sales.dept as d on e.empno = d.deptno\n"
+        + "group by e.empno,d.deptno";
+    checkPlanning(tester, preProgram, new HepPlanner(program), sql);
+  }
+
+  /** Push a aggregate functions into a relation that is unique on the join
+   * key. */
+  @Test public void testPushAggregateThroughJoinDistinct() throws Exception {
+    final HepProgram preProgram = new HepProgramBuilder()
+        .addRuleInstance(AggregateProjectMergeRule.INSTANCE)
+        .build();
+    final HepProgram program = new HepProgramBuilder()
+        .addRuleInstance(AggregateJoinTransposeRule.EXTENDED)
+        .build();
+    final String sql = "select d.deptno,\n"
+        + "  sum(sal) as sum_sal, count(*) as c\n"
+        + "from sales.emp as e\n"
+        + "join (select distinct deptno from sales.dept) as d\n"
+        + "  on e.empno = d.deptno\n"
+        + "group by d.deptno";
+    checkPlanning(tester, preProgram, new HepPlanner(program), sql);
+  }
+
+  /** Push count(*) through join, no GROUP BY. */
+  @Test public void testPushAggregateSumNoGroup() throws Exception {
+    final HepProgram preProgram = new HepProgramBuilder()
+        .addRuleInstance(AggregateProjectMergeRule.INSTANCE)
+        .build();
+    final HepProgram program = new HepProgramBuilder()
+        .addRuleInstance(AggregateJoinTransposeRule.EXTENDED)
+        .build();
+    final String sql =
+        "select count(*) from sales.emp join sales.dept using (deptno)";
+    checkPlanning(tester, preProgram, new HepPlanner(program), sql);
   }
 
   @Test public void testSwapOuterJoin() throws Exception {

http://git-wip-us.apache.org/repos/asf/incubator-calcite/blob/cf7a7a97/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
----------------------------------------------------------------------
diff --git a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
index d01a8d1..044b6b2 100644
--- a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
+++ b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
@@ -3383,14 +3383,16 @@ LogicalAggregate(group=[{0, 10}])
         </Resource>
         <Resource name="planAfter">
             <![CDATA[
-LogicalAggregate(group=[{0, 10}])
-  LogicalJoin(condition=[AND(=($0, $10), =($9, $12))], joinType=[inner])
-    LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], SLACKER=[$8], $f9=[+($7, $0)])
-      LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], SLACKER=[$8])
-        LogicalFilter(condition=[=($0, 10)])
-          LogicalTableScan(table=[[CATALOG, SALES, EMP]])
-    LogicalProject(DEPTNO=[$0], NAME=[$1], $f2=[+($0, 5)])
-      LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
+LogicalAggregate(group=[{0, 2}])
+  LogicalJoin(condition=[AND(=($0, $2), =($1, $3))], joinType=[inner])
+    LogicalAggregate(group=[{0, 9}])
+      LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], SLACKER=[$8], $f9=[+($7, $0)])
+        LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], SLACKER=[$8])
+          LogicalFilter(condition=[=($0, 10)])
+            LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+    LogicalAggregate(group=[{0, 2}])
+      LogicalProject(DEPTNO=[$0], NAME=[$1], $f2=[+($0, 5)])
+        LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
 ]]>
         </Resource>
     </TestCase>
@@ -3420,7 +3422,7 @@ LogicalAggregate(group=[{0, 9}])
 ]]>
         </Resource>
     </TestCase>
-    <TestCase name="testPushAggregateThroughJoin4">
+    <TestCase name="testPushAggregateSumThroughJoin">
         <Resource name="sql">
             <![CDATA[select e.empno,sum(sal) 
 from (select * from sales.emp where empno = 10) as e join sales.dept as d on e.empno = d.deptno group by e.empno,d.deptno]]>
@@ -3439,12 +3441,15 @@ LogicalProject(EMPNO=[$0], EXPR$1=[$2])
         <Resource name="planAfter">
             <![CDATA[
 LogicalProject(EMPNO=[$0], EXPR$1=[$2])
-  LogicalAggregate(group=[{0, 9}], EXPR$1=[SUM($5)])
-    LogicalJoin(condition=[=($0, $9)], joinType=[inner])
-      LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], SLACKER=[$8])
-        LogicalFilter(condition=[=($0, 10)])
-          LogicalTableScan(table=[[CATALOG, SALES, EMP]])
-      LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
+  LogicalProject($f0=[$0], $f1=[$2], $f2=[$4])
+    LogicalProject($f0=[$0], $f1=[$1], $f2=[$2], $f3=[$3], $f4=[CAST(*($1, $3)):INTEGER NOT NULL])
+      LogicalJoin(condition=[=($0, $2)], joinType=[inner])
+        LogicalAggregate(group=[{0}], EXPR$1=[SUM($5)])
+          LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], SLACKER=[$8])
+            LogicalFilter(condition=[=($0, 10)])
+              LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+        LogicalAggregate(group=[{0}], agg#0=[COUNT()])
+          LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
 ]]>
         </Resource>
     </TestCase>
@@ -3772,4 +3777,89 @@ LogicalProject(DEPTNO=[$0])
 ]]>
         </Resource>
     </TestCase>
+    <TestCase name="testPushAggregateFunctionsThroughJoin">
+        <Resource name="sql">
+            <![CDATA[select e.empno,
+  min(sal) as min_sal, min(e.deptno) as min_deptno,
+  sum(sal) + 1 as sum_sal_plus, max(sal) as max_sal,
+  sum(sal) as sum_sal_2, count(sal) as count_sal
+from sales.emp as e
+join sales.dept as d on e.empno = d.deptno
+group by e.empno,d.deptno]]>
+        </Resource>
+        <Resource name="planBefore">
+            <![CDATA[
+LogicalProject(EMPNO=[$0], MIN_SAL=[$2], MIN_DEPTNO=[$3], SUM_SAL_PLUS=[+($4, 1)], MAX_SAL=[$5], SUM_SAL_2=[$4], COUNT_SAL=[$6])
+  LogicalAggregate(group=[{0, 9}], MIN_SAL=[MIN($5)], MIN_DEPTNO=[MIN($7)], SUM_SAL_2=[SUM($5)], MAX_SAL=[MAX($5)], COUNT_SAL=[COUNT()])
+    LogicalJoin(condition=[=($0, $9)], joinType=[inner])
+      LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+      LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
+]]>
+        </Resource>
+        <Resource name="planAfter">
+            <![CDATA[
+LogicalProject(EMPNO=[$0], MIN_SAL=[$2], MIN_DEPTNO=[$3], SUM_SAL_PLUS=[+($4, 1)], MAX_SAL=[$5], SUM_SAL_2=[$4], COUNT_SAL=[$6])
+  LogicalProject($f0=[$0], $f1=[$6], $f2=[$1], $f3=[$2], $f4=[$8], $f5=[$4], $f6=[$9])
+    LogicalProject($f0=[$0], $f1=[$1], $f2=[$2], $f3=[$3], $f4=[$4], $f5=[$5], $f6=[$6], $f7=[$7], $f8=[CAST(*($3, $7)):INTEGER NOT NULL], $f9=[*($5, $7)])
+      LogicalJoin(condition=[=($0, $6)], joinType=[inner])
+        LogicalAggregate(group=[{0}], MIN_SAL=[MIN($5)], MIN_DEPTNO=[MIN($7)], SUM_SAL_2=[SUM($5)], MAX_SAL=[MAX($5)], COUNT_SAL=[COUNT()])
+          LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+        LogicalAggregate(group=[{0}], agg#0=[COUNT()])
+          LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
+]]>
+        </Resource>
+    </TestCase>
+    <TestCase name="testPushAggregateSumNoGroup">
+        <Resource name="sql">
+            <![CDATA[select count(*) from sales.emp join sales.dept using (deptno)]]>
+        </Resource>
+        <Resource name="planBefore">
+            <![CDATA[
+LogicalAggregate(group=[{}], EXPR$0=[COUNT()])
+  LogicalJoin(condition=[=($7, $9)], joinType=[inner])
+    LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+    LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
+]]>
+        </Resource>
+        <Resource name="planAfter">
+            <![CDATA[
+LogicalAggregate(group=[{}], EXPR$0=[$SUM0($4)])
+  LogicalProject($f0=[$0], $f1=[$1], $f2=[$2], $f3=[$3], $f4=[*($1, $3)])
+    LogicalJoin(condition=[=($0, $2)], joinType=[inner])
+      LogicalAggregate(group=[{7}], EXPR$0=[COUNT()])
+        LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+      LogicalAggregate(group=[{0}], EXPR$0=[COUNT()])
+        LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
+]]>
+        </Resource>
+    </TestCase>
+    <TestCase name="testPushAggregateThroughJoinDistinct">
+        <Resource name="sql">
+            <![CDATA[select d.deptno,
+  sum(sal) as sum_sal, count(*) as c
+from sales.emp as e
+join (select distinct deptno from sales.dept) as d
+  on e.empno = d.deptno
+group by d.deptno]]>
+        </Resource>
+        <Resource name="planBefore">
+            <![CDATA[
+LogicalAggregate(group=[{9}], SUM_SAL=[SUM($5)], C=[COUNT()])
+  LogicalJoin(condition=[=($0, $9)], joinType=[inner])
+    LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+    LogicalAggregate(group=[{0}])
+      LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
+]]>
+        </Resource>
+        <Resource name="planAfter">
+            <![CDATA[
+LogicalProject($f0=[$3], $f1=[$1], $f2=[$2])
+  LogicalJoin(condition=[=($0, $3)], joinType=[inner])
+    LogicalAggregate(group=[{0}], SUM_SAL=[SUM($5)], C=[COUNT()])
+      LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+    LogicalAggregate(group=[{0}])
+      LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
+]]>
+        </Resource>
+    </TestCase>
 </Root>

http://git-wip-us.apache.org/repos/asf/incubator-calcite/blob/cf7a7a97/core/src/test/resources/sql/agg.oq
----------------------------------------------------------------------
diff --git a/core/src/test/resources/sql/agg.oq b/core/src/test/resources/sql/agg.oq
index aeecc04..771ff86 100644
--- a/core/src/test/resources/sql/agg.oq
+++ b/core/src/test/resources/sql/agg.oq
@@ -52,7 +52,6 @@ select count(deptno) as c from emp;
 !ok
 
 # composite count
-!if (false) {
 select count(deptno, ename, 1, deptno) as c from emp;
 +---+
 | C |
@@ -62,7 +61,6 @@ select count(deptno, ename, 1, deptno) as c from emp;
 (1 row)
 
 !ok
-!}
 
 select city, gender as c from emps;
 +---------------+---+
@@ -773,6 +771,327 @@ select avg(comm) as a, count(comm) as c from "scott".emp where empno < 7844;
 
 !ok
 
+# [CALCITE-751] Aggregate join transpose
+select count(*)
+from "scott".emp join "scott".dept using (deptno);
++--------+
+| EXPR$0 |
++--------+
+|     14 |
++--------+
+(1 row)
+
+!ok
+EnumerableAggregate(group=[{}], EXPR$0=[COUNT()])
+  EnumerableJoin(condition=[=($0, $2)], joinType=[inner])
+    EnumerableCalc(expr#0..2=[{inputs}], DEPTNO=[$t0])
+      EnumerableTableScan(table=[[scott, DEPT]])
+    EnumerableCalc(expr#0..7=[{inputs}], EMPNO=[$t0], DEPTNO=[$t7])
+      EnumerableTableScan(table=[[scott, EMP]])
+!plan
+
+# Push sum: splits into sum * count
+select sum(sal)
+from "scott".emp join "scott".dept using (deptno);
++----------+
+| EXPR$0   |
++----------+
+| 29025.00 |
++----------+
+(1 row)
+
+!ok
+EnumerableAggregate(group=[{}], EXPR$0=[SUM($2)])
+  EnumerableJoin(condition=[=($0, $3)], joinType=[inner])
+    EnumerableCalc(expr#0..2=[{inputs}], DEPTNO=[$t0])
+      EnumerableTableScan(table=[[scott, DEPT]])
+    EnumerableCalc(expr#0..7=[{inputs}], EMPNO=[$t0], SAL=[$t5], DEPTNO=[$t7])
+      EnumerableTableScan(table=[[scott, EMP]])
+!plan
+
+# Push sum; no aggregate needed after join
+select sum(sal)
+from "scott".emp join "scott".dept using (deptno)
+group by emp.deptno, dept.deptno;
++----------+
+| EXPR$0   |
++----------+
+| 10875.00 |
+|  8750.00 |
+|  9400.00 |
++----------+
+(3 rows)
+
+!ok
+EnumerableCalc(expr#0..2=[{inputs}], EXPR$0=[$t2])
+  EnumerableAggregate(group=[{0, 3}], EXPR$0=[SUM($2)])
+    EnumerableJoin(condition=[=($0, $3)], joinType=[inner])
+      EnumerableCalc(expr#0..2=[{inputs}], DEPTNO=[$t0])
+        EnumerableTableScan(table=[[scott, DEPT]])
+      EnumerableCalc(expr#0..7=[{inputs}], EMPNO=[$t0], SAL=[$t5], DEPTNO=[$t7])
+        EnumerableTableScan(table=[[scott, EMP]])
+!plan
+
+# Push sum; group by only one of the join keys
+select sum(sal)
+from "scott".emp join "scott".dept using (deptno)
+group by emp.deptno;
++----------+
+| EXPR$0   |
++----------+
+| 10875.00 |
+|  8750.00 |
+|  9400.00 |
++----------+
+(3 rows)
+
+!ok
+EnumerableCalc(expr#0..1=[{inputs}], EXPR$0=[$t1])
+  EnumerableAggregate(group=[{3}], EXPR$0=[SUM($2)])
+    EnumerableJoin(condition=[=($0, $3)], joinType=[inner])
+      EnumerableCalc(expr#0..2=[{inputs}], DEPTNO=[$t0])
+        EnumerableTableScan(table=[[scott, DEPT]])
+      EnumerableCalc(expr#0..7=[{inputs}], EMPNO=[$t0], SAL=[$t5], DEPTNO=[$t7])
+        EnumerableTableScan(table=[[scott, EMP]])
+!plan
+
+# Push min; Join-Aggregate is optimized to SemiJoin
+select min(sal)
+from "scott".emp join "scott".dept using (deptno)
+group by emp.deptno;
++---------+
+| EXPR$0  |
++---------+
+| 1300.00 |
+|  800.00 |
+|  950.00 |
++---------+
+(3 rows)
+
+!ok
+EnumerableCalc(expr#0..1=[{inputs}], EXPR$0=[$t1])
+  EnumerableAggregate(group=[{3}], EXPR$0=[MIN($2)])
+    EnumerableJoin(condition=[=($0, $3)], joinType=[inner])
+      EnumerableCalc(expr#0..2=[{inputs}], DEPTNO=[$t0])
+        EnumerableTableScan(table=[[scott, DEPT]])
+      EnumerableCalc(expr#0..7=[{inputs}], EMPNO=[$t0], SAL=[$t5], DEPTNO=[$t7])
+        EnumerableTableScan(table=[[scott, EMP]])
+!plan
+
+# Push sum and count
+select count(*) as c, sum(sal) as s
+from "scott".emp join "scott".dept using (deptno);
++----+----------+
+| C  | S        |
++----+----------+
+| 14 | 29025.00 |
++----+----------+
+(1 row)
+
+!ok
+EnumerableAggregate(group=[{}], C=[COUNT()], S=[SUM($2)])
+  EnumerableJoin(condition=[=($0, $3)], joinType=[inner])
+    EnumerableCalc(expr#0..2=[{inputs}], DEPTNO=[$t0])
+      EnumerableTableScan(table=[[scott, DEPT]])
+    EnumerableCalc(expr#0..7=[{inputs}], EMPNO=[$t0], SAL=[$t5], DEPTNO=[$t7])
+      EnumerableTableScan(table=[[scott, EMP]])
+!plan
+
+# Push sum and count, group by join key
+select count(*) as c, sum(sal) as s
+from "scott".emp join "scott".dept using (deptno) group by emp.deptno;
++---+----------+
+| C | S        |
++---+----------+
+| 3 |  8750.00 |
+| 5 | 10875.00 |
+| 6 |  9400.00 |
++---+----------+
+(3 rows)
+
+!ok
+# No aggregate on top, because output of join is unique
+EnumerableCalc(expr#0..2=[{inputs}], C=[$t1], S=[$t2])
+  EnumerableAggregate(group=[{3}], C=[COUNT()], S=[SUM($2)])
+    EnumerableJoin(condition=[=($0, $3)], joinType=[inner])
+      EnumerableCalc(expr#0..2=[{inputs}], DEPTNO=[$t0])
+        EnumerableTableScan(table=[[scott, DEPT]])
+      EnumerableCalc(expr#0..7=[{inputs}], EMPNO=[$t0], SAL=[$t5], DEPTNO=[$t7])
+        EnumerableTableScan(table=[[scott, EMP]])
+!plan
+
+# Push sum and count, group by join key plus another column
+select count(*) as c, sum(sal) as s
+from "scott".emp join "scott".dept using (deptno) group by emp.job, dept.deptno;
++---+---------+
+| C | S       |
++---+---------+
+| 1 | 1300.00 |
+| 1 | 2450.00 |
+| 1 | 2850.00 |
+| 1 | 2975.00 |
+| 1 | 5000.00 |
+| 1 |  950.00 |
+| 2 | 1900.00 |
+| 2 | 6000.00 |
+| 4 | 5600.00 |
++---+---------+
+(9 rows)
+
+!ok
+EnumerableCalc(expr#0..3=[{inputs}], C=[$t2], S=[$t3])
+  EnumerableAggregate(group=[{0, 2}], C=[COUNT()], S=[SUM($3)])
+    EnumerableJoin(condition=[=($0, $4)], joinType=[inner])
+      EnumerableCalc(expr#0..2=[{inputs}], DEPTNO=[$t0])
+        EnumerableTableScan(table=[[scott, DEPT]])
+      EnumerableCalc(expr#0..7=[{inputs}], EMPNO=[$t0], JOB=[$t2], SAL=[$t5], DEPTNO=[$t7])
+        EnumerableTableScan(table=[[scott, EMP]])
+!plan
+
+# Push sum and count, group by non-join column
+select count(*) as c, sum(sal) as s
+from "scott".emp join "scott".dept using (deptno) group by emp.job;
++---+---------+
+| C | S       |
++---+---------+
+| 1 | 5000.00 |
+| 2 | 6000.00 |
+| 3 | 8275.00 |
+| 4 | 4150.00 |
+| 4 | 5600.00 |
++---+---------+
+(5 rows)
+
+!ok
+EnumerableCalc(expr#0..2=[{inputs}], C=[$t1], S=[$t2])
+  EnumerableAggregate(group=[{2}], C=[COUNT()], S=[SUM($3)])
+    EnumerableJoin(condition=[=($0, $4)], joinType=[inner])
+      EnumerableCalc(expr#0..2=[{inputs}], DEPTNO=[$t0])
+        EnumerableTableScan(table=[[scott, DEPT]])
+      EnumerableCalc(expr#0..7=[{inputs}], EMPNO=[$t0], JOB=[$t2], SAL=[$t5], DEPTNO=[$t7])
+        EnumerableTableScan(table=[[scott, EMP]])
+!plan
+
+# Push count and sum, group by superset of join key
+select count(*) as c, sum(sal) as s
+from "scott".emp join "scott".dept using (deptno) group by emp.job, dept.deptno;
++---+---------+
+| C | S       |
++---+---------+
+| 1 | 5000.00 |
+| 2 | 6000.00 |
+| 4 | 5600.00 |
+| 1 | 1300.00 |
+| 1 | 2450.00 |
+| 1 | 2850.00 |
+| 1 | 2975.00 |
+| 1 |  950.00 |
+| 2 | 1900.00 |
++---+---------+
+(9 rows)
+
+!ok
+EnumerableCalc(expr#0..3=[{inputs}], C=[$t2], S=[$t3])
+  EnumerableAggregate(group=[{0, 2}], C=[COUNT()], S=[SUM($3)])
+    EnumerableJoin(condition=[=($0, $4)], joinType=[inner])
+      EnumerableCalc(expr#0..2=[{inputs}], DEPTNO=[$t0])
+        EnumerableTableScan(table=[[scott, DEPT]])
+      EnumerableCalc(expr#0..7=[{inputs}], EMPNO=[$t0], JOB=[$t2], SAL=[$t5], DEPTNO=[$t7])
+        EnumerableTableScan(table=[[scott, EMP]])
+!plan
+
+# Push count and sum, group by a column being aggregated
+select count(*) as c, sum(sal) as s
+from "scott".emp join "scott".dept using (deptno) group by emp.sal;
++---+---------+
+| C | S       |
++---+---------+
+| 1 | 5000.00 |
+| 2 | 6000.00 |
+| 1 | 1100.00 |
+| 1 | 1300.00 |
+| 1 | 1500.00 |
+| 1 | 1600.00 |
+| 1 | 2450.00 |
+| 1 | 2850.00 |
+| 1 | 2975.00 |
+| 1 |  800.00 |
+| 1 |  950.00 |
+| 2 | 2500.00 |
++---+---------+
+(12 rows)
+
+!ok
+EnumerableCalc(expr#0..2=[{inputs}], C=[$t1], S=[$t2])
+  EnumerableAggregate(group=[{2}], C=[COUNT()], S=[SUM($2)])
+    EnumerableJoin(condition=[=($0, $3)], joinType=[inner])
+      EnumerableCalc(expr#0..2=[{inputs}], DEPTNO=[$t0])
+        EnumerableTableScan(table=[[scott, DEPT]])
+      EnumerableCalc(expr#0..7=[{inputs}], EMPNO=[$t0], SAL=[$t5], DEPTNO=[$t7])
+        EnumerableTableScan(table=[[scott, EMP]])
+!plan
+
+# Push sum, self-join, returning one row with a null value
+select sum(e.sal) as s
+from "scott".emp e join "scott".emp m on e.mgr = e.empno;
++---+
+| S |
++---+
+|   |
++---+
+(1 row)
+
+!ok
+
+# Push sum, self-join
+select sum(e.sal) as s
+from "scott".emp e join "scott".emp m on e.mgr = m.empno;
++----------+
+| S        |
++----------+
+| 24025.00 |
++----------+
+(1 row)
+
+!ok
+
+# Push sum, self-join, aggregate by column on "many" side
+select sum(e.sal) as s
+from "scott".emp e join "scott".emp m on e.mgr = m.empno
+group by m.empno;
++---------+
+| S       |
++---------+
+| 1100.00 |
+| 1300.00 |
+| 6000.00 |
+| 6550.00 |
+|  800.00 |
+| 8275.00 |
++---------+
+(6 rows)
+
+!ok
+
+# Push sum, self-join, aggregate by column on "one" side.
+# Note inflated totals due to cartesian product.
+select sum(m.sal) as s
+from "scott".emp e join "scott".emp m on e.mgr = m.empno
+group by m.empno;
++----------+
+| S        |
++----------+
+| 14250.00 |
+| 15000.00 |
+|  2450.00 |
+|  3000.00 |
+|  3000.00 |
+|  5950.00 |
++----------+
+(6 rows)
+
+!ok
+
 # [CALCITE-729] IndexOutOfBoundsException in ROLLUP query on JDBC data source
 !use jdbc_scott
 select deptno, job, count(*) as c

http://git-wip-us.apache.org/repos/asf/incubator-calcite/blob/cf7a7a97/core/src/test/resources/sql/join.oq
----------------------------------------------------------------------
diff --git a/core/src/test/resources/sql/join.oq b/core/src/test/resources/sql/join.oq
index bb273dc..0c69e2a 100644
--- a/core/src/test/resources/sql/join.oq
+++ b/core/src/test/resources/sql/join.oq
@@ -122,11 +122,12 @@ from "scott".emp join "scott".dept using (deptno);
 (3 rows)
 
 !ok
-EnumerableJoin(condition=[=($0, $1)], joinType=[inner])
-  EnumerableAggregate(group=[{0}])
-    EnumerableTableScan(table=[[scott, DEPT]])
-  EnumerableAggregate(group=[{7}])
-    EnumerableTableScan(table=[[scott, EMP]])
+EnumerableAggregate(group=[{0, 2}])
+  EnumerableJoin(condition=[=($0, $2)], joinType=[inner])
+    EnumerableCalc(expr#0..2=[{inputs}], DEPTNO=[$t0])
+      EnumerableTableScan(table=[[scott, DEPT]])
+    EnumerableCalc(expr#0..7=[{inputs}], EMPNO=[$t0], DEPTNO=[$t7])
+      EnumerableTableScan(table=[[scott, EMP]])
 !plan
 
 select distinct dept.deptno

http://git-wip-us.apache.org/repos/asf/incubator-calcite/blob/cf7a7a97/plus/src/test/java/org/apache/calcite/adapter/tpcds/TpcdsTest.java
----------------------------------------------------------------------
diff --git a/plus/src/test/java/org/apache/calcite/adapter/tpcds/TpcdsTest.java b/plus/src/test/java/org/apache/calcite/adapter/tpcds/TpcdsTest.java
index 9f855f2..15fc496 100644
--- a/plus/src/test/java/org/apache/calcite/adapter/tpcds/TpcdsTest.java
+++ b/plus/src/test/java/org/apache/calcite/adapter/tpcds/TpcdsTest.java
@@ -16,11 +16,21 @@
  */
 package org.apache.calcite.adapter.tpcds;
 
+import org.apache.calcite.jdbc.CalciteConnection;
+import org.apache.calcite.plan.RelOptUtil;
+import org.apache.calcite.plan.RelTraitDef;
 import org.apache.calcite.prepare.Prepare;
+import org.apache.calcite.rel.RelNode;
+import org.apache.calcite.rel.core.JoinRelType;
 import org.apache.calcite.runtime.Hook;
+import org.apache.calcite.schema.SchemaPlus;
+import org.apache.calcite.sql.fun.SqlStdOperatorTable;
+import org.apache.calcite.sql.parser.SqlParser;
 import org.apache.calcite.test.CalciteAssert;
+import org.apache.calcite.tools.Frameworks;
 import org.apache.calcite.tools.Program;
 import org.apache.calcite.tools.Programs;
+import org.apache.calcite.tools.RelBuilder;
 import org.apache.calcite.util.Bug;
 import org.apache.calcite.util.Holder;
 import org.apache.calcite.util.Pair;
@@ -141,6 +151,10 @@ public class TpcdsTest {
                 + "                EnumerableTableAccessRel(table=[[TPCDS, CATALOG_SALES]]): rowcount = 1441548.0, cumulative cost = {1441548.0 rows, 1441549.0 cpu, 0.0 io}\n"));
   }
 
+  @Test public void testQuery27() {
+    checkQuery(27).runs();
+  }
+
   @Test public void testQuery58() {
     checkQuery(58).explainContains("PLAN").runs();
   }
@@ -189,6 +203,95 @@ public class TpcdsTest {
     return with()
         .query(sql.replaceAll("tpcds\\.", "tpcds_01."));
   }
+
+  public Frameworks.ConfigBuilder config() throws Exception {
+    final Holder<SchemaPlus> root = Holder.of(null);
+    CalciteAssert.model(TPCDS_MODEL)
+        .doWithConnection(
+            new Function<CalciteConnection, Object>() {
+              public Object apply(CalciteConnection input) {
+                root.set(input.getRootSchema().getSubSchema("TPCDS"));
+                return null;
+              }
+            });
+    return Frameworks.newConfigBuilder()
+        .parserConfig(SqlParser.Config.DEFAULT)
+        .defaultSchema(root.get())
+        .traitDefs((List<RelTraitDef>) null)
+        .programs(Programs.heuristicJoinOrder(Programs.RULE_SET, true, 2));
+  }
+
+  /**
+   * Builder query 27 using {@link RelBuilder}.
+   *
+   * <blockquote><pre>
+   *   select  i_item_id,
+   *         s_state, grouping(s_state) g_state,
+   *         avg(ss_quantity) agg1,
+   *         avg(ss_list_price) agg2,
+   *         avg(ss_coupon_amt) agg3,
+   *         avg(ss_sales_price) agg4
+   * from store_sales, customer_demographics, date_dim, store, item
+   * where ss_sold_date_sk = d_date_sk and
+   *        ss_item_sk = i_item_sk and
+   *        ss_store_sk = s_store_sk and
+   *        ss_cdemo_sk = cd_demo_sk and
+   *        cd_gender = 'dist(gender, 1, 1)' and
+   *        cd_marital_status = 'dist(marital_status, 1, 1)' and
+   *        cd_education_status = 'dist(education, 1, 1)' and
+   *        d_year = 1998 and
+   *        s_state in ('distmember(fips_county,[STATENUMBER.1], 3)',
+   *              'distmember(fips_county,[STATENUMBER.2], 3)',
+   *              'distmember(fips_county,[STATENUMBER.3], 3)',
+   *              'distmember(fips_county,[STATENUMBER.4], 3)',
+   *              'distmember(fips_county,[STATENUMBER.5], 3)',
+   *              'distmember(fips_county,[STATENUMBER.6], 3)')
+   *  group by rollup (i_item_id, s_state)
+   *  order by i_item_id
+   *          ,s_state
+   *  LIMIT 100
+   * </pre></blockquote>
+   */
+  @Test public void testQuery27Builder() throws Exception {
+    final RelBuilder builder = RelBuilder.create(config().build());
+    final RelNode root =
+        builder.scan("STORE_SALES")
+            .scan("CUSTOMER_DEMOGRAPHICS")
+            .scan("DATE_DIM")
+            .scan("STORE")
+            .scan("ITEM")
+            .join(JoinRelType.INNER)
+            .join(JoinRelType.INNER)
+            .join(JoinRelType.INNER)
+            .join(JoinRelType.INNER)
+            .filter(
+                builder.equals(builder.field("SS_SOLD_DATE_SK"), builder.field("D_DATE_SK")),
+                builder.equals(builder.field("SS_ITEM_SK"), builder.field("I_ITEM_SK")),
+                builder.equals(builder.field("SS_STORE_SK"), builder.field("S_STORE_SK")),
+                builder.equals(builder.field("SS_CDEMO_SK"), builder.field("CD_DEMO_SK")),
+                builder.equals(builder.field("CD_GENDER"), builder.literal("M")),
+                builder.equals(builder.field("CD_MARITAL_STATUS"), builder.literal("S")),
+                builder.equals(builder.field("CD_EDUCATION_STATUS"),
+                    builder.literal("HIGH SCHOOL")),
+                builder.equals(builder.field("D_YEAR"), builder.literal(1998)),
+                builder.call(SqlStdOperatorTable.IN,
+                    builder.field("S_STATE"),
+                    builder.call(SqlStdOperatorTable.ARRAY_VALUE_CONSTRUCTOR,
+                        builder.literal("CA"),
+                        builder.literal("OR"),
+                        builder.literal("WA"),
+                        builder.literal("TX"),
+                        builder.literal("OK"),
+                        builder.literal("MD"))))
+            .aggregate(builder.groupKey("I_ITEM_ID", "S_STATE"),
+                builder.avg(false, "AGG1", builder.field("SS_QUANTITY")),
+                builder.avg(false, "AGG2", builder.field("SS_LIST_PRICE")),
+                builder.avg(false, "AGG3", builder.field("SS_COUPON_AMT")),
+                builder.avg(false, "AGG4", builder.field("SS_SALES_PRICE")))
+            .sortLimit(0, 100, builder.field("I_ITEM_ID"), builder.field("S_STATE"))
+            .build();
+    System.out.println(RelOptUtil.toString(root));
+  }
 }
 
 // End TpcdsTest.java