You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@iotdb.apache.org by ja...@apache.org on 2023/03/29 05:47:07 UTC
[iotdb] branch master updated: [IOTDB-5586] Reduce the scope of lock in MemoryPool Version2
This is an automated email from the ASF dual-hosted git repository.
jackietien pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iotdb.git
The following commit(s) were added to refs/heads/master by this push:
new 94620929bb [IOTDB-5586] Reduce the scope of lock in MemoryPool Version2
94620929bb is described below
commit 94620929bb7e619c1613d4a17039eec2031fca32
Author: Xiangwei Wei <34...@users.noreply.github.com>
AuthorDate: Wed Mar 29 13:46:59 2023 +0800
[IOTDB-5586] Reduce the scope of lock in MemoryPool Version2
---
.../db/exception/runtime/MemoryLeakException.java | 27 +++
.../iotdb/db/mpp/common/FragmentInstanceId.java | 4 +
.../execution/exchange/MPPDataExchangeManager.java | 6 +
.../mpp/execution/exchange/SharedTsBlockQueue.java | 16 +-
.../mpp/execution/exchange/sink/SinkChannel.java | 12 +-
.../execution/exchange/source/SourceHandle.java | 12 +-
.../fragment/FragmentInstanceExecution.java | 5 +
.../iotdb/db/mpp/execution/memory/MemoryPool.java | 250 +++++++++++----------
.../db/mpp/execution/memory/MemoryPoolTest.java | 27 ++-
9 files changed, 200 insertions(+), 159 deletions(-)
diff --git a/server/src/main/java/org/apache/iotdb/db/exception/runtime/MemoryLeakException.java b/server/src/main/java/org/apache/iotdb/db/exception/runtime/MemoryLeakException.java
new file mode 100644
index 0000000000..beb9a83951
--- /dev/null
+++ b/server/src/main/java/org/apache/iotdb/db/exception/runtime/MemoryLeakException.java
@@ -0,0 +1,27 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iotdb.db.exception.runtime;
+
+public class MemoryLeakException extends RuntimeException {
+
+ public MemoryLeakException(String message) {
+ super(message);
+ }
+}
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/common/FragmentInstanceId.java b/server/src/main/java/org/apache/iotdb/db/mpp/common/FragmentInstanceId.java
index 793066b94a..24164c4e89 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/common/FragmentInstanceId.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/common/FragmentInstanceId.java
@@ -57,6 +57,10 @@ public class FragmentInstanceId {
return instanceId;
}
+ public String getFragmentInstanceId() {
+ return fragmentId + "." + instanceId;
+ }
+
public String toString() {
return fullId;
}
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/MPPDataExchangeManager.java b/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/MPPDataExchangeManager.java
index 501781b408..9a58904c79 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/MPPDataExchangeManager.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/MPPDataExchangeManager.java
@@ -470,6 +470,12 @@ public class MPPDataExchangeManager implements IMPPDataExchangeManager {
return mppDataExchangeService;
}
+ public void deRegisterFragmentInstanceFromMemoryPool(String queryId, String fragmentInstanceId) {
+ localMemoryManager
+ .getQueryPool()
+ .deRegisterFragmentInstanceToQueryMemoryMap(queryId, fragmentInstanceId);
+ }
+
private synchronized ISinkChannel createLocalSinkChannel(
TFragmentInstanceId localFragmentInstanceId,
TFragmentInstanceId remoteFragmentInstanceId,
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/SharedTsBlockQueue.java b/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/SharedTsBlockQueue.java
index 55b7462922..34858dfaa4 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/SharedTsBlockQueue.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/SharedTsBlockQueue.java
@@ -91,6 +91,10 @@ public class SharedTsBlockQueue {
this.localPlanNodeId = Validate.notNull(planNodeId, "PlanNode ID cannot be null");
this.localMemoryManager =
Validate.notNull(localMemoryManager, "local memory manager cannot be null");
+ localMemoryManager
+ .getQueryPool()
+ .registerPlanNodeIdToQueryMemoryMap(
+ fragmentInstanceId.queryId, fullFragmentInstanceId, planNodeId);
}
public boolean hasNoMoreTsBlocks() {
@@ -260,10 +264,6 @@ public class SharedTsBlockQueue {
bufferRetainedSizeInBytes);
bufferRetainedSizeInBytes = 0;
}
- localMemoryManager
- .getQueryPool()
- .clearMemoryReservationMap(
- localFragmentInstanceId.getQueryId(), fullFragmentInstanceId, localPlanNodeId);
}
/** Destroy the queue and cancel the future. Should only be called in abnormal case */
@@ -289,10 +289,6 @@ public class SharedTsBlockQueue {
bufferRetainedSizeInBytes);
bufferRetainedSizeInBytes = 0;
}
- localMemoryManager
- .getQueryPool()
- .clearMemoryReservationMap(
- localFragmentInstanceId.getQueryId(), fullFragmentInstanceId, localPlanNodeId);
}
/** Destroy the queue and cancel the future. Should only be called in abnormal case */
@@ -318,9 +314,5 @@ public class SharedTsBlockQueue {
bufferRetainedSizeInBytes);
bufferRetainedSizeInBytes = 0;
}
- localMemoryManager
- .getQueryPool()
- .clearMemoryReservationMap(
- localFragmentInstanceId.getQueryId(), fullFragmentInstanceId, localPlanNodeId);
}
}
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/sink/SinkChannel.java b/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/sink/SinkChannel.java
index 5cd28de462..b32028bee1 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/sink/SinkChannel.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/sink/SinkChannel.java
@@ -140,6 +140,10 @@ public class SinkChannel implements ISinkChannel {
localFragmentInstanceId.instanceId);
this.bufferRetainedSizeInBytes = DEFAULT_MAX_TSBLOCK_SIZE_IN_BYTES;
this.currentTsBlockSize = DEFAULT_MAX_TSBLOCK_SIZE_IN_BYTES;
+ localMemoryManager
+ .getQueryPool()
+ .registerPlanNodeIdToQueryMemoryMap(
+ localFragmentInstanceId.queryId, fullFragmentInstanceId, localPlanNodeId);
}
@Override
@@ -222,10 +226,6 @@ public class SinkChannel implements ISinkChannel {
bufferRetainedSizeInBytes);
bufferRetainedSizeInBytes = 0;
}
- localMemoryManager
- .getQueryPool()
- .clearMemoryReservationMap(
- localFragmentInstanceId.getQueryId(), fullFragmentInstanceId, localPlanNodeId);
sinkListener.onAborted(this);
aborted = true;
LOGGER.debug("[EndAbortSinkChannel]");
@@ -249,10 +249,6 @@ public class SinkChannel implements ISinkChannel {
bufferRetainedSizeInBytes);
bufferRetainedSizeInBytes = 0;
}
- localMemoryManager
- .getQueryPool()
- .clearMemoryReservationMap(
- localFragmentInstanceId.getQueryId(), fullFragmentInstanceId, localPlanNodeId);
sinkListener.onFinish(this);
closed = true;
LOGGER.debug("[EndCloseSinkChannel]");
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/source/SourceHandle.java b/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/source/SourceHandle.java
index b66dd4bce4..341b7e8321 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/source/SourceHandle.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/source/SourceHandle.java
@@ -142,6 +142,10 @@ public class SourceHandle implements ISourceHandle {
this.mppDataExchangeServiceClientManager = mppDataExchangeServiceClientManager;
this.retryIntervalInMs = DEFAULT_RETRY_INTERVAL_IN_MS;
this.threadName = createFullIdFrom(localFragmentInstanceId, localPlanNodeId);
+ localMemoryManager
+ .getQueryPool()
+ .registerPlanNodeIdToQueryMemoryMap(
+ localFragmentInstanceId.queryId, fullFragmentInstanceId, localPlanNodeId);
}
@Override
@@ -332,10 +336,6 @@ public class SourceHandle implements ISourceHandle {
bufferRetainedSizeInBytes);
bufferRetainedSizeInBytes = 0;
}
- localMemoryManager
- .getQueryPool()
- .clearMemoryReservationMap(
- localFragmentInstanceId.getQueryId(), fullFragmentInstanceId, localPlanNodeId);
aborted = true;
sourceHandleListener.onAborted(this);
}
@@ -369,10 +369,6 @@ public class SourceHandle implements ISourceHandle {
bufferRetainedSizeInBytes);
bufferRetainedSizeInBytes = 0;
}
- localMemoryManager
- .getQueryPool()
- .clearMemoryReservationMap(
- localFragmentInstanceId.getQueryId(), fullFragmentInstanceId, localPlanNodeId);
closed = true;
executorService.submit(new SendCloseSinkChannelEventTask());
currSequenceId = lastSequenceId + 1;
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/execution/fragment/FragmentInstanceExecution.java b/server/src/main/java/org/apache/iotdb/db/mpp/execution/fragment/FragmentInstanceExecution.java
index f9c17c9963..b402f0892e 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/execution/fragment/FragmentInstanceExecution.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/execution/fragment/FragmentInstanceExecution.java
@@ -20,6 +20,7 @@ package org.apache.iotdb.db.mpp.execution.fragment;
import org.apache.iotdb.db.mpp.common.FragmentInstanceId;
import org.apache.iotdb.db.mpp.execution.driver.IDriver;
+import org.apache.iotdb.db.mpp.execution.exchange.MPPDataExchangeService;
import org.apache.iotdb.db.mpp.execution.exchange.sink.ISink;
import org.apache.iotdb.db.mpp.execution.schedule.IDriverScheduler;
import org.apache.iotdb.db.utils.SetThreadName;
@@ -137,6 +138,10 @@ public class FragmentInstanceExecution {
context.releaseResource();
// help for gc
drivers = null;
+ MPPDataExchangeService.getInstance()
+ .getMPPDataExchangeManager()
+ .deRegisterFragmentInstanceFromMemoryPool(
+ instanceId.getQueryId().getId(), instanceId.getFragmentInstanceId());
if (newState.isFailed()) {
scheduler.abortFragmentInstance(instanceId);
}
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/execution/memory/MemoryPool.java b/server/src/main/java/org/apache/iotdb/db/mpp/execution/memory/MemoryPool.java
index 3c9dbb10e3..f72d935514 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/execution/memory/MemoryPool.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/execution/memory/MemoryPool.java
@@ -20,6 +20,7 @@
package org.apache.iotdb.db.mpp.execution.memory;
import org.apache.iotdb.commons.utils.TestOnly;
+import org.apache.iotdb.db.exception.runtime.MemoryLeakException;
import org.apache.iotdb.tsfile.utils.Pair;
import com.google.common.util.concurrent.AbstractFuture;
@@ -32,13 +33,13 @@ import org.slf4j.LoggerFactory;
import javax.annotation.Nullable;
import java.util.ArrayList;
-import java.util.Collections;
-import java.util.HashMap;
import java.util.Iterator;
-import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Queue;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ConcurrentLinkedQueue;
+import java.util.concurrent.atomic.AtomicLong;
/** A thread-safe memory pool. */
public class MemoryPool {
@@ -57,6 +58,8 @@ public class MemoryPool {
*/
private final long maxBytesCanReserve;
+ private boolean isMarked = false;
+
private MemoryReservationFuture(
String queryId,
String fragmentInstanceId,
@@ -77,6 +80,14 @@ public class MemoryPool {
return queryId;
}
+ public boolean isMarked() {
+ return isMarked;
+ }
+
+ public void setMarked(boolean marked) {
+ isMarked = marked;
+ }
+
public String getFragmentInstanceId() {
return fragmentInstanceId;
}
@@ -113,12 +124,13 @@ public class MemoryPool {
private final long maxBytes;
private final long maxBytesPerFragmentInstance;
- private long reservedBytes = 0L;
+ private final AtomicLong remainingBytes;
/** queryId -> fragmentInstanceId -> planNodeId -> bytesReserved */
private final Map<String, Map<String, Map<String, Long>>> queryMemoryReservations =
- new HashMap<>();
+ new ConcurrentHashMap<>();
- private final Queue<MemoryReservationFuture<Void>> memoryReservationFutures = new LinkedList<>();
+ private final Queue<MemoryReservationFuture<Void>> memoryReservationFutures =
+ new ConcurrentLinkedQueue<>();
public MemoryPool(String id, long maxBytes, long maxBytesPerFragmentInstance) {
this.id = Validate.notNull(id);
@@ -130,14 +142,51 @@ public class MemoryPool {
maxBytesPerFragmentInstance,
maxBytes);
this.maxBytesPerFragmentInstance = maxBytesPerFragmentInstance;
+ this.remainingBytes = new AtomicLong(maxBytes);
}
public String getId() {
return id;
}
- public long getMaxBytes() {
- return maxBytes;
+ /**
+ * Before executing, we register memory map which is related to queryId, fragmentInstanceId, and
+ * planNodeId to queryMemoryReservationsMap first.
+ */
+ public void registerPlanNodeIdToQueryMemoryMap(
+ String queryId, String fragmentInstanceId, String planNodeId) {
+ synchronized (queryMemoryReservations) {
+ queryMemoryReservations
+ .computeIfAbsent(queryId, x -> new ConcurrentHashMap<>())
+ .computeIfAbsent(fragmentInstanceId, x -> new ConcurrentHashMap<>())
+ .putIfAbsent(planNodeId, 0L);
+ }
+ }
+
+ /**
+ * If all fragmentInstanceIds related to one queryId have been registered, when the last fragment
+ * instance is deregister, the queryId can be cleared.
+ *
+ * <p>If some fragmentInstanceIds have not been registered when queryId is cleared, they will
+ * register queryId again with lock, so there is no concurrency problem.
+ */
+ public void deRegisterFragmentInstanceToQueryMemoryMap(
+ String queryId, String fragmentInstanceId) {
+ Map<String, Long> planNodeRelatedMemory =
+ queryMemoryReservations.get(queryId).get(fragmentInstanceId);
+ for (Long memoryReserved : planNodeRelatedMemory.values()) {
+ if (memoryReserved != 0) {
+ throw new MemoryLeakException(
+ "PlanNode related memory is not zero when deregister fragment instance from query memory pool.");
+ }
+ }
+ synchronized (queryMemoryReservations) {
+ Map<String, Map<String, Long>> queryRelatedMemory = queryMemoryReservations.get(queryId);
+ queryRelatedMemory.remove(fragmentInstanceId);
+ if (queryRelatedMemory.isEmpty()) {
+ queryMemoryReservations.remove(queryId);
+ }
+ }
}
/**
@@ -169,37 +218,23 @@ public class MemoryPool {
}
ListenableFuture<Void> result;
- synchronized (this) {
- if (maxBytes - reservedBytes < bytesToReserve
- || maxBytesCanReserve
- - queryMemoryReservations
- .getOrDefault(queryId, Collections.emptyMap())
- .getOrDefault(fragmentInstanceId, Collections.emptyMap())
- .getOrDefault(planNodeId, 0L)
- < bytesToReserve) {
- LOGGER.debug(
- "Blocked reserve request: {} bytes memory for planNodeId{}",
- bytesToReserve,
- planNodeId);
- result =
- MemoryReservationFuture.create(
- queryId, fragmentInstanceId, planNodeId, bytesToReserve, maxBytesCanReserve);
- memoryReservationFutures.add((MemoryReservationFuture<Void>) result);
- return new Pair<>(result, Boolean.FALSE);
- } else {
- reservedBytes += bytesToReserve;
- queryMemoryReservations
- .computeIfAbsent(queryId, x -> new HashMap<>())
- .computeIfAbsent(fragmentInstanceId, x -> new HashMap<>())
- .merge(planNodeId, bytesToReserve, Long::sum);
- result = Futures.immediateFuture(null);
- return new Pair<>(result, Boolean.TRUE);
- }
+ if (tryReserve(queryId, fragmentInstanceId, planNodeId, bytesToReserve, maxBytesCanReserve)) {
+ result = Futures.immediateFuture(null);
+ return new Pair<>(result, Boolean.TRUE);
+ } else {
+ LOGGER.debug(
+ "Blocked reserve request: {} bytes memory for planNodeId{}", bytesToReserve, planNodeId);
+ rollbackReserve(queryId, fragmentInstanceId, planNodeId, bytesToReserve);
+ result =
+ MemoryReservationFuture.create(
+ queryId, fragmentInstanceId, planNodeId, bytesToReserve, maxBytesCanReserve);
+ memoryReservationFutures.add((MemoryReservationFuture<Void>) result);
+ return new Pair<>(result, Boolean.FALSE);
}
}
@TestOnly
- public boolean tryReserve(
+ public boolean tryReserveForTest(
String queryId,
String fragmentInstanceId,
String planNodeId,
@@ -213,32 +248,12 @@ public class MemoryPool {
"bytes should be greater than zero while less than or equal to max bytes per fragment instance: %d",
bytesToReserve);
- if (maxBytes - reservedBytes < bytesToReserve
- || maxBytesCanReserve
- - queryMemoryReservations
- .getOrDefault(queryId, Collections.emptyMap())
- .getOrDefault(fragmentInstanceId, Collections.emptyMap())
- .getOrDefault(planNodeId, 0L)
- < bytesToReserve) {
+ if (tryReserve(queryId, fragmentInstanceId, planNodeId, bytesToReserve, maxBytesCanReserve)) {
+ return true;
+ } else {
+ rollbackReserve(queryId, fragmentInstanceId, planNodeId, bytesToReserve);
return false;
}
- synchronized (this) {
- if (maxBytes - reservedBytes < bytesToReserve
- || maxBytesCanReserve
- - queryMemoryReservations
- .getOrDefault(queryId, Collections.emptyMap())
- .getOrDefault(fragmentInstanceId, Collections.emptyMap())
- .getOrDefault(planNodeId, 0L)
- < bytesToReserve) {
- return false;
- }
- reservedBytes += bytesToReserve;
- queryMemoryReservations
- .computeIfAbsent(queryId, x -> new HashMap<>())
- .computeIfAbsent(fragmentInstanceId, x -> new HashMap<>())
- .merge(planNodeId, bytesToReserve, Long::sum);
- }
- return true;
}
/**
@@ -282,58 +297,50 @@ public class MemoryPool {
}
public void free(String queryId, String fragmentInstanceId, String planNodeId, long bytes) {
- List<MemoryReservationFuture<Void>> futureList = new ArrayList<>();
- synchronized (this) {
- Validate.notNull(queryId);
- Validate.isTrue(bytes > 0L);
-
- Long queryReservedBytes =
- queryMemoryReservations
- .getOrDefault(queryId, Collections.emptyMap())
- .getOrDefault(fragmentInstanceId, Collections.emptyMap())
- .get(planNodeId);
- Validate.notNull(queryReservedBytes);
- Validate.isTrue(bytes <= queryReservedBytes);
-
- queryReservedBytes -= bytes;
+ Validate.notNull(queryId);
+ Validate.isTrue(bytes > 0L);
+
+ try {
queryMemoryReservations
.get(queryId)
.get(fragmentInstanceId)
- .put(planNodeId, queryReservedBytes);
+ .computeIfPresent(
+ planNodeId,
+ (k, reservedMemory) -> {
+ if (reservedMemory < bytes) {
+ throw new IllegalArgumentException("Free more memory than has been reserved.");
+ }
+ return reservedMemory - bytes;
+ });
+ } catch (NullPointerException e) {
+ throw new IllegalArgumentException("RelatedMemoryReserved can't be null when freeing memory");
+ }
- reservedBytes -= bytes;
+ remainingBytes.addAndGet(bytes);
- if (memoryReservationFutures.isEmpty()) {
- return;
- }
- Iterator<MemoryReservationFuture<Void>> iterator = memoryReservationFutures.iterator();
- while (iterator.hasNext()) {
- MemoryReservationFuture<Void> future = iterator.next();
- if (future.isCancelled() || future.isDone()) {
+ List<MemoryReservationFuture<Void>> futureList = new ArrayList<>();
+ if (memoryReservationFutures.isEmpty()) {
+ return;
+ }
+ Iterator<MemoryReservationFuture<Void>> iterator = memoryReservationFutures.iterator();
+ while (iterator.hasNext()) {
+ MemoryReservationFuture<Void> future = iterator.next();
+ synchronized (future) {
+ if (future.isCancelled() || future.isDone() || future.isMarked()) {
continue;
}
long bytesToReserve = future.getBytesToReserve();
String curQueryId = future.getQueryId();
String curFragmentInstanceId = future.getFragmentInstanceId();
String curPlanNodeId = future.getPlanNodeId();
- // check total reserved bytes in memory pool
- if (maxBytes - reservedBytes < bytesToReserve) {
- continue;
- }
- // check total reserved bytes of one Sink/Source handle
- if (future.getMaxBytesCanReserve()
- - queryMemoryReservations
- .getOrDefault(curQueryId, Collections.emptyMap())
- .getOrDefault(curFragmentInstanceId, Collections.emptyMap())
- .getOrDefault(curPlanNodeId, 0L)
- >= bytesToReserve) {
- reservedBytes += bytesToReserve;
- queryMemoryReservations
- .computeIfAbsent(curQueryId, x -> new HashMap<>())
- .computeIfAbsent(curFragmentInstanceId, x -> new HashMap<>())
- .merge(curPlanNodeId, bytesToReserve, Long::sum);
+ long maxBytesCanReserve = future.getMaxBytesCanReserve();
+ if (tryReserve(
+ curQueryId, curFragmentInstanceId, curPlanNodeId, bytesToReserve, maxBytesCanReserve)) {
futureList.add(future);
+ future.setMarked(true);
iterator.remove();
+ } else {
+ rollbackReserve(curQueryId, curFragmentInstanceId, curPlanNodeId, bytesToReserve);
}
}
}
@@ -342,10 +349,10 @@ public class MemoryPool {
// If we put this block inside the MemoryPool's lock, we will get deadlock case like the
// following:
// Assuming that thread-A: LocalSourceHandle.receive() -> A-SharedTsBlockQueue.remove() ->
- // MemoryPool.free() (hold MemoryPool's lock) -> future.set(null) -> try to get
+ // MemoryPool.free() (hold future's lock) -> future.set(null) -> try to get
// B-SharedTsBlockQueue's lock
// thread-B: LocalSourceHandle.receive() -> B-SharedTsBlockQueue.remove() (hold
- // B-SharedTsBlockQueue's lock) -> try to get MemoryPool's lock
+ // B-SharedTsBlockQueue's lock) -> try to get future's lock
for (MemoryReservationFuture<Void> future : futureList) {
try {
future.set(null);
@@ -368,26 +375,31 @@ public class MemoryPool {
}
public long getReservedBytes() {
- return reservedBytes;
+ return maxBytes - remainingBytes.get();
}
- public synchronized void clearMemoryReservationMap(
- String queryId, String fragmentInstanceId, String planNodeId) {
- if (queryMemoryReservations.get(queryId) == null
- || queryMemoryReservations.get(queryId).get(fragmentInstanceId) == null) {
- return;
- }
- Map<String, Long> planNodeIdToBytesReserved =
- queryMemoryReservations.get(queryId).get(fragmentInstanceId);
- if (planNodeIdToBytesReserved.get(planNodeId) == null
- || planNodeIdToBytesReserved.get(planNodeId) <= 0) {
- planNodeIdToBytesReserved.remove(planNodeId);
- if (planNodeIdToBytesReserved.isEmpty()) {
- queryMemoryReservations.get(queryId).remove(fragmentInstanceId);
- }
- if (queryMemoryReservations.get(queryId).isEmpty()) {
- queryMemoryReservations.remove(queryId);
- }
- }
+ public boolean tryReserve(
+ String queryId,
+ String fragmentInstanceId,
+ String planNodeId,
+ long bytesToReserve,
+ long maxBytesCanReserve) {
+ long tryRemainingBytes = remainingBytes.addAndGet(-bytesToReserve);
+ long queryRemainingBytes =
+ maxBytesCanReserve
+ - queryMemoryReservations
+ .get(queryId)
+ .get(fragmentInstanceId)
+ .merge(planNodeId, bytesToReserve, Long::sum);
+ return tryRemainingBytes >= 0 && queryRemainingBytes >= 0;
+ }
+
+ private void rollbackReserve(
+ String queryId, String fragmentInstanceId, String planNodeId, long bytesToReserve) {
+ queryMemoryReservations
+ .get(queryId)
+ .get(fragmentInstanceId)
+ .merge(planNodeId, -bytesToReserve, Long::sum);
+ remainingBytes.addAndGet(bytesToReserve);
}
}
diff --git a/server/src/test/java/org/apache/iotdb/db/mpp/execution/memory/MemoryPoolTest.java b/server/src/test/java/org/apache/iotdb/db/mpp/execution/memory/MemoryPoolTest.java
index 94a7c81999..a9b924c418 100644
--- a/server/src/test/java/org/apache/iotdb/db/mpp/execution/memory/MemoryPoolTest.java
+++ b/server/src/test/java/org/apache/iotdb/db/mpp/execution/memory/MemoryPoolTest.java
@@ -36,13 +36,14 @@ public class MemoryPoolTest {
@Before
public void before() {
pool = new MemoryPool("test", 1024L, 512L);
+ pool.registerPlanNodeIdToQueryMemoryMap(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID);
}
@Test
public void testTryReserve() {
Assert.assertTrue(
- pool.tryReserve(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 256L, Long.MAX_VALUE));
+ pool.tryReserveForTest(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 256L, Long.MAX_VALUE));
Assert.assertEquals(256L, pool.getQueryMemoryReservedBytes(QUERY_ID));
Assert.assertEquals(256L, pool.getReservedBytes());
}
@@ -51,7 +52,7 @@ public class MemoryPoolTest {
public void testTryReserveZero() {
try {
- pool.tryReserve(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 0L, Long.MAX_VALUE);
+ pool.tryReserveForTest(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 0L, Long.MAX_VALUE);
Assert.fail("Expect IllegalArgumentException");
} catch (IllegalArgumentException ignore) {
}
@@ -61,7 +62,7 @@ public class MemoryPoolTest {
public void testTryReserveNegative() {
try {
- pool.tryReserve(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, -1L, Long.MAX_VALUE);
+ pool.tryReserveForTest(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, -1L, Long.MAX_VALUE);
Assert.fail("Expect IllegalArgumentException");
} catch (IllegalArgumentException ignore) {
}
@@ -71,7 +72,7 @@ public class MemoryPoolTest {
public void testTryReserveAll() {
Assert.assertTrue(
- pool.tryReserve(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 512L, Long.MAX_VALUE));
+ pool.tryReserveForTest(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 512L, Long.MAX_VALUE));
Assert.assertEquals(512L, pool.getQueryMemoryReservedBytes(QUERY_ID));
Assert.assertEquals(512L, pool.getReservedBytes());
}
@@ -79,10 +80,12 @@ public class MemoryPoolTest {
@Test
public void testOverTryReserve() {
- Assert.assertTrue(pool.tryReserve(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 256L, 512L));
+ Assert.assertTrue(
+ pool.tryReserveForTest(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 256L, 512L));
Assert.assertEquals(256L, pool.getQueryMemoryReservedBytes(QUERY_ID));
Assert.assertEquals(256L, pool.getReservedBytes());
- Assert.assertFalse(pool.tryReserve(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 512L, 511L));
+ Assert.assertFalse(
+ pool.tryReserveForTest(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 512L, 511L));
Assert.assertEquals(256L, pool.getQueryMemoryReservedBytes(QUERY_ID));
Assert.assertEquals(256L, pool.getReservedBytes());
}
@@ -206,7 +209,7 @@ public class MemoryPoolTest {
public void testFree() {
Assert.assertTrue(
- pool.tryReserve(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 512L, Long.MAX_VALUE));
+ pool.tryReserveForTest(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 512L, Long.MAX_VALUE));
Assert.assertEquals(512L, pool.getQueryMemoryReservedBytes(QUERY_ID));
Assert.assertEquals(512L, pool.getReservedBytes());
@@ -219,7 +222,7 @@ public class MemoryPoolTest {
public void testFreeAll() {
Assert.assertTrue(
- pool.tryReserve(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 512L, Long.MAX_VALUE));
+ pool.tryReserveForTest(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 512L, Long.MAX_VALUE));
Assert.assertEquals(512L, pool.getQueryMemoryReservedBytes(QUERY_ID));
Assert.assertEquals(512L, pool.getReservedBytes());
@@ -232,7 +235,7 @@ public class MemoryPoolTest {
public void testFreeZero() {
Assert.assertTrue(
- pool.tryReserve(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 512L, Long.MAX_VALUE));
+ pool.tryReserveForTest(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 512L, Long.MAX_VALUE));
Assert.assertEquals(512L, pool.getQueryMemoryReservedBytes(QUERY_ID));
Assert.assertEquals(512L, pool.getReservedBytes());
@@ -247,7 +250,7 @@ public class MemoryPoolTest {
public void testFreeNegative() {
Assert.assertTrue(
- pool.tryReserve(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 512L, Long.MAX_VALUE));
+ pool.tryReserveForTest(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 512L, Long.MAX_VALUE));
Assert.assertEquals(512L, pool.getQueryMemoryReservedBytes(QUERY_ID));
Assert.assertEquals(512L, pool.getReservedBytes());
@@ -262,7 +265,7 @@ public class MemoryPoolTest {
public void testOverFree() {
Assert.assertTrue(
- pool.tryReserve(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 512L, Long.MAX_VALUE));
+ pool.tryReserveForTest(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 512L, Long.MAX_VALUE));
Assert.assertEquals(512L, pool.getQueryMemoryReservedBytes(QUERY_ID));
Assert.assertEquals(512L, pool.getReservedBytes());
@@ -278,7 +281,7 @@ public class MemoryPoolTest {
// Run out of memory.
Assert.assertTrue(
- pool.tryReserve(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 512L, Long.MAX_VALUE));
+ pool.tryReserveForTest(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 512L, Long.MAX_VALUE));
ListenableFuture<Void> f =
pool.reserve(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 256L, 512L).left;