You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@calcite.apache.org by da...@apache.org on 2019/06/06 02:34:02 UTC

[calcite] 01/01: [CALCITE-2744] RelDecorrelator use wrong output map for LogicalAggregate decorrelate (godfreyhe and Danny Chan)

This is an automated email from the ASF dual-hosted git repository.

danny0405 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/calcite.git

commit d8f4cc4aa4f6c8b49bd76d5043e30270b7ff2546
Author: godfreyhe <go...@163.com>
AuthorDate: Tue Jun 4 16:37:54 2019 +0800

    [CALCITE-2744] RelDecorrelator use wrong output map for LogicalAggregate decorrelate (godfreyhe and Danny Chan)
    
    godfreyhe started the work by apply a new map for LogicalAggregate
    decorrelate register, and Danny Chan add shifts for constant keys
    mapping.
    
    Also fix the test case name and comments.
    
    close apache/calcite#1254
---
 .../apache/calcite/sql2rel/RelDecorrelator.java    |  38 ++++++--
 .../apache/calcite/test/MockSqlOperatorTable.java  |  24 +++++
 .../org/apache/calcite/test/RelOptRulesTest.java   |  38 ++++++++
 .../org/apache/calcite/test/RelOptRulesTest.xml    | 100 +++++++++++++++++++++
 4 files changed, 195 insertions(+), 5 deletions(-)

diff --git a/core/src/main/java/org/apache/calcite/sql2rel/RelDecorrelator.java b/core/src/main/java/org/apache/calcite/sql2rel/RelDecorrelator.java
index 5e9a1c4..a11cb9e 100644
--- a/core/src/main/java/org/apache/calcite/sql2rel/RelDecorrelator.java
+++ b/core/src/main/java/org/apache/calcite/sql2rel/RelDecorrelator.java
@@ -467,8 +467,11 @@ public class RelDecorrelator implements ReflectiveVisitor {
     }
     final RelNode newInput = frame.r;
 
+    // aggregate outputs mapping: group keys and aggregates
+    final Map<Integer, Integer> outputMap = new HashMap<>();
+
     // map from newInput
-    Map<Integer, Integer> mapNewInputToProjOutputs = new HashMap<>();
+    final Map<Integer, Integer> mapNewInputToProjOutputs = new HashMap<>();
     final int oldGroupKeyCount = rel.getGroupSet().cardinality();
 
     // Project projects the original expressions,
@@ -490,6 +493,9 @@ public class RelDecorrelator implements ReflectiveVisitor {
         omittedConstants.put(i, constant);
         continue;
       }
+
+      // add mapping of group keys.
+      outputMap.put(i, newPos);
       int newInputPos = frame.oldToNewOutputs.get(i);
       projects.add(RexInputRef.of2(newInputPos, newInputOutput));
       mapNewInputToProjOutputs.put(newInputPos, newPos);
@@ -593,7 +599,7 @@ public class RelDecorrelator implements ReflectiveVisitor {
 
       // The old to new output position mapping will be the same as that
       // of newProject, plus any aggregates that the oldAgg produces.
-      combinedMap.put(
+      outputMap.put(
           oldInputOutputFieldCount + i,
           newInputOutputFieldCount + i);
     }
@@ -605,15 +611,37 @@ public class RelDecorrelator implements ReflectiveVisitor {
       final List<RexNode> postProjects = new ArrayList<>(relBuilder.fields());
       for (Map.Entry<Integer, RexLiteral> entry
           : omittedConstants.descendingMap().entrySet()) {
-        postProjects.add(entry.getKey() + frame.corDefOutputs.size(),
-            entry.getValue());
+        int index = entry.getKey() + frame.corDefOutputs.size();
+        postProjects.add(index, entry.getValue());
+        // Shift the outputs whose index equals with or bigger than the added index
+        // with 1 offset.
+        shiftMapping(outputMap, index, 1);
+        // Then add the constant key mapping.
+        outputMap.put(entry.getKey(), index);
       }
       relBuilder.project(postProjects);
     }
 
     // Aggregate does not change input ordering so corVars will be
     // located at the same position as the input newProject.
-    return register(rel, relBuilder.build(), combinedMap, corDefOutputs);
+    return register(rel, relBuilder.build(), outputMap, corDefOutputs);
+  }
+
+  /**
+   * Shift the mapping to fixed offset from the {@code startIndex}.
+   * @param mapping    the original mapping
+   * @param startIndex any output whose index equals with or bigger than the starting index
+   *                   would be shift
+   * @param offset     shift offset
+   */
+  private static void shiftMapping(Map<Integer, Integer> mapping, int startIndex, int offset) {
+    for (Map.Entry<Integer, Integer> entry : mapping.entrySet()) {
+      if (entry.getValue() >= startIndex) {
+        mapping.put(entry.getKey(), entry.getValue() + offset);
+      } else {
+        mapping.put(entry.getKey(), entry.getValue());
+      }
+    }
   }
 
   public Frame getInvoke(RelNode r, RelNode parent) {
diff --git a/core/src/test/java/org/apache/calcite/test/MockSqlOperatorTable.java b/core/src/test/java/org/apache/calcite/test/MockSqlOperatorTable.java
index 78b842d..c128540 100644
--- a/core/src/test/java/org/apache/calcite/test/MockSqlOperatorTable.java
+++ b/core/src/test/java/org/apache/calcite/test/MockSqlOperatorTable.java
@@ -18,6 +18,7 @@ package org.apache.calcite.test;
 
 import org.apache.calcite.rel.type.RelDataType;
 import org.apache.calcite.rel.type.RelDataTypeFactory;
+import org.apache.calcite.sql.SqlAggFunction;
 import org.apache.calcite.sql.SqlFunction;
 import org.apache.calcite.sql.SqlFunctionCategory;
 import org.apache.calcite.sql.SqlIdentifier;
@@ -27,6 +28,8 @@ import org.apache.calcite.sql.SqlOperatorBinding;
 import org.apache.calcite.sql.SqlOperatorTable;
 import org.apache.calcite.sql.parser.SqlParserPos;
 import org.apache.calcite.sql.type.OperandTypes;
+import org.apache.calcite.sql.type.ReturnTypes;
+import org.apache.calcite.sql.type.SqlTypeFamily;
 import org.apache.calcite.sql.type.SqlTypeName;
 import org.apache.calcite.sql.util.ChainedSqlOperatorTable;
 import org.apache.calcite.sql.util.ListSqlOperatorTable;
@@ -64,6 +67,7 @@ public class MockSqlOperatorTable extends ChainedSqlOperatorTable {
     opTab.addOperator(new RampFunction());
     opTab.addOperator(new DedupFunction());
     opTab.addOperator(new MyFunction());
+    opTab.addOperator(new MyAvgAggFunction());
   }
 
   /** "RAMP" user-defined function. */
@@ -125,6 +129,26 @@ public class MockSqlOperatorTable extends ChainedSqlOperatorTable {
       return typeFactory.createSqlType(SqlTypeName.BIGINT);
     }
   }
+
+  /** "MY_AVG" user-defined aggregate function. */
+  public static class MyAvgAggFunction extends SqlAggFunction {
+    public MyAvgAggFunction() {
+      super("MY_AVG",
+          null,
+          SqlKind.AVG,
+          ReturnTypes.AVG_AGG_FUNCTION,
+          null,
+          OperandTypes.family(SqlTypeFamily.NUMERIC, SqlTypeFamily.NUMERIC),
+          SqlFunctionCategory.NUMERIC,
+          false,
+          false);
+    }
+
+    @Override public boolean isDeterministic() {
+      return false;
+    }
+  }
+
 }
 
 // End MockSqlOperatorTable.java
diff --git a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
index d70b320..38719ce 100644
--- a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
+++ b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
@@ -5278,6 +5278,44 @@ public class RelOptRulesTest extends RelOptTestBase {
   }
 
   /** Test case for
+   * <a href="https://issues.apache.org/jira/browse/CALCITE-2744">[CALCITE-2744]
+   * RelDecorrelator use wrong output map for LogicalAggregate decorrelate</a>. */
+  @Test public void testDecorrelateAggWithConstantGroupKey() {
+    final String sql = "SELECT * FROM emp A where sal in \n"
+        + "(SELECT max(sal) FROM emp B where A.mgr = B.empno group by deptno, 'abc')";
+    sql(sql)
+        .withLateDecorrelation(true)
+        .withTrim(true)
+        .with(HepProgram.builder().build())
+        .check();
+  }
+
+  /** Test case for CALCITE-2744 for aggregate decorrelate with multi-param agg call
+   * but without group key. */
+  @Test public void testDecorrelateAggWithMultiParamsAggCall() {
+    final String sql = "SELECT * FROM (SELECT MY_AVG(sal, 1) AS c FROM emp) as m,\n"
+                       + " LATERAL TABLE(ramp(m.c)) AS T(s)";
+    sql(sql)
+        .withLateDecorrelation(true)
+        .withTrim(true)
+        .with(HepProgram.builder().build())
+        .checkUnchanged();
+  }
+
+  /** Same as {@link #testDecorrelateAggWithMultiParamsAggCall}
+   * but with constant grouping key. */
+  @Test public void testDecorrelateAggWithMultiParamsAggCall2() {
+    final String sql = "SELECT * FROM "
+        + "(SELECT MY_AVG(sal, 1) AS c FROM emp group by empno, 'abc') as m,\n"
+        + " LATERAL TABLE(ramp(m.c)) AS T(s)";
+    sql(sql)
+        .withLateDecorrelation(true)
+        .withTrim(true)
+        .with(HepProgram.builder().build())
+        .checkUnchanged();
+  }
+
+  /** Test case for
    * <a href="https://issues.apache.org/jira/browse/CALCITE-434">[CALCITE-434]
    * Converting predicates on date dimension columns into date ranges</a>,
    * specifically a rule that converts {@code EXTRACT(YEAR FROM ...) = constant}
diff --git a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
index a359d64..21f99d2 100644
--- a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
+++ b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
@@ -2684,6 +2684,106 @@ LogicalProject("K0"=[$0], "C1"=[$1], "F1"."A0"=[$2], "F2"."A0"=[$3], "F0"."C0"=[
 ]]>
         </Resource>
     </TestCase>
+    <TestCase name="testDecorrelateAggWithConstantGroupKey">
+        <Resource name="sql">
+            <![CDATA[SELECT * FROM emp A where sal in
+(SELECT max(sal) FROM emp B where A.mgr = B.empno group by deptno, 'abc')]]>
+        </Resource>
+        <Resource name="planBefore">
+            <![CDATA[
+LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], SLACKER=[$8])
+  LogicalCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{3, 5}])
+    LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+    LogicalFilter(condition=[=($cor0.SAL, $0)])
+      LogicalAggregate(group=[{0}])
+        LogicalProject(EXPR$0=[$2])
+          LogicalAggregate(group=[{0, 1}], EXPR$0=[MAX($2)])
+            LogicalProject(DEPTNO=[$7], $f1=['abc'], SAL=[$5])
+              LogicalFilter(condition=[=($cor0.MGR, $0)])
+                LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+        </Resource>
+        <Resource name="planMid">
+            <![CDATA[
+LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], SLACKER=[$8])
+  LogicalCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{3, 5}])
+    LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+    LogicalFilter(condition=[=($cor0.SAL, $0)])
+      LogicalAggregate(group=[{0}])
+        LogicalProject(EXPR$0=[$2])
+          LogicalAggregate(group=[{0, 1}], EXPR$0=[MAX($2)])
+            LogicalProject(DEPTNO=[$7], $f1=['abc'], SAL=[$5])
+              LogicalFilter(condition=[=($cor0.MGR, $0)])
+                LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+        </Resource>
+        <Resource name="planAfter">
+            <![CDATA[
+LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], SLACKER=[$8])
+  LogicalJoin(condition=[AND(=($3, $10), =($5, $9))], joinType=[inner])
+    LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+    LogicalAggregate(group=[{0, 1}])
+      LogicalProject(EXPR$0=[$2], EMPNO=[$1])
+        LogicalAggregate(group=[{0, 1}], EXPR$0=[MAX($3)])
+          LogicalProject(DEPTNO=[$7], EMPNO=[$0], $f1=['abc'], SAL=[$5])
+            LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+        </Resource>
+    </TestCase>
+    <TestCase name="testDecorrelateAggWithMultiParamsAggCall">
+        <Resource name="sql">
+            <![CDATA[SELECT * FROM (SELECT MY_AVG(sal, 1) AS c FROM emp) as m,
+ LATERAL TABLE(ramp(m.c)) AS T(s)]]>
+        </Resource>
+        <Resource name="planBefore">
+            <![CDATA[
+LogicalProject(C=[$0], S=[$1])
+  LogicalCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{0}])
+    LogicalAggregate(group=[{}], C=[MY_AVG($0, $1)])
+      LogicalProject(SAL=[$5], $f1=[1])
+        LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+    LogicalTableFunctionScan(invocation=[RAMP($cor0.C)], rowType=[RecordType(INTEGER I)])
+]]>
+        </Resource>
+        <Resource name="planMid">
+            <![CDATA[
+LogicalProject(C=[$0], S=[$1])
+  LogicalCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{0}])
+    LogicalAggregate(group=[{}], C=[MY_AVG($0, $1)])
+      LogicalProject(SAL=[$5], $f1=[1])
+        LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+    LogicalTableFunctionScan(invocation=[RAMP($cor0.C)], rowType=[RecordType(INTEGER I)])
+]]>
+        </Resource>
+    </TestCase>
+    <TestCase name="testDecorrelateAggWithMultiParamsAggCall2">
+        <Resource name="sql">
+            <![CDATA[SELECT * FROM (SELECT MY_AVG(sal, 1) AS c FROM emp group by empno, 'abc') as m,
+ LATERAL TABLE(ramp(m.c)) AS T(s)]]>
+        </Resource>
+        <Resource name="planBefore">
+            <![CDATA[
+LogicalProject(C=[$0], S=[$1])
+  LogicalCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{0}])
+    LogicalProject(C=[$2])
+      LogicalAggregate(group=[{0, 1}], C=[MY_AVG($2, $3)])
+        LogicalProject(EMPNO=[$0], $f1=['abc'], SAL=[$5], $f3=[1])
+          LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+    LogicalTableFunctionScan(invocation=[RAMP($cor0.C)], rowType=[RecordType(INTEGER I)])
+]]>
+        </Resource>
+        <Resource name="planMid">
+            <![CDATA[
+LogicalProject(C=[$0], S=[$1])
+  LogicalCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{0}])
+    LogicalProject(C=[$2])
+      LogicalAggregate(group=[{0, 1}], C=[MY_AVG($2, $3)])
+        LogicalProject(EMPNO=[$0], $f1=['abc'], SAL=[$5], $f3=[1])
+          LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+    LogicalTableFunctionScan(invocation=[RAMP($cor0.C)], rowType=[RecordType(INTEGER I)])
+]]>
+        </Resource>
+    </TestCase>
     <TestCase name="testExtractYearMonthToRange">
         <Resource name="sql">
             <![CDATA[select *