You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@calcite.apache.org by jh...@apache.org on 2016/01/21 23:39:01 UTC

[24/50] [abbrv] calcite git commit: [CALCITE-816] Represent sub-query as a RexNode

http://git-wip-us.apache.org/repos/asf/calcite/blob/505a9064/core/src/main/java/org/apache/calcite/sql2rel/RelDecorrelator.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/calcite/sql2rel/RelDecorrelator.java b/core/src/main/java/org/apache/calcite/sql2rel/RelDecorrelator.java
index 2f1d6b9..2812851 100644
--- a/core/src/main/java/org/apache/calcite/sql2rel/RelDecorrelator.java
+++ b/core/src/main/java/org/apache/calcite/sql2rel/RelDecorrelator.java
@@ -31,12 +31,16 @@ import org.apache.calcite.rel.BiRel;
 import org.apache.calcite.rel.RelCollation;
 import org.apache.calcite.rel.RelNode;
 import org.apache.calcite.rel.RelShuttleImpl;
-import org.apache.calcite.rel.RelVisitor;
 import org.apache.calcite.rel.core.Aggregate;
 import org.apache.calcite.rel.core.AggregateCall;
+import org.apache.calcite.rel.core.Correlate;
 import org.apache.calcite.rel.core.CorrelationId;
+import org.apache.calcite.rel.core.Join;
 import org.apache.calcite.rel.core.JoinRelType;
+import org.apache.calcite.rel.core.Project;
+import org.apache.calcite.rel.core.RelFactories;
 import org.apache.calcite.rel.core.Sort;
+import org.apache.calcite.rel.core.Values;
 import org.apache.calcite.rel.logical.LogicalAggregate;
 import org.apache.calcite.rel.logical.LogicalCorrelate;
 import org.apache.calcite.rel.logical.LogicalFilter;
@@ -58,6 +62,7 @@ import org.apache.calcite.rex.RexInputRef;
 import org.apache.calcite.rex.RexLiteral;
 import org.apache.calcite.rex.RexNode;
 import org.apache.calcite.rex.RexShuttle;
+import org.apache.calcite.rex.RexSubQuery;
 import org.apache.calcite.rex.RexUtil;
 import org.apache.calcite.rex.RexVisitorImpl;
 import org.apache.calcite.sql.SqlExplainLevel;
@@ -67,20 +72,25 @@ import org.apache.calcite.sql.SqlOperator;
 import org.apache.calcite.sql.fun.SqlCountAggFunction;
 import org.apache.calcite.sql.fun.SqlSingleValueAggFunction;
 import org.apache.calcite.sql.fun.SqlStdOperatorTable;
+import org.apache.calcite.tools.RelBuilder;
 import org.apache.calcite.util.Bug;
 import org.apache.calcite.util.Holder;
 import org.apache.calcite.util.ImmutableBitSet;
+import org.apache.calcite.util.Litmus;
 import org.apache.calcite.util.Pair;
 import org.apache.calcite.util.ReflectUtil;
-import org.apache.calcite.util.ReflectiveVisitDispatcher;
 import org.apache.calcite.util.ReflectiveVisitor;
+import org.apache.calcite.util.Stacks;
 import org.apache.calcite.util.Util;
 import org.apache.calcite.util.mapping.Mappings;
 import org.apache.calcite.util.trace.CalciteTrace;
 
+import com.google.common.base.Preconditions;
 import com.google.common.base.Supplier;
 import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
 import com.google.common.collect.ImmutableSet;
+import com.google.common.collect.ImmutableSortedMap;
 import com.google.common.collect.Lists;
 import com.google.common.collect.Maps;
 import com.google.common.collect.Multimap;
@@ -96,6 +106,7 @@ import java.util.HashMap;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
+import java.util.NavigableMap;
 import java.util.Set;
 import java.util.SortedMap;
 import java.util.TreeMap;
@@ -127,10 +138,14 @@ public class RelDecorrelator implements ReflectiveVisitor {
 
   //~ Instance fields --------------------------------------------------------
 
+  private final RelBuilder relBuilder;
+
   // map built during translation
   private CorelMap cm;
 
-  private final DecorrelateRelVisitor decorrelateVisitor;
+  private final ReflectUtil.MethodDispatcher<Frame> dispatcher =
+      ReflectUtil.createMethodDispatcher(Frame.class, this, "decorrelateRel",
+          RelNode.class);
 
   private final RexBuilder rexBuilder;
 
@@ -139,31 +154,24 @@ public class RelDecorrelator implements ReflectiveVisitor {
 
   private final Context context;
 
-  // maps built during decorrelation
-  private final Map<RelNode, RelNode> mapOldToNewRel = Maps.newHashMap();
-
-  // map rel to all the newly created correlated variables in its output
-  private final Map<RelNode, SortedMap<Correlation, Integer>>
-  mapNewRelToMapCorVarToOutputPos = Maps.newHashMap();
-
-  // another map to map old input positions to new input positions
-  // this is from the view point of the parent rel of a new rel.
-  private final Map<RelNode, Map<Integer, Integer>>
-  mapNewRelToMapOldToNewOutputPos = Maps.newHashMap();
+  /** Built during decorrelation, of rel to all the newly created correlated
+   * variables in its output, and to map old input positions to new input
+   * positions. This is from the view point of the parent rel of a new rel. */
+  private final Map<RelNode, Frame> map = new HashMap<>();
 
   private final HashSet<LogicalCorrelate> generatedCorRels = Sets.newHashSet();
 
   //~ Constructors -----------------------------------------------------------
 
   private RelDecorrelator(
-      RexBuilder rexBuilder,
+      RelOptCluster cluster,
       CorelMap cm,
       Context context) {
     this.cm = cm;
-    this.rexBuilder = rexBuilder;
+    this.rexBuilder = cluster.getRexBuilder();
     this.context = context;
+    relBuilder = RelFactories.LOGICAL_BUILDER.create(cluster, null);
 
-    decorrelateVisitor = new DecorrelateRelVisitor();
   }
 
   //~ Methods ----------------------------------------------------------------
@@ -178,18 +186,16 @@ public class RelDecorrelator implements ReflectiveVisitor {
    * {@link org.apache.calcite.rel.logical.LogicalCorrelate} instances removed
    */
   public static RelNode decorrelateQuery(RelNode rootRel) {
-    final CorelMap corelMap = CorelMap.build(rootRel);
+    final CorelMap corelMap = new CorelMapBuilder().build(rootRel);
     if (!corelMap.hasCorrelation()) {
       return rootRel;
     }
 
     final RelOptCluster cluster = rootRel.getCluster();
-    final RexBuilder rexBuilder = cluster.getRexBuilder();
     final RelDecorrelator decorrelator =
-        new RelDecorrelator(rexBuilder, corelMap,
+        new RelDecorrelator(cluster, corelMap,
             cluster.getPlanner().getContext());
 
-
     RelNode newRootRel = decorrelator.removeCorrelationViaRule(rootRel);
 
     if (SQL2REL_LOGGER.isLoggable(Level.FINE)) {
@@ -211,7 +217,7 @@ public class RelDecorrelator implements ReflectiveVisitor {
   private void setCurrent(RelNode root, LogicalCorrelate corRel) {
     currentRel = corRel;
     if (corRel != null) {
-      cm = CorelMap.build(Util.first(root, corRel));
+      cm = new CorelMapBuilder().build(Util.first(root, corRel));
     }
   }
 
@@ -231,13 +237,10 @@ public class RelDecorrelator implements ReflectiveVisitor {
     root = planner.findBestExp();
 
     // Perform decorrelation.
-    mapOldToNewRel.clear();
-    mapNewRelToMapCorVarToOutputPos.clear();
-    mapNewRelToMapOldToNewOutputPos.clear();
-
-    decorrelateVisitor.visit(root, 0, null);
+    map.clear();
 
-    if (mapOldToNewRel.containsKey(root)) {
+    final Frame frame = getInvoke(root, null);
+    if (frame != null) {
       // has been rewritten; apply rules post-decorrelation
       final HepProgram program2 = HepProgram.builder()
           .addRuleInstance(FilterJoinRule.FILTER_ON_JOIN)
@@ -245,7 +248,7 @@ public class RelDecorrelator implements ReflectiveVisitor {
           .build();
 
       final HepPlanner planner2 = createPlanner(program2);
-      final RelNode newRoot = mapOldToNewRel.get(root);
+      final RelNode newRoot = frame.r;
       planner2.setRoot(newRoot);
       return planner2.findBestExp();
     }
@@ -265,7 +268,7 @@ public class RelDecorrelator implements ReflectiveVisitor {
           LogicalCorrelate oldCor = (LogicalCorrelate) oldNode;
           CorrelationId c = oldCor.getCorrelationId();
           if (cm.mapCorVarToCorRel.get(c) == oldNode) {
-            cm.mapCorVarToCorRel.put(c, (LogicalCorrelate) newNode);
+            cm.mapCorVarToCorRel.put(c, newNode);
           }
 
           if (generatedCorRels.contains(oldNode)) {
@@ -298,9 +301,7 @@ public class RelDecorrelator implements ReflectiveVisitor {
     HepPlanner planner = createPlanner(program);
 
     planner.setRoot(root);
-    RelNode newRootRel = planner.findBestExp();
-
-    return newRootRel;
+    return planner.findBestExp();
   }
 
   protected RexNode decorrelateExpr(RexNode exp) {
@@ -312,9 +313,8 @@ public class RelDecorrelator implements ReflectiveVisitor {
       RexNode exp,
       boolean projectPulledAboveLeftCorrelator) {
     RemoveCorrelationRexShuttle shuttle =
-        new RemoveCorrelationRexShuttle(
-            rexBuilder,
-            projectPulledAboveLeftCorrelator);
+        new RemoveCorrelationRexShuttle(rexBuilder,
+            projectPulledAboveLeftCorrelator, null, ImmutableSet.<Integer>of());
     return exp.accept(shuttle);
   }
 
@@ -323,10 +323,9 @@ public class RelDecorrelator implements ReflectiveVisitor {
       boolean projectPulledAboveLeftCorrelator,
       RexInputRef nullIndicator) {
     RemoveCorrelationRexShuttle shuttle =
-        new RemoveCorrelationRexShuttle(
-            rexBuilder,
-            projectPulledAboveLeftCorrelator,
-            nullIndicator);
+        new RemoveCorrelationRexShuttle(rexBuilder,
+            projectPulledAboveLeftCorrelator, nullIndicator,
+            ImmutableSet.<Integer>of());
     return exp.accept(shuttle);
   }
 
@@ -335,30 +334,27 @@ public class RelDecorrelator implements ReflectiveVisitor {
       boolean projectPulledAboveLeftCorrelator,
       Set<Integer> isCount) {
     RemoveCorrelationRexShuttle shuttle =
-        new RemoveCorrelationRexShuttle(
-            rexBuilder,
-            projectPulledAboveLeftCorrelator,
-            isCount);
+        new RemoveCorrelationRexShuttle(rexBuilder,
+            projectPulledAboveLeftCorrelator, null, isCount);
     return exp.accept(shuttle);
   }
 
-  public void decorrelateRelGeneric(RelNode rel) {
+  /** Fallback if none of the other {@code decorrelateRel} methods match. */
+  public Frame decorrelateRel(RelNode rel) {
     RelNode newRel = rel.copy(rel.getTraitSet(), rel.getInputs());
 
     if (rel.getInputs().size() > 0) {
       List<RelNode> oldInputs = rel.getInputs();
       List<RelNode> newInputs = Lists.newArrayList();
       for (int i = 0; i < oldInputs.size(); ++i) {
-        RelNode newInputRel = mapOldToNewRel.get(oldInputs.get(i));
-        if ((newInputRel == null)
-            || mapNewRelToMapCorVarToOutputPos.containsKey(newInputRel)) {
-          // if child is not rewritten, or if it produces correlated
+        final Frame frame = getInvoke(oldInputs.get(i), rel);
+        if (frame == null || !frame.corVarOutputPos.isEmpty()) {
+          // if input is not rewritten, or if it produces correlated
           // variables, terminate rewrite
-          return;
-        } else {
-          newInputs.add(newInputRel);
-          newRel.replaceInput(i, newInputRel);
+          return null;
         }
+        newInputs.add(frame.r);
+        newRel.replaceInput(i, frame.r);
       }
 
       if (!Util.equalShallow(oldInputs, newInputs)) {
@@ -368,12 +364,8 @@ public class RelDecorrelator implements ReflectiveVisitor {
 
     // the output position should not change since there are no corVars
     // coming from below.
-    Map<Integer, Integer> mapOldToNewOutputPos = Maps.newHashMap();
-    for (int i = 0; i < rel.getRowType().getFieldCount(); i++) {
-      mapOldToNewOutputPos.put(i, i);
-    }
-    mapOldToNewRel.put(rel, newRel);
-    mapNewRelToMapOldToNewOutputPos.put(newRel, mapOldToNewOutputPos);
+    return register(rel, newRel, identityMap(rel.getRowType().getFieldCount()),
+        ImmutableSortedMap.<Correlation, Integer>of());
   }
 
   /**
@@ -381,7 +373,7 @@ public class RelDecorrelator implements ReflectiveVisitor {
    *
    * @param rel Sort to be rewritten
    */
-  public void decorrelateRel(Sort rel) {
+  public Frame decorrelateRel(Sort rel) {
     //
     // Rewrite logic:
     //
@@ -397,33 +389,39 @@ public class RelDecorrelator implements ReflectiveVisitor {
     // Its output does not change the input ordering, so there's no
     // need to call propagateExpr.
 
-    RelNode oldChildRel = rel.getInput();
-
-    RelNode newChildRel = mapOldToNewRel.get(oldChildRel);
-    if (newChildRel == null) {
-      // If child has not been rewritten, do not rewrite this rel.
-      return;
+    final RelNode oldInput = rel.getInput();
+    final Frame frame = getInvoke(oldInput, rel);
+    if (frame == null) {
+      // If input has not been rewritten, do not rewrite this rel.
+      return null;
     }
+    final RelNode newInput = frame.r;
 
-    Map<Integer, Integer> childMapOldToNewOutputPos =
-        mapNewRelToMapOldToNewOutputPos.get(newChildRel);
-    assert childMapOldToNewOutputPos != null;
     Mappings.TargetMapping mapping =
         Mappings.target(
-            childMapOldToNewOutputPos,
-            oldChildRel.getRowType().getFieldCount(),
-            newChildRel.getRowType().getFieldCount());
+            frame.oldToNewOutputPos,
+            oldInput.getRowType().getFieldCount(),
+            newInput.getRowType().getFieldCount());
 
     RelCollation oldCollation = rel.getCollation();
     RelCollation newCollation = RexUtil.apply(mapping, oldCollation);
 
-    final Sort newRel =
-        LogicalSort.create(newChildRel, newCollation, rel.offset, rel.fetch);
-
-    mapOldToNewRel.put(rel, newRel);
+    final Sort newSort =
+        LogicalSort.create(newInput, newCollation, rel.offset, rel.fetch);
 
     // Sort does not change input ordering
-    mapNewRelToMapOldToNewOutputPos.put(newRel, childMapOldToNewOutputPos);
+    return register(rel, newSort, frame.oldToNewOutputPos,
+        frame.corVarOutputPos);
+  }
+
+  /**
+   * Rewrites a {@link Values}.
+   *
+   * @param rel Values to be rewritten
+   */
+  public Frame decorrelateRel(Values rel) {
+    // There are no inputs, so rel does not need to be changed.
+    return null;
   }
 
   /**
@@ -431,7 +429,7 @@ public class RelDecorrelator implements ReflectiveVisitor {
    *
    * @param rel Aggregate to rewrite
    */
-  public void decorrelateRel(LogicalAggregate rel) {
+  public Frame decorrelateRel(LogicalAggregate rel) {
     if (rel.getGroupType() != Aggregate.Group.SIMPLE) {
       throw new AssertionError(Bug.CALCITE_461_FIXED);
     }
@@ -439,7 +437,7 @@ public class RelDecorrelator implements ReflectiveVisitor {
     // Rewrite logic:
     //
     // 1. Permute the group by keys to the front.
-    // 2. If the child of an aggregate produces correlated variables,
+    // 2. If the input of an aggregate produces correlated variables,
     //    add them to the group list.
     // 3. Change aggCalls to reference the new project.
     //
@@ -447,117 +445,107 @@ public class RelDecorrelator implements ReflectiveVisitor {
     // Aggregate itself should not reference cor vars.
     assert !cm.mapRefRelToCorVar.containsKey(rel);
 
-    RelNode oldChildRel = rel.getInput();
-
-    RelNode newChildRel = mapOldToNewRel.get(oldChildRel);
-    if (newChildRel == null) {
-      // If child has not been rewritten, do not rewrite this rel.
-      return;
+    final RelNode oldInput = rel.getInput();
+    final Frame frame = getInvoke(oldInput, rel);
+    if (frame == null) {
+      // If input has not been rewritten, do not rewrite this rel.
+      return null;
     }
+    assert !frame.corVarOutputPos.isEmpty();
+    final RelNode newInput = frame.r;
 
-    Map<Integer, Integer> childMapOldToNewOutputPos =
-        mapNewRelToMapOldToNewOutputPos.get(newChildRel);
-    assert childMapOldToNewOutputPos != null;
-
-    // map from newChildRel
-    Map<Integer, Integer> mapNewChildToProjOutputPos = Maps.newHashMap();
+    // map from newInput
+    Map<Integer, Integer> mapNewInputToProjOutputPos = Maps.newHashMap();
     final int oldGroupKeyCount = rel.getGroupSet().cardinality();
 
-    // LogicalProject projects the original expressions,
-    // plus any correlated variables the child wants to pass along.
+    // Project projects the original expressions,
+    // plus any correlated variables the input wants to pass along.
     final List<Pair<RexNode, String>> projects = Lists.newArrayList();
 
-    List<RelDataTypeField> newChildOutput =
-        newChildRel.getRowType().getFieldList();
+    List<RelDataTypeField> newInputOutput =
+        newInput.getRowType().getFieldList();
 
-    int newPos;
+    int newPos = 0;
 
-    // oldChildRel has the original group by keys in the front.
-    for (newPos = 0; newPos < oldGroupKeyCount; newPos++) {
-      int newChildPos = childMapOldToNewOutputPos.get(newPos);
-      projects.add(RexInputRef.of2(newChildPos, newChildOutput));
-      mapNewChildToProjOutputPos.put(newChildPos, newPos);
+    // oldInput has the original group by keys in the front.
+    final NavigableMap<Integer, RexLiteral> omittedConstants = new TreeMap<>();
+    for (int i = 0; i < oldGroupKeyCount; i++) {
+      final RexLiteral constant = projectedLiteral(newInput, i);
+      if (constant != null) {
+        // Exclude constants. Aggregate({true}) occurs because Aggregate({})
+        // would generate 1 row even when applied to an empty table.
+        omittedConstants.put(i, constant);
+        continue;
+      }
+      int newInputPos = frame.oldToNewOutputPos.get(i);
+      projects.add(RexInputRef.of2(newInputPos, newInputOutput));
+      mapNewInputToProjOutputPos.put(newInputPos, newPos);
+      newPos++;
     }
 
-    SortedMap<Correlation, Integer> mapCorVarToOutputPos = Maps.newTreeMap();
-
-    boolean produceCorVar =
-        mapNewRelToMapCorVarToOutputPos.containsKey(newChildRel);
-    if (produceCorVar) {
-      // If child produces correlated variables, move them to the front,
-      // right after any existing groupby fields.
+    final SortedMap<Correlation, Integer> mapCorVarToOutputPos = new TreeMap<>();
+    if (!frame.corVarOutputPos.isEmpty()) {
+      // If input produces correlated variables, move them to the front,
+      // right after any existing GROUP BY fields.
 
-      SortedMap<Correlation, Integer> childMapCorVarToOutputPos =
-          mapNewRelToMapCorVarToOutputPos.get(newChildRel);
-
-      // Now add the corVars from the child, starting from
+      // Now add the corVars from the input, starting from
       // position oldGroupKeyCount.
-      for (Correlation corVar
-          : childMapCorVarToOutputPos.keySet()) {
-        int newChildPos = childMapCorVarToOutputPos.get(corVar);
-        projects.add(RexInputRef.of2(newChildPos, newChildOutput));
+      for (Map.Entry<Correlation, Integer> entry
+          : frame.corVarOutputPos.entrySet()) {
+        projects.add(RexInputRef.of2(entry.getValue(), newInputOutput));
 
-        mapCorVarToOutputPos.put(corVar, newPos);
-        mapNewChildToProjOutputPos.put(newChildPos, newPos);
+        mapCorVarToOutputPos.put(entry.getKey(), newPos);
+        mapNewInputToProjOutputPos.put(entry.getValue(), newPos);
         newPos++;
       }
     }
 
     // add the remaining fields
     final int newGroupKeyCount = newPos;
-    for (int i = 0; i < newChildOutput.size(); i++) {
-      if (!mapNewChildToProjOutputPos.containsKey(i)) {
-        projects.add(RexInputRef.of2(i, newChildOutput));
-        mapNewChildToProjOutputPos.put(i, newPos);
+    for (int i = 0; i < newInputOutput.size(); i++) {
+      if (!mapNewInputToProjOutputPos.containsKey(i)) {
+        projects.add(RexInputRef.of2(i, newInputOutput));
+        mapNewInputToProjOutputPos.put(i, newPos);
         newPos++;
       }
     }
 
-    assert newPos == newChildOutput.size();
+    assert newPos == newInputOutput.size();
 
-    // This LogicalProject will be what the old child maps to,
-    // replacing any previous mapping from old child).
-    RelNode newProjectRel =
-        RelOptUtil.createProject(newChildRel, projects, false);
+    // This Project will be what the old input maps to,
+    // replacing any previous mapping from old input).
+    RelNode newProject =
+        RelOptUtil.createProject(newInput, projects, false);
 
     // update mappings:
-    // oldChildRel ----> newChildRel
+    // oldInput ----> newInput
     //
-    //                   newProjectRel
-    //                        |
-    // oldChildRel ---->  newChildRel
+    //                newProject
+    //                   |
+    // oldInput ----> newInput
     //
     // is transformed to
     //
-    // oldChildRel ----> newProjectRel
-    //                        |
-    //                   newChildRel
+    // oldInput ----> newProject
+    //                   |
+    //                newInput
     Map<Integer, Integer> combinedMap = Maps.newHashMap();
 
-    for (Integer oldChildPos : childMapOldToNewOutputPos.keySet()) {
-      combinedMap.put(
-          oldChildPos,
-          mapNewChildToProjOutputPos.get(
-              childMapOldToNewOutputPos.get(oldChildPos)));
+    for (Integer oldInputPos : frame.oldToNewOutputPos.keySet()) {
+      combinedMap.put(oldInputPos,
+          mapNewInputToProjOutputPos.get(
+              frame.oldToNewOutputPos.get(oldInputPos)));
     }
 
-    mapOldToNewRel.put(oldChildRel, newProjectRel);
-    mapNewRelToMapOldToNewOutputPos.put(newProjectRel, combinedMap);
+    register(oldInput, newProject, combinedMap, mapCorVarToOutputPos);
 
-    if (produceCorVar) {
-      mapNewRelToMapCorVarToOutputPos.put(
-          newProjectRel,
-          mapCorVarToOutputPos);
-    }
-
-    // now it's time to rewrite LogicalAggregate
+    // now it's time to rewrite the Aggregate
+    final ImmutableBitSet newGroupSet = ImmutableBitSet.range(newGroupKeyCount);
     List<AggregateCall> newAggCalls = Lists.newArrayList();
     List<AggregateCall> oldAggCalls = rel.getAggCallList();
 
-    // LogicalAggregate.Call oldAggCall;
-    int oldChildOutputFieldCount = oldChildRel.getRowType().getFieldCount();
-    int newChildOutputFieldCount =
-        newProjectRel.getRowType().getFieldCount();
+    int oldInputOutputFieldCount = rel.getGroupSet().cardinality();
+    int newInputOutputFieldCount = newGroupSet.cardinality();
 
     int i = -1;
     for (AggregateCall oldAggCall : oldAggCalls) {
@@ -567,7 +555,7 @@ public class RelDecorrelator implements ReflectiveVisitor {
       List<Integer> aggArgs = Lists.newArrayList();
 
       // Adjust the aggregator argument positions.
-      // Note aggregator does not change input ordering, so the child
+      // Note aggregator does not change input ordering, so the input
       // output position mapping can be used to derive the new positions
       // for the argument.
       for (int oldPos : oldAggArgs) {
@@ -577,34 +565,57 @@ public class RelDecorrelator implements ReflectiveVisitor {
           : combinedMap.get(oldAggCall.filterArg);
 
       newAggCalls.add(
-          oldAggCall.adaptTo(newProjectRel, aggArgs, filterArg,
+          oldAggCall.adaptTo(newProject, aggArgs, filterArg,
               oldGroupKeyCount, newGroupKeyCount));
 
       // The old to new output position mapping will be the same as that
-      // of newProjectRel, plus any aggregates that the oldAgg produces.
+      // of newProject, plus any aggregates that the oldAgg produces.
       combinedMap.put(
-          oldChildOutputFieldCount + i,
-          newChildOutputFieldCount + i);
+          oldInputOutputFieldCount + i,
+          newInputOutputFieldCount + i);
     }
 
-    LogicalAggregate newAggregate =
-        LogicalAggregate.create(newProjectRel,
+    relBuilder.push(
+        LogicalAggregate.create(newProject,
             false,
-            ImmutableBitSet.range(newGroupKeyCount),
+            newGroupSet,
             null,
-            newAggCalls);
+            newAggCalls));
+
+    if (!omittedConstants.isEmpty()) {
+      final List<RexNode> postProjects = new ArrayList<>(relBuilder.fields());
+      for (Map.Entry<Integer, RexLiteral> entry
+          : omittedConstants.descendingMap().entrySet()) {
+        postProjects.add(entry.getKey() + frame.corVarOutputPos.size(),
+            entry.getValue());
+      }
+      relBuilder.project(postProjects);
+    }
 
-    mapOldToNewRel.put(rel, newAggregate);
+    // Aggregate does not change input ordering so corVars will be
+    // located at the same position as the input newProject.
+    return register(rel, relBuilder.build(), combinedMap, mapCorVarToOutputPos);
+  }
 
-    mapNewRelToMapOldToNewOutputPos.put(newAggregate, combinedMap);
+  public Frame getInvoke(RelNode r, RelNode parent) {
+    final Frame frame = dispatcher.invoke(r);
+    if (frame != null) {
+      map.put(r, frame);
+    }
+    currentRel = parent;
+    return frame;
+  }
 
-    if (produceCorVar) {
-      // LogicalAggregate does not change input ordering so corVars will be
-      // located at the same position as the input newProjectRel.
-      mapNewRelToMapCorVarToOutputPos.put(
-          newAggregate,
-          mapCorVarToOutputPos);
+  /** Returns a literal output field, or null if it is not literal. */
+  private static RexLiteral projectedLiteral(RelNode rel, int i) {
+    if (rel instanceof Project) {
+      final Project project = (Project) rel;
+      final RexNode node = project.getProjects().get(i);
+      if (node instanceof RexLiteral) {
+        return (RexLiteral) node;
+      }
     }
+    return null;
   }
 
   /**
@@ -612,34 +623,24 @@ public class RelDecorrelator implements ReflectiveVisitor {
    *
    * @param rel the project rel to rewrite
    */
-  public void decorrelateRel(LogicalProject rel) {
+  public Frame decorrelateRel(LogicalProject rel) {
     //
     // Rewrite logic:
     //
-    // 1. Pass along any correlated variables coming from the child.
+    // 1. Pass along any correlated variables coming from the input.
     //
 
-    RelNode oldChildRel = rel.getInput();
-
-    RelNode newChildRel = mapOldToNewRel.get(oldChildRel);
-    if (newChildRel == null) {
-      // If child has not been rewritten, do not rewrite this rel.
-      return;
+    final RelNode oldInput = rel.getInput();
+    Frame frame = getInvoke(oldInput, rel);
+    if (frame == null) {
+      // If input has not been rewritten, do not rewrite this rel.
+      return null;
     }
-    List<RexNode> oldProj = rel.getProjects();
-    List<RelDataTypeField> relOutput = rel.getRowType().getFieldList();
-
-    Map<Integer, Integer> childMapOldToNewOutputPos =
-        mapNewRelToMapOldToNewOutputPos.get(newChildRel);
-    assert childMapOldToNewOutputPos != null;
-
-    Map<Integer, Integer> mapOldToNewOutputPos = Maps.newHashMap();
-
-    boolean produceCorVar =
-        mapNewRelToMapCorVarToOutputPos.containsKey(newChildRel);
+    final List<RexNode> oldProjects = rel.getProjects();
+    final List<RelDataTypeField> relOutput = rel.getRowType().getFieldList();
 
     // LogicalProject projects the original expressions,
-    // plus any correlated variables the child wants to pass along.
+    // plus any correlated variables the input wants to pass along.
     final List<Pair<RexNode, String>> projects = Lists.newArrayList();
 
     // If this LogicalProject has correlated reference, create value generator
@@ -647,55 +648,38 @@ public class RelDecorrelator implements ReflectiveVisitor {
     if (cm.mapRefRelToCorVar.containsKey(rel)) {
       decorrelateInputWithValueGenerator(rel);
 
-      // The old child should be mapped to the LogicalJoin created by
+      // The old input should be mapped to the LogicalJoin created by
       // rewriteInputWithValueGenerator().
-      newChildRel = mapOldToNewRel.get(oldChildRel);
-      produceCorVar = true;
+      frame = map.get(oldInput);
     }
 
     // LogicalProject projects the original expressions
+    final Map<Integer, Integer> mapOldToNewOutputPos = Maps.newHashMap();
     int newPos;
-    for (newPos = 0; newPos < oldProj.size(); newPos++) {
+    for (newPos = 0; newPos < oldProjects.size(); newPos++) {
       projects.add(
           newPos,
           Pair.of(
-              decorrelateExpr(oldProj.get(newPos)),
+              decorrelateExpr(oldProjects.get(newPos)),
               relOutput.get(newPos).getName()));
       mapOldToNewOutputPos.put(newPos, newPos);
     }
 
-    SortedMap<Correlation, Integer> mapCorVarToOutputPos = Maps.newTreeMap();
-
-    // Project any correlated variables the child wants to pass along.
-    if (produceCorVar) {
-      SortedMap<Correlation, Integer> childMapCorVarToOutputPos =
-          mapNewRelToMapCorVarToOutputPos.get(newChildRel);
-
-      // propagate cor vars from the new child
-      List<RelDataTypeField> newChildOutput =
-          newChildRel.getRowType().getFieldList();
-      for (Correlation corVar
-          : childMapCorVarToOutputPos.keySet()) {
-        int corVarPos = childMapCorVarToOutputPos.get(corVar);
-        projects.add(RexInputRef.of2(corVarPos, newChildOutput));
-        mapCorVarToOutputPos.put(corVar, newPos);
-        newPos++;
-      }
+    // Project any correlated variables the input wants to pass along.
+    final SortedMap<Correlation, Integer> mapCorVarToOutputPos = new TreeMap<>();
+    for (Map.Entry<Correlation, Integer> entry : frame.corVarOutputPos.entrySet()) {
+      projects.add(
+          RexInputRef.of2(entry.getValue(),
+              frame.r.getRowType().getFieldList()));
+      mapCorVarToOutputPos.put(entry.getKey(), newPos);
+      newPos++;
     }
 
-    RelNode newProjectRel =
-        RelOptUtil.createProject(newChildRel, projects, false);
-
-    mapOldToNewRel.put(rel, newProjectRel);
-    mapNewRelToMapOldToNewOutputPos.put(
-        newProjectRel,
-        mapOldToNewOutputPos);
+    RelNode newProject =
+        RelOptUtil.createProject(frame.r, projects, false);
 
-    if (produceCorVar) {
-      mapNewRelToMapCorVarToOutputPos.put(
-          newProjectRel,
-          mapCorVarToOutputPos);
-    }
+    return register(rel, newProject, mapOldToNewOutputPos,
+        mapCorVarToOutputPos);
   }
 
   /**
@@ -712,44 +696,37 @@ public class RelDecorrelator implements ReflectiveVisitor {
       Iterable<Correlation> correlations,
       int valueGenFieldOffset,
       SortedMap<Correlation, Integer> mapCorVarToOutputPos) {
-    RelNode resultRel = null;
+    final Map<RelNode, List<Integer>> mapNewInputToOutputPos =
+        new HashMap<>();
 
-    Map<RelNode, List<Integer>> mapNewInputRelToOutputPos = Maps.newHashMap();
-
-    Map<RelNode, Integer> mapNewInputRelToNewOffset = Maps.newHashMap();
-
-    RelNode oldInputRel;
-    RelNode newInputRel;
-    List<Integer> newLocalOutputPosList;
+    final Map<RelNode, Integer> mapNewInputToNewOffset = new HashMap<>();
 
     // inputRel provides the definition of a correlated variable.
     // Add to map all the referenced positions(relative to each input rel)
     for (Correlation corVar : correlations) {
-      int oldCorVarOffset = corVar.field;
+      final int oldCorVarOffset = corVar.field;
 
-      oldInputRel = cm.mapCorVarToCorRel.get(corVar.corr).getInput(0);
-      assert oldInputRel != null;
-      newInputRel = mapOldToNewRel.get(oldInputRel);
-      assert newInputRel != null;
+      final RelNode oldInput = getCorRel(corVar);
+      assert oldInput != null;
+      final Frame frame = map.get(oldInput);
+      assert frame != null;
+      final RelNode newInput = frame.r;
 
-      if (!mapNewInputRelToOutputPos.containsKey(newInputRel)) {
+      final List<Integer> newLocalOutputPosList;
+      if (!mapNewInputToOutputPos.containsKey(newInput)) {
         newLocalOutputPosList = Lists.newArrayList();
       } else {
         newLocalOutputPosList =
-            mapNewInputRelToOutputPos.get(newInputRel);
+            mapNewInputToOutputPos.get(newInput);
       }
 
-      Map<Integer, Integer> mapOldToNewOutputPos =
-          mapNewRelToMapOldToNewOutputPos.get(newInputRel);
-      assert mapOldToNewOutputPos != null;
-
-      int newCorVarOffset = mapOldToNewOutputPos.get(oldCorVarOffset);
+      final int newCorVarOffset = frame.oldToNewOutputPos.get(oldCorVarOffset);
 
       // Add all unique positions referenced.
       if (!newLocalOutputPosList.contains(newCorVarOffset)) {
         newLocalOutputPosList.add(newCorVarOffset);
       }
-      mapNewInputRelToOutputPos.put(newInputRel, newLocalOutputPosList);
+      mapNewInputToOutputPos.put(newInput, newLocalOutputPosList);
     }
 
     int offset = 0;
@@ -759,33 +736,34 @@ public class RelDecorrelator implements ReflectiveVisitor {
     // To make sure the plan does not change in terms of join order,
     // join these rels based on their occurrence in cor var list which
     // is sorted.
-    Set<RelNode> joinedInputRelSet = Sets.newHashSet();
+    final Set<RelNode> joinedInputRelSet = Sets.newHashSet();
 
+    RelNode r = null;
     for (Correlation corVar : correlations) {
-      oldInputRel = cm.mapCorVarToCorRel.get(corVar.corr).getInput(0);
-      assert oldInputRel != null;
-      newInputRel = mapOldToNewRel.get(oldInputRel);
-      assert newInputRel != null;
+      final RelNode oldInput = getCorRel(corVar);
+      assert oldInput != null;
+      final RelNode newInput = map.get(oldInput).r;
+      assert newInput != null;
 
-      if (!joinedInputRelSet.contains(newInputRel)) {
-        RelNode projectRel =
+      if (!joinedInputRelSet.contains(newInput)) {
+        RelNode project =
             RelOptUtil.createProject(
-                newInputRel,
-                mapNewInputRelToOutputPos.get(newInputRel));
-        RelNode distinctRel = RelOptUtil.createDistinctRel(projectRel);
-        RelOptCluster cluster = distinctRel.getCluster();
+                newInput,
+                mapNewInputToOutputPos.get(newInput));
+        RelNode distinct = RelOptUtil.createDistinctRel(project);
+        RelOptCluster cluster = distinct.getCluster();
 
-        joinedInputRelSet.add(newInputRel);
-        mapNewInputRelToNewOffset.put(newInputRel, offset);
-        offset += distinctRel.getRowType().getFieldCount();
+        joinedInputRelSet.add(newInput);
+        mapNewInputToNewOffset.put(newInput, offset);
+        offset += distinct.getRowType().getFieldCount();
 
-        if (resultRel == null) {
-          resultRel = distinctRel;
+        if (r == null) {
+          r = distinct;
         } else {
-          resultRel =
-              LogicalJoin.create(resultRel, distinctRel,
+          r =
+              LogicalJoin.create(r, distinct,
                   cluster.getRexBuilder().makeLiteral(true),
-                  JoinRelType.INNER, ImmutableSet.<String>of());
+                  ImmutableSet.<CorrelationId>of(), JoinRelType.INNER);
         }
       }
     }
@@ -794,27 +772,26 @@ public class RelDecorrelator implements ReflectiveVisitor {
     // the join output, leaving room for valueGenFieldOffset because
     // valueGenerators are joined with the original left input of the rel
     // referencing correlated variables.
-    int newOutputPos;
-    int newLocalOutputPos;
     for (Correlation corVar : correlations) {
-      // The first child of a correlatorRel is always the rel defining
+      // The first input of a Correlator is always the rel defining
       // the correlated variables.
-      newInputRel =
-          mapOldToNewRel.get(cm.mapCorVarToCorRel.get(corVar.corr).getInput(0));
-      newLocalOutputPosList = mapNewInputRelToOutputPos.get(newInputRel);
+      final RelNode oldInput = getCorRel(corVar);
+      assert oldInput != null;
+      final Frame frame = map.get(oldInput);
+      final RelNode newInput = frame.r;
+      assert newInput != null;
 
-      Map<Integer, Integer> mapOldToNewOutputPos =
-          mapNewRelToMapOldToNewOutputPos.get(newInputRel);
-      assert mapOldToNewOutputPos != null;
+      final List<Integer> newLocalOutputPosList =
+          mapNewInputToOutputPos.get(newInput);
 
-      newLocalOutputPos = mapOldToNewOutputPos.get(corVar.field);
+      final int newLocalOutputPos = frame.oldToNewOutputPos.get(corVar.field);
 
       // newOutputPos is the index of the cor var in the referenced
       // position list plus the offset of referenced position list of
-      // each newInputRel.
-      newOutputPos =
+      // each newInput.
+      final int newOutputPos =
           newLocalOutputPosList.indexOf(newLocalOutputPos)
-              + mapNewInputRelToNewOffset.get(newInputRel)
+              + mapNewInputToNewOffset.get(newInput)
               + valueGenFieldOffset;
 
       if (mapCorVarToOutputPos.containsKey(corVar)) {
@@ -823,53 +800,47 @@ public class RelDecorrelator implements ReflectiveVisitor {
       mapCorVarToOutputPos.put(corVar, newOutputPos);
     }
 
-    return resultRel;
+    return r;
   }
 
-  private void decorrelateInputWithValueGenerator(
-      RelNode rel) {
-    // currently only handles one child input
-    assert rel.getInputs().size() == 1;
-    RelNode oldChildRel = rel.getInput(0);
-    RelNode newChildRel = mapOldToNewRel.get(oldChildRel);
-
-    Map<Integer, Integer> childMapOldToNewOutputPos =
-        mapNewRelToMapOldToNewOutputPos.get(newChildRel);
-    assert childMapOldToNewOutputPos != null;
+  private RelNode getCorRel(Correlation corVar) {
+    final RelNode r = cm.mapCorVarToCorRel.get(corVar.corr);
+    RelNode r2 = r.getInput(0);
+    if (r2 instanceof Join) {
+      r2 = r2.getInput(0);
+    }
+    return r2;
+  }
 
-    SortedMap<Correlation, Integer> mapCorVarToOutputPos = Maps.newTreeMap();
+  private void decorrelateInputWithValueGenerator(RelNode rel) {
+    // currently only handles one input input
+    assert rel.getInputs().size() == 1;
+    RelNode oldInput = rel.getInput(0);
+    final Frame frame = map.get(oldInput);
 
-    if (mapNewRelToMapCorVarToOutputPos.containsKey(newChildRel)) {
-      mapCorVarToOutputPos.putAll(
-          mapNewRelToMapCorVarToOutputPos.get(newChildRel));
-    }
+    final SortedMap<Correlation, Integer> mapCorVarToOutputPos =
+        new TreeMap<>(frame.corVarOutputPos);
 
     final Collection<Correlation> corVarList = cm.mapRefRelToCorVar.get(rel);
 
-    RelNode newLeftChildRel = newChildRel;
-
-    int leftChildOutputCount = newLeftChildRel.getRowType().getFieldCount();
+    int leftInputOutputCount = frame.r.getRowType().getFieldCount();
 
     // can directly add positions into mapCorVarToOutputPos since join
-    // does not change the output ordering from the children.
-    RelNode valueGenRel =
+    // does not change the output ordering from the inputs.
+    RelNode valueGen =
         createValueGenerator(
             corVarList,
-            leftChildOutputCount,
+            leftInputOutputCount,
             mapCorVarToOutputPos);
 
-    final Set<String> variablesStopped = Collections.emptySet();
-    RelNode joinRel =
-        LogicalJoin.create(newLeftChildRel, valueGenRel,
-            rexBuilder.makeLiteral(true), JoinRelType.INNER, variablesStopped);
-
-    mapOldToNewRel.put(oldChildRel, joinRel);
-    mapNewRelToMapCorVarToOutputPos.put(joinRel, mapCorVarToOutputPos);
+    RelNode join =
+        LogicalJoin.create(frame.r, valueGen, rexBuilder.makeLiteral(true),
+            ImmutableSet.<CorrelationId>of(), JoinRelType.INNER);
 
     // LogicalJoin or LogicalFilter does not change the old input ordering. All
     // input fields from newLeftInput(i.e. the original input to the old
     // LogicalFilter) are in the output and in the same position.
-    mapNewRelToMapOldToNewOutputPos.put(joinRel, childMapOldToNewOutputPos);
+    register(oldInput, join, frame.oldToNewOutputPos, mapCorVarToOutputPos);
   }
 
   /**
@@ -877,7 +848,7 @@ public class RelDecorrelator implements ReflectiveVisitor {
    *
    * @param rel the filter rel to rewrite
    */
-  public void decorrelateRel(LogicalFilter rel) {
+  public Frame decorrelateRel(LogicalFilter rel) {
     //
     // Rewrite logic:
     //
@@ -894,53 +865,36 @@ public class RelDecorrelator implements ReflectiveVisitor {
     // rewrite the filter condition using new input.
     //
 
-    RelNode oldChildRel = rel.getInput();
-
-    RelNode newChildRel = mapOldToNewRel.get(oldChildRel);
-    if (newChildRel == null) {
-      // If child has not been rewritten, do not rewrite this rel.
-      return;
+    final RelNode oldInput = rel.getInput();
+    Frame frame = getInvoke(oldInput, rel);
+    if (frame == null) {
+      // If input has not been rewritten, do not rewrite this rel.
+      return null;
     }
 
-    Map<Integer, Integer> childMapOldToNewOutputPos =
-        mapNewRelToMapOldToNewOutputPos.get(newChildRel);
-    assert childMapOldToNewOutputPos != null;
-
-    boolean produceCorVar =
-        mapNewRelToMapCorVarToOutputPos.containsKey(newChildRel);
-
     // If this LogicalFilter has correlated reference, create value generator
     // and produce the correlated variables in the new output.
     if (cm.mapRefRelToCorVar.containsKey(rel)) {
       decorrelateInputWithValueGenerator(rel);
 
-      // The old child should be mapped to the newly created LogicalJoin by
+      // The old input should be mapped to the newly created LogicalJoin by
       // rewriteInputWithValueGenerator().
-      newChildRel = mapOldToNewRel.get(oldChildRel);
-      produceCorVar = true;
+      frame = map.get(oldInput);
     }
 
     // Replace the filter expression to reference output of the join
     // Map filter to the new filter over join
-    RelNode newFilterRel =
+    RelNode newFilter =
         RelOptUtil.createFilter(
-            newChildRel,
+            frame.r,
             decorrelateExpr(rel.getCondition()));
 
-    mapOldToNewRel.put(rel, newFilterRel);
-
     // Filter does not change the input ordering.
-    mapNewRelToMapOldToNewOutputPos.put(
-        newFilterRel,
-        childMapOldToNewOutputPos);
-
-    if (produceCorVar) {
-      // filter rel does not permute the input all corvars produced by
-      // filter will have the same output positions in the child rel.
-      mapNewRelToMapCorVarToOutputPos.put(
-          newFilterRel,
-          mapNewRelToMapCorVarToOutputPos.get(newChildRel));
-    }
+    // Filter rel does not permute the input.
+    // All corvars produced by filter will have the same output positions in the
+    // input rel.
+    return register(rel, newFilter, frame.oldToNewOutputPos,
+        frame.corVarOutputPos);
   }
 
   /**
@@ -948,7 +902,7 @@ public class RelDecorrelator implements ReflectiveVisitor {
    *
    * @param rel Correlator
    */
-  public void decorrelateRel(LogicalCorrelate rel) {
+  public Frame decorrelateRel(LogicalCorrelate rel) {
     //
     // Rewrite logic:
     //
@@ -959,126 +913,93 @@ public class RelDecorrelator implements ReflectiveVisitor {
     //
 
     // the right input to Correlator should produce correlated variables
-    RelNode oldLeftRel = rel.getInputs().get(0);
-    RelNode oldRightRel = rel.getInputs().get(1);
+    final RelNode oldLeft = rel.getInput(0);
+    final RelNode oldRight = rel.getInput(1);
 
-    RelNode newLeftRel = mapOldToNewRel.get(oldLeftRel);
-    RelNode newRightRel = mapOldToNewRel.get(oldRightRel);
+    final Frame leftFrame = getInvoke(oldLeft, rel);
+    final Frame rightFrame = getInvoke(oldRight, rel);
 
-    if ((newLeftRel == null) || (newRightRel == null)) {
-      // If any child has not been rewritten, do not rewrite this rel.
-      return;
+    if (leftFrame == null || rightFrame == null) {
+      // If any input has not been rewritten, do not rewrite this rel.
+      return null;
     }
 
-    SortedMap<Correlation, Integer> rightChildMapCorVarToOutputPos =
-        mapNewRelToMapCorVarToOutputPos.get(newRightRel);
-
-    if (rightChildMapCorVarToOutputPos == null) {
-      return;
+    if (rightFrame.corVarOutputPos.isEmpty()) {
+      return null;
     }
 
-    Map<Integer, Integer> leftChildMapOldToNewOutputPos =
-        mapNewRelToMapOldToNewOutputPos.get(newLeftRel);
-    assert leftChildMapOldToNewOutputPos != null;
-
-    Map<Integer, Integer> rightChildMapOldToNewOutputPos =
-        mapNewRelToMapOldToNewOutputPos.get(newRightRel);
-
-    assert rightChildMapOldToNewOutputPos != null;
-
-    SortedMap<Correlation, Integer> mapCorVarToOutputPos =
-        rightChildMapCorVarToOutputPos;
-
     assert rel.getRequiredColumns().cardinality()
-        <= rightChildMapCorVarToOutputPos.keySet().size();
+        <= rightFrame.corVarOutputPos.keySet().size();
 
     // Change correlator rel into a join.
     // Join all the correlated variables produced by this correlator rel
     // with the values generated and propagated from the right input
-    RexNode condition = rexBuilder.makeLiteral(true);
+    final SortedMap<Correlation, Integer> corVarOutputPos =
+        new TreeMap<>(rightFrame.corVarOutputPos);
+    final List<RexNode> conditions = new ArrayList<>();
     final List<RelDataTypeField> newLeftOutput =
-        newLeftRel.getRowType().getFieldList();
+        leftFrame.r.getRowType().getFieldList();
     int newLeftFieldCount = newLeftOutput.size();
 
     final List<RelDataTypeField> newRightOutput =
-        newRightRel.getRowType().getFieldList();
+        rightFrame.r.getRowType().getFieldList();
 
-    int newLeftPos;
-    int newRightPos;
     for (Map.Entry<Correlation, Integer> rightOutputPos
-        : Lists.newArrayList(rightChildMapCorVarToOutputPos.entrySet())) {
-      Correlation corVar = rightOutputPos.getKey();
+        : Lists.newArrayList(corVarOutputPos.entrySet())) {
+      final Correlation corVar = rightOutputPos.getKey();
       if (!corVar.corr.equals(rel.getCorrelationId())) {
         continue;
       }
-      newLeftPos = leftChildMapOldToNewOutputPos.get(corVar.field);
-      newRightPos = rightChildMapCorVarToOutputPos.get(corVar);
-      RexNode equi =
-          rexBuilder.makeCall(
-              SqlStdOperatorTable.EQUALS,
+      final int newLeftPos = leftFrame.oldToNewOutputPos.get(corVar.field);
+      final int newRightPos = rightOutputPos.getValue();
+      conditions.add(
+          rexBuilder.makeCall(SqlStdOperatorTable.EQUALS,
               RexInputRef.of(newLeftPos, newLeftOutput),
-              new RexInputRef(
-                  newLeftFieldCount + newRightPos,
-                  newRightOutput.get(newRightPos).getType()));
-      if (condition == rexBuilder.makeLiteral(true)) {
-        condition = equi;
-      } else {
-        condition =
-            rexBuilder.makeCall(
-                SqlStdOperatorTable.AND,
-                condition,
-                equi);
-      }
+              new RexInputRef(newLeftFieldCount + newRightPos,
+                  newRightOutput.get(newRightPos).getType())));
 
       // remove this cor var from output position mapping
-      mapCorVarToOutputPos.remove(corVar);
+      corVarOutputPos.remove(corVar);
     }
 
     // Update the output position for the cor vars: only pass on the cor
     // vars that are not used in the join key.
-    for (Correlation corVar : mapCorVarToOutputPos.keySet()) {
-      int newPos = mapCorVarToOutputPos.get(corVar) + newLeftFieldCount;
-      mapCorVarToOutputPos.put(corVar, newPos);
+    for (Correlation corVar : corVarOutputPos.keySet()) {
+      int newPos = corVarOutputPos.get(corVar) + newLeftFieldCount;
+      corVarOutputPos.put(corVar, newPos);
     }
 
     // then add any cor var from the left input. Do not need to change
     // output positions.
-    if (mapNewRelToMapCorVarToOutputPos.containsKey(newLeftRel)) {
-      mapCorVarToOutputPos.putAll(
-          mapNewRelToMapCorVarToOutputPos.get(newLeftRel));
-    }
+    corVarOutputPos.putAll(leftFrame.corVarOutputPos);
 
     // Create the mapping between the output of the old correlation rel
     // and the new join rel
-    Map<Integer, Integer> mapOldToNewOutputPos = Maps.newHashMap();
+    final Map<Integer, Integer> mapOldToNewOutputPos = Maps.newHashMap();
 
-    int oldLeftFieldCount = oldLeftRel.getRowType().getFieldCount();
+    int oldLeftFieldCount = oldLeft.getRowType().getFieldCount();
 
-    int oldRightFieldCount = oldRightRel.getRowType().getFieldCount();
+    int oldRightFieldCount = oldRight.getRowType().getFieldCount();
     assert rel.getRowType().getFieldCount()
         == oldLeftFieldCount + oldRightFieldCount;
 
     // Left input positions are not changed.
-    mapOldToNewOutputPos.putAll(leftChildMapOldToNewOutputPos);
+    mapOldToNewOutputPos.putAll(leftFrame.oldToNewOutputPos);
 
     // Right input positions are shifted by newLeftFieldCount.
     for (int i = 0; i < oldRightFieldCount; i++) {
       mapOldToNewOutputPos.put(
           i + oldLeftFieldCount,
-          rightChildMapOldToNewOutputPos.get(i) + newLeftFieldCount);
+          rightFrame.oldToNewOutputPos.get(i) + newLeftFieldCount);
     }
 
-    final Set<String> variablesStopped = Collections.emptySet();
-    RelNode newRel =
-        LogicalJoin.create(newLeftRel, newRightRel, condition,
-            rel.getJoinType().toJoinType(), variablesStopped);
-
-    mapOldToNewRel.put(rel, newRel);
-    mapNewRelToMapOldToNewOutputPos.put(newRel, mapOldToNewOutputPos);
+    final RexNode condition =
+        RexUtil.composeConjunction(rexBuilder, conditions, false);
+    RelNode newJoin =
+        LogicalJoin.create(leftFrame.r, rightFrame.r, condition,
+            ImmutableSet.<CorrelationId>of(), rel.getJoinType().toJoinType());
 
-    if (!mapCorVarToOutputPos.isEmpty()) {
-      mapNewRelToMapCorVarToOutputPos.put(newRel, mapCorVarToOutputPos);
-    }
+    return register(rel, newJoin, mapOldToNewOutputPos, corVarOutputPos);
   }
 
   /**
@@ -1086,7 +1007,7 @@ public class RelDecorrelator implements ReflectiveVisitor {
    *
    * @param rel LogicalJoin
    */
-  public void decorrelateRel(LogicalJoin rel) {
+  public Frame decorrelateRel(LogicalJoin rel) {
     //
     // Rewrite logic:
     //
@@ -1094,77 +1015,52 @@ public class RelDecorrelator implements ReflectiveVisitor {
     // 2. map output positions and produce cor vars if any.
     //
 
-    RelNode oldLeftRel = rel.getInputs().get(0);
-    RelNode oldRightRel = rel.getInputs().get(1);
+    final RelNode oldLeft = rel.getInput(0);
+    final RelNode oldRight = rel.getInput(1);
 
-    RelNode newLeftRel = mapOldToNewRel.get(oldLeftRel);
-    RelNode newRightRel = mapOldToNewRel.get(oldRightRel);
+    final Frame leftFrame = getInvoke(oldLeft, rel);
+    final Frame rightFrame = getInvoke(oldRight, rel);
 
-    if ((newLeftRel == null) || (newRightRel == null)) {
-      // If any child has not been rewritten, do not rewrite this rel.
-      return;
+    if (leftFrame == null || rightFrame == null) {
+      // If any input has not been rewritten, do not rewrite this rel.
+      return null;
     }
 
-    Map<Integer, Integer> leftChildMapOldToNewOutputPos =
-        mapNewRelToMapOldToNewOutputPos.get(newLeftRel);
-    assert leftChildMapOldToNewOutputPos != null;
-
-    Map<Integer, Integer> rightChildMapOldToNewOutputPos =
-        mapNewRelToMapOldToNewOutputPos.get(newRightRel);
-    assert rightChildMapOldToNewOutputPos != null;
-
-    SortedMap<Correlation, Integer> mapCorVarToOutputPos = Maps.newTreeMap();
-
-    final Set<String> variablesStopped = Collections.emptySet();
-    RelNode newRel =
-        LogicalJoin.create(newLeftRel, newRightRel,
-            decorrelateExpr(rel.getCondition()), rel.getJoinType(),
-            variablesStopped);
+    final RelNode newJoin =
+        LogicalJoin.create(leftFrame.r, rightFrame.r,
+            decorrelateExpr(rel.getCondition()),
+            ImmutableSet.<CorrelationId>of(), rel.getJoinType());
 
     // Create the mapping between the output of the old correlation rel
     // and the new join rel
     Map<Integer, Integer> mapOldToNewOutputPos = Maps.newHashMap();
 
-    int oldLeftFieldCount = oldLeftRel.getRowType().getFieldCount();
-    int newLeftFieldCount = newLeftRel.getRowType().getFieldCount();
+    int oldLeftFieldCount = oldLeft.getRowType().getFieldCount();
+    int newLeftFieldCount = leftFrame.r.getRowType().getFieldCount();
 
-    int oldRightFieldCount = oldRightRel.getRowType().getFieldCount();
+    int oldRightFieldCount = oldRight.getRowType().getFieldCount();
     assert rel.getRowType().getFieldCount()
         == oldLeftFieldCount + oldRightFieldCount;
 
     // Left input positions are not changed.
-    mapOldToNewOutputPos.putAll(leftChildMapOldToNewOutputPos);
+    mapOldToNewOutputPos.putAll(leftFrame.oldToNewOutputPos);
 
     // Right input positions are shifted by newLeftFieldCount.
     for (int i = 0; i < oldRightFieldCount; i++) {
-      mapOldToNewOutputPos.put(
-          i + oldLeftFieldCount,
-          rightChildMapOldToNewOutputPos.get(i) + newLeftFieldCount);
+      mapOldToNewOutputPos.put(i + oldLeftFieldCount,
+          rightFrame.oldToNewOutputPos.get(i) + newLeftFieldCount);
     }
 
-    if (mapNewRelToMapCorVarToOutputPos.containsKey(newLeftRel)) {
-      mapCorVarToOutputPos.putAll(
-          mapNewRelToMapCorVarToOutputPos.get(newLeftRel));
-    }
+    final SortedMap<Correlation, Integer> mapCorVarToOutputPos =
+        new TreeMap<>(leftFrame.corVarOutputPos);
 
     // Right input positions are shifted by newLeftFieldCount.
-    int oldRightPos;
-    if (mapNewRelToMapCorVarToOutputPos.containsKey(newRightRel)) {
-      SortedMap<Correlation, Integer> rightChildMapCorVarToOutputPos =
-          mapNewRelToMapCorVarToOutputPos.get(newRightRel);
-      for (Correlation corVar : rightChildMapCorVarToOutputPos.keySet()) {
-        oldRightPos = rightChildMapCorVarToOutputPos.get(corVar);
-        mapCorVarToOutputPos.put(
-            corVar,
-            oldRightPos + newLeftFieldCount);
-      }
-    }
-    mapOldToNewRel.put(rel, newRel);
-    mapNewRelToMapOldToNewOutputPos.put(newRel, mapOldToNewOutputPos);
-
-    if (!mapCorVarToOutputPos.isEmpty()) {
-      mapNewRelToMapCorVarToOutputPos.put(newRel, mapCorVarToOutputPos);
+    for (Map.Entry<Correlation, Integer> entry
+        : rightFrame.corVarOutputPos.entrySet()) {
+      mapCorVarToOutputPos.put(entry.getKey(),
+          entry.getValue() + newLeftFieldCount);
     }
+    return register(rel, newJoin, mapOldToNewOutputPos, mapCorVarToOutputPos);
   }
 
   private RexInputRef getNewForOldInputRef(RexInputRef oldInputRef) {
@@ -1175,61 +1071,57 @@ public class RelDecorrelator implements ReflectiveVisitor {
 
     // determine which input rel oldOrdinal references, and adjust
     // oldOrdinal to be relative to that input rel
-    List<RelNode> oldInputRels = currentRel.getInputs();
-    RelNode oldInputRel = null;
+    RelNode oldInput = null;
 
-    for (RelNode oldInputRel0 : oldInputRels) {
-      RelDataType oldInputType = oldInputRel0.getRowType();
+    for (RelNode oldInput0 : currentRel.getInputs()) {
+      RelDataType oldInputType = oldInput0.getRowType();
       int n = oldInputType.getFieldCount();
       if (oldOrdinal < n) {
-        oldInputRel = oldInputRel0;
+        oldInput = oldInput0;
         break;
       }
-      RelNode newInput = mapOldToNewRel.get(oldInputRel0);
+      RelNode newInput = map.get(oldInput0).r;
       newOrdinal += newInput.getRowType().getFieldCount();
       oldOrdinal -= n;
     }
 
-    assert oldInputRel != null;
+    assert oldInput != null;
 
-    RelNode newInputRel = mapOldToNewRel.get(oldInputRel);
-    assert newInputRel != null;
+    final Frame frame = map.get(oldInput);
+    assert frame != null;
 
-    // now oldOrdinal is relative to oldInputRel
+    // now oldOrdinal is relative to oldInput
     int oldLocalOrdinal = oldOrdinal;
 
-    // figure out the newLocalOrdinal, relative to the newInputRel.
+    // figure out the newLocalOrdinal, relative to the newInput.
     int newLocalOrdinal = oldLocalOrdinal;
 
-    Map<Integer, Integer> mapOldToNewOutputPos =
-        mapNewRelToMapOldToNewOutputPos.get(newInputRel);
-
-    if (mapOldToNewOutputPos != null) {
-      newLocalOrdinal = mapOldToNewOutputPos.get(oldLocalOrdinal);
+    if (!frame.oldToNewOutputPos.isEmpty()) {
+      newLocalOrdinal = frame.oldToNewOutputPos.get(oldLocalOrdinal);
     }
 
     newOrdinal += newLocalOrdinal;
 
     return new RexInputRef(newOrdinal,
-        newInputRel.getRowType().getFieldList().get(newLocalOrdinal).getType());
+        frame.r.getRowType().getFieldList().get(newLocalOrdinal).getType());
   }
 
   /**
-   * Pull projRel above the join from its RHS input. Enforce nullability
+   * Pulls project above the join from its RHS input. Enforces nullability
    * for join output.
    *
    * @param join          Join
-   * @param projRel          the original projRel as the RHS input of the join.
+   * @param project       Original project as the right-hand input of the join
    * @param nullIndicatorPos Position of null indicator
    * @return the subtree with the new LogicalProject at the root
    */
   private RelNode projectJoinOutputWithNullability(
       LogicalJoin join,
-      LogicalProject projRel,
+      LogicalProject project,
       int nullIndicatorPos) {
-    RelDataTypeFactory typeFactory = join.getCluster().getTypeFactory();
-    RelNode leftInputRel = join.getLeft();
-    JoinRelType joinType = join.getJoinType();
+    final RelDataTypeFactory typeFactory = join.getCluster().getTypeFactory();
+    final RelNode left = join.getLeft();
+    final JoinRelType joinType = join.getJoinType();
 
     RexInputRef nullIndicator =
         new RexInputRef(
@@ -1245,7 +1137,7 @@ public class RelDecorrelator implements ReflectiveVisitor {
     // project everything from the LHS and then those from the original
     // projRel
     List<RelDataTypeField> leftInputFields =
-        leftInputRel.getRowType().getFieldList();
+        left.getRowType().getFieldList();
 
     for (int i = 0; i < leftInputFields.size(); i++) {
       newProjExprs.add(RexInputRef.of2(i, leftInputFields));
@@ -1257,7 +1149,7 @@ public class RelDecorrelator implements ReflectiveVisitor {
     boolean projectPulledAboveLeftCorrelator =
         joinType.generatesNullsOnRight();
 
-    for (Pair<RexNode, String> pair : projRel.getNamedProjects()) {
+    for (Pair<RexNode, String> pair : project.getNamedProjects()) {
       RexNode newProjExpr =
           removeCorrelationExpr(
               pair.left,
@@ -1267,36 +1159,33 @@ public class RelDecorrelator implements ReflectiveVisitor {
       newProjExprs.add(Pair.of(newProjExpr, pair.right));
     }
 
-    RelNode newProjRel =
-        RelOptUtil.createProject(join, newProjExprs, false);
-
-    return newProjRel;
+    return RelOptUtil.createProject(join, newProjExprs, false);
   }
 
   /**
-   * Pulls projRel above the joinRel from its RHS input. Enforces nullability
-   * for join output.
+   * Pulls a {@link Project} above a {@link Correlate} from its RHS input.
+   * Enforces nullability for join output.
    *
-   * @param corRel  Correlator
-   * @param projRel the original LogicalProject as the RHS input of the join
+   * @param correlate  Correlate
+   * @param project the original project as the RHS input of the join
    * @param isCount Positions which are calls to the <code>COUNT</code>
    *                aggregation function
    * @return the subtree with the new LogicalProject at the root
    */
   private RelNode aggregateCorrelatorOutput(
-      LogicalCorrelate corRel,
-      LogicalProject projRel,
+      Correlate correlate,
+      LogicalProject project,
       Set<Integer> isCount) {
-    RelNode leftInputRel = corRel.getLeft();
-    JoinRelType joinType = corRel.getJoinType().toJoinType();
+    final RelNode left = correlate.getLeft();
+    final JoinRelType joinType = correlate.getJoinType().toJoinType();
 
     // now create the new project
-    List<Pair<RexNode, String>> newProjects = Lists.newArrayList();
+    final List<Pair<RexNode, String>> newProjects = Lists.newArrayList();
 
-    // project everything from the LHS and then those from the original
-    // projRel
-    List<RelDataTypeField> leftInputFields =
-        leftInputRel.getRowType().getFieldList();
+    // Project everything from the LHS and then those from the original
+    // project
+    final List<RelDataTypeField> leftInputFields =
+        left.getRowType().getFieldList();
 
     for (int i = 0; i < leftInputFields.size(); i++) {
       newProjects.add(RexInputRef.of2(i, leftInputFields));
@@ -1308,7 +1197,7 @@ public class RelDecorrelator implements ReflectiveVisitor {
     boolean projectPulledAboveLeftCorrelator =
         joinType.generatesNullsOnRight();
 
-    for (Pair<RexNode, String> pair : projRel.getNamedProjects()) {
+    for (Pair<RexNode, String> pair : project.getNamedProjects()) {
       RexNode newProjExpr =
           removeCorrelationExpr(
               pair.left,
@@ -1317,22 +1206,22 @@ public class RelDecorrelator implements ReflectiveVisitor {
       newProjects.add(Pair.of(newProjExpr, pair.right));
     }
 
-    return RelOptUtil.createProject(corRel, newProjects, false);
+    return RelOptUtil.createProject(correlate, newProjects, false);
   }
 
   /**
    * Checks whether the correlations in projRel and filter are related to
    * the correlated variables provided by corRel.
    *
-   * @param corRel    Correlator
-   * @param projRel   The original Project as the RHS input of the join
+   * @param correlate    Correlate
+   * @param project   The original Project as the RHS input of the join
    * @param filter    Filter
    * @param correlatedJoinKeys Correlated join keys
    * @return true if filter and proj only references corVar provided by corRel
    */
   private boolean checkCorVars(
-      LogicalCorrelate corRel,
-      LogicalProject projRel,
+      LogicalCorrelate correlate,
+      LogicalProject project,
       LogicalFilter filter,
       List<RexFieldAccess> correlatedJoinKeys) {
     if (filter != null) {
@@ -1344,8 +1233,7 @@ public class RelDecorrelator implements ReflectiveVisitor {
           Sets.newHashSet(cm.mapRefRelToCorVar.get(filter));
 
       for (RexFieldAccess correlatedJoinKey : correlatedJoinKeys) {
-        corVarInFilter.remove(
-            cm.mapFieldAccessToCorVar.get(correlatedJoinKey));
+        corVarInFilter.remove(cm.mapFieldAccessToCorVar.get(correlatedJoinKey));
       }
 
       if (!corVarInFilter.isEmpty()) {
@@ -1357,18 +1245,18 @@ public class RelDecorrelator implements ReflectiveVisitor {
       corVarInFilter.addAll(cm.mapRefRelToCorVar.get(filter));
 
       for (Correlation corVar : corVarInFilter) {
-        if (cm.mapCorVarToCorRel.get(corVar.corr) != corRel) {
+        if (cm.mapCorVarToCorRel.get(corVar.corr) != correlate) {
           return false;
         }
       }
     }
 
-    // if projRel has any correlated reference, make sure they are also
-    // provided by the current corRel. They will be projected out of the LHS
-    // of the corRel.
-    if ((projRel != null) && cm.mapRefRelToCorVar.containsKey(projRel)) {
-      for (Correlation corVar : cm.mapRefRelToCorVar.get(projRel)) {
-        if (cm.mapCorVarToCorRel.get(corVar.corr) != corRel) {
+    // if project has any correlated reference, make sure they are also
+    // provided by the current correlate. They will be projected out of the LHS
+    // of the correlate.
+    if ((project != null) && cm.mapRefRelToCorVar.containsKey(project)) {
+      for (Correlation corVar : cm.mapRefRelToCorVar.get(project)) {
+        if (cm.mapCorVarToCorRel.get(corVar.corr) != correlate) {
           return false;
         }
       }
@@ -1380,26 +1268,26 @@ public class RelDecorrelator implements ReflectiveVisitor {
   /**
    * Remove correlated variables from the tree at root corRel
    *
-   * @param corRel Correlator
+   * @param correlate Correlator
    */
-  private void removeCorVarFromTree(LogicalCorrelate corRel) {
-    if (cm.mapCorVarToCorRel.get(corRel.getCorrelationId()) == corRel) {
-      cm.mapCorVarToCorRel.remove(corRel.getCorrelationId());
+  private void removeCorVarFromTree(LogicalCorrelate correlate) {
+    if (cm.mapCorVarToCorRel.get(correlate.getCorrelationId()) == correlate) {
+      cm.mapCorVarToCorRel.remove(correlate.getCorrelationId());
     }
   }
 
   /**
-   * Project all childRel output fields plus the additional expressions.
+   * Projects all {@code input} output fields plus the additional expressions.
    *
-   * @param childRel        Child relational expression
+   * @param input        Input relational expression
    * @param additionalExprs Additional expressions and names
    * @return the new LogicalProject
    */
   private RelNode createProjectWithAdditionalExprs(
-      RelNode childRel,
+      RelNode input,
       List<Pair<RexNode, String>> additionalExprs) {
     final List<RelDataTypeField> fieldList =
-        childRel.getRowType().getFieldList();
+        input.getRowType().getFieldList();
     List<Pair<RexNode, String>> projects = Lists.newArrayList();
     for (Ord<RelDataTypeField> field : Ord.zip(fieldList)) {
       projects.add(
@@ -1409,140 +1297,93 @@ public class RelDecorrelator implements ReflectiveVisitor {
               field.e.getName()));
     }
     projects.addAll(additionalExprs);
-    return RelOptUtil.createProject(childRel, projects, false);
+    return RelOptUtil.createProject(input, projects, false);
   }
 
-  //~ Inner Classes ----------------------------------------------------------
+  /* Returns an immutable map with the identity [0: 0, .., count-1: count-1]. */
+  static Map<Integer, Integer> identityMap(int count) {
+    ImmutableMap.Builder<Integer, Integer> builder = ImmutableMap.builder();
+    for (int i = 0; i < count; i++) {
+      builder.put(i, i);
+    }
+    return builder.build();
+  }
+
+  /** Registers a relational expression and the relational expression it became
+   * after decorrelation. */
+  Frame register(RelNode rel, RelNode newRel,
+      Map<Integer, Integer> oldToNewOutputPos,
+      SortedMap<Correlation, Integer> corVarToOutputPos) {
+    assert allLessThan(oldToNewOutputPos.keySet(),
+        newRel.getRowType().getFieldCount(), Litmus.THROW);
+    final Frame frame = new Frame(newRel, corVarToOutputPos, oldToNewOutputPos);
+    map.put(rel, frame);
+    return frame;
+  }
 
-  /** Visitor that decorrelates. */
-  private class DecorrelateRelVisitor extends RelVisitor {
-    private final ReflectiveVisitDispatcher<RelDecorrelator, RelNode>
-    dispatcher =
-        ReflectUtil.createDispatcher(
-            RelDecorrelator.class,
-            RelNode.class);
-
-    // implement RelVisitor
-    public void visit(RelNode p, int ordinal, RelNode parent) {
-      // rewrite children first  (from left to right)
-      super.visit(p, ordinal, parent);
-
-      currentRel = p;
-
-      final String visitMethodName = "decorrelateRel";
-      boolean found =
-          dispatcher.invokeVisitor(
-              RelDecorrelator.this,
-              currentRel,
-              visitMethodName);
-      setCurrent(null, null);
-
-      if (!found) {
-        decorrelateRelGeneric(p);
+  static boolean allLessThan(Collection<Integer> integers, int limit,
+      Litmus ret) {
+    for (int value : integers) {
+      if (value >= limit) {
+        return ret.fail("out of range; value: " + value + ", limit: " + limit);
       }
-      // else no rewrite will occur. This will terminate the bottom-up
-      // rewrite. If root node of a RelNode tree is not rewritten, the
-      // original tree will be returned. See decorrelate() method.
     }
+    return ret.succeed();
   }
 
+  private static RelNode stripHep(RelNode rel) {
+    if (rel instanceof HepRelVertex) {
+      HepRelVertex hepRelVertex = (HepRelVertex) rel;
+      rel = hepRelVertex.getCurrentRel();
+    }
+    return rel;
+  }
+
+  //~ Inner Classes ----------------------------------------------------------
+
   /** Shuttle that decorrelates. */
   private class DecorrelateRexShuttle extends RexShuttle {
-    // override RexShuttle
-    public RexNode visitFieldAccess(RexFieldAccess fieldAccess) {
-      int newInputRelOutputOffset = 0;
-      RelNode oldInputRel;
-      RelNode newInputRel;
-      Integer newInputPos;
-
-      List<RelNode> inputs = currentRel.getInputs();
-      for (int i = 0; i < inputs.size(); i++) {
-        oldInputRel = inputs.get(i);
-        newInputRel = mapOldToNewRel.get(oldInputRel);
-
-        if ((newInputRel != null)
-            && mapNewRelToMapCorVarToOutputPos.containsKey(newInputRel)) {
-          SortedMap<Correlation, Integer> childMapCorVarToOutputPos =
-              mapNewRelToMapCorVarToOutputPos.get(newInputRel);
-
-          if (childMapCorVarToOutputPos != null) {
-            // try to find in this input rel the position of cor var
-            Correlation corVar = cm.mapFieldAccessToCorVar.get(fieldAccess);
-
-            if (corVar != null) {
-              newInputPos = childMapCorVarToOutputPos.get(corVar);
-              if (newInputPos != null) {
-                // this input rel does produce the cor var
-                // referenced
-                newInputPos += newInputRelOutputOffset;
-
-                // fieldAccess is assumed to have the correct
-                // type info.
-                RexInputRef newInput =
-                    new RexInputRef(
-                        newInputPos,
-                        fieldAccess.getType());
-                return newInput;
-              }
+    @Override public RexNode visitFieldAccess(RexFieldAccess fieldAccess) {
+      int newInputOutputOffset = 0;
+      for (RelNode input : currentRel.getInputs()) {
+        final Frame frame = map.get(input);
+
+        if (frame != null) {
+          // try to find in this input rel the position of cor var
+          final Correlation corVar = cm.mapFieldAccessToCorVar.get(fieldAccess);
+
+          if (corVar != null) {
+            Integer newInputPos = frame.corVarOutputPos.get(corVar);
+            if (newInputPos != null) {
+              // This input rel does produce the cor var referenced.
+              // Assume fieldAccess has the correct type info.
+              return new RexInputRef(newInputPos + newInputOutputOffset,
+                  fieldAccess.getType());
             }
           }
 
           // this input rel does not produce the cor var needed
-          newInputRelOutputOffset +=
-              newInputRel.getRowType().getFieldCount();
+          newInputOutputOffset += frame.r.getRowType().getFieldCount();
         } else {
           // this input rel is not rewritten
-          newInputRelOutputOffset +=
-              oldInputRel.getRowType().getFieldCount();
+          newInputOutputOffset += input.getRowType().getFieldCount();
         }
       }
       return fieldAccess;
     }
 
-    // override RexShuttle
-    public RexNode visitInputRef(RexInputRef inputRef) {
-      RexInputRef newInputRef = getNewForOldInputRef(inputRef);
-      return newInputRef;
+    @Override public RexNode visitInputRef(RexInputRef inputRef) {
+      return getNewForOldInputRef(inputRef);
     }
   }
 
   /** Shuttle that removes correlations. */
   private class RemoveCorrelationRexShuttle extends RexShuttle {
-    RexBuilder rexBuilder;
-    RelDataTypeFactory typeFactory;
-    boolean projectPulledAboveLeftCorrelator;
-    RexInputRef nullIndicator;
-    Set<Integer> isCount;
-
-    public RemoveCorrelationRexShuttle(
-        RexBuilder rexBuilder,
-        boolean projectPulledAboveLeftCorrelator) {
-      this(
-          rexBuilder,
-          projectPulledAboveLeftCorrelator,
-          null, null);
-    }
-
-    public RemoveCorrelationRexShuttle(
-        RexBuilder rexBuilder,
-        boolean projectPulledAboveLeftCorrelator,
-        RexInputRef nullIndicator) {
-      this(
-          rexBuilder,
-          projectPulledAboveLeftCorrelator,
-          nullIndicator,
-          null);
-    }
-
-    public RemoveCorrelationRexShuttle(
-        RexBuilder rexBuilder,
-        boolean projectPulledAboveLeftCorrelator,
-        Set<Integer> isCount) {
-      this(
-          rexBuilder,
-          projectPulledAboveLeftCorrelator,
-          null, isCount);
-    }
+    final RexBuilder rexBuilder;
+    final RelDataTypeFactory typeFactory;
+    final boolean projectPulledAboveLeftCorrelator;
+    final RexInputRef nullIndicator;
+    final ImmutableSet<Integer> isCount;
 
     public RemoveCorrelationRexShuttle(
         RexBuilder rexBuilder,
@@ -1551,8 +1392,8 @@ public class RelDecorrelator implements ReflectiveVisitor {
         Set<Integer> isCount) {
       this.projectPulledAboveLeftCorrelator =
           projectPulledAboveLeftCorrelator;
-      this.nullIndicator = nullIndicator;
-      this.isCount = isCount;
+      this.nullIndicator = nullIndicator; // may be null
+      this.isCount = ImmutableSet.copyOf(isCount);
       this.rexBuilder = rexBuilder;
       this.typeFactory = rexBuilder.getTypeFactory();
     }
@@ -1603,8 +1444,7 @@ public class RelDecorrelator implements ReflectiveVisitor {
           caseOperands);
     }
 
-    // override RexShuttle
-    public RexNode visitFieldAccess(RexFieldAccess fieldAccess) {
+    @Override public RexNode visitFieldAccess(RexFieldAccess fieldAccess) {
       if (cm.mapFieldAccessToCorVar.containsKey(fieldAccess)) {
         // if it is a corVar, change it to be input ref.
         Correlation corVar = cm.mapFieldAccessToCorVar.get(fieldAccess);
@@ -1629,15 +1469,14 @@ public class RelDecorrelator implements ReflectiveVisitor {
       return fieldAccess;
     }
 
-    // override RexShuttle
-    public RexNode visitInputRef(RexInputRef inputRef) {
-      if ((currentRel != null) && (currentRel instanceof LogicalCorrelate)) {
+    @Override public RexNode visitInputRef(RexInputRef inputRef) {
+      if (currentRel instanceof LogicalCorrelate) {
         // if this rel references corVar
         // and now it needs to be rewritten
         // it must have been pulled above the Correlator
         // replace the input ref to account for the LHS of the
         // Correlator
-        int leftInputFieldCount =
+        final int leftInputFieldCount =
             ((LogicalCorrelate) currentRel).getLeft().getRowType()
                 .getFieldCount();
         RelDataType newType = inputRef.getType();
@@ -1663,8 +1502,7 @@ public class RelDecorrelator implements ReflectiveVisitor {
       return inputRef;
     }
 
-    // override RexLiteral
-    public RexNode visitLiteral(RexLiteral literal) {
+    @Override public RexNode visitLiteral(RexLiteral literal) {
       // Use nullIndicator to decide whether to project null.
       // Do nothing if the literal is null.
       if (!RexUtil.isNull(literal)
@@ -1678,7 +1516,7 @@ public class RelDecorrelator implements ReflectiveVisitor {
       return literal;
     }
 
-    public RexNode visitCall(final RexCall call) {
+    @Override public RexNode visitCall(final RexCall call) {
       RexNode newCall;
 
       boolean[] update = {false};
@@ -1752,14 +1590,14 @@ public class RelDecorrelator implements ReflectiveVisitor {
     }
 
     public void onMatch(RelOptRuleCall call) {
-      LogicalAggregate singleAggRel = call.rel(0);
-      LogicalProject projRel = call.rel(1);
-      LogicalAggregate aggRel = call.rel(2);
+      LogicalAggregate singleAggregate = call.rel(0);
+      LogicalProject project = call.rel(1);
+      LogicalAggregate aggregate = call.rel(2);
 
       // check singleAggRel is single_value agg
-      if ((!singleAggRel.getGroupSet().isEmpty())
-          || (singleAggRel.getAggCallList().size() != 1)
-          || !(singleAggRel.getAggCallList().get(0).getAggregation()
+      if ((!singleAggregate.getGroupSet().isEmpty())
+          || (singleAggregate.getAggCallList().size() != 1)
+          || !(singleAggregate.getAggCallList().get(0).getAggregation()
           instanceof SqlSingleValueAggFunction)) {
         return;
       }
@@ -1767,21 +1605,21 @@ public class RelDecorrelator implements ReflectiveVisitor {
       // check projRel only projects one expression
       // check this project only projects one expression, i.e. scalar
       // subqueries.
-      List<RexNode> projExprs = projRel.getProjects();
+      List<RexNode> projExprs = project.getProjects();
       if (projExprs.size() != 1) {
         return;
       }
 
       // check the input to projRel is an aggregate on the entire input
-      if (!aggRel.getGroupSet().isEmpty()) {
+      if (!aggregate.getGroupSet().isEmpty()) {
         return;
       }
 
       // singleAggRel produces a nullable type, so create the new
       // projection that casts proj expr to a nullable type.
-      final RelOptCluster cluster = projRel.getCluster();
-      RelNode newProjRel =
-          RelOptUtil.createProject(aggRel,
+      final RelOptCluster cluster = project.getCluster();
+      RelNode newProject =
+          RelOptUtil.createProject(aggregate,
               ImmutableList.of(
                   rexBuilder.makeCast(
                       cluster.getTypeFactory().createTypeWithNullability(
@@ -1789,7 +1627,7 @@ public class RelDecorrelator implements ReflectiveVisitor {
                           true),
                       projExprs.get(0))),
               null);
-      call.transformTo(newProjRel);
+      call.transformTo(newProject);
     }
   }
 
@@ -1805,14 +1643,14 @@ public class RelDecorrelator implements ReflectiveVisitor {
     }
 
     public void onMatch(RelOptRuleCall call) {
-      LogicalCorrelate corRel = call.rel(0);
-      RelNode leftInputRel = call.rel(1);
-      LogicalAggregate aggRel = call.rel(2);
-      LogicalProject projRel = call.rel(3);
-      RelNode rightInputRel = call.rel(4);
-      RelOptCluster cluster = corRel.getCluster();
+      final LogicalCorrelate correlate = call.rel(0);
+      final RelNode left = call.rel(1);
+      final LogicalAggregate aggregate = call.rel(2);
+      final LogicalProject project = call.rel(3);
+      RelNode right = call.rel(4);
+      final RelOptCluster cluster = correlate.getCluster();
 
-      setCurrent(call.getPlanner().getRoot(), corRel);
+      setCurrent(call.getPlanner().getRoot(), correlate);
 
       // Check for this pattern.
       // The pattern matching could be simplified if rules can be applied
@@ -1823,7 +1661,7 @@ public class RelDecorrelator implements ReflectiveVisitor {
       //   LogicalAggregate (groupby (0) single_value())
       //     LogicalProject-A (may reference coVar)
       //       RightInputRel
-      JoinRelType joinType = corRel.getJoinType().toJoinType();
+      final JoinRelType joinType = correlate.getJoinType().toJoinType();
 
       // corRel.getCondition was here, however Correlate was updated so it
       // never includes a join condition. The code was not modified for brevity.
@@ -1835,23 +1673,23 @@ public class RelDecorrelator implements ReflectiveVisitor {
 
       // check that the agg is of the following type:
       // doing a single_value() on the entire input
-      if ((!aggRel.getGroupSet().isEmpty())
-          || (aggRel.getAggCallList().size() != 1)
-          || !(aggRel.getAggCallList().get(0).getAggregation()
+      if ((!aggregate.getGroupSet().isEmpty())
+          || (aggregate.getAggCallList().size() != 1)
+          || !(aggregate.getAggCallList().get(0).getAggregation()
           instanceof SqlSingleValueAggFunction)) {
         return;
       }
 
       // check this project only projects one expression, i.e. scalar
       // subqueries.
-      if (projRel.getProjects().size() != 1) {
+      if (project.getProjects().size() != 1) {
         return;
       }
 
       int nullIndicatorPos;
 
-      if ((rightInputRel instanceof LogicalFilter)
-          && cm.mapRefRelToCorVar.containsKey(rightInputRel)) {
+      if ((right instanceof LogicalFilter)
+          && cm.mapRefRelToCorVar.containsKey(right)) {
         // rightInputRel has this shape:
         //
         //       LogicalFilter (references corvar)
@@ -1861,14 +1699,14 @@ public class RelDecorrelator implements ReflectiveVisitor {
         // reference, make sure the correlated keys in the filter
         // condition forms a unique key of the RHS.
 
-        LogicalFilter filter = (LogicalFilter) rightInputRel;
-        rightInputRel = filter.getInput();
+        LogicalFilter filter = (LogicalFilter) right;
+        right = filter.getInput();
 
-        assert rightInputRel instanceof HepRelVertex;
-        rightInputRel = ((HepRelVertex) rightInputRel).getCurrentRel();
+        assert right instanceof HepRelVertex;
+        right = ((HepRelVertex) right).getCurrentRel();
 
         // check filter input contains no correlation
-        if (RelOptUtil.getVariablesUsed(rightInputRel).size() > 0) {
+        if (RelOptUtil.getVariablesUsed(right).size() > 0) {
           return;
         }
 
@@ -1889,7 +1727,7 @@ public class RelDecorrelator implements ReflectiveVisitor {
 
         // check that the columns referenced in these comparisons form
         // an unique key of the filterInputRel
-        List<RexInputRef> rightJoinKeys = new ArrayList<RexInputRef>();
+        final List<RexInputRef> rightJoinKeys = new ArrayList<>();
         for (RexNode key : tmpRightJoinKeys) {
           assert key instanceof RexInputRef;
           rightJoinKeys.add((RexInputRef) key);
@@ -1904,11 +1742,11 @@ public class RelDecorrelator implements ReflectiveVisitor {
         // The join filters out the nulls.  So, it's ok if there are
         // nulls in the join keys.
         if (!RelMdUtil.areColumnsDefinitelyUniqueWhenNullsFiltered(
-            rightInputRel,
+            right,
             rightJoinKeys)) {
           SQL2REL_LOGGER.fine(rightJoinKeys.toString()
               + "are not unique keys for "
-              + rightInputRel.toString());
+              + right.toString());
           return;
         }
 
@@ -1918,7 +1756,7 @@ public class RelDecorrelator implements ReflectiveVisitor {
         List<RexFieldAccess> correlatedKeyList =
             visitor.getFieldAccessList();
 
-        if (!checkCorVars(corRel, projRel, filter, correlatedKeyList)) {
+        if (!checkCorVars(correlate, project, filter, correlatedKeyList)) {
           return;
         }
 
@@ -1935,15 +1773,15 @@ public class RelDecorrelator implements ReflectiveVisitor {
             removeCorrelationExpr(filter.getCondition(), false);
 
         nullIndicatorPos =
-            leftInputRel.getRowType().getFieldCount()
+            left.getRowType().getFieldCount()
                 + rightJoinKeys.get(0).getIndex();
-      } else if (cm.mapRefRelToCorVar.containsKey(projRel)) {
+      } else if (cm.mapRefRelToCorVar.containsKey(project)) {
         // check filter input contains no correlation
-        if (RelOptUtil.getVariablesUsed(rightInputRel).size() > 0) {
+        if (RelOptUtil.getVariablesUsed(right).size() > 0) {
           return;
         }
 
-        if (!checkCorVars(corRel, projRel, null, null)) {
+        if (!checkCorVars(correlate, project, null, null)) {
           return;
         }
 
@@ -1957,37 +1795,37 @@ public class RelDecorrelator implements ReflectiveVisitor {
         //         ProjInputRel
 
         // make the new projRel to provide a null indicator
-        rightInputRel =
-            createProjectWithAdditionalExprs(rightInputRel,
+        right =
+            createProjectWithAdditionalExprs(right,
                 ImmutableList.of(
                     Pair.<RexNode, String>of(
                         rexBuilder.makeLiteral(true), "nullIndicator")));
 
         // make the new aggRel
-        rightInputRel =
-            RelOptUtil.createSingleValueAggRel(cluster, rightInputRel);
+        right =
+            RelOptUtil.createSingleValueAggRel(cluster, right);
 
         // The last field:
         //     single_value(true)
         // is the nullIndicator
         nullIndicatorPos =
-            leftInputRel.getRowType().getFieldCount()
-                + rightInputRel.getRowType().getFieldCount() - 1;
+            left.getRowType().getFieldCount()
+                + right.getRowType().getFieldCount() - 1;
       } else {
         return;
       }
 
       // make the new join rel
       LogicalJoin join =
-          LogicalJoin.create(leftInputRel, rightInputRel, joinCond, joinType,
-              ImmutableSet.<String>of());
+          LogicalJoin.create(left, right, joinCond,
+              ImmutableSet.<CorrelationId>of(), joinType);
 
-      RelNode newProjRel =
-          projectJoinOutputWithNullability(join, projRel, nullIndicatorPos);
+      RelNode newProject =
+          projectJoinOutputWithNullability(join, project, nullIndicatorPos);
 
-      call.transformTo(newProjRel);
+      call.transformTo(newProject);
 
-      removeCorVarFromTree(corRel);
+      removeCorVarFromTree(correlate);
     }
   }
 
@@ -2005,15 +1843,15 @@ public class RelDecorrelator implements ReflectiveVisitor {
     }
 
     public void onMatch(RelOptRuleCall call) {
-      LogicalCorrelate corRel = call.rel(0);
-      RelNode leftInputRel = call.rel(1);
-      LogicalProject aggOutputProjRel = call.rel(2);
-      LogicalAggregate aggRel = call.rel(3);
-      LogicalProject aggInputProjRel = call.rel(4);
-      RelNode rightInputRel = call.rel(5);
-      RelOptCluster cluster = corRel.getCluster();
+      final LogicalCorrelate correlate = call.rel(0);
+      final RelNode left = call.rel(1);
+      final LogicalProject aggOutputProject = call.rel(2);
+      final LogicalAggregate aggregate = call.rel(3);
+      final LogicalProject aggInputProject = call.rel(4);
+      RelNode right = call.rel(5);
+      final RelOptCluster cluster = correlate.getCluster();
 
-      setCurrent(call.getPlanner().getRoot(), corRel);
+      setCurrent(call.getPlanner().getRoot(), correlate);
 
       // check for this pattern
       // The pattern matching could be simplified if rules can be applied
@@ -2026,13 +1864,13 @@ public class RelDecorrelator implements ReflectiveVisitor {
       //       LogicalProject-B (references coVar)
       //         rightInputRel
 
-      // check aggOutputProj projects only one expression
-      List<RexNode> aggOutputProjExprs = aggOutputProjRel.getProjects();
-      if (aggOutputProjExprs.size() != 1) {
+      // check aggOutputProject projects only one expression
+      final List<RexNode> aggOutputProjects = aggOutputProject.getProjects();
+      if (aggOutputProjects.size() != 1) {
         return;
       }
 
-      JoinRelType joinType = corRel.getJoinType().toJoinType();
+      final JoinRelType joinType = correlate.getJoinType().toJoinType();
       // corRel.getCondition was here, however Correlate was updated so it
       // never includes a join condition. The code was not modified for brevity.
       RexNode joinCond = rexBuilder.makeLiteral(true);
@@ -2042,14 +1880,14 @@ public class RelDecorrelator implements ReflectiveVisitor {
       }
 
       // check that the agg is on the entire input
-      if (!aggRel.getGroupSet().isEmpty()) {
+      if (!aggregate.getGroupSet().isEmpty()) {
         return;
       }
 
-      List<RexNode> aggInputProjExprs = aggInputProjRel.getProjects();
+      final List<RexNode> aggInputProjects = aggInputProject.getProjects();
 
-      List<AggregateCall> aggCalls = aggRel.getAggCallList();
-      Set<Integer> isCountStar = Sets.newHashSet();
+      final List<AggregateCall> aggCalls = aggregate.getAggCallList();
+      final Set<Integer> isCountStar = Sets.newHashSet();
 
       // mark if agg produces count(*) which needs to reference the
       // nullIndicator after the transformation.
@@ -2062,20 +1900,20 @@ public class RelDecorrelator implements ReflectiveVisitor {
         }
       }
 
-      if ((rightInputRel instanceof LogicalFilter)
-          && cm.mapRefRelToCorVar.containsKey(rightInputRel)) {
+      if ((right instanceof LogicalFilter)
+          && cm.mapRefRelToCorVar.containsKey(right)) {
         // rightInputRel has this shape:
         //
         //       LogicalFilter (references corvar)
         //         FilterInputRel
-        LogicalFilter filter = (LogicalFilter) rightInputRel;
-        rightInputRel = filter.getInput();
+        LogicalFilter filter = (LogicalFilter) right;
+        right = filter.getInput();
 
-        assert rightInputRel instanceof HepRelVertex;
-        rightInputRel = ((HepRelVertex) rightInputRel).getCurrentRel();
+        assert right instanceof HepRelVertex;
+        right = ((HepRelVertex) right).getCurrentRel();
 
         // check filter input contains no correlation
-        if (RelOptUtil.getVariablesUsed(rightInputRel).size() > 0) {
+        if (RelOptUtil.getVariablesUsed(right).size() > 0) {
           return;
         }
 
@@ -2119,17 +1957,17 @@ public class RelDecorrelator implements ReflectiveVisitor {
         // The join filters out the nulls.  So, it's ok if there are
         // nulls in the join keys.
         if (!RelMdUtil.areColumnsDefinitelyUniqueWhenNullsFiltered(
-            leftInputRel,
+            left,
             correlatedInputRefJoinKeys)) {
           SQL2REL_LOGGER.fine(correlatedJoinKeys.toString()
               + "are not unique keys for "
-              + leftInputRel.toString());
+              + left.toString());
           return;
         }
 
         // check cor var references are valid
-        if (!checkCorVars(corRel,
-            aggInputProjRel,
+        if (!checkCorVars(correlate,
+            aggInputProject,
             filter,
             correlatedJoinKeys)) {
           return;
@@ -2180,27 +2018,27 

<TRUNCATED>