You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@calcite.apache.org by jh...@apache.org on 2020/02/13 02:12:09 UTC

[calcite] branch master updated (a09923f -> 555da95)

This is an automated email from the ASF dual-hosted git repository.

jhyde pushed a change to branch master
in repository https://gitbox.apache.org/repos/asf/calcite.git.


    from a09923f  [CALCITE-3785] HepPlanner.belongToDag() doesn't have to use mapDigestToVertex (Xiening Dai)
     new 051b691  Add RelBuilder.transform, which allows you to clone a RelBuilder with slightly different Config
     new ceb9729  [CALCITE-3763] RelBuilder.aggregate should prune unused fields from the input, if the input is a Project
     new 555da95  [CALCITE-3774] In RelBuilder and ProjectMergeRule, prevent merges when it would increase expression complexity

The 3 revisions listed above as "new" are entirely new to this
repository and will be described in separate emails.  The revisions
listed as "add" were already present in the repository and have only
been added to this reference.


Summary of changes:
 .../java/org/apache/calcite/plan/RelOptUtil.java   |  36 ++-
 .../org/apache/calcite/rel/core/RelFactories.java  | 132 ++++++++-
 .../rel/rules/AbstractMaterializedViewRule.java    |   6 +-
 .../apache/calcite/rel/rules/ProjectMergeRule.java |  24 +-
 .../main/java/org/apache/calcite/rex/RexCall.java  |   6 +
 .../main/java/org/apache/calcite/rex/RexNode.java  |  12 +
 .../main/java/org/apache/calcite/rex/RexOver.java  |   4 +
 .../main/java/org/apache/calcite/rex/RexUtil.java  |  15 +
 .../java/org/apache/calcite/rex/RexWindow.java     |  18 +-
 .../org/apache/calcite/rex/RexWindowBound.java     |  13 +
 .../java/org/apache/calcite/tools/RelBuilder.java  | 312 +++++++++++++--------
 .../org/apache/calcite/rex/RexProgramTest.java     |   2 +-
 .../org/apache/calcite/rex/RexProgramTestBase.java |  11 -
 .../java/org/apache/calcite/test/JdbcTest.java     |  11 +-
 .../org/apache/calcite/test/PigRelBuilderTest.java |  34 ++-
 .../org/apache/calcite/test/RelBuilderTest.java    | 240 ++++++++++++++--
 .../java/org/apache/calcite/tools/PlannerTest.java |   2 +-
 .../org/apache/calcite/test/RelOptRulesTest.xml    | 116 ++++----
 .../org/apache/calcite/piglet/PigRelBuilder.java   |  17 +-
 .../java/org/apache/calcite/test/PigRelOpTest.java |   4 +-
 .../java/org/apache/calcite/test/PigletTest.java   |   4 +-
 21 files changed, 761 insertions(+), 258 deletions(-)


[calcite] 03/03: [CALCITE-3774] In RelBuilder and ProjectMergeRule, prevent merges when it would increase expression complexity

Posted by jh...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

jhyde pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/calcite.git

commit 555da953fe758a7d310aeb3aed463f3f2f3cdc3b
Author: Julian Hyde <jh...@apache.org>
AuthorDate: Wed Feb 5 17:02:22 2020 -0800

    [CALCITE-3774] In RelBuilder and ProjectMergeRule, prevent merges when it would increase expression complexity
    
    Add an option RelBuilder.Config.bloat(), default 100.
    Set it, using RelBuilder.Config.withBloat(int),
    to -1 to never merge,
    0 to merge only if complexity does not increase,
    b to merge if complexity increases by no more than b.
    
    Deprecate RelBuilder.shouldMergeProject().
    
    Cache the nodeCount value in RexCall and RexWindow. Compute nodeCount
    each time for RexOver (a sub-class of RexCall with an extra window),
    because caching it would increase the complexity of RexCall's
    constructor.
---
 .../java/org/apache/calcite/plan/RelOptUtil.java   | 23 +++++++
 .../rel/rules/AbstractMaterializedViewRule.java    |  4 +-
 .../apache/calcite/rel/rules/ProjectMergeRule.java | 24 ++++++-
 .../main/java/org/apache/calcite/rex/RexCall.java  |  6 ++
 .../main/java/org/apache/calcite/rex/RexNode.java  | 12 ++++
 .../main/java/org/apache/calcite/rex/RexOver.java  |  4 ++
 .../main/java/org/apache/calcite/rex/RexUtil.java  | 15 ++++
 .../java/org/apache/calcite/rex/RexWindow.java     | 18 +++--
 .../org/apache/calcite/rex/RexWindowBound.java     | 13 ++++
 .../java/org/apache/calcite/tools/RelBuilder.java  | 56 ++++++++++++++-
 .../org/apache/calcite/rex/RexProgramTest.java     |  2 +-
 .../org/apache/calcite/rex/RexProgramTestBase.java | 11 ---
 .../org/apache/calcite/test/RelBuilderTest.java    | 80 ++++++++++++++++++++++
 .../java/org/apache/calcite/tools/PlannerTest.java |  2 +-
 .../org/apache/calcite/piglet/PigRelBuilder.java   | 17 +++--
 15 files changed, 259 insertions(+), 28 deletions(-)

diff --git a/core/src/main/java/org/apache/calcite/plan/RelOptUtil.java b/core/src/main/java/org/apache/calcite/plan/RelOptUtil.java
index 99469d6..9b73fd1 100644
--- a/core/src/main/java/org/apache/calcite/plan/RelOptUtil.java
+++ b/core/src/main/java/org/apache/calcite/plan/RelOptUtil.java
@@ -135,6 +135,7 @@ import java.util.TreeSet;
 import java.util.function.Supplier;
 import java.util.stream.Collectors;
 import javax.annotation.Nonnull;
+import javax.annotation.Nullable;
 
 /**
  * <code>RelOptUtil</code> defines static utility methods for use in optimizing
@@ -2967,6 +2968,28 @@ public abstract class RelOptUtil {
     return list;
   }
 
+  /** As {@link #pushPastProject}, but returns null if the resulting expressions
+   * are significantly more complex.
+   *
+   * @param bloat Maximum allowable increase in complexity */
+  public static @Nullable List<RexNode> pushPastProjectUnlessBloat(
+      List<? extends RexNode> nodes, Project project, int bloat) {
+    if (bloat < 0) {
+      // If bloat is negative never merge.
+      return null;
+    }
+    final List<RexNode> list = pushPastProject(nodes, project);
+    final int bottomCount = RexUtil.nodeCount(project.getProjects());
+    final int topCount = RexUtil.nodeCount(nodes);
+    final int mergedCount = RexUtil.nodeCount(list);
+    if (mergedCount > bottomCount + topCount + bloat) {
+      // The merged expression is more complex than the input expressions.
+      // Do not merge.
+      return null;
+    }
+    return list;
+  }
+
   private static RexShuttle pushShuttle(final Project project) {
     return new RexShuttle() {
       @Override public RexNode visitInputRef(RexInputRef ref) {
diff --git a/core/src/main/java/org/apache/calcite/rel/rules/AbstractMaterializedViewRule.java b/core/src/main/java/org/apache/calcite/rel/rules/AbstractMaterializedViewRule.java
index 80eb1ff..d1338a9 100644
--- a/core/src/main/java/org/apache/calcite/rel/rules/AbstractMaterializedViewRule.java
+++ b/core/src/main/java/org/apache/calcite/rel/rules/AbstractMaterializedViewRule.java
@@ -960,7 +960,9 @@ public abstract class AbstractMaterializedViewRule extends RelOptRule {
           Filter.class, relBuilderFactory, Aggregate.class);
       this.aggregateProjectPullUpConstantsRule = new AggregateProjectPullUpConstantsRule(
           Aggregate.class, Filter.class, relBuilderFactory, "AggFilterPullUpConstants");
-      this.projectMergeRule = new ProjectMergeRule(true, relBuilderFactory);
+      this.projectMergeRule =
+          new ProjectMergeRule(true, ProjectMergeRule.DEFAULT_BLOAT,
+              relBuilderFactory);
     }
 
     @Override protected boolean isValidPlan(Project topProject, RelNode node,
diff --git a/core/src/main/java/org/apache/calcite/rel/rules/ProjectMergeRule.java b/core/src/main/java/org/apache/calcite/rel/rules/ProjectMergeRule.java
index 2d550f9..818f7b5 100644
--- a/core/src/main/java/org/apache/calcite/rel/rules/ProjectMergeRule.java
+++ b/core/src/main/java/org/apache/calcite/rel/rules/ProjectMergeRule.java
@@ -37,14 +37,20 @@ import java.util.List;
  * provided the projects aren't projecting identical sets of input references.
  */
 public class ProjectMergeRule extends RelOptRule {
+  /** Default amount by which complexity is allowed to increase. */
+  public static final int DEFAULT_BLOAT = 100;
+
   public static final ProjectMergeRule INSTANCE =
-      new ProjectMergeRule(true, RelFactories.LOGICAL_BUILDER);
+      new ProjectMergeRule(true, DEFAULT_BLOAT, RelFactories.LOGICAL_BUILDER);
 
   //~ Instance fields --------------------------------------------------------
 
   /** Whether to always merge projects. */
   private final boolean force;
 
+  /** Limit how much complexity can increase during merging. */
+  private final int bloat;
+
   //~ Constructors -----------------------------------------------------------
 
   /**
@@ -52,13 +58,20 @@ public class ProjectMergeRule extends RelOptRule {
    *
    * @param force Whether to always merge projects
    */
-  public ProjectMergeRule(boolean force, RelBuilderFactory relBuilderFactory) {
+  public ProjectMergeRule(boolean force, int bloat,
+      RelBuilderFactory relBuilderFactory) {
     super(
         operand(Project.class,
             operand(Project.class, any())),
         relBuilderFactory,
         "ProjectMergeRule" + (force ? ":force_mode" : ""));
     this.force = force;
+    this.bloat = bloat;
+  }
+
+  @Deprecated // to be removed before 2.0
+  public ProjectMergeRule(boolean force, RelBuilderFactory relBuilderFactory) {
+    this(force, DEFAULT_BLOAT, relBuilderFactory);
   }
 
   @Deprecated // to be removed before 2.0
@@ -106,7 +119,12 @@ public class ProjectMergeRule extends RelOptRule {
     }
 
     final List<RexNode> newProjects =
-        RelOptUtil.pushPastProject(topProject.getProjects(), bottomProject);
+        RelOptUtil.pushPastProjectUnlessBloat(topProject.getProjects(),
+            bottomProject, bloat);
+    if (newProjects == null) {
+      // Merged projects are significantly more complex. Do not merge.
+      return;
+    }
     final RelNode input = bottomProject.getInput();
     if (RexUtil.isIdentity(newProjects, input.getRowType())) {
       if (force
diff --git a/core/src/main/java/org/apache/calcite/rex/RexCall.java b/core/src/main/java/org/apache/calcite/rex/RexCall.java
index 8c38a7f..d5c2cea 100644
--- a/core/src/main/java/org/apache/calcite/rex/RexCall.java
+++ b/core/src/main/java/org/apache/calcite/rex/RexCall.java
@@ -65,6 +65,7 @@ public class RexCall extends RexNode {
   public final SqlOperator op;
   public final ImmutableList<RexNode> operands;
   public final RelDataType type;
+  public final int nodeCount;
 
   /**
    * Simple binary operators are those operators which expects operands from the same Domain.
@@ -91,6 +92,7 @@ public class RexCall extends RexNode {
     this.type = Objects.requireNonNull(type, "type");
     this.op = Objects.requireNonNull(op, "operator");
     this.operands = ImmutableList.copyOf(operands);
+    this.nodeCount = RexUtil.nodeCount(1, this.operands);
     assert op.getKind() != null : op;
     assert op.validRexOperands(operands.size(), Litmus.THROW) : this;
   }
@@ -342,6 +344,10 @@ public class RexCall extends RexNode {
     return op;
   }
 
+  @Override public int nodeCount() {
+    return nodeCount;
+  }
+
   /**
    * Creates a new call to the same operator with different operands.
    *
diff --git a/core/src/main/java/org/apache/calcite/rex/RexNode.java b/core/src/main/java/org/apache/calcite/rex/RexNode.java
index d9f51c3..7a8f058 100644
--- a/core/src/main/java/org/apache/calcite/rex/RexNode.java
+++ b/core/src/main/java/org/apache/calcite/rex/RexNode.java
@@ -158,6 +158,18 @@ public abstract class RexNode {
     }
   }
 
+  /** Returns the number of nodes in this expression.
+   *
+   * <p>Leaf nodes, such as {@link RexInputRef} or {@link RexLiteral}, have
+   * a count of 1. Calls have a count of 1 plus the sum of their operands.
+   *
+   * <p>Node count is a measure of expression complexity that is used by some
+   * planner rules to prevent deeply nested expressions.
+   */
+  public int nodeCount() {
+    return 1;
+  }
+
   /**
    * Accepts a visitor, dispatching to the right overloaded
    * {@link RexVisitor#visitInputRef visitXxx} method.
diff --git a/core/src/main/java/org/apache/calcite/rex/RexOver.java b/core/src/main/java/org/apache/calcite/rex/RexOver.java
index 9e6e931..d197cc2 100644
--- a/core/src/main/java/org/apache/calcite/rex/RexOver.java
+++ b/core/src/main/java/org/apache/calcite/rex/RexOver.java
@@ -125,6 +125,10 @@ public class RexOver extends RexCall {
     return visitor.visitOver(this, arg);
   }
 
+  @Override public int nodeCount() {
+    return super.nodeCount() + window.nodeCount;
+  }
+
   /**
    * Returns whether an expression contains an OVER clause.
    */
diff --git a/core/src/main/java/org/apache/calcite/rex/RexUtil.java b/core/src/main/java/org/apache/calcite/rex/RexUtil.java
index 27b2b22..7dd6a9d 100644
--- a/core/src/main/java/org/apache/calcite/rex/RexUtil.java
+++ b/core/src/main/java/org/apache/calcite/rex/RexUtil.java
@@ -469,6 +469,21 @@ public class RexUtil {
     return false;
   }
 
+  /** Returns the number of nodes (including leaves) in a list of
+   * expressions.
+   *
+   * @see RexNode#nodeCount() */
+  public static int nodeCount(List<? extends RexNode> nodes) {
+    return nodeCount(0, nodes);
+  }
+
+  static int nodeCount(int n, List<? extends RexNode> nodes) {
+    for (RexNode operand : nodes) {
+      n += operand.nodeCount();
+    }
+    return n;
+  }
+
   /**
    * Walks over an expression and determines whether it is constant.
    */
diff --git a/core/src/main/java/org/apache/calcite/rex/RexWindow.java b/core/src/main/java/org/apache/calcite/rex/RexWindow.java
index da9a33f..4e47f68 100644
--- a/core/src/main/java/org/apache/calcite/rex/RexWindow.java
+++ b/core/src/main/java/org/apache/calcite/rex/RexWindow.java
@@ -16,11 +16,14 @@
  */
 package org.apache.calcite.rex;
 
+import org.apache.calcite.util.Pair;
+
 import com.google.common.collect.ImmutableList;
 
 import java.io.PrintWriter;
 import java.io.StringWriter;
 import java.util.List;
+import javax.annotation.Nullable;
 
 /**
  * Specification of the window of rows over which a {@link RexOver} windowed
@@ -37,6 +40,7 @@ public class RexWindow {
   private final RexWindowBound upperBound;
   private final boolean isRows;
   private final String digest;
+  public final int nodeCount;
 
   //~ Constructors -----------------------------------------------------------
 
@@ -49,16 +53,15 @@ public class RexWindow {
   RexWindow(
       List<RexNode> partitionKeys,
       List<RexFieldCollation> orderKeys,
-      RexWindowBound lowerBound,
-      RexWindowBound upperBound,
+      @Nullable RexWindowBound lowerBound,
+      @Nullable RexWindowBound upperBound,
       boolean isRows) {
-    assert partitionKeys != null;
-    assert orderKeys != null;
     this.partitionKeys = ImmutableList.copyOf(partitionKeys);
     this.orderKeys = ImmutableList.copyOf(orderKeys);
     this.lowerBound = lowerBound;
     this.upperBound = upperBound;
     this.isRows = isRows;
+    this.nodeCount = computeCodeCount();
     this.digest = computeDigest();
   }
 
@@ -149,4 +152,11 @@ public class RexWindow {
   public boolean isRows() {
     return isRows;
   }
+
+  private int computeCodeCount() {
+    return RexUtil.nodeCount(partitionKeys)
+        + RexUtil.nodeCount(Pair.left(orderKeys))
+        + (lowerBound == null ? 0 : lowerBound.nodeCount())
+        + (upperBound == null ? 0 : upperBound.nodeCount());
+  }
 }
diff --git a/core/src/main/java/org/apache/calcite/rex/RexWindowBound.java b/core/src/main/java/org/apache/calcite/rex/RexWindowBound.java
index bfcf92d..547bec6 100644
--- a/core/src/main/java/org/apache/calcite/rex/RexWindowBound.java
+++ b/core/src/main/java/org/apache/calcite/rex/RexWindowBound.java
@@ -106,6 +106,15 @@ public abstract class RexWindowBound {
   }
 
   /**
+   * Returns the number of nodes in this bound.
+   *
+   * @see RexNode#nodeCount()
+   */
+  public int nodeCount() {
+    return 1;
+  }
+
+  /**
    * Implements UNBOUNDED PRECEDING/FOLLOWING bound.
    */
   private static class RexWindowBoundUnbounded extends RexWindowBound {
@@ -217,6 +226,10 @@ public abstract class RexWindowBound {
       return offset;
     }
 
+    @Override public int nodeCount() {
+      return super.nodeCount() + offset.nodeCount();
+    }
+
     @Override public <R> RexWindowBound accept(RexVisitor<R> visitor) {
       R r = offset.accept(visitor);
       if (r instanceof RexNode && r != offset) {
diff --git a/core/src/main/java/org/apache/calcite/tools/RelBuilder.java b/core/src/main/java/org/apache/calcite/tools/RelBuilder.java
index 800f6d9..eec1af1 100644
--- a/core/src/main/java/org/apache/calcite/tools/RelBuilder.java
+++ b/core/src/main/java/org/apache/calcite/tools/RelBuilder.java
@@ -1324,8 +1324,9 @@ public class RelBuilder {
       fieldNameList.add(null);
     }
 
+    bloat:
     if (frame.rel instanceof Project
-        && shouldMergeProject()) {
+        && config.bloat() >= 0) {
       final Project project = (Project) frame.rel;
       // Populate field names. If the upper expression is an input ref and does
       // not have a recommended name, use the name of the underlying field.
@@ -1340,7 +1341,13 @@ public class RelBuilder {
         }
       }
       final List<RexNode> newNodes =
-          RelOptUtil.pushPastProject(nodeList, project);
+          RelOptUtil.pushPastProjectUnlessBloat(nodeList, project,
+              config.bloat());
+      if (newNodes == null) {
+        // The merged expression is more complex than the input expressions.
+        // Do not merge.
+        break bloat;
+      }
 
       // Carefully build a list of fields, so that table aliases from the input
       // can be seen for fields that are based on a RexInputRef.
@@ -1445,6 +1452,7 @@ public class RelBuilder {
    * <p>The default implementation returns {@code true};
    * sub-classes may disable merge by overriding to return {@code false}. */
   @Experimental
+  @Deprecated
   protected boolean shouldMergeProject() {
     return true;
   }
@@ -3053,6 +3061,50 @@ public class RelBuilder {
       return new ConfigBuilder(this);
     }
 
+    /** Controls whether to merge two {@link Project} operators when inlining
+     * expressions causes complexity to increase.
+     *
+     * <p>Usually merging projects is beneficial, but occasionally the
+     * result is more complex than the original projects. Consider:
+     *
+     * <pre>
+     * P: Project(a+b+c AS x, d+e+f AS y, g+h+i AS z)  # complexity 15
+     * Q: Project(x*y*z AS p, x-y-z AS q)              # complexity 10
+     * R: Project((a+b+c)*(d+e+f)*(g+h+i) AS s,
+     *            (a+b+c)-(d+e+f)-(g+h+i) AS t)        # complexity 34
+     * </pre>
+     *
+     * The complexity of an expression is the number of nodes (leaves and
+     * operators). For example, {@code a+b+c} has complexity 5 (3 field
+     * references and 2 calls):
+     *
+     * <pre>
+     *       +
+     *      /  \
+     *     +    c
+     *    / \
+     *   a   b
+     * </pre>
+     *
+     * <p>A negative value never allows merges.
+     *
+     * <p>A zero or positive value, {@code bloat}, allows a merge if complexity
+     * of the result is less than or equal to the sum of the complexity of the
+     * originals plus {@code bloat}.
+     *
+     * <p>The default value, 100, allows a moderate increase in complexity but
+     * prevents cases where complexity would run away into the millions and run
+     * out of memory. Moderate complexity is OK; the implementation, say via
+     * {@link org.apache.calcite.adapter.enumerable.EnumerableCalc}, will often
+     * gather common sub-expressions and compute them only once.
+     */
+    @ImmutableBeans.Property
+    @ImmutableBeans.IntDefault(100)
+    int bloat();
+
+    /** Sets {@link #bloat}. */
+    Config withBloat(int bloat);
+
     /** Whether {@link RelBuilder#aggregate} should eliminate duplicate
      * aggregate calls; default true. */
     @ImmutableBeans.Property
diff --git a/core/src/test/java/org/apache/calcite/rex/RexProgramTest.java b/core/src/test/java/org/apache/calcite/rex/RexProgramTest.java
index 6f4f77e..09989c1 100644
--- a/core/src/test/java/org/apache/calcite/rex/RexProgramTest.java
+++ b/core/src/test/java/org/apache/calcite/rex/RexProgramTest.java
@@ -738,7 +738,7 @@ public class RexProgramTest extends RexProgramTestBase {
               rexBuilder.makeFieldAccess(range3, i * 2 + 1)));
     }
     final RexNode cnf = RexUtil.toCnf(rexBuilder, or(list));
-    final int nodeCount = nodeCount(cnf);
+    final int nodeCount = cnf.nodeCount();
     assertThat((n + 1) * (int) Math.pow(2, n) + 1, equalTo(nodeCount));
     if (n == 3) {
       assertThat(cnf.toStringRaw(),
diff --git a/core/src/test/java/org/apache/calcite/rex/RexProgramTestBase.java b/core/src/test/java/org/apache/calcite/rex/RexProgramTestBase.java
index dc4ae80..9a542ed 100644
--- a/core/src/test/java/org/apache/calcite/rex/RexProgramTestBase.java
+++ b/core/src/test/java/org/apache/calcite/rex/RexProgramTestBase.java
@@ -171,17 +171,6 @@ public class RexProgramTestBase extends RexProgramBuilderBase {
         is(expected ? "true" : "false"));
   }
 
-  /** Returns the number of nodes (including leaves) in a Rex tree. */
-  protected static int nodeCount(RexNode node) {
-    int n = 1;
-    if (node instanceof RexCall) {
-      for (RexNode operand : ((RexCall) node).getOperands()) {
-        n += nodeCount(operand);
-      }
-    }
-    return n;
-  }
-
   protected Comparable eval(RexNode e) {
     return RexInterpreter.evaluate(e, ImmutableMap.of());
   }
diff --git a/core/src/test/java/org/apache/calcite/test/RelBuilderTest.java b/core/src/test/java/org/apache/calcite/test/RelBuilderTest.java
index 75cde59..aa9e09d 100644
--- a/core/src/test/java/org/apache/calcite/test/RelBuilderTest.java
+++ b/core/src/test/java/org/apache/calcite/test/RelBuilderTest.java
@@ -75,6 +75,7 @@ import java.sql.DriverManager;
 import java.sql.PreparedStatement;
 import java.sql.ResultSet;
 import java.sql.SQLException;
+import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collections;
 import java.util.List;
@@ -757,6 +758,85 @@ public class RelBuilderTest {
             + "  LogicalValues(tuples=[[]])\n");
   }
 
+  @Test public void testProjectBloat() {
+    final Function<RelBuilder, RelNode> f = b ->
+        b.scan("EMP")
+            .project(
+                b.alias(
+                    caseCall(b, b.field("DEPTNO"),
+                        b.literal(0), b.literal("zero"),
+                        b.literal(1), b.literal("one"),
+                        b.literal(2), b.literal("two"),
+                        b.literal("other")),
+                    "v"))
+            .project(
+                b.call(SqlStdOperatorTable.PLUS, b.field("v"), b.field("v")))
+        .build();
+    // Complexity of bottom is 14; top is 3; merged is 29; difference is -12.
+    // So, we merge if bloat is 20 or 100 (the default),
+    // but not if it is -1, 0 or 10.
+    final String expected = "LogicalProject($f0=[+"
+        + "(CASE(=($7, 0), 'zero', =($7, 1), 'one', =($7, 2), 'two', 'other'),"
+        + " CASE(=($7, 0), 'zero', =($7, 1), 'one', =($7, 2), 'two', 'other'))])\n"
+        + "  LogicalTableScan(table=[[scott, EMP]])\n";
+    final String expectedNeg = "LogicalProject($f0=[+($0, $0)])\n"
+        + "  LogicalProject(v=[CASE(=($7, 0), 'zero', =($7, 1), "
+        + "'one', =($7, 2), 'two', 'other')])\n"
+        + "    LogicalTableScan(table=[[scott, EMP]])\n";
+    assertThat(f.apply(createBuilder(c -> c)), hasTree(expected));
+    assertThat(f.apply(createBuilder(c -> c.withBloat(0))),
+        hasTree(expectedNeg));
+    assertThat(f.apply(createBuilder(c -> c.withBloat(-1))),
+        hasTree(expectedNeg));
+    assertThat(f.apply(createBuilder(c -> c.withBloat(10))),
+        hasTree(expectedNeg));
+    assertThat(f.apply(createBuilder(c -> c.withBloat(20))),
+        hasTree(expected));
+  }
+
+  @Test public void testProjectBloat2() {
+    final Function<RelBuilder, RelNode> f = b ->
+        b.scan("EMP")
+            .project(
+                b.field("DEPTNO"),
+                b.field("SAL"),
+                b.alias(
+                    b.call(SqlStdOperatorTable.PLUS, b.field("DEPTNO"),
+                        b.field("EMPNO")), "PLUS"))
+            .project(
+                b.call(SqlStdOperatorTable.MULTIPLY, b.field("SAL"),
+                    b.field("PLUS")),
+                b.field("SAL"))
+        .build();
+    // Complexity of bottom is 5; top is 4; merged is 6; difference is 3.
+    // So, we merge except when bloat is -1.
+    final String expected = "LogicalProject($f0=[*($5, +($7, $0))], SAL=[$5])\n"
+        + "  LogicalTableScan(table=[[scott, EMP]])\n";
+    final String expectedNeg = "LogicalProject($f0=[*($1, $2)], SAL=[$1])\n"
+        + "  LogicalProject(DEPTNO=[$7], SAL=[$5], PLUS=[+($7, $0)])\n"
+        + "    LogicalTableScan(table=[[scott, EMP]])\n";
+    assertThat(f.apply(createBuilder(c -> c)), hasTree(expected));
+    assertThat(f.apply(createBuilder(c -> c.withBloat(0))),
+        hasTree(expected));
+    assertThat(f.apply(createBuilder(c -> c.withBloat(-1))),
+        hasTree(expectedNeg));
+    assertThat(f.apply(createBuilder(c -> c.withBloat(10))),
+        hasTree(expected));
+    assertThat(f.apply(createBuilder(c -> c.withBloat(20))),
+        hasTree(expected));
+  }
+
+  private RexNode caseCall(RelBuilder b, RexNode ref, RexNode... nodes) {
+    final List<RexNode> list = new ArrayList<>();
+    for (int i = 0; i + 1 < nodes.length; i += 2) {
+      list.add(b.equals(ref, nodes[i]));
+      list.add(nodes[i + 1]);
+    }
+    list.add(nodes.length % 2 == 1 ? nodes[nodes.length - 1]
+        : b.literal(null));
+    return b.call(SqlStdOperatorTable.CASE, list);
+  }
+
   @Test public void testRename() {
     final RelBuilder builder = RelBuilder.create(config().build());
 
diff --git a/core/src/test/java/org/apache/calcite/tools/PlannerTest.java b/core/src/test/java/org/apache/calcite/tools/PlannerTest.java
index c310594..f78d61e 100644
--- a/core/src/test/java/org/apache/calcite/tools/PlannerTest.java
+++ b/core/src/test/java/org/apache/calcite/tools/PlannerTest.java
@@ -1377,7 +1377,7 @@ public class PlannerTest {
   @Test public void testMergeProjectForceMode() throws Exception {
     RuleSet ruleSet =
         RuleSets.ofList(
-            new ProjectMergeRule(true,
+            new ProjectMergeRule(true, ProjectMergeRule.DEFAULT_BLOAT,
                 RelBuilder.proto(RelFactories.DEFAULT_PROJECT_FACTORY)));
     Planner planner = getPlanner(null, Programs.of(ruleSet));
     planner.close();
diff --git a/piglet/src/main/java/org/apache/calcite/piglet/PigRelBuilder.java b/piglet/src/main/java/org/apache/calcite/piglet/PigRelBuilder.java
index 96150af..b8337eb 100644
--- a/piglet/src/main/java/org/apache/calcite/piglet/PigRelBuilder.java
+++ b/piglet/src/main/java/org/apache/calcite/piglet/PigRelBuilder.java
@@ -17,6 +17,7 @@
 package org.apache.calcite.piglet;
 
 import org.apache.calcite.plan.Context;
+import org.apache.calcite.plan.Contexts;
 import org.apache.calcite.plan.Convention;
 import org.apache.calcite.plan.RelOptCluster;
 import org.apache.calcite.plan.RelOptSchema;
@@ -56,6 +57,7 @@ import java.util.Arrays;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
+import java.util.function.UnaryOperator;
 
 /**
  * Extension to {@link RelBuilder} for Pig logical operators.
@@ -79,10 +81,19 @@ public class PigRelBuilder extends RelBuilder {
   public static PigRelBuilder create(FrameworkConfig config) {
     final RelBuilder relBuilder = RelBuilder.create(config);
     Hook.REL_BUILDER_SIMPLIFY.addThread(Hook.propertyJ(false));
-    return new PigRelBuilder(config.getContext(), relBuilder.getCluster(),
+    return new PigRelBuilder(
+        transform(config.getContext(), c -> c.withBloat(-1)),
+        relBuilder.getCluster(),
         relBuilder.getRelOptSchema());
   }
 
+  private static Context transform(Context context,
+      UnaryOperator<RelBuilder.Config> transform) {
+    final Config config =
+        Util.first(context.unwrap(Config.class), Config.DEFAULT);
+    return Contexts.of(transform.apply(config), context);
+  }
+
   public RelNode getRel(String alias) {
     return aliasMap.get(alias);
   }
@@ -108,10 +119,6 @@ public class PigRelBuilder extends RelBuilder {
     return new CorrelationId(nextCorrelId++);
   }
 
-  @Override protected boolean shouldMergeProject() {
-    return false;
-  }
-
   public String getAlias() {
     final RelNode input = peek();
     if (reverseAliasMap.containsKey(input)) {


[calcite] 02/03: [CALCITE-3763] RelBuilder.aggregate should prune unused fields from the input, if the input is a Project

Posted by jh...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

jhyde pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/calcite.git

commit ceb972952739929c175dfd0895407e8e17e0b502
Author: Julian Hyde <jh...@apache.org>
AuthorDate: Fri Jan 31 16:57:31 2020 -0800

    [CALCITE-3763] RelBuilder.aggregate should prune unused fields from the input, if the input is a Project
---
 .../java/org/apache/calcite/plan/RelOptUtil.java   |  13 +-
 .../rel/rules/AbstractMaterializedViewRule.java    |   2 +-
 .../java/org/apache/calcite/tools/RelBuilder.java  |  72 +++++++++-
 .../java/org/apache/calcite/test/JdbcTest.java     |  11 +-
 .../org/apache/calcite/test/PigRelBuilderTest.java |  34 +++--
 .../org/apache/calcite/test/RelBuilderTest.java    | 160 ++++++++++++++++++---
 .../org/apache/calcite/test/RelOptRulesTest.xml    | 116 ++++++++-------
 .../java/org/apache/calcite/test/PigRelOpTest.java |   4 +-
 .../java/org/apache/calcite/test/PigletTest.java   |   4 +-
 9 files changed, 309 insertions(+), 107 deletions(-)

diff --git a/core/src/main/java/org/apache/calcite/plan/RelOptUtil.java b/core/src/main/java/org/apache/calcite/plan/RelOptUtil.java
index 0183486..99469d6 100644
--- a/core/src/main/java/org/apache/calcite/plan/RelOptUtil.java
+++ b/core/src/main/java/org/apache/calcite/plan/RelOptUtil.java
@@ -627,7 +627,8 @@ public abstract class RelOptUtil {
    * Creates a plan suitable for use in <code>EXISTS</code> or <code>IN</code>
    * statements.
    *
-   * @see org.apache.calcite.sql2rel.SqlToRelConverter#convertExists
+   * @see org.apache.calcite.sql2rel.SqlToRelConverter
+   * SqlToRelConverter#convertExists
    *
    * @param seekRel    A query rel, for example the resulting rel from 'select *
    *                   from emp' or 'values (1,2,3)' or '('Foo', 34)'.
@@ -898,9 +899,15 @@ public abstract class RelOptUtil {
 
   /** Gets all fields in an aggregate. */
   public static Set<Integer> getAllFields(Aggregate aggregate) {
+    return getAllFields2(aggregate.getGroupSet(), aggregate.getAggCallList());
+  }
+
+  /** Gets all fields in an aggregate. */
+  public static Set<Integer> getAllFields2(ImmutableBitSet groupSet,
+      List<AggregateCall> aggCallList) {
     final Set<Integer> allFields = new TreeSet<>();
-    allFields.addAll(aggregate.getGroupSet().asList());
-    for (AggregateCall aggregateCall : aggregate.getAggCallList()) {
+    allFields.addAll(groupSet.asList());
+    for (AggregateCall aggregateCall : aggCallList) {
       allFields.addAll(aggregateCall.getArgList());
       if (aggregateCall.filterArg >= 0) {
         allFields.add(aggregateCall.filterArg);
diff --git a/core/src/main/java/org/apache/calcite/rel/rules/AbstractMaterializedViewRule.java b/core/src/main/java/org/apache/calcite/rel/rules/AbstractMaterializedViewRule.java
index cd24090..80eb1ff 100644
--- a/core/src/main/java/org/apache/calcite/rel/rules/AbstractMaterializedViewRule.java
+++ b/core/src/main/java/org/apache/calcite/rel/rules/AbstractMaterializedViewRule.java
@@ -506,7 +506,7 @@ public abstract class AbstractMaterializedViewRule extends RelOptRule {
             // Then, we trigger the unifying method. This method will either create a
             // Project or an Aggregate operator on top of the view. It will also compute
             // the output expressions for the query.
-            RelBuilder builder = call.builder();
+            RelBuilder builder = call.builder().transform(c -> c.withPruneInputOfAggregate(false));
             RelNode viewWithFilter;
             if (!viewCompensationPred.isAlwaysTrue()) {
               RexNode newPred =
diff --git a/core/src/main/java/org/apache/calcite/tools/RelBuilder.java b/core/src/main/java/org/apache/calcite/tools/RelBuilder.java
index e4ef056..800f6d9 100644
--- a/core/src/main/java/org/apache/calcite/tools/RelBuilder.java
+++ b/core/src/main/java/org/apache/calcite/tools/RelBuilder.java
@@ -110,6 +110,7 @@ import java.util.ArrayDeque;
 import java.util.ArrayList;
 import java.util.Collections;
 import java.util.Deque;
+import java.util.HashMap;
 import java.util.HashSet;
 import java.util.LinkedList;
 import java.util.List;
@@ -1584,7 +1585,7 @@ public class RelBuilder {
     final Registrar registrar =
         new Registrar(fields(), peek().getRowType().getFieldNames());
     final GroupKeyImpl groupKey_ = (GroupKeyImpl) groupKey;
-    final ImmutableBitSet groupSet =
+    ImmutableBitSet groupSet =
         ImmutableBitSet.of(registrar.registerExpressions(groupKey_.nodes));
   label:
     if (Iterables.isEmpty(aggCalls)) {
@@ -1610,7 +1611,7 @@ public class RelBuilder {
         return project(fields(groupSet));
       }
     }
-    final ImmutableList<ImmutableBitSet> groupSets;
+    ImmutableList<ImmutableBitSet> groupSets;
     if (groupKey_.nodeLists != null) {
       final int sizeBefore = registrar.extraNodes.size();
       final SortedSet<ImmutableBitSet> groupSetSet =
@@ -1646,7 +1647,7 @@ public class RelBuilder {
     project(registrar.extraNodes);
     rename(registrar.names);
     final Frame frame = stack.pop();
-    final RelNode r = frame.rel;
+    RelNode r = frame.rel;
     final List<AggregateCall> aggregateCalls = new ArrayList<>();
     for (AggCall aggCall : aggCalls) {
       final AggregateCall aggregateCall;
@@ -1685,6 +1686,49 @@ public class RelBuilder {
       assert groupSet.contains(set);
     }
 
+    if (config.pruneInputOfAggregate()
+        && r instanceof Project) {
+      final Set<Integer> fieldsUsed =
+          RelOptUtil.getAllFields2(groupSet, aggregateCalls);
+      // Some parts of the system can't handle rows with zero fields, so
+      // pretend that one field is used.
+      if (fieldsUsed.isEmpty()) {
+        r = ((Project) r).getInput();
+      } else if (fieldsUsed.size() < r.getRowType().getFieldCount()) {
+        // Some fields are computed but not used. Prune them.
+        final Map<Integer, Integer> map = new HashMap<>();
+        for (int source : fieldsUsed) {
+          map.put(source, map.size());
+        }
+
+        groupSet = groupSet.permute(map);
+        groupSets =
+            ImmutableBitSet.ORDERING.immutableSortedCopy(
+                ImmutableBitSet.permute(groupSets, map));
+
+        final Mappings.TargetMapping targetMapping =
+            Mappings.target(map, r.getRowType().getFieldCount(),
+                fieldsUsed.size());
+        final List<AggregateCall> oldAggregateCalls =
+            new ArrayList<>(aggregateCalls);
+        aggregateCalls.clear();
+        for (AggregateCall aggregateCall : oldAggregateCalls) {
+          aggregateCalls.add(aggregateCall.transform(targetMapping));
+        }
+
+        final Project project = (Project) r;
+        final List<RexNode> newProjects = new ArrayList<>();
+        final RelDataTypeFactory.Builder builder =
+            cluster.getTypeFactory().builder();
+        for (int i : fieldsUsed) {
+          newProjects.add(project.getProjects().get(i));
+          builder.add(project.getRowType().getFieldList().get(i));
+        }
+        r = project.copy(r.getTraitSet(), project.getInput(), newProjects,
+            builder.build());
+      }
+    }
+
     if (!config.dedupAggregateCalls() || Util.isDistinct(aggregateCalls)) {
       return aggregate_(groupSet, groupSets, r, aggregateCalls,
           registrar.extraNodes, frame.fields);
@@ -2724,10 +2768,17 @@ public class RelBuilder {
       if (distinct) {
         b.append("DISTINCT ");
       }
-      b.append(operands)
-          .append(')');
+      final int iMax = operands.size() - 1;
+      for (int i = 0; ; i++) {
+        b.append(operands.get(i));
+        if (i == iMax) {
+          break;
+        }
+        b.append(", ");
+      }
+      b.append(')');
       if (filter != null) {
-        b.append(" FILTER (WHERE" + filter + ")");
+        b.append(" FILTER (WHERE ").append(filter).append(')');
       }
       return b.toString();
     }
@@ -3011,6 +3062,15 @@ public class RelBuilder {
     /** Sets {@link #dedupAggregateCalls}. */
     Config withDedupAggregateCalls(boolean dedupAggregateCalls);
 
+    /** Whether {@link RelBuilder#aggregate} should prune unused
+     * input columns; default true. */
+    @ImmutableBeans.Property
+    @ImmutableBeans.BooleanDefault(true)
+    boolean pruneInputOfAggregate();
+
+    /** Sets {@link #pruneInputOfAggregate}. */
+    Config withPruneInputOfAggregate(boolean pruneInputOfAggregate);
+
     /** Whether to simplify expressions; default true. */
     @ImmutableBeans.Property
     @ImmutableBeans.BooleanDefault(true)
diff --git a/core/src/test/java/org/apache/calcite/test/JdbcTest.java b/core/src/test/java/org/apache/calcite/test/JdbcTest.java
index c00fe13..7605703 100644
--- a/core/src/test/java/org/apache/calcite/test/JdbcTest.java
+++ b/core/src/test/java/org/apache/calcite/test/JdbcTest.java
@@ -3428,12 +3428,11 @@ public class JdbcTest {
       CalciteAssert.hr()
           .query("select count(*) c from \"hr\".\"emps\", \"hr\".\"depts\"")
           .convertContains("LogicalAggregate(group=[{}], C=[COUNT()])\n"
-              + "  LogicalProject(DUMMY=[0])\n"
-              + "    LogicalJoin(condition=[true], joinType=[inner])\n"
-              + "      LogicalProject(DUMMY=[0])\n"
-              + "        EnumerableTableScan(table=[[hr, emps]])\n"
-              + "      LogicalProject(DUMMY=[0])\n"
-              + "        EnumerableTableScan(table=[[hr, depts]])");
+              + "  LogicalJoin(condition=[true], joinType=[inner])\n"
+              + "    LogicalProject(DUMMY=[0])\n"
+              + "      EnumerableTableScan(table=[[hr, emps]])\n"
+              + "    LogicalProject(DUMMY=[0])\n"
+              + "      EnumerableTableScan(table=[[hr, depts]])");
     }
   }
 
diff --git a/core/src/test/java/org/apache/calcite/test/PigRelBuilderTest.java b/core/src/test/java/org/apache/calcite/test/PigRelBuilderTest.java
index 43b6bbe..784c45c 100644
--- a/core/src/test/java/org/apache/calcite/test/PigRelBuilderTest.java
+++ b/core/src/test/java/org/apache/calcite/test/PigRelBuilderTest.java
@@ -26,6 +26,7 @@ import org.apache.calcite.util.Util;
 
 import org.junit.jupiter.api.Test;
 
+import java.util.function.Function;
 import java.util.function.UnaryOperator;
 
 import static org.hamcrest.CoreMatchers.is;
@@ -108,16 +109,26 @@ public class PigRelBuilderTest {
     //     [PARTITION BY partitioner] [PARALLEL n];
     // Equivalent to Pig Latin:
     //   r = GROUP e BY (deptno, job);
-    final PigRelBuilder builder = PigRelBuilder.create(config().build());
-    final RelNode root = builder
-        .scan("EMP")
-        .group(null, null, -1, builder.groupKey("DEPTNO", "JOB").alias("e"))
-        .build();
+    final Function<PigRelBuilder, RelNode> f = builder ->
+        builder.scan("EMP")
+            .group(null, null, -1, builder.groupKey("DEPTNO", "JOB").alias("e"))
+            .build();
     final String plan = ""
+        + "LogicalAggregate(group=[{0, 1}], EMP=[COLLECT($2)])\n"
+        + "  LogicalProject(JOB=[$2], DEPTNO=[$7], "
+        + "$f8=[ROW($0, $1, $2, $3, $4, $5, $6, $7)])\n"
+        + "    LogicalTableScan(table=[[scott, EMP]])\n";
+    assertThat(str(f.apply(createBuilder(b -> b))), is(plan));
+
+    // now without pruning
+    final String plan2 = ""
         + "LogicalAggregate(group=[{2, 7}], EMP=[COLLECT($8)])\n"
-        + "  LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], $f8=[ROW($0, $1, $2, $3, $4, $5, $6, $7)])\n"
+        + "  LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], "
+        + "HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], $f8=[ROW($0, $1, $2, $3, $4, $5, $6, $7)])\n"
         + "    LogicalTableScan(table=[[scott, EMP]])\n";
-    assertThat(str(root), is(plan));
+    assertThat(
+        str(f.apply(createBuilder(b -> b.withPruneInputOfAggregate(false)))),
+        is(plan2));
   }
 
   @Test public void testGroup2() {
@@ -132,10 +143,11 @@ public class PigRelBuilderTest {
             builder.groupKey("DEPTNO").alias("d"))
         .build();
     final String plan = "LogicalJoin(condition=[=($0, $2)], joinType=[inner])\n"
-        + "  LogicalAggregate(group=[{0}], EMP=[COLLECT($8)])\n"
-        + "    LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], $f8=[ROW($0, $1, $2, $3, $4, $5, $6, $7)])\n"
-        + "      LogicalTableScan(table=[[scott, EMP]])\n  LogicalAggregate(group=[{0}], DEPT=[COLLECT($3)])\n"
-        + "    LogicalProject(DEPTNO=[$0], DNAME=[$1], LOC=[$2], $f3=[ROW($0, $1, $2)])\n"
+        + "  LogicalAggregate(group=[{0}], EMP=[COLLECT($1)])\n"
+        + "    LogicalProject(EMPNO=[$0], $f8=[ROW($0, $1, $2, $3, $4, $5, $6, $7)])\n"
+        + "      LogicalTableScan(table=[[scott, EMP]])\n"
+        + "  LogicalAggregate(group=[{0}], DEPT=[COLLECT($1)])\n"
+        + "    LogicalProject(DEPTNO=[$0], $f3=[ROW($0, $1, $2)])\n"
         + "      LogicalTableScan(table=[[scott, DEPT]])\n";
     assertThat(str(root), is(plan));
   }
diff --git a/core/src/test/java/org/apache/calcite/test/RelBuilderTest.java b/core/src/test/java/org/apache/calcite/test/RelBuilderTest.java
index 4e546d3..75cde59 100644
--- a/core/src/test/java/org/apache/calcite/test/RelBuilderTest.java
+++ b/core/src/test/java/org/apache/calcite/test/RelBuilderTest.java
@@ -81,6 +81,7 @@ import java.util.List;
 import java.util.Locale;
 import java.util.NoSuchElementException;
 import java.util.TreeSet;
+import java.util.function.Function;
 import java.util.function.UnaryOperator;
 
 import static org.apache.calcite.test.Matchers.hasHints;
@@ -901,8 +902,7 @@ public class RelBuilderTest {
     //   SELECT COUNT(*) AS c, SUM(mgr + 1) AS s
     //   FROM emp
     //   GROUP BY ename, hiredate + mgr
-    final RelBuilder builder = RelBuilder.create(config().build());
-    RelNode root =
+    final Function<RelBuilder, RelNode> f = builder ->
         builder.scan("EMP")
             .aggregate(
                 builder.groupKey(builder.field(1),
@@ -916,10 +916,20 @@ public class RelBuilderTest {
                         builder.literal(1))).as("S"))
             .build();
     final String expected = ""
+        + "LogicalAggregate(group=[{0, 1}], C=[COUNT()], S=[SUM($2)])\n"
+        + "  LogicalProject(ENAME=[$1], $f8=[+($4, $3)], $f9=[+($3, 1)])\n"
+        + "    LogicalTableScan(table=[[scott, EMP]])\n";
+    assertThat(f.apply(createBuilder(c -> c)), hasTree(expected));
+
+    // now without pruning
+    final String expected2 = ""
         + "LogicalAggregate(group=[{1, 8}], C=[COUNT()], S=[SUM($9)])\n"
-        + "  LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], $f8=[+($4, $3)], $f9=[+($3, 1)])\n"
+        + "  LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], "
+        + "HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], $f8=[+($4, $3)], "
+        + "$f9=[+($3, 1)])\n"
         + "    LogicalTableScan(table=[[scott, EMP]])\n";
-    assertThat(root, hasTree(expected));
+    assertThat(f.apply(createBuilder(c -> c.withPruneInputOfAggregate(false))),
+        hasTree(expected2));
   }
 
   /** Test case for
@@ -1080,8 +1090,7 @@ public class RelBuilderTest {
     //   SELECT deptno, COUNT(*) FILTER (WHERE empno > 100) AS c
     //   FROM emp
     //   GROUP BY ROLLUP(deptno)
-    final RelBuilder builder = RelBuilder.create(config().build());
-    RelNode root =
+    final Function<RelBuilder, RelNode> f = builder ->
         builder.scan("EMP")
             .aggregate(
                 builder.groupKey(ImmutableBitSet.of(7),
@@ -1095,10 +1104,19 @@ public class RelBuilderTest {
                     .as("C"))
             .build();
     final String expected = ""
+        + "LogicalAggregate(group=[{0}], groups=[[{0}, {}]], C=[COUNT() FILTER $1])\n"
+        + "  LogicalProject(DEPTNO=[$7], $f8=[>($0, 100)])\n"
+        + "    LogicalTableScan(table=[[scott, EMP]])\n";
+    assertThat(f.apply(createBuilder(c -> c)), hasTree(expected));
+
+    // now without pruning
+    final String expected2 = ""
         + "LogicalAggregate(group=[{7}], groups=[[{7}, {}]], C=[COUNT() FILTER $8])\n"
-        + "  LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], $f8=[>($0, 100)])\n"
+        + "  LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], "
+        + "HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], $f8=[>($0, 100)])\n"
         + "    LogicalTableScan(table=[[scott, EMP]])\n";
-    assertThat(root, hasTree(expected));
+    assertThat(f.apply(createBuilder(c -> c.withPruneInputOfAggregate(false))),
+        hasTree(expected2));
   }
 
   @Test public void testAggregateFilterFails() {
@@ -1128,8 +1146,7 @@ public class RelBuilderTest {
     //   SELECT deptno, SUM(sal) FILTER (WHERE comm < 100) AS c
     //   FROM emp
     //   GROUP BY deptno
-    final RelBuilder builder = RelBuilder.create(config().build());
-    RelNode root =
+    final Function<RelBuilder, RelNode> f = builder ->
         builder.scan("EMP")
             .aggregate(
                 builder.groupKey(builder.field("DEPTNO")),
@@ -1140,10 +1157,18 @@ public class RelBuilderTest {
                     .as("C"))
             .build();
     final String expected = ""
+        + "LogicalAggregate(group=[{1}], C=[SUM($0) FILTER $2])\n"
+        + "  LogicalProject(SAL=[$5], DEPTNO=[$7], $f8=[IS TRUE(<($6, 100))])\n"
+        + "    LogicalTableScan(table=[[scott, EMP]])\n";
+    assertThat(f.apply(createBuilder(c -> c)), hasTree(expected));
+
+    // now without pruning
+    final String expected2 = ""
         + "LogicalAggregate(group=[{7}], C=[SUM($5) FILTER $8])\n"
         + "  LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], $f8=[IS TRUE(<($6, 100))])\n"
         + "    LogicalTableScan(table=[[scott, EMP]])\n";
-    assertThat(root, hasTree(expected));
+    assertThat(f.apply(createBuilder(c -> c.withPruneInputOfAggregate(false))),
+        hasTree(expected2));
   }
 
   /** Test case for
@@ -1169,8 +1194,7 @@ public class RelBuilderTest {
   }
 
   @Test public void testAggregateProjectWithExpression() {
-    final RelBuilder builder = RelBuilder.create(config().build());
-    RelNode root =
+    final Function<RelBuilder, RelNode> f = builder ->
         builder.scan("EMP")
             .project(builder.field("DEPTNO"))
             .aggregate(
@@ -1181,10 +1205,105 @@ public class RelBuilderTest {
                         "d3")))
             .build();
     final String expected = ""
+        + "LogicalAggregate(group=[{0}])\n"
+        + "  LogicalProject(d3=[+($7, 3)])\n"
+        + "    LogicalTableScan(table=[[scott, EMP]])\n";
+    assertThat(f.apply(createBuilder(c -> c)), hasTree(expected));
+
+    // now without pruning
+    final String expected2 = ""
         + "LogicalAggregate(group=[{1}])\n"
         + "  LogicalProject(DEPTNO=[$7], d3=[+($7, 3)])\n"
         + "    LogicalTableScan(table=[[scott, EMP]])\n";
-    assertThat(root, hasTree(expected));
+    assertThat(f.apply(createBuilder(c -> c.withPruneInputOfAggregate(false))),
+        hasTree(expected2));
+  }
+
+  /** Tests that {@link RelBuilder#aggregate} on top of a {@link Project} prunes
+   * away expressions that are not used.
+   *
+   * @see RelBuilder.Config#pruneInputOfAggregate */
+  @Test public void testAggregateProjectPrune() {
+    // SELECT deptno, SUM(sal) FILTER (WHERE b)
+    // FROM (
+    //   SELECT deptno, empno + 10, sal, job = 'CLERK' AS b
+    //   FROM emp)
+    // GROUP BY deptno
+    //   -->
+    // SELECT deptno, SUM(sal) FILTER (WHERE b)
+    // FROM (
+    //   SELECT deptno, sal, job = 'CLERK' AS b
+    //   FROM emp)
+    // GROUP BY deptno
+    final Function<RelBuilder, RelNode> f = builder ->
+        builder.scan("EMP")
+            .project(builder.field("DEPTNO"),
+                builder.call(SqlStdOperatorTable.PLUS,
+                    builder.field("EMPNO"), builder.literal(10)),
+                builder.field("SAL"),
+                builder.field("JOB"))
+            .aggregate(
+                builder.groupKey(builder.field("DEPTNO")),
+                    builder.sum(builder.field("SAL"))
+                .filter(
+                    builder.call(SqlStdOperatorTable.EQUALS,
+                        builder.field("JOB"), builder.literal("CLERK"))))
+            .build();
+    final String expected = ""
+        + "LogicalAggregate(group=[{0}], agg#0=[SUM($1) FILTER $2])\n"
+        + "  LogicalProject(DEPTNO=[$7], SAL=[$5], $f4=[IS TRUE(=($2, 'CLERK'))])\n"
+        + "    LogicalTableScan(table=[[scott, EMP]])\n";
+    assertThat(f.apply(createBuilder(c -> c)),
+        hasTree(expected));
+
+    // now with pruning disabled
+    final String expected2 = ""
+        + "LogicalAggregate(group=[{0}], agg#0=[SUM($2) FILTER $4])\n"
+        + "  LogicalProject(DEPTNO=[$7], $f1=[+($0, 10)], SAL=[$5], JOB=[$2], "
+        + "$f4=[IS TRUE(=($2, 'CLERK'))])\n"
+        + "    LogicalTableScan(table=[[scott, EMP]])\n";
+    assertThat(f.apply(createBuilder(c -> c.withPruneInputOfAggregate(false))),
+        hasTree(expected2));
+  }
+
+  /** Tests that (a) if the input is a project and no fields are used
+   * we remove the project (rather than projecting zero fields, which
+   * would be wrong), and (b) if the same aggregate function is used
+   * twice, we add a project on top. */
+  @Test public void testAggregateProjectPruneEmpty() {
+    // SELECT COUNT(*) AS C, COUNT(*) AS C2 FROM (
+    //  SELECT deptno, empno + 10, sal, job = 'CLERK' AS b
+    //  FROM emp)
+    //   -->
+    // SELECT C, C AS C2 FROM (
+    //   SELECT COUNT(*) AS c
+    //   FROM emp)
+    final Function<RelBuilder, RelNode> f = builder ->
+        builder.scan("EMP")
+            .project(builder.field("DEPTNO"),
+                builder.call(SqlStdOperatorTable.PLUS,
+                    builder.field("EMPNO"), builder.literal(10)),
+                builder.field("SAL"),
+                builder.field("JOB"))
+            .aggregate(
+                builder.groupKey(),
+                    builder.countStar("C"),
+                    builder.countStar("C2"))
+            .build();
+    final String expected = ""
+        + "LogicalProject(C=[$0], C2=[$0])\n"
+        + "  LogicalAggregate(group=[{}], C=[COUNT()])\n"
+        + "    LogicalTableScan(table=[[scott, EMP]])\n";
+    assertThat(f.apply(createBuilder(c -> c)), hasTree(expected));
+
+    // now with pruning disabled
+    final String expected2 = ""
+        + "LogicalProject(C=[$0], C2=[$0])\n"
+        + "  LogicalAggregate(group=[{}], C=[COUNT()])\n"
+        + "    LogicalProject(DEPTNO=[$7], $f1=[+($0, 10)], SAL=[$5], JOB=[$2])\n"
+        + "      LogicalTableScan(table=[[scott, EMP]])\n";
+    assertThat(f.apply(createBuilder(c -> c.withPruneInputOfAggregate(false))),
+        hasTree(expected2));
   }
 
   @Test public void testAggregateGroupingKeyOutOfRangeFails() {
@@ -1317,8 +1436,7 @@ public class RelBuilderTest {
     // but applying "select ... group by ()" to it would change the result.
     // In theory, we could omit the distinct if we know there is precisely one
     // row, but we don't currently.
-    final RelBuilder builder = RelBuilder.create(config().build());
-    RelNode root =
+    final Function<RelBuilder, RelNode> f = builder ->
         builder.scan("EMP")
             .filter(
                 builder.call(SqlStdOperatorTable.IS_NULL,
@@ -1327,10 +1445,18 @@ public class RelBuilderTest {
             .distinct()
             .build();
     final String expected = "LogicalAggregate(group=[{}])\n"
+        + "  LogicalFilter(condition=[IS NULL($6)])\n"
+        + "    LogicalTableScan(table=[[scott, EMP]])\n";
+    assertThat(f.apply(createBuilder(c -> c)), hasTree(expected));
+
+    // now without pruning
+    // (The empty LogicalProject is dubious, but it's what we've always done)
+    final String expected2 = "LogicalAggregate(group=[{}])\n"
         + "  LogicalProject\n"
         + "    LogicalFilter(condition=[IS NULL($6)])\n"
         + "      LogicalTableScan(table=[[scott, EMP]])\n";
-    assertThat(root, hasTree(expected));
+    assertThat(f.apply(createBuilder(c -> c.withPruneInputOfAggregate(false))),
+        hasTree(expected2));
   }
 
   @Test public void testUnion() {
diff --git a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
index d1cefda..1661066 100644
--- a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
+++ b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
@@ -40,8 +40,8 @@ LogicalAggregate(group=[{}], SUM_SAL=[SUM($0)], COUNT_DISTINCT_CLERK=[COUNT(DIST
         <Resource name="planAfter">
             <![CDATA[
 LogicalProject(SUM_SAL=[$0], COUNT_DISTINCT_CLERK=[$1], SUM_SAL_D10=[$2], SUM_SAL_D20=[$3], COUNT_D30=[CAST($4):INTEGER], COUNT_D40=[$5], COUNT_D20=[$6])
-  LogicalAggregate(group=[{}], SUM_SAL=[SUM($0)], COUNT_DISTINCT_CLERK=[COUNT(DISTINCT $7) FILTER $8], SUM_SAL_D10=[SUM($9) FILTER $10], SUM_SAL_D20=[SUM($11) FILTER $12], COUNT_D30=[COUNT() FILTER $13], COUNT_D40=[COUNT() FILTER $14], COUNT_D20=[COUNT() FILTER $15])
-    LogicalProject(SAL=[$5], $f1=[CASE(=($2, 'CLERK'), $7, null:INTEGER)], $f2=[CASE(=($7, 10), $5, null:INTEGER)], $f3=[CASE(=($7, 20), $5, 0)], $f4=[CASE(=($7, 30), 1, 0)], $f5=[CASE(=($7, 40), 'x', null:CHAR(1))], $f6=[CASE(=($7, 20), 1, null:INTEGER)], DEPTNO=[$7], $f8=[=($2, 'CLERK')], SAL0=[$5], $f10=[=($7, 10)], SAL1=[$5], $f12=[=($7, 20)], $f13=[=($7, 30)], $f14=[=($7, 40)], $f15=[=($7, 20)])
+  LogicalAggregate(group=[{}], SUM_SAL=[SUM($0)], COUNT_DISTINCT_CLERK=[COUNT(DISTINCT $1) FILTER $2], SUM_SAL_D10=[SUM($3) FILTER $4], SUM_SAL_D20=[SUM($5) FILTER $6], COUNT_D30=[COUNT() FILTER $7], COUNT_D40=[COUNT() FILTER $8], COUNT_D20=[COUNT() FILTER $9])
+    LogicalProject(SAL=[$5], DEPTNO=[$7], $f8=[=($2, 'CLERK')], SAL0=[$5], $f10=[=($7, 10)], SAL1=[$5], $f12=[=($7, 20)], $f13=[=($7, 30)], $f14=[=($7, 40)], $f15=[=($7, 20)])
       LogicalTableScan(table=[[CATALOG, SALES, EMP]])
 ]]>
         </Resource>
@@ -717,7 +717,7 @@ LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$
     LogicalTableScan(table=[[CATALOG, SALES, EMP]])
     LogicalProject(DEPTNO=[$0], $f1=[true])
       LogicalAggregate(group=[{0}])
-        LogicalProject(DEPTNO=[$7], i=[true])
+        LogicalProject(DEPTNO=[$7])
           LogicalTableScan(table=[[CATALOG, SALES, EMP]])
 ]]>
         </Resource>
@@ -782,11 +782,11 @@ LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$
         LogicalTableScan(table=[[CATALOG, SALES, EMP]])
         LogicalProject(DEPTNO=[$0], $f1=[true])
           LogicalAggregate(group=[{0}])
-            LogicalProject(DEPTNO=[$7], i=[true])
+            LogicalProject(DEPTNO=[$7])
               LogicalTableScan(table=[[CATALOG, SALES, EMP]])
       LogicalProject(JOB=[$0], $f1=[true])
         LogicalAggregate(group=[{0}])
-          LogicalProject(JOB=[$2], i=[true])
+          LogicalProject(JOB=[$2])
             LogicalFilter(condition=[=($5, 34)])
               LogicalTableScan(table=[[CATALOG, SALES, EMP]])
 ]]>
@@ -843,7 +843,7 @@ LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$
             LogicalTableScan(table=[[CATALOG, SALES, EMP]])
     LogicalProject(DEPTNO=[$0], $f1=[true])
       LogicalAggregate(group=[{0}])
-        LogicalProject(DEPTNO=[$7], i=[true])
+        LogicalProject(DEPTNO=[$7])
           LogicalTableScan(table=[[CATALOG, SALES, EMP]])
 ]]>
         </Resource>
@@ -1167,8 +1167,8 @@ LogicalAggregate(group=[{}], EXPR$0=[SUM($0)], EXPR$1=[COUNT(DISTINCT $1) FILTER
         </Resource>
         <Resource name="planAfter">
             <![CDATA[
-LogicalAggregate(group=[{}], EXPR$0=[MIN($2) FILTER $4], EXPR$1=[COUNT($0) FILTER $3])
-  LogicalProject(SAL=[$0], $f2=[$1], EXPR$0=[$2], $g_0_f_1=[AND(=($3, 0), IS TRUE($1))], $g_3=[=($3, 3)])
+LogicalAggregate(group=[{}], EXPR$0=[MIN($1) FILTER $3], EXPR$1=[COUNT($0) FILTER $2])
+  LogicalProject(SAL=[$0], EXPR$0=[$2], $g_0_f_1=[AND(=($3, 0), IS TRUE($1))], $g_3=[=($3, 3)])
     LogicalAggregate(group=[{1, 2}], groups=[[{1, 2}, {}]], EXPR$0=[SUM($0)], $g=[GROUPING($1, $2)])
       LogicalProject(COMM=[$6], SAL=[$5], $f2=[>($5, 1000)])
         LogicalTableScan(table=[[CATALOG, SALES, EMP]])
@@ -1179,7 +1179,7 @@ LogicalAggregate(group=[{}], EXPR$0=[MIN($2) FILTER $4], EXPR$1=[COUNT($0) FILTE
         <Resource name="sql">
             <![CDATA[SELECT COUNT(DISTINCT c) FILTER (WHERE d),
 COUNT(DISTINCT d) FILTER (WHERE c)
-FROM (select sal > 1000 is true as c, sal < 500 is true as d from emp)]]>
+FROM (select sal > 1000 is true as c, sal < 500 is true as d, comm from emp)]]>
         </Resource>
         <Resource name="planBefore">
             <![CDATA[
@@ -1214,8 +1214,8 @@ LogicalAggregate(group=[{0}], EXPR$1=[SUM($1)], EXPR$2=[COUNT(DISTINCT $2) FILTE
         <Resource name="planAfter">
             <![CDATA[
 LogicalProject(DEPTNO=[$0], EXPR$1=[CAST($1):INTEGER NOT NULL], EXPR$2=[$2])
-  LogicalAggregate(group=[{0}], EXPR$1=[MIN($3) FILTER $5], EXPR$2=[COUNT($1) FILTER $4])
-    LogicalProject(DEPTNO=[$0], SAL=[$1], $f3=[$2], EXPR$1=[$3], $g_0_f_2=[AND(=($4, 0), IS TRUE($2))], $g_3=[=($4, 3)])
+  LogicalAggregate(group=[{0}], EXPR$1=[MIN($2) FILTER $4], EXPR$2=[COUNT($1) FILTER $3])
+    LogicalProject(DEPTNO=[$0], SAL=[$1], EXPR$1=[$3], $g_0_f_2=[AND(=($4, 0), IS TRUE($2))], $g_3=[=($4, 3)])
       LogicalAggregate(group=[{0, 2, 3}], groups=[[{0, 2, 3}, {0}]], EXPR$1=[SUM($1)], $g=[GROUPING($0, $2, $3)])
         LogicalProject(DEPTNO=[$7], COMM=[$6], SAL=[$5], $f3=[>($5, 1000)])
           LogicalTableScan(table=[[CATALOG, SALES, EMP]])
@@ -1489,29 +1489,27 @@ LogicalProject(EXPR$0=[1])
     </TestCase>
     <TestCase name="testPushAboveFiltersIntoInnerJoinCondition">
         <Resource name="sql">
-            <![CDATA[
-select * from sales.dept d inner join sales.emp e
+            <![CDATA[select * from sales.dept d inner join sales.emp e
 on d.deptno = e.deptno and d.deptno > e.mgr
-where d.deptno > e.mgr
-]]>
+where d.deptno > e.mgr]]>
         </Resource>
-    <Resource name="planBefore">
-        <![CDATA[
+        <Resource name="planBefore">
+            <![CDATA[
 LogicalProject(DEPTNO=[$0], NAME=[$1], EMPNO=[$2], ENAME=[$3], JOB=[$4], MGR=[$5], HIREDATE=[$6], SAL=[$7], COMM=[$8], DEPTNO0=[$9], SLACKER=[$10])
   LogicalFilter(condition=[>($0, $5)])
     LogicalJoin(condition=[AND(=($0, $9), >($0, $5))], joinType=[inner])
       LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
       LogicalTableScan(table=[[CATALOG, SALES, EMP]])
 ]]>
-    </Resource>
-    <Resource name="planAfter">
-        <![CDATA[
+        </Resource>
+        <Resource name="planAfter">
+            <![CDATA[
 LogicalProject(DEPTNO=[$0], NAME=[$1], EMPNO=[$2], ENAME=[$3], JOB=[$4], MGR=[$5], HIREDATE=[$6], SAL=[$7], COMM=[$8], DEPTNO0=[$9], SLACKER=[$10])
   LogicalJoin(condition=[AND(=($0, $9), >($0, $5))], joinType=[inner])
     LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
     LogicalTableScan(table=[[CATALOG, SALES, EMP]])
 ]]>
-    </Resource>
+        </Resource>
     </TestCase>
     <TestCase name="testPushFilterThroughSemiJoin">
         <Resource name="sql">
@@ -3331,8 +3329,8 @@ LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$
     LogicalTableScan(table=[[CATALOG, SALES, EMP]])
     LogicalAggregate(group=[{0, 1}])
       LogicalProject(EXPR$0=[$2], EMPNO=[$1])
-        LogicalAggregate(group=[{0, 1}], EXPR$0=[MAX($3)])
-          LogicalProject(DEPTNO=[$7], EMPNO=[$0], $f1=['abc'], SAL=[$5])
+        LogicalAggregate(group=[{0, 1}], EXPR$0=[MAX($2)])
+          LogicalProject(DEPTNO=[$7], EMPNO=[$0], SAL=[$5])
             LogicalTableScan(table=[[CATALOG, SALES, EMP]])
 ]]>
         </Resource>
@@ -4805,8 +4803,8 @@ LogicalProject(DEPTNO=[$0], EXPR$1=[$2])
         </Resource>
         <Resource name="planAfter">
             <![CDATA[
-LogicalAggregate(group=[{0}], EXPR$1=[MAX($2)])
-  LogicalProject(DEPTNO=[$7], FOUR=[4], MGR=[$3])
+LogicalAggregate(group=[{0}], EXPR$1=[MAX($1)])
+  LogicalProject(DEPTNO=[$7], MGR=[$3])
     LogicalTableScan(table=[[CATALOG, SALES, EMP]])
 ]]>
         </Resource>
@@ -4827,8 +4825,8 @@ LogicalProject(DEPTNO=[$0], EXPR$1=[$2])
         </Resource>
         <Resource name="planAfter">
             <![CDATA[
-LogicalAggregate(group=[{0}], EXPR$1=[MAX($2)])
-  LogicalProject(DEPTNO=[$7], FOUR=[4], ENAME=[$1])
+LogicalAggregate(group=[{0}], EXPR$1=[MAX($1)])
+  LogicalProject(DEPTNO=[$7], ENAME=[$1])
     LogicalTableScan(table=[[CATALOG, SALES, EMP]])
 ]]>
         </Resource>
@@ -4850,8 +4848,8 @@ LogicalProject(DEPTNO=[$0], EXPR$1=[$4])
         <Resource name="planAfter">
             <![CDATA[
 LogicalProject(DEPTNO=[$0], EXPR$1=[$2])
-  LogicalAggregate(group=[{0, 3}], EXPR$1=[MAX($4)])
-    LogicalProject(DEPTNO=[$7], FOUR=[4], TWO_PLUS_THREE=[+(2, 3)], DEPTNO42=[+($7, 42)], MGR=[$3])
+  LogicalAggregate(group=[{0, 1}], EXPR$1=[MAX($2)])
+    LogicalProject(DEPTNO=[$7], DEPTNO42=[+($7, 42)], MGR=[$3])
       LogicalTableScan(table=[[CATALOG, SALES, EMP]])
 ]]>
         </Resource>
@@ -4872,8 +4870,8 @@ LogicalProject(DEPTNO=[$1], EXPR$1=[$2])
         </Resource>
         <Resource name="planAfter">
             <![CDATA[
-LogicalAggregate(group=[{1}], EXPR$1=[MAX($2)])
-  LogicalProject(FOUR=[4], DEPTNO=[$7], MGR=[$3])
+LogicalAggregate(group=[{0}], EXPR$1=[MAX($1)])
+  LogicalProject(DEPTNO=[$7], MGR=[$3])
     LogicalTableScan(table=[[CATALOG, SALES, EMP]])
 ]]>
         </Resource>
@@ -4894,8 +4892,8 @@ LogicalProject(DEPTNO=[$1], EXPR$1=[$2])
         </Resource>
         <Resource name="planAfter">
             <![CDATA[
-LogicalAggregate(group=[{1}], EXPR$1=[MAX($2)])
-  LogicalProject($f0=[+(42, 24)], DEPTNO=[$7], MGR=[$3])
+LogicalAggregate(group=[{0}], EXPR$1=[MAX($1)])
+  LogicalProject(DEPTNO=[$7], MGR=[$3])
     LogicalTableScan(table=[[CATALOG, SALES, EMP]])
 ]]>
         </Resource>
@@ -4916,8 +4914,8 @@ LogicalAggregate(group=[{0, 1}], EXPR$2=[MAX($2)])
         <Resource name="planAfter">
             <![CDATA[
 LogicalProject(EXPR$0=[$0], EXPR$1=[+(2, 3)], EXPR$2=[$1])
-  LogicalAggregate(group=[{0}], EXPR$2=[MAX($2)])
-    LogicalProject(EXPR$0=[4], EXPR$1=[+(2, 3)], MGR=[$3])
+  LogicalAggregate(group=[{0}], EXPR$2=[MAX($1)])
+    LogicalProject(EXPR$0=[4], MGR=[$3])
       LogicalTableScan(table=[[CATALOG, SALES, EMP]])
 ]]>
         </Resource>
@@ -4938,8 +4936,8 @@ LogicalAggregate(group=[{0, 1}], EXPR$2=[MAX($2)])
         <Resource name="planAfter">
             <![CDATA[
 LogicalProject(EXPR$0=[$0], EXPR$1=[+(2, 3)], EXPR$2=[$1])
-  LogicalAggregate(group=[{0}], EXPR$2=[MAX($2)])
-    LogicalProject(EXPR$0=[4], EXPR$1=[+(2, 3)], FIVE=[5])
+  LogicalAggregate(group=[{0}], EXPR$2=[MAX($1)])
+    LogicalProject(EXPR$0=[4], FIVE=[5])
       LogicalTableScan(table=[[CATALOG, SALES, EMP]])
 ]]>
         </Resource>
@@ -4960,8 +4958,8 @@ LogicalAggregate(group=[{0, 1}], EXPR$2=[MAX($2)])
         <Resource name="planAfter">
             <![CDATA[
 LogicalProject(EXPR$0=[$0], EXPR$1=[+(2, 3)], EXPR$2=[$1])
-  LogicalAggregate(group=[{0}], EXPR$2=[MAX($2)])
-    LogicalProject(EXPR$0=[4], EXPR$1=[+(2, 3)], $f2=[5])
+  LogicalAggregate(group=[{0}], EXPR$2=[MAX($1)])
+    LogicalProject(EXPR$0=[4], $f2=[5])
       LogicalTableScan(table=[[CATALOG, SALES, EMP]])
 ]]>
         </Resource>
@@ -6715,7 +6713,7 @@ LogicalProject(DEPTNO=[$0], NAME=[$1])
   LogicalJoin(condition=[=($0, $2)], joinType=[inner])
     LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
     LogicalAggregate(group=[{0}])
-      LogicalProject(DEPTNO=[$7], $f0=[true])
+      LogicalProject(DEPTNO=[$7])
         LogicalFilter(condition=[>($5, 100)])
           LogicalTableScan(table=[[CATALOG, SALES, EMP]])
 ]]>
@@ -6724,7 +6722,7 @@ LogicalProject(DEPTNO=[$0], NAME=[$1])
             <![CDATA[
 LogicalJoin(condition=[=($0, $2)], joinType=[semi])
   LogicalTableScan(table=[[CATALOG, SALES, DEPT]])
-  LogicalProject(DEPTNO=[$7], $f0=[true])
+  LogicalProject(DEPTNO=[$7])
     LogicalFilter(condition=[>($5, 100)])
       LogicalTableScan(table=[[CATALOG, SALES, EMP]])
 ]]>
@@ -7837,8 +7835,8 @@ LogicalAggregate(group=[{}], EXPR$0=[SUM($9)])
         </Resource>
         <Resource name="planAfter">
             <![CDATA[
-LogicalAggregate(group=[{}], EXPR$0=[SUM($3)])
-  LogicalProject(SAL=[$0], $f1=[$1], SAL0=[$2], $f3=[CAST(*($1, $2)):INTEGER])
+LogicalAggregate(group=[{}], EXPR$0=[SUM($0)])
+  LogicalProject($f3=[CAST(*($1, $2)):INTEGER])
     LogicalJoin(condition=[=($0, $2)], joinType=[inner])
       LogicalAggregate(group=[{5}], agg#0=[COUNT()])
         LogicalTableScan(table=[[CATALOG, SALES, EMP]])
@@ -7870,8 +7868,8 @@ LogicalProject(JOB=[$0], MGR0=[$2], DEPTNO=[$1], HIREDATE1=[$3], COMM1=[$4])
         <Resource name="planAfter">
             <![CDATA[
 LogicalProject(JOB=[$0], MGR0=[$2], DEPTNO=[$1], HIREDATE1=[$3], COMM1=[$4])
-  LogicalAggregate(group=[{0, 2, 4}], HIREDATE1=[MAX($6)], COMM1=[SUM($8)])
-    LogicalProject(JOB=[$0], SAL=[$1], DEPTNO=[$2], $f3=[$3], MGR=[$4], SAL0=[$5], HIREDATE1=[$6], COMM1=[$7], $f8=[CAST(*($3, $7)):INTEGER NOT NULL])
+  LogicalAggregate(group=[{0, 1, 2}], HIREDATE1=[MAX($3)], COMM1=[SUM($4)])
+    LogicalProject(JOB=[$0], DEPTNO=[$2], MGR=[$4], HIREDATE1=[$6], $f8=[CAST(*($3, $7)):INTEGER NOT NULL])
       LogicalJoin(condition=[=($1, $5)], joinType=[inner])
         LogicalAggregate(group=[{2, 5, 7}], agg#0=[COUNT()])
           LogicalTableScan(table=[[CATALOG, SALES, EMP]])
@@ -7930,8 +7928,8 @@ LogicalAggregate(group=[{}], EXPR$0=[SUM($5)])
         </Resource>
         <Resource name="planAfter">
             <![CDATA[
-LogicalAggregate(group=[{}], EXPR$0=[SUM($4)])
-  LogicalProject(JOB=[$0], EXPR$0=[$1], NAME=[$2], $f1=[$3], $f4=[CAST(*($1, $3)):INTEGER])
+LogicalAggregate(group=[{}], EXPR$0=[SUM($0)])
+  LogicalProject($f4=[CAST(*($1, $3)):INTEGER])
     LogicalJoin(condition=[=($0, $2)], joinType=[inner])
       LogicalAggregate(group=[{2}], EXPR$0=[SUM($5)])
         LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], SLACKER=[$8])
@@ -7961,8 +7959,8 @@ LogicalAggregate(group=[{}], EXPR$0=[SUM($5)])
         <Resource name="planAfter">
             <![CDATA[
 LogicalProject(EXPR$0=[CASE(=($1, 0), null:INTEGER, $0)])
-  LogicalAggregate(group=[{}], EXPR$0=[$SUM0($5)], agg#1=[$SUM0($6)])
-    LogicalProject(JOB=[$0], EXPR$0=[$1], $f2=[$2], NAME=[$3], $f1=[$4], $f5=[CAST(*($1, $4)):INTEGER NOT NULL], $f6=[*($2, $4)])
+  LogicalAggregate(group=[{}], EXPR$0=[$SUM0($0)], agg#1=[$SUM0($1)])
+    LogicalProject($f5=[CAST(*($1, $4)):INTEGER NOT NULL], $f6=[*($2, $4)])
       LogicalJoin(condition=[=($0, $3)], joinType=[inner])
         LogicalAggregate(group=[{2}], EXPR$0=[$SUM0($5)], agg#1=[COUNT()])
           LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], SLACKER=[$8])
@@ -8579,8 +8577,8 @@ LogicalAggregate(group=[{}], EXPR$0=[COUNT()])
         </Resource>
         <Resource name="planAfter">
             <![CDATA[
-LogicalAggregate(group=[{}], EXPR$0=[$SUM0($4)])
-  LogicalProject(JOB=[$0], EXPR$0=[$1], NAME=[$2], EXPR$00=[$3], $f4=[*($1, $3)])
+LogicalAggregate(group=[{}], EXPR$0=[$SUM0($0)])
+  LogicalProject($f4=[*($1, $3)])
     LogicalJoin(condition=[=($0, $2)], joinType=[inner])
       LogicalAggregate(group=[{2}], EXPR$0=[COUNT()])
         LogicalTableScan(table=[[CATALOG, SALES, EMP]])
@@ -8605,8 +8603,8 @@ LogicalAggregate(group=[{}], VOLUME=[COUNT()], C1_SUM_SAL=[SUM($0)])
         </Resource>
         <Resource name="planAfter">
             <![CDATA[
-LogicalAggregate(group=[{}], VOLUME=[$SUM0($3)], C1_SUM_SAL=[SUM($4)])
-  LogicalProject(ENAME=[$0], SAL=[$1], ENAME0=[$2], VOLUME=[$3], $f4=[CAST(*($1, $3)):INTEGER])
+LogicalAggregate(group=[{}], VOLUME=[$SUM0($0)], C1_SUM_SAL=[SUM($1)])
+  LogicalProject(VOLUME=[$3], $f4=[CAST(*($1, $3)):INTEGER])
     LogicalJoin(condition=[=($0, $2)], joinType=[inner])
       LogicalProject(ENAME=[$1], SAL=[$0])
         LogicalProject(SAL=[$5], ENAME=[$1])
@@ -10887,8 +10885,8 @@ LogicalProject(C=[$2])
             <![CDATA[
 LogicalProject(C=[$2])
   LogicalProject(DEPTNO=[10], SAL=[$0], C=[$1])
-    LogicalAggregate(group=[{1}], C=[COUNT()])
-      LogicalProject(DEPTNO=[$7], SAL=[$5])
+    LogicalAggregate(group=[{0}], C=[COUNT()])
+      LogicalProject(SAL=[$5])
         LogicalFilter(condition=[=($7, 10)])
           LogicalTableScan(table=[[CATALOG, SALES, EMP]])
 ]]>
@@ -10944,7 +10942,7 @@ LogicalProject(JOB=[$1])
   LogicalFilter(condition=[>($2, 3)])
     LogicalProject(SAL=[$0], JOB=['Clerk':VARCHAR(10)], $f2=[$1])
       LogicalAggregate(group=[{0}], agg#0=[COUNT()])
-        LogicalProject(SAL=[$5], JOB=[$2])
+        LogicalProject(SAL=[$5])
           LogicalFilter(condition=[AND(IS NULL($5), =($2, 'Clerk'))])
             LogicalTableScan(table=[[CATALOG, SALES, EMP]])
 ]]>
@@ -10974,7 +10972,7 @@ LogicalProject(HIREDATE=[$1])
   LogicalFilter(condition=[>($2, 3)])
     LogicalProject(SAL=[$0], HIREDATE=[CURRENT_TIMESTAMP], $f2=[$1])
       LogicalAggregate(group=[{0}], agg#0=[COUNT()])
-        LogicalProject(SAL=[$5], HIREDATE=[$4])
+        LogicalProject(SAL=[$5])
           LogicalFilter(condition=[AND(IS NULL($5), =($4, CURRENT_TIMESTAMP))])
             LogicalTableScan(table=[[CATALOG, SALES, EMP]])
 ]]>
@@ -11333,8 +11331,8 @@ LogicalAggregate(group=[{0, 1, 2}], S=[SUM($2)])
         <Resource name="planAfter">
             <![CDATA[
 LogicalProject(JOB=[$0], EMPNO=[10], SAL=[$1], S=[$2])
-  LogicalAggregate(group=[{0, 2}], S=[SUM($2)])
-    LogicalProject(JOB=[$2], EMPNO=[$0], SAL=[$5])
+  LogicalAggregate(group=[{0, 1}], S=[SUM($1)])
+    LogicalProject(JOB=[$2], SAL=[$5])
       LogicalFilter(condition=[=($0, 10)])
         LogicalTableScan(table=[[CATALOG, SALES, EMP]])
 ]]>
diff --git a/piglet/src/test/java/org/apache/calcite/test/PigRelOpTest.java b/piglet/src/test/java/org/apache/calcite/test/PigRelOpTest.java
index c29ff62..25ea76d 100644
--- a/piglet/src/test/java/org/apache/calcite/test/PigRelOpTest.java
+++ b/piglet/src/test/java/org/apache/calcite/test/PigRelOpTest.java
@@ -1050,8 +1050,8 @@ public class PigRelOpTest extends PigRelTestBase {
         + "        LogicalTableScan(table=[[scott, DEPT]])\n";
     final String optimizedPlan = ""
         + "LogicalProject($f0=[$1])\n"
-        + "  LogicalAggregate(group=[{0}], agg#0=[COLLECT($2)])\n"
-        + "    LogicalProject(DEPTNO=[$0], DNAME=[$1], $f2=[ROW($0, $1)])\n"
+        + "  LogicalAggregate(group=[{0}], agg#0=[COLLECT($1)])\n"
+        + "    LogicalProject(DEPTNO=[$0], $f2=[ROW($0, $1)])\n"
         + "      LogicalTableScan(table=[[scott, DEPT]])\n";
     final String result = ""
         + "({(20,RESEARCH)})\n"
diff --git a/piglet/src/test/java/org/apache/calcite/test/PigletTest.java b/piglet/src/test/java/org/apache/calcite/test/PigletTest.java
index 19f53fb..83efe3b 100644
--- a/piglet/src/test/java/org/apache/calcite/test/PigletTest.java
+++ b/piglet/src/test/java/org/apache/calcite/test/PigletTest.java
@@ -135,8 +135,8 @@ public class PigletTest {
     final String s = "A = LOAD 'EMP';\n"
         + "B = GROUP A BY DEPTNO;";
     final String expected = ""
-        + "LogicalAggregate(group=[{7}], A=[COLLECT($8)])\n"
-        + "  LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], $f8=[ROW($0, $1, $2, $3, $4, $5, $6, $7)])\n"
+        + "LogicalAggregate(group=[{0}], A=[COLLECT($1)])\n"
+        + "  LogicalProject(DEPTNO=[$7], $f8=[ROW($0, $1, $2, $3, $4, $5, $6, $7)])\n"
         + "    LogicalTableScan(table=[[scott, EMP]])\n";
     pig(s).explainContains(expected);
   }


[calcite] 01/03: Add RelBuilder.transform, which allows you to clone a RelBuilder with slightly different Config

Posted by jh...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

jhyde pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/calcite.git

commit 051b6919dfc5b60406e81d5e8b8d5efb263def87
Author: Julian Hyde <jh...@apache.org>
AuthorDate: Fri Feb 7 18:10:42 2020 -0800

    Add RelBuilder.transform, which allows you to clone a RelBuilder with slightly different Config
    
    Add class RelFactories.Struct, which contains an instance of
    each RelNode factory. This allows more efficient initialization
    of RelBuilder.
---
 .../org/apache/calcite/rel/core/RelFactories.java  | 132 +++++++++++++--
 .../java/org/apache/calcite/tools/RelBuilder.java  | 184 +++++++++------------
 2 files changed, 193 insertions(+), 123 deletions(-)

diff --git a/core/src/main/java/org/apache/calcite/rel/core/RelFactories.java b/core/src/main/java/org/apache/calcite/rel/core/RelFactories.java
index 13b2cfd..e077008 100644
--- a/core/src/main/java/org/apache/calcite/rel/core/RelFactories.java
+++ b/core/src/main/java/org/apache/calcite/rel/core/RelFactories.java
@@ -17,6 +17,7 @@
 package org.apache.calcite.rel.core;
 
 import org.apache.calcite.linq4j.function.Experimental;
+import org.apache.calcite.plan.Context;
 import org.apache.calcite.plan.Contexts;
 import org.apache.calcite.plan.RelOptCluster;
 import org.apache.calcite.plan.RelOptTable;
@@ -53,6 +54,7 @@ import org.apache.calcite.sql.SqlKind;
 import org.apache.calcite.tools.RelBuilder;
 import org.apache.calcite.tools.RelBuilderFactory;
 import org.apache.calcite.util.ImmutableBitSet;
+import org.apache.calcite.util.Util;
 
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableSet;
@@ -60,6 +62,7 @@ import com.google.common.collect.ImmutableSet;
 import java.lang.reflect.Type;
 import java.util.List;
 import java.util.Map;
+import java.util.Objects;
 import java.util.Set;
 import java.util.SortedSet;
 import javax.annotation.Nonnull;
@@ -116,24 +119,28 @@ public class RelFactories {
   public static final RepeatUnionFactory DEFAULT_REPEAT_UNION_FACTORY =
       new RepeatUnionFactoryImpl();
 
+  public static final Struct DEFAULT_STRUCT =
+      new Struct(DEFAULT_FILTER_FACTORY,
+          DEFAULT_PROJECT_FACTORY,
+          DEFAULT_AGGREGATE_FACTORY,
+          DEFAULT_SORT_FACTORY,
+          DEFAULT_EXCHANGE_FACTORY,
+          DEFAULT_SORT_EXCHANGE_FACTORY,
+          DEFAULT_SET_OP_FACTORY,
+          DEFAULT_JOIN_FACTORY,
+          DEFAULT_CORRELATE_FACTORY,
+          DEFAULT_VALUES_FACTORY,
+          DEFAULT_TABLE_SCAN_FACTORY,
+          DEFAULT_TABLE_FUNCTION_SCAN_FACTORY,
+          DEFAULT_SNAPSHOT_FACTORY,
+          DEFAULT_MATCH_FACTORY,
+          DEFAULT_SPOOL_FACTORY,
+          DEFAULT_REPEAT_UNION_FACTORY);
+
   /** A {@link RelBuilderFactory} that creates a {@link RelBuilder} that will
    * create logical relational expressions for everything. */
   public static final RelBuilderFactory LOGICAL_BUILDER =
-      RelBuilder.proto(
-          Contexts.of(DEFAULT_PROJECT_FACTORY,
-              DEFAULT_FILTER_FACTORY,
-              DEFAULT_JOIN_FACTORY,
-              DEFAULT_SORT_FACTORY,
-              DEFAULT_EXCHANGE_FACTORY,
-              DEFAULT_SORT_EXCHANGE_FACTORY,
-              DEFAULT_AGGREGATE_FACTORY,
-              DEFAULT_MATCH_FACTORY,
-              DEFAULT_SET_OP_FACTORY,
-              DEFAULT_VALUES_FACTORY,
-              DEFAULT_TABLE_SCAN_FACTORY,
-              DEFAULT_SNAPSHOT_FACTORY,
-              DEFAULT_SPOOL_FACTORY,
-              DEFAULT_REPEAT_UNION_FACTORY));
+      RelBuilder.proto(Contexts.of(DEFAULT_STRUCT));
 
   private RelFactories() {
   }
@@ -673,4 +680,99 @@ public class RelFactories {
       return LogicalRepeatUnion.create(seed, iterative, all, iterationLimit);
     }
   }
+
+  /** Immutable record that contains an instance of each factory. */
+  public static class Struct {
+    public final FilterFactory filterFactory;
+    public final ProjectFactory projectFactory;
+    public final AggregateFactory aggregateFactory;
+    public final SortFactory sortFactory;
+    public final ExchangeFactory exchangeFactory;
+    public final SortExchangeFactory sortExchangeFactory;
+    public final SetOpFactory setOpFactory;
+    public final JoinFactory joinFactory;
+    public final CorrelateFactory correlateFactory;
+    public final ValuesFactory valuesFactory;
+    public final TableScanFactory scanFactory;
+    public final TableFunctionScanFactory tableFunctionScanFactory;
+    public final SnapshotFactory snapshotFactory;
+    public final MatchFactory matchFactory;
+    public final SpoolFactory spoolFactory;
+    public final RepeatUnionFactory repeatUnionFactory;
+
+    private Struct(FilterFactory filterFactory,
+        ProjectFactory projectFactory,
+        AggregateFactory aggregateFactory,
+        SortFactory sortFactory,
+        ExchangeFactory exchangeFactory,
+        SortExchangeFactory sortExchangeFactory,
+        SetOpFactory setOpFactory,
+        JoinFactory joinFactory,
+        CorrelateFactory correlateFactory,
+        ValuesFactory valuesFactory,
+        TableScanFactory scanFactory,
+        TableFunctionScanFactory tableFunctionScanFactory,
+        SnapshotFactory snapshotFactory,
+        MatchFactory matchFactory,
+        SpoolFactory spoolFactory,
+        RepeatUnionFactory repeatUnionFactory) {
+      this.filterFactory = Objects.requireNonNull(filterFactory);
+      this.projectFactory = Objects.requireNonNull(projectFactory);
+      this.aggregateFactory = Objects.requireNonNull(aggregateFactory);
+      this.sortFactory = Objects.requireNonNull(sortFactory);
+      this.exchangeFactory = Objects.requireNonNull(exchangeFactory);
+      this.sortExchangeFactory = Objects.requireNonNull(sortExchangeFactory);
+      this.setOpFactory = Objects.requireNonNull(setOpFactory);
+      this.joinFactory = Objects.requireNonNull(joinFactory);
+      this.correlateFactory = Objects.requireNonNull(correlateFactory);
+      this.valuesFactory = Objects.requireNonNull(valuesFactory);
+      this.scanFactory = Objects.requireNonNull(scanFactory);
+      this.tableFunctionScanFactory =
+          Objects.requireNonNull(tableFunctionScanFactory);
+      this.snapshotFactory = Objects.requireNonNull(snapshotFactory);
+      this.matchFactory = Objects.requireNonNull(matchFactory);
+      this.spoolFactory = Objects.requireNonNull(spoolFactory);
+      this.repeatUnionFactory = Objects.requireNonNull(repeatUnionFactory);
+    }
+
+    public static @Nonnull Struct fromContext(Context context) {
+      Struct struct = context.unwrap(Struct.class);
+      if (struct != null) {
+        return struct;
+      }
+      return new Struct(
+          Util.first(context.unwrap(FilterFactory.class),
+              DEFAULT_FILTER_FACTORY),
+          Util.first(context.unwrap(ProjectFactory.class),
+              DEFAULT_PROJECT_FACTORY),
+          Util.first(context.unwrap(AggregateFactory.class),
+              DEFAULT_AGGREGATE_FACTORY),
+          Util.first(context.unwrap(SortFactory.class),
+              DEFAULT_SORT_FACTORY),
+          Util.first(context.unwrap(ExchangeFactory.class),
+              DEFAULT_EXCHANGE_FACTORY),
+          Util.first(context.unwrap(SortExchangeFactory.class),
+              DEFAULT_SORT_EXCHANGE_FACTORY),
+          Util.first(context.unwrap(SetOpFactory.class),
+              DEFAULT_SET_OP_FACTORY),
+          Util.first(context.unwrap(JoinFactory.class),
+              DEFAULT_JOIN_FACTORY),
+          Util.first(context.unwrap(CorrelateFactory.class),
+              DEFAULT_CORRELATE_FACTORY),
+          Util.first(context.unwrap(ValuesFactory.class),
+              DEFAULT_VALUES_FACTORY),
+          Util.first(context.unwrap(TableScanFactory.class),
+              DEFAULT_TABLE_SCAN_FACTORY),
+          Util.first(context.unwrap(TableFunctionScanFactory.class),
+              DEFAULT_TABLE_FUNCTION_SCAN_FACTORY),
+          Util.first(context.unwrap(SnapshotFactory.class),
+              DEFAULT_SNAPSHOT_FACTORY),
+          Util.first(context.unwrap(MatchFactory.class),
+              DEFAULT_MATCH_FACTORY),
+          Util.first(context.unwrap(SpoolFactory.class),
+              DEFAULT_SPOOL_FACTORY),
+          Util.first(context.unwrap(RepeatUnionFactory.class),
+              DEFAULT_REPEAT_UNION_FACTORY));
+    }
+  }
 }
diff --git a/core/src/main/java/org/apache/calcite/tools/RelBuilder.java b/core/src/main/java/org/apache/calcite/tools/RelBuilder.java
index b4455a0..e4ef056 100644
--- a/core/src/main/java/org/apache/calcite/tools/RelBuilder.java
+++ b/core/src/main/java/org/apache/calcite/tools/RelBuilder.java
@@ -119,6 +119,7 @@ import java.util.Objects;
 import java.util.Set;
 import java.util.SortedSet;
 import java.util.TreeSet;
+import java.util.function.UnaryOperator;
 import java.util.stream.Collectors;
 import javax.annotation.Nonnull;
 
@@ -143,25 +144,10 @@ import static org.apache.calcite.util.Static.RESOURCE;
 public class RelBuilder {
   protected final RelOptCluster cluster;
   protected final RelOptSchema relOptSchema;
-  private final RelFactories.FilterFactory filterFactory;
-  private final RelFactories.ProjectFactory projectFactory;
-  private final RelFactories.AggregateFactory aggregateFactory;
-  private final RelFactories.SortFactory sortFactory;
-  private final RelFactories.ExchangeFactory exchangeFactory;
-  private final RelFactories.SortExchangeFactory sortExchangeFactory;
-  private final RelFactories.SetOpFactory setOpFactory;
-  private final RelFactories.JoinFactory joinFactory;
-  private final RelFactories.CorrelateFactory correlateFactory;
-  private final RelFactories.ValuesFactory valuesFactory;
-  private final RelFactories.TableScanFactory scanFactory;
-  private final RelFactories.TableFunctionScanFactory tableFunctionScanFactory;
-  private final RelFactories.SnapshotFactory snapshotFactory;
-  private final RelFactories.MatchFactory matchFactory;
-  private final RelFactories.SpoolFactory spoolFactory;
-  private final RelFactories.RepeatUnionFactory repeatUnionFactory;
   private final Deque<Frame> stack = new ArrayDeque<>();
   private final RexSimplify simplifier;
   private final Config config;
+  private final RelFactories.Struct struct;
 
   protected RelBuilder(Context context, RelOptCluster cluster,
       RelOptSchema relOptSchema) {
@@ -171,54 +157,8 @@ public class RelBuilder {
       context = Contexts.EMPTY_CONTEXT;
     }
     this.config = getConfig(context);
-    this.aggregateFactory =
-        Util.first(context.unwrap(RelFactories.AggregateFactory.class),
-            RelFactories.DEFAULT_AGGREGATE_FACTORY);
-    this.filterFactory =
-        Util.first(context.unwrap(RelFactories.FilterFactory.class),
-            RelFactories.DEFAULT_FILTER_FACTORY);
-    this.projectFactory =
-        Util.first(context.unwrap(RelFactories.ProjectFactory.class),
-            RelFactories.DEFAULT_PROJECT_FACTORY);
-    this.sortFactory =
-        Util.first(context.unwrap(RelFactories.SortFactory.class),
-            RelFactories.DEFAULT_SORT_FACTORY);
-    this.exchangeFactory =
-        Util.first(context.unwrap(RelFactories.ExchangeFactory.class),
-            RelFactories.DEFAULT_EXCHANGE_FACTORY);
-    this.sortExchangeFactory =
-        Util.first(context.unwrap(RelFactories.SortExchangeFactory.class),
-            RelFactories.DEFAULT_SORT_EXCHANGE_FACTORY);
-    this.setOpFactory =
-        Util.first(context.unwrap(RelFactories.SetOpFactory.class),
-            RelFactories.DEFAULT_SET_OP_FACTORY);
-    this.joinFactory =
-        Util.first(context.unwrap(RelFactories.JoinFactory.class),
-            RelFactories.DEFAULT_JOIN_FACTORY);
-    this.correlateFactory =
-        Util.first(context.unwrap(RelFactories.CorrelateFactory.class),
-            RelFactories.DEFAULT_CORRELATE_FACTORY);
-    this.valuesFactory =
-        Util.first(context.unwrap(RelFactories.ValuesFactory.class),
-            RelFactories.DEFAULT_VALUES_FACTORY);
-    this.scanFactory =
-        Util.first(context.unwrap(RelFactories.TableScanFactory.class),
-            RelFactories.DEFAULT_TABLE_SCAN_FACTORY);
-    this.tableFunctionScanFactory =
-        Util.first(context.unwrap(RelFactories.TableFunctionScanFactory.class),
-            RelFactories.DEFAULT_TABLE_FUNCTION_SCAN_FACTORY);
-    this.snapshotFactory =
-        Util.first(context.unwrap(RelFactories.SnapshotFactory.class),
-            RelFactories.DEFAULT_SNAPSHOT_FACTORY);
-    this.matchFactory =
-        Util.first(context.unwrap(RelFactories.MatchFactory.class),
-            RelFactories.DEFAULT_MATCH_FACTORY);
-    this.spoolFactory =
-        Util.first(context.unwrap(RelFactories.SpoolFactory.class),
-            RelFactories.DEFAULT_SPOOL_FACTORY);
-    this.repeatUnionFactory =
-        Util.first(context.unwrap(RelFactories.RepeatUnionFactory.class),
-            RelFactories.DEFAULT_REPEAT_UNION_FACTORY);
+    this.struct =
+        Objects.requireNonNull(RelFactories.Struct.fromContext(context));
     final RexExecutor executor =
         Util.first(context.unwrap(RexExecutor.class),
             Util.first(cluster.getPlanner().getExecutor(), RexUtil.EXECUTOR));
@@ -246,6 +186,14 @@ public class RelBuilder {
             new RelBuilder(config.getContext(), cluster, relOptSchema));
   }
 
+  /** Creates a copy of this RelBuilder, with the same state as this, applying
+   * a transform to the config. */
+  public RelBuilder transform(UnaryOperator<Config> transform) {
+    final Context context =
+        Contexts.of(struct, transform.apply(config));
+    return new RelBuilder(context, cluster, relOptSchema);
+  }
+
   /** Converts this RelBuilder to a string.
    * The string is the string representation of all of the RelNodes on the stack. */
   @Override public String toString() {
@@ -284,7 +232,7 @@ public class RelBuilder {
   }
 
   public RelFactories.TableScanFactory getScanFactory() {
-    return scanFactory;
+    return struct.scanFactory;
   }
 
   // Methods for manipulating the stack
@@ -1071,7 +1019,9 @@ public class RelBuilder {
     if (relOptTable == null) {
       throw RESOURCE.tableNotFound(String.join(".", names)).ex();
     }
-    final RelNode scan = scanFactory.createScan(cluster, relOptTable, ImmutableList.of());
+    final RelNode scan =
+        struct.scanFactory.createScan(cluster, relOptTable,
+            ImmutableList.of());
     push(scan);
     rename(relOptTable.getRowType().getFieldNames());
 
@@ -1104,7 +1054,8 @@ public class RelBuilder {
    */
   public RelBuilder snapshot(RexNode period) {
     final Frame frame = stack.pop();
-    final RelNode snapshot = snapshotFactory.createSnapshot(frame.rel, period);
+    final RelNode snapshot =
+        struct.snapshotFactory.createSnapshot(frame.rel, period);
     stack.push(new Frame(snapshot, frame.fields));
     return this;
   }
@@ -1164,8 +1115,8 @@ public class RelBuilder {
 
     final RexNode call = call(operator, ImmutableList.copyOf(operands));
     final RelNode functionScan =
-        tableFunctionScanFactory.createTableFunctionScan(cluster, inputs,
-            call, null, getColumnMappings(operator));
+        struct.tableFunctionScanFactory.createTableFunctionScan(cluster,
+            inputs, call, null, getColumnMappings(operator));
     push(functionScan);
     return this;
   }
@@ -1218,8 +1169,9 @@ public class RelBuilder {
 
     if (!simplifiedPredicates.isAlwaysTrue()) {
       final Frame frame = stack.pop();
-      final RelNode filter = filterFactory.createFilter(frame.rel,
-          simplifiedPredicates, ImmutableSet.copyOf(variablesSet));
+      final RelNode filter =
+          struct.filterFactory.createFilter(frame.rel,
+              simplifiedPredicates, ImmutableSet.copyOf(variablesSet));
       stack.push(new Frame(filter, frame.fields));
     }
     return this;
@@ -1478,7 +1430,7 @@ public class RelBuilder {
       return this;
     }
     final RelNode project =
-        projectFactory.createProject(frame.rel,
+        struct.projectFactory.createProject(frame.rel,
             ImmutableList.copyOf(hints),
             ImmutableList.copyOf(nodeList),
             fieldNameList);
@@ -1771,8 +1723,9 @@ public class RelBuilder {
       ImmutableList<ImmutableBitSet> groupSets, RelNode input,
       List<AggregateCall> aggregateCalls, List<RexNode> extraNodes,
       ImmutableList<Field> inFields) {
-    final RelNode aggregate = aggregateFactory.createAggregate(input,
-        ImmutableList.of(), groupSet, groupSets, aggregateCalls);
+    final RelNode aggregate =
+        struct.aggregateFactory.createAggregate(input,
+            ImmutableList.of(), groupSet, groupSets, aggregateCalls);
 
     // build field list
     final ImmutableList.Builder<Field> fields = ImmutableList.builder();
@@ -1829,7 +1782,7 @@ public class RelBuilder {
     case 1:
       return push(inputs.get(0));
     default:
-      return push(setOpFactory.createSetOp(kind, inputs, all));
+      return push(struct.setOpFactory.createSetOp(kind, inputs, all));
     }
   }
 
@@ -1914,7 +1867,9 @@ public class RelBuilder {
         rowType,
         transientTable,
         ImmutableList.of(tableName));
-    RelNode scan = scanFactory.createScan(cluster, relOptTable, ImmutableList.of());
+    RelNode scan =
+        struct.scanFactory.createScan(cluster, relOptTable,
+            ImmutableList.of());
     push(scan);
     rename(rowType.getFieldNames());
     return this;
@@ -1927,8 +1882,11 @@ public class RelBuilder {
    * @param writeType Spool's write type (as described in {@link Spool.Type})
    * @param table Table to write into
    */
-  private RelBuilder tableSpool(Spool.Type readType, Spool.Type writeType, RelOptTable table) {
-    RelNode spool =  spoolFactory.createTableSpool(peek(), readType, writeType, table);
+  private RelBuilder tableSpool(Spool.Type readType, Spool.Type writeType,
+      RelOptTable table) {
+    RelNode spool =
+        struct.spoolFactory.createTableSpool(peek(), readType, writeType,
+            table);
     replaceTop(spool);
     return this;
   }
@@ -1984,8 +1942,10 @@ public class RelBuilder {
 
     RelNode iterative = tableSpool(Spool.Type.LAZY, Spool.Type.LAZY, finder.relOptTable).build();
     RelNode seed = tableSpool(Spool.Type.LAZY, Spool.Type.LAZY, finder.relOptTable).build();
-    RelNode repUnion = repeatUnionFactory.createRepeatUnion(seed, iterative, all, iterationLimit);
-    return push(repUnion);
+    RelNode repeatUnion =
+        struct.repeatUnionFactory.createRepeatUnion(seed, iterative, all,
+            iterationLimit);
+    return push(repeatUnion);
   }
 
   /**
@@ -2065,11 +2025,13 @@ public class RelBuilder {
       default:
         postCondition = condition;
       }
-      join = correlateFactory.createCorrelate(left.rel, right.rel, id,
-          requiredColumns, joinType);
+      join =
+          struct.correlateFactory.createCorrelate(left.rel, right.rel, id,
+              requiredColumns, joinType);
     } else {
-      join = joinFactory.createJoin(left.rel, right.rel, ImmutableList.of(), condition,
-          variablesSet, joinType, false);
+      join =
+          struct.joinFactory.createJoin(left.rel, right.rel,
+              ImmutableList.of(), condition, variablesSet, joinType, false);
     }
     final ImmutableList.Builder<Field> fields = ImmutableList.builder();
     fields.addAll(left.fields);
@@ -2102,9 +2064,9 @@ public class RelBuilder {
     rename(registrar.names);
     Frame left = stack.pop();
 
-    final RelNode correlate = correlateFactory
-        .createCorrelate(left.rel, right.rel, correlationId,
-            ImmutableBitSet.of(requiredOrdinals), joinType);
+    final RelNode correlate =
+        struct.correlateFactory.createCorrelate(left.rel, right.rel,
+            correlationId, ImmutableBitSet.of(requiredOrdinals), joinType);
 
     final ImmutableList.Builder<Field> fields = ImmutableList.builder();
     fields.addAll(left.fields);
@@ -2154,7 +2116,7 @@ public class RelBuilder {
   public RelBuilder semiJoin(Iterable<? extends RexNode> conditions) {
     final Frame right = stack.pop();
     final RelNode semiJoin =
-        joinFactory.createJoin(peek(),
+        struct.joinFactory.createJoin(peek(),
             right.rel,
             ImmutableList.of(),
             and(conditions),
@@ -2191,7 +2153,7 @@ public class RelBuilder {
   public RelBuilder antiJoin(Iterable<? extends RexNode> conditions) {
     final Frame right = stack.pop();
     final RelNode antiJoin =
-        joinFactory.createJoin(peek(),
+        struct.joinFactory.createJoin(peek(),
             right.rel,
             ImmutableList.of(),
             and(conditions),
@@ -2311,7 +2273,8 @@ public class RelBuilder {
   public RelBuilder empty() {
     final Frame frame = stack.pop();
     final RelNode values =
-        valuesFactory.createValues(cluster, frame.rel.getRowType(), ImmutableList.of());
+        struct.valuesFactory.createValues(cluster, frame.rel.getRowType(),
+            ImmutableList.of());
     stack.push(new Frame(values, frame.fields));
     return this;
   }
@@ -2328,8 +2291,9 @@ public class RelBuilder {
   public RelBuilder values(RelDataType rowType, Object... columnValues) {
     final ImmutableList<ImmutableList<RexLiteral>> tupleList =
         tupleList(rowType.getFieldCount(), columnValues);
-    RelNode values = valuesFactory.createValues(cluster, rowType,
-        ImmutableList.copyOf(tupleList));
+    RelNode values =
+        struct.valuesFactory.createValues(cluster, rowType,
+            ImmutableList.copyOf(tupleList));
     push(values);
     return this;
   }
@@ -2346,7 +2310,8 @@ public class RelBuilder {
   public RelBuilder values(Iterable<? extends List<RexLiteral>> tupleList,
       RelDataType rowType) {
     RelNode values =
-        valuesFactory.createValues(cluster, rowType, copy(tupleList));
+        struct.valuesFactory.createValues(cluster, rowType,
+            copy(tupleList));
     push(values);
     return this;
   }
@@ -2390,7 +2355,8 @@ public class RelBuilder {
 
   /** Creates an Exchange by distribution. */
   public RelBuilder exchange(RelDistribution distribution) {
-    RelNode exchange = exchangeFactory.createExchange(peek(), distribution);
+    RelNode exchange =
+        struct.exchangeFactory.createExchange(peek(), distribution);
     replaceTop(exchange);
     return this;
   }
@@ -2398,8 +2364,9 @@ public class RelBuilder {
   /** Creates a SortExchange by distribution and collation. */
   public RelBuilder sortExchange(RelDistribution distribution,
       RelCollation collation) {
-    RelNode exchange = sortExchangeFactory
-        .createSortExchange(peek(), distribution, collation);
+    RelNode exchange =
+        struct.sortExchangeFactory.createSortExchange(peek(), distribution,
+            collation);
     replaceTop(exchange);
     return this;
   }
@@ -2461,7 +2428,7 @@ public class RelBuilder {
         if (sort2.offset == null && sort2.fetch == null) {
           replaceTop(sort2.getInput());
           final RelNode sort =
-              sortFactory.createSort(peek(), sort2.collation,
+              struct.sortFactory.createSort(peek(), sort2.collation,
                   offsetNode, fetchNode);
           replaceTop(sort);
           return this;
@@ -2473,10 +2440,10 @@ public class RelBuilder {
           final Sort sort2 = (Sort) project.getInput();
           if (sort2.offset == null && sort2.fetch == null) {
             final RelNode sort =
-                sortFactory.createSort(sort2.getInput(), sort2.collation,
-                    offsetNode, fetchNode);
+                struct.sortFactory.createSort(sort2.getInput(),
+                    sort2.collation, offsetNode, fetchNode);
             replaceTop(
-                projectFactory.createProject(sort,
+                struct.projectFactory.createProject(sort,
                     project.getHints(),
                     project.getProjects(),
                     Pair.right(project.getNamedProjects())));
@@ -2489,8 +2456,8 @@ public class RelBuilder {
       project(registrar.extraNodes);
     }
     final RelNode sort =
-        sortFactory.createSort(peek(), RelCollations.of(fieldCollations),
-            offsetNode, fetchNode);
+        struct.sortFactory.createSort(peek(),
+            RelCollations.of(fieldCollations), offsetNode, fetchNode);
     replaceTop(sort);
     if (registrar.addedFieldCount() > 0) {
       project(registrar.originalExtraNodes);
@@ -2534,7 +2501,8 @@ public class RelBuilder {
   public RelBuilder convert(RelDataType castRowType, boolean rename) {
     final RelNode r = build();
     final RelNode r2 =
-        RelOptUtil.createCastRel(r, castRowType, rename, projectFactory);
+        RelOptUtil.createCastRel(r, castRowType, rename,
+            struct.projectFactory);
     push(r2);
     return this;
   }
@@ -2594,11 +2562,11 @@ public class RelBuilder {
       measures.put(alias, operands.get(0));
     }
 
-    final RelNode match = matchFactory.createMatch(peek(), pattern,
-        typeBuilder.build(), strictStart, strictEnd, patternDefinitions,
-        measures.build(), after, subsets, allRows,
-        partitionBitSet, RelCollations.of(fieldCollations),
-        interval);
+    final RelNode match =
+        struct.matchFactory.createMatch(peek(), pattern,
+            typeBuilder.build(), strictStart, strictEnd, patternDefinitions,
+            measures.build(), after, subsets, allRows,
+            partitionBitSet, RelCollations.of(fieldCollations), interval);
     stack.push(new Frame(match));
     return this;
   }