You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@iotdb.apache.org by ja...@apache.org on 2023/03/29 05:47:07 UTC

[iotdb] branch master updated: [IOTDB-5586] Reduce the scope of lock in MemoryPool Version2

This is an automated email from the ASF dual-hosted git repository.

jackietien pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iotdb.git


The following commit(s) were added to refs/heads/master by this push:
     new 94620929bb [IOTDB-5586] Reduce the scope of lock in MemoryPool Version2
94620929bb is described below

commit 94620929bb7e619c1613d4a17039eec2031fca32
Author: Xiangwei Wei <34...@users.noreply.github.com>
AuthorDate: Wed Mar 29 13:46:59 2023 +0800

    [IOTDB-5586] Reduce the scope of lock in MemoryPool Version2
---
 .../db/exception/runtime/MemoryLeakException.java  |  27 +++
 .../iotdb/db/mpp/common/FragmentInstanceId.java    |   4 +
 .../execution/exchange/MPPDataExchangeManager.java |   6 +
 .../mpp/execution/exchange/SharedTsBlockQueue.java |  16 +-
 .../mpp/execution/exchange/sink/SinkChannel.java   |  12 +-
 .../execution/exchange/source/SourceHandle.java    |  12 +-
 .../fragment/FragmentInstanceExecution.java        |   5 +
 .../iotdb/db/mpp/execution/memory/MemoryPool.java  | 250 +++++++++++----------
 .../db/mpp/execution/memory/MemoryPoolTest.java    |  27 ++-
 9 files changed, 200 insertions(+), 159 deletions(-)

diff --git a/server/src/main/java/org/apache/iotdb/db/exception/runtime/MemoryLeakException.java b/server/src/main/java/org/apache/iotdb/db/exception/runtime/MemoryLeakException.java
new file mode 100644
index 0000000000..beb9a83951
--- /dev/null
+++ b/server/src/main/java/org/apache/iotdb/db/exception/runtime/MemoryLeakException.java
@@ -0,0 +1,27 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iotdb.db.exception.runtime;
+
+public class MemoryLeakException extends RuntimeException {
+
+  public MemoryLeakException(String message) {
+    super(message);
+  }
+}
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/common/FragmentInstanceId.java b/server/src/main/java/org/apache/iotdb/db/mpp/common/FragmentInstanceId.java
index 793066b94a..24164c4e89 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/common/FragmentInstanceId.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/common/FragmentInstanceId.java
@@ -57,6 +57,10 @@ public class FragmentInstanceId {
     return instanceId;
   }
 
+  public String getFragmentInstanceId() {
+    return fragmentId + "." + instanceId;
+  }
+
   public String toString() {
     return fullId;
   }
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/MPPDataExchangeManager.java b/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/MPPDataExchangeManager.java
index 501781b408..9a58904c79 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/MPPDataExchangeManager.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/MPPDataExchangeManager.java
@@ -470,6 +470,12 @@ public class MPPDataExchangeManager implements IMPPDataExchangeManager {
     return mppDataExchangeService;
   }
 
+  public void deRegisterFragmentInstanceFromMemoryPool(String queryId, String fragmentInstanceId) {
+    localMemoryManager
+        .getQueryPool()
+        .deRegisterFragmentInstanceToQueryMemoryMap(queryId, fragmentInstanceId);
+  }
+
   private synchronized ISinkChannel createLocalSinkChannel(
       TFragmentInstanceId localFragmentInstanceId,
       TFragmentInstanceId remoteFragmentInstanceId,
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/SharedTsBlockQueue.java b/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/SharedTsBlockQueue.java
index 55b7462922..34858dfaa4 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/SharedTsBlockQueue.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/SharedTsBlockQueue.java
@@ -91,6 +91,10 @@ public class SharedTsBlockQueue {
     this.localPlanNodeId = Validate.notNull(planNodeId, "PlanNode ID cannot be null");
     this.localMemoryManager =
         Validate.notNull(localMemoryManager, "local memory manager cannot be null");
+    localMemoryManager
+        .getQueryPool()
+        .registerPlanNodeIdToQueryMemoryMap(
+            fragmentInstanceId.queryId, fullFragmentInstanceId, planNodeId);
   }
 
   public boolean hasNoMoreTsBlocks() {
@@ -260,10 +264,6 @@ public class SharedTsBlockQueue {
               bufferRetainedSizeInBytes);
       bufferRetainedSizeInBytes = 0;
     }
-    localMemoryManager
-        .getQueryPool()
-        .clearMemoryReservationMap(
-            localFragmentInstanceId.getQueryId(), fullFragmentInstanceId, localPlanNodeId);
   }
 
   /** Destroy the queue and cancel the future. Should only be called in abnormal case */
@@ -289,10 +289,6 @@ public class SharedTsBlockQueue {
               bufferRetainedSizeInBytes);
       bufferRetainedSizeInBytes = 0;
     }
-    localMemoryManager
-        .getQueryPool()
-        .clearMemoryReservationMap(
-            localFragmentInstanceId.getQueryId(), fullFragmentInstanceId, localPlanNodeId);
   }
 
   /** Destroy the queue and cancel the future. Should only be called in abnormal case */
@@ -318,9 +314,5 @@ public class SharedTsBlockQueue {
               bufferRetainedSizeInBytes);
       bufferRetainedSizeInBytes = 0;
     }
-    localMemoryManager
-        .getQueryPool()
-        .clearMemoryReservationMap(
-            localFragmentInstanceId.getQueryId(), fullFragmentInstanceId, localPlanNodeId);
   }
 }
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/sink/SinkChannel.java b/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/sink/SinkChannel.java
index 5cd28de462..b32028bee1 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/sink/SinkChannel.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/sink/SinkChannel.java
@@ -140,6 +140,10 @@ public class SinkChannel implements ISinkChannel {
             localFragmentInstanceId.instanceId);
     this.bufferRetainedSizeInBytes = DEFAULT_MAX_TSBLOCK_SIZE_IN_BYTES;
     this.currentTsBlockSize = DEFAULT_MAX_TSBLOCK_SIZE_IN_BYTES;
+    localMemoryManager
+        .getQueryPool()
+        .registerPlanNodeIdToQueryMemoryMap(
+            localFragmentInstanceId.queryId, fullFragmentInstanceId, localPlanNodeId);
   }
 
   @Override
@@ -222,10 +226,6 @@ public class SinkChannel implements ISinkChannel {
               bufferRetainedSizeInBytes);
       bufferRetainedSizeInBytes = 0;
     }
-    localMemoryManager
-        .getQueryPool()
-        .clearMemoryReservationMap(
-            localFragmentInstanceId.getQueryId(), fullFragmentInstanceId, localPlanNodeId);
     sinkListener.onAborted(this);
     aborted = true;
     LOGGER.debug("[EndAbortSinkChannel]");
@@ -249,10 +249,6 @@ public class SinkChannel implements ISinkChannel {
               bufferRetainedSizeInBytes);
       bufferRetainedSizeInBytes = 0;
     }
-    localMemoryManager
-        .getQueryPool()
-        .clearMemoryReservationMap(
-            localFragmentInstanceId.getQueryId(), fullFragmentInstanceId, localPlanNodeId);
     sinkListener.onFinish(this);
     closed = true;
     LOGGER.debug("[EndCloseSinkChannel]");
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/source/SourceHandle.java b/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/source/SourceHandle.java
index b66dd4bce4..341b7e8321 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/source/SourceHandle.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/source/SourceHandle.java
@@ -142,6 +142,10 @@ public class SourceHandle implements ISourceHandle {
     this.mppDataExchangeServiceClientManager = mppDataExchangeServiceClientManager;
     this.retryIntervalInMs = DEFAULT_RETRY_INTERVAL_IN_MS;
     this.threadName = createFullIdFrom(localFragmentInstanceId, localPlanNodeId);
+    localMemoryManager
+        .getQueryPool()
+        .registerPlanNodeIdToQueryMemoryMap(
+            localFragmentInstanceId.queryId, fullFragmentInstanceId, localPlanNodeId);
   }
 
   @Override
@@ -332,10 +336,6 @@ public class SourceHandle implements ISourceHandle {
                 bufferRetainedSizeInBytes);
         bufferRetainedSizeInBytes = 0;
       }
-      localMemoryManager
-          .getQueryPool()
-          .clearMemoryReservationMap(
-              localFragmentInstanceId.getQueryId(), fullFragmentInstanceId, localPlanNodeId);
       aborted = true;
       sourceHandleListener.onAborted(this);
     }
@@ -369,10 +369,6 @@ public class SourceHandle implements ISourceHandle {
                 bufferRetainedSizeInBytes);
         bufferRetainedSizeInBytes = 0;
       }
-      localMemoryManager
-          .getQueryPool()
-          .clearMemoryReservationMap(
-              localFragmentInstanceId.getQueryId(), fullFragmentInstanceId, localPlanNodeId);
       closed = true;
       executorService.submit(new SendCloseSinkChannelEventTask());
       currSequenceId = lastSequenceId + 1;
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/execution/fragment/FragmentInstanceExecution.java b/server/src/main/java/org/apache/iotdb/db/mpp/execution/fragment/FragmentInstanceExecution.java
index f9c17c9963..b402f0892e 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/execution/fragment/FragmentInstanceExecution.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/execution/fragment/FragmentInstanceExecution.java
@@ -20,6 +20,7 @@ package org.apache.iotdb.db.mpp.execution.fragment;
 
 import org.apache.iotdb.db.mpp.common.FragmentInstanceId;
 import org.apache.iotdb.db.mpp.execution.driver.IDriver;
+import org.apache.iotdb.db.mpp.execution.exchange.MPPDataExchangeService;
 import org.apache.iotdb.db.mpp.execution.exchange.sink.ISink;
 import org.apache.iotdb.db.mpp.execution.schedule.IDriverScheduler;
 import org.apache.iotdb.db.utils.SetThreadName;
@@ -137,6 +138,10 @@ public class FragmentInstanceExecution {
             context.releaseResource();
             // help for gc
             drivers = null;
+            MPPDataExchangeService.getInstance()
+                .getMPPDataExchangeManager()
+                .deRegisterFragmentInstanceFromMemoryPool(
+                    instanceId.getQueryId().getId(), instanceId.getFragmentInstanceId());
             if (newState.isFailed()) {
               scheduler.abortFragmentInstance(instanceId);
             }
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/execution/memory/MemoryPool.java b/server/src/main/java/org/apache/iotdb/db/mpp/execution/memory/MemoryPool.java
index 3c9dbb10e3..f72d935514 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/execution/memory/MemoryPool.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/execution/memory/MemoryPool.java
@@ -20,6 +20,7 @@
 package org.apache.iotdb.db.mpp.execution.memory;
 
 import org.apache.iotdb.commons.utils.TestOnly;
+import org.apache.iotdb.db.exception.runtime.MemoryLeakException;
 import org.apache.iotdb.tsfile.utils.Pair;
 
 import com.google.common.util.concurrent.AbstractFuture;
@@ -32,13 +33,13 @@ import org.slf4j.LoggerFactory;
 import javax.annotation.Nullable;
 
 import java.util.ArrayList;
-import java.util.Collections;
-import java.util.HashMap;
 import java.util.Iterator;
-import java.util.LinkedList;
 import java.util.List;
 import java.util.Map;
 import java.util.Queue;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ConcurrentLinkedQueue;
+import java.util.concurrent.atomic.AtomicLong;
 
 /** A thread-safe memory pool. */
 public class MemoryPool {
@@ -57,6 +58,8 @@ public class MemoryPool {
      */
     private final long maxBytesCanReserve;
 
+    private boolean isMarked = false;
+
     private MemoryReservationFuture(
         String queryId,
         String fragmentInstanceId,
@@ -77,6 +80,14 @@ public class MemoryPool {
       return queryId;
     }
 
+    public boolean isMarked() {
+      return isMarked;
+    }
+
+    public void setMarked(boolean marked) {
+      isMarked = marked;
+    }
+
     public String getFragmentInstanceId() {
       return fragmentInstanceId;
     }
@@ -113,12 +124,13 @@ public class MemoryPool {
   private final long maxBytes;
   private final long maxBytesPerFragmentInstance;
 
-  private long reservedBytes = 0L;
+  private final AtomicLong remainingBytes;
   /** queryId -> fragmentInstanceId -> planNodeId -> bytesReserved */
   private final Map<String, Map<String, Map<String, Long>>> queryMemoryReservations =
-      new HashMap<>();
+      new ConcurrentHashMap<>();
 
-  private final Queue<MemoryReservationFuture<Void>> memoryReservationFutures = new LinkedList<>();
+  private final Queue<MemoryReservationFuture<Void>> memoryReservationFutures =
+      new ConcurrentLinkedQueue<>();
 
   public MemoryPool(String id, long maxBytes, long maxBytesPerFragmentInstance) {
     this.id = Validate.notNull(id);
@@ -130,14 +142,51 @@ public class MemoryPool {
         maxBytesPerFragmentInstance,
         maxBytes);
     this.maxBytesPerFragmentInstance = maxBytesPerFragmentInstance;
+    this.remainingBytes = new AtomicLong(maxBytes);
   }
 
   public String getId() {
     return id;
   }
 
-  public long getMaxBytes() {
-    return maxBytes;
+  /**
+   * Before executing, we register memory map which is related to queryId, fragmentInstanceId, and
+   * planNodeId to queryMemoryReservationsMap first.
+   */
+  public void registerPlanNodeIdToQueryMemoryMap(
+      String queryId, String fragmentInstanceId, String planNodeId) {
+    synchronized (queryMemoryReservations) {
+      queryMemoryReservations
+          .computeIfAbsent(queryId, x -> new ConcurrentHashMap<>())
+          .computeIfAbsent(fragmentInstanceId, x -> new ConcurrentHashMap<>())
+          .putIfAbsent(planNodeId, 0L);
+    }
+  }
+
+  /**
+   * If all fragmentInstanceIds related to one queryId have been registered, when the last fragment
+   * instance is deregister, the queryId can be cleared.
+   *
+   * <p>If some fragmentInstanceIds have not been registered when queryId is cleared, they will
+   * register queryId again with lock, so there is no concurrency problem.
+   */
+  public void deRegisterFragmentInstanceToQueryMemoryMap(
+      String queryId, String fragmentInstanceId) {
+    Map<String, Long> planNodeRelatedMemory =
+        queryMemoryReservations.get(queryId).get(fragmentInstanceId);
+    for (Long memoryReserved : planNodeRelatedMemory.values()) {
+      if (memoryReserved != 0) {
+        throw new MemoryLeakException(
+            "PlanNode related memory is not zero when deregister fragment instance from query memory pool.");
+      }
+    }
+    synchronized (queryMemoryReservations) {
+      Map<String, Map<String, Long>> queryRelatedMemory = queryMemoryReservations.get(queryId);
+      queryRelatedMemory.remove(fragmentInstanceId);
+      if (queryRelatedMemory.isEmpty()) {
+        queryMemoryReservations.remove(queryId);
+      }
+    }
   }
 
   /**
@@ -169,37 +218,23 @@ public class MemoryPool {
     }
 
     ListenableFuture<Void> result;
-    synchronized (this) {
-      if (maxBytes - reservedBytes < bytesToReserve
-          || maxBytesCanReserve
-                  - queryMemoryReservations
-                      .getOrDefault(queryId, Collections.emptyMap())
-                      .getOrDefault(fragmentInstanceId, Collections.emptyMap())
-                      .getOrDefault(planNodeId, 0L)
-              < bytesToReserve) {
-        LOGGER.debug(
-            "Blocked reserve request: {} bytes memory for planNodeId{}",
-            bytesToReserve,
-            planNodeId);
-        result =
-            MemoryReservationFuture.create(
-                queryId, fragmentInstanceId, planNodeId, bytesToReserve, maxBytesCanReserve);
-        memoryReservationFutures.add((MemoryReservationFuture<Void>) result);
-        return new Pair<>(result, Boolean.FALSE);
-      } else {
-        reservedBytes += bytesToReserve;
-        queryMemoryReservations
-            .computeIfAbsent(queryId, x -> new HashMap<>())
-            .computeIfAbsent(fragmentInstanceId, x -> new HashMap<>())
-            .merge(planNodeId, bytesToReserve, Long::sum);
-        result = Futures.immediateFuture(null);
-        return new Pair<>(result, Boolean.TRUE);
-      }
+    if (tryReserve(queryId, fragmentInstanceId, planNodeId, bytesToReserve, maxBytesCanReserve)) {
+      result = Futures.immediateFuture(null);
+      return new Pair<>(result, Boolean.TRUE);
+    } else {
+      LOGGER.debug(
+          "Blocked reserve request: {} bytes memory for planNodeId{}", bytesToReserve, planNodeId);
+      rollbackReserve(queryId, fragmentInstanceId, planNodeId, bytesToReserve);
+      result =
+          MemoryReservationFuture.create(
+              queryId, fragmentInstanceId, planNodeId, bytesToReserve, maxBytesCanReserve);
+      memoryReservationFutures.add((MemoryReservationFuture<Void>) result);
+      return new Pair<>(result, Boolean.FALSE);
     }
   }
 
   @TestOnly
-  public boolean tryReserve(
+  public boolean tryReserveForTest(
       String queryId,
       String fragmentInstanceId,
       String planNodeId,
@@ -213,32 +248,12 @@ public class MemoryPool {
         "bytes should be greater than zero while less than or equal to max bytes per fragment instance: %d",
         bytesToReserve);
 
-    if (maxBytes - reservedBytes < bytesToReserve
-        || maxBytesCanReserve
-                - queryMemoryReservations
-                    .getOrDefault(queryId, Collections.emptyMap())
-                    .getOrDefault(fragmentInstanceId, Collections.emptyMap())
-                    .getOrDefault(planNodeId, 0L)
-            < bytesToReserve) {
+    if (tryReserve(queryId, fragmentInstanceId, planNodeId, bytesToReserve, maxBytesCanReserve)) {
+      return true;
+    } else {
+      rollbackReserve(queryId, fragmentInstanceId, planNodeId, bytesToReserve);
       return false;
     }
-    synchronized (this) {
-      if (maxBytes - reservedBytes < bytesToReserve
-          || maxBytesCanReserve
-                  - queryMemoryReservations
-                      .getOrDefault(queryId, Collections.emptyMap())
-                      .getOrDefault(fragmentInstanceId, Collections.emptyMap())
-                      .getOrDefault(planNodeId, 0L)
-              < bytesToReserve) {
-        return false;
-      }
-      reservedBytes += bytesToReserve;
-      queryMemoryReservations
-          .computeIfAbsent(queryId, x -> new HashMap<>())
-          .computeIfAbsent(fragmentInstanceId, x -> new HashMap<>())
-          .merge(planNodeId, bytesToReserve, Long::sum);
-    }
-    return true;
   }
 
   /**
@@ -282,58 +297,50 @@ public class MemoryPool {
   }
 
   public void free(String queryId, String fragmentInstanceId, String planNodeId, long bytes) {
-    List<MemoryReservationFuture<Void>> futureList = new ArrayList<>();
-    synchronized (this) {
-      Validate.notNull(queryId);
-      Validate.isTrue(bytes > 0L);
-
-      Long queryReservedBytes =
-          queryMemoryReservations
-              .getOrDefault(queryId, Collections.emptyMap())
-              .getOrDefault(fragmentInstanceId, Collections.emptyMap())
-              .get(planNodeId);
-      Validate.notNull(queryReservedBytes);
-      Validate.isTrue(bytes <= queryReservedBytes);
-
-      queryReservedBytes -= bytes;
+    Validate.notNull(queryId);
+    Validate.isTrue(bytes > 0L);
+
+    try {
       queryMemoryReservations
           .get(queryId)
           .get(fragmentInstanceId)
-          .put(planNodeId, queryReservedBytes);
+          .computeIfPresent(
+              planNodeId,
+              (k, reservedMemory) -> {
+                if (reservedMemory < bytes) {
+                  throw new IllegalArgumentException("Free more memory than has been reserved.");
+                }
+                return reservedMemory - bytes;
+              });
+    } catch (NullPointerException e) {
+      throw new IllegalArgumentException("RelatedMemoryReserved can't be null when freeing memory");
+    }
 
-      reservedBytes -= bytes;
+    remainingBytes.addAndGet(bytes);
 
-      if (memoryReservationFutures.isEmpty()) {
-        return;
-      }
-      Iterator<MemoryReservationFuture<Void>> iterator = memoryReservationFutures.iterator();
-      while (iterator.hasNext()) {
-        MemoryReservationFuture<Void> future = iterator.next();
-        if (future.isCancelled() || future.isDone()) {
+    List<MemoryReservationFuture<Void>> futureList = new ArrayList<>();
+    if (memoryReservationFutures.isEmpty()) {
+      return;
+    }
+    Iterator<MemoryReservationFuture<Void>> iterator = memoryReservationFutures.iterator();
+    while (iterator.hasNext()) {
+      MemoryReservationFuture<Void> future = iterator.next();
+      synchronized (future) {
+        if (future.isCancelled() || future.isDone() || future.isMarked()) {
           continue;
         }
         long bytesToReserve = future.getBytesToReserve();
         String curQueryId = future.getQueryId();
         String curFragmentInstanceId = future.getFragmentInstanceId();
         String curPlanNodeId = future.getPlanNodeId();
-        // check total reserved bytes in memory pool
-        if (maxBytes - reservedBytes < bytesToReserve) {
-          continue;
-        }
-        // check total reserved bytes of one Sink/Source handle
-        if (future.getMaxBytesCanReserve()
-                - queryMemoryReservations
-                    .getOrDefault(curQueryId, Collections.emptyMap())
-                    .getOrDefault(curFragmentInstanceId, Collections.emptyMap())
-                    .getOrDefault(curPlanNodeId, 0L)
-            >= bytesToReserve) {
-          reservedBytes += bytesToReserve;
-          queryMemoryReservations
-              .computeIfAbsent(curQueryId, x -> new HashMap<>())
-              .computeIfAbsent(curFragmentInstanceId, x -> new HashMap<>())
-              .merge(curPlanNodeId, bytesToReserve, Long::sum);
+        long maxBytesCanReserve = future.getMaxBytesCanReserve();
+        if (tryReserve(
+            curQueryId, curFragmentInstanceId, curPlanNodeId, bytesToReserve, maxBytesCanReserve)) {
           futureList.add(future);
+          future.setMarked(true);
           iterator.remove();
+        } else {
+          rollbackReserve(curQueryId, curFragmentInstanceId, curPlanNodeId, bytesToReserve);
         }
       }
     }
@@ -342,10 +349,10 @@ public class MemoryPool {
     // If we put this block inside the MemoryPool's lock, we will get deadlock case like the
     // following:
     // Assuming that thread-A: LocalSourceHandle.receive() -> A-SharedTsBlockQueue.remove() ->
-    // MemoryPool.free() (hold MemoryPool's lock) -> future.set(null) -> try to get
+    // MemoryPool.free() (hold future's lock) -> future.set(null) -> try to get
     // B-SharedTsBlockQueue's lock
     // thread-B: LocalSourceHandle.receive() -> B-SharedTsBlockQueue.remove() (hold
-    // B-SharedTsBlockQueue's lock) -> try to get MemoryPool's lock
+    // B-SharedTsBlockQueue's lock) -> try to get future's lock
     for (MemoryReservationFuture<Void> future : futureList) {
       try {
         future.set(null);
@@ -368,26 +375,31 @@ public class MemoryPool {
   }
 
   public long getReservedBytes() {
-    return reservedBytes;
+    return maxBytes - remainingBytes.get();
   }
 
-  public synchronized void clearMemoryReservationMap(
-      String queryId, String fragmentInstanceId, String planNodeId) {
-    if (queryMemoryReservations.get(queryId) == null
-        || queryMemoryReservations.get(queryId).get(fragmentInstanceId) == null) {
-      return;
-    }
-    Map<String, Long> planNodeIdToBytesReserved =
-        queryMemoryReservations.get(queryId).get(fragmentInstanceId);
-    if (planNodeIdToBytesReserved.get(planNodeId) == null
-        || planNodeIdToBytesReserved.get(planNodeId) <= 0) {
-      planNodeIdToBytesReserved.remove(planNodeId);
-      if (planNodeIdToBytesReserved.isEmpty()) {
-        queryMemoryReservations.get(queryId).remove(fragmentInstanceId);
-      }
-      if (queryMemoryReservations.get(queryId).isEmpty()) {
-        queryMemoryReservations.remove(queryId);
-      }
-    }
+  public boolean tryReserve(
+      String queryId,
+      String fragmentInstanceId,
+      String planNodeId,
+      long bytesToReserve,
+      long maxBytesCanReserve) {
+    long tryRemainingBytes = remainingBytes.addAndGet(-bytesToReserve);
+    long queryRemainingBytes =
+        maxBytesCanReserve
+            - queryMemoryReservations
+                .get(queryId)
+                .get(fragmentInstanceId)
+                .merge(planNodeId, bytesToReserve, Long::sum);
+    return tryRemainingBytes >= 0 && queryRemainingBytes >= 0;
+  }
+
+  private void rollbackReserve(
+      String queryId, String fragmentInstanceId, String planNodeId, long bytesToReserve) {
+    queryMemoryReservations
+        .get(queryId)
+        .get(fragmentInstanceId)
+        .merge(planNodeId, -bytesToReserve, Long::sum);
+    remainingBytes.addAndGet(bytesToReserve);
   }
 }
diff --git a/server/src/test/java/org/apache/iotdb/db/mpp/execution/memory/MemoryPoolTest.java b/server/src/test/java/org/apache/iotdb/db/mpp/execution/memory/MemoryPoolTest.java
index 94a7c81999..a9b924c418 100644
--- a/server/src/test/java/org/apache/iotdb/db/mpp/execution/memory/MemoryPoolTest.java
+++ b/server/src/test/java/org/apache/iotdb/db/mpp/execution/memory/MemoryPoolTest.java
@@ -36,13 +36,14 @@ public class MemoryPoolTest {
   @Before
   public void before() {
     pool = new MemoryPool("test", 1024L, 512L);
+    pool.registerPlanNodeIdToQueryMemoryMap(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID);
   }
 
   @Test
   public void testTryReserve() {
 
     Assert.assertTrue(
-        pool.tryReserve(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 256L, Long.MAX_VALUE));
+        pool.tryReserveForTest(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 256L, Long.MAX_VALUE));
     Assert.assertEquals(256L, pool.getQueryMemoryReservedBytes(QUERY_ID));
     Assert.assertEquals(256L, pool.getReservedBytes());
   }
@@ -51,7 +52,7 @@ public class MemoryPoolTest {
   public void testTryReserveZero() {
 
     try {
-      pool.tryReserve(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 0L, Long.MAX_VALUE);
+      pool.tryReserveForTest(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 0L, Long.MAX_VALUE);
       Assert.fail("Expect IllegalArgumentException");
     } catch (IllegalArgumentException ignore) {
     }
@@ -61,7 +62,7 @@ public class MemoryPoolTest {
   public void testTryReserveNegative() {
 
     try {
-      pool.tryReserve(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, -1L, Long.MAX_VALUE);
+      pool.tryReserveForTest(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, -1L, Long.MAX_VALUE);
       Assert.fail("Expect IllegalArgumentException");
     } catch (IllegalArgumentException ignore) {
     }
@@ -71,7 +72,7 @@ public class MemoryPoolTest {
   public void testTryReserveAll() {
 
     Assert.assertTrue(
-        pool.tryReserve(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 512L, Long.MAX_VALUE));
+        pool.tryReserveForTest(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 512L, Long.MAX_VALUE));
     Assert.assertEquals(512L, pool.getQueryMemoryReservedBytes(QUERY_ID));
     Assert.assertEquals(512L, pool.getReservedBytes());
   }
@@ -79,10 +80,12 @@ public class MemoryPoolTest {
   @Test
   public void testOverTryReserve() {
 
-    Assert.assertTrue(pool.tryReserve(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 256L, 512L));
+    Assert.assertTrue(
+        pool.tryReserveForTest(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 256L, 512L));
     Assert.assertEquals(256L, pool.getQueryMemoryReservedBytes(QUERY_ID));
     Assert.assertEquals(256L, pool.getReservedBytes());
-    Assert.assertFalse(pool.tryReserve(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 512L, 511L));
+    Assert.assertFalse(
+        pool.tryReserveForTest(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 512L, 511L));
     Assert.assertEquals(256L, pool.getQueryMemoryReservedBytes(QUERY_ID));
     Assert.assertEquals(256L, pool.getReservedBytes());
   }
@@ -206,7 +209,7 @@ public class MemoryPoolTest {
   public void testFree() {
 
     Assert.assertTrue(
-        pool.tryReserve(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 512L, Long.MAX_VALUE));
+        pool.tryReserveForTest(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 512L, Long.MAX_VALUE));
     Assert.assertEquals(512L, pool.getQueryMemoryReservedBytes(QUERY_ID));
     Assert.assertEquals(512L, pool.getReservedBytes());
 
@@ -219,7 +222,7 @@ public class MemoryPoolTest {
   public void testFreeAll() {
 
     Assert.assertTrue(
-        pool.tryReserve(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 512L, Long.MAX_VALUE));
+        pool.tryReserveForTest(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 512L, Long.MAX_VALUE));
     Assert.assertEquals(512L, pool.getQueryMemoryReservedBytes(QUERY_ID));
     Assert.assertEquals(512L, pool.getReservedBytes());
 
@@ -232,7 +235,7 @@ public class MemoryPoolTest {
   public void testFreeZero() {
 
     Assert.assertTrue(
-        pool.tryReserve(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 512L, Long.MAX_VALUE));
+        pool.tryReserveForTest(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 512L, Long.MAX_VALUE));
     Assert.assertEquals(512L, pool.getQueryMemoryReservedBytes(QUERY_ID));
     Assert.assertEquals(512L, pool.getReservedBytes());
 
@@ -247,7 +250,7 @@ public class MemoryPoolTest {
   public void testFreeNegative() {
 
     Assert.assertTrue(
-        pool.tryReserve(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 512L, Long.MAX_VALUE));
+        pool.tryReserveForTest(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 512L, Long.MAX_VALUE));
     Assert.assertEquals(512L, pool.getQueryMemoryReservedBytes(QUERY_ID));
     Assert.assertEquals(512L, pool.getReservedBytes());
 
@@ -262,7 +265,7 @@ public class MemoryPoolTest {
   public void testOverFree() {
 
     Assert.assertTrue(
-        pool.tryReserve(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 512L, Long.MAX_VALUE));
+        pool.tryReserveForTest(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 512L, Long.MAX_VALUE));
     Assert.assertEquals(512L, pool.getQueryMemoryReservedBytes(QUERY_ID));
     Assert.assertEquals(512L, pool.getReservedBytes());
 
@@ -278,7 +281,7 @@ public class MemoryPoolTest {
 
     // Run out of memory.
     Assert.assertTrue(
-        pool.tryReserve(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 512L, Long.MAX_VALUE));
+        pool.tryReserveForTest(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 512L, Long.MAX_VALUE));
 
     ListenableFuture<Void> f =
         pool.reserve(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 256L, 512L).left;