You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@pinot.apache.org by ro...@apache.org on 2022/11/30 17:44:30 UTC

[pinot] branch master updated: [multistage][hotfix] shuffle rewrite (#9870)

This is an automated email from the ASF dual-hosted git repository.

rongr pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git


The following commit(s) were added to refs/heads/master by this push:
     new 9bd420fc4a [multistage][hotfix] shuffle rewrite (#9870)
9bd420fc4a is described below

commit 9bd420fc4a53f13f19e921587cab8c9ec67cc233
Author: Rong Rong <ro...@apache.org>
AuthorDate: Wed Nov 30 09:44:23 2022 -0800

    [multistage][hotfix] shuffle rewrite (#9870)
    
    * fix incorrect shuffle rewrite for project and agg, it should never alter the partition keys.
    * fix another partition key rewrite doesn't carry inputRef changes due to node schema changes.
    
    Co-authored-by: Rong Rong <ro...@startree.ai>
---
 .../planner/logical/ShuffleRewriteVisitor.java     | 60 ++++++++++++----------
 .../apache/pinot/query/QueryCompilationTest.java   | 38 ++++++++++++--
 .../test/resources/queries/SelectExpressions.json  |  3 +-
 3 files changed, 69 insertions(+), 32 deletions(-)

diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/ShuffleRewriteVisitor.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/ShuffleRewriteVisitor.java
index 58adfc96d1..b881982fe8 100644
--- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/ShuffleRewriteVisitor.java
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/ShuffleRewriteVisitor.java
@@ -18,7 +18,10 @@
  */
 package org.apache.pinot.query.planner.logical;
 
+import java.util.HashMap;
 import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
 import java.util.Set;
 import org.apache.calcite.rel.RelDistribution;
 import org.apache.pinot.query.planner.partitioning.FieldSelectionKeySelector;
@@ -66,20 +69,8 @@ public class ShuffleRewriteVisitor implements StageNodeVisitor<Set<Integer>, Voi
   @Override
   public Set<Integer> visitAggregate(AggregateNode node, Void context) {
     Set<Integer> oldPartitionKeys = node.getInputs().get(0).visit(this, context);
-
-    // any input reference directly carries over in group set of aggregation
-    // should still be a partition key
-    Set<Integer> partitionKeys = new HashSet<>();
-    for (int i = 0; i < node.getGroupSet().size(); i++) {
-      RexExpression rex = node.getGroupSet().get(i);
-      if (rex instanceof RexExpression.InputRef) {
-        if (oldPartitionKeys.contains(((RexExpression.InputRef) rex).getIndex())) {
-          partitionKeys.add(i);
-        }
-      }
-    }
-
-    return partitionKeys;
+    List<RexExpression> groupSet = node.getGroupSet();
+    return deriveNewPartitionKeysFromRexExpressions(groupSet, oldPartitionKeys);
   }
 
   @Override
@@ -105,11 +96,15 @@ public class ShuffleRewriteVisitor implements StageNodeVisitor<Set<Integer>, Voi
       if (leftPKs.contains(leftIdx)) {
         partitionKeys.add(leftIdx);
       }
+      // TODO: enable right key carrying. currently we only support left key carrying b/c of the partition key list
+      // doesn't understand equivalent partition key column or group partition key columns, yet.
+      /*
       if (rightPks.contains(rightIdx)) {
         // combined schema will have all the left fields before the right fields
         // so add the leftDataSchemaSize before adding the key
         partitionKeys.add(leftDataSchemaSize + rightIdx);
       }
+      */
     }
 
     return partitionKeys;
@@ -150,19 +145,7 @@ public class ShuffleRewriteVisitor implements StageNodeVisitor<Set<Integer>, Voi
   @Override
   public Set<Integer> visitProject(ProjectNode node, Void context) {
     Set<Integer> oldPartitionKeys = node.getInputs().get(0).visit(this, context);
-
-    // all inputs carry over if they're still in the projection result
-    Set<Integer> partitionKeys = new HashSet<>();
-    for (int i = 0; i < node.getProjects().size(); i++) {
-      RexExpression rex = node.getProjects().get(i);
-      if (rex instanceof RexExpression.InputRef) {
-        if (oldPartitionKeys.contains(((RexExpression.InputRef) rex).getIndex())) {
-          partitionKeys.add(i);
-        }
-      }
-    }
-
-    return partitionKeys;
+    return deriveNewPartitionKeysFromRexExpressions(node.getProjects(), oldPartitionKeys);
   }
 
   @Override
@@ -189,4 +172,27 @@ public class ShuffleRewriteVisitor implements StageNodeVisitor<Set<Integer>, Voi
     }
     return false;
   }
+
+  private static Set<Integer> deriveNewPartitionKeysFromRexExpressions(List<RexExpression> rexExpressionList,
+      Set<Integer> oldPartitionKeys) {
+    Map<Integer, Integer> partitionKeyMap = new HashMap<>();
+    for (int i = 0; i < rexExpressionList.size(); i++) {
+      RexExpression rex = rexExpressionList.get(i);
+      if (rex instanceof RexExpression.InputRef) {
+        // put the old-index to new-index mapping
+        // TODO: it doesn't handle duplicate references. e.g. if the same old partition key is referred twice. it will
+        // only keep the second one. (see JOIN handling on left/right as another example)
+        partitionKeyMap.put(((RexExpression.InputRef) rex).getIndex(), i);
+      }
+    }
+    if (partitionKeyMap.keySet().containsAll(oldPartitionKeys)) {
+      Set<Integer> newPartitionKeys = new HashSet<>();
+      for (int oldKey : oldPartitionKeys) {
+        newPartitionKeys.add(partitionKeyMap.get(oldKey));
+      }
+      return newPartitionKeys;
+    } else {
+      return new HashSet<>();
+    }
+  }
 }
diff --git a/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryCompilationTest.java b/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryCompilationTest.java
index 7ece7fce43..c562067d2a 100644
--- a/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryCompilationTest.java
+++ b/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryCompilationTest.java
@@ -79,13 +79,39 @@ public class QueryCompilationTest extends QueryEnvironmentTestBase {
   }
 
   @Test
-  public void testQueryGroupByAfterJoinShouldNotDoDataShuffle()
+  public void testQueryGroupByAfterJoinShouldProperlyRewriteShuffle()
       throws Exception {
     String query = "SELECT a.col1, a.col2, AVG(b.col3) FROM a JOIN b ON a.col1 = b.col2 "
         + " WHERE a.col3 >= 0 AND a.col2 = 'a' AND b.col3 < 0 GROUP BY a.col1, a.col2";
     QueryPlan queryPlan = _queryEnvironment.planQuery(query);
-    Assert.assertEquals(queryPlan.getQueryStageMap().size(), 5);
-    Assert.assertEquals(queryPlan.getStageMetadataMap().size(), 5);
+    assertGroupBySingletonAfterJoin(queryPlan, true);
+
+    // same query with selection list re-odering should also work
+    query = "SELECT a.col2, a.col1, AVG(b.col3) FROM a JOIN b ON a.col1 = b.col2 "
+        + " WHERE a.col3 >= 0 AND a.col2 = 'a' AND b.col3 < 0 GROUP BY a.col2, a.col1";
+    queryPlan = _queryEnvironment.planQuery(query);
+    assertGroupBySingletonAfterJoin(queryPlan, true);
+
+    // exact same group key should also work
+    query = "SELECT a.col1, a.col2, AVG(b.col3) FROM a JOIN b ON a.col1 = b.col2 AND a.col2 = b.col2"
+        + " WHERE a.col3 >= 0 AND a.col2 = 'a' AND b.col3 < 0 GROUP BY a.col1, a.col2";
+    queryPlan = _queryEnvironment.planQuery(query);
+    assertGroupBySingletonAfterJoin(queryPlan, true);
+
+    // shrinking group key should not rewrite into singleton
+    query = "SELECT a.col1, AVG(b.col3) FROM a JOIN b ON a.col1 = b.col2 AND a.col2 = b.col2"
+        + " WHERE a.col3 >= 0 AND a.col2 = 'a' AND b.col3 < 0 GROUP BY a.col1";
+    queryPlan = _queryEnvironment.planQuery(query);
+    assertGroupBySingletonAfterJoin(queryPlan, false);
+
+    // mismatched group key should not rewrite into singleton
+    query = "SELECT a.col3, a.col1, AVG(b.col3) FROM a JOIN b ON a.col1 = b.col2 AND a.col2 = b.col2"
+        + " WHERE a.col3 >= 0 AND a.col2 = 'a' AND b.col3 < 0 GROUP BY a.col1, a.col3";
+    queryPlan = _queryEnvironment.planQuery(query);
+    assertGroupBySingletonAfterJoin(queryPlan, false);
+  }
+
+  private static void assertGroupBySingletonAfterJoin(QueryPlan queryPlan, boolean shouldRewrite) throws Exception {
     for (Map.Entry<Integer, StageMetadata> e : queryPlan.getStageMetadataMap().entrySet()) {
       if (e.getValue().getScannedTables().size() == 0 && !PlannerUtils.isRootStage(e.getKey())) {
         StageNode node = queryPlan.getQueryStageMap().get(e.getKey());
@@ -101,7 +127,11 @@ public class QueryCompilationTest extends QueryEnvironmentTestBase {
           if (node instanceof AggregateNode && node.getInputs().get(0) instanceof MailboxReceiveNode) {
             // AGG is exchanged with singleton since it has already been distributed by JOIN.
             MailboxReceiveNode input = (MailboxReceiveNode) node.getInputs().get(0);
-            Assert.assertEquals(input.getExchangeType(), RelDistribution.Type.SINGLETON);
+            if (shouldRewrite) {
+              Assert.assertEquals(input.getExchangeType(), RelDistribution.Type.SINGLETON);
+            } else {
+              Assert.assertNotEquals(input.getExchangeType(), RelDistribution.Type.SINGLETON);
+            }
             break;
           }
           node = node.getInputs().get(0);
diff --git a/pinot-query-runtime/src/test/resources/queries/SelectExpressions.json b/pinot-query-runtime/src/test/resources/queries/SelectExpressions.json
index 57c0b09f1a..f214307b2a 100644
--- a/pinot-query-runtime/src/test/resources/queries/SelectExpressions.json
+++ b/pinot-query-runtime/src/test/resources/queries/SelectExpressions.json
@@ -36,7 +36,8 @@
       { "sql": "SELECT intCol as \"value\", doubleCol + floatCol AS \"sum\" FROM {tbl1}"},
       { "sql": "SELECT intCol as \"from\" FROM {tbl1}"},
       { "sql": "SELECT intCol as key, SUM(doubleCol + floatCol) AS aggSum FROM {tbl1} GROUP BY intCol"},
-      { "sql": "SELECT a.intCol as key, SUM(a.doubleCol + b.intCol) AS aggSum FROM {tbl1} AS a JOIN {tbl2} AS b ON a.intCol = b.intCol GROUP BY a.intCol"}
+      { "sql": "SELECT intCol, SUM(avgVal) FROM (SELECT strCol, intCol, AVG(doubleCol) AS avgVal FROM {tbl1} GROUP BY intCol, strCol) GROUP BY intCol"},
+      { "sql": "SELECT strCol, MAX(sumVal), MAX(sumVal + avgVal) AS transVal FROM (SELECT strCol, intCol, SUM(floatCol + 2 * intCol) AS sumVal, AVG(doubleCol) AS avgVal FROM {tbl1} GROUP BY strCol, intCol) GROUP BY strCol ORDER BY MAX(sumVal)" }
     ]
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@pinot.apache.org
For additional commands, e-mail: commits-help@pinot.apache.org