You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@calcite.apache.org by da...@apache.org on 2019/06/06 02:34:02 UTC
[calcite] 01/01: [CALCITE-2744] RelDecorrelator use wrong output
map for LogicalAggregate decorrelate (godfreyhe and Danny Chan)
This is an automated email from the ASF dual-hosted git repository.
danny0405 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/calcite.git
commit d8f4cc4aa4f6c8b49bd76d5043e30270b7ff2546
Author: godfreyhe <go...@163.com>
AuthorDate: Tue Jun 4 16:37:54 2019 +0800
[CALCITE-2744] RelDecorrelator use wrong output map for LogicalAggregate decorrelate (godfreyhe and Danny Chan)
godfreyhe started the work by apply a new map for LogicalAggregate
decorrelate register, and Danny Chan add shifts for constant keys
mapping.
Also fix the test case name and comments.
close apache/calcite#1254
---
.../apache/calcite/sql2rel/RelDecorrelator.java | 38 ++++++--
.../apache/calcite/test/MockSqlOperatorTable.java | 24 +++++
.../org/apache/calcite/test/RelOptRulesTest.java | 38 ++++++++
.../org/apache/calcite/test/RelOptRulesTest.xml | 100 +++++++++++++++++++++
4 files changed, 195 insertions(+), 5 deletions(-)
diff --git a/core/src/main/java/org/apache/calcite/sql2rel/RelDecorrelator.java b/core/src/main/java/org/apache/calcite/sql2rel/RelDecorrelator.java
index 5e9a1c4..a11cb9e 100644
--- a/core/src/main/java/org/apache/calcite/sql2rel/RelDecorrelator.java
+++ b/core/src/main/java/org/apache/calcite/sql2rel/RelDecorrelator.java
@@ -467,8 +467,11 @@ public class RelDecorrelator implements ReflectiveVisitor {
}
final RelNode newInput = frame.r;
+ // aggregate outputs mapping: group keys and aggregates
+ final Map<Integer, Integer> outputMap = new HashMap<>();
+
// map from newInput
- Map<Integer, Integer> mapNewInputToProjOutputs = new HashMap<>();
+ final Map<Integer, Integer> mapNewInputToProjOutputs = new HashMap<>();
final int oldGroupKeyCount = rel.getGroupSet().cardinality();
// Project projects the original expressions,
@@ -490,6 +493,9 @@ public class RelDecorrelator implements ReflectiveVisitor {
omittedConstants.put(i, constant);
continue;
}
+
+ // add mapping of group keys.
+ outputMap.put(i, newPos);
int newInputPos = frame.oldToNewOutputs.get(i);
projects.add(RexInputRef.of2(newInputPos, newInputOutput));
mapNewInputToProjOutputs.put(newInputPos, newPos);
@@ -593,7 +599,7 @@ public class RelDecorrelator implements ReflectiveVisitor {
// The old to new output position mapping will be the same as that
// of newProject, plus any aggregates that the oldAgg produces.
- combinedMap.put(
+ outputMap.put(
oldInputOutputFieldCount + i,
newInputOutputFieldCount + i);
}
@@ -605,15 +611,37 @@ public class RelDecorrelator implements ReflectiveVisitor {
final List<RexNode> postProjects = new ArrayList<>(relBuilder.fields());
for (Map.Entry<Integer, RexLiteral> entry
: omittedConstants.descendingMap().entrySet()) {
- postProjects.add(entry.getKey() + frame.corDefOutputs.size(),
- entry.getValue());
+ int index = entry.getKey() + frame.corDefOutputs.size();
+ postProjects.add(index, entry.getValue());
+ // Shift the outputs whose index equals with or bigger than the added index
+ // with 1 offset.
+ shiftMapping(outputMap, index, 1);
+ // Then add the constant key mapping.
+ outputMap.put(entry.getKey(), index);
}
relBuilder.project(postProjects);
}
// Aggregate does not change input ordering so corVars will be
// located at the same position as the input newProject.
- return register(rel, relBuilder.build(), combinedMap, corDefOutputs);
+ return register(rel, relBuilder.build(), outputMap, corDefOutputs);
+ }
+
+ /**
+ * Shift the mapping to fixed offset from the {@code startIndex}.
+ * @param mapping the original mapping
+ * @param startIndex any output whose index equals with or bigger than the starting index
+ * would be shift
+ * @param offset shift offset
+ */
+ private static void shiftMapping(Map<Integer, Integer> mapping, int startIndex, int offset) {
+ for (Map.Entry<Integer, Integer> entry : mapping.entrySet()) {
+ if (entry.getValue() >= startIndex) {
+ mapping.put(entry.getKey(), entry.getValue() + offset);
+ } else {
+ mapping.put(entry.getKey(), entry.getValue());
+ }
+ }
}
public Frame getInvoke(RelNode r, RelNode parent) {
diff --git a/core/src/test/java/org/apache/calcite/test/MockSqlOperatorTable.java b/core/src/test/java/org/apache/calcite/test/MockSqlOperatorTable.java
index 78b842d..c128540 100644
--- a/core/src/test/java/org/apache/calcite/test/MockSqlOperatorTable.java
+++ b/core/src/test/java/org/apache/calcite/test/MockSqlOperatorTable.java
@@ -18,6 +18,7 @@ package org.apache.calcite.test;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
+import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlFunction;
import org.apache.calcite.sql.SqlFunctionCategory;
import org.apache.calcite.sql.SqlIdentifier;
@@ -27,6 +28,8 @@ import org.apache.calcite.sql.SqlOperatorBinding;
import org.apache.calcite.sql.SqlOperatorTable;
import org.apache.calcite.sql.parser.SqlParserPos;
import org.apache.calcite.sql.type.OperandTypes;
+import org.apache.calcite.sql.type.ReturnTypes;
+import org.apache.calcite.sql.type.SqlTypeFamily;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.sql.util.ChainedSqlOperatorTable;
import org.apache.calcite.sql.util.ListSqlOperatorTable;
@@ -64,6 +67,7 @@ public class MockSqlOperatorTable extends ChainedSqlOperatorTable {
opTab.addOperator(new RampFunction());
opTab.addOperator(new DedupFunction());
opTab.addOperator(new MyFunction());
+ opTab.addOperator(new MyAvgAggFunction());
}
/** "RAMP" user-defined function. */
@@ -125,6 +129,26 @@ public class MockSqlOperatorTable extends ChainedSqlOperatorTable {
return typeFactory.createSqlType(SqlTypeName.BIGINT);
}
}
+
+ /** "MY_AVG" user-defined aggregate function. */
+ public static class MyAvgAggFunction extends SqlAggFunction {
+ public MyAvgAggFunction() {
+ super("MY_AVG",
+ null,
+ SqlKind.AVG,
+ ReturnTypes.AVG_AGG_FUNCTION,
+ null,
+ OperandTypes.family(SqlTypeFamily.NUMERIC, SqlTypeFamily.NUMERIC),
+ SqlFunctionCategory.NUMERIC,
+ false,
+ false);
+ }
+
+ @Override public boolean isDeterministic() {
+ return false;
+ }
+ }
+
}
// End MockSqlOperatorTable.java
diff --git a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
index d70b320..38719ce 100644
--- a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
+++ b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
@@ -5278,6 +5278,44 @@ public class RelOptRulesTest extends RelOptTestBase {
}
/** Test case for
+ * <a href="https://issues.apache.org/jira/browse/CALCITE-2744">[CALCITE-2744]
+ * RelDecorrelator use wrong output map for LogicalAggregate decorrelate</a>. */
+ @Test public void testDecorrelateAggWithConstantGroupKey() {
+ final String sql = "SELECT * FROM emp A where sal in \n"
+ + "(SELECT max(sal) FROM emp B where A.mgr = B.empno group by deptno, 'abc')";
+ sql(sql)
+ .withLateDecorrelation(true)
+ .withTrim(true)
+ .with(HepProgram.builder().build())
+ .check();
+ }
+
+ /** Test case for CALCITE-2744 for aggregate decorrelate with multi-param agg call
+ * but without group key. */
+ @Test public void testDecorrelateAggWithMultiParamsAggCall() {
+ final String sql = "SELECT * FROM (SELECT MY_AVG(sal, 1) AS c FROM emp) as m,\n"
+ + " LATERAL TABLE(ramp(m.c)) AS T(s)";
+ sql(sql)
+ .withLateDecorrelation(true)
+ .withTrim(true)
+ .with(HepProgram.builder().build())
+ .checkUnchanged();
+ }
+
+ /** Same as {@link #testDecorrelateAggWithMultiParamsAggCall}
+ * but with constant grouping key. */
+ @Test public void testDecorrelateAggWithMultiParamsAggCall2() {
+ final String sql = "SELECT * FROM "
+ + "(SELECT MY_AVG(sal, 1) AS c FROM emp group by empno, 'abc') as m,\n"
+ + " LATERAL TABLE(ramp(m.c)) AS T(s)";
+ sql(sql)
+ .withLateDecorrelation(true)
+ .withTrim(true)
+ .with(HepProgram.builder().build())
+ .checkUnchanged();
+ }
+
+ /** Test case for
* <a href="https://issues.apache.org/jira/browse/CALCITE-434">[CALCITE-434]
* Converting predicates on date dimension columns into date ranges</a>,
* specifically a rule that converts {@code EXTRACT(YEAR FROM ...) = constant}
diff --git a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
index a359d64..21f99d2 100644
--- a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
+++ b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
@@ -2684,6 +2684,106 @@ LogicalProject("K0"=[$0], "C1"=[$1], "F1"."A0"=[$2], "F2"."A0"=[$3], "F0"."C0"=[
]]>
</Resource>
</TestCase>
+ <TestCase name="testDecorrelateAggWithConstantGroupKey">
+ <Resource name="sql">
+ <![CDATA[SELECT * FROM emp A where sal in
+(SELECT max(sal) FROM emp B where A.mgr = B.empno group by deptno, 'abc')]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], SLACKER=[$8])
+ LogicalCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{3, 5}])
+ LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+ LogicalFilter(condition=[=($cor0.SAL, $0)])
+ LogicalAggregate(group=[{0}])
+ LogicalProject(EXPR$0=[$2])
+ LogicalAggregate(group=[{0, 1}], EXPR$0=[MAX($2)])
+ LogicalProject(DEPTNO=[$7], $f1=['abc'], SAL=[$5])
+ LogicalFilter(condition=[=($cor0.MGR, $0)])
+ LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+ </Resource>
+ <Resource name="planMid">
+ <![CDATA[
+LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], SLACKER=[$8])
+ LogicalCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{3, 5}])
+ LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+ LogicalFilter(condition=[=($cor0.SAL, $0)])
+ LogicalAggregate(group=[{0}])
+ LogicalProject(EXPR$0=[$2])
+ LogicalAggregate(group=[{0, 1}], EXPR$0=[MAX($2)])
+ LogicalProject(DEPTNO=[$7], $f1=['abc'], SAL=[$5])
+ LogicalFilter(condition=[=($cor0.MGR, $0)])
+ LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+ </Resource>
+ <Resource name="planAfter">
+ <![CDATA[
+LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], SLACKER=[$8])
+ LogicalJoin(condition=[AND(=($3, $10), =($5, $9))], joinType=[inner])
+ LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+ LogicalAggregate(group=[{0, 1}])
+ LogicalProject(EXPR$0=[$2], EMPNO=[$1])
+ LogicalAggregate(group=[{0, 1}], EXPR$0=[MAX($3)])
+ LogicalProject(DEPTNO=[$7], EMPNO=[$0], $f1=['abc'], SAL=[$5])
+ LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testDecorrelateAggWithMultiParamsAggCall">
+ <Resource name="sql">
+ <![CDATA[SELECT * FROM (SELECT MY_AVG(sal, 1) AS c FROM emp) as m,
+ LATERAL TABLE(ramp(m.c)) AS T(s)]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalProject(C=[$0], S=[$1])
+ LogicalCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{0}])
+ LogicalAggregate(group=[{}], C=[MY_AVG($0, $1)])
+ LogicalProject(SAL=[$5], $f1=[1])
+ LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+ LogicalTableFunctionScan(invocation=[RAMP($cor0.C)], rowType=[RecordType(INTEGER I)])
+]]>
+ </Resource>
+ <Resource name="planMid">
+ <![CDATA[
+LogicalProject(C=[$0], S=[$1])
+ LogicalCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{0}])
+ LogicalAggregate(group=[{}], C=[MY_AVG($0, $1)])
+ LogicalProject(SAL=[$5], $f1=[1])
+ LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+ LogicalTableFunctionScan(invocation=[RAMP($cor0.C)], rowType=[RecordType(INTEGER I)])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testDecorrelateAggWithMultiParamsAggCall2">
+ <Resource name="sql">
+ <![CDATA[SELECT * FROM (SELECT MY_AVG(sal, 1) AS c FROM emp group by empno, 'abc') as m,
+ LATERAL TABLE(ramp(m.c)) AS T(s)]]>
+ </Resource>
+ <Resource name="planBefore">
+ <![CDATA[
+LogicalProject(C=[$0], S=[$1])
+ LogicalCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{0}])
+ LogicalProject(C=[$2])
+ LogicalAggregate(group=[{0, 1}], C=[MY_AVG($2, $3)])
+ LogicalProject(EMPNO=[$0], $f1=['abc'], SAL=[$5], $f3=[1])
+ LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+ LogicalTableFunctionScan(invocation=[RAMP($cor0.C)], rowType=[RecordType(INTEGER I)])
+]]>
+ </Resource>
+ <Resource name="planMid">
+ <![CDATA[
+LogicalProject(C=[$0], S=[$1])
+ LogicalCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{0}])
+ LogicalProject(C=[$2])
+ LogicalAggregate(group=[{0, 1}], C=[MY_AVG($2, $3)])
+ LogicalProject(EMPNO=[$0], $f1=['abc'], SAL=[$5], $f3=[1])
+ LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+ LogicalTableFunctionScan(invocation=[RAMP($cor0.C)], rowType=[RecordType(INTEGER I)])
+]]>
+ </Resource>
+ </TestCase>
<TestCase name="testExtractYearMonthToRange">
<Resource name="sql">
<![CDATA[select *