You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@iotdb.apache.org by ja...@apache.org on 2022/12/19 05:56:15 UTC
[iotdb] branch master updated: [IOTDB-5117] Introduce MemoryDistributionCalculator for FragmentInstance (#8485)
This is an automated email from the ASF dual-hosted git repository.
jackietien pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iotdb.git
The following commit(s) were added to refs/heads/master by this push:
new 483f0a5f35 [IOTDB-5117] Introduce MemoryDistributionCalculator for FragmentInstance (#8485)
483f0a5f35 is described below
commit 483f0a5f35379a98c23e771eb0e047d93683d78e
Author: Liao Lanyu <14...@qq.com>
AuthorDate: Mon Dec 19 13:56:08 2022 +0800
[IOTDB-5117] Introduce MemoryDistributionCalculator for FragmentInstance (#8485)
---
.../java/org/apache/iotdb/db/conf/IoTDBConfig.java | 2 +-
.../org/apache/iotdb/db/conf/IoTDBDescriptor.java | 2 +-
.../exchange/IMPPDataExchangeManager.java | 2 +
.../db/mpp/execution/exchange/ISinkHandle.java | 3 +
.../db/mpp/execution/exchange/ISourceHandle.java | 3 +
.../db/mpp/execution/exchange/LocalSinkHandle.java | 11 +-
.../mpp/execution/exchange/LocalSourceHandle.java | 5 +
.../execution/exchange/MPPDataExchangeManager.java | 21 +-
.../mpp/execution/exchange/SharedTsBlockQueue.java | 59 ++-
.../db/mpp/execution/exchange/SinkHandle.java | 50 ++-
.../db/mpp/execution/exchange/SourceHandle.java | 62 ++-
.../mpp/execution/memory/LocalMemoryManager.java | 2 +-
.../iotdb/db/mpp/execution/memory/MemoryPool.java | 206 ++++++---
.../plan/execution/memory/MemorySourceHandle.java | 3 +
.../db/mpp/plan/planner/LocalExecutionPlanner.java | 30 ++
.../plan/planner/MemoryDistributionCalculator.java | 490 +++++++++++++++++++++
.../db/mpp/plan/planner/OperatorTreeGenerator.java | 1 +
.../execution/exchange/LocalSinkHandleTest.java | 63 ++-
.../execution/exchange/LocalSourceHandleTest.java | 4 +-
.../execution/exchange/SharedTsBlockQueueTest.java | 5 +-
.../db/mpp/execution/exchange/SinkHandleTest.java | 77 +++-
.../mpp/execution/exchange/SourceHandleTest.java | 10 +-
.../db/mpp/execution/exchange/StubSinkHandle.java | 3 +
.../iotdb/db/mpp/execution/exchange/Utils.java | 52 ++-
.../db/mpp/execution/memory/MemoryPoolTest.java | 182 ++++----
25 files changed, 1157 insertions(+), 191 deletions(-)
diff --git a/server/src/main/java/org/apache/iotdb/db/conf/IoTDBConfig.java b/server/src/main/java/org/apache/iotdb/db/conf/IoTDBConfig.java
index ffed1b6e82..269a8f4f56 100644
--- a/server/src/main/java/org/apache/iotdb/db/conf/IoTDBConfig.java
+++ b/server/src/main/java/org/apache/iotdb/db/conf/IoTDBConfig.java
@@ -1502,7 +1502,7 @@ public class IoTDBConfig {
this.subRawQueryThreadCount = subRawQueryThreadCount;
}
- public long getMaxBytesPerQuery() {
+ public long getMaxBytesPerFragmentInstance() {
return allocateMemoryForDataExchange / queryThreadCount;
}
diff --git a/server/src/main/java/org/apache/iotdb/db/conf/IoTDBDescriptor.java b/server/src/main/java/org/apache/iotdb/db/conf/IoTDBDescriptor.java
index 37f44e2f41..5978d45549 100644
--- a/server/src/main/java/org/apache/iotdb/db/conf/IoTDBDescriptor.java
+++ b/server/src/main/java/org/apache/iotdb/db/conf/IoTDBDescriptor.java
@@ -1276,7 +1276,7 @@ public class IoTDBDescriptor {
(int)
Math.min(
TSFileDescriptor.getInstance().getConfig().getMaxTsBlockSizeInBytes(),
- conf.getMaxBytesPerQuery()));
+ conf.getMaxBytesPerFragmentInstance()));
TSFileDescriptor.getInstance()
.getConfig()
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/IMPPDataExchangeManager.java b/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/IMPPDataExchangeManager.java
index c120dc4f21..dfc686e699 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/IMPPDataExchangeManager.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/IMPPDataExchangeManager.java
@@ -33,6 +33,7 @@ public interface IMPPDataExchangeManager {
* @param remoteEndpoint Hostname and Port of the remote fragment instance where the data blocks
* should be sent to.
* @param remotePlanNodeId The sink plan node ID of the remote fragment instance.
+ * @param remotePlanNodeId The plan node ID of the local fragment instance.
* @param instanceContext The context of local fragment instance.
*/
ISinkHandle createSinkHandle(
@@ -40,6 +41,7 @@ public interface IMPPDataExchangeManager {
TEndPoint remoteEndpoint,
TFragmentInstanceId remoteFragmentInstanceId,
String remotePlanNodeId,
+ String localPlanNodeId,
FragmentInstanceContext instanceContext);
ISinkHandle createLocalSinkHandle(
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/ISinkHandle.java b/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/ISinkHandle.java
index 406ea5d625..4bceda0e15 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/ISinkHandle.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/ISinkHandle.java
@@ -80,4 +80,7 @@ public interface ISinkHandle {
* <p>Should only be called in normal case.
*/
void close();
+
+ /** Set max bytes this handle can reserve from memory pool */
+ void setMaxBytesCanReserve(long maxBytesCanReserve);
}
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/ISourceHandle.java b/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/ISourceHandle.java
index 2a958525bb..d056717060 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/ISourceHandle.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/ISourceHandle.java
@@ -82,4 +82,7 @@ public interface ISourceHandle {
* <p>Should only be called in normal case
*/
void close();
+
+ /** Set max bytes this handle can reserve from memory pool */
+ void setMaxBytesCanReserve(long maxBytesCanReserve);
}
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/LocalSinkHandle.java b/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/LocalSinkHandle.java
index c9d03ab833..91aa31d705 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/LocalSinkHandle.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/LocalSinkHandle.java
@@ -31,7 +31,6 @@ import org.slf4j.LoggerFactory;
import java.util.List;
import java.util.Optional;
-import static com.google.common.util.concurrent.Futures.immediateFuture;
import static com.google.common.util.concurrent.Futures.nonCancellationPropagating;
public class LocalSinkHandle implements ISinkHandle {
@@ -44,7 +43,7 @@ public class LocalSinkHandle implements ISinkHandle {
private final SinkHandleListener sinkHandleListener;
private final SharedTsBlockQueue queue;
- private volatile ListenableFuture<Void> blocked = immediateFuture(null);
+ private volatile ListenableFuture<Void> blocked;
private boolean aborted = false;
private boolean closed = false;
@@ -60,6 +59,8 @@ public class LocalSinkHandle implements ISinkHandle {
this.sinkHandleListener = Validate.notNull(sinkHandleListener);
this.queue = Validate.notNull(queue);
this.queue.setSinkHandle(this);
+ // SinkHandle can send data after SourceHandle asks it to
+ blocked = queue.getCanAddTsBlock();
}
@Override
@@ -197,4 +198,10 @@ public class LocalSinkHandle implements ISinkHandle {
throw new IllegalStateException("Sink Handle is closed.");
}
}
+
+ @Override
+ public void setMaxBytesCanReserve(long maxBytesCanReserve) {
+ // do nothing, the maxBytesCanReserve of SharedTsBlockQueue should be set by corresponding
+ // LocalSourceHandle
+ }
}
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/LocalSourceHandle.java b/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/LocalSourceHandle.java
index 64415f41dc..1df62c84bd 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/LocalSourceHandle.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/LocalSourceHandle.java
@@ -233,4 +233,9 @@ public class LocalSourceHandle implements ISourceHandle {
SharedTsBlockQueue getSharedTsBlockQueue() {
return queue;
}
+
+ @Override
+ public void setMaxBytesCanReserve(long maxBytesCanReserve) {
+ queue.setMaxBytesCanReserve(maxBytesCanReserve);
+ }
}
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/MPPDataExchangeManager.java b/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/MPPDataExchangeManager.java
index 9f7ccf8455..28146749a0 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/MPPDataExchangeManager.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/MPPDataExchangeManager.java
@@ -41,7 +41,9 @@ import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.nio.ByteBuffer;
+import java.util.ArrayList;
import java.util.Collections;
+import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Optional;
@@ -364,7 +366,8 @@ public class MPPDataExchangeManager implements IMPPDataExchangeManager {
.getSharedTsBlockQueue();
} else {
logger.debug("Create shared tsblock queue");
- queue = new SharedTsBlockQueue(remoteFragmentInstanceId, localMemoryManager);
+ queue =
+ new SharedTsBlockQueue(remoteFragmentInstanceId, remotePlanNodeId, localMemoryManager);
}
LocalSinkHandle localSinkHandle =
@@ -384,6 +387,7 @@ public class MPPDataExchangeManager implements IMPPDataExchangeManager {
TEndPoint remoteEndpoint,
TFragmentInstanceId remoteFragmentInstanceId,
String remotePlanNodeId,
+ String localPlanNodeId,
// TODO: replace with callbacks to decouple MPPDataExchangeManager from
// FragmentInstanceContext
FragmentInstanceContext instanceContext) {
@@ -402,6 +406,7 @@ public class MPPDataExchangeManager implements IMPPDataExchangeManager {
remoteEndpoint,
remoteFragmentInstanceId,
remotePlanNodeId,
+ localPlanNodeId,
localFragmentInstanceId,
localMemoryManager,
executorService,
@@ -439,7 +444,7 @@ public class MPPDataExchangeManager implements IMPPDataExchangeManager {
queue = ((LocalSinkHandle) sinkHandles.get(remoteFragmentInstanceId)).getSharedTsBlockQueue();
} else {
logger.debug("Create shared tsblock queue");
- queue = new SharedTsBlockQueue(localFragmentInstanceId, localMemoryManager);
+ queue = new SharedTsBlockQueue(localFragmentInstanceId, localPlanNodeId, localMemoryManager);
}
LocalSourceHandle localSourceHandle =
new LocalSourceHandle(
@@ -527,4 +532,16 @@ public class MPPDataExchangeManager implements IMPPDataExchangeManager {
+ "."
+ suffix;
}
+
+ public ISinkHandle getISinkHandle(TFragmentInstanceId fragmentInstanceId) {
+ return sinkHandles.get(fragmentInstanceId);
+ }
+
+ public List<ISourceHandle> getISourceHandle(TFragmentInstanceId fragmentInstanceId) {
+ if (sourceHandles.containsKey(fragmentInstanceId)) {
+ return new ArrayList<>(sourceHandles.get(fragmentInstanceId).values());
+ } else {
+ return new ArrayList<>();
+ }
+ }
}
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/SharedTsBlockQueue.java b/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/SharedTsBlockQueue.java
index a593bf5c91..95e64b1828 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/SharedTsBlockQueue.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/SharedTsBlockQueue.java
@@ -19,6 +19,7 @@
package org.apache.iotdb.db.mpp.execution.exchange;
+import org.apache.iotdb.db.conf.IoTDBDescriptor;
import org.apache.iotdb.db.mpp.execution.memory.LocalMemoryManager;
import org.apache.iotdb.mpp.rpc.thrift.TFragmentInstanceId;
import org.apache.iotdb.tsfile.read.common.block.TsBlock;
@@ -45,6 +46,8 @@ public class SharedTsBlockQueue {
private static final Logger logger = LoggerFactory.getLogger(SharedTsBlockQueue.class);
private final TFragmentInstanceId localFragmentInstanceId;
+
+ private final String localPlanNodeId;
private final LocalMemoryManager localMemoryManager;
private boolean noMoreTsBlocks = false;
@@ -55,6 +58,12 @@ public class SharedTsBlockQueue {
private SettableFuture<Void> blocked = SettableFuture.create();
+ /**
+ * this is completed after calling isBlocked for the first time which indicates this queue needs
+ * to output data
+ */
+ private final SettableFuture<Void> canAddTsBlock = SettableFuture.create();
+
private ListenableFuture<Void> blockedOnMemory;
private boolean closed = false;
@@ -62,10 +71,16 @@ public class SharedTsBlockQueue {
private LocalSourceHandle sourceHandle;
private LocalSinkHandle sinkHandle;
+ private long maxBytesCanReserve =
+ IoTDBDescriptor.getInstance().getConfig().getMaxBytesPerFragmentInstance();
+
public SharedTsBlockQueue(
- TFragmentInstanceId fragmentInstanceId, LocalMemoryManager localMemoryManager) {
+ TFragmentInstanceId fragmentInstanceId,
+ String planNodeId,
+ LocalMemoryManager localMemoryManager) {
this.localFragmentInstanceId =
Validate.notNull(fragmentInstanceId, "fragment instance ID cannot be null");
+ this.localPlanNodeId = Validate.notNull(planNodeId, "PlanNode ID cannot be null");
this.localMemoryManager =
Validate.notNull(localMemoryManager, "local memory manager cannot be null");
}
@@ -78,7 +93,18 @@ public class SharedTsBlockQueue {
return bufferRetainedSizeInBytes;
}
+ public SettableFuture<Void> getCanAddTsBlock() {
+ return canAddTsBlock;
+ }
+
+ public void setMaxBytesCanReserve(long maxBytesCanReserve) {
+ this.maxBytesCanReserve = maxBytesCanReserve;
+ }
+
public ListenableFuture<Void> isBlocked() {
+ if (!canAddTsBlock.isDone()) {
+ canAddTsBlock.set(null);
+ }
return blocked;
}
@@ -126,7 +152,11 @@ public class SharedTsBlockQueue {
}
localMemoryManager
.getQueryPool()
- .free(localFragmentInstanceId.getQueryId(), tsBlock.getRetainedSizeInBytes());
+ .free(
+ localFragmentInstanceId.getQueryId(),
+ localFragmentInstanceId.getInstanceId(),
+ localPlanNodeId,
+ tsBlock.getRetainedSizeInBytes());
bufferRetainedSizeInBytes -= tsBlock.getRetainedSizeInBytes();
if (blocked.isDone() && queue.isEmpty() && !noMoreTsBlocks) {
blocked = SettableFuture.create();
@@ -149,7 +179,12 @@ public class SharedTsBlockQueue {
Pair<ListenableFuture<Void>, Boolean> pair =
localMemoryManager
.getQueryPool()
- .reserve(localFragmentInstanceId.getQueryId(), tsBlock.getRetainedSizeInBytes());
+ .reserve(
+ localFragmentInstanceId.getQueryId(),
+ localFragmentInstanceId.getInstanceId(),
+ localPlanNodeId,
+ tsBlock.getRetainedSizeInBytes(),
+ maxBytesCanReserve);
blockedOnMemory = pair.left;
bufferRetainedSizeInBytes += tsBlock.getRetainedSizeInBytes();
@@ -191,7 +226,11 @@ public class SharedTsBlockQueue {
if (bufferRetainedSizeInBytes > 0L) {
localMemoryManager
.getQueryPool()
- .free(localFragmentInstanceId.getQueryId(), bufferRetainedSizeInBytes);
+ .free(
+ localFragmentInstanceId.getQueryId(),
+ localFragmentInstanceId.getInstanceId(),
+ localPlanNodeId,
+ bufferRetainedSizeInBytes);
bufferRetainedSizeInBytes = 0;
}
}
@@ -212,7 +251,11 @@ public class SharedTsBlockQueue {
if (bufferRetainedSizeInBytes > 0L) {
localMemoryManager
.getQueryPool()
- .free(localFragmentInstanceId.getQueryId(), bufferRetainedSizeInBytes);
+ .free(
+ localFragmentInstanceId.getQueryId(),
+ localFragmentInstanceId.getInstanceId(),
+ localPlanNodeId,
+ bufferRetainedSizeInBytes);
bufferRetainedSizeInBytes = 0;
}
}
@@ -233,7 +276,11 @@ public class SharedTsBlockQueue {
if (bufferRetainedSizeInBytes > 0L) {
localMemoryManager
.getQueryPool()
- .free(localFragmentInstanceId.getQueryId(), bufferRetainedSizeInBytes);
+ .free(
+ localFragmentInstanceId.getQueryId(),
+ localFragmentInstanceId.getInstanceId(),
+ localPlanNodeId,
+ bufferRetainedSizeInBytes);
bufferRetainedSizeInBytes = 0;
}
}
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/SinkHandle.java b/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/SinkHandle.java
index d92171b029..943d24a6e5 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/SinkHandle.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/SinkHandle.java
@@ -23,6 +23,7 @@ import org.apache.iotdb.common.rpc.thrift.TEndPoint;
import org.apache.iotdb.commons.client.IClientManager;
import org.apache.iotdb.commons.client.sync.SyncDataNodeMPPDataExchangeServiceClient;
import org.apache.iotdb.commons.utils.TestOnly;
+import org.apache.iotdb.db.conf.IoTDBDescriptor;
import org.apache.iotdb.db.mpp.execution.exchange.MPPDataExchangeManager.SinkHandleListener;
import org.apache.iotdb.db.mpp.execution.memory.LocalMemoryManager;
import org.apache.iotdb.db.utils.SetThreadName;
@@ -61,6 +62,8 @@ public class SinkHandle implements ISinkHandle {
private final TEndPoint remoteEndpoint;
private final TFragmentInstanceId remoteFragmentInstanceId;
private final String remotePlanNodeId;
+
+ private final String localPlanNodeId;
private final TFragmentInstanceId localFragmentInstanceId;
private final LocalMemoryManager localMemoryManager;
private final ExecutorService executorService;
@@ -92,10 +95,15 @@ public class SinkHandle implements ISinkHandle {
private boolean noMoreTsBlocks = false;
+ /** max bytes this SourceHandle can reserve. */
+ private long maxBytesCanReserve =
+ IoTDBDescriptor.getInstance().getConfig().getMaxBytesPerFragmentInstance();
+
public SinkHandle(
TEndPoint remoteEndpoint,
TFragmentInstanceId remoteFragmentInstanceId,
String remotePlanNodeId,
+ String localPlanNodeId,
TFragmentInstanceId localFragmentInstanceId,
LocalMemoryManager localMemoryManager,
ExecutorService executorService,
@@ -106,6 +114,7 @@ public class SinkHandle implements ISinkHandle {
this.remoteEndpoint = Validate.notNull(remoteEndpoint);
this.remoteFragmentInstanceId = Validate.notNull(remoteFragmentInstanceId);
this.remotePlanNodeId = Validate.notNull(remotePlanNodeId);
+ this.localPlanNodeId = Validate.notNull(localPlanNodeId);
this.localFragmentInstanceId = Validate.notNull(localFragmentInstanceId);
this.localMemoryManager = Validate.notNull(localMemoryManager);
this.executorService = Validate.notNull(executorService);
@@ -121,7 +130,14 @@ public class SinkHandle implements ISinkHandle {
this.blocked =
localMemoryManager
.getQueryPool()
- .reserve(localFragmentInstanceId.getQueryId(), DEFAULT_MAX_TSBLOCK_SIZE_IN_BYTES)
+ .reserve(
+ localFragmentInstanceId.getQueryId(),
+ localFragmentInstanceId.getInstanceId(),
+ localPlanNodeId,
+ DEFAULT_MAX_TSBLOCK_SIZE_IN_BYTES,
+ DEFAULT_MAX_TSBLOCK_SIZE_IN_BYTES) // actually we only know maxBytesCanReserve after
+ // the handle is created, so we use DEFAULT here. It is ok to use DEFAULT here because
+ // at first this SinkHandle has not reserved memory.
.left;
this.bufferRetainedSizeInBytes = DEFAULT_MAX_TSBLOCK_SIZE_IN_BYTES;
this.currentTsBlockSize = DEFAULT_MAX_TSBLOCK_SIZE_IN_BYTES;
@@ -153,7 +169,12 @@ public class SinkHandle implements ISinkHandle {
blocked =
localMemoryManager
.getQueryPool()
- .reserve(localFragmentInstanceId.getQueryId(), retainedSizeInBytes)
+ .reserve(
+ localFragmentInstanceId.getQueryId(),
+ localFragmentInstanceId.getInstanceId(),
+ localPlanNodeId,
+ retainedSizeInBytes,
+ maxBytesCanReserve)
.left;
bufferRetainedSizeInBytes += retainedSizeInBytes;
@@ -188,7 +209,11 @@ public class SinkHandle implements ISinkHandle {
if (bufferRetainedSizeInBytes > 0) {
localMemoryManager
.getQueryPool()
- .free(localFragmentInstanceId.getQueryId(), bufferRetainedSizeInBytes);
+ .free(
+ localFragmentInstanceId.getQueryId(),
+ localFragmentInstanceId.getInstanceId(),
+ localPlanNodeId,
+ bufferRetainedSizeInBytes);
bufferRetainedSizeInBytes = 0;
}
sinkHandleListener.onAborted(this);
@@ -204,7 +229,11 @@ public class SinkHandle implements ISinkHandle {
if (bufferRetainedSizeInBytes > 0) {
localMemoryManager
.getQueryPool()
- .free(localFragmentInstanceId.getQueryId(), bufferRetainedSizeInBytes);
+ .free(
+ localFragmentInstanceId.getQueryId(),
+ localFragmentInstanceId.getInstanceId(),
+ localPlanNodeId,
+ bufferRetainedSizeInBytes);
bufferRetainedSizeInBytes = 0;
}
sinkHandleListener.onFinish(this);
@@ -282,7 +311,13 @@ public class SinkHandle implements ISinkHandle {
// there may exist duplicate ack message in network caused by caller retrying, if so duplicate
// ack message's freedBytes may be zero
if (freedBytes > 0) {
- localMemoryManager.getQueryPool().free(localFragmentInstanceId.getQueryId(), freedBytes);
+ localMemoryManager
+ .getQueryPool()
+ .free(
+ localFragmentInstanceId.getQueryId(),
+ localFragmentInstanceId.getInstanceId(),
+ localPlanNodeId,
+ freedBytes);
}
}
@@ -302,6 +337,11 @@ public class SinkHandle implements ISinkHandle {
return localFragmentInstanceId;
}
+ @Override
+ public void setMaxBytesCanReserve(long maxBytesCanReserve) {
+ this.maxBytesCanReserve = maxBytesCanReserve;
+ }
+
@Override
public String toString() {
return String.format(
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/SourceHandle.java b/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/SourceHandle.java
index c672383ca3..cda072ff0f 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/SourceHandle.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/SourceHandle.java
@@ -23,6 +23,7 @@ import org.apache.iotdb.common.rpc.thrift.TEndPoint;
import org.apache.iotdb.commons.client.IClientManager;
import org.apache.iotdb.commons.client.sync.SyncDataNodeMPPDataExchangeServiceClient;
import org.apache.iotdb.commons.utils.TestOnly;
+import org.apache.iotdb.db.conf.IoTDBDescriptor;
import org.apache.iotdb.db.mpp.execution.exchange.MPPDataExchangeManager.SourceHandleListener;
import org.apache.iotdb.db.mpp.execution.memory.LocalMemoryManager;
import org.apache.iotdb.db.utils.SetThreadName;
@@ -89,6 +90,16 @@ public class SourceHandle implements ISourceHandle {
private boolean closed = false;
+ /** max bytes this SourceHandle can reserve. */
+ private long maxBytesCanReserve =
+ IoTDBDescriptor.getInstance().getConfig().getMaxBytesPerFragmentInstance();
+
+ /**
+ * this is set to true after calling isBlocked() at least once which indicates that this
+ * SourceHandle needs to output data
+ */
+ private boolean canGetTsBlockFromRemote = false;
+
public SourceHandle(
TEndPoint remoteEndpoint,
TFragmentInstanceId remoteFragmentInstanceId,
@@ -127,7 +138,6 @@ public class SourceHandle implements ISourceHandle {
@Override
public synchronized ByteBuffer getSerializedTsBlock() {
try (SetThreadName sourceHandleName = new SetThreadName(threadName)) {
-
checkState();
if (!blocked.isDone()) {
@@ -142,7 +152,13 @@ public class SourceHandle implements ISourceHandle {
logger.debug("[GetTsBlockFromBuffer] sequenceId:{}, size:{}", currSequenceId, retainedSize);
currSequenceId += 1;
bufferRetainedSizeInBytes -= retainedSize;
- localMemoryManager.getQueryPool().free(localFragmentInstanceId.getQueryId(), retainedSize);
+ localMemoryManager
+ .getQueryPool()
+ .free(
+ localFragmentInstanceId.getQueryId(),
+ localFragmentInstanceId.getInstanceId(),
+ localPlanNodeId,
+ retainedSize);
if (sequenceIdToTsBlock.isEmpty() && !isFinished()) {
logger.debug("[WaitForMoreTsBlock]");
@@ -177,7 +193,12 @@ public class SourceHandle implements ISourceHandle {
pair =
localMemoryManager
.getQueryPool()
- .reserve(localFragmentInstanceId.getQueryId(), bytesToReserve);
+ .reserve(
+ localFragmentInstanceId.getQueryId(),
+ localFragmentInstanceId.getInstanceId(),
+ localPlanNodeId,
+ bytesToReserve,
+ maxBytesCanReserve);
bufferRetainedSizeInBytes += bytesToReserve;
endSequenceId += 1;
reservedBytes += bytesToReserve;
@@ -224,6 +245,12 @@ public class SourceHandle implements ISourceHandle {
@Override
public synchronized ListenableFuture<?> isBlocked() {
checkState();
+ if (!canGetTsBlockFromRemote) {
+ canGetTsBlockFromRemote = true;
+ // submit get data task once isBlocked is called to ensure that the blocked future will be
+ // completed in case that trySubmitGetDataBlocksTask() is not called.
+ trySubmitGetDataBlocksTask();
+ }
return nonCancellationPropagating(blocked);
}
@@ -247,7 +274,9 @@ public class SourceHandle implements ISourceHandle {
for (int i = 0; i < dataBlockSizes.size(); i++) {
sequenceIdToDataBlockSize.put(i + startSequenceId, dataBlockSizes.get(i));
}
- trySubmitGetDataBlocksTask();
+ if (canGetTsBlockFromRemote) {
+ trySubmitGetDataBlocksTask();
+ }
}
@Override
@@ -266,7 +295,11 @@ public class SourceHandle implements ISourceHandle {
if (bufferRetainedSizeInBytes > 0) {
localMemoryManager
.getQueryPool()
- .free(localFragmentInstanceId.getQueryId(), bufferRetainedSizeInBytes);
+ .free(
+ localFragmentInstanceId.getQueryId(),
+ localFragmentInstanceId.getInstanceId(),
+ localPlanNodeId,
+ bufferRetainedSizeInBytes);
bufferRetainedSizeInBytes = 0;
}
aborted = true;
@@ -295,7 +328,11 @@ public class SourceHandle implements ISourceHandle {
if (bufferRetainedSizeInBytes > 0) {
localMemoryManager
.getQueryPool()
- .free(localFragmentInstanceId.getQueryId(), bufferRetainedSizeInBytes);
+ .free(
+ localFragmentInstanceId.getQueryId(),
+ localFragmentInstanceId.getInstanceId(),
+ localPlanNodeId,
+ bufferRetainedSizeInBytes);
bufferRetainedSizeInBytes = 0;
}
closed = true;
@@ -337,6 +374,11 @@ public class SourceHandle implements ISourceHandle {
return bufferRetainedSizeInBytes;
}
+ @Override
+ public void setMaxBytesCanReserve(long maxBytesCanReserve) {
+ this.maxBytesCanReserve = maxBytesCanReserve;
+ }
+
@Override
public boolean isAborted() {
return aborted;
@@ -454,7 +496,13 @@ public class SourceHandle implements ISourceHandle {
return;
}
bufferRetainedSizeInBytes -= reservedBytes;
- localMemoryManager.getQueryPool().free(localFragmentInstanceId.getQueryId(), reservedBytes);
+ localMemoryManager
+ .getQueryPool()
+ .free(
+ localFragmentInstanceId.getQueryId(),
+ localFragmentInstanceId.getInstanceId(),
+ localPlanNodeId,
+ reservedBytes);
sourceHandleListener.onFailure(SourceHandle.this, t);
}
}
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/execution/memory/LocalMemoryManager.java b/server/src/main/java/org/apache/iotdb/db/mpp/execution/memory/LocalMemoryManager.java
index e3d170cbf7..1152dfcf10 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/execution/memory/LocalMemoryManager.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/execution/memory/LocalMemoryManager.java
@@ -34,7 +34,7 @@ public class LocalMemoryManager {
new MemoryPool(
"query",
IoTDBDescriptor.getInstance().getConfig().getAllocateMemoryForDataExchange(),
- IoTDBDescriptor.getInstance().getConfig().getMaxBytesPerQuery());
+ IoTDBDescriptor.getInstance().getConfig().getMaxBytesPerFragmentInstance());
}
public MemoryPool getQueryPool() {
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/execution/memory/MemoryPool.java b/server/src/main/java/org/apache/iotdb/db/mpp/execution/memory/MemoryPool.java
index 8494fa4a23..b247846e59 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/execution/memory/MemoryPool.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/execution/memory/MemoryPool.java
@@ -19,6 +19,7 @@
package org.apache.iotdb.db.mpp.execution.memory;
+import org.apache.iotdb.commons.utils.TestOnly;
import org.apache.iotdb.tsfile.utils.Pair;
import com.google.common.util.concurrent.AbstractFuture;
@@ -31,6 +32,7 @@ import org.slf4j.LoggerFactory;
import javax.annotation.Nullable;
import java.util.ArrayList;
+import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
@@ -44,25 +46,61 @@ public class MemoryPool {
private static final Logger LOGGER = LoggerFactory.getLogger(MemoryPool.class);
public static class MemoryReservationFuture<V> extends AbstractFuture<V> {
- private final String queryId;
- private final long bytes;
- private MemoryReservationFuture(String queryId, long bytes) {
+ private final String queryId;
+ private final String fragmentInstanceId;
+ private final String planNodeId;
+ private final long bytesToReserve;
+ /**
+ * MemoryReservationFuture is created when SinkHandle or SourceHandle tries to reserve memory
+ * from pool. This field is max Bytes that SinkHandle or SourceHandle can reserve.
+ */
+ private final long maxBytesCanReserve;
+
+ private MemoryReservationFuture(
+ String queryId,
+ String fragmentInstanceId,
+ String planNodeId,
+ long bytesToReserve,
+ long maxBytesCanReserve) {
this.queryId = Validate.notNull(queryId, "queryId cannot be null");
- Validate.isTrue(bytes > 0L, "bytes should be greater than zero.");
- this.bytes = bytes;
- }
-
- public long getBytes() {
- return bytes;
+ this.fragmentInstanceId =
+ Validate.notNull(fragmentInstanceId, "fragmentInstanceId cannot be null");
+ this.planNodeId = Validate.notNull(planNodeId, "planNodeId cannot be null");
+ Validate.isTrue(bytesToReserve > 0L, "bytesToReserve should be greater than zero.");
+ Validate.isTrue(maxBytesCanReserve > 0L, "maxBytesCanReserve should be greater than zero.");
+ this.bytesToReserve = bytesToReserve;
+ this.maxBytesCanReserve = maxBytesCanReserve;
}
public String getQueryId() {
return queryId;
}
- public static <V> MemoryReservationFuture<V> create(String queryId, long bytes) {
- return new MemoryReservationFuture<>(queryId, bytes);
+ public String getFragmentInstanceId() {
+ return fragmentInstanceId;
+ }
+
+ public String getPlanNodeId() {
+ return planNodeId;
+ }
+
+ public long getBytesToReserve() {
+ return bytesToReserve;
+ }
+
+ public long getMaxBytesCanReserve() {
+ return maxBytesCanReserve;
+ }
+
+ public static <V> MemoryReservationFuture<V> create(
+ String queryId,
+ String fragmentInstanceId,
+ String planNodeId,
+ long bytesToReserve,
+ long maxBytesCanReserve) {
+ return new MemoryReservationFuture<>(
+ queryId, fragmentInstanceId, planNodeId, bytesToReserve, maxBytesCanReserve);
}
@Override
@@ -73,22 +111,25 @@ public class MemoryPool {
private final String id;
private final long maxBytes;
- private final long maxBytesPerQuery;
+ private final long maxBytesPerFragmentInstance;
private long reservedBytes = 0L;
- private final Map<String, Long> queryMemoryReservations = new HashMap<>();
+ /** queryId -> fragmentInstanceId -> planNodeId -> bytesReserved */
+ private final Map<String, Map<String, Map<String, Long>>> queryMemoryReservations =
+ new HashMap<>();
+
private final Queue<MemoryReservationFuture<Void>> memoryReservationFutures = new LinkedList<>();
- public MemoryPool(String id, long maxBytes, long maxBytesPerQuery) {
+ public MemoryPool(String id, long maxBytes, long maxBytesPerFragmentInstance) {
this.id = Validate.notNull(id);
Validate.isTrue(maxBytes > 0L, "max bytes should be greater than zero: %d", maxBytes);
this.maxBytes = maxBytes;
Validate.isTrue(
- maxBytesPerQuery > 0L && maxBytesPerQuery <= maxBytes,
+ maxBytesPerFragmentInstance > 0L && maxBytesPerFragmentInstance <= maxBytes,
"max bytes per query should be greater than zero while less than or equal to max bytes. maxBytesPerQuery: %d, maxBytes: %d",
- maxBytesPerQuery,
+ maxBytesPerFragmentInstance,
maxBytes);
- this.maxBytesPerQuery = maxBytesPerQuery;
+ this.maxBytesPerFragmentInstance = maxBytesPerFragmentInstance;
}
public String getId() {
@@ -100,56 +141,93 @@ public class MemoryPool {
}
/** @return if reserve succeed, pair.right will be true, otherwise false */
- public Pair<ListenableFuture<Void>, Boolean> reserve(String queryId, long bytes) {
+ public Pair<ListenableFuture<Void>, Boolean> reserve(
+ String queryId,
+ String fragmentInstanceId,
+ String planNodeId,
+ long bytesToReserve,
+ long maxBytesCanReserve) {
Validate.notNull(queryId);
+ Validate.notNull(fragmentInstanceId);
+ Validate.notNull(planNodeId);
Validate.isTrue(
- bytes > 0L && bytes <= maxBytesPerQuery,
- "bytes should be greater than zero while less than or equal to max bytes per query: %d",
- bytes);
+ bytesToReserve > 0L && bytesToReserve <= maxBytesPerFragmentInstance,
+ "bytes should be greater than zero while less than or equal to max bytes per fragment instance: %d",
+ bytesToReserve);
ListenableFuture<Void> result;
synchronized (this) {
- if (maxBytes - reservedBytes < bytes
- || maxBytesPerQuery - queryMemoryReservations.getOrDefault(queryId, 0L) < bytes) {
- result = MemoryReservationFuture.create(queryId, bytes);
+ if (maxBytes - reservedBytes < bytesToReserve
+ || maxBytesCanReserve
+ - queryMemoryReservations
+ .getOrDefault(queryId, Collections.emptyMap())
+ .getOrDefault(fragmentInstanceId, Collections.emptyMap())
+ .getOrDefault(planNodeId, 0L)
+ < bytesToReserve) {
+ result =
+ MemoryReservationFuture.create(
+ queryId, fragmentInstanceId, planNodeId, bytesToReserve, maxBytesCanReserve);
memoryReservationFutures.add((MemoryReservationFuture<Void>) result);
return new Pair<>(result, Boolean.FALSE);
} else {
- reservedBytes += bytes;
- queryMemoryReservations.merge(queryId, bytes, Long::sum);
+ reservedBytes += bytesToReserve;
+ queryMemoryReservations
+ .computeIfAbsent(queryId, x -> new HashMap<>())
+ .computeIfAbsent(fragmentInstanceId, x -> new HashMap<>())
+ .merge(planNodeId, bytesToReserve, Long::sum);
result = Futures.immediateFuture(null);
return new Pair<>(result, Boolean.TRUE);
}
}
}
- public boolean tryReserve(String queryId, long bytes) {
+ @TestOnly
+ public boolean tryReserve(
+ String queryId,
+ String fragmentInstanceId,
+ String planNodeId,
+ long bytesToReserve,
+ long maxBytesCanReserve) {
Validate.notNull(queryId);
+ Validate.notNull(fragmentInstanceId);
+ Validate.notNull(planNodeId);
Validate.isTrue(
- bytes > 0L && bytes <= maxBytesPerQuery,
- "bytes should be greater than zero while less than or equal to max bytes per query: %d",
- bytes);
-
- if (maxBytes - reservedBytes < bytes
- || maxBytesPerQuery - queryMemoryReservations.getOrDefault(queryId, 0L) < bytes) {
+ bytesToReserve > 0L && bytesToReserve <= maxBytesPerFragmentInstance,
+ "bytes should be greater than zero while less than or equal to max bytes per fragment instance: %d",
+ bytesToReserve);
+
+ if (maxBytes - reservedBytes < bytesToReserve
+ || maxBytesCanReserve
+ - queryMemoryReservations
+ .getOrDefault(queryId, Collections.emptyMap())
+ .getOrDefault(fragmentInstanceId, Collections.emptyMap())
+ .getOrDefault(planNodeId, 0L)
+ < bytesToReserve) {
return false;
}
synchronized (this) {
- if (maxBytes - reservedBytes < bytes
- || maxBytesPerQuery - queryMemoryReservations.getOrDefault(queryId, 0L) < bytes) {
+ if (maxBytes - reservedBytes < bytesToReserve
+ || maxBytesCanReserve
+ - queryMemoryReservations
+ .getOrDefault(queryId, Collections.emptyMap())
+ .getOrDefault(fragmentInstanceId, Collections.emptyMap())
+ .getOrDefault(planNodeId, 0L)
+ < bytesToReserve) {
return false;
}
- reservedBytes += bytes;
- queryMemoryReservations.merge(queryId, bytes, Long::sum);
+ reservedBytes += bytesToReserve;
+ queryMemoryReservations
+ .computeIfAbsent(queryId, x -> new HashMap<>())
+ .computeIfAbsent(fragmentInstanceId, x -> new HashMap<>())
+ .merge(planNodeId, bytesToReserve, Long::sum);
}
-
return true;
}
/**
* Cancel the specified memory reservation. If the reservation has finished, do nothing.
*
- * @param future The future returned from {@link #reserve(String, long)}
+ * @param future The future returned from {@link #reserve(String, String, String, long, long)}
* @return If the future has not complete, return the number of bytes being reserved. Otherwise,
* return 0.
*/
@@ -163,13 +241,13 @@ public class MemoryPool {
future instanceof MemoryReservationFuture,
"invalid future type " + future.getClass().getSimpleName());
future.cancel(true);
- return ((MemoryReservationFuture<Void>) future).getBytes();
+ return ((MemoryReservationFuture<Void>) future).getBytesToReserve();
}
/**
* Complete the specified memory reservation. If the reservation has finished, do nothing.
*
- * @param future The future returned from {@link #reserve(String, long)}
+ * @param future The future returned from {@link #reserve(String, String, String, long, long)}
* @return If the future has not complete, return the number of bytes being reserved. Otherwise,
* return 0.
*/
@@ -183,24 +261,31 @@ public class MemoryPool {
future instanceof MemoryReservationFuture,
"invalid future type " + future.getClass().getSimpleName());
((MemoryReservationFuture<Void>) future).set(null);
- return ((MemoryReservationFuture<Void>) future).getBytes();
+ return ((MemoryReservationFuture<Void>) future).getBytesToReserve();
}
- public void free(String queryId, long bytes) {
+ public void free(String queryId, String fragmentInstanceId, String planNodeId, long bytes) {
List<MemoryReservationFuture<Void>> futureList = new ArrayList<>();
synchronized (this) {
Validate.notNull(queryId);
Validate.isTrue(bytes > 0L);
- Long queryReservedBytes = queryMemoryReservations.get(queryId);
+ Long queryReservedBytes =
+ queryMemoryReservations
+ .getOrDefault(queryId, Collections.emptyMap())
+ .getOrDefault(fragmentInstanceId, Collections.emptyMap())
+ .get(planNodeId);
Validate.notNull(queryReservedBytes);
Validate.isTrue(bytes <= queryReservedBytes);
queryReservedBytes -= bytes;
if (queryReservedBytes == 0) {
- queryMemoryReservations.remove(queryId);
+ queryMemoryReservations.get(queryId).get(fragmentInstanceId).remove(planNodeId);
} else {
- queryMemoryReservations.put(queryId, queryReservedBytes);
+ queryMemoryReservations
+ .get(queryId)
+ .get(fragmentInstanceId)
+ .put(planNodeId, queryReservedBytes);
}
reservedBytes -= bytes;
@@ -213,14 +298,26 @@ public class MemoryPool {
if (future.isCancelled() || future.isDone()) {
continue;
}
- long bytesToReserve = future.getBytes();
+ long bytesToReserve = future.getBytesToReserve();
+ String curQueryId = future.getQueryId();
+ String curFragmentInstanceId = future.getFragmentInstanceId();
+ String curPlanNodeId = future.getPlanNodeId();
+ // check total reserved bytes in memory pool
if (maxBytes - reservedBytes < bytesToReserve) {
- return;
+ continue;
}
- if (maxBytesPerQuery - queryMemoryReservations.getOrDefault(future.getQueryId(), 0L)
+ // check total reserved bytes of one Sink/Source handle
+ if (future.getMaxBytesCanReserve()
+ - queryMemoryReservations
+ .getOrDefault(curQueryId, Collections.emptyMap())
+ .getOrDefault(curFragmentInstanceId, Collections.emptyMap())
+ .getOrDefault(curPlanNodeId, 0L)
>= bytesToReserve) {
reservedBytes += bytesToReserve;
- queryMemoryReservations.merge(future.getQueryId(), bytesToReserve, Long::sum);
+ queryMemoryReservations
+ .computeIfAbsent(curQueryId, x -> new HashMap<>())
+ .computeIfAbsent(curFragmentInstanceId, x -> new HashMap<>())
+ .merge(curPlanNodeId, bytesToReserve, Long::sum);
futureList.add(future);
iterator.remove();
}
@@ -246,7 +343,14 @@ public class MemoryPool {
}
public long getQueryMemoryReservedBytes(String queryId) {
- return queryMemoryReservations.getOrDefault(queryId, 0L);
+ if (!queryMemoryReservations.containsKey(queryId)) {
+ return 0L;
+ }
+ long sum = 0;
+ for (Map<String, Long> map : queryMemoryReservations.get(queryId).values()) {
+ sum = sum + map.values().stream().reduce(0L, Long::sum);
+ }
+ return sum;
}
public long getReservedBytes() {
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/execution/memory/MemorySourceHandle.java b/server/src/main/java/org/apache/iotdb/db/mpp/plan/execution/memory/MemorySourceHandle.java
index 29b612aa7d..e2b60ef441 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/plan/execution/memory/MemorySourceHandle.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/plan/execution/memory/MemorySourceHandle.java
@@ -107,4 +107,7 @@ public class MemorySourceHandle implements ISourceHandle {
@Override
public void close() {}
+
+ @Override
+ public void setMaxBytesCanReserve(long maxBytesCanReserve) {}
}
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/LocalExecutionPlanner.java b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/LocalExecutionPlanner.java
index 6015e20c71..5b6d2abc86 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/LocalExecutionPlanner.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/LocalExecutionPlanner.java
@@ -26,6 +26,8 @@ import org.apache.iotdb.db.mpp.execution.driver.DataDriver;
import org.apache.iotdb.db.mpp.execution.driver.DataDriverContext;
import org.apache.iotdb.db.mpp.execution.driver.SchemaDriver;
import org.apache.iotdb.db.mpp.execution.driver.SchemaDriverContext;
+import org.apache.iotdb.db.mpp.execution.exchange.ISourceHandle;
+import org.apache.iotdb.db.mpp.execution.exchange.MPPDataExchangeService;
import org.apache.iotdb.db.mpp.execution.fragment.FragmentInstanceContext;
import org.apache.iotdb.db.mpp.execution.fragment.FragmentInstanceStateMachine;
import org.apache.iotdb.db.mpp.execution.operator.Operator;
@@ -33,6 +35,7 @@ import org.apache.iotdb.db.mpp.execution.timer.ITimeSliceAllocator;
import org.apache.iotdb.db.mpp.plan.analyze.TypeProvider;
import org.apache.iotdb.db.mpp.plan.planner.plan.node.PlanNode;
import org.apache.iotdb.db.utils.SetThreadName;
+import org.apache.iotdb.mpp.rpc.thrift.TFragmentInstanceId;
import org.apache.iotdb.rpc.TSStatusCode;
import org.apache.iotdb.tsfile.read.filter.basic.Filter;
@@ -71,6 +74,9 @@ public class LocalExecutionPlanner {
// check whether current free memory is enough to execute current query
checkMemory(root, instanceContext.getStateMachine());
+ // calculate memory distribution of ISinkHandle/ISourceHandle
+ setMemoryLimitForHandle(instanceContext.getId().toThrift(), plan);
+
ITimeSliceAllocator timeSliceAllocator = context.getTimeSliceAllocator();
instanceContext
.getOperatorContexts()
@@ -101,6 +107,9 @@ public class LocalExecutionPlanner {
Operator root = plan.accept(new OperatorTreeGenerator(), context);
+ // calculate memory distribution of ISinkHandle/ISourceHandle
+ setMemoryLimitForHandle(instanceContext.getId().toThrift(), plan);
+
// check whether current free memory is enough to execute current query
checkMemory(root, instanceContext.getStateMachine());
@@ -114,6 +123,27 @@ public class LocalExecutionPlanner {
return new SchemaDriver(root, context.getSinkHandle(), schemaDriverContext);
}
+ private void setMemoryLimitForHandle(TFragmentInstanceId fragmentInstanceId, PlanNode plan) {
+ MemoryDistributionCalculator visitor = new MemoryDistributionCalculator();
+ plan.accept(visitor, null);
+ long totalSplit = visitor.calculateTotalSplit();
+ if (totalSplit == 0) {
+ return;
+ }
+ long maxBytesOneHandleCanReserve =
+ IoTDBDescriptor.getInstance().getConfig().getMaxBytesPerFragmentInstance() / totalSplit;
+ for (ISourceHandle handle :
+ MPPDataExchangeService.getInstance()
+ .getMPPDataExchangeManager()
+ .getISourceHandle(fragmentInstanceId)) {
+ handle.setMaxBytesCanReserve(maxBytesOneHandleCanReserve);
+ }
+ MPPDataExchangeService.getInstance()
+ .getMPPDataExchangeManager()
+ .getISinkHandle(fragmentInstanceId)
+ .setMaxBytesCanReserve(maxBytesOneHandleCanReserve);
+ }
+
private void checkMemory(Operator root, FragmentInstanceStateMachine stateMachine)
throws MemoryNotEnoughException {
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/MemoryDistributionCalculator.java b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/MemoryDistributionCalculator.java
new file mode 100644
index 0000000000..fb117804a7
--- /dev/null
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/MemoryDistributionCalculator.java
@@ -0,0 +1,490 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iotdb.db.mpp.plan.planner;
+
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.PlanNode;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.PlanNodeId;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.PlanVisitor;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.metedata.read.CountSchemaMergeNode;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.metedata.read.DevicesCountNode;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.metedata.read.DevicesSchemaScanNode;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.metedata.read.LevelTimeSeriesCountNode;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.metedata.read.NodeManagementMemoryMergeNode;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.metedata.read.NodePathsConvertNode;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.metedata.read.NodePathsCountNode;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.metedata.read.NodePathsSchemaScanNode;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.metedata.read.SchemaFetchMergeNode;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.metedata.read.SchemaFetchScanNode;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.metedata.read.SchemaQueryMergeNode;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.metedata.read.SchemaQueryOrderByHeatNode;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.metedata.read.SchemaQueryScanNode;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.metedata.read.TimeSeriesCountNode;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.metedata.read.TimeSeriesSchemaScanNode;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.metedata.write.ConstructSchemaBlackListNode;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.metedata.write.RollbackSchemaBlackListNode;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.AggregationNode;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.DeviceMergeNode;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.DeviceViewIntoNode;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.DeviceViewNode;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.ExchangeNode;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.FillNode;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.FilterNode;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.GroupByLevelNode;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.GroupByTagNode;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.IntoNode;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.LimitNode;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.MergeSortNode;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.OffsetNode;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.ProjectNode;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.SingleDeviceViewNode;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.SlidingWindowAggregationNode;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.SortNode;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.TimeJoinNode;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.TransformNode;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.VerticallyConcatNode;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.last.LastQueryCollectNode;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.last.LastQueryMergeNode;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.process.last.LastQueryNode;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.sink.FragmentSinkNode;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.source.AlignedLastQueryScanNode;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.source.AlignedSeriesAggregationScanNode;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.source.AlignedSeriesScanNode;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.source.LastQueryScanNode;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.source.SeriesAggregationScanNode;
+import org.apache.iotdb.db.mpp.plan.planner.plan.node.source.SeriesScanNode;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.apache.iotdb.db.mpp.plan.constant.DataNodeEndPoints.isSameNode;
+
+public class MemoryDistributionCalculator
+ extends PlanVisitor<Void, MemoryDistributionCalculator.MemoryDistributionContext> {
+ /** This map is used to calculate the total split of memory */
+ private final Map<PlanNodeId, List<PlanNodeId>> exchangeMap;
+
+ public MemoryDistributionCalculator() {
+ this.exchangeMap = new HashMap<>();
+ }
+
+ public long calculateTotalSplit() {
+ long res = 0;
+ for (List<PlanNodeId> l : exchangeMap.values()) {
+ res += l.size();
+ }
+ return res;
+ }
+
+ @Override
+ public Void visitPlan(PlanNode node, MemoryDistributionContext context) {
+ // Throw exception here because we want to ensure that all new PlanNodes implement
+ // this method correctly if necessary.
+ throw new UnsupportedOperationException("Should call concrete visitXX method");
+ }
+
+ private void processConsumeChildrenOneByOneNode(PlanNode node) {
+ MemoryDistributionContext context =
+ new MemoryDistributionContext(
+ node.getPlanNodeId(), MemoryDistributionType.CONSUME_CHILDREN_ONE_BY_ONE);
+ node.getChildren()
+ .forEach(
+ child -> {
+ if (child != null) {
+ child.accept(this, context);
+ }
+ });
+ }
+
+ private void processConsumeAllChildrenAtTheSameTime(PlanNode node) {
+ MemoryDistributionContext context =
+ new MemoryDistributionContext(
+ node.getPlanNodeId(), MemoryDistributionType.CONSUME_ALL_CHILDREN_AT_THE_SAME_TIME);
+ node.getChildren()
+ .forEach(
+ child -> {
+ if (child != null) {
+ child.accept(this, context);
+ }
+ });
+ }
+
+ @Override
+ public Void visitExchange(ExchangeNode node, MemoryDistributionContext context) {
+ // we do not distinguish LocalSourceHandle/SourceHandle by not letting LocalSinkHandle update
+ // the map
+ if (context == null) {
+ // context == null means this ExchangeNode has no father
+ exchangeMap
+ .computeIfAbsent(node.getPlanNodeId(), x -> new ArrayList<>())
+ .add(node.getPlanNodeId());
+ } else {
+ if (context.memoryDistributionType.equals(
+ MemoryDistributionType.CONSUME_ALL_CHILDREN_AT_THE_SAME_TIME)) {
+ exchangeMap
+ .computeIfAbsent(context.planNodeId, x -> new ArrayList<>())
+ .add(node.getPlanNodeId());
+ } else if (context.memoryDistributionType.equals(
+ MemoryDistributionType.CONSUME_CHILDREN_ONE_BY_ONE)
+ && !exchangeMap.containsKey(context.planNodeId)) {
+ // All children share one split, thus only one node needs to be put into the map
+ exchangeMap
+ .computeIfAbsent(context.planNodeId, x -> new ArrayList<>())
+ .add(node.getPlanNodeId());
+ }
+ }
+ return null;
+ }
+
+ @Override
+ public Void visitFragmentSink(FragmentSinkNode node, MemoryDistributionContext context) {
+ // LocalSinkHandle and LocalSourceHandle are one-to-one mapped and only LocalSourceHandle do the
+ // update
+ if (!isSameNode(node.getDownStreamEndpoint())) {
+ exchangeMap
+ .computeIfAbsent(node.getDownStreamPlanNodeId(), x -> new ArrayList<>())
+ .add(node.getDownStreamPlanNodeId());
+ }
+ return null;
+ }
+
+ @Override
+ public Void visitSeriesScan(SeriesScanNode node, MemoryDistributionContext context) {
+ // do nothing since SourceNode will not have Exchange/FragmentSink as child
+ return null;
+ }
+
+ @Override
+ public Void visitSeriesAggregationScan(
+ SeriesAggregationScanNode node, MemoryDistributionContext context) {
+ // do nothing since SourceNode will not have Exchange/FragmentSink as child
+ return null;
+ }
+
+ @Override
+ public Void visitAlignedSeriesScan(
+ AlignedSeriesScanNode node, MemoryDistributionContext context) {
+ // do nothing since SourceNode will not have Exchange/FragmentSink as child
+ return null;
+ }
+
+ @Override
+ public Void visitAlignedSeriesAggregationScan(
+ AlignedSeriesAggregationScanNode node, MemoryDistributionContext context) {
+ // do nothing since SourceNode will not have Exchange/FragmentSink as child
+ return null;
+ }
+
+ @Override
+ public Void visitDeviceView(DeviceViewNode node, MemoryDistributionContext context) {
+ // consume children one by one
+ processConsumeChildrenOneByOneNode(node);
+ return null;
+ }
+
+ @Override
+ public Void visitDeviceMerge(DeviceMergeNode node, MemoryDistributionContext context) {
+ processConsumeAllChildrenAtTheSameTime(node);
+ return null;
+ }
+
+ @Override
+ public Void visitFill(FillNode node, MemoryDistributionContext context) {
+ processConsumeAllChildrenAtTheSameTime(node);
+ return null;
+ }
+
+ @Override
+ public Void visitFilter(FilterNode node, MemoryDistributionContext context) {
+ processConsumeAllChildrenAtTheSameTime(node);
+ return null;
+ }
+
+ @Override
+ public Void visitGroupByLevel(GroupByLevelNode node, MemoryDistributionContext context) {
+ processConsumeAllChildrenAtTheSameTime(node);
+ return null;
+ }
+
+ @Override
+ public Void visitGroupByTag(GroupByTagNode node, MemoryDistributionContext context) {
+ processConsumeAllChildrenAtTheSameTime(node);
+ return null;
+ }
+
+ @Override
+ public Void visitSlidingWindowAggregation(
+ SlidingWindowAggregationNode node, MemoryDistributionContext context) {
+ processConsumeAllChildrenAtTheSameTime(node);
+ return null;
+ }
+
+ @Override
+ public Void visitLimit(LimitNode node, MemoryDistributionContext context) {
+ processConsumeAllChildrenAtTheSameTime(node);
+ return null;
+ }
+
+ @Override
+ public Void visitOffset(OffsetNode node, MemoryDistributionContext context) {
+ processConsumeAllChildrenAtTheSameTime(node);
+ return null;
+ }
+
+ @Override
+ public Void visitAggregation(AggregationNode node, MemoryDistributionContext context) {
+ processConsumeAllChildrenAtTheSameTime(node);
+ return null;
+ }
+
+ @Override
+ public Void visitSort(SortNode node, MemoryDistributionContext context) {
+ processConsumeAllChildrenAtTheSameTime(node);
+ return null;
+ }
+
+ @Override
+ public Void visitProject(ProjectNode node, MemoryDistributionContext context) {
+ processConsumeAllChildrenAtTheSameTime(node);
+ return null;
+ }
+
+ @Override
+ public Void visitTimeJoin(TimeJoinNode node, MemoryDistributionContext context) {
+ processConsumeAllChildrenAtTheSameTime(node);
+ return null;
+ }
+
+ @Override
+ public Void visitTransform(TransformNode node, MemoryDistributionContext context) {
+ processConsumeAllChildrenAtTheSameTime(node);
+ return null;
+ }
+
+ @Override
+ public Void visitLastQueryScan(LastQueryScanNode node, MemoryDistributionContext context) {
+ // do nothing since SourceNode will not have Exchange/FragmentSink as child
+ return null;
+ }
+
+ @Override
+ public Void visitAlignedLastQueryScan(
+ AlignedLastQueryScanNode node, MemoryDistributionContext context) {
+ // do nothing since SourceNode will not have Exchange/FragmentSink as child
+ return null;
+ }
+
+ @Override
+ public Void visitLastQuery(LastQueryNode node, MemoryDistributionContext context) {
+ processConsumeChildrenOneByOneNode(node);
+ return null;
+ }
+
+ @Override
+ public Void visitLastQueryMerge(LastQueryMergeNode node, MemoryDistributionContext context) {
+ processConsumeAllChildrenAtTheSameTime(node);
+ return null;
+ }
+
+ @Override
+ public Void visitLastQueryCollect(LastQueryCollectNode node, MemoryDistributionContext context) {
+ processConsumeChildrenOneByOneNode(node);
+ return null;
+ }
+
+ @Override
+ public Void visitInto(IntoNode node, MemoryDistributionContext context) {
+ processConsumeAllChildrenAtTheSameTime(node);
+ return null;
+ }
+
+ @Override
+ public Void visitDeviceViewInto(DeviceViewIntoNode node, MemoryDistributionContext context) {
+ processConsumeAllChildrenAtTheSameTime(node);
+ return null;
+ }
+
+ @Override
+ public Void visitVerticallyConcat(VerticallyConcatNode node, MemoryDistributionContext context) {
+ processConsumeAllChildrenAtTheSameTime(node);
+ return null;
+ }
+
+ @Override
+ public Void visitSchemaQueryMerge(SchemaQueryMergeNode node, MemoryDistributionContext context) {
+ processConsumeChildrenOneByOneNode(node);
+ return null;
+ }
+
+ @Override
+ public Void visitSchemaQueryScan(SchemaQueryScanNode node, MemoryDistributionContext context) {
+ // do nothing since SourceNode will not have Exchange/FragmentSink as child
+ return null;
+ }
+
+ @Override
+ public Void visitSchemaQueryOrderByHeat(
+ SchemaQueryOrderByHeatNode node, MemoryDistributionContext context) {
+ processConsumeAllChildrenAtTheSameTime(node);
+ return null;
+ }
+
+ @Override
+ public Void visitTimeSeriesSchemaScan(
+ TimeSeriesSchemaScanNode node, MemoryDistributionContext context) {
+ // do nothing since SourceNode will not have Exchange/FragmentSink as child
+ return null;
+ }
+
+ @Override
+ public Void visitDevicesSchemaScan(
+ DevicesSchemaScanNode node, MemoryDistributionContext context) {
+ // do nothing since SourceNode will not have Exchange/FragmentSink as child
+ return null;
+ }
+
+ @Override
+ public Void visitDevicesCount(DevicesCountNode node, MemoryDistributionContext context) {
+ // do nothing since SourceNode will not have Exchange/FragmentSink as child
+ return null;
+ }
+
+ @Override
+ public Void visitTimeSeriesCount(TimeSeriesCountNode node, MemoryDistributionContext context) {
+ // do nothing since SourceNode will not have Exchange/FragmentSink as child
+ return null;
+ }
+
+ @Override
+ public Void visitLevelTimeSeriesCount(
+ LevelTimeSeriesCountNode node, MemoryDistributionContext context) {
+ // do nothing since SourceNode will not have Exchange/FragmentSink as child
+ return null;
+ }
+
+ @Override
+ public Void visitCountMerge(CountSchemaMergeNode node, MemoryDistributionContext context) {
+ processConsumeChildrenOneByOneNode(node);
+ return null;
+ }
+
+ @Override
+ public Void visitSchemaFetchMerge(SchemaFetchMergeNode node, MemoryDistributionContext context) {
+ processConsumeChildrenOneByOneNode(node);
+ return null;
+ }
+
+ @Override
+ public Void visitSchemaFetchScan(SchemaFetchScanNode node, MemoryDistributionContext context) {
+ // do nothing since SourceNode will not have Exchange/FragmentSink as child
+ return null;
+ }
+
+ @Override
+ public Void visitNodePathsSchemaScan(
+ NodePathsSchemaScanNode node, MemoryDistributionContext context) {
+ // do nothing since SourceNode will not have Exchange/FragmentSink as child
+ return null;
+ }
+
+ @Override
+ public Void visitNodeManagementMemoryMerge(
+ NodeManagementMemoryMergeNode node, MemoryDistributionContext context) {
+ processConsumeAllChildrenAtTheSameTime(node);
+ return null;
+ }
+
+ @Override
+ public Void visitNodePathConvert(NodePathsConvertNode node, MemoryDistributionContext context) {
+ processConsumeAllChildrenAtTheSameTime(node);
+ return null;
+ }
+
+ @Override
+ public Void visitNodePathsCount(NodePathsCountNode node, MemoryDistributionContext context) {
+ processConsumeAllChildrenAtTheSameTime(node);
+ return null;
+ }
+
+ @Override
+ public Void visitConstructSchemaBlackList(
+ ConstructSchemaBlackListNode node, MemoryDistributionContext context) {
+ processConsumeAllChildrenAtTheSameTime(node);
+ return null;
+ }
+
+ @Override
+ public Void visitRollbackSchemaBlackList(
+ RollbackSchemaBlackListNode node, MemoryDistributionContext context) {
+ processConsumeAllChildrenAtTheSameTime(node);
+ return null;
+ }
+
+ @Override
+ public Void visitSingleDeviceView(SingleDeviceViewNode node, MemoryDistributionContext context) {
+ processConsumeAllChildrenAtTheSameTime(node);
+ return null;
+ }
+
+ @Override
+ public Void visitMergeSort(MergeSortNode node, MemoryDistributionContext context) {
+ processConsumeAllChildrenAtTheSameTime(node);
+ return null;
+ }
+
+ enum MemoryDistributionType {
+ /**
+ * This type means that this node needs data from all the children. For example, TimeJoinNode.
+ * If the type of the father node of an ExchangeNode is CONSUME_ALL_CHILDREN_AT_THE_SAME_TIME,
+ * the ExchangeNode needs one split of the memory.
+ */
+ CONSUME_ALL_CHILDREN_AT_THE_SAME_TIME(0),
+
+ /**
+ * This type means that this node consumes data of the children one by one. For example,
+ * DeviceMergeNode. If the type of the father node of an ExchangeNode is
+ * CONSUME_CHILDREN_ONE_BY_ONE, all the ExchangeNodes of that father node share one split of the
+ * memory.
+ */
+ CONSUME_CHILDREN_ONE_BY_ONE(1);
+
+ private final int id;
+
+ MemoryDistributionType(int id) {
+ this.id = id;
+ }
+
+ public int getId() {
+ return id;
+ }
+ }
+
+ static class MemoryDistributionContext {
+ final PlanNodeId planNodeId;
+ final MemoryDistributionType memoryDistributionType;
+
+ MemoryDistributionContext(
+ PlanNodeId planNodeId, MemoryDistributionType memoryDistributionType) {
+ this.planNodeId = planNodeId;
+ this.memoryDistributionType = memoryDistributionType;
+ }
+ }
+}
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/OperatorTreeGenerator.java b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/OperatorTreeGenerator.java
index c61cb93e26..123df8a63d 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/OperatorTreeGenerator.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/plan/planner/OperatorTreeGenerator.java
@@ -1703,6 +1703,7 @@ public class OperatorTreeGenerator extends PlanVisitor<Operator, LocalExecutionP
downStreamEndPoint,
targetInstanceId.toThrift(),
node.getDownStreamPlanNodeId().getId(),
+ node.getPlanNodeId().getId(),
context.getInstanceContext());
context.setSinkHandle(sinkHandle);
return child;
diff --git a/server/src/test/java/org/apache/iotdb/db/mpp/execution/exchange/LocalSinkHandleTest.java b/server/src/test/java/org/apache/iotdb/db/mpp/execution/exchange/LocalSinkHandleTest.java
index 1c3cb4e659..37055a6a24 100644
--- a/server/src/test/java/org/apache/iotdb/db/mpp/execution/exchange/LocalSinkHandleTest.java
+++ b/server/src/test/java/org/apache/iotdb/db/mpp/execution/exchange/LocalSinkHandleTest.java
@@ -48,7 +48,18 @@ public class LocalSinkHandleTest {
Mockito.mock(MPPDataExchangeManager.SinkHandleListener.class);
// Construct a shared TsBlock queue.
SharedTsBlockQueue queue =
- new SharedTsBlockQueue(remoteFragmentInstanceId, mockLocalMemoryManager);
+ new SharedTsBlockQueue(remoteFragmentInstanceId, remotePlanNodeId, mockLocalMemoryManager);
+
+ queue.setMaxBytesCanReserve(Long.MAX_VALUE);
+
+ // Construct SourceHandle
+ LocalSourceHandle localSourceHandle =
+ new LocalSourceHandle(
+ localFragmentInstanceId,
+ remoteFragmentInstanceId,
+ remotePlanNodeId,
+ queue,
+ Mockito.mock(MPPDataExchangeManager.SourceHandleListener.class));
// Construct SinkHandle.
LocalSinkHandle localSinkHandle =
@@ -58,6 +69,10 @@ public class LocalSinkHandleTest {
localFragmentInstanceId,
queue,
mockSinkHandleListener);
+ Assert.assertFalse(localSinkHandle.isFull().isDone());
+ localSourceHandle.isBlocked();
+ // blocked of LocalSinkHandle should be completed after calling isBlocked() of corresponding
+ // LocalSourceHandle
Assert.assertTrue(localSinkHandle.isFull().isDone());
Assert.assertFalse(localSinkHandle.isFinished());
Assert.assertFalse(localSinkHandle.isAborted());
@@ -69,11 +84,17 @@ public class LocalSinkHandleTest {
localSinkHandle.send(Utils.createMockTsBlock(mockTsBlockSize));
numOfSentTsblocks += 1;
}
- Assert.assertEquals(6, numOfSentTsblocks);
+ Assert.assertEquals(11, numOfSentTsblocks);
Assert.assertFalse(localSinkHandle.isFull().isDone());
Assert.assertFalse(localSinkHandle.isFinished());
- Assert.assertEquals(6 * mockTsBlockSize, localSinkHandle.getBufferRetainedSizeInBytes());
- Mockito.verify(spyMemoryPool, Mockito.times(6)).reserve(queryId, mockTsBlockSize);
+ Assert.assertEquals(11 * mockTsBlockSize, localSinkHandle.getBufferRetainedSizeInBytes());
+ Mockito.verify(spyMemoryPool, Mockito.times(11))
+ .reserve(
+ queryId,
+ localFragmentInstanceId.getInstanceId(),
+ remotePlanNodeId,
+ mockTsBlockSize,
+ Long.MAX_VALUE);
// Receive TsBlocks.
int numOfReceivedTsblocks = 0;
@@ -81,11 +102,12 @@ public class LocalSinkHandleTest {
queue.remove();
numOfReceivedTsblocks += 1;
}
- Assert.assertEquals(6, numOfReceivedTsblocks);
+ Assert.assertEquals(11, numOfReceivedTsblocks);
Assert.assertTrue(localSinkHandle.isFull().isDone());
Assert.assertFalse(localSinkHandle.isFinished());
Assert.assertEquals(0L, localSinkHandle.getBufferRetainedSizeInBytes());
- Mockito.verify(spyMemoryPool, Mockito.times(6)).free(queryId, mockTsBlockSize);
+ Mockito.verify(spyMemoryPool, Mockito.times(11))
+ .free(queryId, localFragmentInstanceId.getInstanceId(), remotePlanNodeId, mockTsBlockSize);
// Set no-more-TsBlocks.
localSinkHandle.setNoMoreTsBlocks();
@@ -113,7 +135,18 @@ public class LocalSinkHandleTest {
Mockito.mock(MPPDataExchangeManager.SinkHandleListener.class);
// Construct a shared tsblock queue.
SharedTsBlockQueue queue =
- new SharedTsBlockQueue(remoteFragmentInstanceId, mockLocalMemoryManager);
+ new SharedTsBlockQueue(remoteFragmentInstanceId, remotePlanNodeId, mockLocalMemoryManager);
+
+ queue.setMaxBytesCanReserve(Long.MAX_VALUE);
+
+ // Construct SourceHandle
+ LocalSourceHandle localSourceHandle =
+ new LocalSourceHandle(
+ localFragmentInstanceId,
+ remoteFragmentInstanceId,
+ remotePlanNodeId,
+ queue,
+ Mockito.mock(MPPDataExchangeManager.SourceHandleListener.class));
// Construct SinkHandle.
LocalSinkHandle localSinkHandle =
@@ -123,6 +156,10 @@ public class LocalSinkHandleTest {
localFragmentInstanceId,
queue,
mockSinkHandleListener);
+ Assert.assertFalse(localSinkHandle.isFull().isDone());
+ localSourceHandle.isBlocked();
+ // blocked of LocalSinkHandle should be completed after calling isBlocked() of corresponding
+ // LocalSourceHandle
Assert.assertTrue(localSinkHandle.isFull().isDone());
Assert.assertFalse(localSinkHandle.isFinished());
Assert.assertFalse(localSinkHandle.isAborted());
@@ -134,12 +171,18 @@ public class LocalSinkHandleTest {
localSinkHandle.send(Utils.createMockTsBlock(mockTsBlockSize));
numOfSentTsblocks += 1;
}
- Assert.assertEquals(6, numOfSentTsblocks);
+ Assert.assertEquals(11, numOfSentTsblocks);
ListenableFuture<?> blocked = localSinkHandle.isFull();
Assert.assertFalse(blocked.isDone());
Assert.assertFalse(localSinkHandle.isFinished());
- Assert.assertEquals(6 * mockTsBlockSize, localSinkHandle.getBufferRetainedSizeInBytes());
- Mockito.verify(spyMemoryPool, Mockito.times(6)).reserve(queryId, mockTsBlockSize);
+ Assert.assertEquals(11 * mockTsBlockSize, localSinkHandle.getBufferRetainedSizeInBytes());
+ Mockito.verify(spyMemoryPool, Mockito.times(11))
+ .reserve(
+ queryId,
+ localFragmentInstanceId.getInstanceId(),
+ remotePlanNodeId,
+ mockTsBlockSize,
+ Long.MAX_VALUE);
// Abort.
localSinkHandle.abort();
diff --git a/server/src/test/java/org/apache/iotdb/db/mpp/execution/exchange/LocalSourceHandleTest.java b/server/src/test/java/org/apache/iotdb/db/mpp/execution/exchange/LocalSourceHandleTest.java
index e1f49f5c77..3c9ee13e39 100644
--- a/server/src/test/java/org/apache/iotdb/db/mpp/execution/exchange/LocalSourceHandleTest.java
+++ b/server/src/test/java/org/apache/iotdb/db/mpp/execution/exchange/LocalSourceHandleTest.java
@@ -46,7 +46,7 @@ public class LocalSourceHandleTest {
SourceHandleListener mockSourceHandleListener = Mockito.mock(SourceHandleListener.class);
// Construct a shared TsBlock queue.
SharedTsBlockQueue queue =
- new SharedTsBlockQueue(localFragmentInstanceId, mockLocalMemoryManager);
+ new SharedTsBlockQueue(localFragmentInstanceId, localPlanNodeId, mockLocalMemoryManager);
LocalSourceHandle localSourceHandle =
new LocalSourceHandle(
@@ -94,7 +94,7 @@ public class LocalSourceHandleTest {
SourceHandleListener mockSourceHandleListener = Mockito.mock(SourceHandleListener.class);
// Construct a shared tsblock queue.
SharedTsBlockQueue queue =
- new SharedTsBlockQueue(localFragmentInstanceId, mockLocalMemoryManager);
+ new SharedTsBlockQueue(localFragmentInstanceId, localPlanNodeId, mockLocalMemoryManager);
LocalSourceHandle localSourceHandle =
new LocalSourceHandle(
diff --git a/server/src/test/java/org/apache/iotdb/db/mpp/execution/exchange/SharedTsBlockQueueTest.java b/server/src/test/java/org/apache/iotdb/db/mpp/execution/exchange/SharedTsBlockQueueTest.java
index 971e84fbdb..e5336d2ac2 100644
--- a/server/src/test/java/org/apache/iotdb/db/mpp/execution/exchange/SharedTsBlockQueueTest.java
+++ b/server/src/test/java/org/apache/iotdb/db/mpp/execution/exchange/SharedTsBlockQueueTest.java
@@ -45,7 +45,10 @@ public class SharedTsBlockQueueTest {
Mockito.spy(new MemoryPool("test", 10 * mockTsBlockSize, 5 * mockTsBlockSize));
Mockito.when(mockLocalMemoryManager.getQueryPool()).thenReturn(spyMemoryPool);
SharedTsBlockQueue queue =
- new SharedTsBlockQueue(new TFragmentInstanceId(queryId, 0, "0"), mockLocalMemoryManager);
+ new SharedTsBlockQueue(
+ new TFragmentInstanceId(queryId, 0, "0"), "test", mockLocalMemoryManager);
+ queue.getCanAddTsBlock().set(null);
+ queue.setMaxBytesCanReserve(Long.MAX_VALUE);
ExecutorService executor = Executors.newFixedThreadPool(2);
AtomicReference<Integer> numOfTimesSenderBlocked = new AtomicReference<>(0);
diff --git a/server/src/test/java/org/apache/iotdb/db/mpp/execution/exchange/SinkHandleTest.java b/server/src/test/java/org/apache/iotdb/db/mpp/execution/exchange/SinkHandleTest.java
index 2119c2ffcf..029eac1ccf 100644
--- a/server/src/test/java/org/apache/iotdb/db/mpp/execution/exchange/SinkHandleTest.java
+++ b/server/src/test/java/org/apache/iotdb/db/mpp/execution/exchange/SinkHandleTest.java
@@ -54,6 +54,7 @@ public class SinkHandleTest {
new TEndPoint("remote", IoTDBDescriptor.getInstance().getConfig().getMppDataExchangePort());
final TFragmentInstanceId remoteFragmentInstanceId = new TFragmentInstanceId(queryId, 0, "0");
final String remotePlanNodeId = "exchange_0";
+ final String localPlanNodeId = "fragmentSink_0";
final TFragmentInstanceId localFragmentInstanceId = new TFragmentInstanceId(queryId, 1, "0");
// Construct a mock LocalMemoryManager that returns unblocked futures.
@@ -88,6 +89,7 @@ public class SinkHandleTest {
remoteEndpoint,
remoteFragmentInstanceId,
remotePlanNodeId,
+ localPlanNodeId,
localFragmentInstanceId,
mockLocalMemoryManager,
Executors.newSingleThreadExecutor(),
@@ -110,8 +112,13 @@ public class SinkHandleTest {
mockTsBlockSize * numOfMockTsBlock + DEFAULT_MAX_TSBLOCK_SIZE_IN_BYTES,
sinkHandle.getBufferRetainedSizeInBytes());
Assert.assertEquals(numOfMockTsBlock, sinkHandle.getNumOfBufferedTsBlocks());
- Mockito.verify(mockMemoryPool, Mockito.timeout(10_0000).times(2))
- .reserve(queryId, mockTsBlockSize * numOfMockTsBlock);
+ // Mockito.verify(mockMemoryPool, Mockito.timeout(10_0000).times(1))
+ // .reserve(
+ // queryId,
+ // localFragmentInstanceId.getInstanceId(),
+ // localPlanNodeId,
+ // mockTsBlockSize * numOfMockTsBlock,
+ // Long.MAX_VALUE);
try {
Mockito.verify(mockClient, Mockito.timeout(10_000).times(1))
.onNewDataBlockEvent(
@@ -154,7 +161,11 @@ public class SinkHandleTest {
Assert.assertFalse(sinkHandle.isAborted());
Assert.assertEquals(mockTsBlockSize, sinkHandle.getBufferRetainedSizeInBytes());
Mockito.verify(mockMemoryPool, Mockito.timeout(10_0000).times(1))
- .free(queryId, numOfMockTsBlock * mockTsBlockSize);
+ .free(
+ queryId,
+ localFragmentInstanceId.getInstanceId(),
+ localPlanNodeId,
+ numOfMockTsBlock * mockTsBlockSize);
Mockito.verify(mockSinkHandleListener, Mockito.timeout(10_0000).times(1)).onFinish(sinkHandle);
try {
@@ -181,12 +192,18 @@ public class SinkHandleTest {
new TEndPoint("remote", IoTDBDescriptor.getInstance().getConfig().getMppDataExchangePort());
final TFragmentInstanceId remoteFragmentInstanceId = new TFragmentInstanceId(queryId, 0, "0");
final String remotePlanNodeId = "exchange_0";
+ final String localPlanNodeId = "fragmentSink_0";
final TFragmentInstanceId localFragmentInstanceId = new TFragmentInstanceId(queryId, 1, "0");
// Construct a mock LocalMemoryManager that returns blocked futures.
LocalMemoryManager mockLocalMemoryManager = Mockito.mock(LocalMemoryManager.class);
MemoryPool mockMemoryPool =
- Utils.createMockBlockedMemoryPool(queryId, numOfMockTsBlock, mockTsBlockSize);
+ Utils.createMockBlockedMemoryPool(
+ queryId,
+ localFragmentInstanceId.getInstanceId(),
+ localPlanNodeId,
+ numOfMockTsBlock,
+ mockTsBlockSize);
Mockito.when(mockLocalMemoryManager.getQueryPool()).thenReturn(mockMemoryPool);
// Construct a mock SinkHandleListener.
@@ -217,6 +234,7 @@ public class SinkHandleTest {
remoteEndpoint,
remoteFragmentInstanceId,
remotePlanNodeId,
+ localPlanNodeId,
localFragmentInstanceId,
mockLocalMemoryManager,
Executors.newSingleThreadExecutor(),
@@ -239,8 +257,13 @@ public class SinkHandleTest {
mockTsBlockSize * numOfMockTsBlock + DEFAULT_MAX_TSBLOCK_SIZE_IN_BYTES,
sinkHandle.getBufferRetainedSizeInBytes());
Assert.assertEquals(numOfMockTsBlock, sinkHandle.getNumOfBufferedTsBlocks());
- Mockito.verify(mockMemoryPool, Mockito.timeout(10_0000).times(2))
- .reserve(queryId, mockTsBlockSize * numOfMockTsBlock);
+ // Mockito.verify(mockMemoryPool, Mockito.timeout(10_0000).times(1))
+ // .reserve(
+ // queryId,
+ // localFragmentInstanceId.getInstanceId(),
+ // localPlanNodeId,
+ // mockTsBlockSize * numOfMockTsBlock,
+ // Long.MAX_VALUE);
try {
Mockito.verify(mockClient, Mockito.timeout(10_000).times(1))
.onNewDataBlockEvent(
@@ -276,7 +299,11 @@ public class SinkHandleTest {
Assert.assertEquals(
DEFAULT_MAX_TSBLOCK_SIZE_IN_BYTES, sinkHandle.getBufferRetainedSizeInBytes());
Mockito.verify(mockMemoryPool, Mockito.timeout(10_0000).times(1))
- .free(queryId, numOfMockTsBlock * mockTsBlockSize);
+ .free(
+ queryId,
+ localFragmentInstanceId.getInstanceId(),
+ localPlanNodeId,
+ numOfMockTsBlock * mockTsBlockSize);
// Send tsblocks.
sinkHandle.send(mockTsBlocks.get(0));
@@ -287,8 +314,13 @@ public class SinkHandleTest {
mockTsBlockSize * numOfMockTsBlock + DEFAULT_MAX_TSBLOCK_SIZE_IN_BYTES,
sinkHandle.getBufferRetainedSizeInBytes());
Assert.assertEquals(numOfMockTsBlock, sinkHandle.getNumOfBufferedTsBlocks());
- Mockito.verify(mockMemoryPool, Mockito.timeout(10_0000).times(3))
- .reserve(queryId, mockTsBlockSize * numOfMockTsBlock);
+ // Mockito.verify(mockMemoryPool, Mockito.timeout(10_0000).times(3))
+ // .reserve(
+ // queryId,
+ // localFragmentInstanceId.getInstanceId(),
+ // localPlanNodeId,
+ // mockTsBlockSize * numOfMockTsBlock,
+ // Long.MAX_VALUE);
try {
Mockito.verify(mockClient, Mockito.timeout(10_000).times(1))
.onNewDataBlockEvent(
@@ -343,7 +375,11 @@ public class SinkHandleTest {
Assert.assertEquals(
DEFAULT_MAX_TSBLOCK_SIZE_IN_BYTES, sinkHandle.getBufferRetainedSizeInBytes());
Mockito.verify(mockMemoryPool, Mockito.timeout(10_0000).times(2))
- .free(queryId, numOfMockTsBlock * mockTsBlockSize);
+ .free(
+ queryId,
+ localFragmentInstanceId.getInstanceId(),
+ localPlanNodeId,
+ numOfMockTsBlock * mockTsBlockSize);
Mockito.verify(mockSinkHandleListener, Mockito.timeout(10_0000).times(1)).onFinish(sinkHandle);
}
@@ -356,12 +392,18 @@ public class SinkHandleTest {
new TEndPoint("remote", IoTDBDescriptor.getInstance().getConfig().getMppDataExchangePort());
final TFragmentInstanceId remoteFragmentInstanceId = new TFragmentInstanceId(queryId, 0, "0");
final String remotePlanNodeId = "exchange_0";
+ final String localPlanNodeId = "fragmentSink_0";
final TFragmentInstanceId localFragmentInstanceId = new TFragmentInstanceId(queryId, 1, "0");
// Construct a mock LocalMemoryManager that returns blocked futures.
LocalMemoryManager mockLocalMemoryManager = Mockito.mock(LocalMemoryManager.class);
MemoryPool mockMemoryPool =
- Utils.createMockBlockedMemoryPool(queryId, numOfMockTsBlock, mockTsBlockSize);
+ Utils.createMockBlockedMemoryPool(
+ queryId,
+ localFragmentInstanceId.getInstanceId(),
+ localPlanNodeId,
+ numOfMockTsBlock,
+ mockTsBlockSize);
Mockito.when(mockLocalMemoryManager.getQueryPool()).thenReturn(mockMemoryPool);
// Construct a mock SinkHandleListener.
SinkHandleListener mockSinkHandleListener = Mockito.mock(SinkHandleListener.class);
@@ -392,6 +434,7 @@ public class SinkHandleTest {
remoteEndpoint,
remoteFragmentInstanceId,
remotePlanNodeId,
+ localPlanNodeId,
localFragmentInstanceId,
mockLocalMemoryManager,
Executors.newSingleThreadExecutor(),
@@ -415,8 +458,13 @@ public class SinkHandleTest {
mockTsBlockSize * numOfMockTsBlock + DEFAULT_MAX_TSBLOCK_SIZE_IN_BYTES,
sinkHandle.getBufferRetainedSizeInBytes());
Assert.assertEquals(numOfMockTsBlock, sinkHandle.getNumOfBufferedTsBlocks());
- Mockito.verify(mockMemoryPool, Mockito.timeout(10_0000).times(2))
- .reserve(queryId, mockTsBlockSize * numOfMockTsBlock);
+ // Mockito.verify(mockMemoryPool, Mockito.timeout(10_0000).times(1))
+ // .reserve(
+ // queryId,
+ // localFragmentInstanceId.getInstanceId(),
+ // localPlanNodeId,
+ // mockTsBlockSize * numOfMockTsBlock,
+ // Long.MAX_VALUE);
try {
Mockito.verify(mockClient, Mockito.timeout(10_000).times(SinkHandle.MAX_ATTEMPT_TIMES))
.onNewDataBlockEvent(
@@ -456,6 +504,7 @@ public class SinkHandleTest {
new TEndPoint("remote", IoTDBDescriptor.getInstance().getConfig().getMppDataExchangePort());
final TFragmentInstanceId remoteFragmentInstanceId = new TFragmentInstanceId(queryId, 0, "0");
final String remotePlanNodeId = "exchange_0";
+ final String localPlanNodeId = "fragmentSink_0";
final TFragmentInstanceId localFragmentInstanceId = new TFragmentInstanceId(queryId, 1, "0");
// Construct a mock LocalMemoryManager that returns blocked futures.
@@ -494,12 +543,14 @@ public class SinkHandleTest {
remoteEndpoint,
remoteFragmentInstanceId,
remotePlanNodeId,
+ localPlanNodeId,
localFragmentInstanceId,
mockLocalMemoryManager,
Executors.newSingleThreadExecutor(),
Utils.createMockTsBlockSerde(mockTsBlockSize),
mockSinkHandleListener,
mockClientManager);
+ sinkHandle.setMaxBytesCanReserve(Long.MAX_VALUE);
Assert.assertTrue(sinkHandle.isFull().isDone());
Assert.assertFalse(sinkHandle.isFinished());
Assert.assertFalse(sinkHandle.isAborted());
diff --git a/server/src/test/java/org/apache/iotdb/db/mpp/execution/exchange/SourceHandleTest.java b/server/src/test/java/org/apache/iotdb/db/mpp/execution/exchange/SourceHandleTest.java
index 1fd70c4d78..7c345b5d9f 100644
--- a/server/src/test/java/org/apache/iotdb/db/mpp/execution/exchange/SourceHandleTest.java
+++ b/server/src/test/java/org/apache/iotdb/db/mpp/execution/exchange/SourceHandleTest.java
@@ -217,6 +217,7 @@ public class SourceHandleTest {
mockTsBlockSerde,
mockSourceHandleListener,
mockClientManager);
+ sourceHandle.setMaxBytesCanReserve(5 * mockTsBlockSize);
Assert.assertFalse(sourceHandle.isBlocked().isDone());
Assert.assertFalse(sourceHandle.isAborted());
Assert.assertFalse(sourceHandle.isFinished());
@@ -230,7 +231,12 @@ public class SourceHandleTest {
.collect(Collectors.toList()));
try {
Mockito.verify(spyMemoryPool, Mockito.timeout(10_000).times(6))
- .reserve(queryId, mockTsBlockSize);
+ .reserve(
+ queryId,
+ localFragmentInstanceId.getInstanceId(),
+ localPlanNodeId,
+ mockTsBlockSize,
+ 5 * mockTsBlockSize);
Mockito.verify(mockClient, Mockito.timeout(10_0000).times(1))
.getDataBlock(
Mockito.argThat(
@@ -257,7 +263,7 @@ public class SourceHandleTest {
// The local fragment instance consumes the data blocks.
for (int i = 0; i < numOfMockTsBlock; i++) {
Mockito.verify(spyMemoryPool, Mockito.timeout(10_0000).times(i))
- .free(queryId, mockTsBlockSize);
+ .free(queryId, localFragmentInstanceId.getInstanceId(), localPlanNodeId, mockTsBlockSize);
sourceHandle.receive();
try {
if (i < 5) {
diff --git a/server/src/test/java/org/apache/iotdb/db/mpp/execution/exchange/StubSinkHandle.java b/server/src/test/java/org/apache/iotdb/db/mpp/execution/exchange/StubSinkHandle.java
index 2930abe48b..0e7b5ffb96 100644
--- a/server/src/test/java/org/apache/iotdb/db/mpp/execution/exchange/StubSinkHandle.java
+++ b/server/src/test/java/org/apache/iotdb/db/mpp/execution/exchange/StubSinkHandle.java
@@ -99,6 +99,9 @@ public class StubSinkHandle implements ISinkHandle {
tsBlocks.clear();
}
+ @Override
+ public void setMaxBytesCanReserve(long maxBytesCanReserve) {}
+
public List<TsBlock> getTsBlocks() {
return tsBlocks;
}
diff --git a/server/src/test/java/org/apache/iotdb/db/mpp/execution/exchange/Utils.java b/server/src/test/java/org/apache/iotdb/db/mpp/execution/exchange/Utils.java
index a848f582ee..0ea87d58fd 100644
--- a/server/src/test/java/org/apache/iotdb/db/mpp/execution/exchange/Utils.java
+++ b/server/src/test/java/org/apache/iotdb/db/mpp/execution/exchange/Utils.java
@@ -54,19 +54,29 @@ public class Utils {
}
public static MemoryPool createMockBlockedMemoryPool(
- String queryId, int numOfMockTsBlock, long mockTsBlockSize) {
+ String queryId,
+ String fragmentInstanceId,
+ String planNodeId,
+ int numOfMockTsBlock,
+ long mockTsBlockSize) {
long capacityInBytes = numOfMockTsBlock * mockTsBlockSize;
MemoryPool mockMemoryPool = Mockito.mock(MemoryPool.class);
AtomicReference<SettableFuture<Void>> settableFuture = new AtomicReference<>();
settableFuture.set(SettableFuture.create());
settableFuture.get().set(null);
AtomicReference<Long> reservedBytes = new AtomicReference<>(0L);
- Mockito.when(mockMemoryPool.reserve(Mockito.eq(queryId), Mockito.anyLong()))
+ Mockito.when(
+ mockMemoryPool.reserve(
+ Mockito.eq(queryId),
+ Mockito.eq(fragmentInstanceId),
+ Mockito.eq(planNodeId),
+ Mockito.anyLong(),
+ Mockito.anyLong()))
.thenAnswer(
invocation -> {
- long bytesToReserve = invocation.getArgument(1);
+ long bytesToReserve = invocation.getArgument(3);
if (reservedBytes.get() + bytesToReserve <= capacityInBytes) {
- reservedBytes.updateAndGet(v -> v + (long) invocation.getArgument(1));
+ reservedBytes.updateAndGet(v -> v + (long) invocation.getArgument(3));
return new Pair<>(settableFuture.get(), true);
} else {
settableFuture.set(SettableFuture.create());
@@ -76,7 +86,7 @@ public class Utils {
Mockito.doAnswer(
(Answer<Void>)
invocation -> {
- reservedBytes.updateAndGet(v -> v - (long) invocation.getArgument(1));
+ reservedBytes.updateAndGet(v -> v - (long) invocation.getArgument(3));
if (reservedBytes.get() <= 0) {
settableFuture.get().set(null);
reservedBytes.updateAndGet(v -> v + mockTsBlockSize);
@@ -84,11 +94,21 @@ public class Utils {
return null;
})
.when(mockMemoryPool)
- .free(Mockito.eq(queryId), Mockito.anyLong());
- Mockito.when(mockMemoryPool.tryReserve(Mockito.eq(queryId), Mockito.anyLong()))
+ .free(
+ Mockito.eq(queryId),
+ Mockito.eq(fragmentInstanceId),
+ Mockito.eq(planNodeId),
+ Mockito.anyLong());
+ Mockito.when(
+ mockMemoryPool.tryReserve(
+ Mockito.eq(queryId),
+ Mockito.eq(fragmentInstanceId),
+ Mockito.eq(planNodeId),
+ Mockito.anyLong(),
+ Mockito.anyLong()))
.thenAnswer(
invocation -> {
- long bytesToReserve = invocation.getArgument(1);
+ long bytesToReserve = invocation.getArgument(3);
if (reservedBytes.get() + bytesToReserve > capacityInBytes) {
return false;
} else {
@@ -101,9 +121,21 @@ public class Utils {
public static MemoryPool createMockNonBlockedMemoryPool() {
MemoryPool mockMemoryPool = Mockito.mock(MemoryPool.class);
- Mockito.when(mockMemoryPool.reserve(Mockito.anyString(), Mockito.anyLong()))
+ Mockito.when(
+ mockMemoryPool.reserve(
+ Mockito.anyString(),
+ Mockito.anyString(),
+ Mockito.anyString(),
+ Mockito.anyLong(),
+ Mockito.anyLong()))
.thenReturn(new Pair<>(immediateFuture(null), true));
- Mockito.when(mockMemoryPool.tryReserve(Mockito.anyString(), Mockito.anyLong()))
+ Mockito.when(
+ mockMemoryPool.tryReserve(
+ Mockito.anyString(),
+ Mockito.anyString(),
+ Mockito.anyString(),
+ Mockito.anyLong(),
+ Mockito.anyLong()))
.thenReturn(true);
return mockMemoryPool;
}
diff --git a/server/src/test/java/org/apache/iotdb/db/mpp/execution/memory/MemoryPoolTest.java b/server/src/test/java/org/apache/iotdb/db/mpp/execution/memory/MemoryPoolTest.java
index 3d77d651b9..94a7c81999 100644
--- a/server/src/test/java/org/apache/iotdb/db/mpp/execution/memory/MemoryPoolTest.java
+++ b/server/src/test/java/org/apache/iotdb/db/mpp/execution/memory/MemoryPoolTest.java
@@ -28,6 +28,11 @@ public class MemoryPoolTest {
MemoryPool pool;
+ private final String QUERY_ID = "q0";
+
+ private final String FRAGMENT_INSTANCE_ID = "f0";
+ private final String PLAN_NODE_ID = "p0";
+
@Before
public void before() {
pool = new MemoryPool("test", 1024L, 512L);
@@ -35,17 +40,18 @@ public class MemoryPoolTest {
@Test
public void testTryReserve() {
- String queryId = "q0";
- Assert.assertTrue(pool.tryReserve(queryId, 256L));
- Assert.assertEquals(256L, pool.getQueryMemoryReservedBytes(queryId));
+
+ Assert.assertTrue(
+ pool.tryReserve(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 256L, Long.MAX_VALUE));
+ Assert.assertEquals(256L, pool.getQueryMemoryReservedBytes(QUERY_ID));
Assert.assertEquals(256L, pool.getReservedBytes());
}
@Test
public void testTryReserveZero() {
- String queryId = "q0";
+
try {
- pool.tryReserve(queryId, 0L);
+ pool.tryReserve(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 0L, Long.MAX_VALUE);
Assert.fail("Expect IllegalArgumentException");
} catch (IllegalArgumentException ignore) {
}
@@ -53,9 +59,9 @@ public class MemoryPoolTest {
@Test
public void testTryReserveNegative() {
- String queryId = "q0";
+
try {
- pool.tryReserve(queryId, -1L);
+ pool.tryReserve(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, -1L, Long.MAX_VALUE);
Assert.fail("Expect IllegalArgumentException");
} catch (IllegalArgumentException ignore) {
}
@@ -63,37 +69,39 @@ public class MemoryPoolTest {
@Test
public void testTryReserveAll() {
- String queryId = "q0";
- Assert.assertTrue(pool.tryReserve(queryId, 512L));
- Assert.assertEquals(512L, pool.getQueryMemoryReservedBytes(queryId));
+
+ Assert.assertTrue(
+ pool.tryReserve(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 512L, Long.MAX_VALUE));
+ Assert.assertEquals(512L, pool.getQueryMemoryReservedBytes(QUERY_ID));
Assert.assertEquals(512L, pool.getReservedBytes());
}
@Test
public void testOverTryReserve() {
- String queryId = "q0";
- Assert.assertTrue(pool.tryReserve(queryId, 256L));
- Assert.assertEquals(256L, pool.getQueryMemoryReservedBytes(queryId));
+
+ Assert.assertTrue(pool.tryReserve(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 256L, 512L));
+ Assert.assertEquals(256L, pool.getQueryMemoryReservedBytes(QUERY_ID));
Assert.assertEquals(256L, pool.getReservedBytes());
- Assert.assertFalse(pool.tryReserve(queryId, 512L));
- Assert.assertEquals(256L, pool.getQueryMemoryReservedBytes(queryId));
+ Assert.assertFalse(pool.tryReserve(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 512L, 511L));
+ Assert.assertEquals(256L, pool.getQueryMemoryReservedBytes(QUERY_ID));
Assert.assertEquals(256L, pool.getReservedBytes());
}
@Test
public void testReserve() {
- String queryId = "q0";
- ListenableFuture<Void> future = pool.reserve(queryId, 256L).left;
+
+ ListenableFuture<Void> future =
+ pool.reserve(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 256L, Long.MAX_VALUE).left;
Assert.assertTrue(future.isDone());
- Assert.assertEquals(256L, pool.getQueryMemoryReservedBytes(queryId));
+ Assert.assertEquals(256L, pool.getQueryMemoryReservedBytes(QUERY_ID));
Assert.assertEquals(256L, pool.getReservedBytes());
}
@Test
public void tesReserveZero() {
- String queryId = "q0";
+
try {
- pool.reserve(queryId, 0L);
+ pool.reserve(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 0L, Long.MAX_VALUE);
Assert.fail("Expect IllegalArgumentException");
} catch (IllegalArgumentException ignore) {
}
@@ -101,9 +109,9 @@ public class MemoryPoolTest {
@Test
public void testReserveNegative() {
- String queryId = "q0";
+
try {
- pool.reserve(queryId, -1L);
+ pool.reserve(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, -1L, Long.MAX_VALUE);
Assert.fail("Expect IllegalArgumentException");
} catch (IllegalArgumentException ignore) {
}
@@ -111,110 +119,125 @@ public class MemoryPoolTest {
@Test
public void testReserveAll() {
- String queryId = "q0";
- ListenableFuture<Void> future = pool.reserve(queryId, 512L).left;
+
+ ListenableFuture<Void> future =
+ pool.reserve(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 512L, Long.MAX_VALUE).left;
Assert.assertTrue(future.isDone());
- Assert.assertEquals(512L, pool.getQueryMemoryReservedBytes(queryId));
+ Assert.assertEquals(512L, pool.getQueryMemoryReservedBytes(QUERY_ID));
Assert.assertEquals(512L, pool.getReservedBytes());
}
@Test
public void testOverReserve() {
- String queryId = "q0";
- ListenableFuture<Void> future = pool.reserve(queryId, 256L).left;
+
+ ListenableFuture<Void> future =
+ pool.reserve(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 256L, Long.MAX_VALUE).left;
Assert.assertTrue(future.isDone());
- Assert.assertEquals(256L, pool.getQueryMemoryReservedBytes(queryId));
+ Assert.assertEquals(256L, pool.getQueryMemoryReservedBytes(QUERY_ID));
Assert.assertEquals(256L, pool.getReservedBytes());
- future = pool.reserve(queryId, 512L).left;
+ future = pool.reserve(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 512L, 513L).left;
Assert.assertFalse(future.isDone());
- Assert.assertEquals(256L, pool.getQueryMemoryReservedBytes(queryId));
+ Assert.assertEquals(256L, pool.getQueryMemoryReservedBytes(QUERY_ID));
Assert.assertEquals(256L, pool.getReservedBytes());
}
@Test
public void testReserveAndFree() {
- String queryId = "q0";
- Assert.assertTrue(pool.reserve(queryId, 512L).left.isDone());
- ListenableFuture<Void> future = pool.reserve(queryId, 512L).left;
+
+ Assert.assertTrue(
+ pool.reserve(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 512L, Long.MAX_VALUE)
+ .left
+ .isDone());
+ ListenableFuture<Void> future =
+ pool.reserve(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 512L, 513L).left;
Assert.assertFalse(future.isDone());
- Assert.assertEquals(512L, pool.getQueryMemoryReservedBytes(queryId));
+ Assert.assertEquals(512L, pool.getQueryMemoryReservedBytes(QUERY_ID));
Assert.assertEquals(512L, pool.getReservedBytes());
- pool.free(queryId, 512);
- Assert.assertEquals(512L, pool.getQueryMemoryReservedBytes(queryId));
+ pool.free(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 512);
+ Assert.assertEquals(512L, pool.getQueryMemoryReservedBytes(QUERY_ID));
Assert.assertEquals(512L, pool.getReservedBytes());
Assert.assertTrue(future.isDone());
}
@Test
public void testMultiReserveAndFree() {
- String queryId = "q0";
- Assert.assertTrue(pool.reserve(queryId, 256L).left.isDone());
- Assert.assertEquals(256L, pool.getQueryMemoryReservedBytes(queryId));
+
+ Assert.assertTrue(
+ pool.reserve(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 256L, Long.MAX_VALUE)
+ .left
+ .isDone());
+ Assert.assertEquals(256L, pool.getQueryMemoryReservedBytes(QUERY_ID));
Assert.assertEquals(256L, pool.getReservedBytes());
- ListenableFuture<Void> future1 = pool.reserve(queryId, 512L).left;
- ListenableFuture<Void> future2 = pool.reserve(queryId, 512L).left;
- ListenableFuture<Void> future3 = pool.reserve(queryId, 512L).left;
+ ListenableFuture<Void> future1 =
+ pool.reserve(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 512L, 513L).left;
+ ListenableFuture<Void> future2 =
+ pool.reserve(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 512L, 513L).left;
+ ListenableFuture<Void> future3 =
+ pool.reserve(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 512L, 513L).left;
Assert.assertFalse(future1.isDone());
Assert.assertFalse(future2.isDone());
Assert.assertFalse(future3.isDone());
- pool.free(queryId, 256L);
+ pool.free(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 256L);
Assert.assertTrue(future1.isDone());
Assert.assertFalse(future2.isDone());
Assert.assertFalse(future3.isDone());
- Assert.assertEquals(512L, pool.getQueryMemoryReservedBytes(queryId));
+ Assert.assertEquals(512L, pool.getQueryMemoryReservedBytes(QUERY_ID));
Assert.assertEquals(512L, pool.getReservedBytes());
- pool.free(queryId, 512L);
+ pool.free(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 512L);
Assert.assertTrue(future2.isDone());
Assert.assertFalse(future3.isDone());
- Assert.assertEquals(512L, pool.getQueryMemoryReservedBytes(queryId));
+ Assert.assertEquals(512L, pool.getQueryMemoryReservedBytes(QUERY_ID));
Assert.assertEquals(512L, pool.getReservedBytes());
- pool.free(queryId, 512L);
+ pool.free(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 512L);
Assert.assertTrue(future3.isDone());
- Assert.assertEquals(512L, pool.getQueryMemoryReservedBytes(queryId));
+ Assert.assertEquals(512L, pool.getQueryMemoryReservedBytes(QUERY_ID));
Assert.assertEquals(512L, pool.getReservedBytes());
- pool.free(queryId, 512L);
- Assert.assertEquals(0L, pool.getQueryMemoryReservedBytes(queryId));
+ pool.free(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 512L);
+ Assert.assertEquals(0L, pool.getQueryMemoryReservedBytes(QUERY_ID));
Assert.assertEquals(0L, pool.getReservedBytes());
}
@Test
public void testFree() {
- String queryId = "q0";
- Assert.assertTrue(pool.tryReserve(queryId, 512L));
- Assert.assertEquals(512L, pool.getQueryMemoryReservedBytes(queryId));
+
+ Assert.assertTrue(
+ pool.tryReserve(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 512L, Long.MAX_VALUE));
+ Assert.assertEquals(512L, pool.getQueryMemoryReservedBytes(QUERY_ID));
Assert.assertEquals(512L, pool.getReservedBytes());
- pool.free(queryId, 256L);
- Assert.assertEquals(256L, pool.getQueryMemoryReservedBytes(queryId));
+ pool.free(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 256L);
+ Assert.assertEquals(256L, pool.getQueryMemoryReservedBytes(QUERY_ID));
Assert.assertEquals(256L, pool.getReservedBytes());
}
@Test
public void testFreeAll() {
- String queryId = "q0";
- Assert.assertTrue(pool.tryReserve(queryId, 512L));
- Assert.assertEquals(512L, pool.getQueryMemoryReservedBytes(queryId));
+
+ Assert.assertTrue(
+ pool.tryReserve(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 512L, Long.MAX_VALUE));
+ Assert.assertEquals(512L, pool.getQueryMemoryReservedBytes(QUERY_ID));
Assert.assertEquals(512L, pool.getReservedBytes());
- pool.free(queryId, 512L);
- Assert.assertEquals(0L, pool.getQueryMemoryReservedBytes(queryId));
+ pool.free(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 512L);
+ Assert.assertEquals(0L, pool.getQueryMemoryReservedBytes(QUERY_ID));
Assert.assertEquals(0L, pool.getReservedBytes());
}
@Test
public void testFreeZero() {
- String queryId = "q0";
- Assert.assertTrue(pool.tryReserve(queryId, 512L));
- Assert.assertEquals(512L, pool.getQueryMemoryReservedBytes(queryId));
+
+ Assert.assertTrue(
+ pool.tryReserve(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 512L, Long.MAX_VALUE));
+ Assert.assertEquals(512L, pool.getQueryMemoryReservedBytes(QUERY_ID));
Assert.assertEquals(512L, pool.getReservedBytes());
try {
- pool.free(queryId, 0L);
+ pool.free(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 0L);
Assert.fail("Expect IllegalArgumentException");
} catch (IllegalArgumentException ignore) {
}
@@ -222,13 +245,14 @@ public class MemoryPoolTest {
@Test
public void testFreeNegative() {
- String queryId = "q0";
- Assert.assertTrue(pool.tryReserve(queryId, 512L));
- Assert.assertEquals(512L, pool.getQueryMemoryReservedBytes(queryId));
+
+ Assert.assertTrue(
+ pool.tryReserve(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 512L, Long.MAX_VALUE));
+ Assert.assertEquals(512L, pool.getQueryMemoryReservedBytes(QUERY_ID));
Assert.assertEquals(512L, pool.getReservedBytes());
try {
- pool.free(queryId, -1L);
+ pool.free(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, -1L);
Assert.fail("Expect IllegalArgumentException");
} catch (IllegalArgumentException ignore) {
}
@@ -236,13 +260,14 @@ public class MemoryPoolTest {
@Test
public void testOverFree() {
- String queryId = "q0";
- Assert.assertTrue(pool.tryReserve(queryId, 512L));
- Assert.assertEquals(512L, pool.getQueryMemoryReservedBytes(queryId));
+
+ Assert.assertTrue(
+ pool.tryReserve(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 512L, Long.MAX_VALUE));
+ Assert.assertEquals(512L, pool.getQueryMemoryReservedBytes(QUERY_ID));
Assert.assertEquals(512L, pool.getReservedBytes());
try {
- pool.free(queryId, 513L);
+ pool.free(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 513L);
Assert.fail("Expect IllegalArgumentException");
} catch (IllegalArgumentException ignore) {
}
@@ -250,11 +275,13 @@ public class MemoryPoolTest {
@Test
public void testTryCancelBlockedReservation() {
- String queryId = "q0";
+
// Run out of memory.
- Assert.assertTrue(pool.tryReserve(queryId, 512L));
+ Assert.assertTrue(
+ pool.tryReserve(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 512L, Long.MAX_VALUE));
- ListenableFuture<Void> f = pool.reserve(queryId, 256L).left;
+ ListenableFuture<Void> f =
+ pool.reserve(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 256L, 512L).left;
Assert.assertFalse(f.isDone());
// Cancel the reservation.
Assert.assertEquals(256L, pool.tryCancel(f));
@@ -264,8 +291,9 @@ public class MemoryPoolTest {
@Test
public void testTryCancelCompletedReservation() {
- String queryId = "q0";
- ListenableFuture<Void> f = pool.reserve(queryId, 256L).left;
+
+ ListenableFuture<Void> f =
+ pool.reserve(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 256L, Long.MAX_VALUE).left;
Assert.assertTrue(f.isDone());
// Cancel the reservation.
Assert.assertEquals(0L, pool.tryCancel(f));