You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@pinot.apache.org by ro...@apache.org on 2022/11/30 17:44:30 UTC
[pinot] branch master updated: [multistage][hotfix] shuffle rewrite (#9870)
This is an automated email from the ASF dual-hosted git repository.
rongr pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git
The following commit(s) were added to refs/heads/master by this push:
new 9bd420fc4a [multistage][hotfix] shuffle rewrite (#9870)
9bd420fc4a is described below
commit 9bd420fc4a53f13f19e921587cab8c9ec67cc233
Author: Rong Rong <ro...@apache.org>
AuthorDate: Wed Nov 30 09:44:23 2022 -0800
[multistage][hotfix] shuffle rewrite (#9870)
* fix incorrect shuffle rewrite for project and agg, it should never alter the partition keys.
* fix another partition key rewrite doesn't carry inputRef changes due to node schema changes.
Co-authored-by: Rong Rong <ro...@startree.ai>
---
.../planner/logical/ShuffleRewriteVisitor.java | 60 ++++++++++++----------
.../apache/pinot/query/QueryCompilationTest.java | 38 ++++++++++++--
.../test/resources/queries/SelectExpressions.json | 3 +-
3 files changed, 69 insertions(+), 32 deletions(-)
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/ShuffleRewriteVisitor.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/ShuffleRewriteVisitor.java
index 58adfc96d1..b881982fe8 100644
--- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/ShuffleRewriteVisitor.java
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/ShuffleRewriteVisitor.java
@@ -18,7 +18,10 @@
*/
package org.apache.pinot.query.planner.logical;
+import java.util.HashMap;
import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
import java.util.Set;
import org.apache.calcite.rel.RelDistribution;
import org.apache.pinot.query.planner.partitioning.FieldSelectionKeySelector;
@@ -66,20 +69,8 @@ public class ShuffleRewriteVisitor implements StageNodeVisitor<Set<Integer>, Voi
@Override
public Set<Integer> visitAggregate(AggregateNode node, Void context) {
Set<Integer> oldPartitionKeys = node.getInputs().get(0).visit(this, context);
-
- // any input reference directly carries over in group set of aggregation
- // should still be a partition key
- Set<Integer> partitionKeys = new HashSet<>();
- for (int i = 0; i < node.getGroupSet().size(); i++) {
- RexExpression rex = node.getGroupSet().get(i);
- if (rex instanceof RexExpression.InputRef) {
- if (oldPartitionKeys.contains(((RexExpression.InputRef) rex).getIndex())) {
- partitionKeys.add(i);
- }
- }
- }
-
- return partitionKeys;
+ List<RexExpression> groupSet = node.getGroupSet();
+ return deriveNewPartitionKeysFromRexExpressions(groupSet, oldPartitionKeys);
}
@Override
@@ -105,11 +96,15 @@ public class ShuffleRewriteVisitor implements StageNodeVisitor<Set<Integer>, Voi
if (leftPKs.contains(leftIdx)) {
partitionKeys.add(leftIdx);
}
+ // TODO: enable right key carrying. currently we only support left key carrying b/c of the partition key list
+ // doesn't understand equivalent partition key column or group partition key columns, yet.
+ /*
if (rightPks.contains(rightIdx)) {
// combined schema will have all the left fields before the right fields
// so add the leftDataSchemaSize before adding the key
partitionKeys.add(leftDataSchemaSize + rightIdx);
}
+ */
}
return partitionKeys;
@@ -150,19 +145,7 @@ public class ShuffleRewriteVisitor implements StageNodeVisitor<Set<Integer>, Voi
@Override
public Set<Integer> visitProject(ProjectNode node, Void context) {
Set<Integer> oldPartitionKeys = node.getInputs().get(0).visit(this, context);
-
- // all inputs carry over if they're still in the projection result
- Set<Integer> partitionKeys = new HashSet<>();
- for (int i = 0; i < node.getProjects().size(); i++) {
- RexExpression rex = node.getProjects().get(i);
- if (rex instanceof RexExpression.InputRef) {
- if (oldPartitionKeys.contains(((RexExpression.InputRef) rex).getIndex())) {
- partitionKeys.add(i);
- }
- }
- }
-
- return partitionKeys;
+ return deriveNewPartitionKeysFromRexExpressions(node.getProjects(), oldPartitionKeys);
}
@Override
@@ -189,4 +172,27 @@ public class ShuffleRewriteVisitor implements StageNodeVisitor<Set<Integer>, Voi
}
return false;
}
+
+ private static Set<Integer> deriveNewPartitionKeysFromRexExpressions(List<RexExpression> rexExpressionList,
+ Set<Integer> oldPartitionKeys) {
+ Map<Integer, Integer> partitionKeyMap = new HashMap<>();
+ for (int i = 0; i < rexExpressionList.size(); i++) {
+ RexExpression rex = rexExpressionList.get(i);
+ if (rex instanceof RexExpression.InputRef) {
+ // put the old-index to new-index mapping
+ // TODO: it doesn't handle duplicate references. e.g. if the same old partition key is referred twice. it will
+ // only keep the second one. (see JOIN handling on left/right as another example)
+ partitionKeyMap.put(((RexExpression.InputRef) rex).getIndex(), i);
+ }
+ }
+ if (partitionKeyMap.keySet().containsAll(oldPartitionKeys)) {
+ Set<Integer> newPartitionKeys = new HashSet<>();
+ for (int oldKey : oldPartitionKeys) {
+ newPartitionKeys.add(partitionKeyMap.get(oldKey));
+ }
+ return newPartitionKeys;
+ } else {
+ return new HashSet<>();
+ }
+ }
}
diff --git a/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryCompilationTest.java b/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryCompilationTest.java
index 7ece7fce43..c562067d2a 100644
--- a/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryCompilationTest.java
+++ b/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryCompilationTest.java
@@ -79,13 +79,39 @@ public class QueryCompilationTest extends QueryEnvironmentTestBase {
}
@Test
- public void testQueryGroupByAfterJoinShouldNotDoDataShuffle()
+ public void testQueryGroupByAfterJoinShouldProperlyRewriteShuffle()
throws Exception {
String query = "SELECT a.col1, a.col2, AVG(b.col3) FROM a JOIN b ON a.col1 = b.col2 "
+ " WHERE a.col3 >= 0 AND a.col2 = 'a' AND b.col3 < 0 GROUP BY a.col1, a.col2";
QueryPlan queryPlan = _queryEnvironment.planQuery(query);
- Assert.assertEquals(queryPlan.getQueryStageMap().size(), 5);
- Assert.assertEquals(queryPlan.getStageMetadataMap().size(), 5);
+ assertGroupBySingletonAfterJoin(queryPlan, true);
+
+ // same query with selection list re-odering should also work
+ query = "SELECT a.col2, a.col1, AVG(b.col3) FROM a JOIN b ON a.col1 = b.col2 "
+ + " WHERE a.col3 >= 0 AND a.col2 = 'a' AND b.col3 < 0 GROUP BY a.col2, a.col1";
+ queryPlan = _queryEnvironment.planQuery(query);
+ assertGroupBySingletonAfterJoin(queryPlan, true);
+
+ // exact same group key should also work
+ query = "SELECT a.col1, a.col2, AVG(b.col3) FROM a JOIN b ON a.col1 = b.col2 AND a.col2 = b.col2"
+ + " WHERE a.col3 >= 0 AND a.col2 = 'a' AND b.col3 < 0 GROUP BY a.col1, a.col2";
+ queryPlan = _queryEnvironment.planQuery(query);
+ assertGroupBySingletonAfterJoin(queryPlan, true);
+
+ // shrinking group key should not rewrite into singleton
+ query = "SELECT a.col1, AVG(b.col3) FROM a JOIN b ON a.col1 = b.col2 AND a.col2 = b.col2"
+ + " WHERE a.col3 >= 0 AND a.col2 = 'a' AND b.col3 < 0 GROUP BY a.col1";
+ queryPlan = _queryEnvironment.planQuery(query);
+ assertGroupBySingletonAfterJoin(queryPlan, false);
+
+ // mismatched group key should not rewrite into singleton
+ query = "SELECT a.col3, a.col1, AVG(b.col3) FROM a JOIN b ON a.col1 = b.col2 AND a.col2 = b.col2"
+ + " WHERE a.col3 >= 0 AND a.col2 = 'a' AND b.col3 < 0 GROUP BY a.col1, a.col3";
+ queryPlan = _queryEnvironment.planQuery(query);
+ assertGroupBySingletonAfterJoin(queryPlan, false);
+ }
+
+ private static void assertGroupBySingletonAfterJoin(QueryPlan queryPlan, boolean shouldRewrite) throws Exception {
for (Map.Entry<Integer, StageMetadata> e : queryPlan.getStageMetadataMap().entrySet()) {
if (e.getValue().getScannedTables().size() == 0 && !PlannerUtils.isRootStage(e.getKey())) {
StageNode node = queryPlan.getQueryStageMap().get(e.getKey());
@@ -101,7 +127,11 @@ public class QueryCompilationTest extends QueryEnvironmentTestBase {
if (node instanceof AggregateNode && node.getInputs().get(0) instanceof MailboxReceiveNode) {
// AGG is exchanged with singleton since it has already been distributed by JOIN.
MailboxReceiveNode input = (MailboxReceiveNode) node.getInputs().get(0);
- Assert.assertEquals(input.getExchangeType(), RelDistribution.Type.SINGLETON);
+ if (shouldRewrite) {
+ Assert.assertEquals(input.getExchangeType(), RelDistribution.Type.SINGLETON);
+ } else {
+ Assert.assertNotEquals(input.getExchangeType(), RelDistribution.Type.SINGLETON);
+ }
break;
}
node = node.getInputs().get(0);
diff --git a/pinot-query-runtime/src/test/resources/queries/SelectExpressions.json b/pinot-query-runtime/src/test/resources/queries/SelectExpressions.json
index 57c0b09f1a..f214307b2a 100644
--- a/pinot-query-runtime/src/test/resources/queries/SelectExpressions.json
+++ b/pinot-query-runtime/src/test/resources/queries/SelectExpressions.json
@@ -36,7 +36,8 @@
{ "sql": "SELECT intCol as \"value\", doubleCol + floatCol AS \"sum\" FROM {tbl1}"},
{ "sql": "SELECT intCol as \"from\" FROM {tbl1}"},
{ "sql": "SELECT intCol as key, SUM(doubleCol + floatCol) AS aggSum FROM {tbl1} GROUP BY intCol"},
- { "sql": "SELECT a.intCol as key, SUM(a.doubleCol + b.intCol) AS aggSum FROM {tbl1} AS a JOIN {tbl2} AS b ON a.intCol = b.intCol GROUP BY a.intCol"}
+ { "sql": "SELECT intCol, SUM(avgVal) FROM (SELECT strCol, intCol, AVG(doubleCol) AS avgVal FROM {tbl1} GROUP BY intCol, strCol) GROUP BY intCol"},
+ { "sql": "SELECT strCol, MAX(sumVal), MAX(sumVal + avgVal) AS transVal FROM (SELECT strCol, intCol, SUM(floatCol + 2 * intCol) AS sumVal, AVG(doubleCol) AS avgVal FROM {tbl1} GROUP BY strCol, intCol) GROUP BY strCol ORDER BY MAX(sumVal)" }
]
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@pinot.apache.org
For additional commands, e-mail: commits-help@pinot.apache.org