You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@pinot.apache.org by ja...@apache.org on 2024/02/04 19:20:12 UTC

(pinot) branch master updated: [Multi-stage] Ser/de stage plan in parallel (#12363)

This is an automated email from the ASF dual-hosted git repository.

jackie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git


The following commit(s) were added to refs/heads/master by this push:
     new 0a5a134fb5 [Multi-stage] Ser/de stage plan in parallel (#12363)
0a5a134fb5 is described below

commit 0a5a134fb57fa24450fb3d385cc4f3756a32b343
Author: Xiaotian (Jackie) Jiang <17...@users.noreply.github.com>
AuthorDate: Sun Feb 4 11:20:06 2024 -0800

    [Multi-stage] Ser/de stage plan in parallel (#12363)
---
 .../query/runtime/plan/DistributedStagePlan.java   | 24 ++------
 .../runtime/plan/serde/QueryPlanSerDeUtils.java    | 30 ++++------
 .../query/service/dispatch/QueryDispatcher.java    | 69 +++++++++++-----------
 .../pinot/query/service/server/QueryServer.java    | 68 +++++++++++++--------
 .../service/dispatch/QueryDispatcherTest.java      | 13 ++--
 5 files changed, 97 insertions(+), 107 deletions(-)

diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/DistributedStagePlan.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/DistributedStagePlan.java
index 2aa269e6aa..62e8d19254 100644
--- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/DistributedStagePlan.java
+++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/DistributedStagePlan.java
@@ -32,14 +32,10 @@ import org.apache.pinot.query.routing.WorkerMetadata;
  * <p>It is also the extended version of the {@link org.apache.pinot.core.query.request.ServerQueryRequest}.
  */
 public class DistributedStagePlan {
-  private int _stageId;
-  private VirtualServerAddress _server;
-  private PlanNode _stageRoot;
-  private StageMetadata _stageMetadata;
-
-  public DistributedStagePlan(int stageId) {
-    _stageId = stageId;
-  }
+  private final int _stageId;
+  private final VirtualServerAddress _server;
+  private final PlanNode _stageRoot;
+  private final StageMetadata _stageMetadata;
 
   public DistributedStagePlan(int stageId, VirtualServerAddress server, PlanNode stageRoot,
       StageMetadata stageMetadata) {
@@ -65,18 +61,6 @@ public class DistributedStagePlan {
     return _stageMetadata;
   }
 
-  public void setServer(VirtualServerAddress serverAddress) {
-    _server = serverAddress;
-  }
-
-  public void setStageRoot(PlanNode stageRoot) {
-    _stageRoot = stageRoot;
-  }
-
-  public void setStageMetadata(StageMetadata stageMetadata) {
-    _stageMetadata = stageMetadata;
-  }
-
   public WorkerMetadata getCurrentWorkerMetadata() {
     return _stageMetadata.getWorkerMetadataList().get(_server.workerId());
   }
diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/serde/QueryPlanSerDeUtils.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/serde/QueryPlanSerDeUtils.java
index 91bbcc2010..f4b34a145a 100644
--- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/serde/QueryPlanSerDeUtils.java
+++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/serde/QueryPlanSerDeUtils.java
@@ -48,14 +48,6 @@ public class QueryPlanSerDeUtils {
     // do not instantiate.
   }
 
-  public static List<DistributedStagePlan> deserializeStagePlan(Worker.QueryRequest request) {
-    List<DistributedStagePlan> distributedStagePlans = new ArrayList<>();
-    for (Worker.StagePlan stagePlan : request.getStagePlanList()) {
-      distributedStagePlans.addAll(deserializeStagePlan(stagePlan));
-    }
-    return distributedStagePlans;
-  }
-
   public static VirtualServerAddress protoToAddress(String virtualAddressStr) {
     Matcher matcher = VIRTUAL_SERVER_PATTERN.matcher(virtualAddressStr);
     if (!matcher.matches()) {
@@ -73,21 +65,21 @@ public class QueryPlanSerDeUtils {
     return String.format("%s@%s:%s", serverAddress.workerId(), serverAddress.hostname(), serverAddress.port());
   }
 
-  private static List<DistributedStagePlan> deserializeStagePlan(Worker.StagePlan stagePlan) {
-    List<DistributedStagePlan> distributedStagePlans = new ArrayList<>();
-    String serverAddress = stagePlan.getStageMetadata().getServerAddress();
+  public static List<DistributedStagePlan> deserializeStagePlan(Worker.StagePlan stagePlan) {
+    int stageId = stagePlan.getStageId();
+    Worker.StageMetadata protoStageMetadata = stagePlan.getStageMetadata();
+    String serverAddress = protoStageMetadata.getServerAddress();
     String[] hostPort = StringUtils.split(serverAddress, ':');
     String hostname = hostPort[0];
     int port = Integer.parseInt(hostPort[1]);
     AbstractPlanNode stageRoot = StageNodeSerDeUtils.deserializeStageNode(stagePlan.getStageRoot());
-    StageMetadata stageMetadata = fromProtoStageMetadata(stagePlan.getStageMetadata());
-    for (int workerId : stagePlan.getStageMetadata().getWorkerIdsList()) {
-      DistributedStagePlan distributedStagePlan = new DistributedStagePlan(stagePlan.getStageId());
-      VirtualServerAddress virtualServerAddress = new VirtualServerAddress(hostname, port, workerId);
-      distributedStagePlan.setServer(virtualServerAddress);
-      distributedStagePlan.setStageRoot(stageRoot);
-      distributedStagePlan.setStageMetadata(stageMetadata);
-      distributedStagePlans.add(distributedStagePlan);
+    StageMetadata stageMetadata = fromProtoStageMetadata(protoStageMetadata);
+    List<Integer> workerIds = protoStageMetadata.getWorkerIdsList();
+    List<DistributedStagePlan> distributedStagePlans = new ArrayList<>(workerIds.size());
+    for (int workerId : workerIds) {
+      distributedStagePlans.add(
+          new DistributedStagePlan(stageId, new VirtualServerAddress(hostname, port, workerId), stageRoot,
+              stageMetadata));
     }
     return distributedStagePlans;
   }
diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/QueryDispatcher.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/QueryDispatcher.java
index 2029e31a6f..8336d9aa27 100644
--- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/QueryDispatcher.java
+++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/QueryDispatcher.java
@@ -29,6 +29,7 @@ import java.util.Map;
 import java.util.Set;
 import java.util.concurrent.ArrayBlockingQueue;
 import java.util.concurrent.BlockingQueue;
+import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Executors;
@@ -111,22 +112,39 @@ public class QueryDispatcher {
       throws Exception {
     Deadline deadline = Deadline.after(timeoutMs, TimeUnit.MILLISECONDS);
     List<DispatchablePlanFragment> stagePlans = dispatchableSubPlan.getQueryStageList();
-    int numStages = stagePlans.size();
-    Set<QueryServerInstance> serverInstances = new HashSet<>();
-    // TODO: If serialization is slow, consider serializing each stage in parallel
-    StageInfo[] stageInfoMap = new StageInfo[numStages];
     // Ignore the reduce stage (stage 0)
-    for (int stageId = 1; stageId < numStages; stageId++) {
-      DispatchablePlanFragment stagePlan = stagePlans.get(stageId);
+    int numStages = stagePlans.size() - 1;
+    Set<QueryServerInstance> serverInstances = new HashSet<>();
+    // Serialize the stage plans in parallel
+    Plan.StageNode[] stageRootNodes = new Plan.StageNode[numStages];
+    //noinspection unchecked
+    List<Worker.WorkerMetadata>[] stageWorkerMetadataLists = new List[numStages];
+    CompletableFuture<?>[] stagePlanSerializationStubs = new CompletableFuture[2 * numStages];
+    for (int i = 0; i < numStages; i++) {
+      DispatchablePlanFragment stagePlan = stagePlans.get(i + 1);
       serverInstances.addAll(stagePlan.getServerInstanceToWorkerIdMap().keySet());
-      Plan.StageNode rootNode =
-          StageNodeSerDeUtils.serializeStageNode((AbstractPlanNode) stagePlan.getPlanFragment().getFragmentRoot());
-      List<Worker.WorkerMetadata> workerMetadataList = QueryPlanSerDeUtils.toProtoWorkerMetadataList(stagePlan);
-      stageInfoMap[stageId] = new StageInfo(rootNode, workerMetadataList, stagePlan.getCustomProperties());
+      int finalI = i;
+      stagePlanSerializationStubs[2 * i] = CompletableFuture.runAsync(() -> stageRootNodes[finalI] =
+              StageNodeSerDeUtils.serializeStageNode((AbstractPlanNode) stagePlan.getPlanFragment().getFragmentRoot()),
+          _executorService);
+      stagePlanSerializationStubs[2 * i + 1] = CompletableFuture.runAsync(
+          () -> stageWorkerMetadataLists[finalI] = QueryPlanSerDeUtils.toProtoWorkerMetadataList(stagePlan),
+          _executorService);
+    }
+    try {
+      CompletableFuture.allOf(stagePlanSerializationStubs)
+          .get(deadline.timeRemaining(TimeUnit.MILLISECONDS), TimeUnit.MILLISECONDS);
+    } finally {
+      for (CompletableFuture<?> future : stagePlanSerializationStubs) {
+        if (!future.isDone()) {
+          future.cancel(true);
+        }
+      }
     }
     Map<String, String> requestMetadata = new HashMap<>();
     requestMetadata.put(CommonConstants.Query.Request.MetadataKeys.REQUEST_ID, Long.toString(requestId));
-    requestMetadata.put(CommonConstants.Broker.Request.QueryOptionKey.TIMEOUT_MS, Long.toString(timeoutMs));
+    requestMetadata.put(CommonConstants.Broker.Request.QueryOptionKey.TIMEOUT_MS,
+        Long.toString(deadline.timeRemaining(TimeUnit.MILLISECONDS)));
     requestMetadata.putAll(queryOptions);
 
     // Submit the query plan to all servers in parallel
@@ -136,17 +154,13 @@ public class QueryDispatcher {
       _executorService.submit(() -> {
         try {
           Worker.QueryRequest.Builder requestBuilder = Worker.QueryRequest.newBuilder();
-          for (int stageId = 1; stageId < numStages; stageId++) {
-            List<Integer> workerIds = stagePlans.get(stageId).getServerInstanceToWorkerIdMap().get(serverInstance);
+          for (int i = 0; i < numStages; i++) {
+            DispatchablePlanFragment stagePlan = stagePlans.get(i + 1);
+            List<Integer> workerIds = stagePlan.getServerInstanceToWorkerIdMap().get(serverInstance);
             if (workerIds != null) {
-              StageInfo stageInfo = stageInfoMap[stageId];
-              Worker.StageMetadata stageMetadata =
-                  QueryPlanSerDeUtils.toProtoStageMetadata(stageInfo._workerMetadataList, stageInfo._customProperties,
-                      serverInstance, workerIds);
-              Worker.StagePlan stagePlan =
-                  Worker.StagePlan.newBuilder().setStageId(stageId).setStageRoot(stageInfo._rootNode)
-                      .setStageMetadata(stageMetadata).build();
-              requestBuilder.addStagePlan(stagePlan);
+              requestBuilder.addStagePlan(Worker.StagePlan.newBuilder().setStageId(i).setStageRoot(stageRootNodes[i])
+                  .setStageMetadata(QueryPlanSerDeUtils.toProtoStageMetadata(stageWorkerMetadataLists[i],
+                      stagePlan.getCustomProperties(), serverInstance, workerIds)).build());
             }
           }
           requestBuilder.putAllMetadata(requestMetadata);
@@ -188,19 +202,6 @@ public class QueryDispatcher {
     }
   }
 
-  private static class StageInfo {
-    final Plan.StageNode _rootNode;
-    final List<Worker.WorkerMetadata> _workerMetadataList;
-    final Map<String, String> _customProperties;
-
-    StageInfo(Plan.StageNode rootNode, List<Worker.WorkerMetadata> workerMetadataList,
-        Map<String, String> customProperties) {
-      _rootNode = rootNode;
-      _workerMetadataList = workerMetadataList;
-      _customProperties = customProperties;
-    }
-  }
-
   private void cancel(long requestId, DispatchableSubPlan dispatchableSubPlan) {
     List<DispatchablePlanFragment> stagePlans = dispatchableSubPlan.getQueryStageList();
     int numStages = stagePlans.size();
diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/server/QueryServer.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/server/QueryServer.java
index 4a4daa148b..ecfa9b09f8 100644
--- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/server/QueryServer.java
+++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/server/QueryServer.java
@@ -20,9 +20,7 @@ package org.apache.pinot.query.service.server;
 
 import io.grpc.Server;
 import io.grpc.ServerBuilder;
-import io.grpc.Status;
 import io.grpc.stub.StreamObserver;
-import java.util.Collections;
 import java.util.List;
 import java.util.Map;
 import java.util.concurrent.CompletableFuture;
@@ -97,42 +95,60 @@ public class QueryServer extends PinotQueryWorkerGrpc.PinotQueryWorkerImplBase {
 
   @Override
   public void submit(Worker.QueryRequest request, StreamObserver<Worker.QueryResponse> responseObserver) {
-    // Deserialize the request
-    List<DistributedStagePlan> distributedStagePlans;
-    Map<String, String> requestMetadata;
-    requestMetadata = Collections.unmodifiableMap(request.getMetadataMap());
+    Map<String, String> requestMetadata = request.getMetadataMap();
     long requestId = Long.parseLong(requestMetadata.get(CommonConstants.Query.Request.MetadataKeys.REQUEST_ID));
     long timeoutMs = Long.parseLong(requestMetadata.get(CommonConstants.Broker.Request.QueryOptionKey.TIMEOUT_MS));
     long deadlineMs = System.currentTimeMillis() + timeoutMs;
-    // 1. Deserialized request
-    try {
-      distributedStagePlans = QueryPlanSerDeUtils.deserializeStagePlan(request);
-    } catch (Exception e) {
-      LOGGER.error("Caught exception while deserializing the request: {}", requestId, e);
-      responseObserver.onError(Status.INVALID_ARGUMENT.withDescription("Bad request").withCause(e).asException());
-      return;
-    }
-    // 2. Submit distributed stage plans, await response successful or any failure which cancels all other tasks.
-    int numSubmission = distributedStagePlans.size();
-    CompletableFuture<?>[] submissionStubs = new CompletableFuture[numSubmission];
-    for (int i = 0; i < numSubmission; i++) {
-      DistributedStagePlan distributedStagePlan = distributedStagePlans.get(i);
-      submissionStubs[i] =
-          CompletableFuture.runAsync(() -> _queryRunner.processQuery(distributedStagePlan, requestMetadata),
-              _querySubmissionExecutorService);
+
+    List<Worker.StagePlan> stagePlans = request.getStagePlanList();
+    int numStages = stagePlans.size();
+    CompletableFuture<?>[] stageSubmissionStubs = new CompletableFuture[numStages];
+    for (int i = 0; i < numStages; i++) {
+      Worker.StagePlan stagePlan = stagePlans.get(i);
+      stageSubmissionStubs[i] = CompletableFuture.runAsync(() -> {
+        List<DistributedStagePlan> workerPlans;
+        try {
+          workerPlans = QueryPlanSerDeUtils.deserializeStagePlan(stagePlan);
+        } catch (Exception e) {
+          throw new RuntimeException(
+              String.format("Caught exception while deserializing stage plan for request: %d, stage id: %d", requestId,
+                  stagePlan.getStageId()), e);
+        }
+        int numWorkers = workerPlans.size();
+        CompletableFuture<?>[] workerSubmissionStubs = new CompletableFuture[numWorkers];
+        for (int j = 0; j < numWorkers; j++) {
+          DistributedStagePlan workerPlan = workerPlans.get(j);
+          workerSubmissionStubs[j] =
+              CompletableFuture.runAsync(() -> _queryRunner.processQuery(workerPlan, requestMetadata),
+                  _querySubmissionExecutorService);
+        }
+        try {
+          CompletableFuture.allOf(workerSubmissionStubs)
+              .get(deadlineMs - System.currentTimeMillis(), TimeUnit.MILLISECONDS);
+        } catch (Exception e) {
+          throw new RuntimeException(
+              String.format("Caught exception while submitting request: %d, stage id: %d", requestId,
+                  stagePlan.getStageId()), e);
+        } finally {
+          for (CompletableFuture<?> future : workerSubmissionStubs) {
+            if (!future.isDone()) {
+              future.cancel(true);
+            }
+          }
+        }
+      }, _querySubmissionExecutorService);
     }
     try {
-      CompletableFuture.allOf(submissionStubs).get(deadlineMs - System.currentTimeMillis(), TimeUnit.MILLISECONDS);
+      CompletableFuture.allOf(stageSubmissionStubs).get(deadlineMs - System.currentTimeMillis(), TimeUnit.MILLISECONDS);
     } catch (Exception e) {
-      LOGGER.error("error occurred during stage submission for {}:\n{}", requestId, e);
+      LOGGER.error("Caught exception while submitting request: {}", requestId, e);
       responseObserver.onNext(Worker.QueryResponse.newBuilder()
           .putMetadata(CommonConstants.Query.Response.ServerResponseStatus.STATUS_ERROR,
               QueryException.getTruncatedStackTrace(e)).build());
       responseObserver.onCompleted();
       return;
     } finally {
-      // Cancel all ongoing submission
-      for (CompletableFuture<?> future : submissionStubs) {
+      for (CompletableFuture<?> future : stageSubmissionStubs) {
         if (!future.isDone()) {
           future.cancel(true);
         }
diff --git a/pinot-query-runtime/src/test/java/org/apache/pinot/query/service/dispatch/QueryDispatcherTest.java b/pinot-query-runtime/src/test/java/org/apache/pinot/query/service/dispatch/QueryDispatcherTest.java
index c7be429297..5af8f038c3 100644
--- a/pinot-query-runtime/src/test/java/org/apache/pinot/query/service/dispatch/QueryDispatcherTest.java
+++ b/pinot-query-runtime/src/test/java/org/apache/pinot/query/service/dispatch/QueryDispatcherTest.java
@@ -25,6 +25,7 @@ import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.TimeoutException;
 import java.util.concurrent.atomic.AtomicLong;
 import org.apache.pinot.common.proto.Worker;
 import org.apache.pinot.query.QueryEnvironment;
@@ -197,15 +198,11 @@ public class QueryDispatcherTest extends QueryTestSet {
     Mockito.reset(failingQueryServer);
   }
 
-  @Test
-  public void testQueryDispatcherThrowsWhenDeadlinePreExpiredAndAsyncResponseNotPolled() {
+  @Test(expectedExceptions = TimeoutException.class)
+  public void testQueryDispatcherThrowsWhenDeadlinePreExpiredAndAsyncResponseNotPolled()
+      throws Exception {
     String sql = "SELECT * FROM a WHERE col1 = 'foo'";
     DispatchableSubPlan dispatchableSubPlan = _queryEnvironment.planQuery(sql);
-    try {
-      _queryDispatcher.submit(REQUEST_ID_GEN.getAndIncrement(), dispatchableSubPlan, 0L, Collections.emptyMap());
-      Assert.fail("Method call above should have failed");
-    } catch (Exception e) {
-      Assert.assertTrue(e.getMessage().contains("Timed out waiting"));
-    }
+    _queryDispatcher.submit(REQUEST_ID_GEN.getAndIncrement(), dispatchableSubPlan, 0L, Collections.emptyMap());
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@pinot.apache.org
For additional commands, e-mail: commits-help@pinot.apache.org