You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@calcite.apache.org by jh...@apache.org on 2021/03/04 01:56:11 UTC

[calcite] branch master updated: [CALCITE-4276] MaterializedViewOnlyAggregateRule performs invalid rewrite on query that contains join and time-rollup function (FLOOR) (Justin Swett)

This is an automated email from the ASF dual-hosted git repository.

jhyde pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/calcite.git


The following commit(s) were added to refs/heads/master by this push:
     new 0fb14d5  [CALCITE-4276] MaterializedViewOnlyAggregateRule performs invalid rewrite on query that contains join and time-rollup function (FLOOR) (Justin Swett)
0fb14d5 is described below

commit 0fb14d553764f2a993ec56db4a36de2713ac1206
Author: Justin Swett <js...@google.com>
AuthorDate: Tue Sep 29 21:48:42 2020 -0700

    [CALCITE-4276] MaterializedViewOnlyAggregateRule performs invalid rewrite on query that contains join and time-rollup function (FLOOR) (Justin Swett)
    
    Without the fix, MaterializedViewOnlyAggregateRule gets a
    field ordinal wrong, which manifests as a type mismatch
    something like this:
    
      java.lang.AssertionError: type mismatch:
          ref: TIMESTAMP(3)
          input: INTEGER NOT NULL
        at org.apache.calcite.util.Litmus$1.fail(Litmus.java:32)
        at org.apache.calcite.plan.RelOptUtil.eq(RelOptUtil.java:2207)
        at org.apache.calcite.rex.RexChecker.visitInputRef(RexChecker.java:129)
    
    It's difficult to see among all the refactoring, but the fix
    is just two lines; in MaterializedViewOnlyAggregateRule, change
    
      final int k = find(topViewProject, r);
    
    to
    
       final int j = find(viewNode, r);
       final int k = find(topViewProject, j);
    
    and the problem goes away.
---
 .../materialize/MaterializedViewAggregateRule.java | 316 ++++++++++-----------
 .../java/org/apache/calcite/tools/RelBuilder.java  |  15 +
 .../java/org/apache/calcite/test/JdbcTest.java     |  10 +
 .../apache/calcite/test/MaterializationTest.java   |   9 +
 .../test/MaterializedViewRelOptRulesTest.java      |  38 ++-
 5 files changed, 225 insertions(+), 163 deletions(-)

diff --git a/core/src/main/java/org/apache/calcite/rel/rules/materialize/MaterializedViewAggregateRule.java b/core/src/main/java/org/apache/calcite/rel/rules/materialize/MaterializedViewAggregateRule.java
index 34ae639..02c1a98 100644
--- a/core/src/main/java/org/apache/calcite/rel/rules/materialize/MaterializedViewAggregateRule.java
+++ b/core/src/main/java/org/apache/calcite/rel/rules/materialize/MaterializedViewAggregateRule.java
@@ -17,6 +17,7 @@
 package org.apache.calcite.rel.rules.materialize;
 
 import org.apache.calcite.avatica.util.TimeUnitRange;
+import org.apache.calcite.linq4j.Ord;
 import org.apache.calcite.plan.RelOptRule;
 import org.apache.calcite.plan.RelOptUtil;
 import org.apache.calcite.plan.hep.HepPlanner;
@@ -47,7 +48,6 @@ import org.apache.calcite.rex.RexTableInputRef.RelTableRef;
 import org.apache.calcite.rex.RexUtil;
 import org.apache.calcite.sql.SqlAggFunction;
 import org.apache.calcite.sql.SqlFunction;
-import org.apache.calcite.sql.SqlKind;
 import org.apache.calcite.sql.fun.SqlMinMaxAggFunction;
 import org.apache.calcite.sql.fun.SqlStdOperatorTable;
 import org.apache.calcite.sql.type.SqlTypeName;
@@ -380,15 +380,15 @@ public abstract class MaterializedViewAggregateRule<C extends MaterializedViewAg
       RelNode input,
       @Nullable Project topProject,
       RelNode node,
-      @Nullable Project topViewProject,
+      @Nullable Project topViewProject0,
       RelNode viewNode,
       BiMap<RelTableRef, RelTableRef> queryToViewTableMapping,
       EquivalenceClasses queryEC) {
     final Aggregate queryAggregate = (Aggregate) node;
     final Aggregate viewAggregate = (Aggregate) viewNode;
     // Get group by references and aggregate call input references needed
-    ImmutableBitSet.Builder indexes = ImmutableBitSet.builder();
-    ImmutableBitSet references = null;
+    final ImmutableBitSet.Builder indexes = ImmutableBitSet.builder();
+    final ImmutableBitSet references;
     if (topProject != null && !unionRewriting) {
       // We have a Project on top, gather only what is needed
       final RelOptUtil.InputFinder inputFinder =
@@ -415,11 +415,12 @@ public abstract class MaterializedViewAggregateRule<C extends MaterializedViewAg
           indexes.set(inputIdx);
         }
       }
+      references = null;
     }
 
     // Create mapping from query columns to view columns
-    List<RexNode> rollupNodes = new ArrayList<>();
-    Multimap<Integer, Integer> m = generateMapping(rexBuilder, simplify, mq,
+    final List<RexNode> rollupNodes = new ArrayList<>();
+    final Multimap<Integer, Integer> m = generateMapping(rexBuilder, simplify, mq,
         queryAggregate.getInput(), viewAggregate.getInput(), indexes.build(),
         queryToViewTableMapping, queryEC, rollupNodes);
     if (m == null) {
@@ -460,12 +461,13 @@ public abstract class MaterializedViewAggregateRule<C extends MaterializedViewAg
       }
     }
     boolean containsDistinctAgg = false;
-    for (int idx = 0; idx < queryAggregate.getAggCallList().size(); idx++) {
-      if (references != null && !references.get(queryAggregate.getGroupCount() + idx)) {
+    for (Ord<AggregateCall> ord : Ord.zip(queryAggregate.getAggCallList())) {
+      if (references != null
+          && !references.get(queryAggregate.getGroupCount() + ord.i)) {
         // Ignore
         continue;
       }
-      AggregateCall queryAggCall = queryAggregate.getAggCallList().get(idx);
+      final AggregateCall queryAggCall = ord.e;
       if (queryAggCall.filterArg >= 0) {
         // Not supported currently
         return null;
@@ -488,7 +490,7 @@ public abstract class MaterializedViewAggregateRule<C extends MaterializedViewAg
           // Continue
           continue;
         }
-        aggregateMapping.set(queryAggregate.getGroupCount() + idx,
+        aggregateMapping.set(queryAggregate.getGroupCount() + ord.i,
             viewAggregate.getGroupCount() + j);
         if (queryAggCall.isDistinct()) {
           containsDistinctAgg = true;
@@ -497,26 +499,25 @@ public abstract class MaterializedViewAggregateRule<C extends MaterializedViewAg
       }
     }
 
-    // If we reach here, to simplify things, we create an identity topViewProject
-    // if not present
-    if (topViewProject == null) {
-      topViewProject = (Project) relBuilder.push(viewNode)
-          .project(relBuilder.fields(), ImmutableList.of(), true).build();
-    }
+    // To simplify things, create an identity topViewProject if not present.
+    final Project topViewProject = topViewProject0 != null
+        ? topViewProject0
+        : (Project) relBuilder.push(viewNode)
+            .project(relBuilder.fields(), ImmutableList.of(), true)
+            .build();
 
     // Generate result rewriting
     final List<RexNode> additionalViewExprs = new ArrayList<>();
 
     // Multimap is required since a column in the materialized view's project
     // could map to multiple columns in the target query
-    ImmutableMultimap<Integer, Integer> rewritingMapping = null;
-    RelNode result = relBuilder.push(input).build();
+    final ImmutableMultimap<Integer, Integer> rewritingMapping;
+    relBuilder.push(input);
     // We create view expressions that will be used in a Project on top of the
     // view in case we need to rollup the expression
-    final List<RexNode> inputViewExprs = new ArrayList<>();
-    inputViewExprs.addAll(relBuilder.push(result).fields());
-    relBuilder.clear();
-    if (forceRollup || queryAggregate.getGroupCount() != viewAggregate.getGroupCount()
+    final List<RexNode> inputViewExprs = new ArrayList<>(relBuilder.fields());
+    if (forceRollup
+        || queryAggregate.getGroupCount() != viewAggregate.getGroupCount()
         || matchModality == MatchModality.VIEW_PARTIAL) {
       if (containsDistinctAgg) {
         // Cannot rollup DISTINCT aggregate
@@ -527,36 +528,27 @@ public abstract class MaterializedViewAggregateRule<C extends MaterializedViewAg
           ImmutableMultimap.builder();
       final ImmutableBitSet.Builder groupSetB = ImmutableBitSet.builder();
       for (int i = 0; i < queryAggregate.getGroupCount(); i++) {
-        int targetIdx = aggregateMapping.getTargetOpt(i);
+        final int targetIdx = aggregateMapping.getTargetOpt(i);
         if (targetIdx == -1) {
           // No matching group by column, we bail out
           return null;
         }
-        boolean added = false;
         if (targetIdx >= viewAggregate.getRowType().getFieldCount()) {
-          RexNode targetNode = rollupNodes.get(
-              targetIdx - viewInputFieldCount - viewInputDifferenceViewFieldCount);
+          RexNode targetNode =
+              rollupNodes.get(targetIdx - viewInputFieldCount
+                  - viewInputDifferenceViewFieldCount);
           // We need to rollup this expression
           final Multimap<RexNode, Integer> exprsLineage = ArrayListMultimap.create();
-          final ImmutableBitSet refs = RelOptUtil.InputFinder.bits(targetNode);
-          for (int childTargetIdx : refs) {
-            added = false;
-            for (int k = 0; k < topViewProject.getProjects().size() && !added; k++) {
-              RexNode n = topViewProject.getProjects().get(k);
-              if (!n.isA(SqlKind.INPUT_REF)) {
-                continue;
-              }
-              final int ref = ((RexInputRef) n).getIndex();
-              if (ref == childTargetIdx) {
-                exprsLineage.put(
-                    new RexInputRef(ref, targetNode.getType()), k);
-                added = true;
-              }
-            }
-            if (!added) {
+          for (int r : RelOptUtil.InputFinder.bits(targetNode)) {
+            final int j = find(viewNode, r);
+            final int k = find(topViewProject, j);
+            if (k < 0) {
               // No matching column needed for computed expression, bail out
               return null;
             }
+            final RexInputRef ref =
+                relBuilder.with(viewNode.getInput(0), b -> b.field(r));
+            exprsLineage.put(ref, k);
           }
           // We create the new node pointing to the index
           groupSetB.set(inputViewExprs.size());
@@ -566,115 +558,92 @@ public abstract class MaterializedViewAggregateRule<C extends MaterializedViewAg
           // We need to create the rollup expression
           RexNode rollupExpression = requireNonNull(
               shuttleReferences(rexBuilder, targetNode, exprsLineage),
-              () -> "shuttleReferences produced null for targetNode=" + targetNode
-                  + ", exprsLineage=" + exprsLineage);
+              () -> "shuttleReferences produced null for targetNode="
+                  + targetNode + ", exprsLineage=" + exprsLineage);
           inputViewExprs.add(rollupExpression);
-          added = true;
         } else {
           // This expression should be referenced directly
-          for (int k = 0; k < topViewProject.getProjects().size() && !added; k++) {
-            RexNode n = topViewProject.getProjects().get(k);
-            if (!n.isA(SqlKind.INPUT_REF)) {
-              continue;
-            }
-            int ref = ((RexInputRef) n).getIndex();
-            if (ref == targetIdx) {
-              groupSetB.set(k);
-              rewritingMappingB.put(k, i);
-              added = true;
-            }
+          final int k = find(topViewProject, targetIdx);
+          if (k < 0) {
+            // No matching group by column, we bail out
+            return null;
           }
-        }
-        if (!added) {
-          // No matching group by column, we bail out
-          return null;
+          groupSetB.set(k);
+          rewritingMappingB.put(k, i);
         }
       }
       final ImmutableBitSet groupSet = groupSetB.build();
       final List<AggCall> aggregateCalls = new ArrayList<>();
-      for (int i = 0; i < queryAggregate.getAggCallList().size(); i++) {
-        if (references != null && !references.get(queryAggregate.getGroupCount() + i)) {
+      for (Ord<AggregateCall> ord : Ord.zip(queryAggregate.getAggCallList())) {
+        final int sourceIdx = queryAggregate.getGroupCount() + ord.i;
+        if (references != null && !references.get(sourceIdx)) {
           // Ignore
           continue;
         }
-        int sourceIdx = queryAggregate.getGroupCount() + i;
-        int targetIdx =
+        final int targetIdx =
             aggregateMapping.getTargetOpt(sourceIdx);
         if (targetIdx < 0) {
           // No matching aggregation column, we bail out
           return null;
         }
-        AggregateCall queryAggCall = queryAggregate.getAggCallList().get(i);
-        boolean added = false;
-        for (int k = 0; k < topViewProject.getProjects().size() && !added; k++) {
-          RexNode n = topViewProject.getProjects().get(k);
-          if (!n.isA(SqlKind.INPUT_REF)) {
-            continue;
-          }
-          int ref = ((RexInputRef) n).getIndex();
-          if (ref == targetIdx) {
-            SqlAggFunction rollupAgg =
-                getRollup(queryAggCall.getAggregation());
-            if (rollupAgg == null) {
-              // Cannot rollup this aggregate, bail out
-              return null;
-            }
-            rewritingMappingB.put(k, queryAggregate.getGroupCount() + aggregateCalls.size());
-            final RexInputRef operand = rexBuilder.makeInputRef(input, k);
-            aggregateCalls.add(
-                relBuilder.aggregateCall(rollupAgg, operand)
-                    .approximate(queryAggCall.isApproximate())
-                    .distinct(queryAggCall.isDistinct())
-                    .as(queryAggCall.name));
-            added = true;
-          }
-        }
-        if (!added) {
+        final int k = find(topViewProject, targetIdx);
+        if (k < 0) {
           // No matching aggregation column, we bail out
           return null;
         }
+        final AggregateCall queryAggCall = ord.e;
+        SqlAggFunction rollupAgg =
+            getRollup(queryAggCall.getAggregation());
+        if (rollupAgg == null) {
+          // Cannot rollup this aggregate, bail out
+          return null;
+        }
+        rewritingMappingB.put(k,
+            queryAggregate.getGroupCount() + aggregateCalls.size());
+        final RexInputRef operand = rexBuilder.makeInputRef(input, k);
+        aggregateCalls.add(
+            relBuilder.aggregateCall(rollupAgg, operand)
+                .approximate(queryAggCall.isApproximate())
+                .distinct(queryAggCall.isDistinct())
+                .as(queryAggCall.name));
       }
       // Create aggregate on top of input
-      RelNode prevNode = result;
-      relBuilder.push(result);
-      if (inputViewExprs.size() != result.getRowType().getFieldCount()) {
+      final RelNode prevNode = relBuilder.peek();
+      if (inputViewExprs.size() > prevNode.getRowType().getFieldCount()) {
         relBuilder.project(inputViewExprs);
       }
-      result = relBuilder
-          .aggregate(relBuilder.groupKey(groupSet), aggregateCalls)
-          .build();
-      if (prevNode == result && groupSet.cardinality() != result.getRowType().getFieldCount()) {
+      relBuilder
+          .aggregate(relBuilder.groupKey(groupSet), aggregateCalls);
+      if (prevNode == relBuilder.peek()
+          && groupSet.cardinality() != relBuilder.peek().getRowType().getFieldCount()) {
         // Aggregate was not inserted but we need to prune columns
-        result = relBuilder
-            .push(result)
-            .project(relBuilder.fields(groupSet))
-            .build();
+        relBuilder.project(relBuilder.fields(groupSet));
       }
-      // We introduce a project on top, as group by columns order is lost
+      // We introduce a project on top, as group by columns order is lost.
+      // Multimap is required since a column in the materialized view's project
+      // could map to multiple columns in the target query.
       rewritingMapping = rewritingMappingB.build();
       final ImmutableMultimap<Integer, Integer> inverseMapping = rewritingMapping.inverse();
       final List<RexNode> projects = new ArrayList<>();
 
       final ImmutableBitSet.Builder addedProjects = ImmutableBitSet.builder();
       for (int i = 0; i < queryAggregate.getGroupCount(); i++) {
-        int pos = groupSet.indexOf(inverseMapping.get(i).iterator().next());
+        final int pos = groupSet.indexOf(inverseMapping.get(i).iterator().next());
         addedProjects.set(pos);
-        projects.add(
-            rexBuilder.makeInputRef(result, pos));
+        projects.add(relBuilder.field(pos));
       }
 
-      ImmutableBitSet projectedCols = addedProjects.build();
+      final ImmutableBitSet projectedCols = addedProjects.build();
       // We add aggregate functions that are present in result to projection list
-      for (int i = 0; i < result.getRowType().getFieldCount(); i++) {
+      for (int i = 0; i < relBuilder.peek().getRowType().getFieldCount(); i++) {
         if (!projectedCols.get(i)) {
-          projects.add(rexBuilder.makeInputRef(result, i));
+          projects.add(relBuilder.field(i));
         }
       }
-      result = relBuilder
-          .push(result)
-          .project(projects)
-          .build();
-    } // end if queryAggregate.getGroupCount() != viewAggregate.getGroupCount()
+      relBuilder.project(projects);
+    } else {
+      rewritingMapping = null;
+    }
 
     // Add query expressions on top. We first map query expressions to view
     // expressions. Once we have done that, if the expression is contained
@@ -695,37 +664,75 @@ public abstract class MaterializedViewAggregateRule<C extends MaterializedViewAg
     }
     // Available in view.
     final Multimap<RexNode, Integer> viewExprs = ArrayListMultimap.create();
-    int numberViewExprs = 0;
-    for (RexNode viewExpr : topViewProject.getProjects()) {
-      viewExprs.put(viewExpr, numberViewExprs++);
-    }
-    for (RexNode additionalViewExpr : additionalViewExprs) {
-      viewExprs.put(additionalViewExpr, numberViewExprs++);
-    }
+    addAllIndexed(viewExprs, topViewProject.getProjects());
+    addAllIndexed(viewExprs, additionalViewExprs);
     final List<RexNode> rewrittenExprs = new ArrayList<>(topExprs.size());
     for (RexNode expr : topExprs) {
       // First map through the aggregate
-      RexNode rewrittenExpr = shuttleReferences(rexBuilder, expr, aggregateMapping);
-      if (rewrittenExpr == null) {
+      final RexNode e2 = shuttleReferences(rexBuilder, expr, aggregateMapping);
+      if (e2 == null) {
         // Cannot map expression
         return null;
       }
       // Next map through the last project
-      rewrittenExpr =
-          shuttleReferences(rexBuilder, rewrittenExpr, viewExprs, result, rewritingMapping);
-      if (rewrittenExpr == null) {
+      final RexNode e3 =
+          shuttleReferences(rexBuilder, e2, viewExprs,
+              relBuilder.peek(), rewritingMapping);
+      if (e3 == null) {
         // Cannot map expression
         return null;
       }
-      rewrittenExprs.add(rewrittenExpr);
+      rewrittenExprs.add(e3);
     }
     return relBuilder
-        .push(result)
         .project(rewrittenExprs)
         .convert(topRowType, false)
         .build();
   }
 
+  private static <K> void addAllIndexed(Multimap<K, Integer> multimap,
+      Iterable<? extends K> list) {
+    for (K k : list) {
+      multimap.put(k, multimap.size());
+    }
+  }
+
+  /** Given a relational expression with a single input (such as a Project or
+   * Aggregate) and the ordinal of an input field, returns the ordinal of the
+   * output field that references the input field. Or -1 if the field is not
+   * propagated.
+   *
+   * <p>For example, if {@code rel} is {@code Project(c0, c2)} (on input with
+   * columns (c0, c1, c2)), then {@code find(rel, 2)} returns 1 (c2);
+   * {@code find(rel, 1)} returns -1 (because c1 is not projected).
+   *
+   * <p>If {@code rel} is {@code Aggregate([0, 2], sum(1))}, then
+   * {@code find(rel, 2)} returns 1, and {@code find(rel, 1)} returns -1.
+   *
+   * @param rel Relational expression
+   * @param ref Ordinal of output field
+   * @return Ordinal of input field, or -1
+   */
+  private static int find(RelNode rel, int ref) {
+    if (rel instanceof Project) {
+      Project project = (Project) rel;
+      for (Ord<RexNode> p : Ord.zip(project.getProjects())) {
+        if (p.e instanceof RexInputRef
+            && ((RexInputRef) p.e).getIndex() == ref) {
+          return p.i;
+        }
+      }
+    }
+    if (rel instanceof Aggregate) {
+      Aggregate aggregate = (Aggregate) rel;
+      int k = aggregate.getGroupSet().indexOf(ref);
+      if (k >= 0) {
+        return k;
+      }
+    }
+    return -1;
+  }
+
   /**
    * Mapping from node expressions to target expressions.
    *
@@ -746,7 +753,7 @@ public abstract class MaterializedViewAggregateRule<C extends MaterializedViewAg
     Map<RexTableInputRef, Set<RexTableInputRef>> equivalenceClassesMap =
         sourceEC.getEquivalenceClassesMap();
     Multimap<RexNode, Integer> exprsLineage = ArrayListMultimap.create();
-    List<RexNode> timestampExprs = new ArrayList<>();
+    final List<RexNode> timestampExprs = new ArrayList<>();
     for (int i = 0; i < target.getRowType().getFieldCount(); i++) {
       Set<RexNode> s = mq.getExpressionLineage(target, rexBuilder.makeInputRef(target, i));
       if (s == null) {
@@ -758,7 +765,7 @@ public abstract class MaterializedViewAggregateRule<C extends MaterializedViewAg
       final RexNode e = Iterables.getOnlyElement(s);
       // Rewrite expr to be expressed on query tables
       final RexNode simplified = simplify.simplifyUnknownAsFalse(e);
-      RexNode expr = RexUtil.swapTableColumnReferences(rexBuilder,
+      final RexNode expr = RexUtil.swapTableColumnReferences(rexBuilder,
           simplified,
           tableMapping.inverse(),
           equivalenceClassesMap);
@@ -776,39 +783,26 @@ public abstract class MaterializedViewAggregateRule<C extends MaterializedViewAg
     // FLOOR(ts to DAY) via FLOOR(FLOOR(ts to HOUR) to DAY)
     for (RexNode timestampExpr : timestampExprs) {
       for (TimeUnitRange value : SUPPORTED_DATE_TIME_ROLLUP_UNITS) {
-        // CEIL
-        RexNode ceilExpr =
-            rexBuilder.makeCall(getCeilSqlFunction(value),
-                timestampExpr, rexBuilder.makeFlag(value));
-        // References self-row
-        RexNode rewrittenCeilExpr =
-            shuttleReferences(rexBuilder, ceilExpr, exprsLineage);
-        if (rewrittenCeilExpr != null) {
-          // We add the CEIL expression to the additional expressions, replacing the child
-          // expression by the position that it references
-          additionalExprs.add(rewrittenCeilExpr);
-          // Then we simplify the expression and we add it to the expressions lineage so we
-          // can try to find a match
-          final RexNode simplified =
-              simplify.simplifyUnknownAsFalse(ceilExpr);
-          exprsLineage.put(simplified,
-              target.getRowType().getFieldCount() + additionalExprs.size() - 1);
-        }
-        // FLOOR
-        RexNode floorExpr =
-            rexBuilder.makeCall(getFloorSqlFunction(value),
-                timestampExpr, rexBuilder.makeFlag(value));
-        // References self-row
-        RexNode rewrittenFloorExpr =
-            shuttleReferences(rexBuilder, floorExpr, exprsLineage);
-        if (rewrittenFloorExpr != null) {
-          // We add the FLOOR expression to the additional expressions, replacing the child
-          // expression by the position that it references
-          additionalExprs.add(rewrittenFloorExpr);
-          // Then we simplify the expression and we add it to the expressions lineage so we
-          // can try to find a match
+        final SqlFunction[] functions = {getCeilSqlFunction(value),
+            getFloorSqlFunction(value)};
+        for (SqlFunction function : functions) {
+          final RexNode call =
+              rexBuilder.makeCall(function,
+                  timestampExpr, rexBuilder.makeFlag(value));
+          // References self-row
+          final RexNode rewrittenCall =
+              shuttleReferences(rexBuilder, call, exprsLineage);
+          if (rewrittenCall == null) {
+            continue;
+          }
+          // We add the CEIL or FLOOR expression to the additional
+          // expressions, replacing the child expression by the position that
+          // it references
+          additionalExprs.add(rewrittenCall);
+          // Then we simplify the expression and we add it to the expressions
+          // lineage so we can try to find a match.
           final RexNode simplified =
-              simplify.simplifyUnknownAsFalse(floorExpr);
+              simplify.simplifyUnknownAsFalse(call);
           exprsLineage.put(simplified,
               target.getRowType().getFieldCount() + additionalExprs.size() - 1);
         }
diff --git a/core/src/main/java/org/apache/calcite/tools/RelBuilder.java b/core/src/main/java/org/apache/calcite/tools/RelBuilder.java
index a838e08..04b7df6 100644
--- a/core/src/main/java/org/apache/calcite/tools/RelBuilder.java
+++ b/core/src/main/java/org/apache/calcite/tools/RelBuilder.java
@@ -372,6 +372,10 @@ public class RelBuilder {
   }
 
   private Frame peek_(int n) {
+    if (n == 0) {
+      // more efficient than starting an iterator
+      return Objects.requireNonNull(stack.peek(), "stack.peek");
+    }
     return Iterables.get(stack, n);
   }
 
@@ -399,6 +403,17 @@ public class RelBuilder {
     return offset;
   }
 
+  /** Evaluates an expression with a relational expression temporarily on the
+   * stack. */
+  public <E> E with(RelNode r, Function<RelBuilder, E> fn) {
+    try {
+      push(r);
+      return fn.apply(this);
+    } finally {
+      stack.pop();
+    }
+  }
+
   // Methods that return scalar expressions
 
   /** Creates a literal (constant expression). */
diff --git a/core/src/test/java/org/apache/calcite/test/JdbcTest.java b/core/src/test/java/org/apache/calcite/test/JdbcTest.java
index f2222aa..45b0ab0 100644
--- a/core/src/test/java/org/apache/calcite/test/JdbcTest.java
+++ b/core/src/test/java/org/apache/calcite/test/JdbcTest.java
@@ -7911,6 +7911,16 @@ public class JdbcTest {
     }
   }
 
+  public static class DepartmentPlus extends Department {
+    public final Timestamp inceptionDate;
+
+    public DepartmentPlus(int deptno, String name, List<Employee> employees,
+        Location location, Timestamp inceptionDate) {
+      super(deptno, name, employees, location);
+      this.inceptionDate = inceptionDate;
+    }
+  }
+
   public static class Location {
     public final int x;
     public final int y;
diff --git a/core/src/test/java/org/apache/calcite/test/MaterializationTest.java b/core/src/test/java/org/apache/calcite/test/MaterializationTest.java
index 8b16d14..9bf37db 100644
--- a/core/src/test/java/org/apache/calcite/test/MaterializationTest.java
+++ b/core/src/test/java/org/apache/calcite/test/MaterializationTest.java
@@ -29,6 +29,7 @@ import org.apache.calcite.runtime.Hook;
 import org.apache.calcite.schema.QueryableTable;
 import org.apache.calcite.schema.TranslatableTable;
 import org.apache.calcite.test.JdbcTest.Department;
+import org.apache.calcite.test.JdbcTest.DepartmentPlus;
 import org.apache.calcite.test.JdbcTest.Dependent;
 import org.apache.calcite.test.JdbcTest.Employee;
 import org.apache.calcite.test.JdbcTest.Event;
@@ -402,6 +403,14 @@ public class MaterializationTest {
             new Location(0, 52)),
         new Department(20, "HR", Collections.singletonList(emps[1]), null),
     };
+    public final DepartmentPlus[] depts2 = {
+        new DepartmentPlus(10, "Sales", Arrays.asList(emps[0], emps[2], emps[3]),
+            new Location(-122, 38), new Timestamp(0)),
+        new DepartmentPlus(30, "Marketing", ImmutableList.of(),
+            new Location(0, 52), new Timestamp(0)),
+        new DepartmentPlus(20, "HR", Collections.singletonList(emps[1]),
+            null, new Timestamp(0)),
+    };
     public final Dependent[] dependents = {
         new Dependent(10, "Michael"),
         new Dependent(10, "Jane"),
diff --git a/core/src/test/java/org/apache/calcite/test/MaterializedViewRelOptRulesTest.java b/core/src/test/java/org/apache/calcite/test/MaterializedViewRelOptRulesTest.java
index d0e2aa5..707b76d 100644
--- a/core/src/test/java/org/apache/calcite/test/MaterializedViewRelOptRulesTest.java
+++ b/core/src/test/java/org/apache/calcite/test/MaterializedViewRelOptRulesTest.java
@@ -31,8 +31,10 @@ import org.junit.jupiter.api.Test;
 import java.util.List;
 
 /**
- * Unit test for extensions of AbstractMaterializedViewRule,
- * in which materialized view gets matched by using structual information of plan.
+ * Unit test for
+ * {@link org.apache.calcite.rel.rules.materialize.MaterializedViewRule} and its
+ * sub-classes, in which materialized views are matched to the structure of a
+ * plan.
  */
 public class MaterializedViewRelOptRulesTest extends AbstractMaterializedViewTest {
 
@@ -759,6 +761,38 @@ public class MaterializedViewRelOptRulesTest extends AbstractMaterializedViewTes
         .ok();
   }
 
+  /** Test case for
+   * <a href="https://issues.apache.org/jira/browse/CALCITE-4276">[CALCITE-4276]
+   * If query contains join and rollup function (FLOOR), rewrite to materialized
+   * view contains bad field offset</a>. */
+  @Test void testJoinAggregateMaterializationAggregateFuncs15() {
+    final String m = ""
+        + "SELECT \"deptno\",\n"
+        + "  COUNT(*) AS \"dept_size\",\n"
+        + "  SUM(\"salary\") AS \"dept_budget\"\n"
+        + "FROM \"emps\"\n"
+        + "GROUP BY \"deptno\"";
+    final String q = ""
+        + "SELECT FLOOR(\"CREATED_AT\" TO YEAR) AS by_year,\n"
+        + "  COUNT(*) AS \"num_emps\"\n"
+        + "FROM (SELECT\"deptno\"\n"
+        + "    FROM \"emps\") AS \"t\"\n"
+        + "JOIN (SELECT \"deptno\",\n"
+        + "        \"inceptionDate\" as \"CREATED_AT\"\n"
+        + "    FROM \"depts2\") using (\"deptno\")\n"
+        + "GROUP BY FLOOR(\"CREATED_AT\" TO YEAR)";
+    String plan = ""
+        + "EnumerableAggregate(group=[{8}], num_emps=[$SUM0($1)])\n"
+        + "  EnumerableCalc(expr#0..7=[{inputs}], expr#8=[FLAG(YEAR)], "
+        + "expr#9=[FLOOR($t3, $t8)], proj#0..7=[{exprs}], $f8=[$t9])\n"
+        + "    EnumerableHashJoin(condition=[=($0, $4)], joinType=[inner])\n"
+        + "      EnumerableTableScan(table=[[hr, MV0]])\n"
+        + "      EnumerableTableScan(table=[[hr, depts2]])\n";
+    sql(m, q)
+        .withChecker(resultContains(plan))
+        .ok();
+  }
+
   @Test void testJoinMaterialization1() {
     String q = "select *\n"
         + "from (select * from \"emps\" where \"empid\" < 300)\n"