You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@iotdb.apache.org by xi...@apache.org on 2023/02/27 06:10:02 UTC

[iotdb] 01/02: reduce the lock range in memory pool

This is an automated email from the ASF dual-hosted git repository.

xiangweiwei pushed a commit to branch memoryPoolLock
in repository https://gitbox.apache.org/repos/asf/iotdb.git

commit 5b5506a8e4a9d7c66ee222e7f85ef9fe5e6b3ced
Author: Alima777 <wx...@gmail.com>
AuthorDate: Mon Feb 27 13:46:06 2023 +0800

    reduce the lock range in memory pool
---
 .../iotdb/db/mpp/execution/memory/MemoryPool.java  | 187 +++++++++++----------
 1 file changed, 94 insertions(+), 93 deletions(-)

diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/execution/memory/MemoryPool.java b/server/src/main/java/org/apache/iotdb/db/mpp/execution/memory/MemoryPool.java
index 3c9dbb10e3..2ff322e148 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/execution/memory/MemoryPool.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/execution/memory/MemoryPool.java
@@ -25,20 +25,21 @@ import org.apache.iotdb.tsfile.utils.Pair;
 import com.google.common.util.concurrent.AbstractFuture;
 import com.google.common.util.concurrent.Futures;
 import com.google.common.util.concurrent.ListenableFuture;
+import javax.annotation.Nullable;
 import org.apache.commons.lang3.Validate;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
-import javax.annotation.Nullable;
-
 import java.util.ArrayList;
 import java.util.Collections;
-import java.util.HashMap;
 import java.util.Iterator;
-import java.util.LinkedList;
 import java.util.List;
 import java.util.Map;
 import java.util.Queue;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ConcurrentLinkedQueue;
+import java.util.concurrent.atomic.AtomicLong;
+import java.util.concurrent.locks.ReentrantLock;
 
 /** A thread-safe memory pool. */
 public class MemoryPool {
@@ -113,12 +114,14 @@ public class MemoryPool {
   private final long maxBytes;
   private final long maxBytesPerFragmentInstance;
 
-  private long reservedBytes = 0L;
+  private final ReentrantLock memoryReserveLock = new ReentrantLock(false);
+  private final AtomicLong remainingBytes;
   /** queryId -> fragmentInstanceId -> planNodeId -> bytesReserved */
   private final Map<String, Map<String, Map<String, Long>>> queryMemoryReservations =
-      new HashMap<>();
+      new ConcurrentHashMap<>();
 
-  private final Queue<MemoryReservationFuture<Void>> memoryReservationFutures = new LinkedList<>();
+  private final Queue<MemoryReservationFuture<Void>> memoryReservationFutures =
+      new ConcurrentLinkedQueue<>();
 
   public MemoryPool(String id, long maxBytes, long maxBytesPerFragmentInstance) {
     this.id = Validate.notNull(id);
@@ -130,16 +133,13 @@ public class MemoryPool {
         maxBytesPerFragmentInstance,
         maxBytes);
     this.maxBytesPerFragmentInstance = maxBytesPerFragmentInstance;
+    this.remainingBytes = new AtomicLong(maxBytes);
   }
 
   public String getId() {
     return id;
   }
 
-  public long getMaxBytes() {
-    return maxBytes;
-  }
-
   /**
    * Reserve memory with bytesToReserve.
    *
@@ -169,8 +169,9 @@ public class MemoryPool {
     }
 
     ListenableFuture<Void> result;
-    synchronized (this) {
-      if (maxBytes - reservedBytes < bytesToReserve
+    memoryReserveLock.lock();
+    try {
+      if (remainingBytes.get() < bytesToReserve
           || maxBytesCanReserve
                   - queryMemoryReservations
                       .getOrDefault(queryId, Collections.emptyMap())
@@ -187,14 +188,16 @@ public class MemoryPool {
         memoryReservationFutures.add((MemoryReservationFuture<Void>) result);
         return new Pair<>(result, Boolean.FALSE);
       } else {
-        reservedBytes += bytesToReserve;
+        remainingBytes.addAndGet(-bytesToReserve);
         queryMemoryReservations
-            .computeIfAbsent(queryId, x -> new HashMap<>())
-            .computeIfAbsent(fragmentInstanceId, x -> new HashMap<>())
+            .computeIfAbsent(queryId, x -> new ConcurrentHashMap<>())
+            .computeIfAbsent(fragmentInstanceId, x -> new ConcurrentHashMap<>())
             .merge(planNodeId, bytesToReserve, Long::sum);
         result = Futures.immediateFuture(null);
         return new Pair<>(result, Boolean.TRUE);
       }
+    } finally {
+      memoryReserveLock.unlock();
     }
   }
 
@@ -213,17 +216,9 @@ public class MemoryPool {
         "bytes should be greater than zero while less than or equal to max bytes per fragment instance: %d",
         bytesToReserve);
 
-    if (maxBytes - reservedBytes < bytesToReserve
-        || maxBytesCanReserve
-                - queryMemoryReservations
-                    .getOrDefault(queryId, Collections.emptyMap())
-                    .getOrDefault(fragmentInstanceId, Collections.emptyMap())
-                    .getOrDefault(planNodeId, 0L)
-            < bytesToReserve) {
-      return false;
-    }
-    synchronized (this) {
-      if (maxBytes - reservedBytes < bytesToReserve
+    memoryReserveLock.lock();
+    try {
+      if (remainingBytes.get() < bytesToReserve
           || maxBytesCanReserve
                   - queryMemoryReservations
                       .getOrDefault(queryId, Collections.emptyMap())
@@ -232,13 +227,15 @@ public class MemoryPool {
               < bytesToReserve) {
         return false;
       }
-      reservedBytes += bytesToReserve;
+      remainingBytes.addAndGet(-bytesToReserve);
       queryMemoryReservations
-          .computeIfAbsent(queryId, x -> new HashMap<>())
-          .computeIfAbsent(fragmentInstanceId, x -> new HashMap<>())
+          .computeIfAbsent(queryId, x -> new ConcurrentHashMap<>())
+          .computeIfAbsent(fragmentInstanceId, x -> new ConcurrentHashMap<>())
           .merge(planNodeId, bytesToReserve, Long::sum);
+      return true;
+    } finally {
+      memoryReserveLock.unlock();
     }
-    return true;
   }
 
   /**
@@ -282,59 +279,61 @@ public class MemoryPool {
   }
 
   public void free(String queryId, String fragmentInstanceId, String planNodeId, long bytes) {
-    List<MemoryReservationFuture<Void>> futureList = new ArrayList<>();
-    synchronized (this) {
-      Validate.notNull(queryId);
-      Validate.isTrue(bytes > 0L);
-
-      Long queryReservedBytes =
-          queryMemoryReservations
-              .getOrDefault(queryId, Collections.emptyMap())
-              .getOrDefault(fragmentInstanceId, Collections.emptyMap())
-              .get(planNodeId);
-      Validate.notNull(queryReservedBytes);
-      Validate.isTrue(bytes <= queryReservedBytes);
-
-      queryReservedBytes -= bytes;
-      queryMemoryReservations
-          .get(queryId)
-          .get(fragmentInstanceId)
-          .put(planNodeId, queryReservedBytes);
+    Validate.notNull(queryId);
+    Validate.isTrue(bytes > 0L);
 
-      reservedBytes -= bytes;
+    Long queryReservedBytes =
+        queryMemoryReservations
+            .getOrDefault(queryId, Collections.emptyMap())
+            .getOrDefault(fragmentInstanceId, Collections.emptyMap())
+            .computeIfPresent(
+                planNodeId,
+                (k, reservedMemory) -> {
+                  if (reservedMemory < bytes) {
+                    throw new IllegalArgumentException("Free more memory than has been reserved.");
+                  }
+                  return reservedMemory - bytes;
+                });
+
+    Validate.notNull(queryReservedBytes);
+    remainingBytes.addAndGet(bytes);
 
-      if (memoryReservationFutures.isEmpty()) {
-        return;
+    List<MemoryReservationFuture<Void>> futureList = new ArrayList<>();
+    if (memoryReservationFutures.isEmpty()) {
+      return;
+    }
+    Iterator<MemoryReservationFuture<Void>> iterator = memoryReservationFutures.iterator();
+    while (iterator.hasNext()) {
+      MemoryReservationFuture<Void> future = iterator.next();
+      if (future.isCancelled() || future.isDone()) {
+        continue;
       }
-      Iterator<MemoryReservationFuture<Void>> iterator = memoryReservationFutures.iterator();
-      while (iterator.hasNext()) {
-        MemoryReservationFuture<Void> future = iterator.next();
-        if (future.isCancelled() || future.isDone()) {
-          continue;
-        }
-        long bytesToReserve = future.getBytesToReserve();
-        String curQueryId = future.getQueryId();
-        String curFragmentInstanceId = future.getFragmentInstanceId();
-        String curPlanNodeId = future.getPlanNodeId();
-        // check total reserved bytes in memory pool
-        if (maxBytes - reservedBytes < bytesToReserve) {
+      long bytesToReserve = future.getBytesToReserve();
+      String curQueryId = future.getQueryId();
+      String curFragmentInstanceId = future.getFragmentInstanceId();
+      String curPlanNodeId = future.getPlanNodeId();
+      // check total reserved bytes in memory pool
+      memoryReserveLock.lock();
+      try {
+        if (remainingBytes.get() < bytesToReserve
+            || future.getMaxBytesCanReserve()
+                    - queryMemoryReservations
+                        .getOrDefault(curQueryId, Collections.emptyMap())
+                        .getOrDefault(curFragmentInstanceId, Collections.emptyMap())
+                        .getOrDefault(curPlanNodeId, 0L)
+                < bytesToReserve) {
           continue;
         }
         // check total reserved bytes of one Sink/Source handle
-        if (future.getMaxBytesCanReserve()
-                - queryMemoryReservations
-                    .getOrDefault(curQueryId, Collections.emptyMap())
-                    .getOrDefault(curFragmentInstanceId, Collections.emptyMap())
-                    .getOrDefault(curPlanNodeId, 0L)
-            >= bytesToReserve) {
-          reservedBytes += bytesToReserve;
-          queryMemoryReservations
-              .computeIfAbsent(curQueryId, x -> new HashMap<>())
-              .computeIfAbsent(curFragmentInstanceId, x -> new HashMap<>())
-              .merge(curPlanNodeId, bytesToReserve, Long::sum);
-          futureList.add(future);
-          iterator.remove();
-        }
+        remainingBytes.addAndGet(-bytesToReserve);
+        queryMemoryReservations
+            .computeIfAbsent(curQueryId, x -> new ConcurrentHashMap<>())
+            .computeIfAbsent(curFragmentInstanceId, x -> new ConcurrentHashMap<>())
+            .merge(curPlanNodeId, bytesToReserve, Long::sum);
+        futureList.add(future);
+        iterator.remove();
+      } finally {
+        memoryReserveLock.unlock();
       }
     }
 
@@ -368,25 +367,27 @@ public class MemoryPool {
   }
 
   public long getReservedBytes() {
-    return reservedBytes;
+    return maxBytes - remainingBytes.get();
   }
 
-  public synchronized void clearMemoryReservationMap(
+  public void clearMemoryReservationMap(
       String queryId, String fragmentInstanceId, String planNodeId) {
-    if (queryMemoryReservations.get(queryId) == null
-        || queryMemoryReservations.get(queryId).get(fragmentInstanceId) == null) {
-      return;
-    }
-    Map<String, Long> planNodeIdToBytesReserved =
-        queryMemoryReservations.get(queryId).get(fragmentInstanceId);
-    if (planNodeIdToBytesReserved.get(planNodeId) == null
-        || planNodeIdToBytesReserved.get(planNodeId) <= 0) {
-      planNodeIdToBytesReserved.remove(planNodeId);
-      if (planNodeIdToBytesReserved.isEmpty()) {
-        queryMemoryReservations.get(queryId).remove(fragmentInstanceId);
+    synchronized (queryMemoryReservations) {
+      if (queryMemoryReservations.get(queryId) == null
+          || queryMemoryReservations.get(queryId).get(fragmentInstanceId) == null) {
+        return;
       }
-      if (queryMemoryReservations.get(queryId).isEmpty()) {
-        queryMemoryReservations.remove(queryId);
+      Map<String, Long> planNodeIdToBytesReserved =
+          queryMemoryReservations.get(queryId).get(fragmentInstanceId);
+      if (planNodeIdToBytesReserved.get(planNodeId) == null
+          || planNodeIdToBytesReserved.get(planNodeId) <= 0) {
+        planNodeIdToBytesReserved.remove(planNodeId);
+        if (planNodeIdToBytesReserved.isEmpty()) {
+          queryMemoryReservations.get(queryId).remove(fragmentInstanceId);
+        }
+        if (queryMemoryReservations.get(queryId).isEmpty()) {
+          queryMemoryReservations.remove(queryId);
+        }
       }
     }
   }