You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@pinot.apache.org by si...@apache.org on 2023/02/02 00:49:44 UTC

[pinot] branch master updated: [multistage] add singleton instance stage (#10211)

This is an automated email from the ASF dual-hosted git repository.

siddteotia pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git


The following commit(s) were added to refs/heads/master by this push:
     new 5f3e1dcc47 [multistage] add singleton instance stage (#10211)
5f3e1dcc47 is described below

commit 5f3e1dcc477798bcda1f6673d4e61bd3958c2710
Author: Rong Rong <ro...@apache.org>
AuthorDate: Wed Feb 1 16:49:37 2023 -0800

    [multistage] add singleton instance stage (#10211)
    
    * adding singleton-only stage concept
    
    * adding in fixes for rules determining whether singleton is needed for intermediate vs final stage of a relnode
    
    ---------
    
    Co-authored-by: Rong Rong <ro...@startree.ai>
---
 .../rel/rules/PinotJoinExchangeNodeInsertRule.java |  2 +-
 .../calcite/rel/rules/PinotQueryRuleSets.java      | 11 ++++---
 .../rel/rules/PinotSortExchangeCopyRule.java       | 10 ++++--
 .../apache/pinot/query/planner/PlannerUtils.java   |  4 +++
 .../apache/pinot/query/planner/StageMetadata.java  | 20 ++++++++++++
 .../query/planner/hints/PinotRelationalHints.java  |  2 --
 .../query/planner/logical/RelToStageConverter.java |  2 +-
 .../query/planner/logical/RexExpressionUtils.java  |  2 +-
 .../pinot/query/planner/stage/AggregateNode.java   | 11 ++++++-
 .../apache/pinot/query/routing/WorkerManager.java  | 36 ++++++++++++++++------
 .../pinot/query/runtime/operator/SortOperator.java |  2 +-
 .../src/test/resources/queries/Parallelism.json    |  6 +---
 12 files changed, 80 insertions(+), 28 deletions(-)

diff --git a/pinot-query-planner/src/main/java/org/apache/calcite/rel/rules/PinotJoinExchangeNodeInsertRule.java b/pinot-query-planner/src/main/java/org/apache/calcite/rel/rules/PinotJoinExchangeNodeInsertRule.java
index 03e8d842bd..2253913927 100644
--- a/pinot-query-planner/src/main/java/org/apache/calcite/rel/rules/PinotJoinExchangeNodeInsertRule.java
+++ b/pinot-query-planner/src/main/java/org/apache/calcite/rel/rules/PinotJoinExchangeNodeInsertRule.java
@@ -65,7 +65,7 @@ public class PinotJoinExchangeNodeInsertRule extends RelOptRule {
 
     if (joinInfo.leftKeys.isEmpty()) {
       // when there's no JOIN key, use broadcast.
-      leftExchange = LogicalExchange.create(leftInput, RelDistributions.SINGLETON);
+      leftExchange = LogicalExchange.create(leftInput, RelDistributions.RANDOM_DISTRIBUTED);
       rightExchange = LogicalExchange.create(rightInput, RelDistributions.BROADCAST_DISTRIBUTED);
     } else {
       // when join key exists, use hash distribution.
diff --git a/pinot-query-planner/src/main/java/org/apache/calcite/rel/rules/PinotQueryRuleSets.java b/pinot-query-planner/src/main/java/org/apache/calcite/rel/rules/PinotQueryRuleSets.java
index 43ec25dfba..1a3589e1c9 100644
--- a/pinot-query-planner/src/main/java/org/apache/calcite/rel/rules/PinotQueryRuleSets.java
+++ b/pinot-query-planner/src/main/java/org/apache/calcite/rel/rules/PinotQueryRuleSets.java
@@ -59,10 +59,6 @@ public class PinotQueryRuleSets {
           CoreRules.PROJECT_MERGE,
           // remove identity project
           CoreRules.PROJECT_REMOVE,
-          // add an extra exchange for sort
-          PinotSortExchangeNodeInsertRule.INSTANCE,
-          // copy exchanges down
-          PinotSortExchangeCopyRule.SORT_EXCHANGE_COPY,
           // reorder sort and projection
           CoreRules.SORT_PROJECT_TRANSPOSE,
 
@@ -95,6 +91,13 @@ public class PinotQueryRuleSets {
 
           // Pinot specific rules
           PinotFilterExpandSearchRule.INSTANCE,
+
+          // Pinot exchange rules
+          // add an extra exchange for sort
+          PinotSortExchangeNodeInsertRule.INSTANCE,
+          // copy exchanges down, this must be done after SortExchangeNodeInsertRule
+          PinotSortExchangeCopyRule.SORT_EXCHANGE_COPY,
+
           PinotJoinExchangeNodeInsertRule.INSTANCE,
           PinotAggregateExchangeNodeInsertRule.INSTANCE
       );
diff --git a/pinot-query-planner/src/main/java/org/apache/calcite/rel/rules/PinotSortExchangeCopyRule.java b/pinot-query-planner/src/main/java/org/apache/calcite/rel/rules/PinotSortExchangeCopyRule.java
index 7f163ca99b..3a15679fd7 100644
--- a/pinot-query-planner/src/main/java/org/apache/calcite/rel/rules/PinotSortExchangeCopyRule.java
+++ b/pinot-query-planner/src/main/java/org/apache/calcite/rel/rules/PinotSortExchangeCopyRule.java
@@ -30,6 +30,7 @@ import org.apache.calcite.rel.logical.LogicalSortExchange;
 import org.apache.calcite.rel.metadata.RelMdUtil;
 import org.apache.calcite.rel.metadata.RelMetadataQuery;
 import org.apache.calcite.rex.RexBuilder;
+import org.apache.calcite.rex.RexLiteral;
 import org.apache.calcite.rex.RexNode;
 import org.apache.calcite.sql.type.SqlTypeName;
 import org.apache.pinot.query.planner.logical.RexExpressionUtils;
@@ -42,6 +43,9 @@ public class PinotSortExchangeCopyRule extends RelRule<RelRule.Config> {
   public static final PinotSortExchangeCopyRule SORT_EXCHANGE_COPY =
       PinotSortExchangeCopyRule.Config.DEFAULT.toRule();
   private static final TypeFactory TYPE_FACTORY = new TypeFactory(new TypeSystem());
+  private static final RexBuilder REX_BUILDER = new RexBuilder(TYPE_FACTORY);
+  private static final RexLiteral REX_ZERO = REX_BUILDER.makeLiteral(0,
+      TYPE_FACTORY.createSqlType(SqlTypeName.INTEGER));
 
   /**
    * Creates a PinotSortExchangeCopyRule.
@@ -80,14 +84,14 @@ public class PinotSortExchangeCopyRule extends RelRule<RelRule.Config> {
     } else if (sort.offset == null) {
       fetch = sort.fetch;
     } else {
-      RexBuilder rexBuilder = new RexBuilder(TYPE_FACTORY);
       int total = RexExpressionUtils.getValueAsInt(sort.fetch) + RexExpressionUtils.getValueAsInt(sort.offset);
-      fetch = rexBuilder.makeLiteral(total, TYPE_FACTORY.createSqlType(SqlTypeName.INTEGER));
+      fetch = REX_BUILDER.makeLiteral(total, TYPE_FACTORY.createSqlType(SqlTypeName.INTEGER));
     }
 
     final RelNode newExchangeInput = sort.copy(sort.getTraitSet(), exchange.getInput(), collation, null, fetch);
     final RelNode exchangeCopy = exchange.copy(exchange.getTraitSet(), newExchangeInput, exchange.getDistribution());
-    final RelNode sortCopy = sort.copy(sort.getTraitSet(), exchangeCopy, collation, sort.offset, sort.fetch);
+    final RelNode sortCopy = sort.copy(sort.getTraitSet(), exchangeCopy, collation,
+        sort.offset == null ? REX_ZERO : sort.offset, sort.fetch);
 
     call.transformTo(sortCopy);
   }
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/PlannerUtils.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/PlannerUtils.java
index 8cf8115a08..c3ce9fc116 100644
--- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/PlannerUtils.java
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/PlannerUtils.java
@@ -40,6 +40,10 @@ public class PlannerUtils {
     return stageId == 0;
   }
 
+  public static boolean isFinalStage(int stageId) {
+    return stageId == 1;
+  }
+
   public static String explainPlan(RelNode relRoot, SqlExplainFormat format, SqlExplainLevel explainLevel) {
     return RelOptUtil.dumpPlan("Execution Plan", relRoot, format, explainLevel);
   }
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/StageMetadata.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/StageMetadata.java
index 225599e098..8ac6743f84 100644
--- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/StageMetadata.java
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/StageMetadata.java
@@ -25,6 +25,9 @@ import java.util.List;
 import java.util.Map;
 import org.apache.pinot.core.routing.TimeBoundaryInfo;
 import org.apache.pinot.core.transport.ServerInstance;
+import org.apache.pinot.query.planner.hints.PinotRelationalHints;
+import org.apache.pinot.query.planner.stage.AggregateNode;
+import org.apache.pinot.query.planner.stage.SortNode;
 import org.apache.pinot.query.planner.stage.StageNode;
 import org.apache.pinot.query.planner.stage.TableScanNode;
 import org.apache.pinot.query.routing.VirtualServer;
@@ -54,18 +57,31 @@ public class StageMetadata implements Serializable {
   // time boundary info
   private TimeBoundaryInfo _timeBoundaryInfo;
 
+  // whether a stage requires singleton instance to execute, e.g. stage contains global reduce (sort/agg) operator.
+  private boolean _requiresSingletonInstance;
 
   public StageMetadata() {
     _scannedTables = new ArrayList<>();
     _serverInstances = new ArrayList<>();
     _serverInstanceToSegmentsMap = new HashMap<>();
     _timeBoundaryInfo = null;
+    _requiresSingletonInstance = false;
   }
 
   public void attach(StageNode stageNode) {
     if (stageNode instanceof TableScanNode) {
       _scannedTables.add(((TableScanNode) stageNode).getTableName());
     }
+    if (stageNode instanceof AggregateNode) {
+      AggregateNode aggNode = (AggregateNode) stageNode;
+      _requiresSingletonInstance = _requiresSingletonInstance || (aggNode.getGroupSet().size() == 0
+          && aggNode.getRelHints().contains(PinotRelationalHints.AGG_INTERMEDIATE_STAGE));
+    }
+    if (stageNode instanceof SortNode) {
+      SortNode sortNode = (SortNode) stageNode;
+      _requiresSingletonInstance = _requiresSingletonInstance || (sortNode.getCollationKeys().size() > 0
+          && sortNode.getOffset() != -1);
+    }
   }
 
   public List<String> getScannedTables() {
@@ -97,6 +113,10 @@ public class StageMetadata implements Serializable {
     return _timeBoundaryInfo;
   }
 
+  public boolean isRequiresSingletonInstance() {
+    return _requiresSingletonInstance;
+  }
+
   public void setTimeBoundaryInfo(TimeBoundaryInfo timeBoundaryInfo) {
     _timeBoundaryInfo = timeBoundaryInfo;
   }
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/hints/PinotRelationalHints.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/hints/PinotRelationalHints.java
index 2c4cb976a6..a479834650 100644
--- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/hints/PinotRelationalHints.java
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/hints/PinotRelationalHints.java
@@ -25,8 +25,6 @@ import org.apache.calcite.rel.hint.RelHint;
  * Provide certain relational hint to query planner for better optimization.
  */
 public class PinotRelationalHints {
-  public static final RelHint USE_HASH_DISTRIBUTE = RelHint.builder("USE_HASH_DISTRIBUTE").build();
-  public static final RelHint USE_BROADCAST_DISTRIBUTE = RelHint.builder("USE_BROADCAST_DISTRIBUTE").build();
   public static final RelHint AGG_INTERMEDIATE_STAGE = RelHint.builder("AGG_INTERMEDIATE_STAGE").build();
   public static final RelHint AGG_LEAF_STAGE = RelHint.builder("AGG_LEAF_STAGE").build();
 
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RelToStageConverter.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RelToStageConverter.java
index 80218c6442..cede6a38e3 100644
--- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RelToStageConverter.java
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RelToStageConverter.java
@@ -97,7 +97,7 @@ public final class RelToStageConverter {
 
   private static StageNode convertLogicalAggregate(LogicalAggregate node, int currentStageId) {
     return new AggregateNode(currentStageId, toDataSchema(node.getRowType()), node.getAggCallList(),
-        node.getGroupSet());
+        node.getGroupSet(), node.getHints());
   }
 
   private static StageNode convertLogicalProject(LogicalProject node, int currentStageId) {
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RexExpressionUtils.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RexExpressionUtils.java
index 364da7c164..c5a47fe789 100644
--- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RexExpressionUtils.java
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RexExpressionUtils.java
@@ -91,7 +91,7 @@ public class RexExpressionUtils {
 
   public static Integer getValueAsInt(RexNode in) {
     if (in == null) {
-      return 0;
+      return -1;
     }
 
     Preconditions.checkArgument(in instanceof RexLiteral, "expected literal, got " + in);
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/AggregateNode.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/AggregateNode.java
index ea8dc2c1c1..251b7986d3 100644
--- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/AggregateNode.java
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/stage/AggregateNode.java
@@ -22,6 +22,7 @@ import java.util.ArrayList;
 import java.util.List;
 import java.util.stream.Collectors;
 import org.apache.calcite.rel.core.AggregateCall;
+import org.apache.calcite.rel.hint.RelHint;
 import org.apache.calcite.util.ImmutableBitSet;
 import org.apache.pinot.common.utils.DataSchema;
 import org.apache.pinot.query.planner.logical.RexExpression;
@@ -29,6 +30,8 @@ import org.apache.pinot.query.planner.serde.ProtoProperties;
 
 
 public class AggregateNode extends AbstractStageNode {
+
+  private List<RelHint> _relHints;
   @ProtoProperties
   private List<RexExpression> _aggCalls;
   @ProtoProperties
@@ -38,13 +41,15 @@ public class AggregateNode extends AbstractStageNode {
     super(stageId);
   }
 
-  public AggregateNode(int stageId, DataSchema dataSchema, List<AggregateCall> aggCalls, ImmutableBitSet groupSet) {
+  public AggregateNode(int stageId, DataSchema dataSchema, List<AggregateCall> aggCalls, ImmutableBitSet groupSet,
+      List<RelHint> relHints) {
     super(stageId, dataSchema);
     _aggCalls = aggCalls.stream().map(RexExpression::toRexExpression).collect(Collectors.toList());
     _groupSet = new ArrayList<>(groupSet.cardinality());
     for (Integer integer : groupSet) {
       _groupSet.add(new RexExpression.InputRef(integer));
     }
+    _relHints = relHints;
   }
 
   public List<RexExpression> getAggCalls() {
@@ -55,6 +60,10 @@ public class AggregateNode extends AbstractStageNode {
     return _groupSet;
   }
 
+  public List<RelHint> getRelHints() {
+    return _relHints;
+  }
+
   @Override
   public String explain() {
     return "AGGREGATE";
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/routing/WorkerManager.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/routing/WorkerManager.java
index 60930b6dce..c1f896c86d 100644
--- a/pinot-query-planner/src/main/java/org/apache/pinot/query/routing/WorkerManager.java
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/routing/WorkerManager.java
@@ -25,6 +25,7 @@ import java.util.Collection;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
+import java.util.Random;
 import java.util.stream.Collectors;
 import org.apache.pinot.core.routing.RoutingManager;
 import org.apache.pinot.core.routing.RoutingTable;
@@ -48,6 +49,7 @@ import org.apache.pinot.sql.parsers.CalciteSqlCompiler;
  * the worker manager later when we split out the query-spi layer.
  */
 public class WorkerManager {
+  private static final Random RANDOM = new Random();
 
   private final String _hostName;
   private final int _port;
@@ -63,6 +65,7 @@ public class WorkerManager {
       Map<String, String> options) {
     List<String> scannedTables = stageMetadata.getScannedTables();
     if (scannedTables.size() == 1) {
+      // --- LEAF STAGE ---
       // table scan stage, need to attach server as well as segment info for each physical table type.
       String logicalTableName = scannedTables.get(0);
       Map<String, RoutingTable> routingTableMap = getRoutingTable(logicalTableName, requestId);
@@ -102,30 +105,45 @@ public class WorkerManager {
               .collect(Collectors.toList())));
       stageMetadata.setServerInstanceToSegmentsMap(serverInstanceToSegmentsMap);
     } else if (PlannerUtils.isRootStage(stageId)) {
+      // --- ROOT STAGE / BROKER REDUCE STAGE ---
       // ROOT stage doesn't have a QueryServer as it is strictly only reducing results.
       // here we simply assign the worker instance with identical server/mailbox port number.
       stageMetadata.setServerInstances(Lists.newArrayList(
           new VirtualServer(new WorkerInstance(_hostName, _port, _port, _port, _port), 0)));
     } else {
-      stageMetadata.setServerInstances(assignServers(_routingManager.getEnabledServerInstanceMap().values(), options));
+      // --- INTERMEDIATE STAGES ---
+      // TODO: actually make assignment strategy decisions for intermediate stages
+      stageMetadata.setServerInstances(assignServers(_routingManager.getEnabledServerInstanceMap().values(),
+          stageMetadata.isRequiresSingletonInstance(), options));
     }
   }
 
-  private static List<VirtualServer> assignServers(Collection<ServerInstance> servers, Map<String, String> options) {
+  private static List<VirtualServer> assignServers(Collection<ServerInstance> servers,
+      boolean requiresSingletonInstance, Map<String, String> options) {
     int stageParallelism = Integer.parseInt(
         options.getOrDefault(CommonConstants.Broker.Request.QueryOptionKey.STAGE_PARALLELISM, "1"));
 
     List<VirtualServer> serverInstances = new ArrayList<>();
+    int idx = 0;
+    int matchingIdx = -1;
+    if (requiresSingletonInstance) {
+      matchingIdx = RANDOM.nextInt(servers.size());
+    }
     for (ServerInstance server : servers) {
-      String hostname = server.getHostname();
-      if (server.getQueryServicePort() > 0 && server.getQueryMailboxPort() > 0
-          && !hostname.startsWith(CommonConstants.Helix.PREFIX_OF_BROKER_INSTANCE)
-          && !hostname.startsWith(CommonConstants.Helix.PREFIX_OF_CONTROLLER_INSTANCE)
-          && !hostname.startsWith(CommonConstants.Helix.PREFIX_OF_MINION_INSTANCE)) {
-        for (int virtualId = 0; virtualId < stageParallelism; virtualId++) {
-          serverInstances.add(new VirtualServer(server, virtualId));
+      if (matchingIdx == -1 || idx == matchingIdx) {
+        String hostname = server.getHostname();
+        if (server.getQueryServicePort() > 0 && server.getQueryMailboxPort() > 0
+            && !hostname.startsWith(CommonConstants.Helix.PREFIX_OF_BROKER_INSTANCE)
+            && !hostname.startsWith(CommonConstants.Helix.PREFIX_OF_CONTROLLER_INSTANCE)
+            && !hostname.startsWith(CommonConstants.Helix.PREFIX_OF_MINION_INSTANCE)) {
+          for (int virtualId = 0; virtualId < stageParallelism; virtualId++) {
+            if (matchingIdx == -1 || virtualId == 0) {
+              serverInstances.add(new VirtualServer(server, virtualId));
+            }
+          }
         }
       }
+      idx++;
     }
     return serverInstances;
   }
diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/SortOperator.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/SortOperator.java
index 1dee1e60c5..cde9bba3da 100644
--- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/SortOperator.java
+++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/SortOperator.java
@@ -65,7 +65,7 @@ public class SortOperator extends MultiStageOperator {
       int maxHolderCapacity, long requestId, int stageId) {
     _upstreamOperator = upstreamOperator;
     _fetch = fetch;
-    _offset = offset;
+    _offset = Math.max(offset, 0);
     _dataSchema = dataSchema;
     _upstreamErrorBlock = null;
     _isSortedBlockConstructed = false;
diff --git a/pinot-query-runtime/src/test/resources/queries/Parallelism.json b/pinot-query-runtime/src/test/resources/queries/Parallelism.json
index 5f5c3f6f55..e6e2db96ca 100644
--- a/pinot-query-runtime/src/test/resources/queries/Parallelism.json
+++ b/pinot-query-runtime/src/test/resources/queries/Parallelism.json
@@ -37,11 +37,7 @@
       {"sql": "SET stageParallelism=2; SELECT {l}.key, {l}.lval, {r}.rval FROM {l} JOIN {r} ON {l}.key = {r}.key"},
       {"sql": "SET stageParallelism=2; SELECT {l}.key, SUM({l}.lval + {r}.rval) FROM {l} JOIN {r} ON {l}.key = {r}.key GROUP BY {l}.key"},
       {"sql": "SET stageParallelism=2; SELECT * FROM {l} WHERE lval NOT IN (SELECT rval FROM {r} WHERE rval > 2)"},
-      {
-        "description": "current stage parallelism doesn't work with broadcast join",
-        "sql": "SET stageParallelism=2; SELECT * FROM {l}, {r}",
-        "expectedException": ".*Cannot issue query with stageParallelism > 1 for queries that use SINGLETON exchange.*"
-      }
+      {"sql": "SET stageParallelism=2; SELECT * FROM {l}, {r}"}
     ]
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@pinot.apache.org
For additional commands, e-mail: commits-help@pinot.apache.org