You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@iotdb.apache.org by xi...@apache.org on 2022/04/07 10:36:35 UTC
[iotdb] branch master updated: [IOTDB-2727] data block manager impl (#5367)
This is an automated email from the ASF dual-hosted git repository.
xingtanzjr pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iotdb.git
The following commit(s) were added to refs/heads/master by this push:
new 3174c501cd [IOTDB-2727] data block manager impl (#5367)
3174c501cd is described below
commit 3174c501cdfa7f40e4b3a223641b74e2fe02fa80
Author: Zhong Wang <wa...@alibaba-inc.com>
AuthorDate: Thu Apr 7 18:36:29 2022 +0800
[IOTDB-2727] data block manager impl (#5367)
* [IOTDB-2727] data block manager impl
* Refactor
1. remove internal methods form interfaces
2. use impl independent exception
3. add logs for important methods
* Make SourceHandleListener and SinkHandleListener public
* Make TsBlockSerde#serialized and TsBlockSerde#deserialize public
* Fix SinkHandle#close
* Fix sourceHandle#isFinished
* Optimize UT(s)
* Optimize exception handling
* Fix memory reservations are related to wrong query
* Fix DataBlockServiceImpl#getDataBlock
Co-authored-by: Jackie Tien <ja...@gmail.com>
* Optmizations & fixes.
1. Stop checking blocked#isDone when creating a blocked listenable future in SourceHandle#receive
2. Avoid getting the SinkHandle multiple times from the map in DataBlockServiceImpl#getDataBlock
3. Style issues.
4. Throw an exception when the target SourceHandle of a event doesn't exist.
* Fix style issue
* Optimizations & fixes
1. Use LinkedHashMap as the container of tsblock in SinkHandle
for better performance
2. Add AcknowledgeDataBlockEvent to handle retry of
GetDataBlockRequest
* Resolve comments
1. Add configurations for memory manager and data block manager.
2. Add TODO(s)
* Resolve comments
1. SourceHandle#trySubmitGetDataBlocksTask triggers another call of trySubmitGetDataBlocksTask
* Rebase
* Rebase
* Fix UT
* Resolve comments
1. Fix MemoryPool#free incorrectly stops invoking queued listeners
2. Replace the word operator with plan node
3. Add shuffle configurations
* Fix style issue
* Fix UT
Co-authored-by: Jackie Tien <ja...@gmail.com>
---
.../resources/conf/iotdb-engine.properties | 15 +
.../java/org/apache/iotdb/db/conf/IoTDBConfig.java | 44 ++
.../org/apache/iotdb/db/conf/IoTDBDescriptor.java | 25 +
.../iotdb/db/mpp/buffer/DataBlockManager.java | 352 +++++++++---
.../db/mpp/buffer/DataBlockManagerService.java | 90 ----
.../iotdb/db/mpp/buffer/DataBlockService.java | 137 +++++
.../mpp/buffer/DataBlockServiceClientFactory.java | 24 +-
.../iotdb/db/mpp/buffer/DataBlockServiceImpl.java | 50 --
...ler.java => DataBlockServiceThriftHandler.java} | 2 +-
.../iotdb/db/mpp/buffer/IDataBlockManager.java | 20 +-
.../apache/iotdb/db/mpp/buffer/ISinkHandle.java | 35 +-
.../apache/iotdb/db/mpp/buffer/ISourceHandle.java | 21 +-
.../org/apache/iotdb/db/mpp/buffer/SinkHandle.java | 365 +++++++++++++
.../apache/iotdb/db/mpp/buffer/SourceHandle.java | 371 ++++++++++++-
.../apache/iotdb/db/mpp/buffer/StubSinkHandle.java | 29 +-
.../{ISourceHandle.java => TsBlockSerde.java} | 29 +-
...ISourceHandle.java => TsBlockSerdeFactory.java} | 26 +-
.../apache/iotdb/db/mpp/execution/DataDriver.java | 3 +-
.../iotdb/db/mpp/execution/SchemaDriver.java | 3 +-
.../iotdb/db/mpp/memory/LocalMemoryManager.java | 17 +-
.../org/apache/iotdb/db/mpp/memory/MemoryPool.java | 108 +++-
.../apache/iotdb/db/mpp/buffer/SinkHandleTest.java | 448 ++++++++++++++++
.../iotdb/db/mpp/buffer/SourceHandleTest.java | 591 +++++++++++++++++++++
.../java/org/apache/iotdb/db/mpp/buffer/Utils.java | 105 ++++
.../apache/iotdb/db/mpp/memory/MemoryPoolTest.java | 140 ++++-
thrift/src/main/thrift/mpp.thrift | 27 +-
26 files changed, 2729 insertions(+), 348 deletions(-)
diff --git a/server/src/assembly/resources/conf/iotdb-engine.properties b/server/src/assembly/resources/conf/iotdb-engine.properties
index d6598a5d14..a5b7527b2e 100644
--- a/server/src/assembly/resources/conf/iotdb-engine.properties
+++ b/server/src/assembly/resources/conf/iotdb-engine.properties
@@ -932,3 +932,18 @@ timestamp_precision=ms
####################
# Datatype: float
# group_by_fill_cache_size_in_mb=1.0
+
+####################
+### Shuffle Configuration
+####################
+# Datatype: int
+# data_block_manager_port=7777
+
+# Datatype: int
+# data_block_manager_core_pool_size=1
+
+# Datatype: int
+# data_block_manager_max_pool_size=5
+
+# Datatype: int
+# data_block_manager_keep_alive_time_in_ms=1000
diff --git a/server/src/main/java/org/apache/iotdb/db/conf/IoTDBConfig.java b/server/src/main/java/org/apache/iotdb/db/conf/IoTDBConfig.java
index 580d8f7552..ffb2f2fc7e 100644
--- a/server/src/main/java/org/apache/iotdb/db/conf/IoTDBConfig.java
+++ b/server/src/main/java/org/apache/iotdb/db/conf/IoTDBConfig.java
@@ -826,6 +826,18 @@ public class IoTDBConfig {
/** The max time of data node waiting to join into the cluster */
private long joinClusterTimeOutMs = TimeUnit.SECONDS.toMillis(60);
+ /** Port that data block manager thrift service listen to. */
+ private int dataBlockManagerPort = 7777;
+
+ /** Core pool size of data block manager. */
+ private int dataBlockManagerCorePoolSize = 1;
+
+ /** Max pool size of data block manager. */
+ private int dataBlockManagerMaxPoolSize = 5;
+
+ /** Thread keep alive time in ms of data block manager. */
+ private int dataBlockManagerKeepAliveTimeInMs = 1000;
+
public IoTDBConfig() {
try {
internalIp = InetAddress.getLocalHost().getHostAddress();
@@ -2586,4 +2598,36 @@ public class IoTDBConfig {
public void setJoinClusterTimeOutMs(long joinClusterTimeOutMs) {
this.joinClusterTimeOutMs = joinClusterTimeOutMs;
}
+
+ public int getDataBlockManagerPort() {
+ return dataBlockManagerPort;
+ }
+
+ public void setDataBlockManagerPort(int dataBlockManagerPort) {
+ this.dataBlockManagerPort = dataBlockManagerPort;
+ }
+
+ public int getDataBlockManagerCorePoolSize() {
+ return dataBlockManagerCorePoolSize;
+ }
+
+ public void setDataBlockManagerCorePoolSize(int dataBlockManagerCorePoolSize) {
+ this.dataBlockManagerCorePoolSize = dataBlockManagerCorePoolSize;
+ }
+
+ public int getDataBlockManagerMaxPoolSize() {
+ return dataBlockManagerMaxPoolSize;
+ }
+
+ public void setDataBlockManagerMaxPoolSize(int dataBlockManagerMaxPoolSize) {
+ this.dataBlockManagerMaxPoolSize = dataBlockManagerMaxPoolSize;
+ }
+
+ public int getDataBlockManagerKeepAliveTimeInMs() {
+ return dataBlockManagerKeepAliveTimeInMs;
+ }
+
+ public void setDataBlockManagerKeepAliveTimeInMs(int dataBlockManagerKeepAliveTimeInMs) {
+ this.dataBlockManagerKeepAliveTimeInMs = dataBlockManagerKeepAliveTimeInMs;
+ }
}
diff --git a/server/src/main/java/org/apache/iotdb/db/conf/IoTDBDescriptor.java b/server/src/main/java/org/apache/iotdb/db/conf/IoTDBDescriptor.java
index e384b5c2b1..8e904f0c81 100644
--- a/server/src/main/java/org/apache/iotdb/db/conf/IoTDBDescriptor.java
+++ b/server/src/main/java/org/apache/iotdb/db/conf/IoTDBDescriptor.java
@@ -833,6 +833,9 @@ public class IoTDBDescriptor {
// cluster
loadClusterProps(properties);
+
+ // shuffle
+ loadShuffleProps(properties);
} catch (FileNotFoundException e) {
logger.warn("Fail to find config file {}", url, e);
} catch (IOException e) {
@@ -1474,6 +1477,28 @@ public class IoTDBDescriptor {
properties.getProperty("internal_port", Integer.toString(conf.getInternalPort()))));
}
+ public void loadShuffleProps(Properties properties) {
+ conf.setDataBlockManagerPort(
+ Integer.parseInt(
+ properties.getProperty(
+ "data_block_manager_port", Integer.toString(conf.getDataBlockManagerPort()))));
+ conf.setDataBlockManagerCorePoolSize(
+ Integer.parseInt(
+ properties.getProperty(
+ "data_block_manager_core_pool_size",
+ Integer.toString(conf.getDataBlockManagerCorePoolSize()))));
+ conf.setDataBlockManagerMaxPoolSize(
+ Integer.parseInt(
+ properties.getProperty(
+ "data_block_manager_max_pool_size",
+ Integer.toString(conf.getDataBlockManagerMaxPoolSize()))));
+ conf.setDataBlockManagerKeepAliveTimeInMs(
+ Integer.parseInt(
+ properties.getProperty(
+ "data_block_manager_keep_alive_time_in_ms",
+ Integer.toString(conf.getDataBlockManagerKeepAliveTimeInMs()))));
+ }
+
/** Get default encode algorithm by data type */
public TSEncoding getDefaultEncodingByType(TSDataType dataType) {
switch (dataType) {
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/buffer/DataBlockManager.java b/server/src/main/java/org/apache/iotdb/db/mpp/buffer/DataBlockManager.java
index 7b83950d43..42857b65e9 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/buffer/DataBlockManager.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/buffer/DataBlockManager.java
@@ -19,98 +19,318 @@
package org.apache.iotdb.db.mpp.buffer;
-import org.apache.iotdb.commons.concurrent.IoTDBThreadPoolFactory;
import org.apache.iotdb.db.mpp.memory.LocalMemoryManager;
-import org.apache.iotdb.db.mpp.schedule.task.FragmentInstanceTask;
+import org.apache.iotdb.mpp.rpc.thrift.AcknowledgeDataBlockEvent;
+import org.apache.iotdb.mpp.rpc.thrift.DataBlockService;
+import org.apache.iotdb.mpp.rpc.thrift.EndOfDataBlockEvent;
+import org.apache.iotdb.mpp.rpc.thrift.GetDataBlockRequest;
+import org.apache.iotdb.mpp.rpc.thrift.GetDataBlockResponse;
+import org.apache.iotdb.mpp.rpc.thrift.NewDataBlockEvent;
+import org.apache.iotdb.mpp.rpc.thrift.TFragmentInstanceId;
import org.apache.commons.lang3.Validate;
+import org.apache.thrift.TException;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
-import java.util.HashMap;
-import java.util.List;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.Collections;
import java.util.Map;
-import java.util.concurrent.ScheduledExecutorService;
-
-public class DataBlockManager {
-
- public static class FragmentInstanceInfo {
- private String hostname;
- private String queryId;
- private String fragmentId;
- private String instanceId;
-
- public FragmentInstanceInfo(
- String hostname, String queryId, String fragmentId, String instanceId) {
- this.hostname = Validate.notNull(hostname);
- this.queryId = Validate.notNull(queryId);
- this.fragmentId = Validate.notNull(fragmentId);
- this.instanceId = Validate.notNull(instanceId);
+import java.util.Map.Entry;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ExecutorService;
+import java.util.function.Supplier;
+
+public class DataBlockManager implements IDataBlockManager {
+
+ private static final Logger logger = LoggerFactory.getLogger(DataBlockManager.class);
+
+ public interface SourceHandleListener {
+ void onFinished(SourceHandle sourceHandle);
+
+ void onClosed(SourceHandle sourceHandle);
+ }
+
+ public interface SinkHandleListener {
+ void onFinish(SinkHandle sinkHandle);
+
+ void onClosed(SinkHandle sinkHandle);
+
+ void onAborted(SinkHandle sinkHandle);
+ }
+
+ /** Handle thrift communications. */
+ class DataBlockServiceImpl implements DataBlockService.Iface {
+
+ @Override
+ public GetDataBlockResponse getDataBlock(GetDataBlockRequest req) throws TException {
+ logger.debug(
+ "Get data block request received, for data blocks whose sequence ID in [{}, {}) from {}.",
+ req.getStartSequenceId(),
+ req.getEndSequenceId(),
+ req.getSourceFragmentInstanceId());
+ if (!sinkHandles.containsKey(req.getSourceFragmentInstanceId())) {
+ throw new TException(
+ "Source fragment instance not found. Fragment instance ID: "
+ + req.getSourceFragmentInstanceId()
+ + ".");
+ }
+ GetDataBlockResponse resp = new GetDataBlockResponse();
+ SinkHandle sinkHandle = sinkHandles.get(req.getSourceFragmentInstanceId());
+ for (int i = req.getStartSequenceId(); i < req.getEndSequenceId(); i++) {
+ ByteBuffer serializedTsBlock = sinkHandle.getSerializedTsBlock(i);
+ resp.addToTsBlocks(serializedTsBlock);
+ }
+ return resp;
}
- public String getHostname() {
- return hostname;
+ @Override
+ public void onAcknowledgeDataBlockEvent(AcknowledgeDataBlockEvent e) throws TException {
+ logger.debug(
+ "Acknowledge data block event received, for data blocks whose sequence ID in [{}, {}) from {}.",
+ e.getStartSequenceId(),
+ e.getEndSequenceId(),
+ e.getSourceFragmentInstanceId());
+ if (!sinkHandles.containsKey(e.getSourceFragmentInstanceId())) {
+ throw new TException(
+ "Source fragment instance not found. Fragment instance ID: "
+ + e.getSourceFragmentInstanceId()
+ + ".");
+ }
+ sinkHandles
+ .get(e.getSourceFragmentInstanceId())
+ .acknowledgeTsBlock(e.getStartSequenceId(), e.getEndSequenceId());
}
- public String getQueryId() {
- return queryId;
+ @Override
+ public void onNewDataBlockEvent(NewDataBlockEvent e) throws TException {
+ logger.debug(
+ "New data block event received, for plan node {} of {} from {}.",
+ e.getTargetPlanNodeId(),
+ e.getTargetFragmentInstanceId(),
+ e.getSourceFragmentInstanceId());
+ if (!sourceHandles.containsKey(e.getTargetFragmentInstanceId())
+ || !sourceHandles
+ .get(e.getTargetFragmentInstanceId())
+ .containsKey(e.getTargetPlanNodeId())
+ || sourceHandles
+ .get(e.getTargetFragmentInstanceId())
+ .get(e.getTargetPlanNodeId())
+ .isClosed()) {
+ throw new TException(
+ "Target fragment instance not found. Fragment instance ID: "
+ + e.getTargetFragmentInstanceId()
+ + ".");
+ }
+
+ SourceHandle sourceHandle =
+ sourceHandles.get(e.getTargetFragmentInstanceId()).get(e.getTargetPlanNodeId());
+ sourceHandle.updatePendingDataBlockInfo(e.getStartSequenceId(), e.getBlockSizes());
}
- public String getFragmentId() {
- return fragmentId;
+ @Override
+ public void onEndOfDataBlockEvent(EndOfDataBlockEvent e) throws TException {
+ logger.debug(
+ "End of data block event received, for plan node {} of {} from {}.",
+ e.getTargetPlanNodeId(),
+ e.getTargetFragmentInstanceId(),
+ e.getSourceFragmentInstanceId());
+ if (!sourceHandles.containsKey(e.getTargetFragmentInstanceId())
+ || !sourceHandles
+ .get(e.getTargetFragmentInstanceId())
+ .containsKey(e.getTargetPlanNodeId())
+ || sourceHandles
+ .get(e.getTargetFragmentInstanceId())
+ .get(e.getTargetPlanNodeId())
+ .isClosed()) {
+ throw new TException(
+ "Target fragment instance not found. Fragment instance ID: "
+ + e.getTargetFragmentInstanceId()
+ + ".");
+ }
+ SourceHandle sourceHandle =
+ sourceHandles
+ .getOrDefault(e.getTargetFragmentInstanceId(), Collections.emptyMap())
+ .get(e.getTargetPlanNodeId());
+ sourceHandle.setNoMoreTsBlocks(e.getLastSequenceId());
+ }
+ }
+
+ /** Listen to the state changes of a source handle. */
+ class SourceHandleListenerImpl implements SourceHandleListener {
+ @Override
+ public void onFinished(SourceHandle sourceHandle) {
+ logger.info("Release resources of finished source handle {}", sourceHandle);
+ if (!sourceHandles.containsKey(sourceHandle.getLocalFragmentInstanceId())
+ || !sourceHandles
+ .get(sourceHandle.getLocalFragmentInstanceId())
+ .containsKey(sourceHandle.getLocalPlanNodeId())) {
+ logger.info(
+ "Resources of finished source handle {} has already been released", sourceHandle);
+ }
+ sourceHandles
+ .get(sourceHandle.getLocalFragmentInstanceId())
+ .remove(sourceHandle.getLocalPlanNodeId());
+ if (sourceHandles.get(sourceHandle.getLocalFragmentInstanceId()).isEmpty()) {
+ sourceHandles.remove(sourceHandle.getLocalFragmentInstanceId());
+ }
}
- public String getInstanceId() {
- return instanceId;
+ @Override
+ public void onClosed(SourceHandle sourceHandle) {
+ onFinished(sourceHandle);
}
}
- /**
- * Create a sink handle.
- *
- * @param local The {@link FragmentInstanceInfo} of local fragment instance.
- * @param remote The {@link FragmentInstanceInfo} of downstream instance.
- */
- public ISinkHandle createSinkHandle(FragmentInstanceInfo local, FragmentInstanceInfo remote) {
- throw new UnsupportedOperationException();
+ /** Listen to the state changes of a sink handle. */
+ class SinkHandleListenerImpl implements SinkHandleListener {
+
+ @Override
+ public void onFinish(SinkHandle sinkHandle) {
+ logger.info("Release resources of finished sink handle {}", sourceHandles);
+ if (!sinkHandles.containsKey(sinkHandle.getLocalFragmentInstanceId())) {
+ logger.info("Resources of finished sink handle {} has already been released", sinkHandle);
+ }
+ sinkHandles.remove(sinkHandle.getLocalFragmentInstanceId());
+ }
+
+ @Override
+ public void onClosed(SinkHandle sinkHandle) {}
+
+ @Override
+ public void onAborted(SinkHandle sinkHandle) {
+ logger.info("Release resources of aborted sink handle {}", sourceHandles);
+ if (!sinkHandles.containsKey(sinkHandle.getLocalFragmentInstanceId())) {
+ logger.info("Resources of aborted sink handle {} has already been released", sinkHandle);
+ }
+ sinkHandles.remove(sinkHandle.getLocalFragmentInstanceId());
+ }
}
- public ISinkHandle createPartitionedSinkHandle(
- FragmentInstanceInfo local, List<FragmentInstanceInfo> remotes) {
- throw new UnsupportedOperationException();
+ private final LocalMemoryManager localMemoryManager;
+ private final Supplier<TsBlockSerde> tsBlockSerdeFactory;
+ private final ExecutorService executorService;
+ private final DataBlockServiceClientFactory clientFactory;
+ private final Map<TFragmentInstanceId, Map<String, SourceHandle>> sourceHandles;
+ private final Map<TFragmentInstanceId, SinkHandle> sinkHandles;
+
+ private DataBlockServiceImpl dataBlockService;
+
+ public DataBlockManager(
+ LocalMemoryManager localMemoryManager,
+ Supplier<TsBlockSerde> tsBlockSerdeFactory,
+ ExecutorService executorService,
+ DataBlockServiceClientFactory clientFactory) {
+ this.localMemoryManager = Validate.notNull(localMemoryManager);
+ this.tsBlockSerdeFactory = Validate.notNull(tsBlockSerdeFactory);
+ this.executorService = Validate.notNull(executorService);
+ this.clientFactory = Validate.notNull(clientFactory);
+ sourceHandles = new ConcurrentHashMap<>();
+ sinkHandles = new ConcurrentHashMap<>();
}
- /**
- * Create a source handle.
- *
- * @param local The {@link FragmentInstanceInfo} of local fragment instance.
- * @param remote The {@link FragmentInstanceInfo} of downstream instance.
- */
- public ISourceHandle createSourceHandle(FragmentInstanceInfo local, FragmentInstanceInfo remote) {
- throw new UnsupportedOperationException();
+ public DataBlockServiceImpl getOrCreateDataBlockServiceImpl() {
+ if (dataBlockService != null) {
+ dataBlockService = new DataBlockServiceImpl();
+ }
+ return dataBlockService;
+ }
+
+ @Override
+ public ISinkHandle createSinkHandle(
+ TFragmentInstanceId localFragmentInstanceId,
+ String remoteHostname,
+ TFragmentInstanceId remoteFragmentInstanceId,
+ String remotePlanNodeId)
+ throws IOException {
+ if (sinkHandles.containsKey(localFragmentInstanceId)) {
+ throw new IllegalStateException("Sink handle for " + localFragmentInstanceId + " exists.");
+ }
+
+ logger.info(
+ "Create sink handle to plan node {} of {} for {}",
+ remotePlanNodeId,
+ remoteFragmentInstanceId,
+ localFragmentInstanceId);
+
+ SinkHandle sinkHandle =
+ new SinkHandle(
+ remoteHostname,
+ remoteFragmentInstanceId,
+ remotePlanNodeId,
+ localFragmentInstanceId,
+ localMemoryManager,
+ executorService,
+ clientFactory.getDataBlockServiceClient(remoteHostname, 7777),
+ tsBlockSerdeFactory.get(),
+ new SinkHandleListenerImpl());
+ sinkHandles.put(localFragmentInstanceId, sinkHandle);
+ return sinkHandle;
+ }
+
+ @Override
+ public ISourceHandle createSourceHandle(
+ TFragmentInstanceId localFragmentInstanceId,
+ String localPlanNodeId,
+ String remoteHostname,
+ TFragmentInstanceId remoteFragmentInstanceId)
+ throws IOException {
+ if (sourceHandles.containsKey(localFragmentInstanceId)
+ && sourceHandles.get(localFragmentInstanceId).containsKey(localPlanNodeId)) {
+ throw new IllegalStateException(
+ "Source handle for plan node "
+ + localPlanNodeId
+ + " of "
+ + localFragmentInstanceId
+ + " exists.");
+ }
+
+ logger.info(
+ "Create source handle from {} for plan node {} of {}",
+ remoteFragmentInstanceId,
+ localPlanNodeId,
+ localFragmentInstanceId);
+
+ SourceHandle sourceHandle =
+ new SourceHandle(
+ remoteHostname,
+ remoteFragmentInstanceId,
+ localFragmentInstanceId,
+ localPlanNodeId,
+ localMemoryManager,
+ executorService,
+ // TODO: hard coded port.
+ clientFactory.getDataBlockServiceClient(remoteHostname, 7777),
+ tsBlockSerdeFactory.get(),
+ new SourceHandleListenerImpl());
+ sourceHandles
+ .computeIfAbsent(localFragmentInstanceId, key -> new ConcurrentHashMap<>())
+ .put(localPlanNodeId, sourceHandle);
+ return sourceHandle;
}
/**
- * Release all the related resources, including data blocks that are not yet sent to downstream
+ * Release all the related resources, including data blocks that are not yet fetched by downstream
* fragment instances.
*
* <p>This method should be called when a fragment instance finished in an abnormal state.
*/
- public void forceDeregisterFragmentInstance(FragmentInstanceTask task) {
- throw new UnsupportedOperationException();
- }
-
- LocalMemoryManager localMemoryManager;
- ScheduledExecutorService scheduledExecutorService;
- DataBlockServiceClientFactory clientFactory;
- Map<String, Map<String, Map<String, ISourceHandle>>> sourceHandles;
- Map<String, Map<String, Map<String, ISinkHandle>>> sinkHandles;
-
- public DataBlockManager(LocalMemoryManager localMemoryManager) {
- this.localMemoryManager = Validate.notNull(localMemoryManager);
- // TODO: configurable number of threads
- scheduledExecutorService =
- IoTDBThreadPoolFactory.newScheduledThreadPoolWithDaemon(5, "get-data-block");
- clientFactory = new DataBlockServiceClientFactory();
- sourceHandles = new HashMap<>();
- sinkHandles = new HashMap<>();
+ public void forceDeregisterFragmentInstance(TFragmentInstanceId fragmentInstanceId) {
+ logger.info("Force deregister fragment instance {}", fragmentInstanceId);
+ if (sinkHandles.containsKey(fragmentInstanceId)) {
+ ISinkHandle sinkHandle = sinkHandles.get(fragmentInstanceId);
+ logger.info("Abort sink handle {}", sinkHandle);
+ sinkHandle.abort();
+ sinkHandles.remove(fragmentInstanceId);
+ }
+ if (sourceHandles.containsKey(fragmentInstanceId)) {
+ Map<String, SourceHandle> planNodeIdToSourceHandle = sourceHandles.get(fragmentInstanceId);
+ for (Entry<String, SourceHandle> entry : planNodeIdToSourceHandle.entrySet()) {
+ logger.info("Close source handle {}", sourceHandles);
+ entry.getValue().close();
+ }
+ sourceHandles.remove(fragmentInstanceId);
+ }
}
}
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/buffer/DataBlockManagerService.java b/server/src/main/java/org/apache/iotdb/db/mpp/buffer/DataBlockManagerService.java
deleted file mode 100644
index 1e4168ce01..0000000000
--- a/server/src/main/java/org/apache/iotdb/db/mpp/buffer/DataBlockManagerService.java
+++ /dev/null
@@ -1,90 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.iotdb.db.mpp.buffer;
-
-import org.apache.iotdb.commons.concurrent.ThreadName;
-import org.apache.iotdb.commons.exception.runtime.RPCServiceException;
-import org.apache.iotdb.commons.service.ServiceType;
-import org.apache.iotdb.commons.service.ThriftService;
-import org.apache.iotdb.commons.service.ThriftServiceThread;
-import org.apache.iotdb.mpp.rpc.thrift.DataBlockService.Processor;
-
-public class DataBlockManagerService extends ThriftService {
-
- private DataBlockServiceImpl impl;
-
- @Override
- public ThriftService getImplementation() {
- return DataBlockManagerServiceHolder.INSTANCE;
- }
-
- @Override
- public void initTProcessor()
- throws ClassNotFoundException, IllegalAccessException, InstantiationException {
- impl = new DataBlockServiceImpl();
- processor = new Processor<>(impl);
- }
-
- @Override
- public void initThriftServiceThread()
- throws IllegalAccessException, InstantiationException, ClassNotFoundException {
- try {
- thriftServiceThread =
- new ThriftServiceThread(
- processor,
- getID().getName(),
- ThreadName.DATA_BLOCK_MANAGER_CLIENT.getName(),
- getBindIP(),
- getBindPort(),
- // TODO: hard coded maxWorkerThreads & timeoutSecond
- 32,
- 60,
- new DataBlockManagerServiceThriftHandler(),
- // TODO: hard coded compress strategy
- true);
- } catch (RPCServiceException e) {
- throw new IllegalAccessException(e.getMessage());
- }
- thriftServiceThread.setName(ThreadName.DATA_BLOCK_MANAGER_SERVICE.getName());
- }
-
- @Override
- public String getBindIP() {
- // TODO: hard coded bind IP.
- return "0.0.0.0";
- }
-
- @Override
- public int getBindPort() {
- // TODO: hard coded bind port.
- return 7777;
- }
-
- @Override
- public ServiceType getID() {
- return ServiceType.DATA_BLOCK_MANAGER_SERVICE;
- }
-
- private static class DataBlockManagerServiceHolder {
- private static final DataBlockManagerService INSTANCE = new DataBlockManagerService();
-
- private DataBlockManagerServiceHolder() {}
- }
-}
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/buffer/DataBlockService.java b/server/src/main/java/org/apache/iotdb/db/mpp/buffer/DataBlockService.java
new file mode 100644
index 0000000000..e2cd12461e
--- /dev/null
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/buffer/DataBlockService.java
@@ -0,0 +1,137 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iotdb.db.mpp.buffer;
+
+import org.apache.iotdb.commons.concurrent.IoTDBThreadPoolFactory;
+import org.apache.iotdb.commons.concurrent.IoTThreadFactory;
+import org.apache.iotdb.commons.concurrent.ThreadName;
+import org.apache.iotdb.commons.exception.runtime.RPCServiceException;
+import org.apache.iotdb.commons.service.ServiceType;
+import org.apache.iotdb.commons.service.ThriftService;
+import org.apache.iotdb.commons.service.ThriftServiceThread;
+import org.apache.iotdb.db.conf.IoTDBConfig;
+import org.apache.iotdb.db.conf.IoTDBDescriptor;
+import org.apache.iotdb.db.mpp.memory.LocalMemoryManager;
+import org.apache.iotdb.mpp.rpc.thrift.DataBlockService.Processor;
+
+import org.apache.commons.lang3.Validate;
+
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.LinkedBlockingQueue;
+import java.util.concurrent.TimeUnit;
+
+public class DataBlockService extends ThriftService {
+
+ private LocalMemoryManager localMemoryManager;
+ private TsBlockSerdeFactory tsBlockSerdeFactory;
+ private DataBlockManager dataBlockManager;
+ private ExecutorService executorService;
+ private DataBlockServiceClientFactory clientFactory;
+
+ private DataBlockService() {}
+
+ @Override
+ public ThriftService getImplementation() {
+ return DataBlockManagerServiceHolder.INSTANCE;
+ }
+
+ @Override
+ public void initTProcessor()
+ throws ClassNotFoundException, IllegalAccessException, InstantiationException {
+ IoTDBConfig config = IoTDBDescriptor.getInstance().getConfig();
+ executorService =
+ IoTDBThreadPoolFactory.newThreadPool(
+ config.getDataBlockManagerCorePoolSize(),
+ config.getDataBlockManagerMaxPoolSize(),
+ config.getDataBlockManagerKeepAliveTimeInMs(),
+ TimeUnit.MILLISECONDS,
+ // TODO: Use a priority queue.
+ new LinkedBlockingQueue<>(),
+ new IoTThreadFactory("data-block-manager-task-executors"),
+ "data-block-manager-task-executors");
+ clientFactory = new DataBlockServiceClientFactory();
+ this.dataBlockManager =
+ new DataBlockManager(
+ localMemoryManager, tsBlockSerdeFactory, executorService, clientFactory);
+ processor = new Processor<>(dataBlockManager.getOrCreateDataBlockServiceImpl());
+ }
+
+ public void setLocalMemoryManager(LocalMemoryManager localMemoryManager) {
+ this.localMemoryManager = Validate.notNull(localMemoryManager);
+ }
+
+ public void setTsBlockSerdeFactory(TsBlockSerdeFactory tsBlockSerdeFactory) {
+ this.tsBlockSerdeFactory = Validate.notNull(tsBlockSerdeFactory);
+ }
+
+ public DataBlockManager getDataBlockManager() {
+ return dataBlockManager;
+ }
+
+ @Override
+ public void initThriftServiceThread()
+ throws IllegalAccessException, InstantiationException, ClassNotFoundException {
+ try {
+ IoTDBConfig config = IoTDBDescriptor.getInstance().getConfig();
+ thriftServiceThread =
+ new ThriftServiceThread(
+ processor,
+ getID().getName(),
+ ThreadName.DATA_BLOCK_MANAGER_CLIENT.getName(),
+ getBindIP(),
+ getBindPort(),
+ config.getRpcMaxConcurrentClientNum(),
+ config.getThriftServerAwaitTimeForStopService(),
+ new DataBlockServiceThriftHandler(),
+ // TODO: hard coded compress strategy
+ true);
+ } catch (RPCServiceException e) {
+ throw new IllegalAccessException(e.getMessage());
+ }
+ thriftServiceThread.setName(ThreadName.DATA_BLOCK_MANAGER_SERVICE.getName());
+ }
+
+ @Override
+ public String getBindIP() {
+ return IoTDBDescriptor.getInstance().getConfig().getRpcAddress();
+ }
+
+ @Override
+ public int getBindPort() {
+ return IoTDBDescriptor.getInstance().getConfig().getDataBlockManagerPort();
+ }
+
+ @Override
+ public ServiceType getID() {
+ return ServiceType.DATA_BLOCK_MANAGER_SERVICE;
+ }
+
+ @Override
+ public void stop() {
+ super.stop();
+ executorService.shutdown();
+ }
+
+ private static class DataBlockManagerServiceHolder {
+ private static final DataBlockService INSTANCE = new DataBlockService();
+
+ private DataBlockManagerServiceHolder() {}
+ }
+}
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/buffer/DataBlockServiceClientFactory.java b/server/src/main/java/org/apache/iotdb/db/mpp/buffer/DataBlockServiceClientFactory.java
index 4f4765c5b8..f289662d80 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/buffer/DataBlockServiceClientFactory.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/buffer/DataBlockServiceClientFactory.java
@@ -24,21 +24,27 @@ import org.apache.iotdb.mpp.rpc.thrift.DataBlockService;
import org.apache.iotdb.mpp.rpc.thrift.DataBlockService.Client;
import org.apache.iotdb.rpc.RpcTransportFactory;
+import org.apache.thrift.TException;
import org.apache.thrift.protocol.TBinaryProtocol;
import org.apache.thrift.protocol.TCompactProtocol;
import org.apache.thrift.protocol.TProtocol;
import org.apache.thrift.transport.TTransport;
-import org.apache.thrift.transport.TTransportException;
+
+import java.io.IOException;
public class DataBlockServiceClientFactory {
public DataBlockService.Client getDataBlockServiceClient(String hostname, int port)
- throws TTransportException {
- TTransport transport = RpcTransportFactory.INSTANCE.getTransportWithNoTimeout(hostname, port);
- transport.open();
- TProtocol protocol =
- IoTDBDescriptor.getInstance().getConfig().isRpcThriftCompressionEnable()
- ? new TCompactProtocol(transport)
- : new TBinaryProtocol(transport);
- return new Client(protocol);
+ throws IOException {
+ try {
+ TTransport transport = RpcTransportFactory.INSTANCE.getTransportWithNoTimeout(hostname, port);
+ transport.open();
+ TProtocol protocol =
+ IoTDBDescriptor.getInstance().getConfig().isRpcThriftCompressionEnable()
+ ? new TCompactProtocol(transport)
+ : new TBinaryProtocol(transport);
+ return new Client(protocol);
+ } catch (TException e) {
+ throw new IOException(e);
+ }
}
}
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/buffer/DataBlockServiceImpl.java b/server/src/main/java/org/apache/iotdb/db/mpp/buffer/DataBlockServiceImpl.java
deleted file mode 100644
index 15352889a2..0000000000
--- a/server/src/main/java/org/apache/iotdb/db/mpp/buffer/DataBlockServiceImpl.java
+++ /dev/null
@@ -1,50 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.iotdb.db.mpp.buffer;
-
-import org.apache.iotdb.mpp.rpc.thrift.DataBlockService;
-import org.apache.iotdb.mpp.rpc.thrift.EndOfDataBlockEvent;
-import org.apache.iotdb.mpp.rpc.thrift.GetDataBlockRequest;
-import org.apache.iotdb.mpp.rpc.thrift.GetDataBlockResponse;
-import org.apache.iotdb.mpp.rpc.thrift.NewDataBlockEvent;
-
-import org.apache.thrift.TException;
-
-public class DataBlockServiceImpl implements DataBlockService.Iface {
-
- public DataBlockServiceImpl() {
- super();
- }
-
- @Override
- public GetDataBlockResponse getDataBlock(GetDataBlockRequest req) throws TException {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public void onNewDataBlockEvent(NewDataBlockEvent e) throws TException {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public void onEndOfDataBlockEvent(EndOfDataBlockEvent e) throws TException {
- throw new UnsupportedOperationException();
- }
-}
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/buffer/DataBlockManagerServiceThriftHandler.java b/server/src/main/java/org/apache/iotdb/db/mpp/buffer/DataBlockServiceThriftHandler.java
similarity index 94%
rename from server/src/main/java/org/apache/iotdb/db/mpp/buffer/DataBlockManagerServiceThriftHandler.java
rename to server/src/main/java/org/apache/iotdb/db/mpp/buffer/DataBlockServiceThriftHandler.java
index d7d40b83d1..334268ba71 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/buffer/DataBlockManagerServiceThriftHandler.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/buffer/DataBlockServiceThriftHandler.java
@@ -24,7 +24,7 @@ import org.apache.thrift.server.ServerContext;
import org.apache.thrift.server.TServerEventHandler;
import org.apache.thrift.transport.TTransport;
-public class DataBlockManagerServiceThriftHandler implements TServerEventHandler {
+public class DataBlockServiceThriftHandler implements TServerEventHandler {
@Override
public void preServe() {}
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/buffer/IDataBlockManager.java b/server/src/main/java/org/apache/iotdb/db/mpp/buffer/IDataBlockManager.java
index 2d4a1154f7..c8d6ba714e 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/buffer/IDataBlockManager.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/buffer/IDataBlockManager.java
@@ -21,6 +21,10 @@ package org.apache.iotdb.db.mpp.buffer;
import org.apache.iotdb.mpp.rpc.thrift.TFragmentInstanceId;
+import org.apache.thrift.transport.TTransportException;
+
+import java.io.IOException;
+
public interface IDataBlockManager {
/**
* Create a sink handle who sends data blocks to a remote downstream fragment instance in async
@@ -31,30 +35,32 @@ public interface IDataBlockManager {
* @param remoteHostname Hostname of the remote fragment instance where the data blocks should be
* sent to.
* @param remoteFragmentInstanceId ID of the remote fragment instance.
- * @param remoteOperatorId The sink operator ID of the remote fragment instance.
+ * @param remotePlanNodeId The sink plan node ID of the remote fragment instance.
*/
ISinkHandle createSinkHandle(
TFragmentInstanceId localFragmentInstanceId,
String remoteHostname,
TFragmentInstanceId remoteFragmentInstanceId,
- String remoteOperatorId);
+ String remotePlanNodeId)
+ throws TTransportException, IOException;
/**
- * Create a source handle who fetches data blocks from a remote upstream fragment instance for an
- * operator of a local fragment instance in async manner.
+ * Create a source handle who fetches data blocks from a remote upstream fragment instance for a
+ * plan node of a local fragment instance in async manner.
*
* @param localFragmentInstanceId ID of the local fragment instance who receives data blocks from
* the source handle.
- * @param localOperatorId The local sink operator ID.
+ * @param localPlanNodeId The local sink plan node ID.
* @param remoteHostname Hostname of the remote fragment instance where the data blocks should be
* received from.
* @param remoteFragmentInstanceId ID of the remote fragment instance.
*/
ISourceHandle createSourceHandle(
TFragmentInstanceId localFragmentInstanceId,
- String localOperatorId,
+ String localPlanNodeId,
String remoteHostname,
- TFragmentInstanceId remoteFragmentInstanceId);
+ TFragmentInstanceId remoteFragmentInstanceId)
+ throws IOException;
/**
* Release all the related resources of a fragment instance, including data blocks that are not
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/buffer/ISinkHandle.java b/server/src/main/java/org/apache/iotdb/db/mpp/buffer/ISinkHandle.java
index 6b4e07cc30..6300c5beef 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/buffer/ISinkHandle.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/buffer/ISinkHandle.java
@@ -22,22 +22,33 @@ import org.apache.iotdb.tsfile.read.common.block.TsBlock;
import com.google.common.util.concurrent.ListenableFuture;
+import java.io.IOException;
+import java.util.List;
+
public interface ISinkHandle extends AutoCloseable {
+ /** Get the total amount of memory used by buffered tsblocks. */
+ long getBufferRetainedSizeInBytes();
+
+ /** Get the number of buffered tsblocks. */
+ int getNumOfBufferedTsBlocks();
+
/** Get a future that will be completed when the output buffer is not full. */
ListenableFuture<Void> isFull();
/**
- * Send a {@link TsBlock} to an unpartitioned output buffer. If no-more-tsblocks has been set, the
- * send tsblock call is ignored. This can happen with limit queries.
+ * Send a list of tsblocks to an unpartitioned output buffer. If no-more-tsblocks has been set,
+ * the send tsblock call is ignored. This can happen with limit queries. A {@link
+ * RuntimeException} will be thrown if any exception happened * during the data transmission.
*/
- void send(TsBlock tsBlock);
+ void send(List<TsBlock> tsBlocks) throws IOException;
/**
* Send a {@link TsBlock} to a specific partition. If no-more-tsblocks has been set, the send
- * tsblock call is ignored. This can happen with limit queries.
+ * tsblock call is ignored. This can happen with limit queries. A {@link RuntimeException} will be
+ * thrown if any exception happened * during the data transmission.
*/
- void send(int partition, TsBlock tsBlock);
+ void send(int partition, List<TsBlock> tsBlocks) throws IOException;
/**
* Notify the handle that no more tsblocks will be sent. Any future calls to send a tsblock should
@@ -45,12 +56,22 @@ public interface ISinkHandle extends AutoCloseable {
*/
void setNoMoreTsBlocks();
+ /** If the handle is closed. */
+ public boolean isClosed();
+
/**
- * Close the handle. Keep the output buffer until all tsblocks are fetched by downstream
+ * If no more tsblocks will be sent and all the tsblocks have been fetched by downstream fragment
* instances.
*/
+ public boolean isFinished();
+
+ /**
+ * Close the handle. The output buffer will not be cleared until all tsblocks are fetched by
+ * downstream instances. A {@link RuntimeException} will be thrown if any exception happened
+ * during the data transmission.
+ */
@Override
- void close();
+ void close() throws IOException;
/** Abort the sink handle, discarding all tsblocks which may still be in memory buffer. */
void abort();
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/buffer/ISourceHandle.java b/server/src/main/java/org/apache/iotdb/db/mpp/buffer/ISourceHandle.java
index ef2919c787..dfb9257b97 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/buffer/ISourceHandle.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/buffer/ISourceHandle.java
@@ -23,18 +23,31 @@ import org.apache.iotdb.tsfile.read.common.block.TsBlock;
import com.google.common.util.concurrent.ListenableFuture;
import java.io.Closeable;
+import java.io.IOException;
public interface ISourceHandle extends Closeable {
- /** Get a {@link TsBlock} from the input buffer. */
- TsBlock receive();
+ /** Get the total amount of memory used by buffered tsblocks. */
+ long getBufferRetainedSizeInBytes();
- /** Check if there are more tsblocks. */
+ /**
+ * Get a {@link TsBlock}. If the source handle is blocked, a null will be returned. A {@link
+ * RuntimeException} will be thrown if any error happened.
+ */
+ TsBlock receive() throws IOException;
+
+ /** If there are more tsblocks. */
boolean isFinished();
- /** Get a future that will be completed when the input buffer is not empty. */
+ /**
+ * Get a future that will be completed when the input buffer is not empty. The future will not
+ * complete even when the handle is finished or closed.
+ */
ListenableFuture<Void> isBlocked();
+ /** If this handle is closed. */
+ boolean isClosed();
+
/** Close the handle. Discarding all tsblocks which may still be in memory buffer. */
@Override
void close();
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/buffer/SinkHandle.java b/server/src/main/java/org/apache/iotdb/db/mpp/buffer/SinkHandle.java
new file mode 100644
index 0000000000..a3e106b5f9
--- /dev/null
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/buffer/SinkHandle.java
@@ -0,0 +1,365 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iotdb.db.mpp.buffer;
+
+import org.apache.iotdb.db.mpp.buffer.DataBlockManager.SinkHandleListener;
+import org.apache.iotdb.db.mpp.memory.LocalMemoryManager;
+import org.apache.iotdb.mpp.rpc.thrift.DataBlockService;
+import org.apache.iotdb.mpp.rpc.thrift.EndOfDataBlockEvent;
+import org.apache.iotdb.mpp.rpc.thrift.NewDataBlockEvent;
+import org.apache.iotdb.mpp.rpc.thrift.TFragmentInstanceId;
+import org.apache.iotdb.tsfile.read.common.block.TsBlock;
+
+import com.google.common.util.concurrent.ListenableFuture;
+import org.apache.commons.lang3.Validate;
+import org.apache.thrift.TException;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.Iterator;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Map.Entry;
+import java.util.StringJoiner;
+import java.util.concurrent.ExecutorService;
+
+import static com.google.common.util.concurrent.Futures.immediateFuture;
+import static com.google.common.util.concurrent.Futures.nonCancellationPropagating;
+
+public class SinkHandle implements ISinkHandle {
+
+ private static final Logger logger = LoggerFactory.getLogger(SinkHandle.class);
+
+ public static final int MAX_ATTEMPT_TIMES = 3;
+
+ private final String remoteHostname;
+ private final TFragmentInstanceId remoteFragmentInstanceId;
+ private final String remotePlanNodeId;
+ private final TFragmentInstanceId localFragmentInstanceId;
+ private final LocalMemoryManager localMemoryManager;
+ private final ExecutorService executorService;
+ private final DataBlockService.Client client;
+ private final TsBlockSerde serde;
+ private final SinkHandleListener sinkHandleListener;
+
+ // Use LinkedHashMap to meet 2 needs,
+ // 1. Predictable iteration order so that removing buffered tsblocks can be efficient.
+ // 2. Fast lookup.
+ private final LinkedHashMap<Integer, TsBlock> sequenceIdToTsBlock = new LinkedHashMap<>();
+
+ private volatile ListenableFuture<Void> blocked = immediateFuture(null);
+ private int nextSequenceId = 0;
+ private long bufferRetainedSizeInBytes;
+ private boolean closed;
+ private boolean noMoreTsBlocks;
+ private Throwable throwable;
+
+ public SinkHandle(
+ String remoteHostname,
+ TFragmentInstanceId remoteFragmentInstanceId,
+ String remotePlanNodeId,
+ TFragmentInstanceId localFragmentInstanceId,
+ LocalMemoryManager localMemoryManager,
+ ExecutorService executorService,
+ DataBlockService.Client client,
+ TsBlockSerde serde,
+ SinkHandleListener sinkHandleListener) {
+ this.remoteHostname = Validate.notNull(remoteHostname);
+ this.remoteFragmentInstanceId = Validate.notNull(remoteFragmentInstanceId);
+ this.remotePlanNodeId = Validate.notNull(remotePlanNodeId);
+ this.localFragmentInstanceId = Validate.notNull(localFragmentInstanceId);
+ this.localMemoryManager = Validate.notNull(localMemoryManager);
+ this.executorService = Validate.notNull(executorService);
+ this.client = Validate.notNull(client);
+ this.serde = Validate.notNull(serde);
+ this.sinkHandleListener = Validate.notNull(sinkHandleListener);
+ }
+
+ @Override
+ public ListenableFuture<Void> isFull() {
+ if (closed) {
+ throw new IllegalStateException("Sink handle is closed.");
+ }
+ return nonCancellationPropagating(blocked);
+ }
+
+ private void submitSendNewDataBlockEventTask(int startSequenceId, List<Long> blockSizes) {
+ executorService.submit(new SendNewDataBlockEventTask(startSequenceId, blockSizes));
+ }
+
+ @Override
+ public void send(List<TsBlock> tsBlocks) throws IOException {
+ Validate.notNull(tsBlocks, "tsBlocks is null");
+ if (throwable != null) {
+ throw new IOException(throwable);
+ }
+ if (closed) {
+ throw new IllegalStateException("Sink handle is closed.");
+ }
+ if (!blocked.isDone()) {
+ throw new IllegalStateException("Sink handle is blocked.");
+ }
+ if (noMoreTsBlocks) {
+ return;
+ }
+
+ long retainedSizeInBytes = 0L;
+ for (TsBlock tsBlock : tsBlocks) {
+ retainedSizeInBytes += tsBlock.getRetainedSizeInBytes();
+ }
+ int startSequenceId;
+ List<Long> tsBlockSizes = new ArrayList<>();
+ synchronized (this) {
+ startSequenceId = nextSequenceId;
+ blocked =
+ localMemoryManager
+ .getQueryPool()
+ .reserve(localFragmentInstanceId.getQueryId(), retainedSizeInBytes);
+ bufferRetainedSizeInBytes += retainedSizeInBytes;
+ for (TsBlock tsBlock : tsBlocks) {
+ sequenceIdToTsBlock.put(nextSequenceId, tsBlock);
+ nextSequenceId += 1;
+ }
+ for (int i = startSequenceId; i < nextSequenceId; i++) {
+ tsBlockSizes.add(sequenceIdToTsBlock.get(i).getRetainedSizeInBytes());
+ }
+ }
+
+ // TODO: consider merge multiple NewDataBlockEvent for less network traffic.
+ submitSendNewDataBlockEventTask(startSequenceId, tsBlockSizes);
+ }
+
+ @Override
+ public void send(int partition, List<TsBlock> tsBlocks) {
+ throw new UnsupportedOperationException();
+ }
+
+ private void sendEndOfDataBlockEvent() throws TException {
+ logger.debug(
+ "Send end of data block event to plan node {} of {}.",
+ remotePlanNodeId,
+ remoteFragmentInstanceId);
+ int attempt = 0;
+ EndOfDataBlockEvent endOfDataBlockEvent =
+ new EndOfDataBlockEvent(
+ remoteFragmentInstanceId,
+ remotePlanNodeId,
+ localFragmentInstanceId,
+ nextSequenceId - 1);
+ while (attempt < MAX_ATTEMPT_TIMES) {
+ attempt += 1;
+ try {
+ client.onEndOfDataBlockEvent(endOfDataBlockEvent);
+ break;
+ } catch (TException e) {
+ logger.error(
+ "Failed to send end of data block event to plan node {} of {} due to {}, attempt times: {}",
+ remotePlanNodeId,
+ remoteFragmentInstanceId,
+ e.getMessage(),
+ attempt);
+ if (attempt == MAX_ATTEMPT_TIMES) {
+ throw e;
+ }
+ }
+ }
+ }
+
+ @Override
+ public void close() throws IOException {
+ logger.info("Sink handle {} is being closed.", this);
+ if (throwable != null) {
+ throw new IOException(throwable);
+ }
+ if (closed) {
+ return;
+ }
+ synchronized (this) {
+ closed = true;
+ noMoreTsBlocks = true;
+ }
+ sinkHandleListener.onClosed(this);
+ try {
+ sendEndOfDataBlockEvent();
+ } catch (TException e) {
+ throw new IOException(e);
+ }
+ logger.info("Sink handle {} is closed.", this);
+ }
+
+ @Override
+ public void abort() {
+ logger.info("Sink handle {} is being aborted.", this);
+ synchronized (this) {
+ sequenceIdToTsBlock.clear();
+ closed = true;
+ localMemoryManager
+ .getQueryPool()
+ .free(localFragmentInstanceId.getQueryId(), bufferRetainedSizeInBytes);
+ bufferRetainedSizeInBytes = 0;
+ }
+ sinkHandleListener.onAborted(this);
+ logger.info("Sink handle {} is aborted", this);
+ }
+
+ @Override
+ public synchronized void setNoMoreTsBlocks() {
+ noMoreTsBlocks = true;
+ }
+
+ @Override
+ public boolean isClosed() {
+ return closed;
+ }
+
+ @Override
+ public boolean isFinished() {
+ return throwable == null && noMoreTsBlocks && sequenceIdToTsBlock.isEmpty();
+ }
+
+ @Override
+ public long getBufferRetainedSizeInBytes() {
+ return bufferRetainedSizeInBytes;
+ }
+
+ @Override
+ public int getNumOfBufferedTsBlocks() {
+ return sequenceIdToTsBlock.size();
+ }
+
+ ByteBuffer getSerializedTsBlock(int partition, int sequenceId) {
+ throw new UnsupportedOperationException();
+ }
+
+ ByteBuffer getSerializedTsBlock(int sequenceId) {
+ TsBlock tsBlock;
+ tsBlock = sequenceIdToTsBlock.get(sequenceId);
+ if (tsBlock == null) {
+ throw new IllegalStateException("The data block doesn't exist. Sequence ID: " + sequenceId);
+ }
+ return serde.serialized(tsBlock);
+ }
+
+ void acknowledgeTsBlock(int startSequenceId, int endSequenceId) {
+ long freedBytes = 0L;
+ synchronized (this) {
+ Iterator<Entry<Integer, TsBlock>> iterator = sequenceIdToTsBlock.entrySet().iterator();
+ while (iterator.hasNext()) {
+ Entry<Integer, TsBlock> entry = iterator.next();
+ if (entry.getKey() < startSequenceId) {
+ continue;
+ }
+ if (entry.getKey() >= endSequenceId) {
+ break;
+ }
+ freedBytes += entry.getValue().getRetainedSizeInBytes();
+ bufferRetainedSizeInBytes -= entry.getValue().getRetainedSizeInBytes();
+ iterator.remove();
+ }
+ }
+ if (isFinished()) {
+ sinkHandleListener.onFinish(this);
+ }
+ localMemoryManager.getQueryPool().free(localFragmentInstanceId.getQueryId(), freedBytes);
+ }
+
+ String getRemoteHostname() {
+ return remoteHostname;
+ }
+
+ TFragmentInstanceId getRemoteFragmentInstanceId() {
+ return remoteFragmentInstanceId;
+ }
+
+ String getRemotePlanNodeId() {
+ return remotePlanNodeId;
+ }
+
+ TFragmentInstanceId getLocalFragmentInstanceId() {
+ return localFragmentInstanceId;
+ }
+
+ @Override
+ public String toString() {
+ return new StringJoiner(", ", SinkHandle.class.getSimpleName() + "[", "]")
+ .add("remoteHostname='" + remoteHostname + "'")
+ .add("remoteFragmentInstanceId=" + remoteFragmentInstanceId)
+ .add("remotePlanNodeId='" + remotePlanNodeId + "'")
+ .add("localFragmentInstanceId=" + localFragmentInstanceId)
+ .toString();
+ }
+
+ /** Send a {@link NewDataBlockEvent} to downstream fragment instance. */
+ class SendNewDataBlockEventTask implements Runnable {
+
+ private final int startSequenceId;
+ private final List<Long> blockSizes;
+
+ SendNewDataBlockEventTask(int startSequenceId, List<Long> blockSizes) {
+ Validate.isTrue(
+ startSequenceId >= 0,
+ "Start sequence ID should be greater than or equal to zero, but was: "
+ + startSequenceId
+ + ".");
+ this.startSequenceId = startSequenceId;
+ this.blockSizes = Validate.notNull(blockSizes);
+ }
+
+ @Override
+ public void run() {
+ logger.debug(
+ "Send new data block event [{}, {}) to plan node {} of {}.",
+ startSequenceId,
+ startSequenceId + blockSizes.size(),
+ remotePlanNodeId,
+ remoteFragmentInstanceId);
+ int attempt = 0;
+ NewDataBlockEvent newDataBlockEvent =
+ new NewDataBlockEvent(
+ remoteFragmentInstanceId,
+ remotePlanNodeId,
+ localFragmentInstanceId,
+ startSequenceId,
+ blockSizes);
+ while (attempt < MAX_ATTEMPT_TIMES) {
+ attempt += 1;
+ try {
+ client.onNewDataBlockEvent(newDataBlockEvent);
+ break;
+ } catch (TException e) {
+ logger.error(
+ "Failed to send new data block event to plan node {} of {} due to {}, attempt times: {}",
+ remotePlanNodeId,
+ remoteFragmentInstanceId,
+ e.getMessage(),
+ attempt);
+ if (attempt == MAX_ATTEMPT_TIMES) {
+ synchronized (this) {
+ throwable = e;
+ }
+ }
+ }
+ }
+ }
+ }
+}
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/buffer/SourceHandle.java b/server/src/main/java/org/apache/iotdb/db/mpp/buffer/SourceHandle.java
index 5f8607e67e..7366f3630c 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/buffer/SourceHandle.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/buffer/SourceHandle.java
@@ -19,72 +19,389 @@
package org.apache.iotdb.db.mpp.buffer;
+import org.apache.iotdb.db.mpp.buffer.DataBlockManager.SourceHandleListener;
+import org.apache.iotdb.db.mpp.memory.LocalMemoryManager;
+import org.apache.iotdb.mpp.rpc.thrift.AcknowledgeDataBlockEvent;
+import org.apache.iotdb.mpp.rpc.thrift.DataBlockService;
+import org.apache.iotdb.mpp.rpc.thrift.GetDataBlockRequest;
+import org.apache.iotdb.mpp.rpc.thrift.GetDataBlockResponse;
+import org.apache.iotdb.mpp.rpc.thrift.TFragmentInstanceId;
import org.apache.iotdb.tsfile.read.common.block.TsBlock;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.SettableFuture;
import org.apache.commons.lang3.Validate;
+import org.apache.thrift.TException;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
-import java.util.ArrayDeque;
-import java.util.Queue;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.StringJoiner;
+import java.util.concurrent.ExecutorService;
import static com.google.common.util.concurrent.Futures.nonCancellationPropagating;
public class SourceHandle implements ISourceHandle {
- private final long bufferCapacityInBytes;
+ private static final Logger logger = LoggerFactory.getLogger(SourceHandle.class);
+
+ public static final int MAX_ATTEMPT_TIMES = 3;
+
+ private final String remoteHostname;
+ private final TFragmentInstanceId remoteFragmentInstanceId;
+ private final TFragmentInstanceId localFragmentInstanceId;
+ private final String localPlanNodeId;
+ private final LocalMemoryManager localMemoryManager;
+ private final ExecutorService executorService;
+ private final DataBlockService.Client client;
+ private final TsBlockSerde serde;
+ private final SourceHandleListener sourceHandleListener;
+
+ private final Map<Integer, TsBlock> sequenceIdToTsBlock = new HashMap<>();
+ private final Map<Integer, Long> sequenceIdToDataBlockSize = new HashMap<>();
- private final Queue<TsBlock> bufferedTsBlocks = new ArrayDeque<>();
private volatile SettableFuture<Void> blocked = SettableFuture.create();
- private volatile long bufferRetainedSizeInBytes;
- private boolean finished;
+ private long bufferRetainedSizeInBytes;
+ private int currSequenceId = 0;
+ private int nextSequenceId = 0;
+ private int lastSequenceId = Integer.MAX_VALUE;
+ private int numActiveGetDataBlocksTask = 0;
+ private boolean noMoreTsBlocks;
private boolean closed;
private Throwable throwable;
- public SourceHandle(long bufferCapacityInBytes) {
- Validate.isTrue(bufferCapacityInBytes > 0L, "capacity cannot be less or equal to zero.");
- this.bufferCapacityInBytes = bufferCapacityInBytes;
+ public SourceHandle(
+ String remoteHostname,
+ TFragmentInstanceId remoteFragmentInstanceId,
+ TFragmentInstanceId localFragmentInstanceId,
+ String localPlanNodeId,
+ LocalMemoryManager localMemoryManager,
+ ExecutorService executorService,
+ DataBlockService.Client client,
+ TsBlockSerde serde,
+ SourceHandleListener sourceHandleListener) {
+ this.remoteHostname = Validate.notNull(remoteHostname);
+ this.remoteFragmentInstanceId = Validate.notNull(remoteFragmentInstanceId);
+ this.localFragmentInstanceId = Validate.notNull(localFragmentInstanceId);
+ this.localPlanNodeId = Validate.notNull(localPlanNodeId);
+ this.localMemoryManager = Validate.notNull(localMemoryManager);
+ this.executorService = Validate.notNull(executorService);
+ this.client = Validate.notNull(client);
+ this.serde = Validate.notNull(serde);
+ this.sourceHandleListener = Validate.notNull(sourceHandleListener);
+ bufferRetainedSizeInBytes = 0L;
}
@Override
- public TsBlock receive() {
+ public TsBlock receive() throws IOException {
if (throwable != null) {
- throw new RuntimeException(throwable);
+ throw new IOException(throwable);
}
if (closed) {
- throw new IllegalStateException("Source handle has been closed.");
+ throw new IllegalStateException("Source handle is closed.");
}
- TsBlock tsBlock = bufferedTsBlocks.poll();
- if (tsBlock != null) {
- bufferRetainedSizeInBytes -= getRetainedSizeInBytes(tsBlock);
+ if (!blocked.isDone()) {
+ throw new IllegalStateException("Source handle is blocked.");
}
- if (bufferedTsBlocks.isEmpty() && !finished && blocked.isDone()) {
- blocked = SettableFuture.create();
+ TsBlock tsBlock;
+ synchronized (this) {
+ tsBlock = sequenceIdToTsBlock.remove(currSequenceId);
+ currSequenceId += 1;
+ bufferRetainedSizeInBytes -= tsBlock.getRetainedSizeInBytes();
+ localMemoryManager
+ .getQueryPool()
+ .free(localFragmentInstanceId.getQueryId(), tsBlock.getRetainedSizeInBytes());
+
+ if (sequenceIdToTsBlock.isEmpty() && !isFinished()) {
+ blocked = SettableFuture.create();
+ }
+ }
+ if (isFinished()) {
+ sourceHandleListener.onFinished(this);
}
+ trySubmitGetDataBlocksTask();
return tsBlock;
}
- private long getRetainedSizeInBytes(TsBlock tsBlock) {
- throw new UnsupportedOperationException();
- }
+ private synchronized void trySubmitGetDataBlocksTask() {
+ final int startSequenceId = nextSequenceId;
+ int endSequenceId = nextSequenceId;
+ long reservedBytes = 0L;
+ ListenableFuture<?> future = null;
+ while (sequenceIdToDataBlockSize.containsKey(endSequenceId)) {
+ Long bytesToReserve = sequenceIdToDataBlockSize.get(endSequenceId);
+ if (bytesToReserve == null) {
+ throw new IllegalStateException("Data block size is null.");
+ }
+ future =
+ localMemoryManager
+ .getQueryPool()
+ .reserve(localFragmentInstanceId.getQueryId(), bytesToReserve);
+ if (future.isDone()) {
+ endSequenceId += 1;
+ reservedBytes += bytesToReserve;
+ bufferRetainedSizeInBytes += bytesToReserve;
+ } else {
+ break;
+ }
+ }
- @Override
- public boolean isFinished() {
- return finished;
+ if (future == null) {
+ // Next data block not generated yet. Do nothing.
+ return;
+ }
+
+ if (future.isDone()) {
+ nextSequenceId = endSequenceId;
+ executorService.submit(new GetDataBlocksTask(startSequenceId, endSequenceId, reservedBytes));
+ numActiveGetDataBlocksTask += 1;
+ } else {
+ nextSequenceId = endSequenceId + 1;
+ // The future being not completed indicates,
+ // 1. Memory has been reserved for blocks in [startSequenceId, endSequenceId).
+ // 2. Memory reservation for block whose sequence ID equals endSequenceId is blocked.
+ // 3. Have not reserve memory for the rest of blocks.
+ //
+ // startSequenceId endSequenceId endSequenceId + 1
+ // |-------- reserved --------|--- blocked ---|--- not reserved ---|
+
+ if (endSequenceId > startSequenceId) {
+ // Memory has been reserved. Submit a GetDataBlocksTask for these blocks.
+ executorService.submit(
+ new GetDataBlocksTask(startSequenceId, endSequenceId, reservedBytes));
+ numActiveGetDataBlocksTask += 1;
+ }
+
+ // Submit a GetDataBlocksTask when memory is freed.
+ final int sequenceIdOfUnReservedDataBlock = endSequenceId;
+ final long sizeOfUnReservedDataBlock = sequenceIdToDataBlockSize.get(endSequenceId);
+ future.addListener(
+ () -> {
+ executorService.submit(
+ new GetDataBlocksTask(
+ sequenceIdOfUnReservedDataBlock,
+ sequenceIdOfUnReservedDataBlock + 1,
+ sizeOfUnReservedDataBlock));
+ numActiveGetDataBlocksTask += 1;
+ bufferRetainedSizeInBytes += sizeOfUnReservedDataBlock;
+ },
+ executorService);
+
+ // Schedule another call of trySubmitGetDataBlocksTask for the rest of blocks.
+ future.addListener(SourceHandle.this::trySubmitGetDataBlocksTask, executorService);
+ }
}
+ @Override
public ListenableFuture<Void> isBlocked() {
+ if (throwable != null) {
+ throw new RuntimeException(throwable);
+ }
+ if (closed) {
+ throw new IllegalStateException("Source handle is closed.");
+ }
return nonCancellationPropagating(blocked);
}
+ synchronized void setNoMoreTsBlocks(int lastSequenceId) {
+ this.lastSequenceId = lastSequenceId;
+ noMoreTsBlocks = true;
+ }
+
+ synchronized void updatePendingDataBlockInfo(int startSequenceId, List<Long> dataBlockSizes) {
+ for (int i = 0; i < dataBlockSizes.size(); i++) {
+ sequenceIdToDataBlockSize.put(i + startSequenceId, dataBlockSizes.get(i));
+ }
+ trySubmitGetDataBlocksTask();
+ }
+
@Override
- public void close() {
+ public synchronized void close() {
if (closed) {
return;
}
- bufferedTsBlocks.clear();
- bufferRetainedSizeInBytes = 0;
+ sequenceIdToDataBlockSize.clear();
+ if (bufferRetainedSizeInBytes > 0) {
+ localMemoryManager
+ .getQueryPool()
+ .free(localFragmentInstanceId.getQueryId(), bufferRetainedSizeInBytes);
+ bufferRetainedSizeInBytes = 0;
+ }
closed = true;
- if (!blocked.isDone()) {}
+ sourceHandleListener.onClosed(this);
+ }
+
+ @Override
+ public boolean isFinished() {
+ return throwable == null
+ && noMoreTsBlocks
+ && numActiveGetDataBlocksTask == 0
+ && nextSequenceId - 1 == lastSequenceId
+ && sequenceIdToTsBlock.isEmpty();
+ }
+
+ String getRemoteHostname() {
+ return remoteHostname;
+ }
+
+ TFragmentInstanceId getRemoteFragmentInstanceId() {
+ return remoteFragmentInstanceId.deepCopy();
+ }
+
+ TFragmentInstanceId getLocalFragmentInstanceId() {
+ return localFragmentInstanceId;
+ }
+
+ String getLocalPlanNodeId() {
+ return localPlanNodeId;
+ }
+
+ @Override
+ public long getBufferRetainedSizeInBytes() {
+ return bufferRetainedSizeInBytes;
+ }
+
+ @Override
+ public boolean isClosed() {
+ return closed;
+ }
+
+ @Override
+ public String toString() {
+ return new StringJoiner(", ", SourceHandle.class.getSimpleName() + "[", "]")
+ .add("remoteHostname='" + remoteHostname + "'")
+ .add("remoteFragmentInstanceId=" + remoteFragmentInstanceId)
+ .add("localFragmentInstanceId=" + localFragmentInstanceId)
+ .add("localPlanNodeId='" + localPlanNodeId + "'")
+ .toString();
+ }
+
+ /** Get data blocks from an upstream fragment instance. */
+ class GetDataBlocksTask implements Runnable {
+ private final int startSequenceId;
+ private final int endSequenceId;
+ private final long reservedBytes;
+
+ GetDataBlocksTask(int startSequenceId, int endSequenceId, long reservedBytes) {
+ Validate.isTrue(
+ startSequenceId >= 0,
+ "Start sequence ID should be greater than or equal to zero. Start sequence ID: "
+ + startSequenceId);
+ this.startSequenceId = startSequenceId;
+ Validate.isTrue(
+ endSequenceId > startSequenceId,
+ "End sequence ID should be greater than the start sequence ID. Start sequence ID: "
+ + startSequenceId
+ + ", end sequence ID: "
+ + endSequenceId);
+ this.endSequenceId = endSequenceId;
+ Validate.isTrue(reservedBytes > 0L, "Reserved bytes should be greater than zero.");
+ this.reservedBytes = reservedBytes;
+ }
+
+ @Override
+ public void run() {
+ logger.debug(
+ "Get data blocks [{}, {}) from {} for plan node {} of {}.",
+ startSequenceId,
+ endSequenceId,
+ remoteFragmentInstanceId,
+ localPlanNodeId,
+ localFragmentInstanceId);
+ GetDataBlockRequest req =
+ new GetDataBlockRequest(remoteFragmentInstanceId, startSequenceId, endSequenceId);
+ int attempt = 0;
+ while (attempt < MAX_ATTEMPT_TIMES) {
+ attempt += 1;
+ try {
+ GetDataBlockResponse resp = client.getDataBlock(req);
+ List<TsBlock> tsBlocks = new ArrayList<>(resp.getTsBlocks().size());
+ for (ByteBuffer byteBuffer : resp.getTsBlocks()) {
+ TsBlock tsBlock = serde.deserialize(byteBuffer);
+ tsBlocks.add(tsBlock);
+ }
+ synchronized (SourceHandle.this) {
+ if (closed) {
+ return;
+ }
+ for (int i = startSequenceId; i < endSequenceId; i++) {
+ sequenceIdToTsBlock.put(i, tsBlocks.get(i - startSequenceId));
+ }
+ if (!blocked.isDone()) {
+ blocked.set(null);
+ }
+ }
+ executorService.submit(
+ new SendAcknowledgeDataBlockEventTask(startSequenceId, endSequenceId));
+ break;
+ } catch (TException e) {
+ logger.error(
+ "Failed to get data block from {} due to {}, attempt times: {}",
+ remoteFragmentInstanceId,
+ e.getMessage(),
+ attempt);
+ if (attempt == MAX_ATTEMPT_TIMES) {
+ synchronized (SourceHandle.this) {
+ throwable = e;
+ bufferRetainedSizeInBytes -= reservedBytes;
+ localMemoryManager
+ .getQueryPool()
+ .free(localFragmentInstanceId.getQueryId(), reservedBytes);
+ }
+ }
+ } finally {
+ numActiveGetDataBlocksTask -= 1;
+ }
+ }
+ // TODO: try to issue another GetDataBlocksTask to make the query run faster.
+ }
+ }
+
+ class SendAcknowledgeDataBlockEventTask implements Runnable {
+
+ private final int startSequenceId;
+ private final int endSequenceId;
+
+ public SendAcknowledgeDataBlockEventTask(int startSequenceId, int endSequenceId) {
+ this.startSequenceId = startSequenceId;
+ this.endSequenceId = endSequenceId;
+ }
+
+ @Override
+ public void run() {
+ logger.debug(
+ "Send ack data block event [{}, {}) to {}.",
+ startSequenceId,
+ endSequenceId,
+ remoteFragmentInstanceId);
+ int attempt = 0;
+ AcknowledgeDataBlockEvent acknowledgeDataBlockEvent =
+ new AcknowledgeDataBlockEvent(remoteFragmentInstanceId, startSequenceId, endSequenceId);
+ while (attempt < MAX_ATTEMPT_TIMES) {
+ attempt += 1;
+ try {
+ client.onAcknowledgeDataBlockEvent(acknowledgeDataBlockEvent);
+ break;
+ } catch (TException e) {
+ logger.error(
+ "Failed to send ack data block event [{}, {}) to {} due to {}, attempt times: {}",
+ startSequenceId,
+ endSequenceId,
+ remoteFragmentInstanceId,
+ e.getMessage(),
+ attempt);
+ if (attempt == MAX_ATTEMPT_TIMES) {
+ synchronized (this) {
+ throwable = e;
+ }
+ }
+ }
+ }
+ }
}
}
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/buffer/StubSinkHandle.java b/server/src/main/java/org/apache/iotdb/db/mpp/buffer/StubSinkHandle.java
index b920eb01ad..dc09a1c038 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/buffer/StubSinkHandle.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/buffer/StubSinkHandle.java
@@ -22,6 +22,7 @@ import org.apache.iotdb.tsfile.read.common.block.TsBlock;
import com.google.common.util.concurrent.ListenableFuture;
+import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
@@ -33,24 +34,44 @@ public class StubSinkHandle implements ISinkHandle {
private final List<TsBlock> tsBlocks = new ArrayList<>();
+ @Override
+ public long getBufferRetainedSizeInBytes() {
+ return 0;
+ }
+
+ @Override
+ public int getNumOfBufferedTsBlocks() {
+ return 0;
+ }
+
@Override
public ListenableFuture<Void> isFull() {
return NOT_BLOCKED;
}
@Override
- public void send(TsBlock tsBlock) {
- tsBlocks.add(tsBlock);
+ public void send(List<TsBlock> tsBlocks) throws IOException {
+ this.tsBlocks.addAll(tsBlocks);
}
@Override
- public void send(int partition, TsBlock tsBlock) {
- tsBlocks.add(tsBlock);
+ public void send(int partition, List<TsBlock> tsBlocks) throws IOException {
+ this.tsBlocks.addAll(tsBlocks);
}
@Override
public void setNoMoreTsBlocks() {}
+ @Override
+ public boolean isClosed() {
+ return false;
+ }
+
+ @Override
+ public boolean isFinished() {
+ return false;
+ }
+
@Override
public void close() {
tsBlocks.clear();
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/buffer/ISourceHandle.java b/server/src/main/java/org/apache/iotdb/db/mpp/buffer/TsBlockSerde.java
similarity index 60%
copy from server/src/main/java/org/apache/iotdb/db/mpp/buffer/ISourceHandle.java
copy to server/src/main/java/org/apache/iotdb/db/mpp/buffer/TsBlockSerde.java
index ef2919c787..8fd0329772 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/buffer/ISourceHandle.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/buffer/TsBlockSerde.java
@@ -7,7 +7,7 @@
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
- * http://www.apache.org/licenses/LICENSE-2.0
+ * http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
@@ -16,26 +16,21 @@
* specific language governing permissions and limitations
* under the License.
*/
+
package org.apache.iotdb.db.mpp.buffer;
import org.apache.iotdb.tsfile.read.common.block.TsBlock;
-import com.google.common.util.concurrent.ListenableFuture;
-
-import java.io.Closeable;
-
-public interface ISourceHandle extends Closeable {
-
- /** Get a {@link TsBlock} from the input buffer. */
- TsBlock receive();
-
- /** Check if there are more tsblocks. */
- boolean isFinished();
+import java.nio.ByteBuffer;
- /** Get a future that will be completed when the input buffer is not empty. */
- ListenableFuture<Void> isBlocked();
+public class TsBlockSerde {
+ public ByteBuffer serialized(TsBlock tsBlock) {
+ // TODO: implement
+ return null;
+ }
- /** Close the handle. Discarding all tsblocks which may still be in memory buffer. */
- @Override
- void close();
+ public TsBlock deserialize(ByteBuffer buffer) {
+ // TODO: implement
+ return null;
+ }
}
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/buffer/ISourceHandle.java b/server/src/main/java/org/apache/iotdb/db/mpp/buffer/TsBlockSerdeFactory.java
similarity index 57%
copy from server/src/main/java/org/apache/iotdb/db/mpp/buffer/ISourceHandle.java
copy to server/src/main/java/org/apache/iotdb/db/mpp/buffer/TsBlockSerdeFactory.java
index ef2919c787..7aa7329987 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/buffer/ISourceHandle.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/buffer/TsBlockSerdeFactory.java
@@ -7,7 +7,7 @@
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
- * http://www.apache.org/licenses/LICENSE-2.0
+ * http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
@@ -16,26 +16,14 @@
* specific language governing permissions and limitations
* under the License.
*/
-package org.apache.iotdb.db.mpp.buffer;
-
-import org.apache.iotdb.tsfile.read.common.block.TsBlock;
-
-import com.google.common.util.concurrent.ListenableFuture;
-
-import java.io.Closeable;
-public interface ISourceHandle extends Closeable {
-
- /** Get a {@link TsBlock} from the input buffer. */
- TsBlock receive();
-
- /** Check if there are more tsblocks. */
- boolean isFinished();
+package org.apache.iotdb.db.mpp.buffer;
- /** Get a future that will be completed when the input buffer is not empty. */
- ListenableFuture<Void> isBlocked();
+import java.util.function.Supplier;
- /** Close the handle. Discarding all tsblocks which may still be in memory buffer. */
+public class TsBlockSerdeFactory implements Supplier<TsBlockSerde> {
@Override
- void close();
+ public TsBlockSerde get() {
+ return new TsBlockSerde();
+ }
}
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/execution/DataDriver.java b/server/src/main/java/org/apache/iotdb/db/mpp/execution/DataDriver.java
index 4b401386c9..5a1168fac3 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/execution/DataDriver.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/execution/DataDriver.java
@@ -40,6 +40,7 @@ import org.slf4j.LoggerFactory;
import javax.annotation.concurrent.NotThreadSafe;
import java.io.IOException;
+import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
@@ -295,7 +296,7 @@ public class DataDriver implements Driver {
if (root.hasNext()) {
TsBlock tsBlock = root.next();
if (tsBlock != null && !tsBlock.isEmpty()) {
- sinkHandle.send(tsBlock);
+ sinkHandle.send(Collections.singletonList(tsBlock));
}
}
return NOT_BLOCKED;
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/execution/SchemaDriver.java b/server/src/main/java/org/apache/iotdb/db/mpp/execution/SchemaDriver.java
index 5a7dfb4654..4a1eef21fa 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/execution/SchemaDriver.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/execution/SchemaDriver.java
@@ -32,6 +32,7 @@ import org.slf4j.LoggerFactory;
import javax.annotation.concurrent.NotThreadSafe;
import java.io.IOException;
+import java.util.Collections;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
@@ -115,7 +116,7 @@ public class SchemaDriver implements Driver {
if (root.hasNext()) {
TsBlock tsBlock = root.next();
if (tsBlock != null && !tsBlock.isEmpty()) {
- sinkHandle.send(tsBlock);
+ sinkHandle.send(Collections.singletonList(tsBlock));
}
}
return NOT_BLOCKED;
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/memory/LocalMemoryManager.java b/server/src/main/java/org/apache/iotdb/db/mpp/memory/LocalMemoryManager.java
index 797075f217..08c59d30b2 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/memory/LocalMemoryManager.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/memory/LocalMemoryManager.java
@@ -19,25 +19,22 @@
package org.apache.iotdb.db.mpp.memory;
+import org.apache.iotdb.db.conf.IoTDBDescriptor;
+
/**
* Manages memory of a data node. The memory is divided into two memory pools so that the memory for
* read and for write can be isolated.
*/
public class LocalMemoryManager {
- private final long maxBytes;
private final MemoryPool queryPool;
public LocalMemoryManager() {
- long maxMemory = Runtime.getRuntime().maxMemory();
- // Save 20% memory for untracked allocations.
- maxBytes = (long) (maxMemory * 0.8);
- // Allocate 50% memory for query execution.
- queryPool = new MemoryPool("query", (long) (maxBytes * 0.5));
- }
-
- public long getMaxBytes() {
- return maxBytes;
+ queryPool =
+ new MemoryPool(
+ "query",
+ IoTDBDescriptor.getInstance().getConfig().getAllocateMemoryForRead(),
+ (long) (IoTDBDescriptor.getInstance().getConfig().getAllocateMemoryForRead() * 0.5));
}
public MemoryPool getQueryPool() {
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/memory/MemoryPool.java b/server/src/main/java/org/apache/iotdb/db/mpp/memory/MemoryPool.java
index aa42e75efe..f634b71a66 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/memory/MemoryPool.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/memory/MemoryPool.java
@@ -19,24 +19,66 @@
package org.apache.iotdb.db.mpp.memory;
+import com.google.common.util.concurrent.AbstractFuture;
+import com.google.common.util.concurrent.Futures;
+import com.google.common.util.concurrent.ListenableFuture;
import org.apache.commons.lang3.Validate;
+import javax.annotation.Nullable;
+
import java.util.HashMap;
+import java.util.Iterator;
+import java.util.LinkedList;
import java.util.Map;
+import java.util.Queue;
-/** Manages certain amount of memory. */
+/** A thread-safe memory pool. */
public class MemoryPool {
+ private static class MemoryReservationFuture<V> extends AbstractFuture<V> {
+ private final String queryId;
+ private final long bytes;
+
+ private MemoryReservationFuture(String queryId, long bytes) {
+ this.queryId = Validate.notNull(queryId, "queryId cannot be null");
+ Validate.isTrue(bytes > 0L, "bytes should be greater than zero.");
+ this.bytes = bytes;
+ }
+
+ public long getBytes() {
+ return bytes;
+ }
+
+ public String getQueryId() {
+ return queryId;
+ }
+
+ public static <V> MemoryReservationFuture<V> create(String queryId, long bytes) {
+ return new MemoryReservationFuture<>(queryId, bytes);
+ }
+
+ @Override
+ public boolean set(@Nullable V value) {
+ return super.set(value);
+ }
+ }
+
private final String id;
private final long maxBytes;
+ private final long maxBytesPerQuery;
private long reservedBytes = 0L;
private final Map<String, Long> queryMemoryReservations = new HashMap<>();
+ private final Queue<MemoryReservationFuture<Void>> memoryReservationFutures = new LinkedList<>();
- public MemoryPool(String id, long maxBytes) {
+ public MemoryPool(String id, long maxBytes, long maxBytesPerQuery) {
this.id = Validate.notNull(id);
- Validate.isTrue(maxBytes > 0L);
+ Validate.isTrue(maxBytes > 0L, "max bytes should be greater than zero.");
this.maxBytes = maxBytes;
+ Validate.isTrue(
+ maxBytesPerQuery > 0L && maxBytesPerQuery <= maxBytes,
+ "max bytes per query should be greater than zero while less than or equal to max bytes.");
+ this.maxBytesPerQuery = maxBytesPerQuery;
}
public String getId() {
@@ -47,15 +89,41 @@ public class MemoryPool {
return maxBytes;
}
+ public ListenableFuture<Void> reserve(String queryId, long bytes) {
+ Validate.notNull(queryId);
+ Validate.isTrue(
+ bytes > 0L && bytes <= maxBytesPerQuery,
+ "bytes should be greater than zero while less than or equal to max bytes per query.");
+
+ ListenableFuture<Void> result;
+ synchronized (this) {
+ if (maxBytes - reservedBytes < bytes
+ || maxBytesPerQuery - queryMemoryReservations.getOrDefault(queryId, 0L) < bytes) {
+ result = MemoryReservationFuture.create(queryId, bytes);
+ memoryReservationFutures.add((MemoryReservationFuture<Void>) result);
+ } else {
+ reservedBytes += bytes;
+ queryMemoryReservations.merge(queryId, bytes, Long::sum);
+ result = Futures.immediateFuture(null);
+ }
+ }
+
+ return result;
+ }
+
public boolean tryReserve(String queryId, long bytes) {
Validate.notNull(queryId);
- Validate.isTrue(bytes > 0L);
+ Validate.isTrue(
+ bytes > 0L && bytes <= maxBytesPerQuery,
+ "bytes should be greater than zero while less than or equal to max bytes per query.");
- if (maxBytes - reservedBytes < bytes) {
+ if (maxBytes - reservedBytes < bytes
+ || maxBytesPerQuery - queryMemoryReservations.getOrDefault(queryId, 0L) < bytes) {
return false;
}
synchronized (this) {
- if (maxBytes - reservedBytes < bytes) {
+ if (maxBytes - reservedBytes < bytes
+ || maxBytesPerQuery - queryMemoryReservations.getOrDefault(queryId, 0L) < bytes) {
return false;
}
reservedBytes += bytes;
@@ -79,11 +147,35 @@ public class MemoryPool {
} else {
queryMemoryReservations.put(queryId, queryReservedBytes);
}
-
reservedBytes -= bytes;
+
+ if (memoryReservationFutures.isEmpty()) {
+ return;
+ }
+ Iterator<MemoryReservationFuture<Void>> iterator = memoryReservationFutures.iterator();
+ while (iterator.hasNext()) {
+ MemoryReservationFuture<Void> future = iterator.next();
+
+ if (future.isCancelled()) {
+ iterator.remove();
+ continue;
+ }
+
+ long bytesToReserve = future.getBytes();
+ if (maxBytes - reservedBytes < bytesToReserve) {
+ return;
+ }
+ if (maxBytesPerQuery - queryMemoryReservations.getOrDefault(future.getQueryId(), 0L)
+ >= bytesToReserve) {
+ reservedBytes += bytesToReserve;
+ queryMemoryReservations.merge(future.getQueryId(), bytesToReserve, Long::sum);
+ future.set(null);
+ iterator.remove();
+ }
+ }
}
- public synchronized long getQueryMemoryReservedBytes(String queryId) {
+ public long getQueryMemoryReservedBytes(String queryId) {
return queryMemoryReservations.getOrDefault(queryId, 0L);
}
diff --git a/server/src/test/java/org/apache/iotdb/db/mpp/buffer/SinkHandleTest.java b/server/src/test/java/org/apache/iotdb/db/mpp/buffer/SinkHandleTest.java
new file mode 100644
index 0000000000..0c03412f39
--- /dev/null
+++ b/server/src/test/java/org/apache/iotdb/db/mpp/buffer/SinkHandleTest.java
@@ -0,0 +1,448 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iotdb.db.mpp.buffer;
+
+import org.apache.iotdb.db.mpp.buffer.DataBlockManager.SinkHandleListener;
+import org.apache.iotdb.db.mpp.memory.LocalMemoryManager;
+import org.apache.iotdb.db.mpp.memory.MemoryPool;
+import org.apache.iotdb.mpp.rpc.thrift.DataBlockService.Client;
+import org.apache.iotdb.mpp.rpc.thrift.EndOfDataBlockEvent;
+import org.apache.iotdb.mpp.rpc.thrift.NewDataBlockEvent;
+import org.apache.iotdb.mpp.rpc.thrift.TFragmentInstanceId;
+import org.apache.iotdb.tsfile.read.common.block.TsBlock;
+
+import org.apache.thrift.TException;
+import org.junit.Assert;
+import org.junit.Test;
+import org.mockito.Mockito;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.List;
+import java.util.concurrent.Executors;
+
+public class SinkHandleTest {
+
+ @Test
+ public void testOneTimeNotBlockedSend() {
+ final String queryId = "q0";
+ final long mockTsBlockSize = 1024L * 1024L;
+ final int numOfMockTsBlock = 10;
+ final String remoteHostname = "remote";
+ final TFragmentInstanceId remoteFragmentInstanceId =
+ new TFragmentInstanceId(queryId, "f0", "0");
+ final String remotePlanNodeId = "exchange_0";
+ final TFragmentInstanceId localFragmentInstanceId = new TFragmentInstanceId(queryId, "f1", "0");
+
+ // Construct a mock LocalMemoryManager that returns unblocked futures.
+ LocalMemoryManager mockLocalMemoryManager = Mockito.mock(LocalMemoryManager.class);
+ MemoryPool mockMemoryPool = Utils.createMockNonBlockedMemoryPool();
+ Mockito.when(mockLocalMemoryManager.getQueryPool()).thenReturn(mockMemoryPool);
+ // Construct a mock client.
+ Client mockClient = Mockito.mock(Client.class);
+ try {
+ Mockito.doNothing()
+ .when(mockClient)
+ .onEndOfDataBlockEvent(Mockito.any(EndOfDataBlockEvent.class));
+ Mockito.doNothing()
+ .when(mockClient)
+ .onNewDataBlockEvent(Mockito.any(NewDataBlockEvent.class));
+ } catch (TException e) {
+ e.printStackTrace();
+ Assert.fail();
+ }
+ // Construct a mock SinkHandleListener.
+ SinkHandleListener mockSinkHandleListener = Mockito.mock(SinkHandleListener.class);
+ // Construct several mock TsBlock(s).
+ List<TsBlock> mockTsBlocks = Utils.createMockTsBlocks(numOfMockTsBlock, mockTsBlockSize);
+
+ // Construct SinkHandle.
+ SinkHandle sinkHandle =
+ new SinkHandle(
+ remoteHostname,
+ remoteFragmentInstanceId,
+ remotePlanNodeId,
+ localFragmentInstanceId,
+ mockLocalMemoryManager,
+ Executors.newSingleThreadExecutor(),
+ mockClient,
+ new TsBlockSerde(),
+ mockSinkHandleListener);
+ Assert.assertTrue(sinkHandle.isFull().isDone());
+ Assert.assertFalse(sinkHandle.isFinished());
+ Assert.assertFalse(sinkHandle.isClosed());
+ Assert.assertEquals(0L, sinkHandle.getBufferRetainedSizeInBytes());
+ Assert.assertEquals(0, sinkHandle.getNumOfBufferedTsBlocks());
+
+ // Send tsblocks.
+ try {
+ sinkHandle.send(mockTsBlocks);
+ } catch (IOException e) {
+ e.printStackTrace();
+ Assert.fail();
+ }
+ sinkHandle.setNoMoreTsBlocks();
+ Assert.assertTrue(sinkHandle.isFull().isDone());
+ Assert.assertFalse(sinkHandle.isFinished());
+ Assert.assertFalse(sinkHandle.isClosed());
+ Assert.assertEquals(
+ mockTsBlockSize * numOfMockTsBlock, sinkHandle.getBufferRetainedSizeInBytes());
+ Assert.assertEquals(numOfMockTsBlock, sinkHandle.getNumOfBufferedTsBlocks());
+ Mockito.verify(mockMemoryPool, Mockito.times(1))
+ .reserve(queryId, mockTsBlockSize * numOfMockTsBlock);
+ try {
+ Thread.sleep(100L);
+ Mockito.verify(mockClient, Mockito.times(1))
+ .onNewDataBlockEvent(
+ Mockito.argThat(
+ e ->
+ remoteFragmentInstanceId.equals(e.getTargetFragmentInstanceId())
+ && remotePlanNodeId.equals(e.getTargetPlanNodeId())
+ && localFragmentInstanceId.equals(e.getSourceFragmentInstanceId())
+ && e.getStartSequenceId() == 0
+ && e.getBlockSizes().size() == numOfMockTsBlock));
+ } catch (InterruptedException | TException e) {
+ e.printStackTrace();
+ Assert.fail();
+ }
+
+ // Get tsblocks.
+ for (int i = 0; i < numOfMockTsBlock; i++) {
+ sinkHandle.getSerializedTsBlock(i);
+ Assert.assertTrue(sinkHandle.isFull().isDone());
+ }
+ Assert.assertFalse(sinkHandle.isFinished());
+
+ // Ack tsblocks.
+ sinkHandle.acknowledgeTsBlock(0, numOfMockTsBlock);
+ Assert.assertTrue(sinkHandle.isFull().isDone());
+ Assert.assertTrue(sinkHandle.isFinished());
+ Assert.assertFalse(sinkHandle.isClosed());
+ Assert.assertEquals(0L, sinkHandle.getBufferRetainedSizeInBytes());
+ Mockito.verify(mockMemoryPool, Mockito.times(1))
+ .free(queryId, numOfMockTsBlock * mockTsBlockSize);
+ Mockito.verify(mockSinkHandleListener, Mockito.times(1)).onFinish(sinkHandle);
+
+ // Close the SinkHandle.
+ try {
+ sinkHandle.close();
+ } catch (IOException e) {
+ e.printStackTrace();
+ Assert.fail();
+ }
+ Assert.assertTrue(sinkHandle.isClosed());
+ Mockito.verify(mockSinkHandleListener, Mockito.times(1)).onClosed(sinkHandle);
+ try {
+ Thread.sleep(100L);
+ Mockito.verify(mockClient, Mockito.times(1))
+ .onEndOfDataBlockEvent(
+ Mockito.argThat(
+ e ->
+ remoteFragmentInstanceId.equals(e.getTargetFragmentInstanceId())
+ && remotePlanNodeId.equals(e.getTargetPlanNodeId())
+ && localFragmentInstanceId.equals(e.getSourceFragmentInstanceId())
+ && numOfMockTsBlock - 1 == e.getLastSequenceId()));
+ } catch (InterruptedException | TException e) {
+ e.printStackTrace();
+ Assert.fail();
+ }
+ }
+
+ @Test
+ public void testMultiTimesBlockedSend() {
+ final String queryId = "q0";
+ final long mockTsBlockSize = 1024L * 1024L;
+ final int numOfMockTsBlock = 10;
+ final String remoteHostname = "remote";
+ final TFragmentInstanceId remoteFragmentInstanceId =
+ new TFragmentInstanceId(queryId, "f0", "0");
+ final String remotePlanNodeId = "exchange_0";
+ final TFragmentInstanceId localFragmentInstanceId = new TFragmentInstanceId(queryId, "f1", "0");
+
+ // Construct a mock LocalMemoryManager that returns blocked futures.
+ LocalMemoryManager mockLocalMemoryManager = Mockito.mock(LocalMemoryManager.class);
+ MemoryPool mockMemoryPool =
+ Utils.createMockBlockedMemoryPool(queryId, numOfMockTsBlock, mockTsBlockSize);
+ Mockito.when(mockLocalMemoryManager.getQueryPool()).thenReturn(mockMemoryPool);
+
+ // Construct a mock SinkHandleListener.
+ SinkHandleListener mockSinkHandleListener = Mockito.mock(SinkHandleListener.class);
+ // Construct several mock TsBlock(s).
+ List<TsBlock> mockTsBlocks = Utils.createMockTsBlocks(numOfMockTsBlock, mockTsBlockSize);
+ // Construct a mock client.
+ Client mockClient = Mockito.mock(Client.class);
+ try {
+ Mockito.doNothing()
+ .when(mockClient)
+ .onEndOfDataBlockEvent(Mockito.any(EndOfDataBlockEvent.class));
+ Mockito.doNothing()
+ .when(mockClient)
+ .onNewDataBlockEvent(Mockito.any(NewDataBlockEvent.class));
+ } catch (TException e) {
+ e.printStackTrace();
+ Assert.fail();
+ }
+
+ // Construct SinkHandle.
+ SinkHandle sinkHandle =
+ new SinkHandle(
+ remoteHostname,
+ remoteFragmentInstanceId,
+ remotePlanNodeId,
+ localFragmentInstanceId,
+ mockLocalMemoryManager,
+ Executors.newSingleThreadExecutor(),
+ mockClient,
+ new TsBlockSerde(),
+ mockSinkHandleListener);
+ Assert.assertTrue(sinkHandle.isFull().isDone());
+ Assert.assertFalse(sinkHandle.isFinished());
+ Assert.assertFalse(sinkHandle.isClosed());
+ Assert.assertEquals(0L, sinkHandle.getBufferRetainedSizeInBytes());
+ Assert.assertEquals(0, sinkHandle.getNumOfBufferedTsBlocks());
+
+ // Send tsblocks.
+ try {
+ sinkHandle.send(mockTsBlocks);
+ } catch (IOException e) {
+ e.printStackTrace();
+ Assert.fail();
+ }
+ Assert.assertFalse(sinkHandle.isFull().isDone());
+ Assert.assertFalse(sinkHandle.isFinished());
+ Assert.assertFalse(sinkHandle.isClosed());
+ Assert.assertEquals(
+ mockTsBlockSize * numOfMockTsBlock, sinkHandle.getBufferRetainedSizeInBytes());
+ Assert.assertEquals(numOfMockTsBlock, sinkHandle.getNumOfBufferedTsBlocks());
+ Mockito.verify(mockMemoryPool, Mockito.times(1))
+ .reserve(queryId, mockTsBlockSize * numOfMockTsBlock);
+ try {
+ Thread.sleep(100L);
+ Mockito.verify(mockClient, Mockito.times(1))
+ .onNewDataBlockEvent(
+ Mockito.argThat(
+ e ->
+ remoteFragmentInstanceId.equals(e.getTargetFragmentInstanceId())
+ && remotePlanNodeId.equals(e.getTargetPlanNodeId())
+ && localFragmentInstanceId.equals(e.getSourceFragmentInstanceId())
+ && e.getStartSequenceId() == 0
+ && e.getBlockSizes().size() == numOfMockTsBlock));
+ } catch (TException | InterruptedException e) {
+ e.printStackTrace();
+ Assert.fail();
+ }
+
+ // Get tsblocks.
+ for (int i = 0; i < numOfMockTsBlock; i++) {
+ sinkHandle.getSerializedTsBlock(i);
+ Assert.assertFalse(sinkHandle.isFull().isDone());
+ }
+ Assert.assertFalse(sinkHandle.isFinished());
+
+ // Ack tsblocks.
+ sinkHandle.acknowledgeTsBlock(0, numOfMockTsBlock);
+ Assert.assertTrue(sinkHandle.isFull().isDone());
+ Assert.assertFalse(sinkHandle.isFinished());
+ Assert.assertFalse(sinkHandle.isClosed());
+ Assert.assertEquals(0L, sinkHandle.getBufferRetainedSizeInBytes());
+ Mockito.verify(mockMemoryPool, Mockito.times(1))
+ .free(queryId, numOfMockTsBlock * mockTsBlockSize);
+
+ // Send tsblocks.
+ try {
+ sinkHandle.send(mockTsBlocks);
+ } catch (IOException e) {
+ e.printStackTrace();
+ Assert.fail();
+ }
+ Assert.assertFalse(sinkHandle.isFull().isDone());
+ Assert.assertFalse(sinkHandle.isFinished());
+ Assert.assertFalse(sinkHandle.isClosed());
+ Assert.assertEquals(
+ mockTsBlockSize * numOfMockTsBlock, sinkHandle.getBufferRetainedSizeInBytes());
+ Assert.assertEquals(numOfMockTsBlock, sinkHandle.getNumOfBufferedTsBlocks());
+ Mockito.verify(mockMemoryPool, Mockito.times(2))
+ .reserve(queryId, mockTsBlockSize * numOfMockTsBlock);
+ try {
+ Thread.sleep(100L);
+ Mockito.verify(mockClient, Mockito.times(1))
+ .onNewDataBlockEvent(
+ Mockito.argThat(
+ e ->
+ remoteFragmentInstanceId.equals(e.getTargetFragmentInstanceId())
+ && remotePlanNodeId.equals(e.getTargetPlanNodeId())
+ && localFragmentInstanceId.equals(e.getSourceFragmentInstanceId())
+ && e.getStartSequenceId() == numOfMockTsBlock
+ && e.getBlockSizes().size() == numOfMockTsBlock));
+ } catch (InterruptedException | TException e) {
+ e.printStackTrace();
+ Assert.fail();
+ }
+
+ // Close the SinkHandle.
+ sinkHandle.setNoMoreTsBlocks();
+ Assert.assertFalse(sinkHandle.isFinished());
+ try {
+ sinkHandle.close();
+ } catch (IOException e) {
+ e.printStackTrace();
+ Assert.fail();
+ }
+ Assert.assertTrue(sinkHandle.isClosed());
+ Mockito.verify(mockSinkHandleListener, Mockito.times(1)).onClosed(sinkHandle);
+ try {
+ Thread.sleep(100L);
+ Mockito.verify(mockClient, Mockito.times(1))
+ .onEndOfDataBlockEvent(
+ Mockito.argThat(
+ e ->
+ remoteFragmentInstanceId.equals(e.getTargetFragmentInstanceId())
+ && remotePlanNodeId.equals(e.getTargetPlanNodeId())
+ && localFragmentInstanceId.equals(e.getSourceFragmentInstanceId())
+ && numOfMockTsBlock * 2 - 1 == e.getLastSequenceId()));
+ } catch (InterruptedException | TException e) {
+ e.printStackTrace();
+ Assert.fail();
+ }
+
+ // Get tsblocks after the SinkHandle is closed.
+ for (int i = numOfMockTsBlock; i < numOfMockTsBlock * 2; i++) {
+ sinkHandle.getSerializedTsBlock(i);
+ }
+ Assert.assertFalse(sinkHandle.isFinished());
+
+ // Ack tsblocks.
+ sinkHandle.acknowledgeTsBlock(numOfMockTsBlock, numOfMockTsBlock * 2);
+ Assert.assertTrue(sinkHandle.isFinished());
+ Assert.assertTrue(sinkHandle.isClosed());
+ Assert.assertEquals(0L, sinkHandle.getBufferRetainedSizeInBytes());
+ Mockito.verify(mockMemoryPool, Mockito.times(2))
+ .free(queryId, numOfMockTsBlock * mockTsBlockSize);
+ Mockito.verify(mockSinkHandleListener, Mockito.times(1)).onFinish(sinkHandle);
+ }
+
+ @Test
+ public void testFailedSend() {
+ final String queryId = "q0";
+ final long mockTsBlockSize = 1024L * 1024L;
+ final int numOfMockTsBlock = 10;
+ final String remoteHostname = "remote";
+ final TFragmentInstanceId remoteFragmentInstanceId =
+ new TFragmentInstanceId(queryId, "f0", "0");
+ final String remotePlanNodeId = "exchange_0";
+ final TFragmentInstanceId localFragmentInstanceId = new TFragmentInstanceId(queryId, "f1", "0");
+
+ // Construct a mock LocalMemoryManager that returns blocked futures.
+ LocalMemoryManager mockLocalMemoryManager = Mockito.mock(LocalMemoryManager.class);
+ MemoryPool mockMemoryPool =
+ Utils.createMockBlockedMemoryPool(queryId, numOfMockTsBlock, mockTsBlockSize);
+ Mockito.when(mockLocalMemoryManager.getQueryPool()).thenReturn(mockMemoryPool);
+ // Construct a mock SinkHandleListener.
+ SinkHandleListener mockSinkHandleListener = Mockito.mock(SinkHandleListener.class);
+ // Construct several mock TsBlock(s).
+ List<TsBlock> mockTsBlocks = Utils.createMockTsBlocks(numOfMockTsBlock, mockTsBlockSize);
+ // Construct a mock client.
+ Client mockClient = Mockito.mock(Client.class);
+ try {
+ Mockito.doThrow(new TException("Mock exception"))
+ .when(mockClient)
+ .onEndOfDataBlockEvent(Mockito.any(EndOfDataBlockEvent.class));
+ Mockito.doThrow(new TException("Mock exception"))
+ .when(mockClient)
+ .onNewDataBlockEvent(Mockito.any(NewDataBlockEvent.class));
+ } catch (TException e) {
+ e.printStackTrace();
+ Assert.fail();
+ }
+
+ // Construct SinkHandle.
+ SinkHandle sinkHandle =
+ new SinkHandle(
+ remoteHostname,
+ remoteFragmentInstanceId,
+ remotePlanNodeId,
+ localFragmentInstanceId,
+ mockLocalMemoryManager,
+ Executors.newSingleThreadExecutor(),
+ mockClient,
+ new TsBlockSerde(),
+ mockSinkHandleListener);
+ Assert.assertTrue(sinkHandle.isFull().isDone());
+ Assert.assertFalse(sinkHandle.isFinished());
+ Assert.assertFalse(sinkHandle.isClosed());
+ Assert.assertEquals(0L, sinkHandle.getBufferRetainedSizeInBytes());
+ Assert.assertEquals(0, sinkHandle.getNumOfBufferedTsBlocks());
+
+ // Send tsblocks.
+ try {
+ sinkHandle.send(mockTsBlocks);
+ } catch (IOException e) {
+ e.printStackTrace();
+ Assert.fail();
+ }
+ sinkHandle.setNoMoreTsBlocks();
+ Assert.assertFalse(sinkHandle.isFull().isDone());
+ Assert.assertFalse(sinkHandle.isFinished());
+ Assert.assertFalse(sinkHandle.isClosed());
+ Assert.assertEquals(
+ mockTsBlockSize * numOfMockTsBlock, sinkHandle.getBufferRetainedSizeInBytes());
+ Assert.assertEquals(numOfMockTsBlock, sinkHandle.getNumOfBufferedTsBlocks());
+ Mockito.verify(mockMemoryPool, Mockito.times(1))
+ .reserve(queryId, mockTsBlockSize * numOfMockTsBlock);
+ try {
+ Thread.sleep(100L);
+ Mockito.verify(mockClient, Mockito.times(SinkHandle.MAX_ATTEMPT_TIMES))
+ .onNewDataBlockEvent(
+ Mockito.argThat(
+ e ->
+ remoteFragmentInstanceId.equals(e.getTargetFragmentInstanceId())
+ && remotePlanNodeId.equals(e.getTargetPlanNodeId())
+ && localFragmentInstanceId.equals(e.getSourceFragmentInstanceId())
+ && e.getStartSequenceId() == 0
+ && e.getBlockSizes().size() == numOfMockTsBlock));
+ } catch (InterruptedException | TException e) {
+ e.printStackTrace();
+ Assert.fail();
+ }
+
+ try {
+ sinkHandle.send(Collections.singletonList(Mockito.mock(TsBlock.class)));
+ Assert.fail("Expect an IOException.");
+ } catch (IOException e) {
+ Assert.assertEquals("org.apache.thrift.TException: Mock exception", e.getMessage());
+ }
+
+ // Close the SinkHandle.
+ try {
+ sinkHandle.close();
+ Assert.fail("Expect an IOException.");
+ } catch (IOException e) {
+ Assert.assertEquals("org.apache.thrift.TException: Mock exception", e.getMessage());
+ }
+ Assert.assertFalse(sinkHandle.isClosed());
+ Mockito.verify(mockSinkHandleListener, Mockito.times(0)).onClosed(sinkHandle);
+
+ // Abort the SinkHandle.
+ sinkHandle.abort();
+ Assert.assertTrue(sinkHandle.isClosed());
+ Mockito.verify(mockSinkHandleListener, Mockito.times(1)).onAborted(sinkHandle);
+ Mockito.verify(mockSinkHandleListener, Mockito.times(0)).onFinish(sinkHandle);
+ }
+}
diff --git a/server/src/test/java/org/apache/iotdb/db/mpp/buffer/SourceHandleTest.java b/server/src/test/java/org/apache/iotdb/db/mpp/buffer/SourceHandleTest.java
new file mode 100644
index 0000000000..77b987a1c1
--- /dev/null
+++ b/server/src/test/java/org/apache/iotdb/db/mpp/buffer/SourceHandleTest.java
@@ -0,0 +1,591 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iotdb.db.mpp.buffer;
+
+import org.apache.iotdb.db.mpp.buffer.DataBlockManager.SourceHandleListener;
+import org.apache.iotdb.db.mpp.memory.LocalMemoryManager;
+import org.apache.iotdb.db.mpp.memory.MemoryPool;
+import org.apache.iotdb.mpp.rpc.thrift.AcknowledgeDataBlockEvent;
+import org.apache.iotdb.mpp.rpc.thrift.DataBlockService.Client;
+import org.apache.iotdb.mpp.rpc.thrift.GetDataBlockRequest;
+import org.apache.iotdb.mpp.rpc.thrift.GetDataBlockResponse;
+import org.apache.iotdb.mpp.rpc.thrift.TFragmentInstanceId;
+
+import org.apache.thrift.TException;
+import org.junit.Assert;
+import org.junit.Test;
+import org.mockito.Mockito;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.concurrent.Executors;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+
+public class SourceHandleTest {
+
+ @Test
+ public void testNonBlockedOneTimeReceive() {
+ final String queryId = "q0";
+ final long mockTsBlockSize = 1024L * 1024L;
+ final int numOfMockTsBlock = 10;
+ final String remoteHostname = "remote";
+ final TFragmentInstanceId remoteFragmentInstanceId =
+ new TFragmentInstanceId(queryId, "f1", "0");
+ final String localPlanNodeId = "exchange_0";
+ final TFragmentInstanceId localFragmentInstanceId = new TFragmentInstanceId(queryId, "f0", "0");
+
+ // Construct a mock LocalMemoryManager that do not block any reservation.
+ LocalMemoryManager mockLocalMemoryManager = Mockito.mock(LocalMemoryManager.class);
+ MemoryPool mockMemoryPool = Utils.createMockNonBlockedMemoryPool();
+ Mockito.when(mockLocalMemoryManager.getQueryPool()).thenReturn(mockMemoryPool);
+ // Construct a mock client.
+ Client mockClient = Mockito.mock(Client.class);
+ try {
+ Mockito.doAnswer(
+ invocation -> {
+ GetDataBlockRequest req = invocation.getArgument(0);
+ List<ByteBuffer> byteBuffers =
+ new ArrayList<>(req.getEndSequenceId() - req.getStartSequenceId());
+ for (int i = 0; i < req.getEndSequenceId() - req.getStartSequenceId(); i++) {
+ byteBuffers.add(ByteBuffer.allocate(0));
+ }
+ return new GetDataBlockResponse(byteBuffers);
+ })
+ .when(mockClient)
+ .getDataBlock(Mockito.any(GetDataBlockRequest.class));
+ } catch (TException e) {
+ e.printStackTrace();
+ Assert.fail();
+ }
+ // Construct a mock SourceHandleListener.
+ SourceHandleListener mockSourceHandleListener = Mockito.mock(SourceHandleListener.class);
+ // Construct a mock TsBlockSerde that deserializes any bytebuffer into a mock TsBlock.
+ TsBlockSerde mockTsBlockSerde = Utils.createMockTsBlockSerde(mockTsBlockSize);
+
+ SourceHandle sourceHandle =
+ new SourceHandle(
+ remoteHostname,
+ remoteFragmentInstanceId,
+ localFragmentInstanceId,
+ localPlanNodeId,
+ mockLocalMemoryManager,
+ Executors.newSingleThreadExecutor(),
+ mockClient,
+ mockTsBlockSerde,
+ mockSourceHandleListener);
+ Assert.assertFalse(sourceHandle.isBlocked().isDone());
+ Assert.assertFalse(sourceHandle.isClosed());
+ Assert.assertFalse(sourceHandle.isFinished());
+ Assert.assertEquals(0L, sourceHandle.getBufferRetainedSizeInBytes());
+
+ // New data blocks event arrived.
+ sourceHandle.updatePendingDataBlockInfo(
+ 0,
+ Stream.generate(() -> mockTsBlockSize)
+ .limit(numOfMockTsBlock)
+ .collect(Collectors.toList()));
+ try {
+ Thread.sleep(100L);
+ Mockito.verify(mockClient, Mockito.times(1))
+ .getDataBlock(
+ Mockito.argThat(
+ req ->
+ remoteFragmentInstanceId.equals(req.getSourceFragmentInstanceId())
+ && 0 == req.getStartSequenceId()
+ && numOfMockTsBlock == req.getEndSequenceId()));
+ Mockito.verify(mockClient, Mockito.times(1))
+ .onAcknowledgeDataBlockEvent(
+ Mockito.argThat(
+ e ->
+ remoteFragmentInstanceId.equals(e.getSourceFragmentInstanceId())
+ && 0 == e.getStartSequenceId()
+ && numOfMockTsBlock == e.getEndSequenceId()));
+ } catch (InterruptedException | TException e) {
+ e.printStackTrace();
+ Assert.fail();
+ }
+ Assert.assertTrue(sourceHandle.isBlocked().isDone());
+ Assert.assertFalse(sourceHandle.isClosed());
+ Assert.assertFalse(sourceHandle.isFinished());
+ Assert.assertEquals(
+ numOfMockTsBlock * mockTsBlockSize, sourceHandle.getBufferRetainedSizeInBytes());
+
+ // The local fragment instance consumes the data blocks.
+ for (int i = 0; i < numOfMockTsBlock; i++) {
+ try {
+ sourceHandle.receive();
+ } catch (IOException e) {
+ e.printStackTrace();
+ Assert.fail();
+ }
+ if (i < numOfMockTsBlock - 1) {
+ Assert.assertTrue(sourceHandle.isBlocked().isDone());
+ } else {
+ Assert.assertFalse(sourceHandle.isBlocked().isDone());
+ }
+ Assert.assertFalse(sourceHandle.isClosed());
+ Assert.assertFalse(sourceHandle.isFinished());
+ Assert.assertEquals(
+ (numOfMockTsBlock - 1 - i) * mockTsBlockSize,
+ sourceHandle.getBufferRetainedSizeInBytes());
+ }
+
+ // Receive EndOfDataBlock event from upstream fragment instance.
+ sourceHandle.setNoMoreTsBlocks(numOfMockTsBlock - 1);
+ Assert.assertFalse(sourceHandle.isBlocked().isDone());
+ Assert.assertFalse(sourceHandle.isClosed());
+ Assert.assertTrue(sourceHandle.isFinished());
+ Assert.assertEquals(0L, sourceHandle.getBufferRetainedSizeInBytes());
+
+ sourceHandle.close();
+ Assert.assertTrue(sourceHandle.isClosed());
+ Assert.assertTrue(sourceHandle.isFinished());
+ Assert.assertEquals(0L, sourceHandle.getBufferRetainedSizeInBytes());
+ }
+
+ @Test
+ public void testBlockedOneTimeReceive() {
+ final String queryId = "q0";
+ final long mockTsBlockSize = 1024L * 1024L;
+ final int numOfMockTsBlock = 10;
+ final String remoteHostname = "remote";
+ final TFragmentInstanceId remoteFragmentInstanceId =
+ new TFragmentInstanceId(queryId, "f1", "0");
+ final String localPlanNodeId = "exchange_0";
+ final TFragmentInstanceId localFragmentInstanceId = new TFragmentInstanceId(queryId, "f0", "0");
+
+ // Construct a mock LocalMemoryManager with capacity 3 * mockTsBlockSize.
+ LocalMemoryManager mockLocalMemoryManager = Mockito.mock(LocalMemoryManager.class);
+ MemoryPool spyMemoryPool =
+ Mockito.spy(new MemoryPool("test", 10 * mockTsBlockSize, 5 * mockTsBlockSize));
+ Mockito.when(mockLocalMemoryManager.getQueryPool()).thenReturn(spyMemoryPool);
+ // Construct a mock client.
+ Client mockClient = Mockito.mock(Client.class);
+ try {
+ Mockito.doAnswer(
+ invocation -> {
+ GetDataBlockRequest req = invocation.getArgument(0);
+ List<ByteBuffer> byteBuffers =
+ new ArrayList<>(req.getEndSequenceId() - req.getStartSequenceId());
+ for (int i = 0; i < req.getEndSequenceId() - req.getStartSequenceId(); i++) {
+ byteBuffers.add(ByteBuffer.allocate(0));
+ }
+ return new GetDataBlockResponse(byteBuffers);
+ })
+ .when(mockClient)
+ .getDataBlock(Mockito.any(GetDataBlockRequest.class));
+ } catch (TException e) {
+ e.printStackTrace();
+ Assert.fail();
+ }
+ // Construct a mock SourceHandleListener.
+ SourceHandleListener mockSourceHandleListener = Mockito.mock(SourceHandleListener.class);
+ // Construct a mock TsBlockSerde that deserializes any bytebuffer into a mock TsBlock.
+ TsBlockSerde mockTsBlockSerde = Utils.createMockTsBlockSerde(mockTsBlockSize);
+
+ SourceHandle sourceHandle =
+ new SourceHandle(
+ remoteHostname,
+ remoteFragmentInstanceId,
+ localFragmentInstanceId,
+ localPlanNodeId,
+ mockLocalMemoryManager,
+ Executors.newSingleThreadExecutor(),
+ mockClient,
+ mockTsBlockSerde,
+ mockSourceHandleListener);
+ Assert.assertFalse(sourceHandle.isBlocked().isDone());
+ Assert.assertFalse(sourceHandle.isClosed());
+ Assert.assertFalse(sourceHandle.isFinished());
+ Assert.assertEquals(0L, sourceHandle.getBufferRetainedSizeInBytes());
+
+ // New data blocks event arrived.
+ sourceHandle.updatePendingDataBlockInfo(
+ 0,
+ Stream.generate(() -> mockTsBlockSize)
+ .limit(numOfMockTsBlock)
+ .collect(Collectors.toList()));
+ try {
+ Thread.sleep(100L);
+ Mockito.verify(spyMemoryPool, Mockito.times(6)).reserve(queryId, mockTsBlockSize);
+ Mockito.verify(mockClient, Mockito.times(1))
+ .getDataBlock(
+ Mockito.argThat(
+ req ->
+ remoteFragmentInstanceId.equals(req.getSourceFragmentInstanceId())
+ && 0 == req.getStartSequenceId()
+ && 5 == req.getEndSequenceId()));
+ Mockito.verify(mockClient, Mockito.times(1))
+ .onAcknowledgeDataBlockEvent(
+ Mockito.argThat(
+ e ->
+ remoteFragmentInstanceId.equals(e.getSourceFragmentInstanceId())
+ && 0 == e.getStartSequenceId()
+ && 5 == e.getEndSequenceId()));
+ } catch (InterruptedException | TException e) {
+ e.printStackTrace();
+ Assert.fail();
+ }
+ Assert.assertTrue(sourceHandle.isBlocked().isDone());
+ Assert.assertFalse(sourceHandle.isClosed());
+ Assert.assertFalse(sourceHandle.isFinished());
+ Assert.assertEquals(5 * mockTsBlockSize, sourceHandle.getBufferRetainedSizeInBytes());
+
+ // The local fragment instance consumes the data blocks.
+ for (int i = 0; i < numOfMockTsBlock; i++) {
+ Mockito.verify(spyMemoryPool, Mockito.times(i)).free(queryId, mockTsBlockSize);
+ try {
+ sourceHandle.receive();
+ } catch (IOException e) {
+ e.printStackTrace();
+ Assert.fail();
+ }
+ try {
+ Thread.sleep(100L);
+ if (i < 5) {
+ Assert.assertEquals(5 * mockTsBlockSize, sourceHandle.getBufferRetainedSizeInBytes());
+ final int startSequenceId = 5 + i;
+ Mockito.verify(mockClient, Mockito.times(1))
+ .getDataBlock(
+ Mockito.argThat(
+ req ->
+ remoteFragmentInstanceId.equals(req.getSourceFragmentInstanceId())
+ && startSequenceId == req.getStartSequenceId()
+ && startSequenceId + 1 == req.getEndSequenceId()));
+ Mockito.verify(mockClient, Mockito.times(1))
+ .onAcknowledgeDataBlockEvent(
+ Mockito.argThat(
+ e ->
+ remoteFragmentInstanceId.equals(e.getSourceFragmentInstanceId())
+ && startSequenceId == e.getStartSequenceId()
+ && startSequenceId + 1 == e.getEndSequenceId()));
+ } else {
+ Assert.assertEquals(
+ (numOfMockTsBlock - 1 - i) * mockTsBlockSize,
+ sourceHandle.getBufferRetainedSizeInBytes());
+ }
+ } catch (InterruptedException | TException e) {
+ e.printStackTrace();
+ Assert.fail();
+ }
+ if (i < numOfMockTsBlock - 1) {
+ Assert.assertTrue(sourceHandle.isBlocked().isDone());
+ } else {
+ Assert.assertFalse(sourceHandle.isBlocked().isDone());
+ }
+ Assert.assertFalse(sourceHandle.isClosed());
+ Assert.assertFalse(sourceHandle.isFinished());
+ }
+
+ // Receive EndOfDataBlock event from upstream fragment instance.
+ sourceHandle.setNoMoreTsBlocks(numOfMockTsBlock - 1);
+ Assert.assertFalse(sourceHandle.isBlocked().isDone());
+ Assert.assertFalse(sourceHandle.isClosed());
+ Assert.assertTrue(sourceHandle.isFinished());
+ Assert.assertEquals(0L, sourceHandle.getBufferRetainedSizeInBytes());
+
+ sourceHandle.close();
+ Assert.assertTrue(sourceHandle.isFinished());
+ Assert.assertEquals(0L, sourceHandle.getBufferRetainedSizeInBytes());
+ }
+
+ @Test
+ public void testMultiTimesReceive() {
+ final String queryId = "q0";
+ final long mockTsBlockSize = 1024L * 1024L;
+ final int numOfMockTsBlock = 10;
+ final String remoteHostname = "remote";
+ final TFragmentInstanceId remoteFragmentInstanceId =
+ new TFragmentInstanceId(queryId, "f1", "0");
+ final String localPlanNodeId = "exchange_0";
+ final TFragmentInstanceId localFragmentInstanceId = new TFragmentInstanceId(queryId, "f0", "0");
+
+ // Construct a mock LocalMemoryManager that returns unblocked futures.
+ LocalMemoryManager mockLocalMemoryManager = Mockito.mock(LocalMemoryManager.class);
+ MemoryPool mockMemoryPool = Utils.createMockNonBlockedMemoryPool();
+ Mockito.when(mockLocalMemoryManager.getQueryPool()).thenReturn(mockMemoryPool);
+ // Construct a mock SourceHandleListener.
+ SourceHandleListener mockSourceHandleListener = Mockito.mock(SourceHandleListener.class);
+ // Construct a mock TsBlockSerde that deserializes any bytebuffer into a mock TsBlock.
+ TsBlockSerde mockTsBlockSerde = Utils.createMockTsBlockSerde(mockTsBlockSize);
+ // Construct a mock client.
+ Client mockClient = Mockito.mock(Client.class);
+ try {
+ Mockito.doAnswer(
+ invocation -> {
+ GetDataBlockRequest req = invocation.getArgument(0);
+ List<ByteBuffer> byteBuffers =
+ new ArrayList<>(req.getEndSequenceId() - req.getStartSequenceId());
+ for (int i = 0; i < req.getEndSequenceId() - req.getStartSequenceId(); i++) {
+ byteBuffers.add(ByteBuffer.allocate(0));
+ }
+ return new GetDataBlockResponse(byteBuffers);
+ })
+ .when(mockClient)
+ .getDataBlock(Mockito.any(GetDataBlockRequest.class));
+ } catch (TException e) {
+ e.printStackTrace();
+ Assert.fail();
+ }
+
+ SourceHandle sourceHandle =
+ new SourceHandle(
+ remoteHostname,
+ remoteFragmentInstanceId,
+ localFragmentInstanceId,
+ localPlanNodeId,
+ mockLocalMemoryManager,
+ Executors.newSingleThreadExecutor(),
+ mockClient,
+ mockTsBlockSerde,
+ mockSourceHandleListener);
+ Assert.assertFalse(sourceHandle.isBlocked().isDone());
+ Assert.assertFalse(sourceHandle.isClosed());
+ Assert.assertFalse(sourceHandle.isFinished());
+ Assert.assertEquals(0L, sourceHandle.getBufferRetainedSizeInBytes());
+
+ // New data blocks event arrived in unordered manner.
+ sourceHandle.updatePendingDataBlockInfo(
+ numOfMockTsBlock,
+ Stream.generate(() -> mockTsBlockSize)
+ .limit(numOfMockTsBlock)
+ .collect(Collectors.toList()));
+ try {
+ Thread.sleep(100L);
+ Mockito.verify(mockClient, Mockito.times(0))
+ .getDataBlock(Mockito.any(GetDataBlockRequest.class));
+ Mockito.verify(mockClient, Mockito.times(0))
+ .onAcknowledgeDataBlockEvent(Mockito.any(AcknowledgeDataBlockEvent.class));
+ } catch (InterruptedException | TException e) {
+ e.printStackTrace();
+ Assert.fail();
+ }
+ Assert.assertFalse(sourceHandle.isBlocked().isDone());
+ Assert.assertFalse(sourceHandle.isClosed());
+ Assert.assertFalse(sourceHandle.isFinished());
+ Assert.assertEquals(0L, sourceHandle.getBufferRetainedSizeInBytes());
+
+ sourceHandle.updatePendingDataBlockInfo(
+ 0,
+ Stream.generate(() -> mockTsBlockSize)
+ .limit(numOfMockTsBlock)
+ .collect(Collectors.toList()));
+ try {
+ Thread.sleep(100L);
+ Mockito.verify(mockClient, Mockito.times(1))
+ .getDataBlock(
+ Mockito.argThat(
+ req ->
+ remoteFragmentInstanceId.equals(req.getSourceFragmentInstanceId())
+ && 0 == req.getStartSequenceId()
+ && numOfMockTsBlock * 2 == req.getEndSequenceId()));
+ Mockito.verify(mockClient, Mockito.times(1))
+ .onAcknowledgeDataBlockEvent(
+ Mockito.argThat(
+ e ->
+ remoteFragmentInstanceId.equals(e.getSourceFragmentInstanceId())
+ && 0 == e.getStartSequenceId()
+ && numOfMockTsBlock * 2 == e.getEndSequenceId()));
+ } catch (InterruptedException | TException e) {
+ e.printStackTrace();
+ Assert.fail();
+ }
+ Assert.assertTrue(sourceHandle.isBlocked().isDone());
+ Assert.assertFalse(sourceHandle.isClosed());
+ Assert.assertFalse(sourceHandle.isFinished());
+ Assert.assertEquals(
+ numOfMockTsBlock * 2 * mockTsBlockSize, sourceHandle.getBufferRetainedSizeInBytes());
+
+ // The local fragment instance consumes the data blocks.
+ for (int i = 0; i < 2 * numOfMockTsBlock; i++) {
+ try {
+ sourceHandle.receive();
+ } catch (IOException e) {
+ e.printStackTrace();
+ Assert.fail();
+ }
+ if (i < 2 * numOfMockTsBlock - 1) {
+ Assert.assertTrue(sourceHandle.isBlocked().isDone());
+ } else {
+ Assert.assertFalse(sourceHandle.isBlocked().isDone());
+ }
+ Assert.assertFalse(sourceHandle.isClosed());
+ Assert.assertFalse(sourceHandle.isFinished());
+ Assert.assertEquals(
+ (2 * numOfMockTsBlock - 1 - i) * mockTsBlockSize,
+ sourceHandle.getBufferRetainedSizeInBytes());
+ }
+
+ // New data blocks event arrived.
+ sourceHandle.updatePendingDataBlockInfo(
+ numOfMockTsBlock * 2,
+ Stream.generate(() -> mockTsBlockSize)
+ .limit(numOfMockTsBlock)
+ .collect(Collectors.toList()));
+ try {
+ Thread.sleep(100L);
+ Mockito.verify(mockClient, Mockito.times(1))
+ .getDataBlock(
+ Mockito.argThat(
+ req ->
+ remoteFragmentInstanceId.equals(req.getSourceFragmentInstanceId())
+ && numOfMockTsBlock * 2 == req.getStartSequenceId()
+ && numOfMockTsBlock * 3 == req.getEndSequenceId()));
+ Mockito.verify(mockClient, Mockito.times(1))
+ .onAcknowledgeDataBlockEvent(
+ Mockito.argThat(
+ e ->
+ remoteFragmentInstanceId.equals(e.getSourceFragmentInstanceId())
+ && numOfMockTsBlock * 2 == e.getStartSequenceId()
+ && numOfMockTsBlock * 3 == e.getEndSequenceId()));
+ } catch (InterruptedException | TException e) {
+ e.printStackTrace();
+ Assert.fail();
+ }
+ Assert.assertTrue(sourceHandle.isBlocked().isDone());
+ Assert.assertFalse(sourceHandle.isClosed());
+ Assert.assertFalse(sourceHandle.isFinished());
+ Assert.assertEquals(
+ numOfMockTsBlock * mockTsBlockSize, sourceHandle.getBufferRetainedSizeInBytes());
+
+ // The local fragment instance consumes the data blocks.
+ for (int i = 0; i < numOfMockTsBlock; i++) {
+ try {
+ sourceHandle.receive();
+ } catch (IOException e) {
+ e.printStackTrace();
+ Assert.fail();
+ }
+ if (i < numOfMockTsBlock - 1) {
+ Assert.assertTrue(sourceHandle.isBlocked().isDone());
+ } else {
+ Assert.assertFalse(sourceHandle.isBlocked().isDone());
+ }
+ Assert.assertFalse(sourceHandle.isClosed());
+ Assert.assertFalse(sourceHandle.isFinished());
+ Assert.assertEquals(
+ (numOfMockTsBlock - 1 - i) * mockTsBlockSize,
+ sourceHandle.getBufferRetainedSizeInBytes());
+ }
+
+ // Receive EndOfDataBlock event from upstream fragment instance.
+ sourceHandle.setNoMoreTsBlocks(3 * numOfMockTsBlock - 1);
+ Assert.assertFalse(sourceHandle.isBlocked().isDone());
+ Assert.assertFalse(sourceHandle.isClosed());
+ Assert.assertTrue(sourceHandle.isFinished());
+ Assert.assertEquals(0L, sourceHandle.getBufferRetainedSizeInBytes());
+
+ sourceHandle.close();
+ Assert.assertTrue(sourceHandle.isClosed());
+ Assert.assertTrue(sourceHandle.isFinished());
+ Assert.assertEquals(0L, sourceHandle.getBufferRetainedSizeInBytes());
+ }
+
+ @Test
+ public void testFailedReceive() {
+ final String queryId = "q0";
+ final long mockTsBlockSize = 1024L * 1024L;
+ final int numOfMockTsBlock = 10;
+ final String remoteHostname = "remote";
+ final TFragmentInstanceId remoteFragmentInstanceId =
+ new TFragmentInstanceId(queryId, "f1", "0");
+ final String localPlanNodeId = "exchange_0";
+ final TFragmentInstanceId localFragmentInstanceId = new TFragmentInstanceId(queryId, "f0", "0");
+
+ // Construct a mock LocalMemoryManager that returns unblocked futures.
+ LocalMemoryManager mockLocalMemoryManager = Mockito.mock(LocalMemoryManager.class);
+ MemoryPool mockMemoryPool = Utils.createMockNonBlockedMemoryPool();
+ Mockito.when(mockLocalMemoryManager.getQueryPool()).thenReturn(mockMemoryPool);
+ // Construct a mock SourceHandleListener.
+ SourceHandleListener mockSourceHandleListener = Mockito.mock(SourceHandleListener.class);
+ // Construct a mock TsBlockSerde that deserializes any bytebuffer into a mock TsBlock.
+ TsBlockSerde mockTsBlockSerde = Utils.createMockTsBlockSerde(mockTsBlockSize);
+ // Construct a mock client.
+ Client mockClient = Mockito.mock(Client.class);
+ try {
+ Mockito.doThrow(new TException("Mock exception"))
+ .when(mockClient)
+ .getDataBlock(Mockito.any(GetDataBlockRequest.class));
+ } catch (TException e) {
+ e.printStackTrace();
+ Assert.fail();
+ }
+
+ SourceHandle sourceHandle =
+ new SourceHandle(
+ remoteHostname,
+ remoteFragmentInstanceId,
+ localFragmentInstanceId,
+ localPlanNodeId,
+ mockLocalMemoryManager,
+ Executors.newSingleThreadExecutor(),
+ mockClient,
+ mockTsBlockSerde,
+ mockSourceHandleListener);
+ Assert.assertFalse(sourceHandle.isBlocked().isDone());
+ Assert.assertFalse(sourceHandle.isClosed());
+ Assert.assertFalse(sourceHandle.isFinished());
+ Assert.assertEquals(0L, sourceHandle.getBufferRetainedSizeInBytes());
+
+ // Receive data blocks from upstream fragment instance.
+ // New data blocks event arrived.
+ sourceHandle.updatePendingDataBlockInfo(
+ 0,
+ Stream.generate(() -> mockTsBlockSize)
+ .limit(numOfMockTsBlock)
+ .collect(Collectors.toList()));
+ try {
+ Thread.sleep(100L);
+ } catch (InterruptedException e) {
+ e.printStackTrace();
+ Assert.fail();
+ }
+ try {
+ Assert.assertFalse(sourceHandle.isBlocked().isDone());
+ Assert.fail("Expect an IOException.");
+ } catch (RuntimeException e) {
+ Assert.assertEquals("org.apache.thrift.TException: Mock exception", e.getMessage());
+ }
+ Assert.assertFalse(sourceHandle.isClosed());
+ Assert.assertFalse(sourceHandle.isFinished());
+ Assert.assertEquals(0L, sourceHandle.getBufferRetainedSizeInBytes());
+
+ // The local fragment instance consumes the data blocks.
+ try {
+ sourceHandle.receive();
+ Assert.fail("Expect an IOException.");
+ } catch (IOException e) {
+ Assert.assertEquals("org.apache.thrift.TException: Mock exception", e.getMessage());
+ }
+
+ // Receive EndOfDataBlock event from upstream fragment instance.
+ sourceHandle.setNoMoreTsBlocks(numOfMockTsBlock - 1);
+ Assert.assertFalse(sourceHandle.isClosed());
+ Assert.assertFalse(sourceHandle.isFinished());
+ Assert.assertEquals(0L, sourceHandle.getBufferRetainedSizeInBytes());
+
+ sourceHandle.close();
+ Assert.assertFalse(sourceHandle.isFinished());
+ Assert.assertEquals(0L, sourceHandle.getBufferRetainedSizeInBytes());
+ }
+}
diff --git a/server/src/test/java/org/apache/iotdb/db/mpp/buffer/Utils.java b/server/src/test/java/org/apache/iotdb/db/mpp/buffer/Utils.java
new file mode 100644
index 0000000000..e61e623937
--- /dev/null
+++ b/server/src/test/java/org/apache/iotdb/db/mpp/buffer/Utils.java
@@ -0,0 +1,105 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iotdb.db.mpp.buffer;
+
+import org.apache.iotdb.db.mpp.memory.MemoryPool;
+import org.apache.iotdb.tsfile.read.common.block.TsBlock;
+
+import com.google.common.util.concurrent.SettableFuture;
+import org.mockito.Mockito;
+import org.mockito.stubbing.Answer;
+
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.concurrent.atomic.AtomicReference;
+
+import static com.google.common.util.concurrent.Futures.immediateFuture;
+
+public class Utils {
+ public static List<TsBlock> createMockTsBlocks(int numOfTsBlocks, long mockTsBlockSize) {
+ List<TsBlock> mockTsBlocks = new ArrayList<>(numOfTsBlocks);
+ for (int i = 0; i < numOfTsBlocks; i++) {
+ TsBlock mockTsBlock = Mockito.mock(TsBlock.class);
+ Mockito.when(mockTsBlock.getRetainedSizeInBytes()).thenReturn(mockTsBlockSize);
+ mockTsBlocks.add(mockTsBlock);
+ }
+
+ return mockTsBlocks;
+ }
+
+ public static MemoryPool createMockBlockedMemoryPool(
+ String queryId, int numOfMockTsBlock, long mockTsBlockSize) {
+ MemoryPool mockMemoryPool = Mockito.mock(MemoryPool.class);
+ AtomicReference<SettableFuture<Void>> settableFuture = new AtomicReference<>();
+ settableFuture.set(SettableFuture.create());
+ settableFuture.get().set(null);
+ AtomicReference<Long> reservedBytes = new AtomicReference<>(0L);
+ Mockito.when(mockMemoryPool.reserve(Mockito.eq(queryId), Mockito.anyLong()))
+ .thenAnswer(
+ invocation -> {
+ reservedBytes.updateAndGet(v -> v + (long) invocation.getArgument(1));
+ settableFuture.set(SettableFuture.create());
+ return settableFuture.get();
+ });
+ Mockito.doAnswer(
+ (Answer<Void>)
+ invocation -> {
+ reservedBytes.updateAndGet(v -> v - (long) invocation.getArgument(1));
+ if (reservedBytes.get() <= 0) {
+ settableFuture.get().set(null);
+ }
+ return null;
+ })
+ .when(mockMemoryPool)
+ .free(Mockito.eq(queryId), Mockito.anyLong());
+ long capacityInBytes = numOfMockTsBlock * mockTsBlockSize;
+ Mockito.when(mockMemoryPool.tryReserve(Mockito.eq(queryId), Mockito.anyLong()))
+ .thenAnswer(
+ invocation -> {
+ long bytesToReserve = invocation.getArgument(1);
+ if (reservedBytes.get() + bytesToReserve > capacityInBytes) {
+ return false;
+ } else {
+ reservedBytes.updateAndGet(v -> v + bytesToReserve);
+ return true;
+ }
+ });
+ return mockMemoryPool;
+ }
+
+ public static MemoryPool createMockNonBlockedMemoryPool() {
+ MemoryPool mockMemoryPool = Mockito.mock(MemoryPool.class);
+ Mockito.when(mockMemoryPool.reserve(Mockito.anyString(), Mockito.anyLong()))
+ .thenReturn(immediateFuture(null));
+ Mockito.when(mockMemoryPool.tryReserve(Mockito.anyString(), Mockito.anyLong()))
+ .thenReturn(true);
+ return mockMemoryPool;
+ }
+
+ public static TsBlockSerde createMockTsBlockSerde(long mockTsBlockSize) {
+ TsBlockSerde mockTsBlockSerde = Mockito.mock(TsBlockSerde.class);
+ TsBlock mockTsBlock = Mockito.mock(TsBlock.class);
+ Mockito.when(mockTsBlock.getRetainedSizeInBytes()).thenReturn(mockTsBlockSize);
+ Mockito.when(mockTsBlockSerde.deserialize(Mockito.any(ByteBuffer.class)))
+ .thenReturn(mockTsBlock);
+ return mockTsBlockSerde;
+ }
+}
diff --git a/server/src/test/java/org/apache/iotdb/db/mpp/memory/MemoryPoolTest.java b/server/src/test/java/org/apache/iotdb/db/mpp/memory/MemoryPoolTest.java
index 8f21c7dc77..c1ca224868 100644
--- a/server/src/test/java/org/apache/iotdb/db/mpp/memory/MemoryPoolTest.java
+++ b/server/src/test/java/org/apache/iotdb/db/mpp/memory/MemoryPoolTest.java
@@ -19,6 +19,7 @@
package org.apache.iotdb.db.mpp.memory;
+import com.google.common.util.concurrent.ListenableFuture;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
@@ -29,23 +30,70 @@ public class MemoryPoolTest {
@Before
public void before() {
- pool = new MemoryPool("test", 1024L);
+ pool = new MemoryPool("test", 1024L, 512L);
}
@Test
- public void testReserve() {
+ public void testTryReserve() {
+ String queryId = "q0";
+ Assert.assertTrue(pool.tryReserve(queryId, 256L));
+ Assert.assertEquals(256L, pool.getQueryMemoryReservedBytes(queryId));
+ Assert.assertEquals(256L, pool.getReservedBytes());
+ }
+
+ @Test
+ public void testTryReserveZero() {
+ String queryId = "q0";
+ try {
+ pool.tryReserve(queryId, 0L);
+ Assert.fail("Expect IllegalArgumentException");
+ } catch (IllegalArgumentException ignore) {
+ }
+ }
+
+ @Test
+ public void testTryReserveNegative() {
+ String queryId = "q0";
+ try {
+ pool.tryReserve(queryId, -1L);
+ Assert.fail("Expect IllegalArgumentException");
+ } catch (IllegalArgumentException ignore) {
+ }
+ }
+
+ @Test
+ public void testTryReserveAll() {
String queryId = "q0";
Assert.assertTrue(pool.tryReserve(queryId, 512L));
Assert.assertEquals(512L, pool.getQueryMemoryReservedBytes(queryId));
Assert.assertEquals(512L, pool.getReservedBytes());
- Assert.assertEquals(1024L, pool.getMaxBytes());
}
@Test
- public void testReserveZero() {
+ public void testOverTryReserve() {
+ String queryId = "q0";
+ Assert.assertTrue(pool.tryReserve(queryId, 256L));
+ Assert.assertEquals(256L, pool.getQueryMemoryReservedBytes(queryId));
+ Assert.assertEquals(256L, pool.getReservedBytes());
+ Assert.assertFalse(pool.tryReserve(queryId, 512L));
+ Assert.assertEquals(256L, pool.getQueryMemoryReservedBytes(queryId));
+ Assert.assertEquals(256L, pool.getReservedBytes());
+ }
+
+ @Test
+ public void testReserve() {
+ String queryId = "q0";
+ ListenableFuture<Void> future = pool.reserve(queryId, 256L);
+ Assert.assertTrue(future.isDone());
+ Assert.assertEquals(256L, pool.getQueryMemoryReservedBytes(queryId));
+ Assert.assertEquals(256L, pool.getReservedBytes());
+ }
+
+ @Test
+ public void tesReserveZero() {
String queryId = "q0";
try {
- pool.tryReserve(queryId, 0L);
+ pool.reserve(queryId, 0L);
Assert.fail("Expect IllegalArgumentException");
} catch (IllegalArgumentException ignore) {
}
@@ -55,7 +103,7 @@ public class MemoryPoolTest {
public void testReserveNegative() {
String queryId = "q0";
try {
- pool.tryReserve(queryId, -1L);
+ pool.reserve(queryId, -1L);
Assert.fail("Expect IllegalArgumentException");
} catch (IllegalArgumentException ignore) {
}
@@ -64,19 +112,74 @@ public class MemoryPoolTest {
@Test
public void testReserveAll() {
String queryId = "q0";
- Assert.assertTrue(pool.tryReserve(queryId, 1024L));
- Assert.assertEquals(1024L, pool.getQueryMemoryReservedBytes(queryId));
- Assert.assertEquals(1024L, pool.getReservedBytes());
- Assert.assertEquals(1024L, pool.getMaxBytes());
+ ListenableFuture<Void> future = pool.reserve(queryId, 512L);
+ Assert.assertTrue(future.isDone());
+ Assert.assertEquals(512L, pool.getQueryMemoryReservedBytes(queryId));
+ Assert.assertEquals(512L, pool.getReservedBytes());
}
@Test
public void testOverReserve() {
String queryId = "q0";
- Assert.assertFalse(pool.tryReserve(queryId, 1025L));
+ ListenableFuture<Void> future = pool.reserve(queryId, 256L);
+ Assert.assertTrue(future.isDone());
+ Assert.assertEquals(256L, pool.getQueryMemoryReservedBytes(queryId));
+ Assert.assertEquals(256L, pool.getReservedBytes());
+ future = pool.reserve(queryId, 512L);
+ Assert.assertFalse(future.isDone());
+ Assert.assertEquals(256L, pool.getQueryMemoryReservedBytes(queryId));
+ Assert.assertEquals(256L, pool.getReservedBytes());
+ }
+
+ @Test
+ public void testReserveAndFree() {
+ String queryId = "q0";
+ Assert.assertTrue(pool.reserve(queryId, 512L).isDone());
+ ListenableFuture<Void> future = pool.reserve(queryId, 512L);
+ Assert.assertFalse(future.isDone());
+ Assert.assertEquals(512L, pool.getQueryMemoryReservedBytes(queryId));
+ Assert.assertEquals(512L, pool.getReservedBytes());
+ pool.free(queryId, 512);
+ Assert.assertEquals(512L, pool.getQueryMemoryReservedBytes(queryId));
+ Assert.assertEquals(512L, pool.getReservedBytes());
+ Assert.assertTrue(future.isDone());
+ }
+
+ @Test
+ public void testMultiReserveAndFree() {
+ String queryId = "q0";
+ Assert.assertTrue(pool.reserve(queryId, 256L).isDone());
+ Assert.assertEquals(256L, pool.getQueryMemoryReservedBytes(queryId));
+ Assert.assertEquals(256L, pool.getReservedBytes());
+
+ ListenableFuture<Void> future1 = pool.reserve(queryId, 512L);
+ ListenableFuture<Void> future2 = pool.reserve(queryId, 512L);
+ ListenableFuture<Void> future3 = pool.reserve(queryId, 512L);
+ Assert.assertFalse(future1.isDone());
+ Assert.assertFalse(future2.isDone());
+ Assert.assertFalse(future3.isDone());
+
+ pool.free(queryId, 256L);
+ Assert.assertTrue(future1.isDone());
+ Assert.assertFalse(future2.isDone());
+ Assert.assertFalse(future3.isDone());
+ Assert.assertEquals(512L, pool.getQueryMemoryReservedBytes(queryId));
+ Assert.assertEquals(512L, pool.getReservedBytes());
+
+ pool.free(queryId, 512L);
+ Assert.assertTrue(future2.isDone());
+ Assert.assertFalse(future3.isDone());
+ Assert.assertEquals(512L, pool.getQueryMemoryReservedBytes(queryId));
+ Assert.assertEquals(512L, pool.getReservedBytes());
+
+ pool.free(queryId, 512L);
+ Assert.assertTrue(future3.isDone());
+ Assert.assertEquals(512L, pool.getQueryMemoryReservedBytes(queryId));
+ Assert.assertEquals(512L, pool.getReservedBytes());
+
+ pool.free(queryId, 512L);
Assert.assertEquals(0L, pool.getQueryMemoryReservedBytes(queryId));
Assert.assertEquals(0L, pool.getReservedBytes());
- Assert.assertEquals(1024L, pool.getMaxBytes());
}
@Test
@@ -85,7 +188,6 @@ public class MemoryPoolTest {
Assert.assertTrue(pool.tryReserve(queryId, 512L));
Assert.assertEquals(512L, pool.getQueryMemoryReservedBytes(queryId));
Assert.assertEquals(512L, pool.getReservedBytes());
- Assert.assertEquals(1024L, pool.getMaxBytes());
pool.free(queryId, 256L);
Assert.assertEquals(256L, pool.getQueryMemoryReservedBytes(queryId));
@@ -98,7 +200,6 @@ public class MemoryPoolTest {
Assert.assertTrue(pool.tryReserve(queryId, 512L));
Assert.assertEquals(512L, pool.getQueryMemoryReservedBytes(queryId));
Assert.assertEquals(512L, pool.getReservedBytes());
- Assert.assertEquals(1024L, pool.getMaxBytes());
pool.free(queryId, 512L);
Assert.assertEquals(0L, pool.getQueryMemoryReservedBytes(queryId));
@@ -111,11 +212,12 @@ public class MemoryPoolTest {
Assert.assertTrue(pool.tryReserve(queryId, 512L));
Assert.assertEquals(512L, pool.getQueryMemoryReservedBytes(queryId));
Assert.assertEquals(512L, pool.getReservedBytes());
- Assert.assertEquals(1024L, pool.getMaxBytes());
- pool.free(queryId, 256L);
- Assert.assertEquals(256L, pool.getQueryMemoryReservedBytes(queryId));
- Assert.assertEquals(256L, pool.getReservedBytes());
+ try {
+ pool.free(queryId, 0L);
+ Assert.fail("Expect IllegalArgumentException");
+ } catch (IllegalArgumentException ignore) {
+ }
}
@Test
@@ -124,7 +226,6 @@ public class MemoryPoolTest {
Assert.assertTrue(pool.tryReserve(queryId, 512L));
Assert.assertEquals(512L, pool.getQueryMemoryReservedBytes(queryId));
Assert.assertEquals(512L, pool.getReservedBytes());
- Assert.assertEquals(1024L, pool.getMaxBytes());
try {
pool.free(queryId, -1L);
@@ -139,7 +240,6 @@ public class MemoryPoolTest {
Assert.assertTrue(pool.tryReserve(queryId, 512L));
Assert.assertEquals(512L, pool.getQueryMemoryReservedBytes(queryId));
Assert.assertEquals(512L, pool.getReservedBytes());
- Assert.assertEquals(1024L, pool.getMaxBytes());
try {
pool.free(queryId, 513L);
diff --git a/thrift/src/main/thrift/mpp.thrift b/thrift/src/main/thrift/mpp.thrift
index 00b6b6cdeb..b1cae785d6 100644
--- a/thrift/src/main/thrift/mpp.thrift
+++ b/thrift/src/main/thrift/mpp.thrift
@@ -27,23 +27,34 @@ struct TFragmentInstanceId {
}
struct GetDataBlockRequest {
- 1: required TFragmentInstanceId fragmentInstanceId
- 2: required i64 blockId
+ 1: required TFragmentInstanceId sourceFragmentInstanceId
+ 2: required i32 startSequenceId
+ 3: required i32 endSequenceId
}
struct GetDataBlockResponse {
1: required list<binary> tsBlocks
}
+struct AcknowledgeDataBlockEvent {
+ 1: required TFragmentInstanceId sourceFragmentInstanceId
+ 2: required i32 startSequenceId
+ 3: required i32 endSequenceId
+}
+
struct NewDataBlockEvent {
- 1: required TFragmentInstanceId fragmentInstanceId
- 2: required string operatorId
- 3: required i64 blockId
+ 1: required TFragmentInstanceId targetFragmentInstanceId
+ 2: required string targetPlanNodeId
+ 3: required TFragmentInstanceId sourceFragmentInstanceId
+ 4: required i32 startSequenceId
+ 5: required list<i64> blockSizes
}
struct EndOfDataBlockEvent {
- 1: required TFragmentInstanceId fragmentInstanceId
- 2: required string operatorId
+ 1: required TFragmentInstanceId targetFragmentInstanceId
+ 2: required string targetPlanNodeId
+ 3: required TFragmentInstanceId sourceFragmentInstanceId
+ 4: required i32 lastSequenceId
}
struct TFragmentInstance {
@@ -111,6 +122,8 @@ service InternalService {
service DataBlockService {
GetDataBlockResponse getDataBlock(GetDataBlockRequest req);
+ void onAcknowledgeDataBlockEvent(AcknowledgeDataBlockEvent e);
+
void onNewDataBlockEvent(NewDataBlockEvent e);
void onEndOfDataBlockEvent(EndOfDataBlockEvent e);