You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@iotdb.apache.org by ja...@apache.org on 2023/02/24 06:09:20 UTC

[iotdb] branch master updated: [IOTDB-5454] Support shuffle function of DataExchangeModule

This is an automated email from the ASF dual-hosted git repository.

jackietien pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iotdb.git


The following commit(s) were added to refs/heads/master by this push:
     new 854e4c8b9c [IOTDB-5454] Support shuffle function of DataExchangeModule
854e4c8b9c is described below

commit 854e4c8b9cb43b607e0d36ffe1b6d354eb43553b
Author: Liao Lanyu <14...@qq.com>
AuthorDate: Fri Feb 24 14:09:13 2023 +0800

    [IOTDB-5454] Support shuffle function of DataExchangeModule
---
 .../db/it/alignbydevice/IoTDBShuffleSink1IT.java   |  150 +++
 .../db/it/alignbydevice/IoTDBShuffleSink2IT.java   |  205 +++
 .../iotdb/db/mpp/execution/driver/Driver.java      |   18 +-
 .../db/mpp/execution/driver/DriverContext.java     |   12 +-
 .../iotdb/db/mpp/execution/driver/IDriver.java     |    6 +-
 .../exchange/IMPPDataExchangeManager.java          |   34 +-
 .../execution/exchange/MPPDataExchangeManager.java |  333 +++--
 .../mpp/execution/exchange/SharedTsBlockQueue.java |   26 +-
 .../exchange/sink/DownStreamChannelIndex.java}     |   27 +-
 .../exchange/sink/DownStreamChannelLocation.java   |  111 ++
 .../exchange/{ISinkHandle.java => sink/ISink.java} |   41 +-
 .../exchange/sink/ISinkChannel.java}               |   28 +-
 .../exchange/sink/ISinkHandle.java}                |   34 +-
 .../LocalSinkChannel.java}                         |   89 +-
 .../execution/exchange/sink/ShuffleSinkHandle.java |  299 +++++
 .../{SinkHandle.java => sink/SinkChannel.java}     |  163 +--
 .../exchange/{ => source}/ISourceHandle.java       |    2 +-
 .../exchange/{ => source}/LocalSourceHandle.java   |   24 +-
 .../exchange/{ => source}/SourceHandle.java        |   44 +-
 .../fragment/FragmentInstanceExecution.java        |   20 +-
 .../fragment/FragmentInstanceManager.java          |   14 +-
 .../operator/sink/IdentitySinkOperator.java        |  132 ++
 .../operator/sink/ShuffleHelperOperator.java       |  134 ++
 .../operator/source/ExchangeOperator.java          |    2 +-
 .../db/mpp/execution/schedule/task/DriverTask.java |    4 +-
 .../db/mpp/plan/execution/QueryExecution.java      |    5 +-
 .../plan/execution/memory/MemorySourceHandle.java  |    2 +-
 .../plan/planner/LocalExecutionPlanContext.java    |   10 +-
 .../db/mpp/plan/planner/OperatorTreeGenerator.java |  123 +-
 .../planner/distribution/DistributionPlanner.java  |  140 +-
 .../planner/distribution/ExchangeNodeAdder.java    |  196 +--
 .../planner/distribution/NodeGroupContext.java     |   49 +-
 .../SimpleFragmentParallelPlanner.java             |   44 +-
 .../db/mpp/plan/planner/plan/FragmentInstance.java |   15 -
 .../plan/planner/plan/node/PlanGraphPrinter.java   |   27 +-
 .../mpp/plan/planner/plan/node/PlanNodeType.java   |   13 +-
 .../db/mpp/plan/planner/plan/node/PlanVisitor.java |   15 +-
 .../planner/plan/node/process/ExchangeNode.java    |   31 +-
 .../plan/node/process/HorizontallyConcatNode.java  |    2 +-
 .../planner/plan/node/sink/FragmentSinkNode.java   |  202 ---
 .../planner/plan/node/sink/IdentitySinkNode.java   |   99 ++
 .../plan/node/sink/MultiChildrenSinkNode.java      |  117 ++
 .../planner/plan/node/sink/ShuffleSinkNode.java    |  105 ++
 .../iotdb/db/mpp/execution/DataDriverTest.java     |    8 +-
 ...nkHandleTest.java => LocalSinkChannelTest.java} |   96 +-
 .../execution/exchange/LocalSourceHandleTest.java  |    1 +
 .../exchange/MPPDataExchangeManagerTest.java       |   65 +-
 .../{SinkHandleTest.java => SinkChannelTest.java}  |  270 ++--
 .../mpp/execution/exchange/SourceHandleTest.java   |    6 +
 .../{StubSinkHandle.java => StubSink.java}         |   10 +-
 .../distribution/AggregationDistributionTest.java  |   17 +-
 .../plan/distribution/AlignedByDeviceTest.java     | 1390 +++++++++++++++++++-
 .../read/DeviceSchemaScanNodeSerdeTest.java        |   22 +-
 .../NodeManagementMemoryMergeNodeSerdeTest.java    |   24 +-
 .../metadata/read/SchemaCountNodeSerdeTest.java    |   44 +-
 .../read/TimeSeriesSchemaScanNodeSerdeTest.java    |   24 +-
 .../plan/node/process/ExchangeNodeSerdeTest.java   |   24 +-
 ...rdeTest.java => IdentitySinkNodeSerdeTest.java} |   45 +-
 ...st.java => ShuffleSinkHandleNodeSerdeTest.java} |   46 +-
 thrift/src/main/thrift/datanode.thrift             |    4 +
 60 files changed, 3979 insertions(+), 1264 deletions(-)

diff --git a/integration-test/src/test/java/org/apache/iotdb/db/it/alignbydevice/IoTDBShuffleSink1IT.java b/integration-test/src/test/java/org/apache/iotdb/db/it/alignbydevice/IoTDBShuffleSink1IT.java
new file mode 100644
index 0000000000..73cc3c71de
--- /dev/null
+++ b/integration-test/src/test/java/org/apache/iotdb/db/it/alignbydevice/IoTDBShuffleSink1IT.java
@@ -0,0 +1,150 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iotdb.db.it.alignbydevice;
+
+import org.apache.iotdb.it.env.EnvFactory;
+import org.apache.iotdb.it.framework.IoTDBTestRunner;
+import org.apache.iotdb.itbase.category.ClusterIT;
+import org.apache.iotdb.itbase.category.LocalStandaloneIT;
+
+import org.junit.AfterClass;
+import org.junit.BeforeClass;
+import org.junit.Test;
+import org.junit.experimental.categories.Category;
+import org.junit.runner.RunWith;
+
+import static org.apache.iotdb.db.it.utils.TestUtils.prepareData;
+import static org.apache.iotdb.db.it.utils.TestUtils.resultSetEqualTest;
+
+@RunWith(IoTDBTestRunner.class)
+@Category({LocalStandaloneIT.class, ClusterIT.class})
+public class IoTDBShuffleSink1IT {
+  private static final String[] SINGLE_SERIES =
+      new String[] {
+        "create database root.single",
+        "insert into root.single.d1(time,s1) values (1,2)",
+        "insert into root.single.d1(time,s1) values (now(),3)",
+        "insert into root.single.d2(time,s1) values (1,4)",
+        "insert into root.single.d2(time,s1) values (now(),5)"
+      };
+  // two devices
+  private static final String[] MULTI_SERIES =
+      new String[] {
+        "create database root.sg",
+        "insert into root.sg.d1(time,s1,s2) values (1,2,2)",
+        "insert into root.sg.d1(time,s1,s2) values (now(),3,3)",
+        "insert into root.sg.d2(time,s1,s2) values (1,4,4)",
+        "insert into root.sg.d2(time,s1,s2) values (now(),5,5)"
+      };
+
+  @BeforeClass
+  public static void setUp() throws Exception {
+    EnvFactory.getEnv().getConfig().getCommonConfig().setDataRegionGroupExtensionPolicy("CUSTOM");
+    EnvFactory.getEnv().getConfig().getCommonConfig().setDefaultDataRegionGroupNumPerDatabase(2);
+    EnvFactory.getEnv().initClusterEnvironment();
+    prepareData(SINGLE_SERIES);
+    prepareData(MULTI_SERIES);
+  }
+
+  @AfterClass
+  public static void tearDown() throws Exception {
+    EnvFactory.getEnv().cleanClusterEnvironment();
+  }
+
+  @Test
+  public void testCountAlignByDeviceOrderByDeviceWithoutValueFilter() {
+    // result of SINGLE_SERIES
+    String expectedHeader1 = "Device,count(s1),";
+    String[] retArray1 = new String[] {"root.single.d1,2,", "root.single.d2,2,"};
+
+    resultSetEqualTest(
+        "select count(s1) from root.single.** align by device", expectedHeader1, retArray1);
+
+    // result of MULTI_SERIES
+    String expectedHeader2 = "Device,count(s1),count(s2),";
+    String[] retArray2 = new String[] {"root.sg.d1,2,2,", "root.sg.d2,2,2,"};
+
+    resultSetEqualTest(
+        "select count(s1),count(s2) from root.sg.** align by device", expectedHeader2, retArray2);
+  }
+
+  @Test
+  public void testCountAlignByDeviceOrderByDeviceWithValueFilter() {
+    // result of SINGLE_SERIES
+    String expectedHeader1 = "Device,count(s1),";
+    String[] retArray1 = new String[] {"root.single.d1,2,", "root.single.d2,1,"};
+
+    resultSetEqualTest(
+        "select count(s1),count(s2) from root.single.** where s1 <= 4 align by device",
+        expectedHeader1,
+        retArray1);
+
+    // result of MULTI_SERIES
+    String expectedHeader2 = "Device,count(s1),count(s2),";
+    String[] retArray2 = new String[] {"root.sg.d1,2,2,", "root.sg.d2,1,1,"};
+
+    resultSetEqualTest(
+        "select count(s1),count(s2) from root.sg.** where s1 <= 4 align by device",
+        expectedHeader2,
+        retArray2);
+  }
+
+  @Test
+  public void testCountAlignByDeviceOrderByTimeWithoutValueFilter() {
+    // result of SINGLE_SERIES
+    String expectedHeader1 = "Device,count(s1),";
+    String[] retArray1 = new String[] {"root.single.d1,2,", "root.single.d2,2,"};
+
+    resultSetEqualTest(
+        "select count(s1),count(s2) from root.single.** order by time align by device",
+        expectedHeader1,
+        retArray1);
+
+    // result of MULTI_SERIES
+    String expectedHeader2 = "Device,count(s1),count(s2),";
+    String[] retArray2 = new String[] {"root.sg.d1,2,2,", "root.sg.d2,2,2,"};
+
+    resultSetEqualTest(
+        "select count(s1),count(s2) from root.sg.** order by time align by device",
+        expectedHeader2,
+        retArray2);
+  }
+
+  @Test
+  public void testCountAlignByDeviceOrderByTimeWithValueFilter() {
+    // result of SINGLE_SERIES
+    String expectedHeader1 = "Device,count(s1),";
+    String[] retArray1 = new String[] {"root.single.d1,2,", "root.single.d2,1,"};
+
+    resultSetEqualTest(
+        "select count(s1),count(s2) from root.single.** where s1 <= 4 order by time align by device",
+        expectedHeader1,
+        retArray1);
+
+    // result of MULTI_SERIES
+    String expectedHeader2 = "Device,count(s1),count(s2),";
+    String[] retArray2 = new String[] {"root.sg.d1,2,2,", "root.sg.d2,1,1,"};
+
+    resultSetEqualTest(
+        "select count(s1),count(s2) from root.sg.** where s1 <= 4 order by time align by device",
+        expectedHeader2,
+        retArray2);
+  }
+}
diff --git a/integration-test/src/test/java/org/apache/iotdb/db/it/alignbydevice/IoTDBShuffleSink2IT.java b/integration-test/src/test/java/org/apache/iotdb/db/it/alignbydevice/IoTDBShuffleSink2IT.java
new file mode 100644
index 0000000000..8558e1f644
--- /dev/null
+++ b/integration-test/src/test/java/org/apache/iotdb/db/it/alignbydevice/IoTDBShuffleSink2IT.java
@@ -0,0 +1,205 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iotdb.db.it.alignbydevice;
+
+import org.apache.iotdb.it.env.EnvFactory;
+import org.apache.iotdb.it.framework.IoTDBTestRunner;
+import org.apache.iotdb.itbase.category.ClusterIT;
+import org.apache.iotdb.itbase.category.LocalStandaloneIT;
+
+import org.junit.AfterClass;
+import org.junit.BeforeClass;
+import org.junit.Test;
+import org.junit.experimental.categories.Category;
+import org.junit.runner.RunWith;
+
+import static org.apache.iotdb.db.it.utils.TestUtils.prepareData;
+import static org.apache.iotdb.db.it.utils.TestUtils.resultSetEqualTest;
+
+@RunWith(IoTDBTestRunner.class)
+@Category({LocalStandaloneIT.class, ClusterIT.class})
+public class IoTDBShuffleSink2IT {
+  private static final String[] SINGLE_SERIES =
+      new String[] {
+        "create database root.single",
+        "insert into root.single.d1(time,s1) values (1,1)",
+        "insert into root.single.d1(time,s1) values (now(),2)",
+        "insert into root.single.d2(time,s1) values (now(),3)",
+        "insert into root.single.d2(time,s1) values (1,4)",
+        "insert into root.single.d3(time,s1) values (now(),5)",
+        "insert into root.single.d3(time,s1) values (1,6)"
+      };
+
+  // three devices, three data regions
+  private static final String[] MULTI_SERIES =
+      new String[] {
+        "create database root.sg",
+        "insert into root.sg.d1(time,s1,s2) values (1,1,1)",
+        "insert into root.sg.d1(time,s1,s2) values (now(),2,2)",
+        "insert into root.sg.d2(time,s1,s2) values (now(),3,3)",
+        "insert into root.sg.d2(time,s1,s2) values (1,4,4)",
+        "insert into root.sg.d3(time,s1,s2) values (now(),5,5)",
+        "insert into root.sg.d3(time,s1,s2) values (1,6,6)"
+      };
+
+  // three devices, three data regions, d3 has only one region
+  private static final String[] SECOND_MULTI_SERIES =
+      new String[] {
+        "create database root.sg1",
+        "insert into root.sg1.d1(time,s1,s2) values (1,1,1)",
+        "insert into root.sg1.d1(time,s1,s2) values (now(),2,2)",
+        "insert into root.sg1.d2(time,s1,s2) values (now(),3,3)",
+        "insert into root.sg1.d2(time,s1,s2) values (1,4,4)",
+        "insert into root.sg1.d3(time,s1,s2) values (1,6,6)"
+      };
+
+  @BeforeClass
+  public static void setUp() throws Exception {
+    EnvFactory.getEnv().getConfig().getCommonConfig().setDataRegionGroupExtensionPolicy("CUSTOM");
+    EnvFactory.getEnv().getConfig().getCommonConfig().setDefaultDataRegionGroupNumPerDatabase(3);
+    EnvFactory.getEnv().initClusterEnvironment();
+    prepareData(SINGLE_SERIES);
+    prepareData(MULTI_SERIES);
+    prepareData(SECOND_MULTI_SERIES);
+  }
+
+  @AfterClass
+  public static void tearDown() throws Exception {
+    EnvFactory.getEnv().cleanClusterEnvironment();
+  }
+
+  @Test
+  public void testCountAlignByDeviceOrderByDeviceWithoutValueFilter() {
+    // result of SINGLE_SERIES
+    String expectedHeader1 = "Device,count(s1),";
+    String[] retArray1 =
+        new String[] {"root.single.d1,2,", "root.single.d2,2,", "root.single.d3,2,"};
+
+    resultSetEqualTest(
+        "select count(s1) from root.single.** align by device", expectedHeader1, retArray1);
+
+    // result of MULTI_SERIES
+    String expectedHeader2 = "Device,count(s1),count(s2),";
+    String[] retArray2 = new String[] {"root.sg.d1,2,2,", "root.sg.d2,2,2,", "root.sg.d3,2,2,"};
+
+    resultSetEqualTest(
+        "select count(s1),count(s2) from root.sg.** align by device", expectedHeader2, retArray2);
+
+    // result of SECOND_MULTI_SERIES
+    String expectedHeader3 = "Device,count(s1),count(s2),";
+    String[] retArray3 = new String[] {"root.sg1.d1,2,2,", "root.sg1.d2,2,2,", "root.sg1.d3,1,1,"};
+
+    resultSetEqualTest(
+        "select count(s1),count(s2) from root.sg1.** align by device", expectedHeader3, retArray3);
+  }
+
+  @Test
+  public void testCountAlignByDeviceOrderByDeviceWithValueFilter() {
+    // result of SINGLE_SERIES
+    String expectedHeader1 = "Device,count(s1),";
+    String[] retArray1 =
+        new String[] {"root.single.d1,2,", "root.single.d2,2,", "root.single.d3,0,"};
+
+    resultSetEqualTest(
+        "select count(s1),count(s2) from root.single.** where s1 <= 4 align by device",
+        expectedHeader1,
+        retArray1);
+
+    // result of MULTI_SERIES
+    String expectedHeader2 = "Device,count(s1),count(s2),";
+    String[] retArray2 = new String[] {"root.sg.d1,2,2,", "root.sg.d2,2,2,", "root.sg.d3,0,0,"};
+
+    resultSetEqualTest(
+        "select count(s1),count(s2) from root.sg.** where s1 <= 4 align by device",
+        expectedHeader2,
+        retArray2);
+
+    // result of SECOND_MULTI_SERIES
+    String expectedHeader3 = "Device,count(s1),count(s2),";
+    String[] retArray3 = new String[] {"root.sg1.d1,2,2,", "root.sg1.d2,2,2,", "root.sg1.d3,0,0,"};
+
+    resultSetEqualTest(
+        "select count(s1),count(s2) from root.sg1.** where s1 <= 4 align by device",
+        expectedHeader3,
+        retArray3);
+  }
+
+  @Test
+  public void testCountAlignByDeviceOrderByTimeWithoutValueFilter() {
+    // result of SINGLE_SERIES
+    String expectedHeader1 = "Device,count(s1),";
+    String[] retArray1 =
+        new String[] {"root.single.d1,2,", "root.single.d2,2,", "root.single.d3,2,"};
+
+    resultSetEqualTest(
+        "select count(s1) from root.single.** order by time align by device",
+        expectedHeader1,
+        retArray1);
+
+    // result of MULTI_SERIES
+    String expectedHeader2 = "Device,count(s1),count(s2),";
+    String[] retArray2 = new String[] {"root.sg.d1,2,2,", "root.sg.d2,2,2,", "root.sg.d3,2,2,"};
+
+    resultSetEqualTest(
+        "select count(s1),count(s2) from root.sg.** order by time align by device",
+        expectedHeader2,
+        retArray2);
+
+    // result of SECOND_MULTI_SERIES
+    String expectedHeader3 = "Device,count(s1),count(s2),";
+    String[] retArray3 = new String[] {"root.sg1.d1,2,2,", "root.sg1.d2,2,2,", "root.sg1.d3,1,1,"};
+
+    resultSetEqualTest(
+        "select count(s1),count(s2) from root.sg1.** order by time align by device",
+        expectedHeader3,
+        retArray3);
+  }
+
+  @Test
+  public void testCountAlignByDeviceOrderByTimeWithValueFilter() {
+    // result of SINGLE_SERIES
+    String expectedHeader1 = "Device,count(s1),";
+    String[] retArray1 =
+        new String[] {"root.single.d1,2,", "root.single.d2,2,", "root.single.d3,0,"};
+
+    resultSetEqualTest(
+        "select count(s1),count(s2) from root.single.** where s1 <= 4 order by time align by device",
+        expectedHeader1,
+        retArray1);
+
+    // result of MULTI_SERIES
+    String expectedHeader2 = "Device,count(s1),count(s2),";
+    String[] retArray2 = new String[] {"root.sg.d1,2,2,", "root.sg.d2,2,2,", "root.sg.d3,0,0,"};
+
+    resultSetEqualTest(
+        "select count(s1),count(s2) from root.sg.** where s1 <= 4 order by time align by device",
+        expectedHeader2,
+        retArray2);
+
+    // result of MULTI_SERIES
+    String expectedHeader3 = "Device,count(s1),count(s2),";
+    String[] retArray3 = new String[] {"root.sg1.d1,2,2,", "root.sg1.d2,2,2,", "root.sg1.d3,0,0,"};
+
+    resultSetEqualTest(
+        "select count(s1),count(s2) from root.sg1.** where s1 <= 4 order by time align by device",
+        expectedHeader3,
+        retArray3);
+  }
+}
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/execution/driver/Driver.java b/server/src/main/java/org/apache/iotdb/db/mpp/execution/driver/Driver.java
index a1ab1c18db..2fd9a56911 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/execution/driver/Driver.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/execution/driver/Driver.java
@@ -18,7 +18,7 @@
  */
 package org.apache.iotdb.db.mpp.execution.driver;
 
-import org.apache.iotdb.db.mpp.execution.exchange.ISinkHandle;
+import org.apache.iotdb.db.mpp.execution.exchange.sink.ISink;
 import org.apache.iotdb.db.mpp.execution.operator.Operator;
 import org.apache.iotdb.db.mpp.execution.operator.OperatorContext;
 import org.apache.iotdb.db.mpp.execution.schedule.task.DriverTaskId;
@@ -55,7 +55,7 @@ public abstract class Driver implements IDriver {
 
   protected final DriverContext driverContext;
   protected final Operator root;
-  protected final ISinkHandle sinkHandle;
+  protected final ISink sink;
   protected final AtomicReference<SettableFuture<?>> driverBlockedFuture = new AtomicReference<>();
   protected final AtomicReference<State> state = new AtomicReference<>(State.ALIVE);
 
@@ -71,10 +71,10 @@ public abstract class Driver implements IDriver {
 
   protected Driver(Operator root, DriverContext driverContext) {
     checkNotNull(root, "root Operator should not be null");
-    checkNotNull(driverContext.getSinkHandle(), "SinkHandle should not be null");
+    checkNotNull(driverContext.getSink(), "Sink should not be null");
     this.driverContext = driverContext;
     this.root = root;
-    this.sinkHandle = driverContext.getSinkHandle();
+    this.sink = driverContext.getSink();
 
     // initially the driverBlockedFuture is not blocked (it is completed)
     SettableFuture<Void> future = SettableFuture.create();
@@ -182,8 +182,8 @@ public abstract class Driver implements IDriver {
   }
 
   @Override
-  public ISinkHandle getSinkHandle() {
-    return sinkHandle;
+  public ISink getSink() {
+    return sink;
   }
 
   @GuardedBy("exclusiveLock")
@@ -204,14 +204,14 @@ public abstract class Driver implements IDriver {
       if (!blocked.isDone()) {
         return blocked;
       }
-      blocked = sinkHandle.isFull();
+      blocked = sink.isFull();
       if (!blocked.isDone()) {
         return blocked;
       }
       if (root.hasNextWithTimer()) {
         TsBlock tsBlock = root.nextWithTimer();
         if (tsBlock != null && !tsBlock.isEmpty()) {
-          sinkHandle.send(tsBlock);
+          sink.send(tsBlock);
         }
       }
       return NOT_BLOCKED;
@@ -363,7 +363,7 @@ public abstract class Driver implements IDriver {
 
     try {
       root.close();
-      sinkHandle.setNoMoreTsBlocks();
+      sink.setNoMoreTsBlocks();
 
       // record operator execution statistics to metrics
       List<OperatorContext> operatorContexts = driverContext.getOperatorContexts();
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/execution/driver/DriverContext.java b/server/src/main/java/org/apache/iotdb/db/mpp/execution/driver/DriverContext.java
index ea369d8fad..58a131545f 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/execution/driver/DriverContext.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/execution/driver/DriverContext.java
@@ -18,7 +18,7 @@
  */
 package org.apache.iotdb.db.mpp.execution.driver;
 
-import org.apache.iotdb.db.mpp.execution.exchange.ISinkHandle;
+import org.apache.iotdb.db.mpp.execution.exchange.sink.ISink;
 import org.apache.iotdb.db.mpp.execution.fragment.FragmentInstanceContext;
 import org.apache.iotdb.db.mpp.execution.operator.OperatorContext;
 import org.apache.iotdb.db.mpp.execution.schedule.task.DriverTaskId;
@@ -37,7 +37,7 @@ public class DriverContext {
   private DriverTaskId driverTaskID;
   private final FragmentInstanceContext fragmentInstanceContext;
   private final List<OperatorContext> operatorContexts = new ArrayList<>();
-  private ISinkHandle sinkHandle;
+  private ISink sink;
   private final RuleBasedTimeSliceAllocator timeSliceAllocator;
   private int dependencyDriverIndex = -1;
 
@@ -78,12 +78,12 @@ public class DriverContext {
     return dependencyDriverIndex;
   }
 
-  public void setSinkHandle(ISinkHandle sinkHandle) {
-    this.sinkHandle = sinkHandle;
+  public void setSink(ISink sink) {
+    this.sink = sink;
   }
 
-  public ISinkHandle getSinkHandle() {
-    return sinkHandle;
+  public ISink getSink() {
+    return sink;
   }
 
   public boolean isInputDriver() {
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/execution/driver/IDriver.java b/server/src/main/java/org/apache/iotdb/db/mpp/execution/driver/IDriver.java
index ff55d5456c..88514bddcc 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/execution/driver/IDriver.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/execution/driver/IDriver.java
@@ -19,7 +19,7 @@
 package org.apache.iotdb.db.mpp.execution.driver;
 
 import org.apache.iotdb.db.mpp.common.FragmentInstanceId;
-import org.apache.iotdb.db.mpp.execution.exchange.ISinkHandle;
+import org.apache.iotdb.db.mpp.execution.exchange.sink.ISink;
 import org.apache.iotdb.db.mpp.execution.schedule.task.DriverTaskId;
 
 import com.google.common.util.concurrent.ListenableFuture;
@@ -69,8 +69,8 @@ public interface IDriver {
    */
   void failed(Throwable t);
 
-  /** @return get SinkHandle of current IDriver */
-  ISinkHandle getSinkHandle();
+  /** @return get Sink of current IDriver */
+  ISink getSink();
 
   int getDependencyDriverIndex();
 }
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/IMPPDataExchangeManager.java b/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/IMPPDataExchangeManager.java
index c02595374c..bbbce26e7c 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/IMPPDataExchangeManager.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/IMPPDataExchangeManager.java
@@ -20,9 +20,18 @@
 package org.apache.iotdb.db.mpp.execution.exchange;
 
 import org.apache.iotdb.common.rpc.thrift.TEndPoint;
+import org.apache.iotdb.db.mpp.execution.driver.DriverContext;
+import org.apache.iotdb.db.mpp.execution.exchange.sink.DownStreamChannelIndex;
+import org.apache.iotdb.db.mpp.execution.exchange.sink.DownStreamChannelLocation;
+import org.apache.iotdb.db.mpp.execution.exchange.sink.ISinkChannel;
+import org.apache.iotdb.db.mpp.execution.exchange.sink.ISinkHandle;
+import org.apache.iotdb.db.mpp.execution.exchange.sink.ShuffleSinkHandle;
+import org.apache.iotdb.db.mpp.execution.exchange.source.ISourceHandle;
 import org.apache.iotdb.db.mpp.execution.fragment.FragmentInstanceContext;
 import org.apache.iotdb.mpp.rpc.thrift.TFragmentInstanceId;
 
+import java.util.List;
+
 public interface IMPPDataExchangeManager {
   /**
    * Create a sink handle who sends data blocks to a remote downstream fragment instance in async
@@ -30,26 +39,17 @@ public interface IMPPDataExchangeManager {
    *
    * @param localFragmentInstanceId ID of the local fragment instance who generates and sends data
    *     blocks to the sink handle.
-   * @param remoteEndpoint Hostname and Port of the remote fragment instance where the data blocks
-   *     should be sent to.
-   * @param remotePlanNodeId The sink plan node ID of the remote fragment instance.
-   * @param remotePlanNodeId The plan node ID of the local fragment instance.
    * @param instanceContext The context of local fragment instance.
    */
-  ISinkHandle createSinkHandle(
+  ISinkHandle createShuffleSinkHandle(
+      List<DownStreamChannelLocation> downStreamChannelLocationList,
+      DownStreamChannelIndex downStreamChannelIndex,
+      ShuffleSinkHandle.ShuffleStrategyEnum shuffleStrategyEnum,
       TFragmentInstanceId localFragmentInstanceId,
-      TEndPoint remoteEndpoint,
-      TFragmentInstanceId remoteFragmentInstanceId,
-      String remotePlanNodeId,
       String localPlanNodeId,
       FragmentInstanceContext instanceContext);
 
-  ISinkHandle createLocalSinkHandleForFragment(
-      TFragmentInstanceId localFragmentInstanceId,
-      TFragmentInstanceId remoteFragmentInstanceId,
-      String remotePlanNodeId,
-      FragmentInstanceContext instanceContext);
-
+  ISinkChannel createLocalSinkChannelForPipeline(DriverContext driverContext, String planNodeId);
   /**
    * Create a source handle who fetches data blocks from a remote upstream fragment instance for a
    * plan node of a local fragment instance in async manner.
@@ -65,6 +65,7 @@ public interface IMPPDataExchangeManager {
   ISourceHandle createSourceHandle(
       TFragmentInstanceId localFragmentInstanceId,
       String localPlanNodeId,
+      int indexOfUpstreamSinkHandle,
       TEndPoint remoteEndpoint,
       TFragmentInstanceId remoteFragmentInstanceId,
       IMPPDataExchangeManagerCallback<Throwable> onFailureCallback);
@@ -72,9 +73,14 @@ public interface IMPPDataExchangeManager {
   ISourceHandle createLocalSourceHandleForFragment(
       TFragmentInstanceId localFragmentInstanceId,
       String localPlanNodeId,
+      String remotePlanNodeId,
       TFragmentInstanceId remoteFragmentInstanceId,
+      int index,
       IMPPDataExchangeManagerCallback<Throwable> onFailureCallback);
 
+  /** SharedTsBlockQueue must belong to corresponding LocalSinkChannel */
+  ISourceHandle createLocalSourceHandleForPipeline(SharedTsBlockQueue queue, DriverContext context);
+
   /**
    * Release all the related resources of a fragment instance, including data blocks that are not
    * yet fetched by downstream fragment instances.
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/MPPDataExchangeManager.java b/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/MPPDataExchangeManager.java
index 2bf11da42f..39cf36a5c0 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/MPPDataExchangeManager.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/MPPDataExchangeManager.java
@@ -23,6 +23,17 @@ import org.apache.iotdb.common.rpc.thrift.TEndPoint;
 import org.apache.iotdb.commons.client.IClientManager;
 import org.apache.iotdb.commons.client.sync.SyncDataNodeMPPDataExchangeServiceClient;
 import org.apache.iotdb.db.mpp.execution.driver.DriverContext;
+import org.apache.iotdb.db.mpp.execution.exchange.sink.DownStreamChannelIndex;
+import org.apache.iotdb.db.mpp.execution.exchange.sink.DownStreamChannelLocation;
+import org.apache.iotdb.db.mpp.execution.exchange.sink.ISink;
+import org.apache.iotdb.db.mpp.execution.exchange.sink.ISinkChannel;
+import org.apache.iotdb.db.mpp.execution.exchange.sink.ISinkHandle;
+import org.apache.iotdb.db.mpp.execution.exchange.sink.LocalSinkChannel;
+import org.apache.iotdb.db.mpp.execution.exchange.sink.ShuffleSinkHandle;
+import org.apache.iotdb.db.mpp.execution.exchange.sink.SinkChannel;
+import org.apache.iotdb.db.mpp.execution.exchange.source.ISourceHandle;
+import org.apache.iotdb.db.mpp.execution.exchange.source.LocalSourceHandle;
+import org.apache.iotdb.db.mpp.execution.exchange.source.SourceHandle;
 import org.apache.iotdb.db.mpp.execution.fragment.FragmentInstanceContext;
 import org.apache.iotdb.db.mpp.execution.memory.LocalMemoryManager;
 import org.apache.iotdb.db.mpp.metric.QueryMetricsManager;
@@ -43,13 +54,16 @@ import org.slf4j.LoggerFactory;
 
 import java.io.IOException;
 import java.nio.ByteBuffer;
+import java.util.List;
 import java.util.Map;
 import java.util.Map.Entry;
 import java.util.Optional;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.ExecutorService;
 import java.util.function.Supplier;
+import java.util.stream.Collectors;
 
+import static org.apache.iotdb.db.mpp.common.DataNodeEndPoints.isSameNode;
 import static org.apache.iotdb.db.mpp.common.FragmentInstanceId.createFullId;
 import static org.apache.iotdb.db.mpp.metric.DataExchangeCostMetricSet.GET_DATA_BLOCK_TASK_SERVER;
 import static org.apache.iotdb.db.mpp.metric.DataExchangeCostMetricSet.ON_ACKNOWLEDGE_DATA_BLOCK_EVENT_TASK_SERVER;
@@ -60,25 +74,9 @@ import static org.apache.iotdb.db.mpp.metric.DataExchangeCountMetricSet.SEND_NEW
 
 public class MPPDataExchangeManager implements IMPPDataExchangeManager {
 
-  private static final Logger logger = LoggerFactory.getLogger(MPPDataExchangeManager.class);
+  private static final Logger LOGGER = LoggerFactory.getLogger(MPPDataExchangeManager.class);
 
-  public interface SourceHandleListener {
-    void onFinished(ISourceHandle sourceHandle);
-
-    void onAborted(ISourceHandle sourceHandle);
-
-    void onFailure(ISourceHandle sourceHandle, Throwable t);
-  }
-
-  public interface SinkHandleListener {
-    void onFinish(ISinkHandle sinkHandle);
-
-    void onEndOfBlocks(ISinkHandle sinkHandle);
-
-    Optional<Throwable> onAborted(ISinkHandle sinkHandle);
-
-    void onFailure(ISinkHandle sinkHandle, Throwable t);
-  }
+  // region =========== MPPDataExchangeServiceImpl ===========
 
   /** Handle thrift communications. */
   class MPPDataExchangeServiceImpl implements MPPDataExchangeService.Iface {
@@ -94,21 +92,26 @@ public class MPPDataExchangeManager implements IMPPDataExchangeManager {
                   req.sourceFragmentInstanceId.queryId,
                   req.sourceFragmentInstanceId.fragmentId,
                   req.sourceFragmentInstanceId.instanceId))) {
-        logger.debug(
+        LOGGER.debug(
             "[ProcessGetTsBlockRequest] sequence ID in [{}, {})",
             req.getStartSequenceId(),
             req.getEndSequenceId());
-        if (!sinkHandles.containsKey(req.getSourceFragmentInstanceId())) {
+        if (!shuffleSinkHandles.containsKey(req.getSourceFragmentInstanceId())) {
           throw new TException(
               "Source fragment instance not found. Fragment instance ID: "
                   + req.getSourceFragmentInstanceId()
                   + ".");
         }
         TGetDataBlockResponse resp = new TGetDataBlockResponse();
-        SinkHandle sinkHandle = (SinkHandle) sinkHandles.get(req.getSourceFragmentInstanceId());
+        // index of the channel must be a SinkChannel
+        SinkChannel sinkChannelHandle =
+            (SinkChannel)
+                (shuffleSinkHandles
+                    .get(req.getSourceFragmentInstanceId())
+                    .getChannel(req.getIndex()));
         for (int i = req.getStartSequenceId(); i < req.getEndSequenceId(); i++) {
           try {
-            ByteBuffer serializedTsBlock = sinkHandle.getSerializedTsBlock(i);
+            ByteBuffer serializedTsBlock = sinkChannelHandle.getSerializedTsBlock(i);
             resp.addToTsBlocks(serializedTsBlock);
           } catch (IllegalStateException | IOException e) {
             throw new TException(e);
@@ -132,21 +135,23 @@ public class MPPDataExchangeManager implements IMPPDataExchangeManager {
                   e.sourceFragmentInstanceId.queryId,
                   e.sourceFragmentInstanceId.fragmentId,
                   e.sourceFragmentInstanceId.instanceId))) {
-        logger.debug(
+        LOGGER.debug(
             "Acknowledge data block event received, for data blocks whose sequence ID in [{}, {}) from {}.",
             e.getStartSequenceId(),
             e.getEndSequenceId(),
             e.getSourceFragmentInstanceId());
-        if (!sinkHandles.containsKey(e.getSourceFragmentInstanceId())) {
-          logger.debug(
+        if (!shuffleSinkHandles.containsKey(e.getSourceFragmentInstanceId())) {
+          LOGGER.debug(
               "received ACK event but target FragmentInstance[{}] is not found.",
               e.getSourceFragmentInstanceId());
           return;
         }
-        ((SinkHandle) sinkHandles.get(e.getSourceFragmentInstanceId()))
+        // index of the channel must be a SinkChannel
+        ((SinkChannel)
+                (shuffleSinkHandles.get(e.getSourceFragmentInstanceId()).getChannel(e.getIndex())))
             .acknowledgeTsBlock(e.getStartSequenceId(), e.getEndSequenceId());
       } catch (Throwable t) {
-        logger.warn(
+        LOGGER.warn(
             "ack TsBlock [{}, {}) failed.", e.getStartSequenceId(), e.getEndSequenceId(), t);
         throw t;
       } finally {
@@ -162,7 +167,7 @@ public class MPPDataExchangeManager implements IMPPDataExchangeManager {
       long startTime = System.nanoTime();
       try (SetThreadName fragmentInstanceName =
           new SetThreadName(createFullIdFrom(e.targetFragmentInstanceId, e.targetPlanNodeId))) {
-        logger.debug(
+        LOGGER.debug(
             "New data block event received, for plan node {} of {} from {}.",
             e.getTargetPlanNodeId(),
             e.getTargetFragmentInstanceId(),
@@ -180,7 +185,7 @@ public class MPPDataExchangeManager implements IMPPDataExchangeManager {
           // may
           // have already been stopped. For example, in the query whit LimitOperator, the downstream
           // FragmentInstance may be finished, although the upstream is still working.
-          logger.debug(
+          LOGGER.debug(
               "received NewDataBlockEvent but the downstream FragmentInstance[{}] is not found",
               e.getTargetFragmentInstanceId());
           return;
@@ -198,7 +203,7 @@ public class MPPDataExchangeManager implements IMPPDataExchangeManager {
     public void onEndOfDataBlockEvent(TEndOfDataBlockEvent e) throws TException {
       try (SetThreadName fragmentInstanceName =
           new SetThreadName(createFullIdFrom(e.targetFragmentInstanceId, e.targetPlanNodeId))) {
-        logger.debug(
+        LOGGER.debug(
             "End of data block event received, for plan node {} of {} from {}.",
             e.getTargetPlanNodeId(),
             e.getTargetFragmentInstanceId(),
@@ -212,7 +217,7 @@ public class MPPDataExchangeManager implements IMPPDataExchangeManager {
                 : (SourceHandle) sourceHandleMap.get(e.getTargetPlanNodeId());
 
         if (sourceHandle == null || sourceHandle.isAborted() || sourceHandle.isFinished()) {
-          logger.debug(
+          LOGGER.debug(
               "received onEndOfDataBlockEvent but the downstream FragmentInstance[{}] is not found",
               e.getTargetFragmentInstanceId());
           return;
@@ -223,6 +228,28 @@ public class MPPDataExchangeManager implements IMPPDataExchangeManager {
     }
   }
 
+  // endregion
+
+  // region =========== listener ===========
+
+  public interface SourceHandleListener {
+    void onFinished(ISourceHandle sourceHandle);
+
+    void onAborted(ISourceHandle sourceHandle);
+
+    void onFailure(ISourceHandle sourceHandle, Throwable t);
+  }
+
+  public interface SinkListener {
+    void onFinish(ISink sink);
+
+    void onEndOfBlocks(ISink sink);
+
+    Optional<Throwable> onAborted(ISink sink);
+
+    void onFailure(ISink sink, Throwable t);
+  }
+
   /** Listen to the state changes of a source handle. */
   class SourceHandleListenerImpl implements SourceHandleListener {
 
@@ -234,12 +261,12 @@ public class MPPDataExchangeManager implements IMPPDataExchangeManager {
 
     @Override
     public void onFinished(ISourceHandle sourceHandle) {
-      logger.debug("[ScHListenerOnFinish]");
+      LOGGER.debug("[ScHListenerOnFinish]");
       Map<String, ISourceHandle> sourceHandleMap =
           sourceHandles.get(sourceHandle.getLocalFragmentInstanceId());
       if (sourceHandleMap == null
           || sourceHandleMap.remove(sourceHandle.getLocalPlanNodeId()) == null) {
-        logger.debug("[ScHListenerAlreadyReleased]");
+        LOGGER.debug("[ScHListenerAlreadyReleased]");
       }
 
       if (sourceHandleMap != null && sourceHandleMap.isEmpty()) {
@@ -249,13 +276,13 @@ public class MPPDataExchangeManager implements IMPPDataExchangeManager {
 
     @Override
     public void onAborted(ISourceHandle sourceHandle) {
-      logger.debug("[ScHListenerOnAbort]");
+      LOGGER.debug("[ScHListenerOnAbort]");
       onFinished(sourceHandle);
     }
 
     @Override
     public void onFailure(ISourceHandle sourceHandle, Throwable t) {
-      logger.warn("Source handle failed due to: ", t);
+      LOGGER.warn("Source handle failed due to: ", t);
       if (onFailureCallback != null) {
         onFailureCallback.call(t);
       }
@@ -277,17 +304,17 @@ public class MPPDataExchangeManager implements IMPPDataExchangeManager {
 
     @Override
     public void onFinished(ISourceHandle sourceHandle) {
-      logger.debug("[ScHListenerOnFinish]");
+      LOGGER.debug("[ScHListenerOnFinish]");
     }
 
     @Override
     public void onAborted(ISourceHandle sourceHandle) {
-      logger.debug("[ScHListenerOnAbort]");
+      LOGGER.debug("[ScHListenerOnAbort]");
     }
 
     @Override
     public void onFailure(ISourceHandle sourceHandle, Throwable t) {
-      logger.warn("Source handle failed due to: ", t);
+      LOGGER.warn("Source handle failed due to: ", t);
       if (onFailureCallback != null) {
         onFailureCallback.call(t);
       }
@@ -295,12 +322,12 @@ public class MPPDataExchangeManager implements IMPPDataExchangeManager {
   }
 
   /** Listen to the state changes of a sink handle. */
-  class SinkHandleListenerImpl implements SinkHandleListener {
+  class ShuffleSinkListenerImpl implements SinkListener {
 
     private final FragmentInstanceContext context;
     private final IMPPDataExchangeManagerCallback<Throwable> onFailureCallback;
 
-    public SinkHandleListenerImpl(
+    public ShuffleSinkListenerImpl(
         FragmentInstanceContext context,
         IMPPDataExchangeManagerCallback<Throwable> onFailureCallback) {
       this.context = context;
@@ -308,37 +335,29 @@ public class MPPDataExchangeManager implements IMPPDataExchangeManager {
     }
 
     @Override
-    public void onFinish(ISinkHandle sinkHandle) {
-      logger.debug("[SkHListenerOnFinish]");
-      removeFromMPPDataExchangeManager(sinkHandle);
+    public void onFinish(ISink sink) {
+      LOGGER.debug("[ShuffleSinkHandleListenerOnFinish]");
+      shuffleSinkHandles.remove(sink.getLocalFragmentInstanceId());
       context.finished();
     }
 
     @Override
-    public void onEndOfBlocks(ISinkHandle sinkHandle) {
-      logger.debug("[SkHListenerOnEndOfTsBlocks]");
+    public void onEndOfBlocks(ISink sink) {
+      LOGGER.debug("[ShuffleSinkHandleListenerOnEndOfTsBlocks]");
       context.transitionToFlushing();
     }
 
     @Override
-    public Optional<Throwable> onAborted(ISinkHandle sinkHandle) {
-      logger.debug("[SkHListenerOnAbort]");
-      removeFromMPPDataExchangeManager(sinkHandle);
+    public Optional<Throwable> onAborted(ISink sink) {
+      LOGGER.debug("[ShuffleSinkHandleListenerOnAbort]");
+      shuffleSinkHandles.remove(sink.getLocalFragmentInstanceId());
       return context.getFailureCause();
     }
 
-    private void removeFromMPPDataExchangeManager(ISinkHandle sinkHandle) {
-      if (sinkHandles.remove(sinkHandle.getLocalFragmentInstanceId()) == null) {
-        logger.debug("[RemoveNoSinkHandle]");
-      } else {
-        logger.debug("[RemoveSinkHandle]");
-      }
-    }
-
     @Override
-    public void onFailure(ISinkHandle sinkHandle, Throwable t) {
-      // TODO: (xingtanzjr) should we remove the sinkHandle from MPPDataExchangeManager ?
-      logger.warn("Sink handle failed due to", t);
+    public void onFailure(ISink sink, Throwable t) {
+      // TODO: (xingtanzjr) should we remove the sink from MPPDataExchangeManager ?
+      LOGGER.warn("Sink failed due to", t);
       if (onFailureCallback != null) {
         onFailureCallback.call(t);
       }
@@ -350,12 +369,12 @@ public class MPPDataExchangeManager implements IMPPDataExchangeManager {
    * handle doesn't equal the finish of the whole fragment, therefore we don't need to notify
    * fragment context. But if it's aborted or failed, it can lead to the total fail.
    */
-  static class PipelineSinkHandleListenerImpl implements SinkHandleListener {
+  static class SinkListenerImpl implements SinkListener {
 
     private final FragmentInstanceContext context;
     private final IMPPDataExchangeManagerCallback<Throwable> onFailureCallback;
 
-    public PipelineSinkHandleListenerImpl(
+    public SinkListenerImpl(
         FragmentInstanceContext context,
         IMPPDataExchangeManagerCallback<Throwable> onFailureCallback) {
       this.context = context;
@@ -363,37 +382,43 @@ public class MPPDataExchangeManager implements IMPPDataExchangeManager {
     }
 
     @Override
-    public void onFinish(ISinkHandle sinkHandle) {
-      logger.debug("[SkHListenerOnFinish]");
+    public void onFinish(ISink sink) {
+      LOGGER.debug("[SkHListenerOnFinish]");
     }
 
     @Override
-    public void onEndOfBlocks(ISinkHandle sinkHandle) {
-      logger.debug("[SkHListenerOnEndOfTsBlocks]");
+    public void onEndOfBlocks(ISink sink) {
+      LOGGER.debug("[SkHListenerOnEndOfTsBlocks]");
     }
 
     @Override
-    public Optional<Throwable> onAborted(ISinkHandle sinkHandle) {
-      logger.debug("[SkHListenerOnAbort]");
+    public Optional<Throwable> onAborted(ISink sink) {
+      LOGGER.debug("[SkHListenerOnAbort]");
       return context.getFailureCause();
     }
 
     @Override
-    public void onFailure(ISinkHandle sinkHandle, Throwable t) {
-      logger.warn("Sink handle failed due to", t);
+    public void onFailure(ISink sink, Throwable t) {
+      LOGGER.warn("Sink handle failed due to", t);
       if (onFailureCallback != null) {
         onFailureCallback.call(t);
       }
     }
   }
 
+  // endregion
+
+  // region =========== MPPDataExchangeManager ===========
+
   private final LocalMemoryManager localMemoryManager;
   private final Supplier<TsBlockSerde> tsBlockSerdeFactory;
   private final ExecutorService executorService;
   private final IClientManager<TEndPoint, SyncDataNodeMPPDataExchangeServiceClient>
       mppDataExchangeServiceClientManager;
   private final Map<TFragmentInstanceId, Map<String, ISourceHandle>> sourceHandles;
-  private final Map<TFragmentInstanceId, ISinkHandle> sinkHandles;
+
+  /** Each FI has only one ShuffleSinkHandle. */
+  private final Map<TFragmentInstanceId, ISinkHandle> shuffleSinkHandles;
 
   private MPPDataExchangeServiceImpl mppDataExchangeService;
 
@@ -409,7 +434,7 @@ public class MPPDataExchangeManager implements IMPPDataExchangeManager {
     this.mppDataExchangeServiceClientManager =
         Validate.notNull(mppDataExchangeServiceClientManager);
     sourceHandles = new ConcurrentHashMap<>();
-    sinkHandles = new ConcurrentHashMap<>();
+    shuffleSinkHandles = new ConcurrentHashMap<>();
   }
 
   public MPPDataExchangeServiceImpl getOrCreateMPPDataExchangeServiceImpl() {
@@ -419,20 +444,16 @@ public class MPPDataExchangeManager implements IMPPDataExchangeManager {
     return mppDataExchangeService;
   }
 
-  @Override
-  public synchronized ISinkHandle createLocalSinkHandleForFragment(
+  private synchronized ISinkChannel createLocalSinkChannel(
       TFragmentInstanceId localFragmentInstanceId,
       TFragmentInstanceId remoteFragmentInstanceId,
       String remotePlanNodeId,
+      String localPlanNodeId,
       // TODO: replace with callbacks to decouple MPPDataExchangeManager from
       // FragmentInstanceContext
       FragmentInstanceContext instanceContext) {
-    if (sinkHandles.containsKey(localFragmentInstanceId)) {
-      throw new IllegalStateException(
-          "Local sink handle for " + localFragmentInstanceId + " exists.");
-    }
 
-    logger.debug(
+    LOGGER.debug(
         "Create local sink handle to plan node {} of {} for {}",
         remotePlanNodeId,
         remoteFragmentInstanceId,
@@ -443,46 +464,40 @@ public class MPPDataExchangeManager implements IMPPDataExchangeManager {
     LocalSourceHandle localSourceHandle =
         sourceHandleMap == null ? null : (LocalSourceHandle) sourceHandleMap.get(remotePlanNodeId);
     if (localSourceHandle != null) {
-      logger.debug("Get shared tsblock queue from local source handle");
+      LOGGER.debug("Get SharedTsBlockQueue from local source handle");
       queue =
           ((LocalSourceHandle) sourceHandles.get(remoteFragmentInstanceId).get(remotePlanNodeId))
               .getSharedTsBlockQueue();
     } else {
-      logger.debug("Create shared tsblock queue");
-      queue =
-          new SharedTsBlockQueue(remoteFragmentInstanceId, remotePlanNodeId, localMemoryManager);
+      LOGGER.debug("Create SharedTsBlockQueue");
+      queue = new SharedTsBlockQueue(localFragmentInstanceId, localPlanNodeId, localMemoryManager);
     }
 
-    LocalSinkHandle localSinkHandle =
-        new LocalSinkHandle(
-            localFragmentInstanceId,
-            queue,
-            new SinkHandleListenerImpl(instanceContext, instanceContext::failed));
-    sinkHandles.put(localFragmentInstanceId, localSinkHandle);
-    return localSinkHandle;
+    return new LocalSinkChannel(
+        localFragmentInstanceId,
+        queue,
+        new SinkListenerImpl(instanceContext, instanceContext::failed));
   }
 
   /**
    * As we know the upstream and downstream node of shared queue, we don't need to put it into the
-   * sinkHandle map.
+   * sink map.
    */
-  public ISinkHandle createLocalSinkHandleForPipeline(
+  public ISinkChannel createLocalSinkChannelForPipeline(
       DriverContext driverContext, String planNodeId) {
-    logger.debug("Create local sink handle for {}", driverContext.getDriverTaskID());
+    LOGGER.debug("Create local sink handle for {}", driverContext.getDriverTaskID());
     SharedTsBlockQueue queue =
         new SharedTsBlockQueue(
             driverContext.getDriverTaskID().getFragmentInstanceId().toThrift(),
             planNodeId,
             localMemoryManager);
     queue.allowAddingTsBlock();
-    return new LocalSinkHandle(
+    return new LocalSinkChannel(
         queue,
-        new PipelineSinkHandleListenerImpl(
-            driverContext.getFragmentInstanceContext(), driverContext::failed));
+        new SinkListenerImpl(driverContext.getFragmentInstanceContext(), driverContext::failed));
   }
 
-  @Override
-  public ISinkHandle createSinkHandle(
+  private ISinkChannel createSinkChannel(
       TFragmentInstanceId localFragmentInstanceId,
       TEndPoint remoteEndpoint,
       TFragmentInstanceId remoteFragmentInstanceId,
@@ -491,30 +506,85 @@ public class MPPDataExchangeManager implements IMPPDataExchangeManager {
       // TODO: replace with callbacks to decouple MPPDataExchangeManager from
       // FragmentInstanceContext
       FragmentInstanceContext instanceContext) {
-    if (sinkHandles.containsKey(localFragmentInstanceId)) {
-      throw new IllegalStateException("Sink handle for " + localFragmentInstanceId + " exists.");
-    }
 
-    logger.debug(
+    LOGGER.debug(
         "Create sink handle to plan node {} of {} for {}",
         remotePlanNodeId,
         remoteFragmentInstanceId,
         localFragmentInstanceId);
 
-    SinkHandle sinkHandle =
-        new SinkHandle(
-            remoteEndpoint,
-            remoteFragmentInstanceId,
-            remotePlanNodeId,
-            localPlanNodeId,
+    return new SinkChannel(
+        remoteEndpoint,
+        remoteFragmentInstanceId,
+        remotePlanNodeId,
+        localPlanNodeId,
+        localFragmentInstanceId,
+        localMemoryManager,
+        executorService,
+        tsBlockSerdeFactory.get(),
+        new SinkListenerImpl(instanceContext, instanceContext::failed),
+        mppDataExchangeServiceClientManager);
+  }
+
+  @Override
+  public ISinkHandle createShuffleSinkHandle(
+      List<DownStreamChannelLocation> downStreamChannelLocationList,
+      DownStreamChannelIndex downStreamChannelIndex,
+      ShuffleSinkHandle.ShuffleStrategyEnum shuffleStrategyEnum,
+      TFragmentInstanceId localFragmentInstanceId,
+      String localPlanNodeId,
+      // TODO: replace with callbacks to decouple MPPDataExchangeManager from
+      // FragmentInstanceContext
+      FragmentInstanceContext instanceContext) {
+    if (shuffleSinkHandles.containsKey(localFragmentInstanceId)) {
+      throw new IllegalStateException(
+          "ShuffleSinkHandle for " + localFragmentInstanceId + " exists.");
+    }
+
+    List<ISinkChannel> downStreamChannelList =
+        downStreamChannelLocationList.stream()
+            .map(
+                downStreamChannelLocation ->
+                    createChannelForShuffleSink(
+                        localFragmentInstanceId,
+                        localPlanNodeId,
+                        downStreamChannelLocation,
+                        instanceContext))
+            .collect(Collectors.toList());
+
+    ShuffleSinkHandle shuffleSinkHandle =
+        new ShuffleSinkHandle(
             localFragmentInstanceId,
-            localMemoryManager,
-            executorService,
-            tsBlockSerdeFactory.get(),
-            new SinkHandleListenerImpl(instanceContext, instanceContext::failed),
-            mppDataExchangeServiceClientManager);
-    sinkHandles.put(localFragmentInstanceId, sinkHandle);
-    return sinkHandle;
+            downStreamChannelList,
+            downStreamChannelIndex,
+            shuffleStrategyEnum,
+            localPlanNodeId,
+            new ShuffleSinkListenerImpl(instanceContext, instanceContext::failed));
+    shuffleSinkHandles.put(localFragmentInstanceId, shuffleSinkHandle);
+    return shuffleSinkHandle;
+  }
+
+  private ISinkChannel createChannelForShuffleSink(
+      TFragmentInstanceId localFragmentInstanceId,
+      String localPlanNodeId,
+      DownStreamChannelLocation downStreamChannelLocation,
+      FragmentInstanceContext instanceContext) {
+    if (isSameNode(downStreamChannelLocation.getRemoteEndpoint())) {
+      return createLocalSinkChannel(
+          localFragmentInstanceId,
+          downStreamChannelLocation.getRemoteFragmentInstanceId(),
+          downStreamChannelLocation.getRemotePlanNodeId(),
+          localPlanNodeId,
+          instanceContext);
+    } else {
+      return createSinkChannel(
+          localFragmentInstanceId,
+          downStreamChannelLocation.getRemoteEndpoint(),
+          downStreamChannelLocation.getRemoteFragmentInstanceId(),
+          downStreamChannelLocation.getRemotePlanNodeId(),
+          localPlanNodeId,
+          instanceContext);
+    }
   }
 
   /**
@@ -523,18 +593,19 @@ public class MPPDataExchangeManager implements IMPPDataExchangeManager {
    */
   public ISourceHandle createLocalSourceHandleForPipeline(
       SharedTsBlockQueue queue, DriverContext context) {
-    logger.debug("Create local source handle for {}", context.getDriverTaskID());
+    LOGGER.debug("Create local source handle for {}", context.getDriverTaskID());
     return new LocalSourceHandle(
         queue,
         new PipelineSourceHandleListenerImpl(context::failed),
         context.getDriverTaskID().toString());
   }
 
-  @Override
   public synchronized ISourceHandle createLocalSourceHandleForFragment(
       TFragmentInstanceId localFragmentInstanceId,
       String localPlanNodeId,
+      String remotePlanNodeId,
       TFragmentInstanceId remoteFragmentInstanceId,
+      int index,
       IMPPDataExchangeManagerCallback<Throwable> onFailureCallback) {
     if (sourceHandles.containsKey(localFragmentInstanceId)
         && sourceHandles.get(localFragmentInstanceId).containsKey(localPlanNodeId)) {
@@ -546,18 +617,21 @@ public class MPPDataExchangeManager implements IMPPDataExchangeManager {
               + " exists.");
     }
 
-    logger.debug(
+    LOGGER.debug(
         "Create local source handle from {} for plan node {} of {}",
         remoteFragmentInstanceId,
         localPlanNodeId,
         localFragmentInstanceId);
     SharedTsBlockQueue queue;
-    if (sinkHandles.containsKey(remoteFragmentInstanceId)) {
-      logger.debug("Get shared tsblock queue from local sink handle");
-      queue = ((LocalSinkHandle) sinkHandles.get(remoteFragmentInstanceId)).getSharedTsBlockQueue();
+    if (shuffleSinkHandles.containsKey(remoteFragmentInstanceId)) {
+      LOGGER.debug("Get SharedTsBlockQueue from local sink handle");
+      queue =
+          ((LocalSinkChannel) shuffleSinkHandles.get(remoteFragmentInstanceId).getChannel(index))
+              .getSharedTsBlockQueue();
     } else {
-      logger.debug("Create shared tsblock queue");
-      queue = new SharedTsBlockQueue(localFragmentInstanceId, localPlanNodeId, localMemoryManager);
+      LOGGER.debug("Create SharedTsBlockQueue");
+      queue =
+          new SharedTsBlockQueue(remoteFragmentInstanceId, remotePlanNodeId, localMemoryManager);
     }
     LocalSourceHandle localSourceHandle =
         new LocalSourceHandle(
@@ -575,6 +649,7 @@ public class MPPDataExchangeManager implements IMPPDataExchangeManager {
   public ISourceHandle createSourceHandle(
       TFragmentInstanceId localFragmentInstanceId,
       String localPlanNodeId,
+      int indexOfUpstreamSinkHandle,
       TEndPoint remoteEndpoint,
       TFragmentInstanceId remoteFragmentInstanceId,
       IMPPDataExchangeManagerCallback<Throwable> onFailureCallback) {
@@ -588,7 +663,7 @@ public class MPPDataExchangeManager implements IMPPDataExchangeManager {
               + " exists.");
     }
 
-    logger.debug(
+    LOGGER.debug(
         "Create source handle from {} for plan node {} of {}",
         remoteFragmentInstanceId,
         localPlanNodeId,
@@ -600,6 +675,7 @@ public class MPPDataExchangeManager implements IMPPDataExchangeManager {
             remoteFragmentInstanceId,
             localFragmentInstanceId,
             localPlanNodeId,
+            indexOfUpstreamSinkHandle,
             localMemoryManager,
             executorService,
             tsBlockSerdeFactory.get(),
@@ -618,21 +694,21 @@ public class MPPDataExchangeManager implements IMPPDataExchangeManager {
    * <p>This method should be called when a fragment instance finished in an abnormal state.
    */
   public void forceDeregisterFragmentInstance(TFragmentInstanceId fragmentInstanceId) {
-    logger.debug("[StartForceReleaseFIDataExchangeResource]");
-    ISinkHandle sinkHandle = sinkHandles.get(fragmentInstanceId);
+    LOGGER.debug("[StartForceReleaseFIDataExchangeResource]");
+    ISink sinkHandle = shuffleSinkHandles.get(fragmentInstanceId);
     if (sinkHandle != null) {
       sinkHandle.abort();
-      sinkHandles.remove(fragmentInstanceId);
+      shuffleSinkHandles.remove(fragmentInstanceId);
     }
     Map<String, ISourceHandle> planNodeIdToSourceHandle = sourceHandles.get(fragmentInstanceId);
     if (planNodeIdToSourceHandle != null) {
       for (Entry<String, ISourceHandle> entry : planNodeIdToSourceHandle.entrySet()) {
-        logger.debug("[CloseSourceHandle] {}", entry.getKey());
+        LOGGER.debug("[CloseSourceHandle] {}", entry.getKey());
         entry.getValue().abort();
       }
       sourceHandles.remove(fragmentInstanceId);
     }
-    logger.debug("[EndForceReleaseFIDataExchangeResource]");
+    LOGGER.debug("[EndForceReleaseFIDataExchangeResource]");
   }
 
   /** @param suffix should be like [PlanNodeId].SourceHandle/SinHandle */
@@ -644,4 +720,5 @@ public class MPPDataExchangeManager implements IMPPDataExchangeManager {
         + "."
         + suffix;
   }
+  // endregion
 }
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/SharedTsBlockQueue.java b/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/SharedTsBlockQueue.java
index 1574d8be3d..b62fffc8c6 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/SharedTsBlockQueue.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/SharedTsBlockQueue.java
@@ -21,6 +21,8 @@ package org.apache.iotdb.db.mpp.execution.exchange;
 
 import org.apache.iotdb.db.conf.IoTDBDescriptor;
 import org.apache.iotdb.db.mpp.common.FragmentInstanceId;
+import org.apache.iotdb.db.mpp.execution.exchange.sink.LocalSinkChannel;
+import org.apache.iotdb.db.mpp.execution.exchange.source.LocalSourceHandle;
 import org.apache.iotdb.db.mpp.execution.memory.LocalMemoryManager;
 import org.apache.iotdb.mpp.rpc.thrift.TFragmentInstanceId;
 import org.apache.iotdb.tsfile.read.common.block.TsBlock;
@@ -44,7 +46,7 @@ import static com.google.common.util.concurrent.MoreExecutors.directExecutor;
 @NotThreadSafe
 public class SharedTsBlockQueue {
 
-  private static final Logger logger = LoggerFactory.getLogger(SharedTsBlockQueue.class);
+  private static final Logger LOGGER = LoggerFactory.getLogger(SharedTsBlockQueue.class);
 
   private final TFragmentInstanceId localFragmentInstanceId;
 
@@ -73,7 +75,7 @@ public class SharedTsBlockQueue {
   private boolean closed = false;
 
   private LocalSourceHandle sourceHandle;
-  private LocalSinkHandle sinkHandle;
+  private LocalSinkChannel sinkChannel;
 
   private long maxBytesCanReserve =
       IoTDBDescriptor.getInstance().getConfig().getMaxBytesPerFragmentInstance();
@@ -129,8 +131,12 @@ public class SharedTsBlockQueue {
     return queue.isEmpty();
   }
 
-  public void setSinkHandle(LocalSinkHandle sinkHandle) {
-    this.sinkHandle = sinkHandle;
+  public int getNumOfBufferedTsBlocks() {
+    return queue.size();
+  }
+
+  public void setSinkChannel(LocalSinkChannel sinkChannel) {
+    this.sinkChannel = sinkChannel;
   }
 
   public void setSourceHandle(LocalSourceHandle sourceHandle) {
@@ -139,9 +145,9 @@ public class SharedTsBlockQueue {
 
   /** Notify no more tsblocks will be added to the queue. */
   public void setNoMoreTsBlocks(boolean noMoreTsBlocks) {
-    logger.debug("[SignalNoMoreTsBlockOnQueue]");
+    LOGGER.debug("[SignalNoMoreTsBlockOnQueue]");
     if (closed) {
-      logger.warn("queue has been destroyed");
+      LOGGER.warn("queue has been destroyed");
       return;
     }
     this.noMoreTsBlocks = noMoreTsBlocks;
@@ -163,9 +169,9 @@ public class SharedTsBlockQueue {
     }
     TsBlock tsBlock = queue.remove();
     // Every time LocalSourceHandle consumes a TsBlock, it needs to send the event to
-    // corresponding LocalSinkHandle.
-    if (sinkHandle != null) {
-      sinkHandle.checkAndInvokeOnFinished();
+    // corresponding LocalSinkChannel.
+    if (sinkChannel != null) {
+      sinkChannel.checkAndInvokeOnFinished();
     }
     localMemoryManager
         .getQueryPool()
@@ -187,7 +193,7 @@ public class SharedTsBlockQueue {
    */
   public ListenableFuture<Void> add(TsBlock tsBlock) {
     if (closed) {
-      logger.warn("queue has been destroyed");
+      LOGGER.warn("queue has been destroyed");
       return immediateVoidFuture();
     }
 
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/distribution/NodeGroupContext.java b/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/sink/DownStreamChannelIndex.java
similarity index 51%
copy from server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/distribution/NodeGroupContext.java
copy to server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/sink/DownStreamChannelIndex.java
index ba3ddb400c..cd36b18ae6 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/distribution/NodeGroupContext.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/sink/DownStreamChannelIndex.java
@@ -17,28 +17,21 @@
  * under the License.
  */
 
-package org.apache.iotdb.db.mpp.plan.planner.distribution;
+package org.apache.iotdb.db.mpp.execution.exchange.sink;
 
-import org.apache.iotdb.db.mpp.common.MPPQueryContext;
-import org.apache.iotdb.db.mpp.plan.planner.plan.node.PlanNodeId;
+public class DownStreamChannelIndex {
+  /** CurrentIndex of downstream ISourceHandle */
+  private int currentIndex;
 
-import java.util.HashMap;
-import java.util.Map;
-
-public class NodeGroupContext {
-  protected MPPQueryContext queryContext;
-  protected Map<PlanNodeId, NodeDistribution> nodeDistributionMap;
-
-  public NodeGroupContext(MPPQueryContext queryContext) {
-    this.queryContext = queryContext;
-    this.nodeDistributionMap = new HashMap<>();
+  public DownStreamChannelIndex(int currentIndex) {
+    this.currentIndex = currentIndex;
   }
 
-  public void putNodeDistribution(PlanNodeId nodeId, NodeDistribution distribution) {
-    this.nodeDistributionMap.put(nodeId, distribution);
+  public int getCurrentIndex() {
+    return currentIndex;
   }
 
-  public NodeDistribution getNodeDistribution(PlanNodeId nodeId) {
-    return this.nodeDistributionMap.get(nodeId);
+  public void setCurrentIndex(int currentIndex) {
+    this.currentIndex = currentIndex;
   }
 }
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/sink/DownStreamChannelLocation.java b/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/sink/DownStreamChannelLocation.java
new file mode 100644
index 0000000000..106dce0c72
--- /dev/null
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/sink/DownStreamChannelLocation.java
@@ -0,0 +1,111 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iotdb.db.mpp.execution.exchange.sink;
+
+import org.apache.iotdb.common.rpc.thrift.TEndPoint;
+import org.apache.iotdb.mpp.rpc.thrift.TFragmentInstanceId;
+import org.apache.iotdb.tsfile.utils.ReadWriteIOUtils;
+
+import java.io.DataOutputStream;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+
+public class DownStreamChannelLocation {
+
+  // We fill these fields util FI was produced
+  private TEndPoint remoteEndpoint;
+  private TFragmentInstanceId remoteFragmentInstanceId;
+
+  private final String remotePlanNodeId;
+
+  /**
+   * @param remoteEndpoint Hostname and Port of the remote fragment instance where the data blocks
+   *     should be sent to.
+   * @param remoteFragmentInstanceId The ID of the remote fragment instance.
+   * @param remotePlanNodeId The plan node ID of the remote exchangeNode.
+   */
+  public DownStreamChannelLocation(
+      TEndPoint remoteEndpoint,
+      TFragmentInstanceId remoteFragmentInstanceId,
+      String remotePlanNodeId) {
+    this.remoteEndpoint = remoteEndpoint;
+    this.remoteFragmentInstanceId = remoteFragmentInstanceId;
+    this.remotePlanNodeId = remotePlanNodeId;
+  }
+
+  public DownStreamChannelLocation(String remotePlanNodeId) {
+    this.remoteEndpoint = null;
+    this.remoteFragmentInstanceId = null;
+    this.remotePlanNodeId = remotePlanNodeId;
+  }
+
+  public void setRemoteEndpoint(TEndPoint remoteEndpoint) {
+    this.remoteEndpoint = remoteEndpoint;
+  }
+
+  public void setRemoteFragmentInstanceId(TFragmentInstanceId remoteFragmentInstanceId) {
+    this.remoteFragmentInstanceId = remoteFragmentInstanceId;
+  }
+
+  public TEndPoint getRemoteEndpoint() {
+    return remoteEndpoint;
+  }
+
+  public TFragmentInstanceId getRemoteFragmentInstanceId() {
+    return remoteFragmentInstanceId;
+  }
+
+  public String getRemotePlanNodeId() {
+    return remotePlanNodeId;
+  }
+
+  public void serialize(ByteBuffer byteBuffer) {
+    ReadWriteIOUtils.write(remoteEndpoint.getIp(), byteBuffer);
+    ReadWriteIOUtils.write(remoteEndpoint.getPort(), byteBuffer);
+    ReadWriteIOUtils.write(remoteFragmentInstanceId.getQueryId(), byteBuffer);
+    ReadWriteIOUtils.write(remoteFragmentInstanceId.getFragmentId(), byteBuffer);
+    ReadWriteIOUtils.write(remoteFragmentInstanceId.getInstanceId(), byteBuffer);
+
+    ReadWriteIOUtils.write(remotePlanNodeId, byteBuffer);
+  }
+
+  public void serialize(DataOutputStream stream) throws IOException {
+    ReadWriteIOUtils.write(remoteEndpoint.getIp(), stream);
+    ReadWriteIOUtils.write(remoteEndpoint.getPort(), stream);
+    ReadWriteIOUtils.write(remoteFragmentInstanceId.getQueryId(), stream);
+    ReadWriteIOUtils.write(remoteFragmentInstanceId.getFragmentId(), stream);
+    ReadWriteIOUtils.write(remoteFragmentInstanceId.getInstanceId(), stream);
+
+    ReadWriteIOUtils.write(remotePlanNodeId, stream);
+  }
+
+  public static DownStreamChannelLocation deserialize(ByteBuffer byteBuffer) {
+    TEndPoint endPoint =
+        new TEndPoint(
+            ReadWriteIOUtils.readString(byteBuffer), ReadWriteIOUtils.readInt(byteBuffer));
+    TFragmentInstanceId fragmentInstanceId =
+        new TFragmentInstanceId(
+            ReadWriteIOUtils.readString(byteBuffer),
+            ReadWriteIOUtils.readInt(byteBuffer),
+            ReadWriteIOUtils.readString(byteBuffer));
+    String remotePlanNodeId = ReadWriteIOUtils.readString(byteBuffer);
+    return new DownStreamChannelLocation(endPoint, fragmentInstanceId, remotePlanNodeId);
+  }
+}
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/ISinkHandle.java b/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/sink/ISink.java
similarity index 59%
rename from server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/ISinkHandle.java
rename to server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/sink/ISink.java
index 4bceda0e15..5af20abf60 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/ISinkHandle.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/sink/ISink.java
@@ -16,71 +16,68 @@
  * specific language governing permissions and limitations
  * under the License.
  */
-package org.apache.iotdb.db.mpp.execution.exchange;
+package org.apache.iotdb.db.mpp.execution.exchange.sink;
 
 import org.apache.iotdb.mpp.rpc.thrift.TFragmentInstanceId;
 import org.apache.iotdb.tsfile.read.common.block.TsBlock;
 
 import com.google.common.util.concurrent.ListenableFuture;
 
-import java.util.List;
-
-public interface ISinkHandle {
+/**
+ * Base interface of {@link ISinkChannel} and {@link ISinkHandle}. This interface defines the
+ * functions we need to transfer data to ISourceHandle.
+ */
+public interface ISink {
 
-  /** Get the local fragment instance ID that this sink handle belongs to. */
+  /** Get the local fragment instance ID that this ISink belongs to. */
   TFragmentInstanceId getLocalFragmentInstanceId();
 
-  /** Get the total amount of memory used by buffered tsblocks. */
+  /** Get the total amount of memory used by buffered TsBlocks. */
   long getBufferRetainedSizeInBytes();
 
   /** Get a future that will be completed when the output buffer is not full. */
   ListenableFuture<?> isFull();
 
   /**
-   * Send a list of tsblocks to an unpartitioned output buffer. If no-more-tsblocks has been set,
+   * Send a {@link TsBlock} to an un-partitioned output buffer. If no-more-TsBlocks has been set,
    * the invocation will be ignored. This can happen with limit queries. A {@link RuntimeException}
    * will be thrown if any exception happened during the data transmission.
    */
   void send(TsBlock tsBlock);
 
   /**
-   * Send a {@link TsBlock} to a specific partition. If no-more-tsblocks has been set, the send
-   * tsblock call is ignored. This can happen with limit queries. A {@link RuntimeException} will be
-   * thrown if any exception happened * during the data transmission.
-   */
-  void send(int partition, List<TsBlock> tsBlocks);
-
-  /**
-   * Notify the handle that there are no more tsblocks. Any future calls to send a tsblock should be
+   * Notify the ISink that there are no more TsBlocks. Any future calls to send a TsBlock should be
    * ignored.
    */
   void setNoMoreTsBlocks();
 
-  /** If the handle is aborted. */
+  /** If the ISink is aborted. */
   boolean isAborted();
 
   /**
-   * If there are no more tsblocks to be sent and all the tsblocks have been fetched by downstream
+   * If there are no more TsBlocks to be sent and all the TsBlocks have been fetched by downstream
    * fragment instances.
    */
   boolean isFinished();
 
   /**
-   * Abort the sink handle. Discard all tsblocks which may still be in the memory buffer and cancel
-   * the future returned by {@link #isFull()}.
+   * Abort the ISink. If this is an ISinkHandle, we should abort all its channels. If this is an
+   * ISinkChannel, we discard all TsBlocks which may still be in the memory buffer and cancel the
+   * future returned by {@link #isFull()}.
    *
    * <p>Should only be called in abnormal case
    */
   void abort();
 
   /**
-   * Close the sink handle. Discard all tsblocks which may still be in the memory buffer and
-   * complete the future returned by {@link #isFull()}.
+   * Close the ISink. If this is an ISinkHandle, we should close all its channels. If this is an
+   * ISinkChannel, we discard all TsBlocks which may still be in the memory buffer and complete the
+   * future returned by {@link #isFull()}.
    *
    * <p>Should only be called in normal case.
    */
   void close();
 
-  /** Set max bytes this handle can reserve from memory pool */
+  /** Set max bytes this ISink can reserve from memory pool */
   void setMaxBytesCanReserve(long maxBytesCanReserve);
 }
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/distribution/NodeGroupContext.java b/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/sink/ISinkChannel.java
similarity index 50%
copy from server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/distribution/NodeGroupContext.java
copy to server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/sink/ISinkChannel.java
index ba3ddb400c..44c74479ec 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/distribution/NodeGroupContext.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/sink/ISinkChannel.java
@@ -17,28 +17,16 @@
  * under the License.
  */
 
-package org.apache.iotdb.db.mpp.plan.planner.distribution;
+package org.apache.iotdb.db.mpp.execution.exchange.sink;
 
-import org.apache.iotdb.db.mpp.common.MPPQueryContext;
-import org.apache.iotdb.db.mpp.plan.planner.plan.node.PlanNodeId;
+public interface ISinkChannel extends ISink {
 
-import java.util.HashMap;
-import java.util.Map;
+  /** Open the channel, should only be called once on each ISinkChannel. */
+  void open();
 
-public class NodeGroupContext {
-  protected MPPQueryContext queryContext;
-  protected Map<PlanNodeId, NodeDistribution> nodeDistributionMap;
+  /** Return true if current channel has no more data. */
+  boolean isNoMoreTsBlocks();
 
-  public NodeGroupContext(MPPQueryContext queryContext) {
-    this.queryContext = queryContext;
-    this.nodeDistributionMap = new HashMap<>();
-  }
-
-  public void putNodeDistribution(PlanNodeId nodeId, NodeDistribution distribution) {
-    this.nodeDistributionMap.put(nodeId, distribution);
-  }
-
-  public NodeDistribution getNodeDistribution(PlanNodeId nodeId) {
-    return this.nodeDistributionMap.get(nodeId);
-  }
+  /** Return the number of TsBlocks the channel has in buffer. */
+  int getNumOfBufferedTsBlocks();
 }
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/distribution/NodeGroupContext.java b/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/sink/ISinkHandle.java
similarity index 50%
copy from server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/distribution/NodeGroupContext.java
copy to server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/sink/ISinkHandle.java
index ba3ddb400c..c600f60eb6 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/distribution/NodeGroupContext.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/sink/ISinkHandle.java
@@ -17,28 +17,20 @@
  * under the License.
  */
 
-package org.apache.iotdb.db.mpp.plan.planner.distribution;
+package org.apache.iotdb.db.mpp.execution.exchange.sink;
 
-import org.apache.iotdb.db.mpp.common.MPPQueryContext;
-import org.apache.iotdb.db.mpp.plan.planner.plan.node.PlanNodeId;
+public interface ISinkHandle extends ISink {
+  /** get channel at specified index */
+  ISinkChannel getChannel(int index);
 
-import java.util.HashMap;
-import java.util.Map;
+  /**
+   * Notify the handle that there are no more TsBlocks for the specified channel. Any future calls
+   * to send a TsBlock to the specified channel should be ignored.
+   *
+   * @param channelIndex index of the channel that should be closed
+   */
+  void setNoMoreTsBlocksOfOneChannel(int channelIndex);
 
-public class NodeGroupContext {
-  protected MPPQueryContext queryContext;
-  protected Map<PlanNodeId, NodeDistribution> nodeDistributionMap;
-
-  public NodeGroupContext(MPPQueryContext queryContext) {
-    this.queryContext = queryContext;
-    this.nodeDistributionMap = new HashMap<>();
-  }
-
-  public void putNodeDistribution(PlanNodeId nodeId, NodeDistribution distribution) {
-    this.nodeDistributionMap.put(nodeId, distribution);
-  }
-
-  public NodeDistribution getNodeDistribution(PlanNodeId nodeId) {
-    return this.nodeDistributionMap.get(nodeId);
-  }
+  /** Open specified channel of ISinkHandle. */
+  void tryOpenChannel(int channelIndex);
 }
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/LocalSinkHandle.java b/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/sink/LocalSinkChannel.java
similarity index 69%
rename from server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/LocalSinkHandle.java
rename to server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/sink/LocalSinkChannel.java
index 8afb07fe6e..29ff8b7a12 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/LocalSinkHandle.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/sink/LocalSinkChannel.java
@@ -17,9 +17,10 @@
  * under the License.
  */
 
-package org.apache.iotdb.db.mpp.execution.exchange;
+package org.apache.iotdb.db.mpp.execution.exchange.sink;
 
-import org.apache.iotdb.db.mpp.execution.exchange.MPPDataExchangeManager.SinkHandleListener;
+import org.apache.iotdb.db.mpp.execution.exchange.MPPDataExchangeManager.SinkListener;
+import org.apache.iotdb.db.mpp.execution.exchange.SharedTsBlockQueue;
 import org.apache.iotdb.db.mpp.metric.QueryMetricsManager;
 import org.apache.iotdb.mpp.rpc.thrift.TFragmentInstanceId;
 import org.apache.iotdb.tsfile.read.common.block.TsBlock;
@@ -29,18 +30,17 @@ import org.apache.commons.lang3.Validate;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
-import java.util.List;
 import java.util.Optional;
 
 import static com.google.common.util.concurrent.Futures.nonCancellationPropagating;
 import static org.apache.iotdb.db.mpp.metric.DataExchangeCostMetricSet.SINK_HANDLE_SEND_TSBLOCK_LOCAL;
 
-public class LocalSinkHandle implements ISinkHandle {
+public class LocalSinkChannel implements ISinkChannel {
 
-  private static final Logger logger = LoggerFactory.getLogger(LocalSinkHandle.class);
+  private static final Logger LOGGER = LoggerFactory.getLogger(LocalSinkChannel.class);
 
   private TFragmentInstanceId localFragmentInstanceId;
-  private final SinkHandleListener sinkHandleListener;
+  private final SinkListener sinkListener;
 
   private final SharedTsBlockQueue queue;
   private volatile ListenableFuture<Void> blocked;
@@ -49,22 +49,22 @@ public class LocalSinkHandle implements ISinkHandle {
 
   private static final QueryMetricsManager QUERY_METRICS = QueryMetricsManager.getInstance();
 
-  public LocalSinkHandle(SharedTsBlockQueue queue, SinkHandleListener sinkHandleListener) {
-    this.sinkHandleListener = Validate.notNull(sinkHandleListener);
+  public LocalSinkChannel(SharedTsBlockQueue queue, SinkListener sinkListener) {
+    this.sinkListener = Validate.notNull(sinkListener);
     this.queue = Validate.notNull(queue);
-    this.queue.setSinkHandle(this);
+    this.queue.setSinkChannel(this);
     blocked = queue.getCanAddTsBlock();
   }
 
-  public LocalSinkHandle(
+  public LocalSinkChannel(
       TFragmentInstanceId localFragmentInstanceId,
       SharedTsBlockQueue queue,
-      SinkHandleListener sinkHandleListener) {
+      SinkListener sinkListener) {
     this.localFragmentInstanceId = Validate.notNull(localFragmentInstanceId);
-    this.sinkHandleListener = Validate.notNull(sinkHandleListener);
+    this.sinkListener = Validate.notNull(sinkListener);
     this.queue = Validate.notNull(queue);
-    this.queue.setSinkHandle(this);
-    // SinkHandle can send data after SourceHandle asks it to
+    this.queue.setSinkChannel(this);
+    // SinkChannel can send data after SourceHandle asks it to
     blocked = queue.getCanAddTsBlock();
   }
 
@@ -75,7 +75,9 @@ public class LocalSinkHandle implements ISinkHandle {
 
   @Override
   public long getBufferRetainedSizeInBytes() {
-    return queue.getBufferRetainedSizeInBytes();
+    synchronized (queue) {
+      return queue.getBufferRetainedSizeInBytes();
+    }
   }
 
   @Override
@@ -100,7 +102,7 @@ public class LocalSinkHandle implements ISinkHandle {
     synchronized (queue) {
       if (isFinished()) {
         synchronized (this) {
-          sinkHandleListener.onFinish(this);
+          sinkListener.onFinish(this);
         }
       }
     }
@@ -122,7 +124,7 @@ public class LocalSinkHandle implements ISinkHandle {
         if (queue.hasNoMoreTsBlocks()) {
           return;
         }
-        logger.debug("[StartSendTsBlockOnLocal]");
+        LOGGER.debug("[StartSendTsBlockOnLocal]");
         synchronized (this) {
           blocked = queue.add(tsBlock);
         }
@@ -133,37 +135,32 @@ public class LocalSinkHandle implements ISinkHandle {
     }
   }
 
-  @Override
-  public synchronized void send(int partition, List<TsBlock> tsBlocks) {
-    throw new UnsupportedOperationException();
-  }
-
   @Override
   public void setNoMoreTsBlocks() {
     synchronized (queue) {
       synchronized (this) {
-        logger.debug("[StartSetNoMoreTsBlocksOnLocal]");
+        LOGGER.debug("[StartSetNoMoreTsBlocksOnLocal]");
         if (aborted || closed) {
           return;
         }
         queue.setNoMoreTsBlocks(true);
-        sinkHandleListener.onEndOfBlocks(this);
+        sinkListener.onEndOfBlocks(this);
       }
     }
     checkAndInvokeOnFinished();
-    logger.debug("[EndSetNoMoreTsBlocksOnLocal]");
+    LOGGER.debug("[EndSetNoMoreTsBlocksOnLocal]");
   }
 
   @Override
   public void abort() {
-    logger.debug("[StartAbortLocalSinkHandle]");
+    LOGGER.debug("[StartAbortLocalSinkChannel]");
     synchronized (queue) {
       synchronized (this) {
         if (aborted || closed) {
           return;
         }
         aborted = true;
-        Optional<Throwable> t = sinkHandleListener.onAborted(this);
+        Optional<Throwable> t = sinkListener.onAborted(this);
         if (t.isPresent()) {
           queue.abort(t.get());
         } else {
@@ -171,12 +168,12 @@ public class LocalSinkHandle implements ISinkHandle {
         }
       }
     }
-    logger.debug("[EndAbortLocalSinkHandle]");
+    LOGGER.debug("[EndAbortLocalSinkChannel]");
   }
 
   @Override
   public void close() {
-    logger.debug("[StartCloseLocalSinkHandle]");
+    LOGGER.debug("[StartCloseLocalSinkChannel]");
     synchronized (queue) {
       synchronized (this) {
         if (aborted || closed) {
@@ -184,10 +181,10 @@ public class LocalSinkHandle implements ISinkHandle {
         }
         closed = true;
         queue.close();
-        sinkHandleListener.onFinish(this);
+        sinkListener.onFinish(this);
       }
     }
-    logger.debug("[EndCloseLocalSinkHandle]");
+    LOGGER.debug("[EndCloseLocalSinkChannel]");
   }
 
   public SharedTsBlockQueue getSharedTsBlockQueue() {
@@ -196,15 +193,37 @@ public class LocalSinkHandle implements ISinkHandle {
 
   private void checkState() {
     if (aborted) {
-      throw new IllegalStateException("Sink handle is aborted.");
+      throw new IllegalStateException("LocalSinkChannel is aborted.");
     } else if (closed) {
-      throw new IllegalStateException("Sink Handle is closed.");
+      throw new IllegalStateException("LocalSinkChannel is closed.");
     }
   }
 
   @Override
   public void setMaxBytesCanReserve(long maxBytesCanReserve) {
-    // do nothing, the maxBytesCanReserve of SharedTsBlockQueue should be set by corresponding
-    // LocalSourceHandle
+    if (maxBytesCanReserve < queue.getMaxBytesCanReserve()) {
+      queue.setMaxBytesCanReserve(maxBytesCanReserve);
+    }
+  }
+
+  // region ============ ISinkChannel related ============
+
+  @Override
+  public void open() {}
+
+  @Override
+  public boolean isNoMoreTsBlocks() {
+    synchronized (queue) {
+      return queue.hasNoMoreTsBlocks();
+    }
+  }
+
+  @Override
+  public int getNumOfBufferedTsBlocks() {
+    synchronized (queue) {
+      return queue.getNumOfBufferedTsBlocks();
+    }
   }
+
+  // end region
 }
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/sink/ShuffleSinkHandle.java b/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/sink/ShuffleSinkHandle.java
new file mode 100644
index 0000000000..9b18f4f771
--- /dev/null
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/sink/ShuffleSinkHandle.java
@@ -0,0 +1,299 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iotdb.db.mpp.execution.exchange.sink;
+
+import org.apache.iotdb.commons.utils.TestOnly;
+import org.apache.iotdb.db.conf.IoTDBDescriptor;
+import org.apache.iotdb.db.mpp.execution.exchange.MPPDataExchangeManager;
+import org.apache.iotdb.db.mpp.metric.QueryMetricsManager;
+import org.apache.iotdb.mpp.rpc.thrift.TFragmentInstanceId;
+import org.apache.iotdb.tsfile.read.common.block.TsBlock;
+
+import com.google.common.util.concurrent.ListenableFuture;
+import org.apache.commons.lang3.Validate;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.List;
+
+import static org.apache.iotdb.db.mpp.metric.DataExchangeCostMetricSet.SINK_HANDLE_SEND_TSBLOCK_REMOTE;
+
+public class ShuffleSinkHandle implements ISinkHandle {
+
+  private static final Logger LOGGER = LoggerFactory.getLogger(ShuffleSinkHandle.class);
+
+  /** Each ISinkHandle in the list matches one downStream ISourceHandle */
+  private final List<ISinkChannel> downStreamChannelList;
+
+  private final boolean[] hasSetNoMoreTsBlocks;
+
+  private final boolean[] channelOpened;
+
+  private final DownStreamChannelIndex downStreamChannelIndex;
+
+  private final int channelNum;
+
+  private final ShuffleStrategy shuffleStrategy;
+
+  private final String localPlanNodeId;
+
+  private final TFragmentInstanceId localFragmentInstanceId;
+
+  private final MPPDataExchangeManager.SinkListener sinkListener;
+
+  private boolean aborted = false;
+
+  private boolean closed = false;
+
+  private static final QueryMetricsManager QUERY_METRICS = QueryMetricsManager.getInstance();
+
+  /** max bytes this ShuffleSinkHandle can reserve. */
+  private long maxBytesCanReserve =
+      IoTDBDescriptor.getInstance().getConfig().getMaxBytesPerFragmentInstance();
+
+  public ShuffleSinkHandle(
+      TFragmentInstanceId localFragmentInstanceId,
+      List<ISinkChannel> downStreamChannelList,
+      DownStreamChannelIndex downStreamChannelIndex,
+      ShuffleStrategyEnum shuffleStrategyEnum,
+      String localPlanNodeId,
+      MPPDataExchangeManager.SinkListener sinkListener) {
+    this.localFragmentInstanceId = Validate.notNull(localFragmentInstanceId);
+    this.downStreamChannelList = Validate.notNull(downStreamChannelList);
+    this.downStreamChannelIndex = Validate.notNull(downStreamChannelIndex);
+    this.localPlanNodeId = Validate.notNull(localPlanNodeId);
+    this.sinkListener = Validate.notNull(sinkListener);
+    this.channelNum = downStreamChannelList.size();
+    this.shuffleStrategy = getShuffleStrategy(shuffleStrategyEnum);
+    this.hasSetNoMoreTsBlocks = new boolean[channelNum];
+    this.channelOpened = new boolean[channelNum];
+    // open first channel
+    tryOpenChannel(0);
+  }
+
+  @Override
+  public TFragmentInstanceId getLocalFragmentInstanceId() {
+    return localFragmentInstanceId;
+  }
+
+  public ISinkChannel getChannel(int index) {
+    return downStreamChannelList.get(index);
+  }
+
+  @Override
+  public synchronized ListenableFuture<?> isFull() {
+    // It is safe to use currentChannel.isFull() to judge whether we can send a TsBlock only when
+    // downStreamChannelIndex will not be changed between we call isFull() and send() of
+    // ShuffleSinkHandle
+    ISinkChannel currentChannel =
+        downStreamChannelList.get(downStreamChannelIndex.getCurrentIndex());
+    return currentChannel.isFull();
+  }
+
+  @Override
+  public synchronized void send(TsBlock tsBlock) {
+    long startTime = System.nanoTime();
+    try {
+      ISinkChannel currentChannel =
+          downStreamChannelList.get(downStreamChannelIndex.getCurrentIndex());
+      checkState();
+      currentChannel.send(tsBlock);
+    } finally {
+      switchChannelIfNecessary();
+      QUERY_METRICS.recordDataExchangeCost(
+          SINK_HANDLE_SEND_TSBLOCK_REMOTE, System.nanoTime() - startTime);
+    }
+  }
+
+  @Override
+  public synchronized void setNoMoreTsBlocks() {
+    for (int i = 0; i < downStreamChannelList.size(); i++) {
+      if (!hasSetNoMoreTsBlocks[i]) {
+        downStreamChannelList.get(i).setNoMoreTsBlocks();
+        hasSetNoMoreTsBlocks[i] = true;
+      }
+    }
+    sinkListener.onEndOfBlocks(this);
+  }
+
+  @Override
+  public synchronized void setNoMoreTsBlocksOfOneChannel(int channelIndex) {
+    if (!hasSetNoMoreTsBlocks[channelIndex]) {
+      downStreamChannelList.get(channelIndex).setNoMoreTsBlocks();
+      hasSetNoMoreTsBlocks[channelIndex] = true;
+    }
+  }
+
+  @Override
+  public synchronized boolean isAborted() {
+    return aborted;
+  }
+
+  @Override
+  public synchronized boolean isFinished() {
+    for (ISink channel : downStreamChannelList) {
+      if (!channel.isFinished()) {
+        return false;
+      }
+    }
+    return true;
+  }
+
+  @Override
+  public synchronized void abort() {
+    if (aborted) {
+      return;
+    }
+    LOGGER.debug("[StartAbortShuffleSinkHandle]");
+    for (ISink channel : downStreamChannelList) {
+      try {
+        channel.abort();
+      } catch (Exception e) {
+        LOGGER.warn("Error occurred when try to abort channel.");
+      }
+    }
+    aborted = true;
+    sinkListener.onAborted(this);
+    LOGGER.debug("[EndAbortShuffleSinkHandle]");
+  }
+
+  @Override
+  public synchronized void close() {
+    if (closed) {
+      return;
+    }
+    LOGGER.debug("[StartCloseShuffleSinkHandle]");
+    for (ISink channel : downStreamChannelList) {
+      try {
+        channel.close();
+      } catch (Exception e) {
+        LOGGER.warn("Error occurred when try to abort channel.");
+      }
+    }
+    closed = true;
+    sinkListener.onFinish(this);
+    LOGGER.debug("[EndCloseShuffleSinkHandle]");
+  }
+
+  @Override
+  public void setMaxBytesCanReserve(long maxBytesCanReserve) {
+    this.maxBytesCanReserve = maxBytesCanReserve;
+    downStreamChannelList.forEach(
+        sinkHandle -> sinkHandle.setMaxBytesCanReserve(maxBytesCanReserve));
+  }
+
+  private void checkState() {
+    if (aborted) {
+      throw new IllegalStateException("ShuffleSinkHandle is aborted.");
+    } else if (closed) {
+      throw new IllegalStateException("ShuffleSinkHandle is closed.");
+    }
+  }
+
+  private void switchChannelIfNecessary() {
+    shuffleStrategy.shuffle();
+    tryOpenChannel(downStreamChannelIndex.getCurrentIndex());
+  }
+
+  public void tryOpenChannel(int channelIndex) {
+    if (!channelOpened[channelIndex]) {
+      downStreamChannelList.get(channelIndex).open();
+      channelOpened[channelIndex] = true;
+    }
+  }
+
+  // region ============ Shuffle Related ============
+  public enum ShuffleStrategyEnum {
+    PLAIN,
+    SIMPLE_ROUND_ROBIN,
+  }
+
+  @FunctionalInterface
+  interface ShuffleStrategy {
+    /*
+     SinkHandle may have multiple channels, we need to choose the next channel each time we send a TsBlock.
+    */
+    void shuffle();
+  }
+
+  class PlainShuffleStrategy implements ShuffleStrategy {
+
+    @Override
+    public void shuffle() {
+      // do nothing
+      if (LOGGER.isDebugEnabled()) {
+        LOGGER.debug(
+            "PlainShuffleStrategy needs to do nothing, current channel index is {}",
+            downStreamChannelIndex.getCurrentIndex());
+      }
+    }
+  }
+
+  class SimpleRoundRobinStrategy implements ShuffleStrategy {
+
+    private final long channelMemoryThreshold = maxBytesCanReserve / channelNum * 3;
+
+    @Override
+    public void shuffle() {
+      int currentIndex = downStreamChannelIndex.getCurrentIndex();
+      for (int i = 1; i < channelNum; i++) {
+        int nextIndex = (currentIndex + i) % channelNum;
+        if (satisfy(nextIndex)) {
+          downStreamChannelIndex.setCurrentIndex(nextIndex);
+          return;
+        }
+      }
+    }
+
+    private boolean satisfy(int channelIndex) {
+      // downStreamChannel is always an ISinkChannel
+      ISinkChannel channel = downStreamChannelList.get(channelIndex);
+      if (channel.isNoMoreTsBlocks()) {
+        return false;
+      }
+      return channel.getBufferRetainedSizeInBytes() <= channelMemoryThreshold
+          && channel.getNumOfBufferedTsBlocks() < 3;
+    }
+  }
+
+  private ShuffleStrategy getShuffleStrategy(ShuffleStrategyEnum strategyEnum) {
+    switch (strategyEnum) {
+      case PLAIN:
+        return new PlainShuffleStrategy();
+      case SIMPLE_ROUND_ROBIN:
+        return new SimpleRoundRobinStrategy();
+      default:
+        throw new UnsupportedOperationException("Unsupported type of shuffle strategy");
+    }
+  }
+
+  // endregion
+
+  // region ============= Test Only =============
+  @TestOnly
+  @Override
+  public long getBufferRetainedSizeInBytes() {
+    return downStreamChannelList.stream()
+        .map(ISink::getBufferRetainedSizeInBytes)
+        .reduce(Long::sum)
+        .orElse(0L);
+  }
+  // endregion
+}
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/SinkHandle.java b/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/sink/SinkChannel.java
similarity index 84%
rename from server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/SinkHandle.java
rename to server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/sink/SinkChannel.java
index fbc36ed34d..9ad82768c6 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/SinkHandle.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/sink/SinkChannel.java
@@ -17,7 +17,7 @@
  * under the License.
  */
 
-package org.apache.iotdb.db.mpp.execution.exchange;
+package org.apache.iotdb.db.mpp.execution.exchange.sink;
 
 import org.apache.iotdb.common.rpc.thrift.TEndPoint;
 import org.apache.iotdb.commons.client.IClientManager;
@@ -25,7 +25,7 @@ import org.apache.iotdb.commons.client.sync.SyncDataNodeMPPDataExchangeServiceCl
 import org.apache.iotdb.commons.utils.TestOnly;
 import org.apache.iotdb.db.conf.IoTDBDescriptor;
 import org.apache.iotdb.db.mpp.common.FragmentInstanceId;
-import org.apache.iotdb.db.mpp.execution.exchange.MPPDataExchangeManager.SinkHandleListener;
+import org.apache.iotdb.db.mpp.execution.exchange.MPPDataExchangeManager.SinkListener;
 import org.apache.iotdb.db.mpp.execution.memory.LocalMemoryManager;
 import org.apache.iotdb.db.mpp.metric.QueryMetricsManager;
 import org.apache.iotdb.db.utils.SetThreadName;
@@ -57,9 +57,9 @@ import static org.apache.iotdb.db.mpp.metric.DataExchangeCostMetricSet.SINK_HAND
 import static org.apache.iotdb.db.mpp.metric.DataExchangeCountMetricSet.SEND_NEW_DATA_BLOCK_NUM_CALLER;
 import static org.apache.iotdb.tsfile.read.common.block.TsBlockBuilderStatus.DEFAULT_MAX_TSBLOCK_SIZE_IN_BYTES;
 
-public class SinkHandle implements ISinkHandle {
+public class SinkChannel implements ISinkChannel {
 
-  private static final Logger logger = LoggerFactory.getLogger(SinkHandle.class);
+  private static final Logger LOGGER = LoggerFactory.getLogger(SinkChannel.class);
 
   public static final int MAX_ATTEMPT_TIMES = 3;
   private static final long DEFAULT_RETRY_INTERVAL_IN_MS = 1000L;
@@ -75,12 +75,12 @@ public class SinkHandle implements ISinkHandle {
   private final LocalMemoryManager localMemoryManager;
   private final ExecutorService executorService;
   private final TsBlockSerde serde;
-  private final SinkHandleListener sinkHandleListener;
+  private final SinkListener sinkListener;
   private final String threadName;
   private long retryIntervalInMs;
 
   // Use LinkedHashMap to meet 2 needs,
-  //   1. Predictable iteration order so that removing buffered tsblocks can be efficient.
+  //   1. Predictable iteration order so that removing buffered TsBlocks can be efficient.
   //   2. Fast lookup.
   private final LinkedHashMap<Integer, Pair<TsBlock, Long>> sequenceIdToTsBlock =
       new LinkedHashMap<>();
@@ -102,13 +102,13 @@ public class SinkHandle implements ISinkHandle {
 
   private boolean noMoreTsBlocks = false;
 
-  /** max bytes this SourceHandle can reserve. */
+  /** max bytes this SinkChannel can reserve. */
   private long maxBytesCanReserve =
       IoTDBDescriptor.getInstance().getConfig().getMaxBytesPerFragmentInstance();
 
   private static final QueryMetricsManager QUERY_METRICS = QueryMetricsManager.getInstance();
 
-  public SinkHandle(
+  public SinkChannel(
       TEndPoint remoteEndpoint,
       TFragmentInstanceId remoteFragmentInstanceId,
       String remotePlanNodeId,
@@ -117,7 +117,7 @@ public class SinkHandle implements ISinkHandle {
       LocalMemoryManager localMemoryManager,
       ExecutorService executorService,
       TsBlockSerde serde,
-      SinkHandleListener sinkHandleListener,
+      SinkListener sinkListener,
       IClientManager<TEndPoint, SyncDataNodeMPPDataExchangeServiceClient>
           mppDataExchangeServiceClientManager) {
     this.remoteEndpoint = Validate.notNull(remoteEndpoint);
@@ -130,7 +130,7 @@ public class SinkHandle implements ISinkHandle {
     this.localMemoryManager = Validate.notNull(localMemoryManager);
     this.executorService = Validate.notNull(executorService);
     this.serde = Validate.notNull(serde);
-    this.sinkHandleListener = Validate.notNull(sinkHandleListener);
+    this.sinkListener = Validate.notNull(sinkListener);
     this.mppDataExchangeServiceClientManager = mppDataExchangeServiceClientManager;
     this.retryIntervalInMs = DEFAULT_RETRY_INTERVAL_IN_MS;
     this.threadName =
@@ -138,18 +138,6 @@ public class SinkHandle implements ISinkHandle {
             localFragmentInstanceId.queryId,
             localFragmentInstanceId.fragmentId,
             localFragmentInstanceId.instanceId);
-    this.blocked =
-        localMemoryManager
-            .getQueryPool()
-            .reserve(
-                localFragmentInstanceId.getQueryId(),
-                fullFragmentInstanceId,
-                localPlanNodeId,
-                DEFAULT_MAX_TSBLOCK_SIZE_IN_BYTES,
-                DEFAULT_MAX_TSBLOCK_SIZE_IN_BYTES) // actually we only know maxBytesCanReserve after
-            // the handle is created, so we use DEFAULT here. It is ok to use DEFAULT here because
-            // at first this SinkHandle has not reserved memory.
-            .left;
     this.bufferRetainedSizeInBytes = DEFAULT_MAX_TSBLOCK_SIZE_IN_BYTES;
     this.currentTsBlockSize = DEFAULT_MAX_TSBLOCK_SIZE_IN_BYTES;
   }
@@ -203,14 +191,9 @@ public class SinkHandle implements ISinkHandle {
     }
   }
 
-  @Override
-  public synchronized void send(int partition, List<TsBlock> tsBlocks) {
-    throw new UnsupportedOperationException();
-  }
-
   @Override
   public synchronized void setNoMoreTsBlocks() {
-    logger.debug("[StartSetNoMoreTsBlocks]");
+    LOGGER.debug("[StartSetNoMoreTsBlocks]");
     if (aborted || closed) {
       return;
     }
@@ -219,9 +202,11 @@ public class SinkHandle implements ISinkHandle {
 
   @Override
   public synchronized void abort() {
-    logger.debug("[StartAbortSinkHandle]");
+    LOGGER.debug("[StartAbortSinkChannel]");
+    if (aborted) {
+      return;
+    }
     sequenceIdToTsBlock.clear();
-    aborted = true;
     bufferRetainedSizeInBytes -= localMemoryManager.getQueryPool().tryCancel(blocked);
     if (bufferRetainedSizeInBytes > 0) {
       localMemoryManager
@@ -237,15 +222,18 @@ public class SinkHandle implements ISinkHandle {
         .getQueryPool()
         .clearMemoryReservationMap(
             localFragmentInstanceId.getQueryId(), fullFragmentInstanceId, localPlanNodeId);
-    sinkHandleListener.onAborted(this);
-    logger.debug("[EndAbortSinkHandle]");
+    sinkListener.onAborted(this);
+    aborted = true;
+    LOGGER.debug("[EndAbortSinkChannel]");
   }
 
   @Override
   public synchronized void close() {
-    logger.debug("[StartCloseSinkHandle]");
+    LOGGER.debug("[StartCloseSinkChannel]");
+    if (closed) {
+      return;
+    }
     sequenceIdToTsBlock.clear();
-    closed = true;
     bufferRetainedSizeInBytes -= localMemoryManager.getQueryPool().tryComplete(blocked);
     if (bufferRetainedSizeInBytes > 0) {
       localMemoryManager
@@ -261,8 +249,9 @@ public class SinkHandle implements ISinkHandle {
         .getQueryPool()
         .clearMemoryReservationMap(
             localFragmentInstanceId.getQueryId(), fullFragmentInstanceId, localPlanNodeId);
-    sinkHandleListener.onFinish(this);
-    logger.debug("[EndCloseSinkHandle]");
+    sinkListener.onFinish(this);
+    closed = true;
+    LOGGER.debug("[EndCloseSinkChannel]");
   }
 
   @Override
@@ -280,25 +269,21 @@ public class SinkHandle implements ISinkHandle {
     return bufferRetainedSizeInBytes;
   }
 
-  public int getNumOfBufferedTsBlocks() {
-    return sequenceIdToTsBlock.size();
-  }
-
-  ByteBuffer getSerializedTsBlock(int partition, int sequenceId) {
+  public ByteBuffer getSerializedTsBlock(int partition, int sequenceId) {
     throw new UnsupportedOperationException();
   }
 
-  synchronized ByteBuffer getSerializedTsBlock(int sequenceId) throws IOException {
+  public synchronized ByteBuffer getSerializedTsBlock(int sequenceId) throws IOException {
     if (aborted || closed) {
-      logger.warn(
-          "SinkHandle still receive getting TsBlock request after being aborted={} or closed={}",
+      LOGGER.warn(
+          "SinkChannel still receive getting TsBlock request after being aborted={} or closed={}",
           aborted,
           closed);
-      throw new IllegalStateException("Sink handle is aborted or closed. ");
+      throw new IllegalStateException("SinkChannel is aborted or closed. ");
     }
     Pair<TsBlock, Long> pair = sequenceIdToTsBlock.get(sequenceId);
     if (pair == null || pair.left == null) {
-      logger.error(
+      LOGGER.warn(
           "The TsBlock doesn't exist. Sequence ID is {}, remaining map is {}",
           sequenceId,
           sequenceIdToTsBlock.entrySet());
@@ -307,7 +292,7 @@ public class SinkHandle implements ISinkHandle {
     return serde.serialize(pair.left);
   }
 
-  void acknowledgeTsBlock(int startSequenceId, int endSequenceId) {
+  public void acknowledgeTsBlock(int startSequenceId, int endSequenceId) {
     long freedBytes = 0L;
     synchronized (this) {
       if (aborted || closed) {
@@ -327,11 +312,11 @@ public class SinkHandle implements ISinkHandle {
         freedBytes += entry.getValue().right;
         bufferRetainedSizeInBytes -= entry.getValue().right;
         iterator.remove();
-        logger.debug("[ACKTsBlock] {}.", entry.getKey());
+        LOGGER.debug("[ACKTsBlock] {}.", entry.getKey());
       }
     }
     if (isFinished()) {
-      sinkHandleListener.onFinish(this);
+      sinkListener.onFinish(this);
     }
     // there may exist duplicate ack message in network caused by caller retrying, if so duplicate
     // ack message's freedBytes may be zero
@@ -346,18 +331,7 @@ public class SinkHandle implements ISinkHandle {
     }
   }
 
-  public TEndPoint getRemoteEndpoint() {
-    return remoteEndpoint;
-  }
-
-  public TFragmentInstanceId getRemoteFragmentInstanceId() {
-    return remoteFragmentInstanceId;
-  }
-
-  public String getRemotePlanNodeId() {
-    return remotePlanNodeId;
-  }
-
+  @Override
   public TFragmentInstanceId getLocalFragmentInstanceId() {
     return localFragmentInstanceId;
   }
@@ -370,7 +344,7 @@ public class SinkHandle implements ISinkHandle {
   @Override
   public String toString() {
     return String.format(
-        "Query[%s]-[%s-%s-SinkHandle]:",
+        "Query[%s]-[%s-%s-SinkChannel]:",
         localFragmentInstanceId.queryId,
         localFragmentInstanceId.fragmentId,
         localFragmentInstanceId.instanceId);
@@ -378,17 +352,50 @@ public class SinkHandle implements ISinkHandle {
 
   private void checkState() {
     if (aborted) {
-      throw new IllegalStateException("Sink handle is aborted.");
+      throw new IllegalStateException("SinkChannel is aborted.");
     } else if (closed) {
-      throw new IllegalStateException("SinkHandle is closed.");
+      throw new IllegalStateException("SinkChannel is closed.");
     }
   }
 
+  // region ============ ISinkChannel related ============
+
+  public void open() {
+    // SinkChannel is opened when ShuffleSinkHandle choose it as the next channel
+    this.blocked =
+        localMemoryManager
+            .getQueryPool()
+            .reserve(
+                localFragmentInstanceId.getQueryId(),
+                fullFragmentInstanceId,
+                localPlanNodeId,
+                DEFAULT_MAX_TSBLOCK_SIZE_IN_BYTES,
+                maxBytesCanReserve) // actually we only know maxBytesCanReserve after
+            // the handle is created, so we use DEFAULT here. It is ok to use DEFAULT here because
+            // at first this SinkChannel has not reserved memory.
+            .left;
+  }
+
+  @Override
+  public boolean isNoMoreTsBlocks() {
+    return noMoreTsBlocks;
+  }
+
+  @Override
+  public int getNumOfBufferedTsBlocks() {
+    return sequenceIdToTsBlock.size();
+  }
+
+  // endregion
+
+  // region ============ TestOnly ============
   @TestOnly
   public void setRetryIntervalInMs(long retryIntervalInMs) {
     this.retryIntervalInMs = retryIntervalInMs;
   }
+  // endregion
 
+  // region ============ inner class ============
   /**
    * Send a {@link org.apache.iotdb.mpp.rpc.thrift.TNewDataBlockEvent} to downstream fragment
    * instance.
@@ -410,8 +417,8 @@ public class SinkHandle implements ISinkHandle {
 
     @Override
     public void run() {
-      try (SetThreadName sinkHandleName = new SetThreadName(threadName)) {
-        logger.debug(
+      try (SetThreadName sinkChannelName = new SetThreadName(threadName)) {
+        LOGGER.debug(
             "[NotifyNewTsBlock] [{}, {})", startSequenceId, startSequenceId + blockSizes.size());
         int attempt = 0;
         TNewDataBlockEvent newDataBlockEvent =
@@ -429,15 +436,15 @@ public class SinkHandle implements ISinkHandle {
             client.onNewDataBlockEvent(newDataBlockEvent);
             break;
           } catch (Exception e) {
-            logger.warn("Failed to send new data block event, attempt times: {}", attempt, e);
+            LOGGER.warn("Failed to send new data block event, attempt times: {}", attempt, e);
             if (attempt == MAX_ATTEMPT_TIMES) {
-              sinkHandleListener.onFailure(SinkHandle.this, e);
+              sinkListener.onFailure(SinkChannel.this, e);
             }
             try {
               Thread.sleep(retryIntervalInMs);
             } catch (InterruptedException ex) {
               Thread.currentThread().interrupt();
-              sinkHandleListener.onFailure(SinkHandle.this, e);
+              sinkListener.onFailure(SinkChannel.this, e);
             }
           } finally {
             QUERY_METRICS.recordDataExchangeCost(
@@ -457,8 +464,8 @@ public class SinkHandle implements ISinkHandle {
 
     @Override
     public void run() {
-      try (SetThreadName sinkHandleName = new SetThreadName(threadName)) {
-        logger.debug("[NotifyNoMoreTsBlock]");
+      try (SetThreadName sinkChannelName = new SetThreadName(threadName)) {
+        LOGGER.debug("[NotifyNoMoreTsBlock]");
         int attempt = 0;
         TEndOfDataBlockEvent endOfDataBlockEvent =
             new TEndOfDataBlockEvent(
@@ -473,26 +480,28 @@ public class SinkHandle implements ISinkHandle {
             client.onEndOfDataBlockEvent(endOfDataBlockEvent);
             break;
           } catch (Exception e) {
-            logger.warn("Failed to send end of data block event, attempt times: {}", attempt, e);
+            LOGGER.warn("Failed to send end of data block event, attempt times: {}", attempt, e);
             if (attempt == MAX_ATTEMPT_TIMES) {
-              logger.warn("Failed to send end of data block event after all retry", e);
-              sinkHandleListener.onFailure(SinkHandle.this, e);
+              LOGGER.warn("Failed to send end of data block event after all retry", e);
+              sinkListener.onFailure(SinkChannel.this, e);
               return;
             }
             try {
               Thread.sleep(retryIntervalInMs);
             } catch (InterruptedException ex) {
               Thread.currentThread().interrupt();
-              sinkHandleListener.onFailure(SinkHandle.this, e);
+              sinkListener.onFailure(SinkChannel.this, e);
             }
           }
         }
         noMoreTsBlocks = true;
         if (isFinished()) {
-          sinkHandleListener.onFinish(SinkHandle.this);
+          sinkListener.onFinish(SinkChannel.this);
         }
-        sinkHandleListener.onEndOfBlocks(SinkHandle.this);
+        sinkListener.onEndOfBlocks(SinkChannel.this);
       }
     }
   }
+  // endregion
+
 }
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/ISourceHandle.java b/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/source/ISourceHandle.java
similarity index 98%
rename from server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/ISourceHandle.java
rename to server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/source/ISourceHandle.java
index d056717060..14ac3429e6 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/ISourceHandle.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/source/ISourceHandle.java
@@ -16,7 +16,7 @@
  * specific language governing permissions and limitations
  * under the License.
  */
-package org.apache.iotdb.db.mpp.execution.exchange;
+package org.apache.iotdb.db.mpp.execution.exchange.source;
 
 import org.apache.iotdb.commons.exception.IoTDBException;
 import org.apache.iotdb.mpp.rpc.thrift.TFragmentInstanceId;
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/LocalSourceHandle.java b/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/source/LocalSourceHandle.java
similarity index 92%
rename from server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/LocalSourceHandle.java
rename to server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/source/LocalSourceHandle.java
index f47dea04d6..7dc6ad2983 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/LocalSourceHandle.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/source/LocalSourceHandle.java
@@ -17,10 +17,11 @@
  * under the License.
  */
 
-package org.apache.iotdb.db.mpp.execution.exchange;
+package org.apache.iotdb.db.mpp.execution.exchange.source;
 
 import org.apache.iotdb.commons.exception.IoTDBException;
 import org.apache.iotdb.db.mpp.execution.exchange.MPPDataExchangeManager.SourceHandleListener;
+import org.apache.iotdb.db.mpp.execution.exchange.SharedTsBlockQueue;
 import org.apache.iotdb.db.mpp.metric.QueryMetricsManager;
 import org.apache.iotdb.db.utils.SetThreadName;
 import org.apache.iotdb.mpp.rpc.thrift.TFragmentInstanceId;
@@ -42,7 +43,7 @@ import static org.apache.iotdb.db.mpp.metric.DataExchangeCostMetricSet.SOURCE_HA
 
 public class LocalSourceHandle implements ISourceHandle {
 
-  private static final Logger logger = LoggerFactory.getLogger(LocalSourceHandle.class);
+  private static final Logger LOGGER = LoggerFactory.getLogger(LocalSourceHandle.class);
 
   private TFragmentInstanceId localFragmentInstanceId;
   private String localPlanNodeId;
@@ -111,7 +112,7 @@ public class LocalSourceHandle implements ISourceHandle {
         tsBlock = queue.remove();
       }
       if (tsBlock != null) {
-        logger.debug(
+        LOGGER.debug(
             "[GetTsBlockFromQueue] TsBlock:{} size:{}",
             currSequenceId,
             tsBlock.getRetainedSizeInBytes());
@@ -180,7 +181,7 @@ public class LocalSourceHandle implements ISourceHandle {
       return;
     }
     try (SetThreadName sourceHandleName = new SetThreadName(threadName)) {
-      logger.debug("[StartAbortLocalSourceHandle]");
+      LOGGER.debug("[StartAbortLocalSourceHandle]");
       synchronized (queue) {
         synchronized (this) {
           if (aborted || closed) {
@@ -191,7 +192,7 @@ public class LocalSourceHandle implements ISourceHandle {
           sourceHandleListener.onAborted(this);
         }
       }
-      logger.debug("[EndAbortLocalSourceHandle]");
+      LOGGER.debug("[EndAbortLocalSourceHandle]");
     }
   }
 
@@ -201,7 +202,7 @@ public class LocalSourceHandle implements ISourceHandle {
       return;
     }
     try (SetThreadName sourceHandleName = new SetThreadName(threadName)) {
-      logger.debug("[StartAbortLocalSourceHandle]");
+      LOGGER.debug("[StartAbortLocalSourceHandle]");
       synchronized (queue) {
         synchronized (this) {
           if (aborted || closed) {
@@ -212,7 +213,7 @@ public class LocalSourceHandle implements ISourceHandle {
           sourceHandleListener.onAborted(this);
         }
       }
-      logger.debug("[EndAbortLocalSourceHandle]");
+      LOGGER.debug("[EndAbortLocalSourceHandle]");
     }
   }
 
@@ -222,7 +223,7 @@ public class LocalSourceHandle implements ISourceHandle {
       return;
     }
     try (SetThreadName sourceHandleName = new SetThreadName(threadName)) {
-      logger.debug("[StartCloseLocalSourceHandle]");
+      LOGGER.debug("[StartCloseLocalSourceHandle]");
       synchronized (queue) {
         synchronized (this) {
           if (aborted || closed) {
@@ -233,7 +234,7 @@ public class LocalSourceHandle implements ISourceHandle {
           sourceHandleListener.onFinished(this);
         }
       }
-      logger.debug("[EndCloseLocalSourceHandle]");
+      LOGGER.debug("[EndCloseLocalSourceHandle]");
     }
   }
 
@@ -251,8 +252,7 @@ public class LocalSourceHandle implements ISourceHandle {
 
   @Override
   public void setMaxBytesCanReserve(long maxBytesCanReserve) {
-    if (maxBytesCanReserve < queue.getMaxBytesCanReserve()) {
-      queue.setMaxBytesCanReserve(maxBytesCanReserve);
-    }
+    // do nothing, the maxBytesCanReserve of SharedTsBlockQueue should be set by corresponding
+    // LocalSinkChannel
   }
 }
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/SourceHandle.java b/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/source/SourceHandle.java
similarity index 94%
rename from server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/SourceHandle.java
rename to server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/source/SourceHandle.java
index 386fdd40b8..10821ca1ac 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/SourceHandle.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/source/SourceHandle.java
@@ -17,7 +17,7 @@
  * under the License.
  */
 
-package org.apache.iotdb.db.mpp.execution.exchange;
+package org.apache.iotdb.db.mpp.execution.exchange.source;
 
 import org.apache.iotdb.common.rpc.thrift.TEndPoint;
 import org.apache.iotdb.commons.client.IClientManager;
@@ -61,7 +61,7 @@ import static org.apache.iotdb.db.mpp.metric.DataExchangeCountMetricSet.ON_ACKNO
 
 public class SourceHandle implements ISourceHandle {
 
-  private static final Logger logger = LoggerFactory.getLogger(SourceHandle.class);
+  private static final Logger LOGGER = LoggerFactory.getLogger(SourceHandle.class);
 
   public static final int MAX_ATTEMPT_TIMES = 3;
   private static final long DEFAULT_RETRY_INTERVAL_IN_MS = 1000;
@@ -72,6 +72,8 @@ public class SourceHandle implements ISourceHandle {
 
   private final String fullFragmentInstanceId;
   private final String localPlanNodeId;
+
+  private final int indexOfUpstreamSinkHandle;
   private final LocalMemoryManager localMemoryManager;
   private final ExecutorService executorService;
   private final TsBlockSerde serde;
@@ -117,6 +119,7 @@ public class SourceHandle implements ISourceHandle {
       TFragmentInstanceId remoteFragmentInstanceId,
       TFragmentInstanceId localFragmentInstanceId,
       String localPlanNodeId,
+      int indexOfUpstreamSinkHandle,
       LocalMemoryManager localMemoryManager,
       ExecutorService executorService,
       TsBlockSerde serde,
@@ -129,6 +132,7 @@ public class SourceHandle implements ISourceHandle {
     this.fullFragmentInstanceId =
         FragmentInstanceId.createFragmentInstanceIdFromTFragmentInstanceId(localFragmentInstanceId);
     this.localPlanNodeId = Validate.notNull(localPlanNodeId);
+    this.indexOfUpstreamSinkHandle = indexOfUpstreamSinkHandle;
     this.localMemoryManager = Validate.notNull(localMemoryManager);
     this.executorService = Validate.notNull(executorService);
     this.serde = Validate.notNull(serde);
@@ -170,7 +174,7 @@ public class SourceHandle implements ISourceHandle {
         return null;
       }
       long retainedSize = sequenceIdToDataBlockSize.remove(currSequenceId);
-      logger.debug("[GetTsBlockFromBuffer] sequenceId:{}, size:{}", currSequenceId, retainedSize);
+      LOGGER.debug("[GetTsBlockFromBuffer] sequenceId:{}, size:{}", currSequenceId, retainedSize);
       currSequenceId += 1;
       bufferRetainedSizeInBytes -= retainedSize;
       localMemoryManager
@@ -182,7 +186,7 @@ public class SourceHandle implements ISourceHandle {
               retainedSize);
 
       if (sequenceIdToTsBlock.isEmpty() && !isFinished()) {
-        logger.debug("[WaitForMoreTsBlock]");
+        LOGGER.debug("[WaitForMoreTsBlock]");
         blocked = SettableFuture.create();
       }
       if (isFinished()) {
@@ -278,8 +282,8 @@ public class SourceHandle implements ISourceHandle {
     return nonCancellationPropagating(blocked);
   }
 
-  synchronized void setNoMoreTsBlocks(int lastSequenceId) {
-    logger.debug("[ReceiveNoMoreTsBlockEvent]");
+  public synchronized void setNoMoreTsBlocks(int lastSequenceId) {
+    LOGGER.debug("[ReceiveNoMoreTsBlockEvent]");
     this.lastSequenceId = lastSequenceId;
     if (!blocked.isDone() && remoteTsBlockedConsumedUp()) {
       blocked.set(null);
@@ -289,8 +293,9 @@ public class SourceHandle implements ISourceHandle {
     }
   }
 
-  synchronized void updatePendingDataBlockInfo(int startSequenceId, List<Long> dataBlockSizes) {
-    logger.debug(
+  public synchronized void updatePendingDataBlockInfo(
+      int startSequenceId, List<Long> dataBlockSizes) {
+    LOGGER.debug(
         "[ReceiveNewTsBlockNotification] [{}, {}), each size is: {}",
         startSequenceId,
         startSequenceId + dataBlockSizes.size(),
@@ -465,9 +470,13 @@ public class SourceHandle implements ISourceHandle {
     @Override
     public void run() {
       try (SetThreadName sourceHandleName = new SetThreadName(threadName)) {
-        logger.debug("[StartPullTsBlocksFromRemote] [{}, {}) ", startSequenceId, endSequenceId);
+        LOGGER.debug("[StartPullTsBlocksFromRemote] [{}, {}) ", startSequenceId, endSequenceId);
         TGetDataBlockRequest req =
-            new TGetDataBlockRequest(remoteFragmentInstanceId, startSequenceId, endSequenceId);
+            new TGetDataBlockRequest(
+                remoteFragmentInstanceId,
+                startSequenceId,
+                endSequenceId,
+                indexOfUpstreamSinkHandle);
         int attempt = 0;
         while (attempt < MAX_ATTEMPT_TIMES) {
           attempt += 1;
@@ -481,7 +490,7 @@ public class SourceHandle implements ISourceHandle {
             List<ByteBuffer> tsBlocks = new ArrayList<>(tsBlockNum);
             tsBlocks.addAll(resp.getTsBlocks());
 
-            logger.debug("[EndPullTsBlocksFromRemote] Count:{}", tsBlockNum);
+            LOGGER.debug("[EndPullTsBlocksFromRemote] Count:{}", tsBlockNum);
             QUERY_METRICS.recordDataBlockNum(GET_DATA_BLOCK_NUM_CALLER, tsBlockNum);
             executorService.submit(
                 new SendAcknowledgeDataBlockEventTask(startSequenceId, endSequenceId));
@@ -492,7 +501,7 @@ public class SourceHandle implements ISourceHandle {
               for (int i = startSequenceId; i < endSequenceId; i++) {
                 sequenceIdToTsBlock.put(i, tsBlocks.get(i - startSequenceId));
               }
-              logger.debug("[PutTsBlocksIntoBuffer]");
+              LOGGER.debug("[PutTsBlocksIntoBuffer]");
               if (!blocked.isDone()) {
                 blocked.set(null);
               }
@@ -500,7 +509,7 @@ public class SourceHandle implements ISourceHandle {
             break;
           } catch (Throwable e) {
 
-            logger.warn(
+            LOGGER.warn(
                 "failed to get data block [{}, {}), attempt times: {}",
                 startSequenceId,
                 endSequenceId,
@@ -561,11 +570,14 @@ public class SourceHandle implements ISourceHandle {
     @Override
     public void run() {
       try (SetThreadName sourceHandleName = new SetThreadName(threadName)) {
-        logger.debug("[SendACKTsBlock] [{}, {}).", startSequenceId, endSequenceId);
+        LOGGER.debug("[SendACKTsBlock] [{}, {}).", startSequenceId, endSequenceId);
         int attempt = 0;
         TAcknowledgeDataBlockEvent acknowledgeDataBlockEvent =
             new TAcknowledgeDataBlockEvent(
-                remoteFragmentInstanceId, startSequenceId, endSequenceId);
+                remoteFragmentInstanceId,
+                startSequenceId,
+                endSequenceId,
+                indexOfUpstreamSinkHandle);
         while (attempt < MAX_ATTEMPT_TIMES) {
           attempt += 1;
           long startTime = System.nanoTime();
@@ -574,7 +586,7 @@ public class SourceHandle implements ISourceHandle {
             client.onAcknowledgeDataBlockEvent(acknowledgeDataBlockEvent);
             break;
           } catch (Throwable e) {
-            logger.warn(
+            LOGGER.warn(
                 "failed to send ack data block event [{}, {}), attempt times: {}",
                 startSequenceId,
                 endSequenceId,
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/execution/fragment/FragmentInstanceExecution.java b/server/src/main/java/org/apache/iotdb/db/mpp/execution/fragment/FragmentInstanceExecution.java
index 51c783da86..2415bed270 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/execution/fragment/FragmentInstanceExecution.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/execution/fragment/FragmentInstanceExecution.java
@@ -20,7 +20,7 @@ package org.apache.iotdb.db.mpp.execution.fragment;
 
 import org.apache.iotdb.db.mpp.common.FragmentInstanceId;
 import org.apache.iotdb.db.mpp.execution.driver.IDriver;
-import org.apache.iotdb.db.mpp.execution.exchange.ISinkHandle;
+import org.apache.iotdb.db.mpp.execution.exchange.sink.ISink;
 import org.apache.iotdb.db.mpp.execution.schedule.IDriverScheduler;
 import org.apache.iotdb.db.utils.SetThreadName;
 
@@ -43,7 +43,7 @@ public class FragmentInstanceExecution {
   private List<IDriver> drivers;
 
   // it will be set to null while this FI is FINISHED
-  private ISinkHandle sinkHandle;
+  private ISink sink;
 
   private final FragmentInstanceStateMachine stateMachine;
 
@@ -54,7 +54,7 @@ public class FragmentInstanceExecution {
       FragmentInstanceId instanceId,
       FragmentInstanceContext context,
       List<IDriver> drivers,
-      ISinkHandle sinkHandle,
+      ISink sinkHandle,
       FragmentInstanceStateMachine stateMachine,
       CounterStat failedInstances,
       long timeOut) {
@@ -70,12 +70,12 @@ public class FragmentInstanceExecution {
       FragmentInstanceId instanceId,
       FragmentInstanceContext context,
       List<IDriver> drivers,
-      ISinkHandle sinkHandle,
+      ISink sink,
       FragmentInstanceStateMachine stateMachine) {
     this.instanceId = instanceId;
     this.context = context;
     this.drivers = drivers;
-    this.sinkHandle = sinkHandle;
+    this.sink = sink;
     this.stateMachine = stateMachine;
   }
 
@@ -119,14 +119,14 @@ public class FragmentInstanceExecution {
             }
 
             if (newState.isFailed()) {
-              sinkHandle.abort();
+              sink.abort();
             } else {
-              sinkHandle.close();
+              sink.close();
             }
             // help for gc
-            sinkHandle = null;
-            // close the driver after sinkHandle is aborted or closed because in driver.close() it
-            // will try to call ISinkHandle.setNoMoreTsBlocks()
+            sink = null;
+            // close the driver after sink is aborted or closed because in driver.close() it
+            // will try to call ISink.setNoMoreTsBlocks()
             for (IDriver driver : drivers) {
               driver.close();
             }
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/execution/fragment/FragmentInstanceManager.java b/server/src/main/java/org/apache/iotdb/db/mpp/execution/fragment/FragmentInstanceManager.java
index 794174b939..4c4dd4187b 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/execution/fragment/FragmentInstanceManager.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/execution/fragment/FragmentInstanceManager.java
@@ -25,7 +25,7 @@ import org.apache.iotdb.db.engine.storagegroup.IDataRegionForQuery;
 import org.apache.iotdb.db.metadata.schemaregion.ISchemaRegion;
 import org.apache.iotdb.db.mpp.common.FragmentInstanceId;
 import org.apache.iotdb.db.mpp.execution.driver.IDriver;
-import org.apache.iotdb.db.mpp.execution.exchange.ISinkHandle;
+import org.apache.iotdb.db.mpp.execution.exchange.sink.ISink;
 import org.apache.iotdb.db.mpp.execution.schedule.DriverScheduler;
 import org.apache.iotdb.db.mpp.execution.schedule.IDriverScheduler;
 import org.apache.iotdb.db.mpp.metric.QueryMetricsManager;
@@ -138,15 +138,15 @@ public class FragmentInstanceManager {
 
                   List<IDriver> drivers = new ArrayList<>();
                   driverFactories.forEach(factory -> drivers.add(factory.createDriver()));
-                  // get the sinkHandle of last driver
-                  ISinkHandle sinkHandle = drivers.get(drivers.size() - 1).getSinkHandle();
+                  // get the sink of last driver
+                  ISink sink = drivers.get(drivers.size() - 1).getSink();
 
                   return createFragmentInstanceExecution(
                       scheduler,
                       instanceId,
                       context,
                       drivers,
-                      sinkHandle,
+                      sink,
                       stateMachine,
                       failedInstances,
                       instance.getTimeOut());
@@ -198,15 +198,15 @@ public class FragmentInstanceManager {
 
                 List<IDriver> drivers = new ArrayList<>();
                 driverFactories.forEach(factory -> drivers.add(factory.createDriver()));
-                // get the sinkHandle of last driver
-                ISinkHandle sinkHandle = drivers.get(drivers.size() - 1).getSinkHandle();
+                // get the sink of last driver
+                ISink sink = drivers.get(drivers.size() - 1).getSink();
 
                 return createFragmentInstanceExecution(
                     scheduler,
                     instanceId,
                     context,
                     drivers,
-                    sinkHandle,
+                    sink,
                     stateMachine,
                     failedInstances,
                     instance.getTimeOut());
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/execution/operator/sink/IdentitySinkOperator.java b/server/src/main/java/org/apache/iotdb/db/mpp/execution/operator/sink/IdentitySinkOperator.java
new file mode 100644
index 0000000000..01ecb7f71b
--- /dev/null
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/execution/operator/sink/IdentitySinkOperator.java
@@ -0,0 +1,132 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iotdb.db.mpp.execution.operator.sink;
+
+import org.apache.iotdb.db.mpp.execution.exchange.sink.DownStreamChannelIndex;
+import org.apache.iotdb.db.mpp.execution.exchange.sink.ISinkHandle;
+import org.apache.iotdb.db.mpp.execution.operator.Operator;
+import org.apache.iotdb.db.mpp.execution.operator.OperatorContext;
+import org.apache.iotdb.tsfile.read.common.block.TsBlock;
+
+import com.google.common.util.concurrent.ListenableFuture;
+
+import java.util.List;
+
+public class IdentitySinkOperator implements Operator {
+
+  private final OperatorContext operatorContext;
+  private final List<Operator> children;
+
+  private final DownStreamChannelIndex downStreamChannelIndex;
+
+  private final ISinkHandle sinkHandle;
+
+  private boolean needToReturnNull = false;
+
+  private boolean isFinished = false;
+
+  public IdentitySinkOperator(
+      OperatorContext operatorContext,
+      List<Operator> children,
+      DownStreamChannelIndex downStreamChannelIndex,
+      ISinkHandle sinkHandle) {
+    this.operatorContext = operatorContext;
+    this.children = children;
+    this.downStreamChannelIndex = downStreamChannelIndex;
+    this.sinkHandle = sinkHandle;
+  }
+
+  @Override
+  public boolean hasNext() {
+    if (children.get(downStreamChannelIndex.getCurrentIndex()).hasNext()) {
+      return true;
+    }
+    int currentIndex = downStreamChannelIndex.getCurrentIndex();
+    // current channel have no more data
+    sinkHandle.setNoMoreTsBlocksOfOneChannel(downStreamChannelIndex.getCurrentIndex());
+    currentIndex++;
+    if (currentIndex >= children.size()) {
+      isFinished = true;
+      return false;
+    }
+    downStreamChannelIndex.setCurrentIndex(currentIndex);
+    // if we reach here, it means that isBlocked() is called on a different child
+    // we need to ensure that this child is not blocked. We set this field to true here so that we
+    // can begin another loop in Driver.
+    needToReturnNull = true;
+    // tryOpenChannel first
+    sinkHandle.tryOpenChannel(currentIndex);
+    return true;
+  }
+
+  @Override
+  public TsBlock next() {
+    if (needToReturnNull) {
+      needToReturnNull = false;
+      return null;
+    }
+    return children.get(downStreamChannelIndex.getCurrentIndex()).next();
+  }
+
+  @Override
+  public ListenableFuture<?> isBlocked() {
+    return children.get(downStreamChannelIndex.getCurrentIndex()).isBlocked();
+  }
+
+  @Override
+  public boolean isFinished() {
+    return isFinished;
+  }
+
+  @Override
+  public OperatorContext getOperatorContext() {
+    return operatorContext;
+  }
+
+  @Override
+  public void close() throws Exception {
+    for (Operator child : children) {
+      child.close();
+    }
+  }
+
+  @Override
+  public long calculateMaxPeekMemory() {
+    long maxPeekMemory = 0;
+    for (Operator child : children) {
+      maxPeekMemory = Math.max(maxPeekMemory, child.calculateMaxPeekMemory());
+    }
+    return maxPeekMemory;
+  }
+
+  @Override
+  public long calculateMaxReturnSize() {
+    long maxReturnSize = 0;
+    for (Operator child : children) {
+      maxReturnSize = Math.max(maxReturnSize, child.calculateMaxReturnSize());
+    }
+    return maxReturnSize;
+  }
+
+  @Override
+  public long calculateRetainedSizeAfterCallingNext() {
+    return 0L;
+  }
+}
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/execution/operator/sink/ShuffleHelperOperator.java b/server/src/main/java/org/apache/iotdb/db/mpp/execution/operator/sink/ShuffleHelperOperator.java
new file mode 100644
index 0000000000..2e7863eb0f
--- /dev/null
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/execution/operator/sink/ShuffleHelperOperator.java
@@ -0,0 +1,134 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iotdb.db.mpp.execution.operator.sink;
+
+import org.apache.iotdb.db.mpp.execution.exchange.sink.DownStreamChannelIndex;
+import org.apache.iotdb.db.mpp.execution.exchange.sink.ISinkHandle;
+import org.apache.iotdb.db.mpp.execution.operator.Operator;
+import org.apache.iotdb.db.mpp.execution.operator.OperatorContext;
+import org.apache.iotdb.tsfile.read.common.block.TsBlock;
+
+import com.google.common.util.concurrent.ListenableFuture;
+
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+public class ShuffleHelperOperator implements Operator {
+  private final OperatorContext operatorContext;
+  private final List<Operator> children;
+
+  private final DownStreamChannelIndex downStreamChannelIndex;
+
+  private final ISinkHandle sinkHandle;
+
+  private final Set<Integer> unfinishedChildren;
+
+  private boolean needToReturnNull = false;
+
+  public ShuffleHelperOperator(
+      OperatorContext operatorContext,
+      List<Operator> children,
+      DownStreamChannelIndex downStreamChannelIndex,
+      ISinkHandle sinkHandle) {
+    this.operatorContext = operatorContext;
+    this.children = children;
+    this.downStreamChannelIndex = downStreamChannelIndex;
+    this.sinkHandle = sinkHandle;
+    this.unfinishedChildren = new HashSet<>(children.size());
+    for (int i = 0; i < children.size(); i++) {
+      unfinishedChildren.add(i);
+    }
+  }
+
+  @Override
+  public boolean hasNext() {
+    int currentIndex = downStreamChannelIndex.getCurrentIndex();
+    if (children.get(currentIndex).hasNext()) {
+      return true;
+    }
+    // current channel have no more data
+    sinkHandle.setNoMoreTsBlocksOfOneChannel(currentIndex);
+    unfinishedChildren.remove(currentIndex);
+    currentIndex = (currentIndex + 1) % children.size();
+    downStreamChannelIndex.setCurrentIndex(currentIndex);
+    // if we reach here, it means that isBlocked() is called on a different child
+    // we need to ensure that this child is not blocked. We set this field to true here so that we
+    // can begin another loop in Driver.
+    needToReturnNull = true;
+    // tryOpenChannel first
+    sinkHandle.tryOpenChannel(currentIndex);
+    return true;
+  }
+
+  @Override
+  public TsBlock next() {
+    if (needToReturnNull) {
+      needToReturnNull = false;
+      return null;
+    }
+    return children.get(downStreamChannelIndex.getCurrentIndex()).next();
+  }
+
+  @Override
+  public ListenableFuture<?> isBlocked() {
+    return children.get(downStreamChannelIndex.getCurrentIndex()).isBlocked();
+  }
+
+  @Override
+  public boolean isFinished() {
+    return unfinishedChildren.isEmpty();
+  }
+
+  @Override
+  public OperatorContext getOperatorContext() {
+    return operatorContext;
+  }
+
+  @Override
+  public void close() throws Exception {
+    for (Operator child : children) {
+      child.close();
+    }
+  }
+
+  @Override
+  public long calculateMaxPeekMemory() {
+    long maxPeekMemory = 0;
+    for (Operator child : children) {
+      maxPeekMemory = Math.max(maxPeekMemory, child.calculateMaxPeekMemory());
+    }
+    return maxPeekMemory;
+  }
+
+  @Override
+  public long calculateMaxReturnSize() {
+    long maxReturnSize = 0;
+    for (Operator child : children) {
+      maxReturnSize = Math.max(maxReturnSize, child.calculateMaxReturnSize());
+    }
+    return maxReturnSize;
+  }
+
+  @Override
+  public long calculateRetainedSizeAfterCallingNext() {
+    return 0L;
+  }
+}
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/execution/operator/source/ExchangeOperator.java b/server/src/main/java/org/apache/iotdb/db/mpp/execution/operator/source/ExchangeOperator.java
index 7e7498fbc1..ba57e1ff07 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/execution/operator/source/ExchangeOperator.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/execution/operator/source/ExchangeOperator.java
@@ -18,7 +18,7 @@
  */
 package org.apache.iotdb.db.mpp.execution.operator.source;
 
-import org.apache.iotdb.db.mpp.execution.exchange.ISourceHandle;
+import org.apache.iotdb.db.mpp.execution.exchange.source.ISourceHandle;
 import org.apache.iotdb.db.mpp.execution.operator.OperatorContext;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.PlanNodeId;
 import org.apache.iotdb.tsfile.read.common.block.TsBlock;
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/execution/schedule/task/DriverTask.java b/server/src/main/java/org/apache/iotdb/db/mpp/execution/schedule/task/DriverTask.java
index c436b180f8..20334d4c44 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/execution/schedule/task/DriverTask.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/execution/schedule/task/DriverTask.java
@@ -22,7 +22,7 @@ import org.apache.iotdb.db.mpp.common.FragmentInstanceId;
 import org.apache.iotdb.db.mpp.common.PlanFragmentId;
 import org.apache.iotdb.db.mpp.common.QueryId;
 import org.apache.iotdb.db.mpp.execution.driver.IDriver;
-import org.apache.iotdb.db.mpp.execution.exchange.ISinkHandle;
+import org.apache.iotdb.db.mpp.execution.exchange.sink.ISink;
 import org.apache.iotdb.db.mpp.execution.schedule.DriverTaskThread;
 import org.apache.iotdb.db.mpp.execution.schedule.ExecutionContext;
 import org.apache.iotdb.db.mpp.execution.schedule.queue.ID;
@@ -264,7 +264,7 @@ public class DriverTask implements IDIndexedAccessible {
     public void failed(Throwable t) {}
 
     @Override
-    public ISinkHandle getSinkHandle() {
+    public ISink getSink() {
       return null;
     }
 
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/execution/QueryExecution.java b/server/src/main/java/org/apache/iotdb/db/mpp/plan/execution/QueryExecution.java
index ae8722f1d7..3084cdbc6d 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/plan/execution/QueryExecution.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/plan/execution/QueryExecution.java
@@ -32,8 +32,8 @@ import org.apache.iotdb.db.mpp.common.MPPQueryContext;
 import org.apache.iotdb.db.mpp.common.header.DatasetHeader;
 import org.apache.iotdb.db.mpp.execution.QueryState;
 import org.apache.iotdb.db.mpp.execution.QueryStateMachine;
-import org.apache.iotdb.db.mpp.execution.exchange.ISourceHandle;
 import org.apache.iotdb.db.mpp.execution.exchange.MPPDataExchangeService;
+import org.apache.iotdb.db.mpp.execution.exchange.source.ISourceHandle;
 import org.apache.iotdb.db.mpp.metric.PerformanceOverviewMetricsManager;
 import org.apache.iotdb.db.mpp.metric.QueryMetricsManager;
 import org.apache.iotdb.db.mpp.plan.analyze.Analysis;
@@ -562,13 +562,16 @@ public class QueryExecution implements IQueryExecution {
                 .createLocalSourceHandleForFragment(
                     context.getResultNodeContext().getVirtualFragmentInstanceId().toThrift(),
                     context.getResultNodeContext().getVirtualResultNodeId().getId(),
+                    context.getResultNodeContext().getUpStreamPlanNodeId().getId(),
                     context.getResultNodeContext().getUpStreamFragmentInstanceId().toThrift(),
+                    0, // Upstream of result ExchangeNode will only have one child.
                     stateMachine::transitionToFailed)
             : MPPDataExchangeService.getInstance()
                 .getMPPDataExchangeManager()
                 .createSourceHandle(
                     context.getResultNodeContext().getVirtualFragmentInstanceId().toThrift(),
                     context.getResultNodeContext().getVirtualResultNodeId().getId(),
+                    0,
                     upstreamEndPoint,
                     context.getResultNodeContext().getUpStreamFragmentInstanceId().toThrift(),
                     stateMachine::transitionToFailed);
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/execution/memory/MemorySourceHandle.java b/server/src/main/java/org/apache/iotdb/db/mpp/plan/execution/memory/MemorySourceHandle.java
index e2b60ef441..a071b4b83c 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/plan/execution/memory/MemorySourceHandle.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/plan/execution/memory/MemorySourceHandle.java
@@ -20,7 +20,7 @@
 package org.apache.iotdb.db.mpp.plan.execution.memory;
 
 import org.apache.iotdb.commons.exception.IoTDBException;
-import org.apache.iotdb.db.mpp.execution.exchange.ISourceHandle;
+import org.apache.iotdb.db.mpp.execution.exchange.source.ISourceHandle;
 import org.apache.iotdb.mpp.rpc.thrift.TFragmentInstanceId;
 import org.apache.iotdb.rpc.TSStatusCode;
 import org.apache.iotdb.tsfile.read.common.block.TsBlock;
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/LocalExecutionPlanContext.java b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/LocalExecutionPlanContext.java
index 0404be866a..3dfde552ca 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/LocalExecutionPlanContext.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/LocalExecutionPlanContext.java
@@ -24,7 +24,7 @@ import org.apache.iotdb.db.mpp.common.FragmentInstanceId;
 import org.apache.iotdb.db.mpp.execution.driver.DataDriverContext;
 import org.apache.iotdb.db.mpp.execution.driver.DriverContext;
 import org.apache.iotdb.db.mpp.execution.driver.SchemaDriverContext;
-import org.apache.iotdb.db.mpp.execution.exchange.ISinkHandle;
+import org.apache.iotdb.db.mpp.execution.exchange.sink.ISink;
 import org.apache.iotdb.db.mpp.execution.fragment.FragmentInstanceContext;
 import org.apache.iotdb.db.mpp.execution.operator.Operator;
 import org.apache.iotdb.db.mpp.execution.operator.source.ExchangeOperator;
@@ -229,10 +229,10 @@ public class LocalExecutionPlanContext {
     return cachedLastValueAndPathList;
   }
 
-  public void setSinkHandle(ISinkHandle sinkHandle) {
-    requireNonNull(sinkHandle, "sinkHandle is null");
-    checkArgument(driverContext.getSinkHandle() == null, "There must be at most one SinkNode");
-    driverContext.setSinkHandle(sinkHandle);
+  public void setISink(ISink sink) {
+    requireNonNull(sink, "sink is null");
+    checkArgument(driverContext.getSink() == null, "There must be at most one SinkNode");
+    driverContext.setSink(sink);
   }
 
   public void setCachedDataTypes(List<TSDataType> cachedDataTypes) {
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/OperatorTreeGenerator.java b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/OperatorTreeGenerator.java
index dd4803d6bf..f8ef1e525b 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/OperatorTreeGenerator.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/OperatorTreeGenerator.java
@@ -33,11 +33,14 @@ import org.apache.iotdb.db.mpp.common.FragmentInstanceId;
 import org.apache.iotdb.db.mpp.common.NodeRef;
 import org.apache.iotdb.db.mpp.execution.driver.DataDriverContext;
 import org.apache.iotdb.db.mpp.execution.driver.SchemaDriverContext;
-import org.apache.iotdb.db.mpp.execution.exchange.ISinkHandle;
-import org.apache.iotdb.db.mpp.execution.exchange.ISourceHandle;
-import org.apache.iotdb.db.mpp.execution.exchange.LocalSinkHandle;
 import org.apache.iotdb.db.mpp.execution.exchange.MPPDataExchangeManager;
 import org.apache.iotdb.db.mpp.execution.exchange.MPPDataExchangeService;
+import org.apache.iotdb.db.mpp.execution.exchange.sink.DownStreamChannelIndex;
+import org.apache.iotdb.db.mpp.execution.exchange.sink.ISinkChannel;
+import org.apache.iotdb.db.mpp.execution.exchange.sink.ISinkHandle;
+import org.apache.iotdb.db.mpp.execution.exchange.sink.LocalSinkChannel;
+import org.apache.iotdb.db.mpp.execution.exchange.sink.ShuffleSinkHandle;
+import org.apache.iotdb.db.mpp.execution.exchange.source.ISourceHandle;
 import org.apache.iotdb.db.mpp.execution.fragment.FragmentInstanceManager;
 import org.apache.iotdb.db.mpp.execution.operator.AggregationUtil;
 import org.apache.iotdb.db.mpp.execution.operator.Operator;
@@ -112,6 +115,8 @@ import org.apache.iotdb.db.mpp.execution.operator.schema.SchemaQueryMergeOperato
 import org.apache.iotdb.db.mpp.execution.operator.schema.SchemaQueryOrderByHeatOperator;
 import org.apache.iotdb.db.mpp.execution.operator.schema.SchemaQueryScanOperator;
 import org.apache.iotdb.db.mpp.execution.operator.schema.source.SchemaSourceFactory;
+import org.apache.iotdb.db.mpp.execution.operator.sink.IdentitySinkOperator;
+import org.apache.iotdb.db.mpp.execution.operator.sink.ShuffleHelperOperator;
 import org.apache.iotdb.db.mpp.execution.operator.source.AlignedSeriesAggregationScanOperator;
 import org.apache.iotdb.db.mpp.execution.operator.source.AlignedSeriesScanOperator;
 import org.apache.iotdb.db.mpp.execution.operator.source.ExchangeOperator;
@@ -171,7 +176,8 @@ import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.TransformNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.last.LastQueryCollectNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.last.LastQueryMergeNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.last.LastQueryNode;
-import org.apache.iotdb.db.mpp.plan.planner.plan.node.sink.FragmentSinkNode;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.sink.IdentitySinkNode;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.sink.ShuffleSinkNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.source.AlignedLastQueryScanNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.source.AlignedSeriesAggregationScanNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.source.AlignedSeriesScanNode;
@@ -1840,20 +1846,26 @@ public class OperatorTreeGenerator extends PlanVisitor<Operator, LocalExecutionP
     FragmentInstanceId remoteInstanceId = node.getUpstreamInstanceId();
 
     TEndPoint upstreamEndPoint = node.getUpstreamEndpoint();
+    boolean isSameNode = isSameNode(upstreamEndPoint);
     ISourceHandle sourceHandle =
-        isSameNode(upstreamEndPoint)
+        isSameNode
             ? MPP_DATA_EXCHANGE_MANAGER.createLocalSourceHandleForFragment(
                 localInstanceId.toThrift(),
                 node.getPlanNodeId().getId(),
+                node.getUpstreamPlanNodeId().getId(),
                 remoteInstanceId.toThrift(),
+                node.getIndexOfUpstreamSinkHandle(),
                 context.getInstanceContext()::failed)
             : MPP_DATA_EXCHANGE_MANAGER.createSourceHandle(
                 localInstanceId.toThrift(),
                 node.getPlanNodeId().getId(),
+                node.getIndexOfUpstreamSinkHandle(),
                 upstreamEndPoint,
                 remoteInstanceId.toThrift(),
                 context.getInstanceContext()::failed);
-    context.addExchangeSumNum(1);
+    if (!isSameNode) {
+      context.addExchangeSumNum(1);
+    }
     sourceHandle.setMaxBytesCanReserve(context.getMaxBytesOneHandleCanReserve());
     ExchangeOperator exchangeOperator =
         new ExchangeOperator(operatorContext, sourceHandle, node.getUpstreamPlanNodeId());
@@ -1862,36 +1874,73 @@ public class OperatorTreeGenerator extends PlanVisitor<Operator, LocalExecutionP
   }
 
   @Override
-  public Operator visitFragmentSink(FragmentSinkNode node, LocalExecutionPlanContext context) {
-    if (!isSameNode(node.getDownStreamEndpoint())) {
-      context.addExchangeSumNum(1);
-    }
-    Operator child = node.getChild().accept(this, context);
+  public Operator visitIdentitySink(IdentitySinkNode node, LocalExecutionPlanContext context) {
+    context.addExchangeSumNum(1);
+    OperatorContext operatorContext =
+        context
+            .getDriverContext()
+            .addOperatorContext(
+                context.getNextOperatorId(),
+                node.getPlanNodeId(),
+                IdentitySinkOperator.class.getSimpleName());
+    context.getTimeSliceAllocator().recordExecutionWeight(operatorContext, 1);
 
-    FragmentInstanceId localInstanceId = context.getInstanceContext().getId();
-    FragmentInstanceId targetInstanceId = node.getDownStreamInstanceId();
-    TEndPoint downStreamEndPoint = node.getDownStreamEndpoint();
+    List<Operator> children =
+        node.getChildren().stream()
+            .map(child -> child.accept(this, context))
+            .collect(Collectors.toList());
 
     checkArgument(
         MPP_DATA_EXCHANGE_MANAGER != null, "MPP_DATA_EXCHANGE_MANAGER should not be null");
+    FragmentInstanceId localInstanceId = context.getInstanceContext().getId();
+    DownStreamChannelIndex downStreamChannelIndex = new DownStreamChannelIndex(0);
+    ISinkHandle sinkHandle =
+        MPP_DATA_EXCHANGE_MANAGER.createShuffleSinkHandle(
+            node.getDownStreamChannelLocationList(),
+            downStreamChannelIndex,
+            ShuffleSinkHandle.ShuffleStrategyEnum.PLAIN,
+            localInstanceId.toThrift(),
+            node.getPlanNodeId().getId(),
+            context.getInstanceContext());
+    sinkHandle.setMaxBytesCanReserve(context.getMaxBytesOneHandleCanReserve());
+    context.getDriverContext().setSink(sinkHandle);
+
+    return new IdentitySinkOperator(operatorContext, children, downStreamChannelIndex, sinkHandle);
+  }
+
+  @Override
+  public Operator visitShuffleSink(ShuffleSinkNode node, LocalExecutionPlanContext context) {
+    context.addExchangeSumNum(1);
+    OperatorContext operatorContext =
+        context
+            .getDriverContext()
+            .addOperatorContext(
+                context.getNextOperatorId(),
+                node.getPlanNodeId(),
+                ShuffleHelperOperator.class.getSimpleName());
+    context.getTimeSliceAllocator().recordExecutionWeight(operatorContext, 1);
 
+    List<Operator> children =
+        node.getChildren().stream()
+            .map(child -> child.accept(this, context))
+            .collect(Collectors.toList());
+
+    checkArgument(
+        MPP_DATA_EXCHANGE_MANAGER != null, "MPP_DATA_EXCHANGE_MANAGER should not be null");
+    FragmentInstanceId localInstanceId = context.getInstanceContext().getId();
+    DownStreamChannelIndex downStreamChannelIndex = new DownStreamChannelIndex(0);
     ISinkHandle sinkHandle =
-        isSameNode(downStreamEndPoint)
-            ? MPP_DATA_EXCHANGE_MANAGER.createLocalSinkHandleForFragment(
-                localInstanceId.toThrift(),
-                targetInstanceId.toThrift(),
-                node.getDownStreamPlanNodeId().getId(),
-                context.getInstanceContext())
-            : MPP_DATA_EXCHANGE_MANAGER.createSinkHandle(
-                localInstanceId.toThrift(),
-                downStreamEndPoint,
-                targetInstanceId.toThrift(),
-                node.getDownStreamPlanNodeId().getId(),
-                node.getPlanNodeId().getId(),
-                context.getInstanceContext());
+        MPP_DATA_EXCHANGE_MANAGER.createShuffleSinkHandle(
+            node.getDownStreamChannelLocationList(),
+            downStreamChannelIndex,
+            ShuffleSinkHandle.ShuffleStrategyEnum.SIMPLE_ROUND_ROBIN,
+            localInstanceId.toThrift(),
+            node.getPlanNodeId().getId(),
+            context.getInstanceContext());
     sinkHandle.setMaxBytesCanReserve(context.getMaxBytesOneHandleCanReserve());
-    context.getDriverContext().setSinkHandle(sinkHandle);
-    return child;
+    context.getDriverContext().setSink(sinkHandle);
+
+    return new ShuffleHelperOperator(operatorContext, children, downStreamChannelIndex, sinkHandle);
   }
 
   @Override
@@ -2437,11 +2486,11 @@ public class OperatorTreeGenerator extends PlanVisitor<Operator, LocalExecutionP
   private Operator createNewPipelineForChildNode(
       LocalExecutionPlanContext context, LocalExecutionPlanContext subContext, PlanNode childNode) {
     Operator childOperation = childNode.accept(this, subContext);
-    ISinkHandle localSinkHandle =
-        MPP_DATA_EXCHANGE_MANAGER.createLocalSinkHandleForPipeline(
+    ISinkChannel localSinkChannel =
+        MPP_DATA_EXCHANGE_MANAGER.createLocalSinkChannelForPipeline(
             // Attention, there is no parent node, use first child node instead
             subContext.getDriverContext(), childNode.getPlanNodeId().getId());
-    subContext.setSinkHandle(localSinkHandle);
+    subContext.setISink(localSinkChannel);
     subContext.addPipelineDriverFactory(childOperation, subContext.getDriverContext());
 
     ExchangeOperator sourceOperator =
@@ -2451,7 +2500,7 @@ public class OperatorTreeGenerator extends PlanVisitor<Operator, LocalExecutionP
                 .addOperatorContext(
                     context.getNextOperatorId(), null, ExchangeOperator.class.getSimpleName()),
             MPP_DATA_EXCHANGE_MANAGER.createLocalSourceHandleForPipeline(
-                ((LocalSinkHandle) localSinkHandle).getSharedTsBlockQueue(),
+                ((LocalSinkChannel) localSinkChannel).getSharedTsBlockQueue(),
                 context.getDriverContext()),
             childNode.getPlanNodeId());
 
@@ -2493,11 +2542,11 @@ public class OperatorTreeGenerator extends PlanVisitor<Operator, LocalExecutionP
           subContext.setDegreeOfParallelism(dopForChild);
           int originPipeNum = context.getPipelineNumber();
           Operator childOperation = childNode.accept(this, subContext);
-          ISinkHandle localSinkHandle =
-              MPP_DATA_EXCHANGE_MANAGER.createLocalSinkHandleForPipeline(
+          ISinkChannel localSinkChannel =
+              MPP_DATA_EXCHANGE_MANAGER.createLocalSinkChannelForPipeline(
                   // Attention, there is no parent node, use first child node instead
                   context.getDriverContext(), childNode.getPlanNodeId().getId());
-          subContext.setSinkHandle(localSinkHandle);
+          subContext.setISink(localSinkChannel);
           subContext.addPipelineDriverFactory(childOperation, subContext.getDriverContext());
 
           int curChildPipelineNum = subContext.getPipelineNumber() - originPipeNum;
@@ -2529,7 +2578,7 @@ public class OperatorTreeGenerator extends PlanVisitor<Operator, LocalExecutionP
                           null,
                           ExchangeOperator.class.getSimpleName()),
                   MPP_DATA_EXCHANGE_MANAGER.createLocalSourceHandleForPipeline(
-                      ((LocalSinkHandle) localSinkHandle).getSharedTsBlockQueue(),
+                      ((LocalSinkChannel) localSinkChannel).getSharedTsBlockQueue(),
                       context.getDriverContext()),
                   childNode.getPlanNodeId());
           context
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/distribution/DistributionPlanner.java b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/distribution/DistributionPlanner.java
index 0c564a3fa1..422759a807 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/distribution/DistributionPlanner.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/distribution/DistributionPlanner.java
@@ -18,8 +18,10 @@
  */
 package org.apache.iotdb.db.mpp.plan.planner.distribution;
 
+import org.apache.iotdb.common.rpc.thrift.TRegionReplicaSet;
 import org.apache.iotdb.db.mpp.common.MPPQueryContext;
 import org.apache.iotdb.db.mpp.common.PlanFragmentId;
+import org.apache.iotdb.db.mpp.execution.exchange.sink.DownStreamChannelLocation;
 import org.apache.iotdb.db.mpp.plan.analyze.Analysis;
 import org.apache.iotdb.db.mpp.plan.analyze.QueryType;
 import org.apache.iotdb.db.mpp.plan.planner.IFragmentParallelPlaner;
@@ -29,21 +31,32 @@ import org.apache.iotdb.db.mpp.plan.planner.plan.LogicalQueryPlan;
 import org.apache.iotdb.db.mpp.plan.planner.plan.PlanFragment;
 import org.apache.iotdb.db.mpp.plan.planner.plan.SubPlan;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.PlanNode;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.PlanNodeId;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.WritePlanNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.ExchangeNode;
-import org.apache.iotdb.db.mpp.plan.planner.plan.node.sink.FragmentSinkNode;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.sink.IdentitySinkNode;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.sink.MultiChildrenSinkNode;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.sink.ShuffleSinkNode;
+import org.apache.iotdb.db.mpp.plan.statement.component.OrderByComponent;
+import org.apache.iotdb.db.mpp.plan.statement.component.SortKey;
 import org.apache.iotdb.db.mpp.plan.statement.crud.QueryStatement;
 import org.apache.iotdb.db.mpp.plan.statement.sys.ShowQueriesStatement;
 
+import org.apache.commons.lang3.Validate;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
 import java.util.List;
+import java.util.Map;
+import java.util.Set;
 
 public class DistributionPlanner {
   private Analysis analysis;
   private MPPQueryContext context;
   private LogicalQueryPlan logicalPlan;
 
-  private int planFragmentIndex = 0;
-
   public DistributionPlanner(Analysis analysis, LogicalQueryPlan logicalPlan) {
     this.analysis = analysis;
     this.logicalPlan = logicalPlan;
@@ -63,7 +76,80 @@ public class DistributionPlanner {
 
   public PlanNode addExchangeNode(PlanNode root) {
     ExchangeNodeAdder adder = new ExchangeNodeAdder(this.analysis);
-    return adder.visit(root, new NodeGroupContext(context));
+    NodeGroupContext nodeGroupContext =
+        new NodeGroupContext(
+            context,
+            analysis.getStatement() instanceof QueryStatement
+                && (((QueryStatement) analysis.getStatement()).isAlignByDevice()),
+            root);
+    PlanNode newRoot = adder.visit(root, nodeGroupContext);
+    adjustUpStream(nodeGroupContext);
+    return newRoot;
+  }
+
+  /**
+   * Adjust upStream of exchangeNodes, generate {@link
+   * org.apache.iotdb.db.mpp.plan.planner.plan.node.sink.IdentitySinkNode} or {@link
+   * org.apache.iotdb.db.mpp.plan.planner.plan.node.sink.ShuffleSinkNode} for the children of
+   * ExchangeNodes with Same DataRegion.
+   */
+  private void adjustUpStream(NodeGroupContext context) {
+    if (context.exchangeNodes.isEmpty()) {
+      return;
+    }
+
+    final boolean needShuffleSinkNode =
+        analysis.getStatement() instanceof QueryStatement
+            && needShuffleSinkNode((QueryStatement) analysis.getStatement(), context);
+
+    // step1: group children of ExchangeNodes
+    Map<TRegionReplicaSet, List<PlanNode>> nodeGroups = new HashMap<>();
+    context.exchangeNodes.forEach(
+        exchangeNode ->
+            nodeGroups
+                .computeIfAbsent(
+                    context.getNodeDistribution(exchangeNode.getChild().getPlanNodeId()).region,
+                    exchangeNodes -> new ArrayList<>())
+                .add(exchangeNode.getChild()));
+
+    // step2: add IdentitySinkNode/ShuffleSinkNode as parent for nodes of each group
+    nodeGroups
+        .values()
+        .forEach(
+            planNodeList -> {
+              MultiChildrenSinkNode parent =
+                  needShuffleSinkNode
+                      ? new ShuffleSinkNode(context.queryContext.getQueryId().genPlanNodeId())
+                      : new IdentitySinkNode(context.queryContext.getQueryId().genPlanNodeId());
+              parent.addChildren(planNodeList);
+              // we put the parent in list to get it quickly by dataRegion of one ExchangeNode
+              planNodeList.add(parent);
+            });
+
+    // step3: add child for each ExchangeNode,
+    // the child is IdentitySinkNode/ShuffleSinkNode we generated in the last step
+
+    // count the visited time of each SinkNode
+    Map<TRegionReplicaSet, Integer> visitedCount = new HashMap<>();
+    context.exchangeNodes.forEach(
+        exchangeNode -> {
+          TRegionReplicaSet regionOfChild =
+              context.getNodeDistribution(exchangeNode.getChild().getPlanNodeId()).region;
+          visitedCount.compute(regionOfChild, (region, count) -> (count == null) ? 0 : count + 1);
+          List<PlanNode> planNodeList = nodeGroups.get(regionOfChild);
+          exchangeNode.setChild(planNodeList.get(planNodeList.size() - 1));
+          exchangeNode.setIndexOfUpstreamSinkHandle(visitedCount.get(regionOfChild));
+        });
+  }
+
+  /** Return true if we need to use ShuffleSinkNode instead of IdentitySinkNode. */
+  private boolean needShuffleSinkNode(
+      QueryStatement queryStatement, NodeGroupContext nodeGroupContext) {
+    OrderByComponent orderByComponent = queryStatement.getOrderByComponent();
+    return nodeGroupContext.isAlignByDevice()
+        && orderByComponent != null
+        && !(orderByComponent.getSortItemList().isEmpty()
+            || orderByComponent.getSortItemList().get(0).getSortKey().equals(SortKey.DEVICE));
   }
 
   public SubPlan splitFragment(PlanNode root) {
@@ -116,12 +202,15 @@ public class DistributionPlanner {
       return;
     }
 
-    FragmentSinkNode sinkNode = new FragmentSinkNode(context.getQueryId().genPlanNodeId());
-    sinkNode.setDownStream(
-        context.getLocalDataBlockEndpoint(),
-        context.getResultNodeContext().getVirtualFragmentInstanceId(),
-        context.getResultNodeContext().getVirtualResultNodeId());
-    sinkNode.setChild(rootInstance.getFragment().getPlanNodeTree());
+    IdentitySinkNode sinkNode =
+        new IdentitySinkNode(
+            context.getQueryId().genPlanNodeId(),
+            Collections.singletonList(rootInstance.getFragment().getPlanNodeTree()),
+            Collections.singletonList(
+                new DownStreamChannelLocation(
+                    context.getLocalDataBlockEndpoint(),
+                    context.getResultNodeContext().getVirtualFragmentInstanceId().toThrift(),
+                    context.getResultNodeContext().getVirtualResultNodeId().getId())));
     context
         .getResultNodeContext()
         .setUpStream(
@@ -144,11 +233,12 @@ public class DistributionPlanner {
 
     public SubPlan splitToSubPlan(PlanNode root) {
       SubPlan rootSubPlan = createSubPlan(root);
-      splitToSubPlan(root, rootSubPlan);
+      Set<PlanNodeId> visitedSinkNode = new HashSet<>();
+      splitToSubPlan(root, rootSubPlan, visitedSinkNode);
       return rootSubPlan;
     }
 
-    private void splitToSubPlan(PlanNode root, SubPlan subPlan) {
+    private void splitToSubPlan(PlanNode root, SubPlan subPlan, Set<PlanNodeId> visitedSinkNode) {
       // TODO: (xingtanzjr) we apply no action for IWritePlanNode currently
       if (root instanceof WritePlanNode) {
         return;
@@ -156,25 +246,27 @@ public class DistributionPlanner {
       if (root instanceof ExchangeNode) {
         // We add a FragmentSinkNode for newly created PlanFragment
         ExchangeNode exchangeNode = (ExchangeNode) root;
-        FragmentSinkNode sinkNode = new FragmentSinkNode(context.getQueryId().genPlanNodeId());
-        sinkNode.setChild(exchangeNode.getChild());
-        sinkNode.setDownStreamPlanNodeId(exchangeNode.getPlanNodeId());
+        Validate.isTrue(
+            exchangeNode.getChild() instanceof MultiChildrenSinkNode,
+            "child of ExchangeNode must be MultiChildrenSinkNode");
+        MultiChildrenSinkNode sinkNode = (MultiChildrenSinkNode) (exchangeNode.getChild());
+        sinkNode.addDownStreamChannelLocation(
+            new DownStreamChannelLocation(exchangeNode.getPlanNodeId().toString()));
 
-        // Record the source node info in the ExchangeNode so that we can keep the connection of
-        // these nodes/fragments
-        exchangeNode.setRemoteSourceNode(sinkNode);
         // We cut off the subtree to make the ExchangeNode as the leaf node of current PlanFragment
         exchangeNode.cleanChildren();
 
-        // Build the child SubPlan Tree
-        SubPlan childSubPlan = createSubPlan(sinkNode);
-        splitToSubPlan(sinkNode, childSubPlan);
-
-        subPlan.addChild(childSubPlan);
+        // If the SinkNode hasn't visited, build the child SubPlan Tree
+        if (!visitedSinkNode.contains(sinkNode.getPlanNodeId())) {
+          visitedSinkNode.add(sinkNode.getPlanNodeId());
+          SubPlan childSubPlan = createSubPlan(sinkNode);
+          splitToSubPlan(sinkNode, childSubPlan, visitedSinkNode);
+          subPlan.addChild(childSubPlan);
+        }
         return;
       }
       for (PlanNode child : root.getChildren()) {
-        splitToSubPlan(child, subPlan);
+        splitToSubPlan(child, subPlan, visitedSinkNode);
       }
     }
 
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/distribution/ExchangeNodeAdder.java b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/distribution/ExchangeNodeAdder.java
index 03f9d7d2c4..14a5ec1c00 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/distribution/ExchangeNodeAdder.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/distribution/ExchangeNodeAdder.java
@@ -56,13 +56,9 @@ import org.apache.iotdb.db.mpp.plan.planner.plan.node.source.LastQueryScanNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.source.SeriesAggregationScanNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.source.SeriesScanNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.source.SourceNode;
-import org.apache.iotdb.db.mpp.plan.planner.plan.parameter.OrderByParameter;
-import org.apache.iotdb.db.mpp.plan.statement.crud.QueryStatement;
 
 import java.util.ArrayList;
 import java.util.Collections;
-import java.util.HashMap;
-import java.util.LinkedList;
 import java.util.List;
 import java.util.Map;
 import java.util.stream.Collectors;
@@ -128,6 +124,7 @@ public class ExchangeNodeAdder extends PlanVisitor<PlanNode, NodeGroupContext> {
                     new ExchangeNode(context.queryContext.getQueryId().genPlanNodeId());
                 exchangeNode.setChild(child);
                 exchangeNode.setOutputColumnNames(child.getOutputColumnNames());
+                context.exchangeNodes.add(exchangeNode);
                 newNode.addChild(exchangeNode);
               } else {
                 newNode.addChild(child);
@@ -198,10 +195,6 @@ public class ExchangeNodeAdder extends PlanVisitor<PlanNode, NodeGroupContext> {
 
   @Override
   public PlanNode visitDeviceView(DeviceViewNode node, NodeGroupContext context) {
-    // A temporary way to decrease the FragmentInstance for aggregation with device view.
-    if (isAggregationQuery()) {
-      return processDeviceViewWithAggregation(node, context);
-    }
     return processMultiChildNode(node, context);
   }
 
@@ -217,81 +210,7 @@ public class ExchangeNodeAdder extends PlanVisitor<PlanNode, NodeGroupContext> {
 
   @Override
   public PlanNode visitMergeSort(MergeSortNode node, NodeGroupContext context) {
-    if (analysis.isVirtualSource()) {
-      return processMultiChildNodeByLocation(node, context);
-    }
-    // 1. Group children by dataRegion
-    Map<TRegionReplicaSet, List<PlanNode>> childrenGroupMap = new HashMap<>();
-    for (int i = 0; i < node.getChildren().size(); i++) {
-      PlanNode rawChildNode = node.getChildren().get(i);
-      PlanNode visitedChild = visit(rawChildNode, context);
-      TRegionReplicaSet region = context.getNodeDistribution(visitedChild.getPlanNodeId()).region;
-      childrenGroupMap.computeIfAbsent(region, k -> new ArrayList<>()).add(visitedChild);
-    }
-
-    // 2.add mergeSortNode for each group
-    List<PlanNode> mergeSortNodeList = new ArrayList<>();
-    for (List<PlanNode> group : childrenGroupMap.values()) {
-      if (group.size() == 1) {
-        PlanNode planNode = group.get(0);
-        if (planNode instanceof SingleDeviceViewNode) {
-          ((SingleDeviceViewNode) planNode).setCacheOutputColumnNames(true);
-        }
-        mergeSortNodeList.add(planNode);
-        continue;
-      }
-      MergeSortNode mergeSortNode =
-          new MergeSortNode(
-              context.queryContext.getQueryId().genPlanNodeId(),
-              node.getMergeOrderParameter(),
-              node.getOutputColumnNames());
-      group.forEach(mergeSortNode::addChild);
-      context.putNodeDistribution(
-          mergeSortNode.getPlanNodeId(),
-          new NodeDistribution(
-              NodeDistributionType.SAME_WITH_ALL_CHILDREN,
-              context.getNodeDistribution(mergeSortNode.getChildren().get(0).getPlanNodeId())
-                  .region));
-      mergeSortNodeList.add(mergeSortNode);
-    }
-
-    return groupPlanNodeByMergeSortNode(
-        mergeSortNodeList, node.getOutputColumnNames(), node.getMergeOrderParameter(), context);
-  }
-
-  private PlanNode groupPlanNodeByMergeSortNode(
-      List<PlanNode> mergeSortNodeList,
-      List<String> outputColumns,
-      OrderByParameter orderByParameter,
-      NodeGroupContext context) {
-    if (mergeSortNodeList.size() == 1) {
-      return mergeSortNodeList.get(0);
-    }
-
-    MergeSortNode mergeSortNode =
-        new MergeSortNode(
-            context.queryContext.getQueryId().genPlanNodeId(), orderByParameter, outputColumns);
-
-    // Each child has different TRegionReplicaSet, so we can select any one from
-    // its child
-    mergeSortNode.addChild(mergeSortNodeList.get(0));
-    context.putNodeDistribution(
-        mergeSortNode.getPlanNodeId(),
-        new NodeDistribution(
-            NodeDistributionType.SAME_WITH_SOME_CHILD,
-            context.getNodeDistribution(mergeSortNodeList.get(0).getPlanNodeId()).region));
-
-    // add ExchangeNode for other child
-    for (int i = 1; i < mergeSortNodeList.size(); i++) {
-      PlanNode child = mergeSortNodeList.get(i);
-      ExchangeNode exchangeNode =
-          new ExchangeNode(context.queryContext.getQueryId().genPlanNodeId());
-      exchangeNode.setChild(child);
-      exchangeNode.setOutputColumnNames(child.getOutputColumnNames());
-      mergeSortNode.addChild(exchangeNode);
-    }
-
-    return mergeSortNode;
+    return processMultiChildNode(node, context);
   }
 
   @Override
@@ -350,70 +269,6 @@ public class ExchangeNodeAdder extends PlanVisitor<PlanNode, NodeGroupContext> {
     return processMultiChildNode(node, context);
   }
 
-  private PlanNode processDeviceViewWithAggregation(DeviceViewNode node, NodeGroupContext context) {
-    // group all the children by DataRegion distribution
-    Map<TRegionReplicaSet, DeviceViewGroup> deviceViewGroupMap = new HashMap<>();
-    for (int i = 0; i < node.getDevices().size(); i++) {
-      String device = node.getDevices().get(i);
-      PlanNode rawChildNode = node.getChildren().get(i);
-      PlanNode visitedChild = visit(rawChildNode, context);
-      TRegionReplicaSet region = context.getNodeDistribution(visitedChild.getPlanNodeId()).region;
-      DeviceViewGroup group = deviceViewGroupMap.computeIfAbsent(region, DeviceViewGroup::new);
-      group.addChild(device, visitedChild);
-    }
-    // Generate DeviceViewNode for each group
-    List<PlanNode> deviceViewNodeList = new ArrayList<>();
-    for (DeviceViewGroup group : deviceViewGroupMap.values()) {
-      DeviceViewNode deviceViewNode =
-          new DeviceViewNode(
-              context.queryContext.getQueryId().genPlanNodeId(),
-              node.getMergeOrderParameter(),
-              node.getOutputColumnNames(),
-              node.getDeviceToMeasurementIndexesMap());
-      for (int i = 0; i < group.devices.size(); i++) {
-        deviceViewNode.addChildDeviceNode(group.devices.get(i), group.children.get(i));
-      }
-      context.putNodeDistribution(
-          deviceViewNode.getPlanNodeId(),
-          new NodeDistribution(
-              NodeDistributionType.SAME_WITH_ALL_CHILDREN,
-              context.getNodeDistribution(deviceViewNode.getChildren().get(0).getPlanNodeId())
-                  .region));
-      deviceViewNodeList.add(deviceViewNode);
-    }
-
-    return groupPlanNodeByMergeSortNode(
-        deviceViewNodeList, node.getOutputColumnNames(), node.getMergeOrderParameter(), context);
-  }
-
-  private static class DeviceViewGroup {
-    public TRegionReplicaSet regionReplicaSet;
-    public List<PlanNode> children;
-    public List<String> devices;
-
-    public DeviceViewGroup(TRegionReplicaSet regionReplicaSet) {
-      this.regionReplicaSet = regionReplicaSet;
-      this.children = new LinkedList<>();
-      this.devices = new LinkedList<>();
-    }
-
-    public void addChild(String device, PlanNode child) {
-      devices.add(device);
-      children.add(child);
-    }
-
-    public int hashCode() {
-      return regionReplicaSet.hashCode();
-    }
-
-    public boolean equals(Object o) {
-      if (o instanceof DeviceViewGroup) {
-        return regionReplicaSet.equals(((DeviceViewGroup) o).regionReplicaSet);
-      }
-      return false;
-    }
-  }
-
   private PlanNode processMultiChildNode(MultiChildProcessNode node, NodeGroupContext context) {
     if (analysis.isVirtualSource()) {
       return processMultiChildNodeByLocation(node, context);
@@ -427,18 +282,34 @@ public class ExchangeNodeAdder extends PlanVisitor<PlanNode, NodeGroupContext> {
               visitedChildren.add(visit(child, context));
             });
 
-    TRegionReplicaSet dataRegion = calculateDataRegionByChildren(visitedChildren, context);
-    NodeDistributionType distributionType =
-        nodeDistributionIsSame(visitedChildren, context)
-            ? NodeDistributionType.SAME_WITH_ALL_CHILDREN
-            : NodeDistributionType.SAME_WITH_SOME_CHILD;
-    context.putNodeDistribution(
-        newNode.getPlanNodeId(), new NodeDistribution(distributionType, dataRegion));
+    TRegionReplicaSet dataRegion;
+    NodeDistributionType distributionType;
+    if (context.isAlignByDevice()) {
+      // For align by device,
+      // if dataRegions of children are the same, we set child's dataRegion to this node,
+      // else we set the selected mostlyUsedDataRegion to this node
+      dataRegion =
+          nodeDistributionIsSame(visitedChildren, context)
+              ? context.getNodeDistribution(visitedChildren.get(0).getPlanNodeId()).region
+              : context.getMostlyUsedDataRegion();
+      context.putNodeDistribution(
+          newNode.getPlanNodeId(),
+          new NodeDistribution(NodeDistributionType.SAME_WITH_ALL_CHILDREN, dataRegion));
+    } else {
+      // TODO For align by time, we keep old logic for now
+      dataRegion = calculateDataRegionByChildren(visitedChildren, context);
+      distributionType =
+          nodeDistributionIsSame(visitedChildren, context)
+              ? NodeDistributionType.SAME_WITH_ALL_CHILDREN
+              : NodeDistributionType.SAME_WITH_SOME_CHILD;
+      context.putNodeDistribution(
+          newNode.getPlanNodeId(), new NodeDistribution(distributionType, dataRegion));
 
-    // If the distributionType of all the children are same, no ExchangeNode need to be added.
-    if (distributionType == NodeDistributionType.SAME_WITH_ALL_CHILDREN) {
-      newNode.setChildren(visitedChildren);
-      return newNode;
+      // If the distributionType of all the children are same, no ExchangeNode need to be added.
+      if (distributionType == NodeDistributionType.SAME_WITH_ALL_CHILDREN) {
+        newNode.setChildren(visitedChildren);
+        return newNode;
+      }
     }
 
     // Otherwise, we need to add ExchangeNode for the child whose DataRegion is different from the
@@ -446,10 +317,14 @@ public class ExchangeNodeAdder extends PlanVisitor<PlanNode, NodeGroupContext> {
     visitedChildren.forEach(
         child -> {
           if (!dataRegion.equals(context.getNodeDistribution(child.getPlanNodeId()).region)) {
+            if (child instanceof SingleDeviceViewNode) {
+              ((SingleDeviceViewNode) child).setCacheOutputColumnNames(true);
+            }
             ExchangeNode exchangeNode =
                 new ExchangeNode(context.queryContext.getQueryId().genPlanNodeId());
             exchangeNode.setChild(child);
             exchangeNode.setOutputColumnNames(child.getOutputColumnNames());
+            context.exchangeNodes.add(exchangeNode);
             newNode.addChild(exchangeNode);
           } else {
             newNode.addChild(child);
@@ -470,6 +345,7 @@ public class ExchangeNodeAdder extends PlanVisitor<PlanNode, NodeGroupContext> {
           new ExchangeNode(context.queryContext.getQueryId().genPlanNodeId());
       exchangeNode.setChild(child);
       exchangeNode.setOutputColumnNames(child.getOutputColumnNames());
+      context.exchangeNodes.add(exchangeNode);
       newNode.addChild(exchangeNode);
     }
     return newNode;
@@ -540,10 +416,6 @@ public class ExchangeNodeAdder extends PlanVisitor<PlanNode, NodeGroupContext> {
     return true;
   }
 
-  private boolean isAggregationQuery() {
-    return ((QueryStatement) analysis.getStatement()).isAggregationQuery();
-  }
-
   public PlanNode visit(PlanNode node, NodeGroupContext context) {
     return node.accept(this, context);
   }
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/distribution/NodeGroupContext.java b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/distribution/NodeGroupContext.java
index ba3ddb400c..389e30f0eb 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/distribution/NodeGroupContext.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/distribution/NodeGroupContext.java
@@ -19,19 +19,54 @@
 
 package org.apache.iotdb.db.mpp.plan.planner.distribution;
 
+import org.apache.iotdb.common.rpc.thrift.TRegionReplicaSet;
+import org.apache.iotdb.commons.partition.DataPartition;
 import org.apache.iotdb.db.mpp.common.MPPQueryContext;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.PlanNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.PlanNodeId;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.ExchangeNode;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.source.SourceNode;
 
+import java.util.ArrayList;
+import java.util.Collections;
 import java.util.HashMap;
+import java.util.List;
 import java.util.Map;
+import java.util.stream.Collectors;
 
 public class NodeGroupContext {
-  protected MPPQueryContext queryContext;
-  protected Map<PlanNodeId, NodeDistribution> nodeDistributionMap;
+  protected final MPPQueryContext queryContext;
+  private final Map<PlanNodeId, NodeDistribution> nodeDistributionMap;
+  private final boolean isAlignByDevice;
+  private final TRegionReplicaSet mostlyUsedDataRegion;
+  protected final List<ExchangeNode> exchangeNodes;
 
-  public NodeGroupContext(MPPQueryContext queryContext) {
+  public NodeGroupContext(MPPQueryContext queryContext, boolean isAlignByDevice, PlanNode root) {
     this.queryContext = queryContext;
     this.nodeDistributionMap = new HashMap<>();
+    this.isAlignByDevice = isAlignByDevice;
+    this.mostlyUsedDataRegion = isAlignByDevice ? getMostlyUsedDataRegion(root) : null;
+    this.exchangeNodes = new ArrayList<>();
+  }
+
+  private TRegionReplicaSet getMostlyUsedDataRegion(PlanNode root) {
+    Map<TRegionReplicaSet, Long> regionCount = new HashMap<>();
+    countRegionOfSourceNodes(root, regionCount);
+    return Collections.max(
+            regionCount.entrySet().stream()
+                .filter(e -> e.getKey() != DataPartition.NOT_ASSIGNED)
+                .collect(Collectors.toList()),
+            Map.Entry.comparingByValue())
+        .getKey();
+  }
+
+  private void countRegionOfSourceNodes(PlanNode root, Map<TRegionReplicaSet, Long> result) {
+    root.getChildren().forEach(child -> countRegionOfSourceNodes(child, result));
+    if (root instanceof SourceNode) {
+      result.compute(
+          ((SourceNode) root).getRegionReplicaSet(),
+          (region, count) -> (count == null) ? 1 : count + 1);
+    }
   }
 
   public void putNodeDistribution(PlanNodeId nodeId, NodeDistribution distribution) {
@@ -41,4 +76,12 @@ public class NodeGroupContext {
   public NodeDistribution getNodeDistribution(PlanNodeId nodeId) {
     return this.nodeDistributionMap.get(nodeId);
   }
+
+  public boolean isAlignByDevice() {
+    return isAlignByDevice;
+  }
+
+  public TRegionReplicaSet getMostlyUsedDataRegion() {
+    return mostlyUsedDataRegion;
+  }
 }
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/distribution/SimpleFragmentParallelPlanner.java b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/distribution/SimpleFragmentParallelPlanner.java
index 13c2e84d4f..d86c96d452 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/distribution/SimpleFragmentParallelPlanner.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/distribution/SimpleFragmentParallelPlanner.java
@@ -36,7 +36,7 @@ import org.apache.iotdb.db.mpp.plan.planner.plan.node.PlanNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.PlanNodeId;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.PlanNodeUtil;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.ExchangeNode;
-import org.apache.iotdb.db.mpp.plan.planner.plan.node.sink.FragmentSinkNode;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.sink.MultiChildrenSinkNode;
 import org.apache.iotdb.db.mpp.plan.statement.crud.QueryStatement;
 import org.apache.iotdb.db.mpp.plan.statement.sys.ShowQueriesStatement;
 import org.apache.iotdb.tsfile.read.filter.basic.Filter;
@@ -201,24 +201,30 @@ public class SimpleFragmentParallelPlanner implements IFragmentParallelPlaner {
   private void calculateNodeTopologyBetweenInstance() {
     for (FragmentInstance instance : fragmentInstanceList) {
       PlanNode rootNode = instance.getFragment().getPlanNodeTree();
-      if (rootNode instanceof FragmentSinkNode) {
-        // Set target Endpoint for FragmentSinkNode
-        FragmentSinkNode sinkNode = (FragmentSinkNode) rootNode;
-        PlanNodeId downStreamNodeId = sinkNode.getDownStreamPlanNodeId();
-        FragmentInstance downStreamInstance = findDownStreamInstance(downStreamNodeId);
-        sinkNode.setDownStream(
-            downStreamInstance.getHostDataNode().getMPPDataExchangeEndPoint(),
-            downStreamInstance.getId(),
-            downStreamNodeId);
-
-        // Set upstream info for corresponding ExchangeNode in downstream FragmentInstance
-        PlanNode downStreamExchangeNode =
-            downStreamInstance.getFragment().getPlanNodeById(downStreamNodeId);
-        ((ExchangeNode) downStreamExchangeNode)
-            .setUpstream(
-                instance.getHostDataNode().getMPPDataExchangeEndPoint(),
-                instance.getId(),
-                sinkNode.getPlanNodeId());
+      if (rootNode instanceof MultiChildrenSinkNode) {
+        MultiChildrenSinkNode sinkNode = (MultiChildrenSinkNode) rootNode;
+        sinkNode
+            .getDownStreamChannelLocationList()
+            .forEach(
+                downStreamChannelLocation -> {
+                  // Set target Endpoint for FragmentSinkNode
+                  PlanNodeId downStreamNodeId =
+                      new PlanNodeId(downStreamChannelLocation.getRemotePlanNodeId());
+                  FragmentInstance downStreamInstance = findDownStreamInstance(downStreamNodeId);
+                  downStreamChannelLocation.setRemoteEndpoint(
+                      downStreamInstance.getHostDataNode().getMPPDataExchangeEndPoint());
+                  downStreamChannelLocation.setRemoteFragmentInstanceId(
+                      downStreamInstance.getId().toThrift());
+
+                  // Set upstream info for corresponding ExchangeNode in downstream FragmentInstance
+                  PlanNode downStreamExchangeNode =
+                      downStreamInstance.getFragment().getPlanNodeById(downStreamNodeId);
+                  ((ExchangeNode) downStreamExchangeNode)
+                      .setUpstream(
+                          instance.getHostDataNode().getMPPDataExchangeEndPoint(),
+                          instance.getId(),
+                          sinkNode.getPlanNodeId());
+                });
       }
     }
   }
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/FragmentInstance.java b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/FragmentInstance.java
index 77db9b673d..97bd7e1748 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/FragmentInstance.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/FragmentInstance.java
@@ -30,9 +30,7 @@ import org.apache.iotdb.db.conf.IoTDBDescriptor;
 import org.apache.iotdb.db.mpp.common.FragmentInstanceId;
 import org.apache.iotdb.db.mpp.common.SessionInfo;
 import org.apache.iotdb.db.mpp.plan.analyze.QueryType;
-import org.apache.iotdb.db.mpp.plan.planner.plan.node.PlanNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.PlanNodeUtil;
-import org.apache.iotdb.db.mpp.plan.planner.plan.node.sink.FragmentSinkNode;
 import org.apache.iotdb.tsfile.read.filter.basic.Filter;
 import org.apache.iotdb.tsfile.read.filter.factory.FilterFactory;
 import org.apache.iotdb.tsfile.utils.PublicBAOS;
@@ -140,19 +138,6 @@ public class FragmentInstance implements IConsensusRequest {
     return isRoot;
   }
 
-  public String getDownstreamInfo() {
-    PlanNode root = getFragment().getPlanNodeTree();
-    if (root instanceof FragmentSinkNode) {
-      FragmentSinkNode sink = (FragmentSinkNode) root;
-      return String.format(
-          "(%s, %s, %s)",
-          sink.getDownStreamEndpoint(),
-          sink.getDownStreamInstanceId(),
-          sink.getDownStreamPlanNodeId());
-    }
-    return "<No downstream>";
-  }
-
   public void setTimeFilter(Filter timeFilter) {
     this.timeFilter = timeFilter;
   }
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/node/PlanGraphPrinter.java b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/node/PlanGraphPrinter.java
index 03e5028e26..c0a839cc6e 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/node/PlanGraphPrinter.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/node/PlanGraphPrinter.java
@@ -45,7 +45,8 @@ import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.TransformNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.last.LastQueryCollectNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.last.LastQueryMergeNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.last.LastQueryNode;
-import org.apache.iotdb.db.mpp.plan.planner.plan.node.sink.FragmentSinkNode;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.sink.IdentitySinkNode;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.sink.ShuffleSinkNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.source.AlignedLastQueryScanNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.source.AlignedSeriesAggregationScanNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.source.AlignedSeriesScanNode;
@@ -305,14 +306,6 @@ public class PlanGraphPrinter extends PlanVisitor<List<String>, PlanGraphPrinter
     return render(node, boxValue, context);
   }
 
-  @Override
-  public List<String> visitFragmentSink(FragmentSinkNode node, GraphContext context) {
-    List<String> boxValue = new ArrayList<>();
-    boxValue.add(String.format("FragmentSink-%s", node.getPlanNodeId().getId()));
-    boxValue.add(String.format("Destination: %s", node.getDownStreamPlanNodeId()));
-    return render(node, boxValue, context);
-  }
-
   @Override
   public List<String> visitTransform(TransformNode node, GraphContext context) {
     List<String> boxValue = new ArrayList<>();
@@ -416,7 +409,21 @@ public class PlanGraphPrinter extends PlanVisitor<List<String>, PlanGraphPrinter
   @Override
   public List<String> visitHorizontallyConcat(HorizontallyConcatNode node, GraphContext context) {
     List<String> boxValue = new ArrayList<>();
-    boxValue.add(String.format("VerticallyConcat-%s", node.getPlanNodeId().getId()));
+    boxValue.add(String.format("HorizontallyConcat-%s", node.getPlanNodeId().getId()));
+    return render(node, boxValue, context);
+  }
+
+  @Override
+  public List<String> visitIdentitySink(IdentitySinkNode node, GraphContext context) {
+    List<String> boxValue = new ArrayList<>();
+    boxValue.add(String.format("IdentitySink-%s", node.getPlanNodeId().getId()));
+    return render(node, boxValue, context);
+  }
+
+  @Override
+  public List<String> visitShuffleSink(ShuffleSinkNode node, GraphContext context) {
+    List<String> boxValue = new ArrayList<>();
+    boxValue.add(String.format("ShuffleSink-%s", node.getPlanNodeId().getId()));
     return render(node, boxValue, context);
   }
 
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/node/PlanNodeType.java b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/node/PlanNodeType.java
index 2527f5a74f..9b75d9e4a1 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/node/PlanNodeType.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/node/PlanNodeType.java
@@ -72,7 +72,8 @@ import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.TransformNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.last.LastQueryCollectNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.last.LastQueryMergeNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.last.LastQueryNode;
-import org.apache.iotdb.db.mpp.plan.planner.plan.node.sink.FragmentSinkNode;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.sink.IdentitySinkNode;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.sink.ShuffleSinkNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.source.AlignedLastQueryScanNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.source.AlignedSeriesAggregationScanNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.source.AlignedSeriesScanNode;
@@ -162,7 +163,9 @@ public enum PlanNodeType {
   MERGE_SORT((short) 66),
   SHOW_QUERIES((short) 67),
   INTERNAL_BATCH_ACTIVATE_TEMPLATE((short) 68),
-  INTERNAL_CREATE_MULTI_TIMESERIES((short) 69);
+  INTERNAL_CREATE_MULTI_TIMESERIES((short) 69),
+  IDENTITY_SINK((short) 70),
+  SHUFFLE_SINK((short) 71);
 
   public static final int BYTES = Short.BYTES;
 
@@ -233,8 +236,6 @@ public enum PlanNodeType {
         return SortNode.deserialize(buffer);
       case 9:
         return TimeJoinNode.deserialize(buffer);
-      case 10:
-        return FragmentSinkNode.deserialize(buffer);
       case 11:
         return SeriesScanNode.deserialize(buffer);
       case 12:
@@ -351,6 +352,10 @@ public enum PlanNodeType {
         return InternalBatchActivateTemplateNode.deserialize(buffer);
       case 69:
         return InternalCreateMultiTimeSeriesNode.deserialize(buffer);
+      case 70:
+        return IdentitySinkNode.deserialize(buffer);
+      case 71:
+        return ShuffleSinkNode.deserialize(buffer);
       default:
         throw new IllegalArgumentException("Invalid node type: " + nodeType);
     }
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/node/PlanVisitor.java b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/node/PlanVisitor.java
index 8181efcfa3..9356ddc742 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/node/PlanVisitor.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/node/PlanVisitor.java
@@ -70,7 +70,8 @@ import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.TransformNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.last.LastQueryCollectNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.last.LastQueryMergeNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.last.LastQueryNode;
-import org.apache.iotdb.db.mpp.plan.planner.plan.node.sink.FragmentSinkNode;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.sink.IdentitySinkNode;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.sink.ShuffleSinkNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.source.AlignedLastQueryScanNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.source.AlignedSeriesAggregationScanNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.source.AlignedSeriesScanNode;
@@ -201,10 +202,6 @@ public abstract class PlanVisitor<R, C> {
     return visitPlan(node, context);
   }
 
-  public R visitFragmentSink(FragmentSinkNode node, C context) {
-    return visitPlan(node, context);
-  }
-
   public R visitCreateTimeSeries(CreateTimeSeriesNode node, C context) {
     return visitPlan(node, context);
   }
@@ -356,4 +353,12 @@ public abstract class PlanVisitor<R, C> {
   public R visitInternalCreateMultiTimeSeries(InternalCreateMultiTimeSeriesNode node, C context) {
     return visitPlan(node, context);
   }
+
+  public R visitIdentitySink(IdentitySinkNode node, C context) {
+    return visitPlan(node, context);
+  }
+
+  public R visitShuffleSink(ShuffleSinkNode node, C context) {
+    return visitPlan(node, context);
+  }
 }
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/node/process/ExchangeNode.java b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/node/process/ExchangeNode.java
index 06310be926..84622d1d5c 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/node/process/ExchangeNode.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/node/process/ExchangeNode.java
@@ -25,7 +25,6 @@ import org.apache.iotdb.db.mpp.plan.planner.plan.node.PlanNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.PlanNodeId;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.PlanNodeType;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.PlanVisitor;
-import org.apache.iotdb.db.mpp.plan.planner.plan.node.sink.FragmentSinkNode;
 import org.apache.iotdb.tsfile.utils.ReadWriteIOUtils;
 
 import java.io.DataOutputStream;
@@ -36,11 +35,6 @@ import java.util.List;
 import java.util.Objects;
 
 public class ExchangeNode extends SingleChildProcessNode {
-
-  // The remoteSourceNode is used to record the remote source info for current ExchangeNode
-  // It is not the child of current ExchangeNode
-  private FragmentSinkNode remoteSourceNode;
-
   // In current version, one ExchangeNode will only have one source.
   // And the fragment which the sourceNode belongs to will only have one instance.
   // Thus, by nodeId and endpoint, the ExchangeNode can know where its source from.
@@ -48,7 +42,10 @@ public class ExchangeNode extends SingleChildProcessNode {
   private FragmentInstanceId upstreamInstanceId;
   private PlanNodeId upstreamPlanNodeId;
 
-  private List<String> outputColumnNames;
+  private List<String> outputColumnNames = new ArrayList<>();
+
+  /** Exchange needs to know which child of IdentitySinkNode/ShuffleSinkNode it matches */
+  private int indexOfUpstreamSinkHandle = 0;
 
   public ExchangeNode(PlanNodeId id) {
     super(id);
@@ -67,11 +64,8 @@ public class ExchangeNode extends SingleChildProcessNode {
   @Override
   public PlanNode clone() {
     ExchangeNode node = new ExchangeNode(getPlanNodeId());
-    if (remoteSourceNode != null) {
-      FragmentSinkNode remoteSourceNodeClone = (FragmentSinkNode) remoteSourceNode.clone();
-      remoteSourceNodeClone.setDownStreamPlanNodeId(node.getPlanNodeId());
-      node.setRemoteSourceNode(remoteSourceNode);
-    }
+    node.setOutputColumnNames(outputColumnNames);
+    node.setIndexOfUpstreamSinkHandle(indexOfUpstreamSinkHandle);
     return node;
   }
 
@@ -102,10 +96,12 @@ public class ExchangeNode extends SingleChildProcessNode {
       outputColumnNames.add(ReadWriteIOUtils.readString(byteBuffer));
       outputColumnNamesSize--;
     }
+    int index = ReadWriteIOUtils.readInt(byteBuffer);
     PlanNodeId planNodeId = PlanNodeId.deserialize(byteBuffer);
     ExchangeNode exchangeNode = new ExchangeNode(planNodeId);
     exchangeNode.setUpstream(endPoint, fragmentInstanceId, upstreamPlanNodeId);
     exchangeNode.setOutputColumnNames(outputColumnNames);
+    exchangeNode.setIndexOfUpstreamSinkHandle(index);
     return exchangeNode;
   }
 
@@ -120,6 +116,7 @@ public class ExchangeNode extends SingleChildProcessNode {
     for (String outputColumnName : outputColumnNames) {
       ReadWriteIOUtils.write(outputColumnName, byteBuffer);
     }
+    ReadWriteIOUtils.write(indexOfUpstreamSinkHandle, byteBuffer);
   }
 
   @Override
@@ -133,6 +130,7 @@ public class ExchangeNode extends SingleChildProcessNode {
     for (String outputColumnName : outputColumnNames) {
       ReadWriteIOUtils.write(outputColumnName, stream);
     }
+    ReadWriteIOUtils.write(indexOfUpstreamSinkHandle, stream);
   }
 
   @Override
@@ -150,13 +148,12 @@ public class ExchangeNode extends SingleChildProcessNode {
         getUpstreamEndpoint().getIp(), getUpstreamInstanceId(), getUpstreamPlanNodeId());
   }
 
-  public FragmentSinkNode getRemoteSourceNode() {
-    return remoteSourceNode;
+  public int getIndexOfUpstreamSinkHandle() {
+    return indexOfUpstreamSinkHandle;
   }
 
-  public void setRemoteSourceNode(FragmentSinkNode remoteSourceNode) {
-    this.remoteSourceNode = remoteSourceNode;
-    this.setOutputColumnNames(remoteSourceNode.getOutputColumnNames());
+  public void setIndexOfUpstreamSinkHandle(int indexOfUpstreamSinkHandle) {
+    this.indexOfUpstreamSinkHandle = indexOfUpstreamSinkHandle;
   }
 
   public TEndPoint getUpstreamEndpoint() {
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/node/process/HorizontallyConcatNode.java b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/node/process/HorizontallyConcatNode.java
index 00d1b3c4f9..3eebf02601 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/node/process/HorizontallyConcatNode.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/node/process/HorizontallyConcatNode.java
@@ -88,7 +88,7 @@ public class HorizontallyConcatNode extends MultiChildProcessNode {
 
   @Override
   public String toString() {
-    return "VerticallyConcatNode-" + this.getPlanNodeId();
+    return "HorizontallyConcatNode-" + this.getPlanNodeId();
   }
 
   @Override
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/node/sink/FragmentSinkNode.java b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/node/sink/FragmentSinkNode.java
deleted file mode 100644
index 249ab7fb95..0000000000
--- a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/node/sink/FragmentSinkNode.java
+++ /dev/null
@@ -1,202 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *      http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.iotdb.db.mpp.plan.planner.plan.node.sink;
-
-import org.apache.iotdb.common.rpc.thrift.TEndPoint;
-import org.apache.iotdb.db.mpp.common.FragmentInstanceId;
-import org.apache.iotdb.db.mpp.plan.planner.plan.node.PlanNode;
-import org.apache.iotdb.db.mpp.plan.planner.plan.node.PlanNodeId;
-import org.apache.iotdb.db.mpp.plan.planner.plan.node.PlanNodeType;
-import org.apache.iotdb.db.mpp.plan.planner.plan.node.PlanVisitor;
-import org.apache.iotdb.tsfile.utils.ReadWriteIOUtils;
-
-import com.google.common.collect.ImmutableList;
-import org.apache.commons.lang.Validate;
-
-import java.io.DataOutputStream;
-import java.io.IOException;
-import java.nio.ByteBuffer;
-import java.util.List;
-import java.util.Objects;
-
-public class FragmentSinkNode extends SinkNode {
-  private PlanNode child;
-
-  private TEndPoint downStreamEndpoint;
-  private FragmentInstanceId downStreamInstanceId;
-  private PlanNodeId downStreamPlanNodeId;
-
-  public FragmentSinkNode(PlanNodeId id) {
-    super(id);
-  }
-
-  @Override
-  public List<PlanNode> getChildren() {
-    return ImmutableList.of(child);
-  }
-
-  @Override
-  public PlanNode clone() {
-    FragmentSinkNode sinkNode = new FragmentSinkNode(getPlanNodeId());
-    sinkNode.setDownStream(downStreamEndpoint, downStreamInstanceId, downStreamPlanNodeId);
-    return sinkNode;
-  }
-
-  @Override
-  public PlanNode cloneWithChildren(List<PlanNode> children) {
-    Validate.isTrue(
-        children == null || children.size() == 1,
-        "Children size of FragmentSinkNode should be 0 or 1");
-    FragmentSinkNode sinkNode = (FragmentSinkNode) clone();
-    if (children != null) {
-      sinkNode.setChild(children.get(0));
-    }
-    return sinkNode;
-  }
-
-  @Override
-  public <R, C> R accept(PlanVisitor<R, C> visitor, C context) {
-    return visitor.visitFragmentSink(this, context);
-  }
-
-  @Override
-  public void addChild(PlanNode child) {
-    this.child = child;
-  }
-
-  @Override
-  public int allowedChildCount() {
-    return ONE_CHILD;
-  }
-
-  @Override
-  public List<String> getOutputColumnNames() {
-    return child.getOutputColumnNames();
-  }
-
-  public static FragmentSinkNode deserialize(ByteBuffer byteBuffer) {
-    TEndPoint downStreamEndpoint =
-        new TEndPoint(
-            ReadWriteIOUtils.readString(byteBuffer), ReadWriteIOUtils.readInt(byteBuffer));
-    FragmentInstanceId downStreamInstanceId = FragmentInstanceId.deserialize(byteBuffer);
-    PlanNodeId downStreamPlanNodeId = PlanNodeId.deserialize(byteBuffer);
-    PlanNodeId planNodeId = PlanNodeId.deserialize(byteBuffer);
-
-    FragmentSinkNode fragmentSinkNode = new FragmentSinkNode(planNodeId);
-    fragmentSinkNode.downStreamEndpoint = downStreamEndpoint;
-    fragmentSinkNode.downStreamInstanceId = downStreamInstanceId;
-    fragmentSinkNode.downStreamPlanNodeId = downStreamPlanNodeId;
-    return fragmentSinkNode;
-  }
-
-  @Override
-  protected void serializeAttributes(ByteBuffer byteBuffer) {
-    PlanNodeType.FRAGMENT_SINK.serialize(byteBuffer);
-    ReadWriteIOUtils.write(downStreamEndpoint.getIp(), byteBuffer);
-    ReadWriteIOUtils.write(downStreamEndpoint.getPort(), byteBuffer);
-    downStreamInstanceId.serialize(byteBuffer);
-    downStreamPlanNodeId.serialize(byteBuffer);
-  }
-
-  @Override
-  protected void serializeAttributes(DataOutputStream stream) throws IOException {
-    PlanNodeType.FRAGMENT_SINK.serialize(stream);
-    ReadWriteIOUtils.write(downStreamEndpoint.getIp(), stream);
-    ReadWriteIOUtils.write(downStreamEndpoint.getPort(), stream);
-    downStreamInstanceId.serialize(stream);
-    downStreamPlanNodeId.serialize(stream);
-  }
-
-  @Override
-  public void send() {}
-
-  @Override
-  public void close() throws Exception {}
-
-  public PlanNode getChild() {
-    return child;
-  }
-
-  public void setChild(PlanNode child) {
-    this.child = child;
-  }
-
-  public String toString() {
-    return String.format(
-        "FragmentSinkNode-%s:[SendTo: (%s)]", getPlanNodeId(), getDownStreamAddress());
-  }
-
-  public String getDownStreamAddress() {
-    if (getDownStreamEndpoint() == null) {
-      return "Not assigned";
-    }
-    return String.format(
-        "%s:%d/%s/%s",
-        getDownStreamEndpoint().getIp(),
-        getDownStreamEndpoint().getPort(),
-        getDownStreamInstanceId(),
-        getDownStreamPlanNodeId());
-  }
-
-  public void setDownStream(TEndPoint endPoint, FragmentInstanceId instanceId, PlanNodeId nodeId) {
-    this.downStreamEndpoint = endPoint;
-    this.downStreamInstanceId = instanceId;
-    this.downStreamPlanNodeId = nodeId;
-  }
-
-  public void setDownStreamPlanNodeId(PlanNodeId downStreamPlanNodeId) {
-    this.downStreamPlanNodeId = downStreamPlanNodeId;
-  }
-
-  public TEndPoint getDownStreamEndpoint() {
-    return downStreamEndpoint;
-  }
-
-  public FragmentInstanceId getDownStreamInstanceId() {
-    return downStreamInstanceId;
-  }
-
-  public PlanNodeId getDownStreamPlanNodeId() {
-    return downStreamPlanNodeId;
-  }
-
-  @Override
-  public boolean equals(Object o) {
-    if (this == o) {
-      return true;
-    }
-    if (o == null || getClass() != o.getClass()) {
-      return false;
-    }
-    if (!super.equals(o)) {
-      return false;
-    }
-    FragmentSinkNode that = (FragmentSinkNode) o;
-    return Objects.equals(child, that.child)
-        && Objects.equals(downStreamEndpoint, that.downStreamEndpoint)
-        && Objects.equals(downStreamInstanceId, that.downStreamInstanceId)
-        && Objects.equals(downStreamPlanNodeId, that.downStreamPlanNodeId);
-  }
-
-  @Override
-  public int hashCode() {
-    return Objects.hash(
-        super.hashCode(), child, downStreamEndpoint, downStreamInstanceId, downStreamPlanNodeId);
-  }
-}
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/node/sink/IdentitySinkNode.java b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/node/sink/IdentitySinkNode.java
new file mode 100644
index 0000000000..30f0cab31c
--- /dev/null
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/node/sink/IdentitySinkNode.java
@@ -0,0 +1,99 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iotdb.db.mpp.plan.planner.plan.node.sink;
+
+import org.apache.iotdb.db.mpp.execution.exchange.sink.DownStreamChannelLocation;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.PlanNode;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.PlanNodeId;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.PlanNodeType;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.PlanVisitor;
+import org.apache.iotdb.tsfile.utils.ReadWriteIOUtils;
+
+import java.io.DataOutputStream;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.stream.Collectors;
+
+public class IdentitySinkNode extends MultiChildrenSinkNode {
+
+  public IdentitySinkNode(PlanNodeId id) {
+    super(id);
+  }
+
+  public IdentitySinkNode(
+      PlanNodeId id, List<DownStreamChannelLocation> downStreamChannelLocationList) {
+    super(id, downStreamChannelLocationList);
+  }
+
+  public IdentitySinkNode(
+      PlanNodeId id,
+      List<PlanNode> children,
+      List<DownStreamChannelLocation> downStreamChannelLocationList) {
+    super(id, children, downStreamChannelLocationList);
+  }
+
+  @Override
+  public PlanNode clone() {
+    return new IdentitySinkNode(getPlanNodeId(), getDownStreamChannelLocationList());
+  }
+
+  @Override
+  public List<String> getOutputColumnNames() {
+    return children.stream()
+        .map(PlanNode::getOutputColumnNames)
+        .flatMap(List::stream)
+        .collect(Collectors.toList());
+  }
+
+  @Override
+  public <R, C> R accept(PlanVisitor<R, C> visitor, C context) {
+    return visitor.visitIdentitySink(this, context);
+  }
+
+  @Override
+  protected void serializeAttributes(ByteBuffer byteBuffer) {
+    PlanNodeType.IDENTITY_SINK.serialize(byteBuffer);
+    ReadWriteIOUtils.write(downStreamChannelLocationList.size(), byteBuffer);
+    for (DownStreamChannelLocation downStreamChannelLocation : downStreamChannelLocationList) {
+      downStreamChannelLocation.serialize(byteBuffer);
+    }
+  }
+
+  @Override
+  protected void serializeAttributes(DataOutputStream stream) throws IOException {
+    PlanNodeType.IDENTITY_SINK.serialize(stream);
+    ReadWriteIOUtils.write(downStreamChannelLocationList.size(), stream);
+    for (DownStreamChannelLocation downStreamChannelLocation : downStreamChannelLocationList) {
+      downStreamChannelLocation.serialize(stream);
+    }
+  }
+
+  public static IdentitySinkNode deserialize(ByteBuffer byteBuffer) {
+    int size = ReadWriteIOUtils.readInt(byteBuffer);
+    List<DownStreamChannelLocation> downStreamChannelLocationList = new ArrayList<>();
+    for (int i = 0; i < size; i++) {
+      downStreamChannelLocationList.add(DownStreamChannelLocation.deserialize(byteBuffer));
+    }
+    PlanNodeId planNodeId = PlanNodeId.deserialize(byteBuffer);
+    return new IdentitySinkNode(planNodeId, downStreamChannelLocationList);
+  }
+}
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/node/sink/MultiChildrenSinkNode.java b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/node/sink/MultiChildrenSinkNode.java
new file mode 100644
index 0000000000..686dae93e9
--- /dev/null
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/node/sink/MultiChildrenSinkNode.java
@@ -0,0 +1,117 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iotdb.db.mpp.plan.planner.plan.node.sink;
+
+import org.apache.iotdb.db.mpp.execution.exchange.sink.DownStreamChannelLocation;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.PlanNode;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.PlanNodeId;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Objects;
+
+public abstract class MultiChildrenSinkNode extends SinkNode {
+
+  protected List<PlanNode> children;
+
+  protected final List<DownStreamChannelLocation> downStreamChannelLocationList;
+
+  public MultiChildrenSinkNode(PlanNodeId id) {
+    super(id);
+    this.children = new ArrayList<>();
+    this.downStreamChannelLocationList = new ArrayList<>();
+  }
+
+  protected MultiChildrenSinkNode(
+      PlanNodeId id,
+      List<PlanNode> children,
+      List<DownStreamChannelLocation> downStreamChannelLocationList) {
+    super(id);
+    this.children = children;
+    this.downStreamChannelLocationList = downStreamChannelLocationList;
+  }
+
+  protected MultiChildrenSinkNode(
+      PlanNodeId id, List<DownStreamChannelLocation> downStreamChannelLocationList) {
+    super(id);
+    this.children = new ArrayList<>();
+    this.downStreamChannelLocationList = downStreamChannelLocationList;
+  }
+
+  public void setChildren(List<PlanNode> children) {
+    this.children = children;
+  }
+
+  @Override
+  public abstract PlanNode clone();
+
+  @Override
+  public List<PlanNode> getChildren() {
+    return children;
+  }
+
+  @Override
+  public void addChild(PlanNode child) {
+    this.children.add(child);
+  }
+
+  public void addChildren(List<PlanNode> children) {
+    this.children.addAll(children);
+  }
+
+  @Override
+  public int allowedChildCount() {
+    return CHILD_COUNT_NO_LIMIT;
+  }
+
+  @Override
+  public boolean equals(Object o) {
+    if (this == o) {
+      return true;
+    }
+    if (o == null || getClass() != o.getClass()) {
+      return false;
+    }
+    if (!super.equals(o)) {
+      return false;
+    }
+    MultiChildrenSinkNode that = (MultiChildrenSinkNode) o;
+    return children.equals(that.children);
+  }
+
+  @Override
+  public int hashCode() {
+    return Objects.hash(super.hashCode(), children);
+  }
+
+  @Override
+  public void send() {}
+
+  @Override
+  public void close() throws Exception {}
+
+  public List<DownStreamChannelLocation> getDownStreamChannelLocationList() {
+    return downStreamChannelLocationList;
+  }
+
+  public void addDownStreamChannelLocation(DownStreamChannelLocation downStreamChannelLocation) {
+    downStreamChannelLocationList.add(downStreamChannelLocation);
+  }
+}
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/node/sink/ShuffleSinkNode.java b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/node/sink/ShuffleSinkNode.java
new file mode 100644
index 0000000000..3352202649
--- /dev/null
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/plan/node/sink/ShuffleSinkNode.java
@@ -0,0 +1,105 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iotdb.db.mpp.plan.planner.plan.node.sink;
+
+import org.apache.iotdb.db.mpp.execution.exchange.sink.DownStreamChannelLocation;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.PlanNode;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.PlanNodeId;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.PlanNodeType;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.PlanVisitor;
+import org.apache.iotdb.tsfile.utils.ReadWriteIOUtils;
+
+import java.io.DataOutputStream;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.stream.Collectors;
+
+/** Responsible for creating ShuffleHelperOperator and corresponding SinkHandle */
+public class ShuffleSinkNode extends MultiChildrenSinkNode {
+
+  public ShuffleSinkNode(PlanNodeId id) {
+    super(id);
+  }
+
+  public ShuffleSinkNode(
+      PlanNodeId id, List<DownStreamChannelLocation> downStreamChannelLocationList) {
+    super(id, downStreamChannelLocationList);
+  }
+
+  public ShuffleSinkNode(
+      PlanNodeId id,
+      List<PlanNode> children,
+      List<DownStreamChannelLocation> downStreamChannelLocationList) {
+    super(id, children, downStreamChannelLocationList);
+  }
+
+  @Override
+  public PlanNode clone() {
+    return new ShuffleSinkNode(getPlanNodeId(), getDownStreamChannelLocationList());
+  }
+
+  @Override
+  public int allowedChildCount() {
+    return CHILD_COUNT_NO_LIMIT;
+  }
+
+  @Override
+  public List<String> getOutputColumnNames() {
+    return children.stream()
+        .map(PlanNode::getOutputColumnNames)
+        .flatMap(List::stream)
+        .collect(Collectors.toList());
+  }
+
+  @Override
+  public <R, C> R accept(PlanVisitor<R, C> visitor, C context) {
+    return visitor.visitShuffleSink(this, context);
+  }
+
+  @Override
+  protected void serializeAttributes(ByteBuffer byteBuffer) {
+    PlanNodeType.SHUFFLE_SINK.serialize(byteBuffer);
+    ReadWriteIOUtils.write(downStreamChannelLocationList.size(), byteBuffer);
+    for (DownStreamChannelLocation downStreamChannelLocation : downStreamChannelLocationList) {
+      downStreamChannelLocation.serialize(byteBuffer);
+    }
+  }
+
+  @Override
+  protected void serializeAttributes(DataOutputStream stream) throws IOException {
+    PlanNodeType.SHUFFLE_SINK.serialize(stream);
+    ReadWriteIOUtils.write(downStreamChannelLocationList.size(), stream);
+    for (DownStreamChannelLocation downStreamChannelLocation : downStreamChannelLocationList) {
+      downStreamChannelLocation.serialize(stream);
+    }
+  }
+
+  public static ShuffleSinkNode deserialize(ByteBuffer byteBuffer) {
+    int size = ReadWriteIOUtils.readInt(byteBuffer);
+    List<DownStreamChannelLocation> downStreamChannelLocationList = new ArrayList<>();
+    for (int i = 0; i < size; i++) {
+      downStreamChannelLocationList.add(DownStreamChannelLocation.deserialize(byteBuffer));
+    }
+    PlanNodeId planNodeId = PlanNodeId.deserialize(byteBuffer);
+    return new ShuffleSinkNode(planNodeId, downStreamChannelLocationList);
+  }
+}
diff --git a/server/src/test/java/org/apache/iotdb/db/mpp/execution/DataDriverTest.java b/server/src/test/java/org/apache/iotdb/db/mpp/execution/DataDriverTest.java
index de1846f338..2c801a9982 100644
--- a/server/src/test/java/org/apache/iotdb/db/mpp/execution/DataDriverTest.java
+++ b/server/src/test/java/org/apache/iotdb/db/mpp/execution/DataDriverTest.java
@@ -32,7 +32,7 @@ import org.apache.iotdb.db.mpp.common.QueryId;
 import org.apache.iotdb.db.mpp.execution.driver.DataDriver;
 import org.apache.iotdb.db.mpp.execution.driver.DataDriverContext;
 import org.apache.iotdb.db.mpp.execution.driver.IDriver;
-import org.apache.iotdb.db.mpp.execution.exchange.StubSinkHandle;
+import org.apache.iotdb.db.mpp.execution.exchange.StubSink;
 import org.apache.iotdb.db.mpp.execution.fragment.FragmentInstanceContext;
 import org.apache.iotdb.db.mpp.execution.fragment.FragmentInstanceState;
 import org.apache.iotdb.db.mpp.execution.fragment.FragmentInstanceStateMachine;
@@ -174,8 +174,8 @@ public class DataDriverTest {
           .thenReturn(new QueryDataSource(seqResources, unSeqResources));
       fragmentInstanceContext.initQueryDataSource(driverContext.getPaths());
 
-      StubSinkHandle sinkHandle = new StubSinkHandle(fragmentInstanceContext);
-      driverContext.setSinkHandle(sinkHandle);
+      StubSink stubSink = new StubSink(fragmentInstanceContext);
+      driverContext.setSink(stubSink);
       IDriver dataDriver = null;
       try {
         dataDriver = new DataDriver(limitOperator, driverContext);
@@ -192,7 +192,7 @@ public class DataDriverTest {
 
         assertEquals(FragmentInstanceState.FLUSHING, stateMachine.getState());
 
-        List<TsBlock> result = sinkHandle.getTsBlocks();
+        List<TsBlock> result = stubSink.getTsBlocks();
 
         int row = 0;
         for (TsBlock tsBlock : result) {
diff --git a/server/src/test/java/org/apache/iotdb/db/mpp/execution/exchange/LocalSinkHandleTest.java b/server/src/test/java/org/apache/iotdb/db/mpp/execution/exchange/LocalSinkChannelTest.java
similarity index 64%
rename from server/src/test/java/org/apache/iotdb/db/mpp/execution/exchange/LocalSinkHandleTest.java
rename to server/src/test/java/org/apache/iotdb/db/mpp/execution/exchange/LocalSinkChannelTest.java
index 8bcb1b300b..dac0ec75d8 100644
--- a/server/src/test/java/org/apache/iotdb/db/mpp/execution/exchange/LocalSinkHandleTest.java
+++ b/server/src/test/java/org/apache/iotdb/db/mpp/execution/exchange/LocalSinkChannelTest.java
@@ -20,7 +20,9 @@
 package org.apache.iotdb.db.mpp.execution.exchange;
 
 import org.apache.iotdb.db.mpp.common.FragmentInstanceId;
-import org.apache.iotdb.db.mpp.execution.exchange.MPPDataExchangeManager.SinkHandleListener;
+import org.apache.iotdb.db.mpp.execution.exchange.MPPDataExchangeManager.SinkListener;
+import org.apache.iotdb.db.mpp.execution.exchange.sink.LocalSinkChannel;
+import org.apache.iotdb.db.mpp.execution.exchange.source.LocalSourceHandle;
 import org.apache.iotdb.db.mpp.execution.memory.LocalMemoryManager;
 import org.apache.iotdb.db.mpp.execution.memory.MemoryPool;
 import org.apache.iotdb.mpp.rpc.thrift.TFragmentInstanceId;
@@ -30,7 +32,7 @@ import org.junit.Assert;
 import org.junit.Test;
 import org.mockito.Mockito;
 
-public class LocalSinkHandleTest {
+public class LocalSinkChannelTest {
   @Test
   public void testSend() {
     final String queryId = "q0";
@@ -44,16 +46,15 @@ public class LocalSinkHandleTest {
     MemoryPool spyMemoryPool =
         Mockito.spy(new MemoryPool("test", 10 * mockTsBlockSize, 5 * mockTsBlockSize));
     Mockito.when(mockLocalMemoryManager.getQueryPool()).thenReturn(spyMemoryPool);
-    // Construct a mock SinkHandleListener.
-    SinkHandleListener mockSinkHandleListener =
-        Mockito.mock(MPPDataExchangeManager.SinkHandleListener.class);
+    // Construct a mock SinkListener.
+    SinkListener mockSinkListener = Mockito.mock(SinkListener.class);
     // Construct a shared TsBlock queue.
     SharedTsBlockQueue queue =
         new SharedTsBlockQueue(remoteFragmentInstanceId, remotePlanNodeId, mockLocalMemoryManager);
 
-    // Construct SinkHandle.
-    LocalSinkHandle localSinkHandle =
-        new LocalSinkHandle(localFragmentInstanceId, queue, mockSinkHandleListener);
+    // Construct Sink.
+    LocalSinkChannel localSinkChannel =
+        new LocalSinkChannel(localFragmentInstanceId, queue, mockSinkListener);
 
     queue.setMaxBytesCanReserve(Long.MAX_VALUE);
 
@@ -65,25 +66,25 @@ public class LocalSinkHandleTest {
             queue,
             Mockito.mock(MPPDataExchangeManager.SourceHandleListener.class));
 
-    Assert.assertFalse(localSinkHandle.isFull().isDone());
+    Assert.assertFalse(localSinkChannel.isFull().isDone());
     localSourceHandle.isBlocked();
-    // blocked of LocalSinkHandle should be completed after calling isBlocked() of corresponding
+    // blocked of LocalSinkChannel should be completed after calling isBlocked() of corresponding
     // LocalSourceHandle
-    Assert.assertTrue(localSinkHandle.isFull().isDone());
-    Assert.assertFalse(localSinkHandle.isFinished());
-    Assert.assertFalse(localSinkHandle.isAborted());
-    Assert.assertEquals(0L, localSinkHandle.getBufferRetainedSizeInBytes());
+    Assert.assertTrue(localSinkChannel.isFull().isDone());
+    Assert.assertFalse(localSinkChannel.isFinished());
+    Assert.assertFalse(localSinkChannel.isAborted());
+    Assert.assertEquals(0L, localSinkChannel.getBufferRetainedSizeInBytes());
 
     // Send TsBlocks.
     int numOfSentTsblocks = 0;
-    while (localSinkHandle.isFull().isDone()) {
-      localSinkHandle.send(Utils.createMockTsBlock(mockTsBlockSize));
+    while (localSinkChannel.isFull().isDone()) {
+      localSinkChannel.send(Utils.createMockTsBlock(mockTsBlockSize));
       numOfSentTsblocks += 1;
     }
     Assert.assertEquals(11, numOfSentTsblocks);
-    Assert.assertFalse(localSinkHandle.isFull().isDone());
-    Assert.assertFalse(localSinkHandle.isFinished());
-    Assert.assertEquals(11 * mockTsBlockSize, localSinkHandle.getBufferRetainedSizeInBytes());
+    Assert.assertFalse(localSinkChannel.isFull().isDone());
+    Assert.assertFalse(localSinkChannel.isFinished());
+    Assert.assertEquals(11 * mockTsBlockSize, localSinkChannel.getBufferRetainedSizeInBytes());
     Mockito.verify(spyMemoryPool, Mockito.times(11))
         .reserve(
             queryId,
@@ -100,9 +101,9 @@ public class LocalSinkHandleTest {
       numOfReceivedTsblocks += 1;
     }
     Assert.assertEquals(11, numOfReceivedTsblocks);
-    Assert.assertTrue(localSinkHandle.isFull().isDone());
-    Assert.assertFalse(localSinkHandle.isFinished());
-    Assert.assertEquals(0L, localSinkHandle.getBufferRetainedSizeInBytes());
+    Assert.assertTrue(localSinkChannel.isFull().isDone());
+    Assert.assertFalse(localSinkChannel.isFinished());
+    Assert.assertEquals(0L, localSinkChannel.getBufferRetainedSizeInBytes());
     Mockito.verify(spyMemoryPool, Mockito.times(11))
         .free(
             queryId,
@@ -112,11 +113,11 @@ public class LocalSinkHandleTest {
             mockTsBlockSize);
 
     // Set no-more-TsBlocks.
-    localSinkHandle.setNoMoreTsBlocks();
-    Assert.assertTrue(localSinkHandle.isFull().isDone());
-    Assert.assertTrue(localSinkHandle.isFinished());
-    Mockito.verify(mockSinkHandleListener, Mockito.times(1)).onEndOfBlocks(localSinkHandle);
-    Mockito.verify(mockSinkHandleListener, Mockito.times(1)).onFinish(localSinkHandle);
+    localSinkChannel.setNoMoreTsBlocks();
+    Assert.assertTrue(localSinkChannel.isFull().isDone());
+    Assert.assertTrue(localSinkChannel.isFinished());
+    Mockito.verify(mockSinkListener, Mockito.times(1)).onEndOfBlocks(localSinkChannel);
+    Mockito.verify(mockSinkListener, Mockito.times(1)).onFinish(localSinkChannel);
   }
 
   @Test
@@ -132,16 +133,15 @@ public class LocalSinkHandleTest {
     MemoryPool spyMemoryPool =
         Mockito.spy(new MemoryPool("test", 10 * mockTsBlockSize, 5 * mockTsBlockSize));
     Mockito.when(mockLocalMemoryManager.getQueryPool()).thenReturn(spyMemoryPool);
-    // Construct a mock SinkHandleListener.
-    MPPDataExchangeManager.SinkHandleListener mockSinkHandleListener =
-        Mockito.mock(MPPDataExchangeManager.SinkHandleListener.class);
+    // Construct a mock SinkListener.
+    SinkListener mockSinkListener = Mockito.mock(SinkListener.class);
     // Construct a shared tsblock queue.
     SharedTsBlockQueue queue =
         new SharedTsBlockQueue(remoteFragmentInstanceId, remotePlanNodeId, mockLocalMemoryManager);
 
-    // Construct SinkHandle.
-    LocalSinkHandle localSinkHandle =
-        new LocalSinkHandle(localFragmentInstanceId, queue, mockSinkHandleListener);
+    // Construct SinkChannel.
+    LocalSinkChannel localSinkChannel =
+        new LocalSinkChannel(localFragmentInstanceId, queue, mockSinkListener);
 
     queue.setMaxBytesCanReserve(Long.MAX_VALUE);
 
@@ -153,26 +153,26 @@ public class LocalSinkHandleTest {
             queue,
             Mockito.mock(MPPDataExchangeManager.SourceHandleListener.class));
 
-    Assert.assertFalse(localSinkHandle.isFull().isDone());
+    Assert.assertFalse(localSinkChannel.isFull().isDone());
     localSourceHandle.isBlocked();
-    // blocked of LocalSinkHandle should be completed after calling isBlocked() of corresponding
+    // blocked of LocalSinkChannel should be completed after calling isBlocked() of corresponding
     // LocalSourceHandle
-    Assert.assertTrue(localSinkHandle.isFull().isDone());
-    Assert.assertFalse(localSinkHandle.isFinished());
-    Assert.assertFalse(localSinkHandle.isAborted());
-    Assert.assertEquals(0L, localSinkHandle.getBufferRetainedSizeInBytes());
+    Assert.assertTrue(localSinkChannel.isFull().isDone());
+    Assert.assertFalse(localSinkChannel.isFinished());
+    Assert.assertFalse(localSinkChannel.isAborted());
+    Assert.assertEquals(0L, localSinkChannel.getBufferRetainedSizeInBytes());
 
     // Send TsBlocks.
     int numOfSentTsblocks = 0;
-    while (localSinkHandle.isFull().isDone()) {
-      localSinkHandle.send(Utils.createMockTsBlock(mockTsBlockSize));
+    while (localSinkChannel.isFull().isDone()) {
+      localSinkChannel.send(Utils.createMockTsBlock(mockTsBlockSize));
       numOfSentTsblocks += 1;
     }
     Assert.assertEquals(11, numOfSentTsblocks);
-    ListenableFuture<?> blocked = localSinkHandle.isFull();
+    ListenableFuture<?> blocked = localSinkChannel.isFull();
     Assert.assertFalse(blocked.isDone());
-    Assert.assertFalse(localSinkHandle.isFinished());
-    Assert.assertEquals(11 * mockTsBlockSize, localSinkHandle.getBufferRetainedSizeInBytes());
+    Assert.assertFalse(localSinkChannel.isFinished());
+    Assert.assertEquals(11 * mockTsBlockSize, localSinkChannel.getBufferRetainedSizeInBytes());
     Mockito.verify(spyMemoryPool, Mockito.times(11))
         .reserve(
             queryId,
@@ -183,10 +183,10 @@ public class LocalSinkHandleTest {
             Long.MAX_VALUE);
 
     // Abort.
-    localSinkHandle.abort();
+    localSinkChannel.abort();
     Assert.assertTrue(blocked.isDone());
-    Assert.assertFalse(localSinkHandle.isFinished());
-    Assert.assertTrue(localSinkHandle.isAborted());
-    Mockito.verify(mockSinkHandleListener, Mockito.times(1)).onAborted(localSinkHandle);
+    Assert.assertFalse(localSinkChannel.isFinished());
+    Assert.assertTrue(localSinkChannel.isAborted());
+    Mockito.verify(mockSinkListener, Mockito.times(1)).onAborted(localSinkChannel);
   }
 }
diff --git a/server/src/test/java/org/apache/iotdb/db/mpp/execution/exchange/LocalSourceHandleTest.java b/server/src/test/java/org/apache/iotdb/db/mpp/execution/exchange/LocalSourceHandleTest.java
index 674a80ce3e..aa15f19967 100644
--- a/server/src/test/java/org/apache/iotdb/db/mpp/execution/exchange/LocalSourceHandleTest.java
+++ b/server/src/test/java/org/apache/iotdb/db/mpp/execution/exchange/LocalSourceHandleTest.java
@@ -20,6 +20,7 @@
 package org.apache.iotdb.db.mpp.execution.exchange;
 
 import org.apache.iotdb.db.mpp.execution.exchange.MPPDataExchangeManager.SourceHandleListener;
+import org.apache.iotdb.db.mpp.execution.exchange.source.LocalSourceHandle;
 import org.apache.iotdb.db.mpp.execution.memory.LocalMemoryManager;
 import org.apache.iotdb.db.mpp.execution.memory.MemoryPool;
 import org.apache.iotdb.mpp.rpc.thrift.TFragmentInstanceId;
diff --git a/server/src/test/java/org/apache/iotdb/db/mpp/execution/exchange/MPPDataExchangeManagerTest.java b/server/src/test/java/org/apache/iotdb/db/mpp/execution/exchange/MPPDataExchangeManagerTest.java
index 786d5b1830..f66baa559d 100644
--- a/server/src/test/java/org/apache/iotdb/db/mpp/execution/exchange/MPPDataExchangeManagerTest.java
+++ b/server/src/test/java/org/apache/iotdb/db/mpp/execution/exchange/MPPDataExchangeManagerTest.java
@@ -23,6 +23,14 @@ import org.apache.iotdb.common.rpc.thrift.TEndPoint;
 import org.apache.iotdb.commons.client.ClientPoolFactory;
 import org.apache.iotdb.commons.client.IClientManager;
 import org.apache.iotdb.commons.client.sync.SyncDataNodeMPPDataExchangeServiceClient;
+import org.apache.iotdb.db.conf.IoTDBDescriptor;
+import org.apache.iotdb.db.mpp.execution.exchange.sink.DownStreamChannelIndex;
+import org.apache.iotdb.db.mpp.execution.exchange.sink.DownStreamChannelLocation;
+import org.apache.iotdb.db.mpp.execution.exchange.sink.ISinkHandle;
+import org.apache.iotdb.db.mpp.execution.exchange.sink.LocalSinkChannel;
+import org.apache.iotdb.db.mpp.execution.exchange.sink.ShuffleSinkHandle;
+import org.apache.iotdb.db.mpp.execution.exchange.source.ISourceHandle;
+import org.apache.iotdb.db.mpp.execution.exchange.source.LocalSourceHandle;
 import org.apache.iotdb.db.mpp.execution.fragment.FragmentInstanceContext;
 import org.apache.iotdb.db.mpp.execution.memory.LocalMemoryManager;
 import org.apache.iotdb.db.mpp.execution.memory.MemoryPool;
@@ -32,6 +40,7 @@ import org.junit.Assert;
 import org.junit.Test;
 import org.mockito.Mockito;
 
+import java.util.Collections;
 import java.util.concurrent.Executors;
 
 public class MPPDataExchangeManagerTest {
@@ -40,6 +49,7 @@ public class MPPDataExchangeManagerTest {
     final TFragmentInstanceId localFragmentInstanceId = new TFragmentInstanceId("q0", 1, "0");
     final TFragmentInstanceId remoteFragmentInstanceId = new TFragmentInstanceId("q0", 0, "0");
     final String remotePlanNodeId = "exchange_0";
+    final String localPlanNodeId = "shuffleSink_0";
     final FragmentInstanceContext mockFragmentInstanceContext =
         Mockito.mock(FragmentInstanceContext.class);
 
@@ -57,23 +67,36 @@ public class MPPDataExchangeManagerTest {
                 .createClientManager(
                     new ClientPoolFactory.SyncDataNodeMPPDataExchangeServiceClientPoolFactory()));
 
-    ISinkHandle localSinkHandle =
-        mppDataExchangeManager.createLocalSinkHandleForFragment(
+    ISinkHandle shuffleSinkHandle =
+        mppDataExchangeManager.createShuffleSinkHandle(
+            Collections.singletonList(
+                new DownStreamChannelLocation(
+                    new TEndPoint(
+                        IoTDBDescriptor.getInstance().getConfig().getInternalAddress(),
+                        IoTDBDescriptor.getInstance().getConfig().getMppDataExchangePort()),
+                    remoteFragmentInstanceId,
+                    remotePlanNodeId)),
+            new DownStreamChannelIndex(0),
+            ShuffleSinkHandle.ShuffleStrategyEnum.PLAIN,
             localFragmentInstanceId,
-            remoteFragmentInstanceId,
-            remotePlanNodeId,
+            localPlanNodeId,
             mockFragmentInstanceContext);
 
-    Assert.assertTrue(localSinkHandle instanceof LocalSinkHandle);
+    Assert.assertTrue(shuffleSinkHandle instanceof ShuffleSinkHandle);
 
     ISourceHandle localSourceHandle =
         mppDataExchangeManager.createLocalSourceHandleForFragment(
-            remoteFragmentInstanceId, remotePlanNodeId, localFragmentInstanceId, t -> {});
+            remoteFragmentInstanceId,
+            remotePlanNodeId,
+            localPlanNodeId,
+            localFragmentInstanceId,
+            0,
+            t -> {});
 
     Assert.assertTrue(localSourceHandle instanceof LocalSourceHandle);
 
     Assert.assertEquals(
-        ((LocalSinkHandle) localSinkHandle).getSharedTsBlockQueue(),
+        ((LocalSinkChannel) shuffleSinkHandle.getChannel(0)).getSharedTsBlockQueue(),
         ((LocalSourceHandle) localSourceHandle).getSharedTsBlockQueue());
   }
 
@@ -81,7 +104,8 @@ public class MPPDataExchangeManagerTest {
   public void testCreateLocalSourceHandle() {
     final TFragmentInstanceId remoteFragmentInstanceId = new TFragmentInstanceId("q0", 1, "0");
     final TFragmentInstanceId localFragmentInstanceId = new TFragmentInstanceId("q0", 0, "0");
-    final String localPlanNodeId = "exchange_0";
+    final String remotePlanNodeId = "exchange_0";
+    final String localPlanNodeId = "shuffleSink_0";
     final FragmentInstanceContext mockFragmentInstanceContext =
         Mockito.mock(FragmentInstanceContext.class);
 
@@ -101,21 +125,34 @@ public class MPPDataExchangeManagerTest {
 
     ISourceHandle localSourceHandle =
         mppDataExchangeManager.createLocalSourceHandleForFragment(
-            localFragmentInstanceId, localPlanNodeId, remoteFragmentInstanceId, t -> {});
+            remoteFragmentInstanceId,
+            remotePlanNodeId,
+            localPlanNodeId,
+            localFragmentInstanceId,
+            0,
+            t -> {});
 
     Assert.assertTrue(localSourceHandle instanceof LocalSourceHandle);
 
-    ISinkHandle localSinkHandle =
-        mppDataExchangeManager.createLocalSinkHandleForFragment(
-            remoteFragmentInstanceId,
+    ISinkHandle shuffleSinkHandle =
+        mppDataExchangeManager.createShuffleSinkHandle(
+            Collections.singletonList(
+                new DownStreamChannelLocation(
+                    new TEndPoint(
+                        IoTDBDescriptor.getInstance().getConfig().getInternalAddress(),
+                        IoTDBDescriptor.getInstance().getConfig().getMppDataExchangePort()),
+                    remoteFragmentInstanceId,
+                    remotePlanNodeId)),
+            new DownStreamChannelIndex(0),
+            ShuffleSinkHandle.ShuffleStrategyEnum.PLAIN,
             localFragmentInstanceId,
             localPlanNodeId,
             mockFragmentInstanceContext);
 
-    Assert.assertTrue(localSinkHandle instanceof LocalSinkHandle);
+    Assert.assertTrue(shuffleSinkHandle instanceof ShuffleSinkHandle);
 
     Assert.assertEquals(
-        ((LocalSinkHandle) localSinkHandle).getSharedTsBlockQueue(),
+        ((LocalSinkChannel) shuffleSinkHandle.getChannel(0)).getSharedTsBlockQueue(),
         ((LocalSourceHandle) localSourceHandle).getSharedTsBlockQueue());
   }
 }
diff --git a/server/src/test/java/org/apache/iotdb/db/mpp/execution/exchange/SinkHandleTest.java b/server/src/test/java/org/apache/iotdb/db/mpp/execution/exchange/SinkChannelTest.java
similarity index 73%
rename from server/src/test/java/org/apache/iotdb/db/mpp/execution/exchange/SinkHandleTest.java
rename to server/src/test/java/org/apache/iotdb/db/mpp/execution/exchange/SinkChannelTest.java
index d50e69dd2a..a1aff6d8e0 100644
--- a/server/src/test/java/org/apache/iotdb/db/mpp/execution/exchange/SinkHandleTest.java
+++ b/server/src/test/java/org/apache/iotdb/db/mpp/execution/exchange/SinkChannelTest.java
@@ -25,7 +25,8 @@ import org.apache.iotdb.commons.client.exception.ClientManagerException;
 import org.apache.iotdb.commons.client.sync.SyncDataNodeMPPDataExchangeServiceClient;
 import org.apache.iotdb.db.conf.IoTDBDescriptor;
 import org.apache.iotdb.db.mpp.common.FragmentInstanceId;
-import org.apache.iotdb.db.mpp.execution.exchange.MPPDataExchangeManager.SinkHandleListener;
+import org.apache.iotdb.db.mpp.execution.exchange.MPPDataExchangeManager.SinkListener;
+import org.apache.iotdb.db.mpp.execution.exchange.sink.SinkChannel;
 import org.apache.iotdb.db.mpp.execution.memory.LocalMemoryManager;
 import org.apache.iotdb.db.mpp.execution.memory.MemoryPool;
 import org.apache.iotdb.mpp.rpc.thrift.TEndOfDataBlockEvent;
@@ -45,7 +46,7 @@ import java.util.concurrent.Future;
 
 import static org.apache.iotdb.tsfile.read.common.block.TsBlockBuilderStatus.DEFAULT_MAX_TSBLOCK_SIZE_IN_BYTES;
 
-public class SinkHandleTest {
+public class SinkChannelTest {
 
   @Test
   public void testOneTimeNotBlockedSend() {
@@ -80,14 +81,14 @@ public class SinkHandleTest {
       e.printStackTrace();
       Assert.fail();
     }
-    // Construct a mock SinkHandleListener.
-    SinkHandleListener mockSinkHandleListener = Mockito.mock(SinkHandleListener.class);
+    // Construct a mock SinkListener.
+    SinkListener mockSinkListener = Mockito.mock(SinkListener.class);
     // Construct several mock TsBlock(s).
     List<TsBlock> mockTsBlocks = Utils.createMockTsBlocks(numOfMockTsBlock, mockTsBlockSize);
 
-    // Construct SinkHandle.
-    SinkHandle sinkHandle =
-        new SinkHandle(
+    // Construct SinkChannel.
+    SinkChannel sinkChannel =
+        new SinkChannel(
             remoteEndpoint,
             remoteFragmentInstanceId,
             remotePlanNodeId,
@@ -96,24 +97,25 @@ public class SinkHandleTest {
             mockLocalMemoryManager,
             Executors.newSingleThreadExecutor(),
             Utils.createMockTsBlockSerde(mockTsBlockSize),
-            mockSinkHandleListener,
+            mockSinkListener,
             mockClientManager);
-    Assert.assertTrue(sinkHandle.isFull().isDone());
-    Assert.assertFalse(sinkHandle.isFinished());
-    Assert.assertFalse(sinkHandle.isAborted());
+    sinkChannel.open();
+    Assert.assertTrue(sinkChannel.isFull().isDone());
+    Assert.assertFalse(sinkChannel.isFinished());
+    Assert.assertFalse(sinkChannel.isAborted());
     Assert.assertEquals(
-        DEFAULT_MAX_TSBLOCK_SIZE_IN_BYTES, sinkHandle.getBufferRetainedSizeInBytes());
-    Assert.assertEquals(0, sinkHandle.getNumOfBufferedTsBlocks());
+        DEFAULT_MAX_TSBLOCK_SIZE_IN_BYTES, sinkChannel.getBufferRetainedSizeInBytes());
+    Assert.assertEquals(0, sinkChannel.getNumOfBufferedTsBlocks());
 
     // Send tsblocks.
-    sinkHandle.send(mockTsBlocks.get(0));
-    Assert.assertTrue(sinkHandle.isFull().isDone());
-    Assert.assertFalse(sinkHandle.isFinished());
-    Assert.assertFalse(sinkHandle.isAborted());
+    sinkChannel.send(mockTsBlocks.get(0));
+    Assert.assertTrue(sinkChannel.isFull().isDone());
+    Assert.assertFalse(sinkChannel.isFinished());
+    Assert.assertFalse(sinkChannel.isAborted());
     Assert.assertEquals(
         mockTsBlockSize * numOfMockTsBlock + DEFAULT_MAX_TSBLOCK_SIZE_IN_BYTES,
-        sinkHandle.getBufferRetainedSizeInBytes());
-    Assert.assertEquals(numOfMockTsBlock, sinkHandle.getNumOfBufferedTsBlocks());
+        sinkChannel.getBufferRetainedSizeInBytes());
+    Assert.assertEquals(numOfMockTsBlock, sinkChannel.getNumOfBufferedTsBlocks());
     //    Mockito.verify(mockMemoryPool, Mockito.timeout(10_0000).times(1))
     //        .reserve(
     //            queryId,
@@ -140,29 +142,28 @@ public class SinkHandleTest {
     // Get tsblocks.
     for (int i = 0; i < numOfMockTsBlock; i++) {
       try {
-        sinkHandle.getSerializedTsBlock(i);
+        sinkChannel.getSerializedTsBlock(i);
       } catch (IOException e) {
         e.printStackTrace();
         Assert.fail();
       }
-      Assert.assertTrue(sinkHandle.isFull().isDone());
+      Assert.assertTrue(sinkChannel.isFull().isDone());
     }
-    Assert.assertFalse(sinkHandle.isFinished());
+    Assert.assertFalse(sinkChannel.isFinished());
 
     // Set no-more-tsblocks.
-    sinkHandle.setNoMoreTsBlocks();
-    Assert.assertTrue(sinkHandle.isFull().isDone());
-    Assert.assertFalse(sinkHandle.isFinished());
-    Assert.assertFalse(sinkHandle.isAborted());
-    Mockito.verify(mockSinkHandleListener, Mockito.timeout(10_000).times(1))
-        .onEndOfBlocks(sinkHandle);
+    sinkChannel.setNoMoreTsBlocks();
+    Assert.assertTrue(sinkChannel.isFull().isDone());
+    Assert.assertFalse(sinkChannel.isFinished());
+    Assert.assertFalse(sinkChannel.isAborted());
+    Mockito.verify(mockSinkListener, Mockito.timeout(10_000).times(1)).onEndOfBlocks(sinkChannel);
 
     // Ack tsblocks.
-    sinkHandle.acknowledgeTsBlock(0, numOfMockTsBlock);
-    Assert.assertTrue(sinkHandle.isFull().isDone());
-    Assert.assertTrue(sinkHandle.isFinished());
-    Assert.assertFalse(sinkHandle.isAborted());
-    Assert.assertEquals(mockTsBlockSize, sinkHandle.getBufferRetainedSizeInBytes());
+    sinkChannel.acknowledgeTsBlock(0, numOfMockTsBlock);
+    Assert.assertTrue(sinkChannel.isFull().isDone());
+    Assert.assertTrue(sinkChannel.isFinished());
+    Assert.assertFalse(sinkChannel.isAborted());
+    Assert.assertEquals(mockTsBlockSize, sinkChannel.getBufferRetainedSizeInBytes());
     Mockito.verify(mockMemoryPool, Mockito.timeout(10_0000).times(1))
         .free(
             queryId,
@@ -170,7 +171,7 @@ public class SinkHandleTest {
                 localFragmentInstanceId),
             localPlanNodeId,
             numOfMockTsBlock * mockTsBlockSize);
-    Mockito.verify(mockSinkHandleListener, Mockito.timeout(10_0000).times(1)).onFinish(sinkHandle);
+    Mockito.verify(mockSinkListener, Mockito.timeout(10_0000).times(1)).onFinish(sinkChannel);
 
     try {
       Mockito.verify(mockClient, Mockito.timeout(10_000).times(1))
@@ -211,8 +212,8 @@ public class SinkHandleTest {
             mockTsBlockSize);
     Mockito.when(mockLocalMemoryManager.getQueryPool()).thenReturn(mockMemoryPool);
 
-    // Construct a mock SinkHandleListener.
-    SinkHandleListener mockSinkHandleListener = Mockito.mock(SinkHandleListener.class);
+    // Construct a mock SinkListener.
+    SinkListener mockSinkListener = Mockito.mock(SinkListener.class);
     // Construct several mock TsBlock(s).
     List<TsBlock> mockTsBlocks = Utils.createMockTsBlocks(numOfMockTsBlock, mockTsBlockSize);
     IClientManager<TEndPoint, SyncDataNodeMPPDataExchangeServiceClient> mockClientManager =
@@ -233,9 +234,9 @@ public class SinkHandleTest {
       Assert.fail();
     }
 
-    // Construct SinkHandle.
-    SinkHandle sinkHandle =
-        new SinkHandle(
+    // Construct SinkChannel.
+    SinkChannel sinkChannel =
+        new SinkChannel(
             remoteEndpoint,
             remoteFragmentInstanceId,
             remotePlanNodeId,
@@ -244,24 +245,25 @@ public class SinkHandleTest {
             mockLocalMemoryManager,
             Executors.newSingleThreadExecutor(),
             Utils.createMockTsBlockSerde(mockTsBlockSize),
-            mockSinkHandleListener,
+            mockSinkListener,
             mockClientManager);
-    Assert.assertTrue(sinkHandle.isFull().isDone());
-    Assert.assertFalse(sinkHandle.isFinished());
-    Assert.assertFalse(sinkHandle.isAborted());
+    sinkChannel.open();
+    Assert.assertTrue(sinkChannel.isFull().isDone());
+    Assert.assertFalse(sinkChannel.isFinished());
+    Assert.assertFalse(sinkChannel.isAborted());
     Assert.assertEquals(
-        DEFAULT_MAX_TSBLOCK_SIZE_IN_BYTES, sinkHandle.getBufferRetainedSizeInBytes());
-    Assert.assertEquals(0, sinkHandle.getNumOfBufferedTsBlocks());
+        DEFAULT_MAX_TSBLOCK_SIZE_IN_BYTES, sinkChannel.getBufferRetainedSizeInBytes());
+    Assert.assertEquals(0, sinkChannel.getNumOfBufferedTsBlocks());
 
     // Send tsblocks.
-    sinkHandle.send(mockTsBlocks.get(0));
-    Assert.assertFalse(sinkHandle.isFull().isDone());
-    Assert.assertFalse(sinkHandle.isFinished());
-    Assert.assertFalse(sinkHandle.isAborted());
+    sinkChannel.send(mockTsBlocks.get(0));
+    Assert.assertFalse(sinkChannel.isFull().isDone());
+    Assert.assertFalse(sinkChannel.isFinished());
+    Assert.assertFalse(sinkChannel.isAborted());
     Assert.assertEquals(
         mockTsBlockSize * numOfMockTsBlock + DEFAULT_MAX_TSBLOCK_SIZE_IN_BYTES,
-        sinkHandle.getBufferRetainedSizeInBytes());
-    Assert.assertEquals(numOfMockTsBlock, sinkHandle.getNumOfBufferedTsBlocks());
+        sinkChannel.getBufferRetainedSizeInBytes());
+    Assert.assertEquals(numOfMockTsBlock, sinkChannel.getNumOfBufferedTsBlocks());
     //    Mockito.verify(mockMemoryPool, Mockito.timeout(10_0000).times(1))
     //        .reserve(
     //            queryId,
@@ -288,22 +290,22 @@ public class SinkHandleTest {
     // Get tsblocks.
     for (int i = 0; i < numOfMockTsBlock; i++) {
       try {
-        sinkHandle.getSerializedTsBlock(i);
+        sinkChannel.getSerializedTsBlock(i);
       } catch (IOException e) {
         e.printStackTrace();
         Assert.fail();
       }
-      Assert.assertFalse(sinkHandle.isFull().isDone());
+      Assert.assertFalse(sinkChannel.isFull().isDone());
     }
-    Assert.assertFalse(sinkHandle.isFinished());
+    Assert.assertFalse(sinkChannel.isFinished());
 
     // Ack tsblocks.
-    sinkHandle.acknowledgeTsBlock(0, numOfMockTsBlock);
-    Assert.assertTrue(sinkHandle.isFull().isDone());
-    Assert.assertFalse(sinkHandle.isFinished());
-    Assert.assertFalse(sinkHandle.isAborted());
+    sinkChannel.acknowledgeTsBlock(0, numOfMockTsBlock);
+    Assert.assertTrue(sinkChannel.isFull().isDone());
+    Assert.assertFalse(sinkChannel.isFinished());
+    Assert.assertFalse(sinkChannel.isAborted());
     Assert.assertEquals(
-        DEFAULT_MAX_TSBLOCK_SIZE_IN_BYTES, sinkHandle.getBufferRetainedSizeInBytes());
+        DEFAULT_MAX_TSBLOCK_SIZE_IN_BYTES, sinkChannel.getBufferRetainedSizeInBytes());
     Mockito.verify(mockMemoryPool, Mockito.timeout(10_0000).times(1))
         .free(
             queryId,
@@ -313,14 +315,14 @@ public class SinkHandleTest {
             numOfMockTsBlock * mockTsBlockSize);
 
     // Send tsblocks.
-    sinkHandle.send(mockTsBlocks.get(0));
-    Assert.assertFalse(sinkHandle.isFull().isDone());
-    Assert.assertFalse(sinkHandle.isFinished());
-    Assert.assertFalse(sinkHandle.isAborted());
+    sinkChannel.send(mockTsBlocks.get(0));
+    Assert.assertFalse(sinkChannel.isFull().isDone());
+    Assert.assertFalse(sinkChannel.isFinished());
+    Assert.assertFalse(sinkChannel.isAborted());
     Assert.assertEquals(
         mockTsBlockSize * numOfMockTsBlock + DEFAULT_MAX_TSBLOCK_SIZE_IN_BYTES,
-        sinkHandle.getBufferRetainedSizeInBytes());
-    Assert.assertEquals(numOfMockTsBlock, sinkHandle.getNumOfBufferedTsBlocks());
+        sinkChannel.getBufferRetainedSizeInBytes());
+    Assert.assertEquals(numOfMockTsBlock, sinkChannel.getNumOfBufferedTsBlocks());
     //    Mockito.verify(mockMemoryPool, Mockito.timeout(10_0000).times(3))
     //        .reserve(
     //            queryId,
@@ -345,11 +347,10 @@ public class SinkHandleTest {
     }
 
     // Set no-more-tsblocks.
-    sinkHandle.setNoMoreTsBlocks();
-    Assert.assertFalse(sinkHandle.isFinished());
-    Assert.assertFalse(sinkHandle.isAborted());
-    Mockito.verify(mockSinkHandleListener, Mockito.timeout(10_000).times(1))
-        .onEndOfBlocks(sinkHandle);
+    sinkChannel.setNoMoreTsBlocks();
+    Assert.assertFalse(sinkChannel.isFinished());
+    Assert.assertFalse(sinkChannel.isAborted());
+    Mockito.verify(mockSinkListener, Mockito.timeout(10_000).times(1)).onEndOfBlocks(sinkChannel);
 
     try {
       Mockito.verify(mockClient, Mockito.timeout(10_000).times(1))
@@ -368,20 +369,20 @@ public class SinkHandleTest {
     // Get tsblocks after no-more-tsblocks is set.
     for (int i = numOfMockTsBlock; i < numOfMockTsBlock * 2; i++) {
       try {
-        sinkHandle.getSerializedTsBlock(i);
+        sinkChannel.getSerializedTsBlock(i);
       } catch (IOException e) {
         e.printStackTrace();
         Assert.fail();
       }
     }
-    Assert.assertFalse(sinkHandle.isFinished());
+    Assert.assertFalse(sinkChannel.isFinished());
 
     // Ack tsblocks.
-    sinkHandle.acknowledgeTsBlock(numOfMockTsBlock, numOfMockTsBlock * 2);
-    Assert.assertTrue(sinkHandle.isFinished());
-    Assert.assertFalse(sinkHandle.isAborted());
+    sinkChannel.acknowledgeTsBlock(numOfMockTsBlock, numOfMockTsBlock * 2);
+    Assert.assertTrue(sinkChannel.isFinished());
+    Assert.assertFalse(sinkChannel.isAborted());
     Assert.assertEquals(
-        DEFAULT_MAX_TSBLOCK_SIZE_IN_BYTES, sinkHandle.getBufferRetainedSizeInBytes());
+        DEFAULT_MAX_TSBLOCK_SIZE_IN_BYTES, sinkChannel.getBufferRetainedSizeInBytes());
     Mockito.verify(mockMemoryPool, Mockito.timeout(10_0000).times(2))
         .free(
             queryId,
@@ -389,7 +390,7 @@ public class SinkHandleTest {
                 localFragmentInstanceId),
             localPlanNodeId,
             numOfMockTsBlock * mockTsBlockSize);
-    Mockito.verify(mockSinkHandleListener, Mockito.timeout(10_0000).times(1)).onFinish(sinkHandle);
+    Mockito.verify(mockSinkListener, Mockito.timeout(10_0000).times(1)).onFinish(sinkChannel);
   }
 
   @Test
@@ -415,8 +416,8 @@ public class SinkHandleTest {
             numOfMockTsBlock,
             mockTsBlockSize);
     Mockito.when(mockLocalMemoryManager.getQueryPool()).thenReturn(mockMemoryPool);
-    // Construct a mock SinkHandleListener.
-    SinkHandleListener mockSinkHandleListener = Mockito.mock(SinkHandleListener.class);
+    // Construct a mock SinkListener.
+    SinkListener mockSinkListener = Mockito.mock(SinkListener.class);
     // Construct several mock TsBlock(s).
     List<TsBlock> mockTsBlocks = Utils.createMockTsBlocks(numOfMockTsBlock, mockTsBlockSize);
     IClientManager<TEndPoint, SyncDataNodeMPPDataExchangeServiceClient> mockClientManager =
@@ -438,9 +439,9 @@ public class SinkHandleTest {
       Assert.fail();
     }
 
-    // Construct SinkHandle.
-    SinkHandle sinkHandle =
-        new SinkHandle(
+    // Construct SinkChannel.
+    SinkChannel sinkChannel =
+        new SinkChannel(
             remoteEndpoint,
             remoteFragmentInstanceId,
             remotePlanNodeId,
@@ -449,25 +450,26 @@ public class SinkHandleTest {
             mockLocalMemoryManager,
             Executors.newSingleThreadExecutor(),
             Utils.createMockTsBlockSerde(mockTsBlockSize),
-            mockSinkHandleListener,
+            mockSinkListener,
             mockClientManager);
-    sinkHandle.setRetryIntervalInMs(0L);
-    Assert.assertTrue(sinkHandle.isFull().isDone());
-    Assert.assertFalse(sinkHandle.isFinished());
-    Assert.assertFalse(sinkHandle.isAborted());
+    sinkChannel.setRetryIntervalInMs(0L);
+    sinkChannel.open();
+    Assert.assertTrue(sinkChannel.isFull().isDone());
+    Assert.assertFalse(sinkChannel.isFinished());
+    Assert.assertFalse(sinkChannel.isAborted());
     Assert.assertEquals(
-        DEFAULT_MAX_TSBLOCK_SIZE_IN_BYTES, sinkHandle.getBufferRetainedSizeInBytes());
-    Assert.assertEquals(0, sinkHandle.getNumOfBufferedTsBlocks());
+        DEFAULT_MAX_TSBLOCK_SIZE_IN_BYTES, sinkChannel.getBufferRetainedSizeInBytes());
+    Assert.assertEquals(0, sinkChannel.getNumOfBufferedTsBlocks());
 
     // Send tsblocks.
-    sinkHandle.send(mockTsBlocks.get(0));
-    Assert.assertFalse(sinkHandle.isFull().isDone());
-    Assert.assertFalse(sinkHandle.isFinished());
-    Assert.assertFalse(sinkHandle.isAborted());
+    sinkChannel.send(mockTsBlocks.get(0));
+    Assert.assertFalse(sinkChannel.isFull().isDone());
+    Assert.assertFalse(sinkChannel.isFinished());
+    Assert.assertFalse(sinkChannel.isAborted());
     Assert.assertEquals(
         mockTsBlockSize * numOfMockTsBlock + DEFAULT_MAX_TSBLOCK_SIZE_IN_BYTES,
-        sinkHandle.getBufferRetainedSizeInBytes());
-    Assert.assertEquals(numOfMockTsBlock, sinkHandle.getNumOfBufferedTsBlocks());
+        sinkChannel.getBufferRetainedSizeInBytes());
+    Assert.assertEquals(numOfMockTsBlock, sinkChannel.getNumOfBufferedTsBlocks());
     //    Mockito.verify(mockMemoryPool, Mockito.timeout(10_0000).times(1))
     //        .reserve(
     //            queryId,
@@ -477,7 +479,7 @@ public class SinkHandleTest {
     //            mockTsBlockSize * numOfMockTsBlock,
     //            Long.MAX_VALUE);
     try {
-      Mockito.verify(mockClient, Mockito.timeout(10_000).times(SinkHandle.MAX_ATTEMPT_TIMES))
+      Mockito.verify(mockClient, Mockito.timeout(10_000).times(SinkChannel.MAX_ATTEMPT_TIMES))
           .onNewDataBlockEvent(
               Mockito.argThat(
                   e ->
@@ -490,20 +492,19 @@ public class SinkHandleTest {
       e.printStackTrace();
       Assert.fail();
     }
-    Mockito.verify(mockSinkHandleListener, Mockito.timeout(10_000).times(1))
-        .onFailure(sinkHandle, mockException);
-
-    // Close the SinkHandle.
-    sinkHandle.setNoMoreTsBlocks();
-    Assert.assertFalse(sinkHandle.isAborted());
-    Mockito.verify(mockSinkHandleListener, Mockito.timeout(10_000).times(0))
-        .onEndOfBlocks(sinkHandle);
-
-    // Abort the SinkHandle.
-    sinkHandle.abort();
-    Assert.assertTrue(sinkHandle.isAborted());
-    Mockito.verify(mockSinkHandleListener, Mockito.timeout(10_0000).times(1)).onAborted(sinkHandle);
-    Mockito.verify(mockSinkHandleListener, Mockito.timeout(10_0000).times(0)).onFinish(sinkHandle);
+    Mockito.verify(mockSinkListener, Mockito.timeout(10_000).times(1))
+        .onFailure(sinkChannel, mockException);
+
+    // Close the SinkChannel.
+    sinkChannel.setNoMoreTsBlocks();
+    Assert.assertFalse(sinkChannel.isAborted());
+    Mockito.verify(mockSinkListener, Mockito.timeout(10_000).times(0)).onEndOfBlocks(sinkChannel);
+
+    // Abort the SinkChannel.
+    sinkChannel.abort();
+    Assert.assertTrue(sinkChannel.isAborted());
+    Mockito.verify(mockSinkListener, Mockito.timeout(10_0000).times(1)).onAborted(sinkChannel);
+    Mockito.verify(mockSinkListener, Mockito.timeout(10_0000).times(0)).onFinish(sinkChannel);
   }
 
   @Test
@@ -526,8 +527,8 @@ public class SinkHandleTest {
                 "test", numOfMockTsBlock * mockTsBlockSize, numOfMockTsBlock * mockTsBlockSize));
     Mockito.when(mockLocalMemoryManager.getQueryPool()).thenReturn(spyMemoryPool);
 
-    // Construct a mock SinkHandleListener.
-    SinkHandleListener mockSinkHandleListener = Mockito.mock(SinkHandleListener.class);
+    // Construct a mock SinkListener.
+    SinkListener mockSinkListener = Mockito.mock(SinkListener.class);
     // Construct several mock TsBlock(s).
     List<TsBlock> mockTsBlocks = Utils.createMockTsBlocks(numOfMockTsBlock, mockTsBlockSize);
     IClientManager<TEndPoint, SyncDataNodeMPPDataExchangeServiceClient> mockClientManager =
@@ -548,9 +549,9 @@ public class SinkHandleTest {
       Assert.fail();
     }
 
-    // Construct SinkHandle.
-    SinkHandle sinkHandle =
-        new SinkHandle(
+    // Construct SinkChannel.
+    SinkChannel sinkChannel =
+        new SinkChannel(
             remoteEndpoint,
             remoteFragmentInstanceId,
             remotePlanNodeId,
@@ -559,36 +560,37 @@ public class SinkHandleTest {
             mockLocalMemoryManager,
             Executors.newSingleThreadExecutor(),
             Utils.createMockTsBlockSerde(mockTsBlockSize),
-            mockSinkHandleListener,
+            mockSinkListener,
             mockClientManager);
-    sinkHandle.setMaxBytesCanReserve(Long.MAX_VALUE);
-    Assert.assertTrue(sinkHandle.isFull().isDone());
-    Assert.assertFalse(sinkHandle.isFinished());
-    Assert.assertFalse(sinkHandle.isAborted());
+    sinkChannel.setMaxBytesCanReserve(Long.MAX_VALUE);
+    sinkChannel.open();
+    Assert.assertTrue(sinkChannel.isFull().isDone());
+    Assert.assertFalse(sinkChannel.isFinished());
+    Assert.assertFalse(sinkChannel.isAborted());
     Assert.assertEquals(
-        DEFAULT_MAX_TSBLOCK_SIZE_IN_BYTES, sinkHandle.getBufferRetainedSizeInBytes());
-    Assert.assertEquals(0, sinkHandle.getNumOfBufferedTsBlocks());
+        DEFAULT_MAX_TSBLOCK_SIZE_IN_BYTES, sinkChannel.getBufferRetainedSizeInBytes());
+    Assert.assertEquals(0, sinkChannel.getNumOfBufferedTsBlocks());
 
     // Send tsblocks.
-    sinkHandle.send(mockTsBlocks.get(0));
-    Future<?> blocked = sinkHandle.isFull();
+    sinkChannel.send(mockTsBlocks.get(0));
+    Future<?> blocked = sinkChannel.isFull();
     Assert.assertFalse(blocked.isDone());
     Assert.assertFalse(blocked.isCancelled());
-    Assert.assertFalse(sinkHandle.isFinished());
-    Assert.assertFalse(sinkHandle.isAborted());
+    Assert.assertFalse(sinkChannel.isFinished());
+    Assert.assertFalse(sinkChannel.isAborted());
     Assert.assertEquals(
         mockTsBlockSize * numOfMockTsBlock + DEFAULT_MAX_TSBLOCK_SIZE_IN_BYTES,
-        sinkHandle.getBufferRetainedSizeInBytes());
-    Assert.assertEquals(numOfMockTsBlock, sinkHandle.getNumOfBufferedTsBlocks());
+        sinkChannel.getBufferRetainedSizeInBytes());
+    Assert.assertEquals(numOfMockTsBlock, sinkChannel.getNumOfBufferedTsBlocks());
 
-    sinkHandle.abort();
+    sinkChannel.abort();
     Assert.assertTrue(blocked.isDone());
     Assert.assertTrue(blocked.isCancelled());
-    Assert.assertFalse(sinkHandle.isFinished());
-    Assert.assertTrue(sinkHandle.isAborted());
-    Assert.assertEquals(0L, sinkHandle.getBufferRetainedSizeInBytes());
-    Assert.assertEquals(0, sinkHandle.getNumOfBufferedTsBlocks());
-    Mockito.verify(mockSinkHandleListener, Mockito.timeout(10_0000).times(1)).onAborted(sinkHandle);
+    Assert.assertFalse(sinkChannel.isFinished());
+    Assert.assertTrue(sinkChannel.isAborted());
+    Assert.assertEquals(0L, sinkChannel.getBufferRetainedSizeInBytes());
+    Assert.assertEquals(0, sinkChannel.getNumOfBufferedTsBlocks());
+    Mockito.verify(mockSinkListener, Mockito.timeout(10_0000).times(1)).onAborted(sinkChannel);
     Assert.assertEquals(0L, spyMemoryPool.getQueryMemoryReservedBytes(queryId));
   }
 }
diff --git a/server/src/test/java/org/apache/iotdb/db/mpp/execution/exchange/SourceHandleTest.java b/server/src/test/java/org/apache/iotdb/db/mpp/execution/exchange/SourceHandleTest.java
index d0b8f49f29..8d6e7cd7d5 100644
--- a/server/src/test/java/org/apache/iotdb/db/mpp/execution/exchange/SourceHandleTest.java
+++ b/server/src/test/java/org/apache/iotdb/db/mpp/execution/exchange/SourceHandleTest.java
@@ -26,6 +26,7 @@ import org.apache.iotdb.commons.client.sync.SyncDataNodeMPPDataExchangeServiceCl
 import org.apache.iotdb.db.conf.IoTDBDescriptor;
 import org.apache.iotdb.db.mpp.common.FragmentInstanceId;
 import org.apache.iotdb.db.mpp.execution.exchange.MPPDataExchangeManager.SourceHandleListener;
+import org.apache.iotdb.db.mpp.execution.exchange.source.SourceHandle;
 import org.apache.iotdb.db.mpp.execution.memory.LocalMemoryManager;
 import org.apache.iotdb.db.mpp.execution.memory.MemoryPool;
 import org.apache.iotdb.mpp.rpc.thrift.TAcknowledgeDataBlockEvent;
@@ -113,6 +114,7 @@ public class SourceHandleTest {
             remoteFragmentInstanceId,
             localFragmentInstanceId,
             localPlanNodeId,
+            0,
             mockLocalMemoryManager,
             Executors.newSingleThreadExecutor(),
             mockTsBlockSerde,
@@ -228,6 +230,7 @@ public class SourceHandleTest {
             remoteFragmentInstanceId,
             localFragmentInstanceId,
             localPlanNodeId,
+            0,
             mockLocalMemoryManager,
             Executors.newSingleThreadExecutor(),
             mockTsBlockSerde,
@@ -388,6 +391,7 @@ public class SourceHandleTest {
             remoteFragmentInstanceId,
             localFragmentInstanceId,
             localPlanNodeId,
+            0,
             mockLocalMemoryManager,
             Executors.newSingleThreadExecutor(),
             mockTsBlockSerde,
@@ -560,6 +564,7 @@ public class SourceHandleTest {
             remoteFragmentInstanceId,
             localFragmentInstanceId,
             localPlanNodeId,
+            0,
             mockLocalMemoryManager,
             Executors.newSingleThreadExecutor(),
             mockTsBlockSerde,
@@ -647,6 +652,7 @@ public class SourceHandleTest {
             remoteFragmentInstanceId,
             localFragmentInstanceId,
             localPlanNodeId,
+            0,
             mockLocalMemoryManager,
             Executors.newSingleThreadExecutor(),
             mockTsBlockSerde,
diff --git a/server/src/test/java/org/apache/iotdb/db/mpp/execution/exchange/StubSinkHandle.java b/server/src/test/java/org/apache/iotdb/db/mpp/execution/exchange/StubSink.java
similarity index 91%
rename from server/src/test/java/org/apache/iotdb/db/mpp/execution/exchange/StubSinkHandle.java
rename to server/src/test/java/org/apache/iotdb/db/mpp/execution/exchange/StubSink.java
index 0e7b5ffb96..2ccc9bf8ee 100644
--- a/server/src/test/java/org/apache/iotdb/db/mpp/execution/exchange/StubSinkHandle.java
+++ b/server/src/test/java/org/apache/iotdb/db/mpp/execution/exchange/StubSink.java
@@ -18,6 +18,7 @@
  */
 package org.apache.iotdb.db.mpp.execution.exchange;
 
+import org.apache.iotdb.db.mpp.execution.exchange.sink.ISink;
 import org.apache.iotdb.db.mpp.execution.fragment.FragmentInstanceContext;
 import org.apache.iotdb.mpp.rpc.thrift.TFragmentInstanceId;
 import org.apache.iotdb.tsfile.read.common.block.TsBlock;
@@ -29,7 +30,7 @@ import java.util.List;
 
 import static com.google.common.util.concurrent.Futures.immediateVoidFuture;
 
-public class StubSinkHandle implements ISinkHandle {
+public class StubSink implements ISink {
 
   private final ListenableFuture<Void> NOT_BLOCKED = immediateVoidFuture();
 
@@ -39,7 +40,7 @@ public class StubSinkHandle implements ISinkHandle {
 
   private boolean closed = false;
 
-  public StubSinkHandle(FragmentInstanceContext instanceContext) {
+  public StubSink(FragmentInstanceContext instanceContext) {
     this.instanceContext = instanceContext;
   }
 
@@ -63,11 +64,6 @@ public class StubSinkHandle implements ISinkHandle {
     this.tsBlocks.add(tsBlock);
   }
 
-  @Override
-  public void send(int partition, List<TsBlock> tsBlocks) {
-    this.tsBlocks.addAll(tsBlocks);
-  }
-
   @Override
   public void setNoMoreTsBlocks() {
     if (closed) {
diff --git a/server/src/test/java/org/apache/iotdb/db/mpp/plan/plan/distribution/AggregationDistributionTest.java b/server/src/test/java/org/apache/iotdb/db/mpp/plan/plan/distribution/AggregationDistributionTest.java
index 3e6a6db0d0..add8f7f53a 100644
--- a/server/src/test/java/org/apache/iotdb/db/mpp/plan/plan/distribution/AggregationDistributionTest.java
+++ b/server/src/test/java/org/apache/iotdb/db/mpp/plan/plan/distribution/AggregationDistributionTest.java
@@ -37,13 +37,14 @@ import org.apache.iotdb.db.mpp.plan.planner.plan.node.PlanNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.PlanNodeId;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.AggregationNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.DeviceViewNode;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.ExchangeNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.GroupByLevelNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.HorizontallyConcatNode;
-import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.MergeSortNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.SlidingWindowAggregationNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.TimeJoinNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.source.SeriesAggregationScanNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.source.SeriesAggregationSourceNode;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.source.SeriesSourceNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.parameter.AggregationDescriptor;
 import org.apache.iotdb.db.mpp.plan.planner.plan.parameter.AggregationStep;
 import org.apache.iotdb.db.mpp.plan.planner.plan.parameter.CrossSeriesAggregationDescriptor;
@@ -775,13 +776,13 @@ public class AggregationDistributionTest {
         plan.getInstances().get(1).getFragment().getPlanNodeTree().getChildren().get(0);
     PlanNode f3Root =
         plan.getInstances().get(2).getFragment().getPlanNodeTree().getChildren().get(0);
-    assertTrue(f1Root instanceof MergeSortNode);
+    assertTrue(f1Root instanceof DeviceViewNode);
     assertTrue(f2Root instanceof HorizontallyConcatNode);
-    assertTrue(f3Root instanceof DeviceViewNode);
-    assertTrue(f3Root.getChildren().get(0) instanceof HorizontallyConcatNode);
-    assertTrue(f1Root.getChildren().get(0) instanceof DeviceViewNode);
-    assertTrue(f1Root.getChildren().get(0).getChildren().get(0) instanceof AggregationNode);
-    assertEquals(3, f1Root.getChildren().get(0).getChildren().get(0).getChildren().size());
+    assertTrue(f3Root instanceof HorizontallyConcatNode);
+    assertTrue(f3Root.getChildren().get(0) instanceof SeriesSourceNode);
+    assertTrue(f1Root.getChildren().get(0) instanceof AggregationNode);
+    assertTrue(f1Root.getChildren().get(1) instanceof ExchangeNode);
+    assertEquals(3, f1Root.getChildren().get(0).getChildren().size());
   }
 
   @Test
@@ -795,7 +796,7 @@ public class AggregationDistributionTest {
     DistributionPlanner planner =
         new DistributionPlanner(analysis, new LogicalQueryPlan(context, logicalPlanNode));
     DistributedQueryPlan plan = planner.planFragments();
-    assertEquals(3, plan.getInstances().size());
+    assertEquals(2, plan.getInstances().size());
     PlanNode f1Root =
         plan.getInstances().get(0).getFragment().getPlanNodeTree().getChildren().get(0);
     PlanNode f2Root =
diff --git a/server/src/test/java/org/apache/iotdb/db/mpp/plan/plan/distribution/AlignedByDeviceTest.java b/server/src/test/java/org/apache/iotdb/db/mpp/plan/plan/distribution/AlignedByDeviceTest.java
index 5d1002e505..0085facb7a 100644
--- a/server/src/test/java/org/apache/iotdb/db/mpp/plan/plan/distribution/AlignedByDeviceTest.java
+++ b/server/src/test/java/org/apache/iotdb/db/mpp/plan/plan/distribution/AlignedByDeviceTest.java
@@ -27,24 +27,1332 @@ import org.apache.iotdb.db.mpp.plan.planner.distribution.DistributionPlanner;
 import org.apache.iotdb.db.mpp.plan.planner.plan.DistributedQueryPlan;
 import org.apache.iotdb.db.mpp.plan.planner.plan.LogicalQueryPlan;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.PlanNode;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.AggregationNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.DeviceViewNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.ExchangeNode;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.FilterNode;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.HorizontallyConcatNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.MergeSortNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.SingleDeviceViewNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.TimeJoinNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.TransformNode;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.sink.IdentitySinkNode;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.sink.ShuffleSinkNode;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.source.SeriesSourceNode;
 
 import org.junit.Test;
 
-import java.util.List;
-
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertTrue;
 
 public class AlignedByDeviceTest {
+  @Test
+  public void testAggregation2Device2Region() {
+    QueryId queryId = new QueryId("test");
+    MPPQueryContext context =
+        new MPPQueryContext("", queryId, null, new TEndPoint(), new TEndPoint());
+    // test of SINGLE_SERIES
+    String sql = "select count(s1) from root.sg.d333,root.sg.d4444 align by device";
+    Analysis analysis = Util.analyze(sql, context);
+    PlanNode logicalPlanNode = Util.genLogicalPlan(analysis, context);
+    DistributionPlanner planner =
+        new DistributionPlanner(analysis, new LogicalQueryPlan(context, logicalPlanNode));
+    DistributedQueryPlan plan = planner.planFragments();
+    assertEquals(2, plan.getInstances().size());
+    PlanNode f1Root = plan.getInstances().get(0).getFragment().getPlanNodeTree();
+    PlanNode f2Root = plan.getInstances().get(1).getFragment().getPlanNodeTree();
+    assertTrue(f1Root instanceof IdentitySinkNode);
+    assertTrue(f1Root.getChildren().get(0) instanceof DeviceViewNode);
+    assertTrue(f1Root.getChildren().get(0).getChildren().get(0) instanceof AggregationNode);
+    assertTrue(
+        f1Root.getChildren().get(0).getChildren().get(0).getChildren().get(0)
+            instanceof SeriesSourceNode);
+    assertTrue(
+        f1Root.getChildren().get(0).getChildren().get(0).getChildren().get(1)
+            instanceof ExchangeNode);
+    assertTrue(f1Root.getChildren().get(0).getChildren().get(1) instanceof AggregationNode);
+    assertTrue(
+        f1Root.getChildren().get(0).getChildren().get(1).getChildren().get(0)
+            instanceof SeriesSourceNode);
+    assertTrue(
+        f1Root.getChildren().get(0).getChildren().get(1).getChildren().get(1)
+            instanceof ExchangeNode);
+    assertTrue(f2Root instanceof IdentitySinkNode);
+    assertTrue(f2Root.getChildren().get(0) instanceof SeriesSourceNode);
+    assertTrue(f2Root.getChildren().get(1) instanceof SeriesSourceNode);
+
+    // test of MULTI_SERIES
+    sql = "select count(s1),count(s2) from root.sg.d333,root.sg.d4444 align by device";
+    analysis = Util.analyze(sql, context);
+    logicalPlanNode = Util.genLogicalPlan(analysis, context);
+    planner = new DistributionPlanner(analysis, new LogicalQueryPlan(context, logicalPlanNode));
+    plan = planner.planFragments();
+    assertEquals(2, plan.getInstances().size());
+    f1Root = plan.getInstances().get(0).getFragment().getPlanNodeTree();
+    f2Root = plan.getInstances().get(1).getFragment().getPlanNodeTree();
+    assertTrue(f1Root instanceof IdentitySinkNode);
+    assertTrue(f1Root.getChildren().get(0) instanceof DeviceViewNode);
+    assertTrue(f1Root.getChildren().get(0).getChildren().get(0) instanceof AggregationNode);
+    assertTrue(
+        f1Root.getChildren().get(0).getChildren().get(0).getChildren().get(0)
+            instanceof SeriesSourceNode);
+    assertTrue(
+        f1Root.getChildren().get(0).getChildren().get(0).getChildren().get(1)
+            instanceof SeriesSourceNode);
+    assertTrue(
+        f1Root.getChildren().get(0).getChildren().get(0).getChildren().get(2)
+            instanceof ExchangeNode);
+    assertTrue(f1Root.getChildren().get(0).getChildren().get(1) instanceof AggregationNode);
+    assertTrue(
+        f1Root.getChildren().get(0).getChildren().get(1).getChildren().get(0)
+            instanceof SeriesSourceNode);
+    assertTrue(
+        f1Root.getChildren().get(0).getChildren().get(1).getChildren().get(1)
+            instanceof SeriesSourceNode);
+    assertTrue(
+        f1Root.getChildren().get(0).getChildren().get(1).getChildren().get(2)
+            instanceof ExchangeNode);
+    assertTrue(f2Root instanceof IdentitySinkNode);
+    assertTrue(f2Root.getChildren().get(0) instanceof HorizontallyConcatNode);
+    assertTrue(f2Root.getChildren().get(1) instanceof HorizontallyConcatNode);
+  }
+
+  @Test
+  public void testAggregation2Device2RegionWithValueFilter() {
+    QueryId queryId = new QueryId("test");
+    MPPQueryContext context =
+        new MPPQueryContext("", queryId, null, new TEndPoint(), new TEndPoint());
+    // test of SINGLE_SERIES
+    String sql = "select count(s1) from root.sg.d333,root.sg.d4444 where s1 <= 4 align by device";
+    Analysis analysis = Util.analyze(sql, context);
+    PlanNode logicalPlanNode = Util.genLogicalPlan(analysis, context);
+    DistributionPlanner planner =
+        new DistributionPlanner(analysis, new LogicalQueryPlan(context, logicalPlanNode));
+    DistributedQueryPlan plan = planner.planFragments();
+    assertEquals(2, plan.getInstances().size());
+    PlanNode f1Root = plan.getInstances().get(0).getFragment().getPlanNodeTree();
+    PlanNode f2Root = plan.getInstances().get(1).getFragment().getPlanNodeTree();
+    assertTrue(f1Root instanceof IdentitySinkNode);
+    assertTrue(f1Root.getChildren().get(0) instanceof DeviceViewNode);
+    assertTrue(f1Root.getChildren().get(0).getChildren().get(0) instanceof AggregationNode);
+    assertTrue(
+        f1Root.getChildren().get(0).getChildren().get(0).getChildren().get(0)
+            instanceof FilterNode);
+    assertTrue(
+        f1Root.getChildren().get(0).getChildren().get(0).getChildren().get(0).getChildren().get(0)
+            instanceof TimeJoinNode);
+    assertTrue(
+        f1Root
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+            instanceof SeriesSourceNode);
+    assertTrue(
+        f1Root
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(1)
+            instanceof ExchangeNode);
+    assertTrue(f1Root.getChildren().get(0).getChildren().get(1) instanceof AggregationNode);
+    assertTrue(
+        f1Root.getChildren().get(0).getChildren().get(1).getChildren().get(0)
+            instanceof FilterNode);
+    assertTrue(
+        f1Root.getChildren().get(0).getChildren().get(1).getChildren().get(0).getChildren().get(0)
+            instanceof TimeJoinNode);
+    assertTrue(
+        f1Root
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(1)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+            instanceof SeriesSourceNode);
+    assertTrue(
+        f1Root
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(1)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(1)
+            instanceof ExchangeNode);
+    assertTrue(f2Root instanceof IdentitySinkNode);
+    assertTrue(f2Root.getChildren().get(0) instanceof SeriesSourceNode);
+    assertTrue(f2Root.getChildren().get(1) instanceof SeriesSourceNode);
+
+    // test of MULTI_SERIES
+    sql =
+        "select count(s1),count(s2) from root.sg.d333,root.sg.d4444 where s1 <= 4 align by device";
+    analysis = Util.analyze(sql, context);
+    logicalPlanNode = Util.genLogicalPlan(analysis, context);
+    planner = new DistributionPlanner(analysis, new LogicalQueryPlan(context, logicalPlanNode));
+    plan = planner.planFragments();
+    assertEquals(2, plan.getInstances().size());
+    f1Root = plan.getInstances().get(0).getFragment().getPlanNodeTree();
+    f2Root = plan.getInstances().get(1).getFragment().getPlanNodeTree();
+    assertTrue(f1Root instanceof IdentitySinkNode);
+    assertTrue(f1Root.getChildren().get(0) instanceof DeviceViewNode);
+    assertTrue(f1Root.getChildren().get(0).getChildren().get(0) instanceof AggregationNode);
+    assertTrue(
+        f1Root.getChildren().get(0).getChildren().get(0).getChildren().get(0)
+            instanceof FilterNode);
+    assertTrue(
+        f1Root.getChildren().get(0).getChildren().get(0).getChildren().get(0).getChildren().get(0)
+            instanceof TimeJoinNode);
+    assertTrue(
+        f1Root
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+            instanceof SeriesSourceNode);
+    assertTrue(
+        f1Root
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(1)
+            instanceof SeriesSourceNode);
+    assertTrue(
+        f1Root
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(2)
+            instanceof ExchangeNode);
+    assertTrue(f1Root.getChildren().get(0).getChildren().get(1) instanceof AggregationNode);
+    assertTrue(
+        f1Root.getChildren().get(0).getChildren().get(1).getChildren().get(0).getChildren().get(0)
+            instanceof TimeJoinNode);
+    assertTrue(
+        f1Root
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(1)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+            instanceof SeriesSourceNode);
+    assertTrue(
+        f1Root
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(1)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(1)
+            instanceof SeriesSourceNode);
+    assertTrue(
+        f1Root
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(1)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(2)
+            instanceof ExchangeNode);
+    assertTrue(f2Root instanceof IdentitySinkNode);
+    assertTrue(f2Root.getChildren().get(0) instanceof TimeJoinNode);
+    assertTrue(f2Root.getChildren().get(1) instanceof TimeJoinNode);
+  }
+
+  @Test
+  public void testAggregation2Device2RegionOrderByTime() {
+    QueryId queryId = new QueryId("test");
+    MPPQueryContext context =
+        new MPPQueryContext("", queryId, null, new TEndPoint(), new TEndPoint());
+    // test of SINGLE_SERIES
+    String sql = "select count(s1) from root.sg.d333,root.sg.d4444 order by time align by device";
+    Analysis analysis = Util.analyze(sql, context);
+    PlanNode logicalPlanNode = Util.genLogicalPlan(analysis, context);
+    DistributionPlanner planner =
+        new DistributionPlanner(analysis, new LogicalQueryPlan(context, logicalPlanNode));
+    DistributedQueryPlan plan = planner.planFragments();
+    assertEquals(2, plan.getInstances().size());
+    PlanNode f1Root = plan.getInstances().get(0).getFragment().getPlanNodeTree();
+    PlanNode f2Root = plan.getInstances().get(1).getFragment().getPlanNodeTree();
+    assertTrue(f1Root instanceof IdentitySinkNode);
+    assertTrue(f1Root.getChildren().get(0) instanceof MergeSortNode);
+    assertTrue(f1Root.getChildren().get(0).getChildren().get(0) instanceof SingleDeviceViewNode);
+    assertTrue(
+        f1Root.getChildren().get(0).getChildren().get(0).getChildren().get(0)
+            instanceof AggregationNode);
+    assertTrue(
+        f1Root.getChildren().get(0).getChildren().get(0).getChildren().get(0).getChildren().get(0)
+            instanceof SeriesSourceNode);
+    assertTrue(
+        f1Root.getChildren().get(0).getChildren().get(0).getChildren().get(0).getChildren().get(1)
+            instanceof ExchangeNode);
+    assertTrue(f1Root.getChildren().get(0).getChildren().get(1) instanceof SingleDeviceViewNode);
+    assertTrue(
+        f1Root.getChildren().get(0).getChildren().get(1).getChildren().get(0)
+            instanceof AggregationNode);
+    assertTrue(
+        f1Root.getChildren().get(0).getChildren().get(1).getChildren().get(0).getChildren().get(0)
+            instanceof SeriesSourceNode);
+    assertTrue(
+        f1Root.getChildren().get(0).getChildren().get(1).getChildren().get(0).getChildren().get(1)
+            instanceof ExchangeNode);
+    assertTrue(f2Root instanceof ShuffleSinkNode);
+    assertTrue(f2Root.getChildren().get(0) instanceof SeriesSourceNode);
+    assertTrue(f2Root.getChildren().get(1) instanceof SeriesSourceNode);
+
+    // test of MULTI_SERIES
+    sql =
+        "select count(s1),count(s2) from root.sg.d333,root.sg.d4444 order by time align by device";
+    analysis = Util.analyze(sql, context);
+    logicalPlanNode = Util.genLogicalPlan(analysis, context);
+    planner = new DistributionPlanner(analysis, new LogicalQueryPlan(context, logicalPlanNode));
+    plan = planner.planFragments();
+    assertEquals(2, plan.getInstances().size());
+    f1Root = plan.getInstances().get(0).getFragment().getPlanNodeTree();
+    f2Root = plan.getInstances().get(1).getFragment().getPlanNodeTree();
+    assertTrue(f1Root instanceof IdentitySinkNode);
+    assertTrue(f1Root.getChildren().get(0) instanceof MergeSortNode);
+    assertTrue(f1Root.getChildren().get(0).getChildren().get(0) instanceof SingleDeviceViewNode);
+    assertTrue(
+        f1Root.getChildren().get(0).getChildren().get(0).getChildren().get(0)
+            instanceof AggregationNode);
+    assertTrue(
+        f1Root.getChildren().get(0).getChildren().get(0).getChildren().get(0).getChildren().get(0)
+            instanceof SeriesSourceNode);
+    assertTrue(
+        f1Root.getChildren().get(0).getChildren().get(0).getChildren().get(0).getChildren().get(1)
+            instanceof SeriesSourceNode);
+    assertTrue(
+        f1Root.getChildren().get(0).getChildren().get(0).getChildren().get(0).getChildren().get(2)
+            instanceof ExchangeNode);
+    assertTrue(f1Root.getChildren().get(0).getChildren().get(1) instanceof SingleDeviceViewNode);
+    assertTrue(
+        f1Root.getChildren().get(0).getChildren().get(1).getChildren().get(0)
+            instanceof AggregationNode);
+    assertTrue(
+        f1Root.getChildren().get(0).getChildren().get(1).getChildren().get(0).getChildren().get(0)
+            instanceof SeriesSourceNode);
+    assertTrue(
+        f1Root.getChildren().get(0).getChildren().get(1).getChildren().get(0).getChildren().get(1)
+            instanceof SeriesSourceNode);
+    assertTrue(
+        f1Root.getChildren().get(0).getChildren().get(1).getChildren().get(0).getChildren().get(2)
+            instanceof ExchangeNode);
+    assertTrue(f2Root instanceof ShuffleSinkNode);
+    assertTrue(f2Root.getChildren().get(0) instanceof HorizontallyConcatNode);
+    assertTrue(f2Root.getChildren().get(1) instanceof HorizontallyConcatNode);
+  }
+
+  @Test
+  public void testAggregation2Device2RegionWithValueFilterOrderByTime() {
+    QueryId queryId = new QueryId("test");
+    MPPQueryContext context =
+        new MPPQueryContext("", queryId, null, new TEndPoint(), new TEndPoint());
+    // test of SINGLE_SERIES
+    String sql =
+        "select count(s1) from root.sg.d333,root.sg.d4444 where s1 <= 4 order by time align by device";
+    Analysis analysis = Util.analyze(sql, context);
+    PlanNode logicalPlanNode = Util.genLogicalPlan(analysis, context);
+    DistributionPlanner planner =
+        new DistributionPlanner(analysis, new LogicalQueryPlan(context, logicalPlanNode));
+    DistributedQueryPlan plan = planner.planFragments();
+    assertEquals(2, plan.getInstances().size());
+    PlanNode f1Root = plan.getInstances().get(0).getFragment().getPlanNodeTree();
+    PlanNode f2Root = plan.getInstances().get(1).getFragment().getPlanNodeTree();
+    assertTrue(f1Root instanceof IdentitySinkNode);
+    assertTrue(f1Root.getChildren().get(0) instanceof MergeSortNode);
+    assertTrue(f1Root.getChildren().get(0).getChildren().get(0) instanceof SingleDeviceViewNode);
+    assertTrue(
+        f1Root.getChildren().get(0).getChildren().get(0).getChildren().get(0)
+            instanceof AggregationNode);
+    assertTrue(
+        f1Root.getChildren().get(0).getChildren().get(0).getChildren().get(0).getChildren().get(0)
+            instanceof FilterNode);
+    assertTrue(
+        f1Root
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+            instanceof TimeJoinNode);
+    assertTrue(
+        f1Root
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+            instanceof SeriesSourceNode);
+    assertTrue(
+        f1Root
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(1)
+            instanceof ExchangeNode);
+    assertTrue(f1Root.getChildren().get(0).getChildren().get(1) instanceof SingleDeviceViewNode);
+    assertTrue(
+        f1Root.getChildren().get(0).getChildren().get(1).getChildren().get(0)
+            instanceof AggregationNode);
+    assertTrue(
+        f1Root.getChildren().get(0).getChildren().get(1).getChildren().get(0).getChildren().get(0)
+            instanceof FilterNode);
+    assertTrue(
+        f1Root
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(1)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+            instanceof TimeJoinNode);
+    assertTrue(
+        f1Root
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(1)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+            instanceof SeriesSourceNode);
+    assertTrue(
+        f1Root
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(1)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(1)
+            instanceof ExchangeNode);
+    assertTrue(f2Root instanceof ShuffleSinkNode);
+    assertTrue(f2Root.getChildren().get(0) instanceof SeriesSourceNode);
+    assertTrue(f2Root.getChildren().get(1) instanceof SeriesSourceNode);
+
+    // test of MULTI_SERIES
+    sql =
+        "select count(s1),count(s2) from root.sg.d333,root.sg.d4444 where s1 <= 4 order by time align by device";
+    analysis = Util.analyze(sql, context);
+    logicalPlanNode = Util.genLogicalPlan(analysis, context);
+    planner = new DistributionPlanner(analysis, new LogicalQueryPlan(context, logicalPlanNode));
+    plan = planner.planFragments();
+    assertEquals(2, plan.getInstances().size());
+    f1Root = plan.getInstances().get(0).getFragment().getPlanNodeTree();
+    f2Root = plan.getInstances().get(1).getFragment().getPlanNodeTree();
+    assertTrue(f1Root instanceof IdentitySinkNode);
+    assertTrue(f1Root.getChildren().get(0) instanceof MergeSortNode);
+    assertTrue(f1Root.getChildren().get(0).getChildren().get(0) instanceof SingleDeviceViewNode);
+    assertTrue(
+        f1Root.getChildren().get(0).getChildren().get(0).getChildren().get(0)
+            instanceof AggregationNode);
+    assertTrue(
+        f1Root.getChildren().get(0).getChildren().get(0).getChildren().get(0).getChildren().get(0)
+            instanceof FilterNode);
+    assertTrue(
+        f1Root
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+            instanceof TimeJoinNode);
+    assertTrue(
+        f1Root
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+            instanceof SeriesSourceNode);
+    assertTrue(
+        f1Root
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(1)
+            instanceof SeriesSourceNode);
+    assertTrue(
+        f1Root
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(2)
+            instanceof ExchangeNode);
+    assertTrue(f1Root.getChildren().get(0).getChildren().get(1) instanceof SingleDeviceViewNode);
+    assertTrue(
+        f1Root.getChildren().get(0).getChildren().get(1).getChildren().get(0)
+            instanceof AggregationNode);
+    assertTrue(
+        f1Root.getChildren().get(0).getChildren().get(1).getChildren().get(0).getChildren().get(0)
+            instanceof FilterNode);
+    assertTrue(
+        f1Root
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(1)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+            instanceof TimeJoinNode);
+    assertTrue(
+        f1Root
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(1)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+            instanceof SeriesSourceNode);
+    assertTrue(
+        f1Root
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(1)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(1)
+            instanceof SeriesSourceNode);
+    assertTrue(
+        f1Root
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(1)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(2)
+            instanceof ExchangeNode);
+    assertTrue(f2Root instanceof ShuffleSinkNode);
+    assertTrue(f2Root.getChildren().get(0) instanceof TimeJoinNode);
+    assertTrue(f2Root.getChildren().get(1) instanceof TimeJoinNode);
+  }
+
+  @Test
+  public void testAggregation2Device3Region() {
+    QueryId queryId = new QueryId("test");
+    MPPQueryContext context =
+        new MPPQueryContext("", queryId, null, new TEndPoint(), new TEndPoint());
+    // test of SINGLE_SERIES
+    String sql = "select count(s1) from root.sg.d1,root.sg.d333 align by device";
+    Analysis analysis = Util.analyze(sql, context);
+    PlanNode logicalPlanNode = Util.genLogicalPlan(analysis, context);
+    DistributionPlanner planner =
+        new DistributionPlanner(analysis, new LogicalQueryPlan(context, logicalPlanNode));
+    DistributedQueryPlan plan = planner.planFragments();
+    assertEquals(3, plan.getInstances().size());
+    PlanNode f1Root = plan.getInstances().get(0).getFragment().getPlanNodeTree();
+    PlanNode f2Root = plan.getInstances().get(1).getFragment().getPlanNodeTree();
+    PlanNode f3Root = plan.getInstances().get(2).getFragment().getPlanNodeTree();
+    assertTrue(f1Root instanceof IdentitySinkNode);
+    assertTrue(f1Root.getChildren().get(0) instanceof DeviceViewNode);
+    assertTrue(f1Root.getChildren().get(0).getChildren().get(0) instanceof AggregationNode);
+    assertTrue(
+        f1Root.getChildren().get(0).getChildren().get(0).getChildren().get(0)
+            instanceof SeriesSourceNode);
+    assertTrue(
+        f1Root.getChildren().get(0).getChildren().get(0).getChildren().get(1)
+            instanceof ExchangeNode);
+    assertTrue(f1Root.getChildren().get(0).getChildren().get(1) instanceof AggregationNode);
+    assertTrue(
+        f1Root.getChildren().get(0).getChildren().get(1).getChildren().get(0)
+            instanceof SeriesSourceNode);
+    assertTrue(
+        f1Root.getChildren().get(0).getChildren().get(1).getChildren().get(1)
+            instanceof ExchangeNode);
+    assertTrue(f2Root instanceof IdentitySinkNode);
+    assertTrue(f2Root.getChildren().get(0) instanceof SeriesSourceNode);
+    assertTrue(f3Root instanceof IdentitySinkNode);
+    assertTrue(f3Root.getChildren().get(0) instanceof SeriesSourceNode);
+
+    // test of MULTI_SERIES
+    sql = "select count(s1),count(s2) from root.sg.d1,root.sg.d333 align by device";
+    analysis = Util.analyze(sql, context);
+    logicalPlanNode = Util.genLogicalPlan(analysis, context);
+    planner = new DistributionPlanner(analysis, new LogicalQueryPlan(context, logicalPlanNode));
+    plan = planner.planFragments();
+    assertEquals(3, plan.getInstances().size());
+    f1Root = plan.getInstances().get(0).getFragment().getPlanNodeTree();
+    f2Root = plan.getInstances().get(1).getFragment().getPlanNodeTree();
+    f3Root = plan.getInstances().get(2).getFragment().getPlanNodeTree();
+    assertTrue(f1Root instanceof IdentitySinkNode);
+    assertTrue(f1Root.getChildren().get(0) instanceof DeviceViewNode);
+    assertTrue(f1Root.getChildren().get(0).getChildren().get(0) instanceof AggregationNode);
+    assertTrue(
+        f1Root.getChildren().get(0).getChildren().get(0).getChildren().get(0)
+            instanceof SeriesSourceNode);
+    assertTrue(
+        f1Root.getChildren().get(0).getChildren().get(0).getChildren().get(1)
+            instanceof SeriesSourceNode);
+    assertTrue(
+        f1Root.getChildren().get(0).getChildren().get(0).getChildren().get(2)
+            instanceof ExchangeNode);
+    assertTrue(f1Root.getChildren().get(0).getChildren().get(1) instanceof AggregationNode);
+    assertTrue(
+        f1Root.getChildren().get(0).getChildren().get(1).getChildren().get(0)
+            instanceof SeriesSourceNode);
+    assertTrue(
+        f1Root.getChildren().get(0).getChildren().get(1).getChildren().get(1)
+            instanceof SeriesSourceNode);
+    assertTrue(
+        f1Root.getChildren().get(0).getChildren().get(1).getChildren().get(2)
+            instanceof ExchangeNode);
+    assertTrue(f2Root instanceof IdentitySinkNode);
+    assertTrue(f2Root.getChildren().get(0) instanceof HorizontallyConcatNode);
+    assertTrue(f3Root instanceof IdentitySinkNode);
+    assertTrue(f3Root.getChildren().get(0) instanceof HorizontallyConcatNode);
+  }
 
   @Test
-  public void test1Device1Region() {}
+  public void testAggregation2Device3RegionWithValueFilter() {
+    QueryId queryId = new QueryId("test");
+    MPPQueryContext context =
+        new MPPQueryContext("", queryId, null, new TEndPoint(), new TEndPoint());
+    // test of SINGLE_SERIES
+    String sql = "select count(s1) from root.sg.d1,root.sg.d333 where s1 <= 4 align by device";
+    Analysis analysis = Util.analyze(sql, context);
+    PlanNode logicalPlanNode = Util.genLogicalPlan(analysis, context);
+    DistributionPlanner planner =
+        new DistributionPlanner(analysis, new LogicalQueryPlan(context, logicalPlanNode));
+    DistributedQueryPlan plan = planner.planFragments();
+    assertEquals(3, plan.getInstances().size());
+    PlanNode f1Root = plan.getInstances().get(0).getFragment().getPlanNodeTree();
+    PlanNode f2Root = plan.getInstances().get(1).getFragment().getPlanNodeTree();
+    PlanNode f3Root = plan.getInstances().get(1).getFragment().getPlanNodeTree();
+    assertTrue(f1Root instanceof IdentitySinkNode);
+    assertTrue(f1Root.getChildren().get(0) instanceof DeviceViewNode);
+    assertTrue(f1Root.getChildren().get(0).getChildren().get(0) instanceof AggregationNode);
+    assertTrue(
+        f1Root.getChildren().get(0).getChildren().get(0).getChildren().get(0)
+            instanceof FilterNode);
+    assertTrue(
+        f1Root.getChildren().get(0).getChildren().get(0).getChildren().get(0).getChildren().get(0)
+            instanceof TimeJoinNode);
+    assertTrue(
+        f1Root
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+            instanceof SeriesSourceNode);
+    assertTrue(
+        f1Root
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(1)
+            instanceof ExchangeNode);
+    assertTrue(f1Root.getChildren().get(0).getChildren().get(1) instanceof AggregationNode);
+    assertTrue(
+        f1Root.getChildren().get(0).getChildren().get(1).getChildren().get(0)
+            instanceof FilterNode);
+    assertTrue(
+        f1Root.getChildren().get(0).getChildren().get(1).getChildren().get(0).getChildren().get(0)
+            instanceof TimeJoinNode);
+    assertTrue(
+        f1Root
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(1)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+            instanceof SeriesSourceNode);
+    assertTrue(
+        f1Root
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(1)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(1)
+            instanceof ExchangeNode);
+    assertTrue(f2Root instanceof IdentitySinkNode);
+    assertTrue(f2Root.getChildren().get(0) instanceof SeriesSourceNode);
+    assertTrue(f3Root instanceof IdentitySinkNode);
+    assertTrue(f3Root.getChildren().get(0) instanceof SeriesSourceNode);
+
+    // test of MULTI_SERIES
+    sql = "select count(s1),count(s2) from root.sg.d1,root.sg.d333 where s1 <= 4 align by device";
+    analysis = Util.analyze(sql, context);
+    logicalPlanNode = Util.genLogicalPlan(analysis, context);
+    planner = new DistributionPlanner(analysis, new LogicalQueryPlan(context, logicalPlanNode));
+    plan = planner.planFragments();
+    assertEquals(3, plan.getInstances().size());
+    f1Root = plan.getInstances().get(0).getFragment().getPlanNodeTree();
+    f2Root = plan.getInstances().get(1).getFragment().getPlanNodeTree();
+    f3Root = plan.getInstances().get(2).getFragment().getPlanNodeTree();
+    assertTrue(f1Root instanceof IdentitySinkNode);
+    assertTrue(f1Root.getChildren().get(0) instanceof DeviceViewNode);
+    assertTrue(f1Root.getChildren().get(0).getChildren().get(0) instanceof AggregationNode);
+    assertTrue(
+        f1Root.getChildren().get(0).getChildren().get(0).getChildren().get(0)
+            instanceof FilterNode);
+    assertTrue(
+        f1Root.getChildren().get(0).getChildren().get(0).getChildren().get(0).getChildren().get(0)
+            instanceof TimeJoinNode);
+    assertTrue(
+        f1Root
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+            instanceof SeriesSourceNode);
+    assertTrue(
+        f1Root
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(1)
+            instanceof SeriesSourceNode);
+    assertTrue(
+        f1Root
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(2)
+            instanceof ExchangeNode);
+    assertTrue(f1Root.getChildren().get(0).getChildren().get(1) instanceof AggregationNode);
+    assertTrue(
+        f1Root.getChildren().get(0).getChildren().get(1).getChildren().get(0).getChildren().get(0)
+            instanceof TimeJoinNode);
+    assertTrue(
+        f1Root
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(1)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+            instanceof SeriesSourceNode);
+    assertTrue(
+        f1Root
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(1)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(1)
+            instanceof SeriesSourceNode);
+    assertTrue(
+        f1Root
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(1)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(2)
+            instanceof ExchangeNode);
+    assertTrue(f2Root instanceof IdentitySinkNode);
+    assertTrue(f2Root.getChildren().get(0) instanceof TimeJoinNode);
+    assertTrue(f3Root instanceof IdentitySinkNode);
+    assertTrue(f3Root.getChildren().get(0) instanceof TimeJoinNode);
+  }
+
+  @Test
+  public void testAggregation2Device3RegionOrderByTime() {
+    QueryId queryId = new QueryId("test");
+    MPPQueryContext context =
+        new MPPQueryContext("", queryId, null, new TEndPoint(), new TEndPoint());
+    // test of SINGLE_SERIES
+    String sql = "select count(s1) from root.sg.d1,root.sg.d333 order by time align by device";
+    Analysis analysis = Util.analyze(sql, context);
+    PlanNode logicalPlanNode = Util.genLogicalPlan(analysis, context);
+    DistributionPlanner planner =
+        new DistributionPlanner(analysis, new LogicalQueryPlan(context, logicalPlanNode));
+    DistributedQueryPlan plan = planner.planFragments();
+    assertEquals(3, plan.getInstances().size());
+    PlanNode f1Root = plan.getInstances().get(0).getFragment().getPlanNodeTree();
+    PlanNode f2Root = plan.getInstances().get(1).getFragment().getPlanNodeTree();
+    PlanNode f3Root = plan.getInstances().get(2).getFragment().getPlanNodeTree();
+    assertTrue(f1Root instanceof IdentitySinkNode);
+    assertTrue(f1Root.getChildren().get(0) instanceof MergeSortNode);
+    assertTrue(f1Root.getChildren().get(0).getChildren().get(0) instanceof SingleDeviceViewNode);
+    assertTrue(
+        f1Root.getChildren().get(0).getChildren().get(0).getChildren().get(0)
+            instanceof AggregationNode);
+    assertTrue(
+        f1Root.getChildren().get(0).getChildren().get(0).getChildren().get(0).getChildren().get(0)
+            instanceof SeriesSourceNode);
+    assertTrue(
+        f1Root.getChildren().get(0).getChildren().get(0).getChildren().get(0).getChildren().get(1)
+            instanceof ExchangeNode);
+    assertTrue(f1Root.getChildren().get(0).getChildren().get(1) instanceof SingleDeviceViewNode);
+    assertTrue(
+        f1Root.getChildren().get(0).getChildren().get(1).getChildren().get(0)
+            instanceof AggregationNode);
+    assertTrue(
+        f1Root.getChildren().get(0).getChildren().get(1).getChildren().get(0).getChildren().get(0)
+            instanceof SeriesSourceNode);
+    assertTrue(
+        f1Root.getChildren().get(0).getChildren().get(1).getChildren().get(0).getChildren().get(1)
+            instanceof ExchangeNode);
+    assertTrue(f2Root instanceof ShuffleSinkNode);
+    assertTrue(f2Root.getChildren().get(0) instanceof SeriesSourceNode);
+    assertTrue(f3Root instanceof ShuffleSinkNode);
+    assertTrue(f3Root.getChildren().get(0) instanceof SeriesSourceNode);
+
+    // test of MULTI_SERIES
+    sql = "select count(s1),count(s2) from root.sg.d1,root.sg.d333 order by time align by device";
+    analysis = Util.analyze(sql, context);
+    logicalPlanNode = Util.genLogicalPlan(analysis, context);
+    planner = new DistributionPlanner(analysis, new LogicalQueryPlan(context, logicalPlanNode));
+    plan = planner.planFragments();
+    assertEquals(3, plan.getInstances().size());
+    f1Root = plan.getInstances().get(0).getFragment().getPlanNodeTree();
+    f2Root = plan.getInstances().get(1).getFragment().getPlanNodeTree();
+    f3Root = plan.getInstances().get(2).getFragment().getPlanNodeTree();
+    assertTrue(f1Root instanceof IdentitySinkNode);
+    assertTrue(f1Root.getChildren().get(0) instanceof MergeSortNode);
+    assertTrue(f1Root.getChildren().get(0).getChildren().get(0) instanceof SingleDeviceViewNode);
+    assertTrue(
+        f1Root.getChildren().get(0).getChildren().get(0).getChildren().get(0)
+            instanceof AggregationNode);
+    assertTrue(
+        f1Root.getChildren().get(0).getChildren().get(0).getChildren().get(0).getChildren().get(0)
+            instanceof SeriesSourceNode);
+    assertTrue(
+        f1Root.getChildren().get(0).getChildren().get(0).getChildren().get(0).getChildren().get(1)
+            instanceof SeriesSourceNode);
+    assertTrue(
+        f1Root.getChildren().get(0).getChildren().get(0).getChildren().get(0).getChildren().get(2)
+            instanceof ExchangeNode);
+    assertTrue(f1Root.getChildren().get(0).getChildren().get(1) instanceof SingleDeviceViewNode);
+    assertTrue(
+        f1Root.getChildren().get(0).getChildren().get(1).getChildren().get(0)
+            instanceof AggregationNode);
+    assertTrue(
+        f1Root.getChildren().get(0).getChildren().get(1).getChildren().get(0).getChildren().get(0)
+            instanceof SeriesSourceNode);
+    assertTrue(
+        f1Root.getChildren().get(0).getChildren().get(1).getChildren().get(0).getChildren().get(1)
+            instanceof SeriesSourceNode);
+    assertTrue(
+        f1Root.getChildren().get(0).getChildren().get(1).getChildren().get(0).getChildren().get(2)
+            instanceof ExchangeNode);
+    assertTrue(f2Root instanceof ShuffleSinkNode);
+    assertTrue(f2Root.getChildren().get(0) instanceof HorizontallyConcatNode);
+    assertTrue(f3Root instanceof ShuffleSinkNode);
+    assertTrue(f3Root.getChildren().get(0) instanceof HorizontallyConcatNode);
+  }
+
+  @Test
+  public void testAggregation2Device3RegionWithValueFilterOrderByTime() {
+    QueryId queryId = new QueryId("test");
+    MPPQueryContext context =
+        new MPPQueryContext("", queryId, null, new TEndPoint(), new TEndPoint());
+    // test of SINGLE_SERIES
+    String sql =
+        "select count(s1) from root.sg.d1,root.sg.d333 where s1 <= 4 order by time align by device";
+    Analysis analysis = Util.analyze(sql, context);
+    PlanNode logicalPlanNode = Util.genLogicalPlan(analysis, context);
+    DistributionPlanner planner =
+        new DistributionPlanner(analysis, new LogicalQueryPlan(context, logicalPlanNode));
+    DistributedQueryPlan plan = planner.planFragments();
+    assertEquals(3, plan.getInstances().size());
+    PlanNode f1Root = plan.getInstances().get(0).getFragment().getPlanNodeTree();
+    PlanNode f2Root = plan.getInstances().get(1).getFragment().getPlanNodeTree();
+    PlanNode f3Root = plan.getInstances().get(2).getFragment().getPlanNodeTree();
+    assertTrue(f1Root instanceof IdentitySinkNode);
+    assertTrue(f1Root.getChildren().get(0) instanceof MergeSortNode);
+    assertTrue(f1Root.getChildren().get(0).getChildren().get(0) instanceof SingleDeviceViewNode);
+    assertTrue(
+        f1Root.getChildren().get(0).getChildren().get(0).getChildren().get(0)
+            instanceof AggregationNode);
+    assertTrue(
+        f1Root.getChildren().get(0).getChildren().get(0).getChildren().get(0).getChildren().get(0)
+            instanceof FilterNode);
+    assertTrue(
+        f1Root
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+            instanceof TimeJoinNode);
+    assertTrue(
+        f1Root
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+            instanceof SeriesSourceNode);
+    assertTrue(
+        f1Root
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(1)
+            instanceof ExchangeNode);
+    assertTrue(f1Root.getChildren().get(0).getChildren().get(1) instanceof SingleDeviceViewNode);
+    assertTrue(
+        f1Root.getChildren().get(0).getChildren().get(1).getChildren().get(0)
+            instanceof AggregationNode);
+    assertTrue(
+        f1Root.getChildren().get(0).getChildren().get(1).getChildren().get(0).getChildren().get(0)
+            instanceof FilterNode);
+    assertTrue(
+        f1Root
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(1)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+            instanceof TimeJoinNode);
+    assertTrue(
+        f1Root
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(1)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+            instanceof SeriesSourceNode);
+    assertTrue(
+        f1Root
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(1)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(1)
+            instanceof ExchangeNode);
+    assertTrue(f2Root instanceof ShuffleSinkNode);
+    assertTrue(f2Root.getChildren().get(0) instanceof SeriesSourceNode);
+    assertTrue(f3Root instanceof ShuffleSinkNode);
+    assertTrue(f3Root.getChildren().get(0) instanceof SeriesSourceNode);
+
+    // test of MULTI_SERIES
+    sql =
+        "select count(s1),count(s2) from root.sg.d1,root.sg.d333 where s1 <= 4 order by time align by device";
+    analysis = Util.analyze(sql, context);
+    logicalPlanNode = Util.genLogicalPlan(analysis, context);
+    planner = new DistributionPlanner(analysis, new LogicalQueryPlan(context, logicalPlanNode));
+    plan = planner.planFragments();
+    assertEquals(3, plan.getInstances().size());
+    f1Root = plan.getInstances().get(0).getFragment().getPlanNodeTree();
+    f2Root = plan.getInstances().get(1).getFragment().getPlanNodeTree();
+    f3Root = plan.getInstances().get(2).getFragment().getPlanNodeTree();
+    assertTrue(f1Root instanceof IdentitySinkNode);
+    assertTrue(f1Root.getChildren().get(0) instanceof MergeSortNode);
+    assertTrue(f1Root.getChildren().get(0).getChildren().get(0) instanceof SingleDeviceViewNode);
+    assertTrue(
+        f1Root.getChildren().get(0).getChildren().get(0).getChildren().get(0)
+            instanceof AggregationNode);
+    assertTrue(
+        f1Root.getChildren().get(0).getChildren().get(0).getChildren().get(0).getChildren().get(0)
+            instanceof FilterNode);
+    assertTrue(
+        f1Root
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+            instanceof TimeJoinNode);
+    assertTrue(
+        f1Root
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+            instanceof SeriesSourceNode);
+    assertTrue(
+        f1Root
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(1)
+            instanceof SeriesSourceNode);
+    assertTrue(
+        f1Root
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(2)
+            instanceof ExchangeNode);
+    assertTrue(f1Root.getChildren().get(0).getChildren().get(1) instanceof SingleDeviceViewNode);
+    assertTrue(
+        f1Root.getChildren().get(0).getChildren().get(1).getChildren().get(0)
+            instanceof AggregationNode);
+    assertTrue(
+        f1Root.getChildren().get(0).getChildren().get(1).getChildren().get(0).getChildren().get(0)
+            instanceof FilterNode);
+    assertTrue(
+        f1Root
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(1)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+            instanceof TimeJoinNode);
+    assertTrue(
+        f1Root
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(1)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+            instanceof SeriesSourceNode);
+    assertTrue(
+        f1Root
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(1)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(1)
+            instanceof SeriesSourceNode);
+    assertTrue(
+        f1Root
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(1)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(2)
+            instanceof ExchangeNode);
+    assertTrue(f2Root instanceof ShuffleSinkNode);
+    assertTrue(f2Root.getChildren().get(0) instanceof TimeJoinNode);
+    assertTrue(f3Root instanceof ShuffleSinkNode);
+    assertTrue(f3Root.getChildren().get(0) instanceof TimeJoinNode);
+  }
+
+  @Test
+  public void testDiffFunction2Device2Region() {
+    QueryId queryId = new QueryId("test_special_process_align_by_device_2_device_2_region");
+    MPPQueryContext context =
+        new MPPQueryContext("", queryId, null, new TEndPoint(), new TEndPoint());
+    String sql = "select diff(s1), diff(s2) from root.sg.d333,root.sg.d4444 align by device";
+    Analysis analysis = Util.analyze(sql, context);
+    PlanNode logicalPlanNode = Util.genLogicalPlan(analysis, context);
+    DistributionPlanner planner =
+        new DistributionPlanner(analysis, new LogicalQueryPlan(context, logicalPlanNode));
+    DistributedQueryPlan plan = planner.planFragments();
+    assertEquals(2, plan.getInstances().size());
+    PlanNode f1Root = plan.getInstances().get(0).getFragment().getPlanNodeTree();
+    PlanNode f2Root = plan.getInstances().get(1).getFragment().getPlanNodeTree();
+    assertTrue(f1Root instanceof IdentitySinkNode);
+    assertTrue(f2Root instanceof IdentitySinkNode);
+    assertTrue(f1Root.getChildren().get(0) instanceof DeviceViewNode);
+    assertTrue(f2Root.getChildren().get(0) instanceof TimeJoinNode);
+    assertTrue(f1Root.getChildren().get(0).getChildren().get(0) instanceof TransformNode);
+    assertTrue(
+        f1Root.getChildren().get(0).getChildren().get(0).getChildren().get(0).getChildren().get(2)
+            instanceof ExchangeNode);
+  }
+
+  @Test
+  public void testDiffFunctionWithOrderByTime2Device2Region() {
+    QueryId queryId =
+        new QueryId("test_special_process_align_by_device_with_order_by_time_2_device_2_region");
+    MPPQueryContext context =
+        new MPPQueryContext("", queryId, null, new TEndPoint(), new TEndPoint());
+    String sql =
+        "select diff(s1), diff(s2) from root.sg.d333,root.sg.d4444 order by time align by device";
+    Analysis analysis = Util.analyze(sql, context);
+    PlanNode logicalPlanNode = Util.genLogicalPlan(analysis, context);
+    DistributionPlanner planner =
+        new DistributionPlanner(analysis, new LogicalQueryPlan(context, logicalPlanNode));
+    DistributedQueryPlan plan = planner.planFragments();
+    assertEquals(2, plan.getInstances().size());
+    PlanNode f1Root = plan.getInstances().get(0).getFragment().getPlanNodeTree();
+    PlanNode f2Root = plan.getInstances().get(1).getFragment().getPlanNodeTree();
+    assertTrue(f1Root instanceof IdentitySinkNode);
+    assertTrue(f2Root instanceof ShuffleSinkNode);
+    assertTrue(f1Root.getChildren().get(0) instanceof MergeSortNode);
+    assertTrue(f2Root.getChildren().get(0) instanceof TimeJoinNode);
+    assertTrue(f1Root.getChildren().get(0).getChildren().get(0) instanceof SingleDeviceViewNode);
+    assertTrue(
+        f1Root
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(2)
+            instanceof ExchangeNode);
+  }
 
   @Test
   public void testDiffFunction2Device3Region() {
@@ -59,11 +1367,15 @@ public class AlignedByDeviceTest {
     //                                 /                      \
     //                           TransformNode                 Exchange
     //                                |                             |
+    //                                |                        IdentityNode
+    //                                |                             |
     //                           TimeJoinNode                  TransformNode
     //                           /     |      \                     |
     //                     d1.s1[1]  d1.s2[1]  Exchange         TimeJoinNode
     //                                            |               /      \
-    //                                        TimeJoinNode  d22.s1[3]   d22.s2[3]
+    //                                        IdentityNode  d22.s1[3]   d22.s2[3]
+    //                                            |
+    //                                        TimeJoinNode
     //                                        /      \
     //                                  d1.s1[2]    d1.s2[2]
     // ------------------------------------------------------------------------------------------------
@@ -74,20 +1386,20 @@ public class AlignedByDeviceTest {
         new DistributionPlanner(analysis, new LogicalQueryPlan(context, logicalPlanNode));
     DistributedQueryPlan plan = planner.planFragments();
     assertEquals(3, plan.getInstances().size());
-    PlanNode f1Root =
-        plan.getInstances().get(0).getFragment().getPlanNodeTree().getChildren().get(0);
-    PlanNode f2Root =
-        plan.getInstances().get(1).getFragment().getPlanNodeTree().getChildren().get(0);
-    PlanNode f3Root =
-        plan.getInstances().get(2).getFragment().getPlanNodeTree().getChildren().get(0);
-    assertTrue(f1Root instanceof DeviceViewNode);
-    assertTrue(f2Root instanceof TimeJoinNode);
-    assertTrue(f3Root instanceof TransformNode);
-    assertTrue(f1Root.getChildren().get(0) instanceof TransformNode);
+    PlanNode f1Root = plan.getInstances().get(0).getFragment().getPlanNodeTree();
+    PlanNode f2Root = plan.getInstances().get(1).getFragment().getPlanNodeTree();
+    PlanNode f3Root = plan.getInstances().get(2).getFragment().getPlanNodeTree();
+    assertTrue(f1Root instanceof IdentitySinkNode);
+    assertTrue(f2Root instanceof IdentitySinkNode);
+    assertTrue(f3Root instanceof IdentitySinkNode);
+    assertTrue(f1Root.getChildren().get(0) instanceof DeviceViewNode);
+    assertTrue(f2Root.getChildren().get(0) instanceof TimeJoinNode);
+    assertTrue(f3Root.getChildren().get(0) instanceof TransformNode);
+    assertTrue(f1Root.getChildren().get(0).getChildren().get(0) instanceof TransformNode);
     assertTrue(
-        f1Root.getChildren().get(0).getChildren().get(0).getChildren().get(2)
+        f1Root.getChildren().get(0).getChildren().get(0).getChildren().get(0).getChildren().get(2)
             instanceof ExchangeNode);
-    assertTrue(f3Root.getChildren().get(0) instanceof TimeJoinNode);
+    assertTrue(f3Root.getChildren().get(0).getChildren().get(0) instanceof TimeJoinNode);
   }
 
   @Test
@@ -104,13 +1416,17 @@ public class AlignedByDeviceTest {
     //                                 /                           \
     //                         SingleDeviceViewNode             Exchange
     //                                |                             |
+    //                                |                        ShuffleSinkNode
+    //                                |                             |
     //                           TransformNode              SingleDeviceViewNode
     //                                |                             |
     //                           TimeJoinNode                  TransformNode
     //                           /     |      \                     |
     //                     d1.s1[1]  d1.s2[1]  Exchange         TimeJoinNode
     //                                            |               /      \
-    //                                       TimeJoinNode  d22.s1[3]   d22.s2[3]
+    //                                        IdentityNode  d22.s1[3]   d22.s2[3]
+    //                                            |
+    //                                      ShuffleSinkNode
     //                                        /      \
     //                                  d1.s1[2]    d1.s2[2]
     // ------------------------------------------------------------------------------------------------
@@ -122,23 +1438,31 @@ public class AlignedByDeviceTest {
         new DistributionPlanner(analysis, new LogicalQueryPlan(context, logicalPlanNode));
     DistributedQueryPlan plan = planner.planFragments();
     assertEquals(3, plan.getInstances().size());
-    PlanNode f1Root =
-        plan.getInstances().get(0).getFragment().getPlanNodeTree().getChildren().get(0);
-    PlanNode f2Root =
-        plan.getInstances().get(1).getFragment().getPlanNodeTree().getChildren().get(0);
-    PlanNode f3Root =
-        plan.getInstances().get(2).getFragment().getPlanNodeTree().getChildren().get(0);
-    assertTrue(f1Root instanceof MergeSortNode);
-    assertTrue(f2Root instanceof TimeJoinNode);
-    assertTrue(f3Root instanceof SingleDeviceViewNode);
-    assertTrue(f1Root.getChildren().get(0) instanceof SingleDeviceViewNode);
+    PlanNode f1Root = plan.getInstances().get(0).getFragment().getPlanNodeTree();
+    PlanNode f2Root = plan.getInstances().get(1).getFragment().getPlanNodeTree();
+    PlanNode f3Root = plan.getInstances().get(2).getFragment().getPlanNodeTree();
+    assertTrue(f1Root instanceof IdentitySinkNode);
+    assertTrue(f2Root instanceof ShuffleSinkNode);
+    assertTrue(f3Root instanceof ShuffleSinkNode);
+    assertTrue(f1Root.getChildren().get(0) instanceof MergeSortNode);
+    assertTrue(f2Root.getChildren().get(0) instanceof TimeJoinNode);
+    assertTrue(f3Root.getChildren().get(0) instanceof SingleDeviceViewNode);
+    assertTrue(f1Root.getChildren().get(0).getChildren().get(0) instanceof SingleDeviceViewNode);
     assertTrue(
-        f1Root.getChildren().get(0).getChildren().get(0).getChildren().get(0).getChildren().get(2)
+        f1Root
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(0)
+                .getChildren()
+                .get(2)
             instanceof ExchangeNode);
-    assertTrue(f3Root.getChildren().get(0).getChildren().get(0) instanceof TimeJoinNode);
-  }
-
-  private LogicalQueryPlan constructLogicalPlan(List<String> series) {
-    return null;
+    assertTrue(
+        f3Root.getChildren().get(0).getChildren().get(0).getChildren().get(0)
+            instanceof TimeJoinNode);
   }
 }
diff --git a/server/src/test/java/org/apache/iotdb/db/mpp/plan/plan/node/metadata/read/DeviceSchemaScanNodeSerdeTest.java b/server/src/test/java/org/apache/iotdb/db/mpp/plan/plan/node/metadata/read/DeviceSchemaScanNodeSerdeTest.java
index a2afa4e721..a9fa05e23d 100644
--- a/server/src/test/java/org/apache/iotdb/db/mpp/plan/plan/node/metadata/read/DeviceSchemaScanNodeSerdeTest.java
+++ b/server/src/test/java/org/apache/iotdb/db/mpp/plan/plan/node/metadata/read/DeviceSchemaScanNodeSerdeTest.java
@@ -23,6 +23,7 @@ import org.apache.iotdb.commons.exception.IllegalPathException;
 import org.apache.iotdb.commons.path.PartialPath;
 import org.apache.iotdb.db.mpp.common.FragmentInstanceId;
 import org.apache.iotdb.db.mpp.common.PlanFragmentId;
+import org.apache.iotdb.db.mpp.execution.exchange.sink.DownStreamChannelLocation;
 import org.apache.iotdb.db.mpp.plan.plan.node.PlanNodeDeserializeHelper;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.PlanNodeId;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.metedata.read.DevicesSchemaScanNode;
@@ -30,12 +31,13 @@ import org.apache.iotdb.db.mpp.plan.planner.plan.node.metedata.read.SchemaQueryM
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.ExchangeNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.LimitNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.OffsetNode;
-import org.apache.iotdb.db.mpp.plan.planner.plan.node.sink.FragmentSinkNode;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.sink.IdentitySinkNode;
 
 import org.junit.Assert;
 import org.junit.Test;
 
 import java.nio.ByteBuffer;
+import java.util.Collections;
 
 public class DeviceSchemaScanNodeSerdeTest {
 
@@ -53,14 +55,18 @@ public class DeviceSchemaScanNodeSerdeTest {
             10,
             false,
             false);
-    FragmentSinkNode fragmentSinkNode = new FragmentSinkNode(new PlanNodeId("fragmentSink"));
-    fragmentSinkNode.addChild(devicesSchemaScanNode);
-    fragmentSinkNode.setDownStream(
-        new TEndPoint("127.0.0.1", 6667),
-        new FragmentInstanceId(new PlanFragmentId("q", 1), "ds"),
-        new PlanNodeId("test"));
+    IdentitySinkNode sinkNode =
+        new IdentitySinkNode(
+            new PlanNodeId("sink"),
+            Collections.singletonList(devicesSchemaScanNode),
+            Collections.singletonList(
+                new DownStreamChannelLocation(
+                    new TEndPoint("127.0.0.1", 6667),
+                    new FragmentInstanceId(new PlanFragmentId("q", 1), "ds").toThrift(),
+                    new PlanNodeId("test").getId())));
+    devicesSchemaScanNode.addChild(sinkNode);
     exchangeNode.addChild(schemaMergeNode);
-    exchangeNode.setRemoteSourceNode(fragmentSinkNode);
+    exchangeNode.setOutputColumnNames(exchangeNode.getChild().getOutputColumnNames());
     exchangeNode.setUpstream(
         new TEndPoint("127.0.0.1", 6667),
         new FragmentInstanceId(new PlanFragmentId("q", 1), "ds"),
diff --git a/server/src/test/java/org/apache/iotdb/db/mpp/plan/plan/node/metadata/read/NodeManagementMemoryMergeNodeSerdeTest.java b/server/src/test/java/org/apache/iotdb/db/mpp/plan/plan/node/metadata/read/NodeManagementMemoryMergeNodeSerdeTest.java
index c885919c74..2a598b1f64 100644
--- a/server/src/test/java/org/apache/iotdb/db/mpp/plan/plan/node/metadata/read/NodeManagementMemoryMergeNodeSerdeTest.java
+++ b/server/src/test/java/org/apache/iotdb/db/mpp/plan/plan/node/metadata/read/NodeManagementMemoryMergeNodeSerdeTest.java
@@ -26,6 +26,7 @@ import org.apache.iotdb.commons.path.PartialPath;
 import org.apache.iotdb.db.metadata.mnode.MNodeType;
 import org.apache.iotdb.db.mpp.common.FragmentInstanceId;
 import org.apache.iotdb.db.mpp.common.PlanFragmentId;
+import org.apache.iotdb.db.mpp.execution.exchange.sink.DownStreamChannelLocation;
 import org.apache.iotdb.db.mpp.plan.plan.node.PlanNodeDeserializeHelper;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.PlanNodeId;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.metedata.read.NodeManagementMemoryMergeNode;
@@ -34,12 +35,13 @@ import org.apache.iotdb.db.mpp.plan.planner.plan.node.metedata.read.NodePathsCou
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.metedata.read.NodePathsSchemaScanNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.metedata.read.SchemaQueryMergeNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.ExchangeNode;
-import org.apache.iotdb.db.mpp.plan.planner.plan.node.sink.FragmentSinkNode;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.sink.IdentitySinkNode;
 
 import org.junit.Assert;
 import org.junit.Test;
 
 import java.nio.ByteBuffer;
+import java.util.Collections;
 import java.util.HashSet;
 import java.util.Set;
 
@@ -94,14 +96,18 @@ public class NodeManagementMemoryMergeNodeSerdeTest {
     NodePathsSchemaScanNode childPathsSchemaScanNode =
         new NodePathsSchemaScanNode(
             new PlanNodeId("NodePathsScan"), new PartialPath("root.ln"), -1);
-    FragmentSinkNode fragmentSinkNode = new FragmentSinkNode(new PlanNodeId("fragmentSink"));
-    fragmentSinkNode.addChild(childPathsSchemaScanNode);
-    fragmentSinkNode.setDownStream(
-        new TEndPoint("127.0.0.1", 6667),
-        new FragmentInstanceId(new PlanFragmentId("q", 1), "ds"),
-        new PlanNodeId("test"));
-    exchangeNode.addChild(schemaMergeNode);
-    exchangeNode.setRemoteSourceNode(fragmentSinkNode);
+    IdentitySinkNode sinkNode =
+        new IdentitySinkNode(
+            new PlanNodeId("sink"),
+            Collections.singletonList(childPathsSchemaScanNode),
+            Collections.singletonList(
+                new DownStreamChannelLocation(
+                    new TEndPoint("127.0.0.1", 6667),
+                    new FragmentInstanceId(new PlanFragmentId("q", 1), "ds").toThrift(),
+                    new PlanNodeId("test").getId())));
+    schemaMergeNode.addChild(exchangeNode);
+    exchangeNode.addChild(sinkNode);
+    exchangeNode.setOutputColumnNames(exchangeNode.getChild().getOutputColumnNames());
     exchangeNode.setUpstream(
         new TEndPoint("127.0.0.1", 6667),
         new FragmentInstanceId(new PlanFragmentId("q", 1), "ds"),
diff --git a/server/src/test/java/org/apache/iotdb/db/mpp/plan/plan/node/metadata/read/SchemaCountNodeSerdeTest.java b/server/src/test/java/org/apache/iotdb/db/mpp/plan/plan/node/metadata/read/SchemaCountNodeSerdeTest.java
index d79d959850..2b7b9e94e6 100644
--- a/server/src/test/java/org/apache/iotdb/db/mpp/plan/plan/node/metadata/read/SchemaCountNodeSerdeTest.java
+++ b/server/src/test/java/org/apache/iotdb/db/mpp/plan/plan/node/metadata/read/SchemaCountNodeSerdeTest.java
@@ -23,18 +23,20 @@ import org.apache.iotdb.commons.exception.IllegalPathException;
 import org.apache.iotdb.commons.path.PartialPath;
 import org.apache.iotdb.db.mpp.common.FragmentInstanceId;
 import org.apache.iotdb.db.mpp.common.PlanFragmentId;
+import org.apache.iotdb.db.mpp.execution.exchange.sink.DownStreamChannelLocation;
 import org.apache.iotdb.db.mpp.plan.plan.node.PlanNodeDeserializeHelper;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.PlanNodeId;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.metedata.read.CountSchemaMergeNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.metedata.read.DevicesCountNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.metedata.read.LevelTimeSeriesCountNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.ExchangeNode;
-import org.apache.iotdb.db.mpp.plan.planner.plan.node.sink.FragmentSinkNode;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.sink.IdentitySinkNode;
 
 import org.junit.Assert;
 import org.junit.Test;
 
 import java.nio.ByteBuffer;
+import java.util.Collections;
 
 public class SchemaCountNodeSerdeTest {
 
@@ -45,14 +47,18 @@ public class SchemaCountNodeSerdeTest {
     DevicesCountNode devicesCountNode =
         new DevicesCountNode(
             new PlanNodeId("devicesCount"), new PartialPath("root.sg.device0"), true);
-    FragmentSinkNode fragmentSinkNode = new FragmentSinkNode(new PlanNodeId("fragmentSink"));
-    fragmentSinkNode.addChild(devicesCountNode);
-    fragmentSinkNode.setDownStream(
-        new TEndPoint("127.0.0.1", 6667),
-        new FragmentInstanceId(new PlanFragmentId("q", 1), "ds"),
-        new PlanNodeId("test"));
-    exchangeNode.addChild(countMergeNode);
-    exchangeNode.setRemoteSourceNode(fragmentSinkNode);
+    IdentitySinkNode sinkNode =
+        new IdentitySinkNode(
+            new PlanNodeId("sink"),
+            Collections.singletonList(devicesCountNode),
+            Collections.singletonList(
+                new DownStreamChannelLocation(
+                    new TEndPoint("127.0.0.1", 6667),
+                    new FragmentInstanceId(new PlanFragmentId("q", 1), "ds").toThrift(),
+                    new PlanNodeId("test").getId())));
+    countMergeNode.addChild(sinkNode);
+    exchangeNode.addChild(sinkNode);
+    exchangeNode.setOutputColumnNames(exchangeNode.getChild().getOutputColumnNames());
     exchangeNode.setUpstream(
         new TEndPoint("127.0.0.1", 6667),
         new FragmentInstanceId(new PlanFragmentId("q", 1), "ds"),
@@ -77,14 +83,18 @@ public class SchemaCountNodeSerdeTest {
             null,
             null,
             false);
-    FragmentSinkNode fragmentSinkNode = new FragmentSinkNode(new PlanNodeId("fragmentSink"));
-    fragmentSinkNode.addChild(levelTimeSeriesCountNode);
-    fragmentSinkNode.setDownStream(
-        new TEndPoint("127.0.0.1", 6667),
-        new FragmentInstanceId(new PlanFragmentId("q", 1), "ds"),
-        new PlanNodeId("test"));
-    exchangeNode.addChild(countMergeNode);
-    exchangeNode.setRemoteSourceNode(fragmentSinkNode);
+    IdentitySinkNode sinkNode =
+        new IdentitySinkNode(
+            new PlanNodeId("sink"),
+            Collections.singletonList(levelTimeSeriesCountNode),
+            Collections.singletonList(
+                new DownStreamChannelLocation(
+                    new TEndPoint("127.0.0.1", 6667),
+                    new FragmentInstanceId(new PlanFragmentId("q", 1), "ds").toThrift(),
+                    new PlanNodeId("test").getId())));
+    countMergeNode.addChild(exchangeNode);
+    exchangeNode.addChild(sinkNode);
+    exchangeNode.setOutputColumnNames(exchangeNode.getChild().getOutputColumnNames());
     exchangeNode.setUpstream(
         new TEndPoint("127.0.0.1", 6667),
         new FragmentInstanceId(new PlanFragmentId("q", 1), "ds"),
diff --git a/server/src/test/java/org/apache/iotdb/db/mpp/plan/plan/node/metadata/read/TimeSeriesSchemaScanNodeSerdeTest.java b/server/src/test/java/org/apache/iotdb/db/mpp/plan/plan/node/metadata/read/TimeSeriesSchemaScanNodeSerdeTest.java
index 564ba64916..c8815caeb4 100644
--- a/server/src/test/java/org/apache/iotdb/db/mpp/plan/plan/node/metadata/read/TimeSeriesSchemaScanNodeSerdeTest.java
+++ b/server/src/test/java/org/apache/iotdb/db/mpp/plan/plan/node/metadata/read/TimeSeriesSchemaScanNodeSerdeTest.java
@@ -23,6 +23,7 @@ import org.apache.iotdb.commons.exception.IllegalPathException;
 import org.apache.iotdb.commons.path.PartialPath;
 import org.apache.iotdb.db.mpp.common.FragmentInstanceId;
 import org.apache.iotdb.db.mpp.common.PlanFragmentId;
+import org.apache.iotdb.db.mpp.execution.exchange.sink.DownStreamChannelLocation;
 import org.apache.iotdb.db.mpp.plan.plan.node.PlanNodeDeserializeHelper;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.PlanNodeId;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.metedata.read.SchemaQueryMergeNode;
@@ -30,12 +31,13 @@ import org.apache.iotdb.db.mpp.plan.planner.plan.node.metedata.read.TimeSeriesSc
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.ExchangeNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.LimitNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.OffsetNode;
-import org.apache.iotdb.db.mpp.plan.planner.plan.node.sink.FragmentSinkNode;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.sink.IdentitySinkNode;
 
 import org.junit.Assert;
 import org.junit.Test;
 
 import java.nio.ByteBuffer;
+import java.util.Collections;
 
 public class TimeSeriesSchemaScanNodeSerdeTest {
 
@@ -56,14 +58,18 @@ public class TimeSeriesSchemaScanNodeSerdeTest {
             false,
             false,
             false);
-    FragmentSinkNode fragmentSinkNode = new FragmentSinkNode(new PlanNodeId("fragmentSink"));
-    fragmentSinkNode.addChild(timeSeriesSchemaScanNode);
-    fragmentSinkNode.setDownStream(
-        new TEndPoint("127.0.0.1", 6667),
-        new FragmentInstanceId(new PlanFragmentId("q", 1), "ds"),
-        new PlanNodeId("test"));
-    exchangeNode.addChild(schemaMergeNode);
-    exchangeNode.setRemoteSourceNode(fragmentSinkNode);
+    IdentitySinkNode sinkNode =
+        new IdentitySinkNode(
+            new PlanNodeId("sink"),
+            Collections.singletonList(timeSeriesSchemaScanNode),
+            Collections.singletonList(
+                new DownStreamChannelLocation(
+                    new TEndPoint("127.0.0.1", 6667),
+                    new FragmentInstanceId(new PlanFragmentId("q", 1), "ds").toThrift(),
+                    new PlanNodeId("test").getId())));
+    schemaMergeNode.addChild(exchangeNode);
+    exchangeNode.addChild(sinkNode);
+    exchangeNode.setOutputColumnNames(exchangeNode.getChild().getOutputColumnNames());
     exchangeNode.setUpstream(
         new TEndPoint("127.0.0.1", 6667),
         new FragmentInstanceId(new PlanFragmentId("q", 1), "ds"),
diff --git a/server/src/test/java/org/apache/iotdb/db/mpp/plan/plan/node/process/ExchangeNodeSerdeTest.java b/server/src/test/java/org/apache/iotdb/db/mpp/plan/plan/node/process/ExchangeNodeSerdeTest.java
index 5b6a411386..94a1c55e13 100644
--- a/server/src/test/java/org/apache/iotdb/db/mpp/plan/plan/node/process/ExchangeNodeSerdeTest.java
+++ b/server/src/test/java/org/apache/iotdb/db/mpp/plan/plan/node/process/ExchangeNodeSerdeTest.java
@@ -22,16 +22,18 @@ import org.apache.iotdb.common.rpc.thrift.TEndPoint;
 import org.apache.iotdb.commons.exception.IllegalPathException;
 import org.apache.iotdb.db.mpp.common.FragmentInstanceId;
 import org.apache.iotdb.db.mpp.common.PlanFragmentId;
+import org.apache.iotdb.db.mpp.execution.exchange.sink.DownStreamChannelLocation;
 import org.apache.iotdb.db.mpp.plan.plan.node.PlanNodeDeserializeHelper;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.PlanNodeId;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.ExchangeNode;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.TimeJoinNode;
-import org.apache.iotdb.db.mpp.plan.planner.plan.node.sink.FragmentSinkNode;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.sink.IdentitySinkNode;
 import org.apache.iotdb.db.mpp.plan.statement.component.Ordering;
 
 import org.junit.Test;
 
 import java.nio.ByteBuffer;
+import java.util.Collections;
 
 import static org.junit.Assert.assertEquals;
 
@@ -42,14 +44,18 @@ public class ExchangeNodeSerdeTest {
     TimeJoinNode timeJoinNode = new TimeJoinNode(new PlanNodeId("TestTimeJoinNode"), Ordering.ASC);
 
     ExchangeNode exchangeNode = new ExchangeNode(new PlanNodeId("TestExchangeNode"));
-    FragmentSinkNode fragmentSinkNode =
-        new FragmentSinkNode(new PlanNodeId("TestFragmentSinkNode"));
-    fragmentSinkNode.setDownStream(
-        new TEndPoint("127.0.0.1", 6666),
-        new FragmentInstanceId(new PlanFragmentId("q", 1), "ds"),
-        new PlanNodeId("test"));
-    fragmentSinkNode.addChild(timeJoinNode);
-    exchangeNode.setRemoteSourceNode(fragmentSinkNode);
+    IdentitySinkNode sinkNode =
+        new IdentitySinkNode(
+            new PlanNodeId("sink"),
+            Collections.singletonList(timeJoinNode),
+            Collections.singletonList(
+                new DownStreamChannelLocation(
+                    new TEndPoint("127.0.0.1", 6667),
+                    new FragmentInstanceId(new PlanFragmentId("q", 1), "ds").toThrift(),
+                    new PlanNodeId("test").getId())));
+
+    exchangeNode.addChild(sinkNode);
+    exchangeNode.setOutputColumnNames(exchangeNode.getChild().getOutputColumnNames());
     exchangeNode.setUpstream(
         new TEndPoint("127.0.0.1", 6666),
         new FragmentInstanceId(new PlanFragmentId("q", 1), "ds"),
diff --git a/server/src/test/java/org/apache/iotdb/db/mpp/plan/plan/node/sink/FragmentSinkNodeSerdeTest.java b/server/src/test/java/org/apache/iotdb/db/mpp/plan/plan/node/sink/IdentitySinkNodeSerdeTest.java
similarity index 54%
copy from server/src/test/java/org/apache/iotdb/db/mpp/plan/plan/node/sink/FragmentSinkNodeSerdeTest.java
copy to server/src/test/java/org/apache/iotdb/db/mpp/plan/plan/node/sink/IdentitySinkNodeSerdeTest.java
index 9d16d18065..da3882dc30 100644
--- a/server/src/test/java/org/apache/iotdb/db/mpp/plan/plan/node/sink/FragmentSinkNodeSerdeTest.java
+++ b/server/src/test/java/org/apache/iotdb/db/mpp/plan/plan/node/sink/IdentitySinkNodeSerdeTest.java
@@ -16,46 +16,45 @@
  * specific language governing permissions and limitations
  * under the License.
  */
+
 package org.apache.iotdb.db.mpp.plan.plan.node.sink;
 
 import org.apache.iotdb.common.rpc.thrift.TEndPoint;
 import org.apache.iotdb.commons.exception.IllegalPathException;
-import org.apache.iotdb.commons.path.PartialPath;
-import org.apache.iotdb.db.mpp.common.FragmentInstanceId;
-import org.apache.iotdb.db.mpp.common.PlanFragmentId;
+import org.apache.iotdb.db.mpp.execution.exchange.sink.DownStreamChannelLocation;
 import org.apache.iotdb.db.mpp.plan.plan.node.PlanNodeDeserializeHelper;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.PlanNodeId;
-import org.apache.iotdb.db.mpp.plan.planner.plan.node.metedata.read.DevicesSchemaScanNode;
-import org.apache.iotdb.db.mpp.plan.planner.plan.node.sink.FragmentSinkNode;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.sink.IdentitySinkNode;
+import org.apache.iotdb.mpp.rpc.thrift.TFragmentInstanceId;
 
 import org.junit.Test;
 
 import java.nio.ByteBuffer;
+import java.util.Collections;
 
 import static org.junit.Assert.assertEquals;
 
-public class FragmentSinkNodeSerdeTest {
+public class IdentitySinkNodeSerdeTest {
 
   @Test
   public void testSerializeAndDeserialize() throws IllegalPathException {
-    FragmentSinkNode fragmentSinkNode =
-        new FragmentSinkNode(new PlanNodeId("TestFragmentSinkNode"));
-    fragmentSinkNode.addChild(
-        new DevicesSchemaScanNode(
-            new PlanNodeId("deviceSchema"),
-            new PartialPath("root.sg.device0"),
-            0,
-            0,
-            false,
-            false));
-    fragmentSinkNode.setDownStream(
-        new TEndPoint("127.0.0.1", 6666),
-        new FragmentInstanceId(new PlanFragmentId("q", 1), "ds"),
-        new PlanNodeId("test"));
-
+    DownStreamChannelLocation downStreamChannelLocation =
+        new DownStreamChannelLocation(
+            new TEndPoint("test", 1), new TFragmentInstanceId("test", 1, "test"), "test");
+    IdentitySinkNode identitySinkNode1 =
+        new IdentitySinkNode(
+            new PlanNodeId("testIdentitySinkNode"),
+            Collections.singletonList(downStreamChannelLocation));
     ByteBuffer byteBuffer = ByteBuffer.allocate(1024);
-    fragmentSinkNode.serialize(byteBuffer);
+    identitySinkNode1.serialize(byteBuffer);
     byteBuffer.flip();
-    assertEquals(PlanNodeDeserializeHelper.deserialize(byteBuffer), fragmentSinkNode);
+    assertEquals(identitySinkNode1, PlanNodeDeserializeHelper.deserialize(byteBuffer));
+
+    IdentitySinkNode identitySinkNode2 =
+        new IdentitySinkNode(new PlanNodeId("testIdentitySinkNode"), Collections.emptyList());
+    ByteBuffer byteBuffer2 = ByteBuffer.allocate(1024);
+    identitySinkNode2.serialize(byteBuffer2);
+    byteBuffer2.flip();
+    assertEquals(identitySinkNode2, PlanNodeDeserializeHelper.deserialize(byteBuffer2));
   }
 }
diff --git a/server/src/test/java/org/apache/iotdb/db/mpp/plan/plan/node/sink/FragmentSinkNodeSerdeTest.java b/server/src/test/java/org/apache/iotdb/db/mpp/plan/plan/node/sink/ShuffleSinkHandleNodeSerdeTest.java
similarity index 57%
rename from server/src/test/java/org/apache/iotdb/db/mpp/plan/plan/node/sink/FragmentSinkNodeSerdeTest.java
rename to server/src/test/java/org/apache/iotdb/db/mpp/plan/plan/node/sink/ShuffleSinkHandleNodeSerdeTest.java
index 9d16d18065..39c6633f2c 100644
--- a/server/src/test/java/org/apache/iotdb/db/mpp/plan/plan/node/sink/FragmentSinkNodeSerdeTest.java
+++ b/server/src/test/java/org/apache/iotdb/db/mpp/plan/plan/node/sink/ShuffleSinkHandleNodeSerdeTest.java
@@ -16,46 +16,44 @@
  * specific language governing permissions and limitations
  * under the License.
  */
+
 package org.apache.iotdb.db.mpp.plan.plan.node.sink;
 
 import org.apache.iotdb.common.rpc.thrift.TEndPoint;
 import org.apache.iotdb.commons.exception.IllegalPathException;
-import org.apache.iotdb.commons.path.PartialPath;
-import org.apache.iotdb.db.mpp.common.FragmentInstanceId;
-import org.apache.iotdb.db.mpp.common.PlanFragmentId;
+import org.apache.iotdb.db.mpp.execution.exchange.sink.DownStreamChannelLocation;
 import org.apache.iotdb.db.mpp.plan.plan.node.PlanNodeDeserializeHelper;
 import org.apache.iotdb.db.mpp.plan.planner.plan.node.PlanNodeId;
-import org.apache.iotdb.db.mpp.plan.planner.plan.node.metedata.read.DevicesSchemaScanNode;
-import org.apache.iotdb.db.mpp.plan.planner.plan.node.sink.FragmentSinkNode;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.sink.ShuffleSinkNode;
+import org.apache.iotdb.mpp.rpc.thrift.TFragmentInstanceId;
 
 import org.junit.Test;
 
 import java.nio.ByteBuffer;
+import java.util.Collections;
 
 import static org.junit.Assert.assertEquals;
 
-public class FragmentSinkNodeSerdeTest {
-
+public class ShuffleSinkHandleNodeSerdeTest {
   @Test
   public void testSerializeAndDeserialize() throws IllegalPathException {
-    FragmentSinkNode fragmentSinkNode =
-        new FragmentSinkNode(new PlanNodeId("TestFragmentSinkNode"));
-    fragmentSinkNode.addChild(
-        new DevicesSchemaScanNode(
-            new PlanNodeId("deviceSchema"),
-            new PartialPath("root.sg.device0"),
-            0,
-            0,
-            false,
-            false));
-    fragmentSinkNode.setDownStream(
-        new TEndPoint("127.0.0.1", 6666),
-        new FragmentInstanceId(new PlanFragmentId("q", 1), "ds"),
-        new PlanNodeId("test"));
-
+    DownStreamChannelLocation downStreamChannelLocation =
+        new DownStreamChannelLocation(
+            new TEndPoint("test", 1), new TFragmentInstanceId("test", 1, "test"), "test");
+    ShuffleSinkNode shuffleSinkNode1 =
+        new ShuffleSinkNode(
+            new PlanNodeId("testIdentitySinkNode"),
+            Collections.singletonList(downStreamChannelLocation));
     ByteBuffer byteBuffer = ByteBuffer.allocate(1024);
-    fragmentSinkNode.serialize(byteBuffer);
+    shuffleSinkNode1.serialize(byteBuffer);
     byteBuffer.flip();
-    assertEquals(PlanNodeDeserializeHelper.deserialize(byteBuffer), fragmentSinkNode);
+    assertEquals(PlanNodeDeserializeHelper.deserialize(byteBuffer), shuffleSinkNode1);
+
+    ShuffleSinkNode shuffleSinkNode2 =
+        new ShuffleSinkNode(new PlanNodeId("testIdentitySinkNode"), Collections.emptyList());
+    ByteBuffer byteBuffer2 = ByteBuffer.allocate(1024);
+    shuffleSinkNode2.serialize(byteBuffer2);
+    byteBuffer2.flip();
+    assertEquals(PlanNodeDeserializeHelper.deserialize(byteBuffer2), shuffleSinkNode2);
   }
 }
diff --git a/thrift/src/main/thrift/datanode.thrift b/thrift/src/main/thrift/datanode.thrift
index 81ff5c1116..bf379d922a 100644
--- a/thrift/src/main/thrift/datanode.thrift
+++ b/thrift/src/main/thrift/datanode.thrift
@@ -63,6 +63,8 @@ struct TGetDataBlockRequest {
   1: required TFragmentInstanceId sourceFragmentInstanceId
   2: required i32 startSequenceId
   3: required i32 endSequenceId
+  // index of upstream SinkHandle
+  4: required i32 index
 }
 
 struct TGetDataBlockResponse {
@@ -73,6 +75,8 @@ struct TAcknowledgeDataBlockEvent {
   1: required TFragmentInstanceId sourceFragmentInstanceId
   2: required i32 startSequenceId
   3: required i32 endSequenceId
+  // index of upstream SinkHandle
+  4: required i32 index
 }
 
 struct TNewDataBlockEvent {