You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@pinot.apache.org by ja...@apache.org on 2024/02/03 19:59:04 UTC

(pinot) branch master updated: [Multi-stage] Optimize query dispatch (#12358)

This is an automated email from the ASF dual-hosted git repository.

jackie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git


The following commit(s) were added to refs/heads/master by this push:
     new c9a82c40a2 [Multi-stage] Optimize query dispatch (#12358)
c9a82c40a2 is described below

commit c9a82c40a2c8bed5e86d8278e0bb57bfc5bee86f
Author: Xiaotian (Jackie) Jiang <17...@users.noreply.github.com>
AuthorDate: Sat Feb 3 11:58:58 2024 -0800

    [Multi-stage] Optimize query dispatch (#12358)
---
 .../runtime/plan/serde/QueryPlanSerDeUtils.java    |  48 ++++------
 .../dispatch/AsyncQueryDispatchResponse.java       |  17 ++--
 .../query/service/dispatch/DispatchClient.java     |  22 +----
 .../query/service/dispatch/DispatchObserver.java   |  17 ++--
 .../query/service/dispatch/QueryDispatcher.java    | 104 +++++++++++++++------
 .../query/service/server/QueryServerTest.java      |  25 +++--
 6 files changed, 130 insertions(+), 103 deletions(-)

diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/serde/QueryPlanSerDeUtils.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/serde/QueryPlanSerDeUtils.java
index c4bded9373..91bbcc2010 100644
--- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/serde/QueryPlanSerDeUtils.java
+++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/serde/QueryPlanSerDeUtils.java
@@ -27,7 +27,6 @@ import java.util.regex.Pattern;
 import org.apache.commons.lang.StringUtils;
 import org.apache.pinot.common.proto.Worker;
 import org.apache.pinot.query.planner.physical.DispatchablePlanFragment;
-import org.apache.pinot.query.planner.physical.DispatchableSubPlan;
 import org.apache.pinot.query.planner.plannode.AbstractPlanNode;
 import org.apache.pinot.query.planner.plannode.StageNodeSerDeUtils;
 import org.apache.pinot.query.routing.MailboxMetadata;
@@ -42,8 +41,8 @@ import org.apache.pinot.query.runtime.plan.StageMetadata;
  * This utility class serialize/deserialize between {@link Worker.StagePlan} elements to Planner elements.
  */
 public class QueryPlanSerDeUtils {
-  private static final Pattern VIRTUAL_SERVER_PATTERN = Pattern.compile(
-      "(?<virtualid>[0-9]+)@(?<host>[^:]+):(?<port>[0-9]+)");
+  private static final Pattern VIRTUAL_SERVER_PATTERN =
+      Pattern.compile("(?<virtualid>[0-9]+)@(?<host>[^:]+):(?<port>[0-9]+)");
 
   private QueryPlanSerDeUtils() {
     // do not instantiate.
@@ -57,18 +56,6 @@ public class QueryPlanSerDeUtils {
     return distributedStagePlans;
   }
 
-  public static Worker.StagePlan serialize(DispatchableSubPlan dispatchableSubPlan, int stageId,
-      QueryServerInstance queryServerInstance, List<Integer> workerIds) {
-    return Worker.StagePlan.newBuilder()
-        .setStageId(stageId)
-        .setStageRoot(StageNodeSerDeUtils.serializeStageNode(
-            (AbstractPlanNode) dispatchableSubPlan.getQueryStageList().get(stageId).getPlanFragment()
-                .getFragmentRoot()))
-        .setStageMetadata(
-            toProtoStageMetadata(dispatchableSubPlan.getQueryStageList().get(stageId), queryServerInstance, workerIds))
-        .build();
-  }
-
   public static VirtualServerAddress protoToAddress(String virtualAddressStr) {
     Matcher matcher = VIRTUAL_SERVER_PATTERN.matcher(virtualAddressStr);
     if (!matcher.matches()) {
@@ -78,8 +65,8 @@ public class QueryPlanSerDeUtils {
     }
 
     // Skipped netty and grpc port as they are not used in worker instance.
-    return new VirtualServerAddress(matcher.group("host"),
-        Integer.parseInt(matcher.group("port")), Integer.parseInt(matcher.group("virtualid")));
+    return new VirtualServerAddress(matcher.group("host"), Integer.parseInt(matcher.group("port")),
+        Integer.parseInt(matcher.group("virtualid")));
   }
 
   public static String addressToProto(VirtualServerAddress serverAddress) {
@@ -145,17 +132,21 @@ public class QueryPlanSerDeUtils {
     return mailboxMetadata;
   }
 
-  private static Worker.StageMetadata toProtoStageMetadata(DispatchablePlanFragment planFragment,
-      QueryServerInstance queryServerInstance, List<Integer> workerIds) {
-    Worker.StageMetadata.Builder builder = Worker.StageMetadata.newBuilder();
-    for (WorkerMetadata workerMetadata : planFragment.getWorkerMetadataList()) {
-      builder.addWorkerMetadata(toProtoWorkerMetadata(workerMetadata));
+  public static Worker.StageMetadata toProtoStageMetadata(List<Worker.WorkerMetadata> workerMetadataList,
+      Map<String, String> customProperties, QueryServerInstance serverInstance, List<Integer> workerIds) {
+    return Worker.StageMetadata.newBuilder().addAllWorkerMetadata(workerMetadataList)
+        .putAllCustomProperty(customProperties)
+        .setServerAddress(String.format("%s:%d", serverInstance.getHostname(), serverInstance.getQueryMailboxPort()))
+        .addAllWorkerIds(workerIds).build();
+  }
+
+  public static List<Worker.WorkerMetadata> toProtoWorkerMetadataList(DispatchablePlanFragment planFragment) {
+    List<WorkerMetadata> workerMetadataList = planFragment.getWorkerMetadataList();
+    List<Worker.WorkerMetadata> protoWorkerMetadataList = new ArrayList<>(workerMetadataList.size());
+    for (WorkerMetadata workerMetadata : workerMetadataList) {
+      protoWorkerMetadataList.add(toProtoWorkerMetadata(workerMetadata));
     }
-    builder.putAllCustomProperty(planFragment.getCustomProperties());
-    builder.setServerAddress(String.format("%s:%d", queryServerInstance.getHostname(),
-        queryServerInstance.getQueryMailboxPort()));
-    builder.addAllWorkerIds(workerIds);
-    return builder.build();
+    return protoWorkerMetadataList;
   }
 
   private static Worker.WorkerMetadata toProtoWorkerMetadata(WorkerMetadata workerMetadata) {
@@ -166,8 +157,7 @@ public class QueryPlanSerDeUtils {
     return builder.build();
   }
 
-  private static Map<Integer, Worker.MailboxMetadata> toProtoMailboxMap(
-      Map<Integer, MailboxMetadata> mailBoxInfosMap) {
+  private static Map<Integer, Worker.MailboxMetadata> toProtoMailboxMap(Map<Integer, MailboxMetadata> mailBoxInfosMap) {
     Map<Integer, Worker.MailboxMetadata> mailboxMetadataMap = new HashMap<>();
     for (Map.Entry<Integer, MailboxMetadata> entry : mailBoxInfosMap.entrySet()) {
       mailboxMetadataMap.put(entry.getKey(), toProtoMailbox(entry.getValue()));
diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/AsyncQueryDispatchResponse.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/AsyncQueryDispatchResponse.java
index 185ba4f607..076d8ce221 100644
--- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/AsyncQueryDispatchResponse.java
+++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/AsyncQueryDispatchResponse.java
@@ -30,27 +30,22 @@ import org.apache.pinot.query.routing.QueryServerInstance;
  * {@link #getThrowable()} to check if it is null.
  */
 class AsyncQueryDispatchResponse {
-  private final QueryServerInstance _virtualServer;
-  private final int _stageId;
+  private final QueryServerInstance _serverInstance;
   private final Worker.QueryResponse _queryResponse;
   private final Throwable _throwable;
 
-  public AsyncQueryDispatchResponse(QueryServerInstance virtualServer, int stageId, Worker.QueryResponse queryResponse,
+  public AsyncQueryDispatchResponse(QueryServerInstance serverInstance, @Nullable Worker.QueryResponse queryResponse,
       @Nullable Throwable throwable) {
-    _virtualServer = virtualServer;
-    _stageId = stageId;
+    _serverInstance = serverInstance;
     _queryResponse = queryResponse;
     _throwable = throwable;
   }
 
-  public QueryServerInstance getVirtualServer() {
-    return _virtualServer;
-  }
-
-  public int getStageId() {
-    return _stageId;
+  public QueryServerInstance getServerInstance() {
+    return _serverInstance;
   }
 
+  @Nullable
   public Worker.QueryResponse getQueryResponse() {
     return _queryResponse;
   }
diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/DispatchClient.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/DispatchClient.java
index 03861a436e..5b036930ce 100644
--- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/DispatchClient.java
+++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/DispatchClient.java
@@ -26,8 +26,6 @@ import java.util.function.Consumer;
 import org.apache.pinot.common.proto.PinotQueryWorkerGrpc;
 import org.apache.pinot.common.proto.Worker;
 import org.apache.pinot.query.routing.QueryServerInstance;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
 
 
 /**
@@ -37,8 +35,8 @@ import org.slf4j.LoggerFactory;
  *       let that take care of pooling. (2) Create a DispatchClient interface and implement pooled/non-pooled versions.
  */
 class DispatchClient {
-  private static final Logger LOGGER = LoggerFactory.getLogger(DispatchClient.class);
   private static final StreamObserver<Worker.CancelResponse> NO_OP_CANCEL_STREAM_OBSERVER = new CancelObserver();
+
   private final ManagedChannel _channel;
   private final PinotQueryWorkerGrpc.PinotQueryWorkerStub _dispatchStub;
 
@@ -51,23 +49,13 @@ class DispatchClient {
     return _channel;
   }
 
-  public void submit(Worker.QueryRequest request, int stageId, QueryServerInstance virtualServer, Deadline deadline,
+  public void submit(Worker.QueryRequest request, QueryServerInstance virtualServer, Deadline deadline,
       Consumer<AsyncQueryDispatchResponse> callback) {
-    try {
-      _dispatchStub.withDeadline(deadline).submit(request, new DispatchObserver(stageId, virtualServer, callback));
-    } catch (Exception e) {
-      LOGGER.error("Query Dispatch failed at client-side", e);
-      callback.accept(new AsyncQueryDispatchResponse(
-          virtualServer, stageId, Worker.QueryResponse.getDefaultInstance(), e));
-    }
+    _dispatchStub.withDeadline(deadline).submit(request, new DispatchObserver(virtualServer, callback));
   }
 
   public void cancel(long requestId) {
-    try {
-      Worker.CancelRequest cancelRequest = Worker.CancelRequest.newBuilder().setRequestId(requestId).build();
-      _dispatchStub.cancel(cancelRequest, NO_OP_CANCEL_STREAM_OBSERVER);
-    } catch (Exception e) {
-      LOGGER.error("Query Cancellation failed at client-side", e);
-    }
+    Worker.CancelRequest cancelRequest = Worker.CancelRequest.newBuilder().setRequestId(requestId).build();
+    _dispatchStub.cancel(cancelRequest, NO_OP_CANCEL_STREAM_OBSERVER);
   }
 }
diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/DispatchObserver.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/DispatchObserver.java
index 2a7425dd99..9b99691655 100644
--- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/DispatchObserver.java
+++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/DispatchObserver.java
@@ -28,15 +28,13 @@ import org.apache.pinot.query.routing.QueryServerInstance;
  * A {@link StreamObserver} used by {@link DispatchClient} to subscribe to the response of a async Query Dispatch call.
  */
 class DispatchObserver implements StreamObserver<Worker.QueryResponse> {
-  private int _stageId;
-  private QueryServerInstance _virtualServer;
-  private Consumer<AsyncQueryDispatchResponse> _callback;
+  private final QueryServerInstance _serverInstance;
+  private final Consumer<AsyncQueryDispatchResponse> _callback;
+
   private Worker.QueryResponse _queryResponse;
 
-  public DispatchObserver(int stageId, QueryServerInstance virtualServer,
-      Consumer<AsyncQueryDispatchResponse> callback) {
-    _stageId = stageId;
-    _virtualServer = virtualServer;
+  public DispatchObserver(QueryServerInstance serverInstance, Consumer<AsyncQueryDispatchResponse> callback) {
+    _serverInstance = serverInstance;
     _callback = callback;
   }
 
@@ -48,12 +46,11 @@ class DispatchObserver implements StreamObserver<Worker.QueryResponse> {
   @Override
   public void onError(Throwable throwable) {
     _callback.accept(
-        new AsyncQueryDispatchResponse(_virtualServer, _stageId, Worker.QueryResponse.getDefaultInstance(),
-            throwable));
+        new AsyncQueryDispatchResponse(_serverInstance, Worker.QueryResponse.getDefaultInstance(), throwable));
   }
 
   @Override
   public void onCompleted() {
-    _callback.accept(new AsyncQueryDispatchResponse(_virtualServer, _stageId, _queryResponse, null));
+    _callback.accept(new AsyncQueryDispatchResponse(_serverInstance, _queryResponse, null));
   }
 }
diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/QueryDispatcher.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/QueryDispatcher.java
index 3f1f43c1eb..2029e31a6f 100644
--- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/QueryDispatcher.java
+++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/QueryDispatcher.java
@@ -27,16 +27,17 @@ import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
+import java.util.concurrent.ArrayBlockingQueue;
 import java.util.concurrent.BlockingQueue;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Executors;
-import java.util.concurrent.LinkedBlockingQueue;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.TimeoutException;
 import javax.annotation.Nullable;
 import org.apache.calcite.util.Pair;
 import org.apache.pinot.common.datablock.DataBlock;
+import org.apache.pinot.common.proto.Plan;
 import org.apache.pinot.common.proto.Worker;
 import org.apache.pinot.common.response.broker.ResultTable;
 import org.apache.pinot.common.utils.DataSchema;
@@ -48,8 +49,10 @@ import org.apache.pinot.query.mailbox.MailboxService;
 import org.apache.pinot.query.planner.PlanFragment;
 import org.apache.pinot.query.planner.physical.DispatchablePlanFragment;
 import org.apache.pinot.query.planner.physical.DispatchableSubPlan;
+import org.apache.pinot.query.planner.plannode.AbstractPlanNode;
 import org.apache.pinot.query.planner.plannode.MailboxReceiveNode;
 import org.apache.pinot.query.planner.plannode.PlanNode;
+import org.apache.pinot.query.planner.plannode.StageNodeSerDeUtils;
 import org.apache.pinot.query.routing.QueryServerInstance;
 import org.apache.pinot.query.routing.WorkerMetadata;
 import org.apache.pinot.query.runtime.blocks.TransferableBlock;
@@ -107,50 +110,76 @@ public class QueryDispatcher {
   void submit(long requestId, DispatchableSubPlan dispatchableSubPlan, long timeoutMs, Map<String, String> queryOptions)
       throws Exception {
     Deadline deadline = Deadline.after(timeoutMs, TimeUnit.MILLISECONDS);
-    BlockingQueue<AsyncQueryDispatchResponse> dispatchCallbacks = new LinkedBlockingQueue<>();
     List<DispatchablePlanFragment> stagePlans = dispatchableSubPlan.getQueryStageList();
     int numStages = stagePlans.size();
-    int numDispatchCalls = 0;
-    // Do not submit the reduce stage (stage 0)
+    Set<QueryServerInstance> serverInstances = new HashSet<>();
+    // TODO: If serialization is slow, consider serializing each stage in parallel
+    StageInfo[] stageInfoMap = new StageInfo[numStages];
+    // Ignore the reduce stage (stage 0)
     for (int stageId = 1; stageId < numStages; stageId++) {
-      for (Map.Entry<QueryServerInstance, List<Integer>> entry : stagePlans.get(stageId)
-          .getServerInstanceToWorkerIdMap().entrySet()) {
-        QueryServerInstance queryServerInstance = entry.getKey();
-        Worker.QueryRequest.Builder queryRequestBuilder = Worker.QueryRequest.newBuilder();
-        queryRequestBuilder.addStagePlan(
-            QueryPlanSerDeUtils.serialize(dispatchableSubPlan, stageId, queryServerInstance, entry.getValue()));
-        Worker.QueryRequest queryRequest =
-            queryRequestBuilder.putMetadata(CommonConstants.Query.Request.MetadataKeys.REQUEST_ID,
-                    String.valueOf(requestId))
-                .putMetadata(CommonConstants.Broker.Request.QueryOptionKey.TIMEOUT_MS, String.valueOf(timeoutMs))
-                .putAllMetadata(queryOptions).build();
-        DispatchClient client = getOrCreateDispatchClient(queryServerInstance);
-        int finalStageId = stageId;
-        _executorService.submit(
-            () -> client.submit(queryRequest, finalStageId, queryServerInstance, deadline, dispatchCallbacks::offer));
-        numDispatchCalls++;
-      }
+      DispatchablePlanFragment stagePlan = stagePlans.get(stageId);
+      serverInstances.addAll(stagePlan.getServerInstanceToWorkerIdMap().keySet());
+      Plan.StageNode rootNode =
+          StageNodeSerDeUtils.serializeStageNode((AbstractPlanNode) stagePlan.getPlanFragment().getFragmentRoot());
+      List<Worker.WorkerMetadata> workerMetadataList = QueryPlanSerDeUtils.toProtoWorkerMetadataList(stagePlan);
+      stageInfoMap[stageId] = new StageInfo(rootNode, workerMetadataList, stagePlan.getCustomProperties());
+    }
+    Map<String, String> requestMetadata = new HashMap<>();
+    requestMetadata.put(CommonConstants.Query.Request.MetadataKeys.REQUEST_ID, Long.toString(requestId));
+    requestMetadata.put(CommonConstants.Broker.Request.QueryOptionKey.TIMEOUT_MS, Long.toString(timeoutMs));
+    requestMetadata.putAll(queryOptions);
+
+    // Submit the query plan to all servers in parallel
+    int numServers = serverInstances.size();
+    BlockingQueue<AsyncQueryDispatchResponse> dispatchCallbacks = new ArrayBlockingQueue<>(numServers);
+    for (QueryServerInstance serverInstance : serverInstances) {
+      _executorService.submit(() -> {
+        try {
+          Worker.QueryRequest.Builder requestBuilder = Worker.QueryRequest.newBuilder();
+          for (int stageId = 1; stageId < numStages; stageId++) {
+            List<Integer> workerIds = stagePlans.get(stageId).getServerInstanceToWorkerIdMap().get(serverInstance);
+            if (workerIds != null) {
+              StageInfo stageInfo = stageInfoMap[stageId];
+              Worker.StageMetadata stageMetadata =
+                  QueryPlanSerDeUtils.toProtoStageMetadata(stageInfo._workerMetadataList, stageInfo._customProperties,
+                      serverInstance, workerIds);
+              Worker.StagePlan stagePlan =
+                  Worker.StagePlan.newBuilder().setStageId(stageId).setStageRoot(stageInfo._rootNode)
+                      .setStageMetadata(stageMetadata).build();
+              requestBuilder.addStagePlan(stagePlan);
+            }
+          }
+          requestBuilder.putAllMetadata(requestMetadata);
+          getOrCreateDispatchClient(serverInstance).submit(requestBuilder.build(), serverInstance, deadline,
+              dispatchCallbacks::offer);
+        } catch (Throwable t) {
+          LOGGER.warn("Caught exception while dispatching query: {} to server: {}", requestId, serverInstance, t);
+          dispatchCallbacks.offer(new AsyncQueryDispatchResponse(serverInstance, null, t));
+        }
+      });
     }
-    int successfulDispatchCalls = 0;
+
+    int numSuccessCalls = 0;
     // TODO: Cancel all dispatched requests if one of the dispatch errors out or deadline is breached.
-    while (!deadline.isExpired() && successfulDispatchCalls < numDispatchCalls) {
+    while (!deadline.isExpired() && numSuccessCalls < numServers) {
       AsyncQueryDispatchResponse resp =
           dispatchCallbacks.poll(deadline.timeRemaining(TimeUnit.MILLISECONDS), TimeUnit.MILLISECONDS);
       if (resp != null) {
         if (resp.getThrowable() != null) {
           throw new RuntimeException(
-              String.format("Error dispatching query to server=%s stage=%s", resp.getVirtualServer(),
-                  resp.getStageId()), resp.getThrowable());
+              String.format("Error dispatching query: %d to server: %s", requestId, resp.getServerInstance()),
+              resp.getThrowable());
         } else {
           Worker.QueryResponse response = resp.getQueryResponse();
+          assert response != null;
           if (response.containsMetadata(CommonConstants.Query.Response.ServerResponseStatus.STATUS_ERROR)) {
             throw new RuntimeException(
-                String.format("Unable to execute query plan at stage %s on server %s: ERROR: %s", resp.getStageId(),
-                    resp.getVirtualServer(),
+                String.format("Unable to execute query plan for request: %d on server: %s, ERROR: %s", requestId,
+                    resp.getServerInstance(),
                     response.getMetadataOrDefault(CommonConstants.Query.Response.ServerResponseStatus.STATUS_ERROR,
                         "null")));
           }
-          successfulDispatchCalls++;
+          numSuccessCalls++;
         }
       }
     }
@@ -159,6 +188,19 @@ public class QueryDispatcher {
     }
   }
 
+  private static class StageInfo {
+    final Plan.StageNode _rootNode;
+    final List<Worker.WorkerMetadata> _workerMetadataList;
+    final Map<String, String> _customProperties;
+
+    StageInfo(Plan.StageNode rootNode, List<Worker.WorkerMetadata> workerMetadataList,
+        Map<String, String> customProperties) {
+      _rootNode = rootNode;
+      _workerMetadataList = workerMetadataList;
+      _customProperties = customProperties;
+    }
+  }
+
   private void cancel(long requestId, DispatchableSubPlan dispatchableSubPlan) {
     List<DispatchablePlanFragment> stagePlans = dispatchableSubPlan.getQueryStageList();
     int numStages = stagePlans.size();
@@ -168,7 +210,11 @@ public class QueryDispatcher {
       serversToCancel.addAll(stagePlans.get(stageId).getServerInstanceToWorkerIdMap().keySet());
     }
     for (QueryServerInstance queryServerInstance : serversToCancel) {
-      getOrCreateDispatchClient(queryServerInstance).cancel(requestId);
+      try {
+        getOrCreateDispatchClient(queryServerInstance).cancel(requestId);
+      } catch (Throwable t) {
+        LOGGER.warn("Caught exception while cancelling query: {} on server: {}", requestId, queryServerInstance, t);
+      }
     }
   }
 
diff --git a/pinot-query-runtime/src/test/java/org/apache/pinot/query/service/server/QueryServerTest.java b/pinot-query-runtime/src/test/java/org/apache/pinot/query/service/server/QueryServerTest.java
index 140851f666..4e5a003427 100644
--- a/pinot-query-runtime/src/test/java/org/apache/pinot/query/service/server/QueryServerTest.java
+++ b/pinot-query-runtime/src/test/java/org/apache/pinot/query/service/server/QueryServerTest.java
@@ -18,7 +18,6 @@
  */
 package org.apache.pinot.query.service.server;
 
-import com.google.common.collect.ImmutableList;
 import com.google.common.collect.Lists;
 import io.grpc.Deadline;
 import io.grpc.ManagedChannel;
@@ -30,6 +29,7 @@ import java.util.Map;
 import java.util.Random;
 import java.util.concurrent.TimeUnit;
 import org.apache.pinot.common.proto.PinotQueryWorkerGrpc;
+import org.apache.pinot.common.proto.Plan;
 import org.apache.pinot.common.proto.Worker;
 import org.apache.pinot.core.routing.TimeBoundaryInfo;
 import org.apache.pinot.query.QueryEnvironment;
@@ -37,7 +37,9 @@ import org.apache.pinot.query.QueryEnvironmentTestBase;
 import org.apache.pinot.query.QueryTestSet;
 import org.apache.pinot.query.planner.physical.DispatchablePlanFragment;
 import org.apache.pinot.query.planner.physical.DispatchableSubPlan;
+import org.apache.pinot.query.planner.plannode.AbstractPlanNode;
 import org.apache.pinot.query.planner.plannode.PlanNode;
+import org.apache.pinot.query.planner.plannode.StageNodeSerDeUtils;
 import org.apache.pinot.query.routing.QueryServerInstance;
 import org.apache.pinot.query.routing.WorkerMetadata;
 import org.apache.pinot.query.runtime.QueryRunner;
@@ -228,15 +230,24 @@ public class QueryServerTest extends QueryTestSet {
   }
 
   private Worker.QueryRequest getQueryRequest(DispatchableSubPlan dispatchableSubPlan, int stageId) {
-    Map<QueryServerInstance, List<Integer>> serverInstanceToWorkerIdMap =
-        dispatchableSubPlan.getQueryStageList().get(stageId).getServerInstanceToWorkerIdMap();
+    DispatchablePlanFragment planFragment = dispatchableSubPlan.getQueryStageList().get(stageId);
+    Map<QueryServerInstance, List<Integer>> serverInstanceToWorkerIdMap = planFragment.getServerInstanceToWorkerIdMap();
     // this particular test set requires the request to have a single QueryServerInstance to dispatch to
     // as it is not testing the multi-tenancy dispatch (which is in the QueryDispatcherTest)
-    QueryServerInstance serverInstance = serverInstanceToWorkerIdMap.keySet().iterator().next();
-    int workerId = serverInstanceToWorkerIdMap.get(serverInstance).get(0);
+    Map.Entry<QueryServerInstance, List<Integer>> entry = serverInstanceToWorkerIdMap.entrySet().iterator().next();
+    QueryServerInstance serverInstance = entry.getKey();
+    List<Integer> workerIds = entry.getValue();
+    Plan.StageNode stageRoot =
+        StageNodeSerDeUtils.serializeStageNode((AbstractPlanNode) planFragment.getPlanFragment().getFragmentRoot());
+    List<Worker.WorkerMetadata> protoWorkerMetadataList = QueryPlanSerDeUtils.toProtoWorkerMetadataList(planFragment);
+    Worker.StageMetadata stageMetadata =
+        QueryPlanSerDeUtils.toProtoStageMetadata(protoWorkerMetadataList, planFragment.getCustomProperties(),
+            serverInstance, workerIds);
+    Worker.StagePlan stagePlan =
+        Worker.StagePlan.newBuilder().setStageId(stageId).setStageRoot(stageRoot).setStageMetadata(stageMetadata)
+            .build();
 
-    return Worker.QueryRequest.newBuilder().addStagePlan(
-            QueryPlanSerDeUtils.serialize(dispatchableSubPlan, stageId, serverInstance, ImmutableList.of(workerId)))
+    return Worker.QueryRequest.newBuilder().addStagePlan(stagePlan)
         // the default configurations that must exist.
         .putMetadata(CommonConstants.Query.Request.MetadataKeys.REQUEST_ID,
             String.valueOf(RANDOM_REQUEST_ID_GEN.nextLong()))


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@pinot.apache.org
For additional commands, e-mail: commits-help@pinot.apache.org