You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@iotdb.apache.org by hu...@apache.org on 2023/03/23 07:18:52 UTC
[iotdb] 02/02: fix bug & finish
This is an automated email from the ASF dual-hosted git repository.
hui pushed a commit to branch lmh/MLSQL
in repository https://gitbox.apache.org/repos/asf/iotdb.git
commit 71fb614bb14cb9fdfbe87a2053628ba5e676fb96
Author: Minghui Liu <li...@foxmail.com>
AuthorDate: Thu Mar 23 15:17:35 2023 +0800
fix bug & finish
---
.../iotdb/confignode/persistence/ModelInfo.java | 2 +-
.../procedure/impl/model/CreateModelProcedure.java | 2 +-
.../procedure/impl/model/DropModelProcedure.java | 27 +-------
.../procedure/state/model/DropModelState.java | 1 -
.../procedure/store/ProcedureFactory.java | 4 ++
mlnode/iotdb/mlnode/service.py | 2 +-
.../iotdb/commons/model/ModelInformation.java | 79 ++++++++++++++++++----
.../org/apache/iotdb/db/client/MLNodeClient.java | 18 +++--
.../impl/DataNodeInternalRPCServiceImpl.java | 3 +-
9 files changed, 88 insertions(+), 50 deletions(-)
diff --git a/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java b/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java
index 13e4dabe73..3c72e09570 100644
--- a/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java
+++ b/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java
@@ -84,7 +84,7 @@ public class ModelInfo implements SnapshotProcessor {
return new TSStatus(TSStatusCode.EXECUTE_STATEMENT_ERROR.getStatusCode())
.setMessage(errorMessage);
}
- return null;
+ return new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode());
}
public TSStatus dropModel(DropModelPlan plan) {
diff --git a/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/CreateModelProcedure.java b/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/CreateModelProcedure.java
index 0a4d306fde..7dff5fe06e 100644
--- a/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/CreateModelProcedure.java
+++ b/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/CreateModelProcedure.java
@@ -49,7 +49,7 @@ import java.util.Objects;
public class CreateModelProcedure extends AbstractNodeProcedure<CreateModelState> {
private static final Logger LOGGER = LoggerFactory.getLogger(CreateModelProcedure.class);
- private static final int RETRY_THRESHOLD = 5;
+ private static final int RETRY_THRESHOLD = 1;
private ModelInformation modelInformation;
private Map<String, String> modelConfigs;
diff --git a/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/DropModelProcedure.java b/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/DropModelProcedure.java
index 1268368bd6..bfa461a8b2 100644
--- a/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/DropModelProcedure.java
+++ b/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/DropModelProcedure.java
@@ -21,19 +21,16 @@ package org.apache.iotdb.confignode.procedure.impl.model;
import org.apache.iotdb.common.rpc.thrift.TDataNodeLocation;
import org.apache.iotdb.common.rpc.thrift.TSStatus;
-import org.apache.iotdb.common.rpc.thrift.TrainingState;
import org.apache.iotdb.commons.model.exception.ModelManagementException;
import org.apache.iotdb.confignode.client.DataNodeRequestType;
import org.apache.iotdb.confignode.client.sync.SyncDataNodeClientPool;
import org.apache.iotdb.confignode.consensus.request.write.model.DropModelPlan;
-import org.apache.iotdb.confignode.consensus.request.write.model.UpdateModelStatePlan;
import org.apache.iotdb.confignode.persistence.ModelInfo;
import org.apache.iotdb.confignode.procedure.env.ConfigNodeProcedureEnv;
import org.apache.iotdb.confignode.procedure.exception.ProcedureException;
import org.apache.iotdb.confignode.procedure.impl.node.AbstractNodeProcedure;
import org.apache.iotdb.confignode.procedure.state.model.DropModelState;
import org.apache.iotdb.confignode.procedure.store.ProcedureType;
-import org.apache.iotdb.confignode.rpc.thrift.TUpdateModelStateReq;
import org.apache.iotdb.consensus.common.response.ConsensusWriteResponse;
import org.apache.iotdb.db.client.MLNodeClient;
import org.apache.iotdb.mpp.rpc.thrift.TDeleteModelMetricsReq;
@@ -53,7 +50,7 @@ import java.util.Optional;
public class DropModelProcedure extends AbstractNodeProcedure<DropModelState> {
private static final Logger LOGGER = LoggerFactory.getLogger(DropModelProcedure.class);
- private static final int RETRY_THRESHOLD = 5;
+ private static final int RETRY_THRESHOLD = 1;
private String modelId;
@@ -87,25 +84,6 @@ public class DropModelProcedure extends AbstractNodeProcedure<DropModelState> {
break;
case VALIDATED:
- LOGGER.info("Change state of model [{}] to DROPPING", modelId);
-
- ConsensusWriteResponse response =
- env.getConfigManager()
- .getConsensusManager()
- .write(
- new UpdateModelStatePlan(
- new TUpdateModelStateReq(modelId, TrainingState.DROPPING)));
- if (!response.isSuccessful()) {
- throw new ModelManagementException(
- String.format(
- "Failed to drop model [%s], fail to modify model state: %s",
- modelId, response.getErrorMessage()));
- }
-
- setNextState(DropModelState.CONFIG_NODE_DROPPING);
- break;
-
- case CONFIG_NODE_DROPPING:
LOGGER.info("Start to drop model metrics [{}] on Data Nodes", modelId);
Optional<TDataNodeLocation> targetDataNode =
@@ -153,7 +131,8 @@ public class DropModelProcedure extends AbstractNodeProcedure<DropModelState> {
case ML_NODE_DROPPED:
LOGGER.info("Start to drop model [{}] on Config Nodes", modelId);
- response = env.getConfigManager().getConsensusManager().write(new DropModelPlan(modelId));
+ ConsensusWriteResponse response =
+ env.getConfigManager().getConsensusManager().write(new DropModelPlan(modelId));
if (!response.isSuccessful()) {
throw new ModelManagementException(
String.format(
diff --git a/confignode/src/main/java/org/apache/iotdb/confignode/procedure/state/model/DropModelState.java b/confignode/src/main/java/org/apache/iotdb/confignode/procedure/state/model/DropModelState.java
index 5f8c5a6f6e..54e32e86da 100644
--- a/confignode/src/main/java/org/apache/iotdb/confignode/procedure/state/model/DropModelState.java
+++ b/confignode/src/main/java/org/apache/iotdb/confignode/procedure/state/model/DropModelState.java
@@ -22,7 +22,6 @@ package org.apache.iotdb.confignode.procedure.state.model;
public enum DropModelState {
INIT,
VALIDATED,
- CONFIG_NODE_DROPPING,
DATA_NODE_DROPPED,
ML_NODE_DROPPED,
CONFIG_NODE_DROPPED
diff --git a/confignode/src/main/java/org/apache/iotdb/confignode/procedure/store/ProcedureFactory.java b/confignode/src/main/java/org/apache/iotdb/confignode/procedure/store/ProcedureFactory.java
index 48a4cfd997..4c026cb4f2 100644
--- a/confignode/src/main/java/org/apache/iotdb/confignode/procedure/store/ProcedureFactory.java
+++ b/confignode/src/main/java/org/apache/iotdb/confignode/procedure/store/ProcedureFactory.java
@@ -158,6 +158,10 @@ public class ProcedureFactory implements IProcedureFactory {
return ProcedureType.DEACTIVATE_TEMPLATE_PROCEDURE;
} else if (procedure instanceof UnsetTemplateProcedure) {
return ProcedureType.UNSET_TEMPLATE_PROCEDURE;
+ } else if (procedure instanceof CreateModelProcedure) {
+ return ProcedureType.CREATE_MODEL_PROCEDURE;
+ } else if (procedure instanceof DropModelProcedure) {
+ return ProcedureType.DROP_MODEL_PROCEDURE;
}
return null;
}
diff --git a/mlnode/iotdb/mlnode/service.py b/mlnode/iotdb/mlnode/service.py
index 8314dc363e..a2c05ea5c3 100644
--- a/mlnode/iotdb/mlnode/service.py
+++ b/mlnode/iotdb/mlnode/service.py
@@ -33,7 +33,7 @@ class RPCService(threading.Thread):
super().__init__()
processor = IMLNodeRPCService.Processor(handler=MLNodeRPCServiceHandler())
transport = TSocket.TServerSocket(host=config.get_mn_rpc_address(), port=config.get_mn_rpc_port())
- transport_factory = TTransport.TBufferedTransportFactory()
+ transport_factory = TTransport.TFramedTransportFactory()
protocol_factory = TCompactProtocol.TCompactProtocolFactory()
self.__pool_server = TServer.TThreadPoolServer(processor, transport, transport_factory, protocol_factory)
diff --git a/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelInformation.java b/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelInformation.java
index e6fbf13c95..a8cff6968d 100644
--- a/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelInformation.java
+++ b/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelInformation.java
@@ -24,6 +24,8 @@ import org.apache.iotdb.common.rpc.thrift.TrainingState;
import org.apache.iotdb.tsfile.utils.PublicBAOS;
import org.apache.iotdb.tsfile.utils.ReadWriteIOUtils;
+import javax.annotation.Nullable;
+
import java.io.DataOutputStream;
import java.io.FileOutputStream;
import java.io.IOException;
@@ -44,12 +46,12 @@ public class ModelInformation {
private final String modelType;
private final List<String> queryExpressions;
- private final String queryFilter;
+ @Nullable private String queryFilter;
private final boolean isAuto;
private TrainingState trainingState;
- private String bestTrailId;
+ @Nullable private String bestTrailId;
private final Map<String, TrailInformation> trailMap;
public ModelInformation(
@@ -58,11 +60,12 @@ public class ModelInformation {
String modelType,
boolean isAuto,
List<String> queryExpressions,
- String queryFilter) {
+ @Nullable String queryFilter) {
this.modelId = modelId;
this.modelTask = modelTask;
this.modelType = modelType;
this.isAuto = isAuto;
+ this.trainingState = TrainingState.PENDING;
this.queryExpressions = queryExpressions;
this.queryFilter = queryFilter;
this.trailMap = new HashMap<>();
@@ -79,10 +82,18 @@ public class ModelInformation {
this.queryExpressions.add(ReadWriteIOUtils.readString(buffer));
}
- this.queryFilter = ReadWriteIOUtils.readString(buffer);
+ byte isNull = ReadWriteIOUtils.readByte(buffer);
+ if (isNull == 1) {
+ this.queryFilter = ReadWriteIOUtils.readString(buffer);
+ }
+
this.isAuto = ReadWriteIOUtils.readBool(buffer);
this.trainingState = TrainingState.findByValue(ReadWriteIOUtils.readInt(buffer));
- this.bestTrailId = ReadWriteIOUtils.readString(buffer);
+
+ isNull = ReadWriteIOUtils.readByte(buffer);
+ if (isNull == 1) {
+ this.bestTrailId = ReadWriteIOUtils.readString(buffer);
+ }
int mapSize = ReadWriteIOUtils.readInt(buffer);
this.trailMap = new HashMap<>();
@@ -103,10 +114,18 @@ public class ModelInformation {
this.queryExpressions.add(ReadWriteIOUtils.readString(stream));
}
- this.queryFilter = ReadWriteIOUtils.readString(stream);
+ byte isNull = ReadWriteIOUtils.readByte(stream);
+ if (isNull == 1) {
+ this.queryFilter = ReadWriteIOUtils.readString(stream);
+ }
+
this.isAuto = ReadWriteIOUtils.readBool(stream);
this.trainingState = TrainingState.findByValue(ReadWriteIOUtils.readInt(stream));
- this.bestTrailId = ReadWriteIOUtils.readString(stream);
+
+ isNull = ReadWriteIOUtils.readByte(stream);
+ if (isNull == 1) {
+ this.bestTrailId = ReadWriteIOUtils.readString(stream);
+ }
int mapSize = ReadWriteIOUtils.readInt(stream);
this.trailMap = new HashMap<>();
@@ -128,6 +147,7 @@ public class ModelInformation {
return queryExpressions;
}
+ @Nullable
public String getQueryFilter() {
return queryFilter;
}
@@ -174,10 +194,24 @@ public class ModelInformation {
for (String queryExpression : queryExpressions) {
ReadWriteIOUtils.write(queryExpression, stream);
}
- ReadWriteIOUtils.write(queryFilter, stream);
+
+ if (queryFilter == null) {
+ ReadWriteIOUtils.write((byte) 0, stream);
+ } else {
+ ReadWriteIOUtils.write((byte) 1, stream);
+ ReadWriteIOUtils.write(queryFilter, stream);
+ }
+
ReadWriteIOUtils.write(isAuto, stream);
ReadWriteIOUtils.write(trainingState.ordinal(), stream);
- ReadWriteIOUtils.write(bestTrailId, stream);
+
+ if (bestTrailId == null) {
+ ReadWriteIOUtils.write((byte) 0, stream);
+ } else {
+ ReadWriteIOUtils.write((byte) 1, stream);
+ ReadWriteIOUtils.write(bestTrailId, stream);
+ }
+
ReadWriteIOUtils.write(trailMap.size(), stream);
for (TrailInformation trailInformation : trailMap.values()) {
trailInformation.serialize(stream);
@@ -194,10 +228,22 @@ public class ModelInformation {
ReadWriteIOUtils.write(queryExpression, stream);
}
- ReadWriteIOUtils.write(queryFilter, stream);
+ if (queryFilter == null) {
+ ReadWriteIOUtils.write((byte) 0, stream);
+ } else {
+ ReadWriteIOUtils.write((byte) 1, stream);
+ ReadWriteIOUtils.write(queryFilter, stream);
+ }
+
ReadWriteIOUtils.write(isAuto, stream);
ReadWriteIOUtils.write(trainingState.ordinal(), stream);
- ReadWriteIOUtils.write(bestTrailId, stream);
+
+ if (bestTrailId == null) {
+ ReadWriteIOUtils.write((byte) 0, stream);
+ } else {
+ ReadWriteIOUtils.write((byte) 1, stream);
+ ReadWriteIOUtils.write(bestTrailId, stream);
+ }
ReadWriteIOUtils.write(trailMap.size(), stream);
for (TrailInformation trailInformation : trailMap.values()) {
@@ -222,9 +268,14 @@ public class ModelInformation {
ReadWriteIOUtils.write(Arrays.toString(queryExpressions.toArray(new String[0])), stream);
ReadWriteIOUtils.write(trainingState.toString(), stream);
- TrailInformation bestTrail = trailMap.get(bestTrailId);
- ReadWriteIOUtils.write(bestTrail.getModelHyperparameter().toString(), stream);
- ReadWriteIOUtils.write(bestTrail.getModelPath(), stream);
+ if (bestTrailId != null) {
+ TrailInformation bestTrail = trailMap.get(bestTrailId);
+ ReadWriteIOUtils.write(bestTrail.getModelHyperparameter().toString(), stream);
+ ReadWriteIOUtils.write(bestTrail.getModelPath(), stream);
+ } else {
+ ReadWriteIOUtils.write("UNKNOWN", stream);
+ ReadWriteIOUtils.write("UNKNOWN", stream);
+ }
return ByteBuffer.wrap(buffer.getBuf(), 0, buffer.size());
}
}
diff --git a/server/src/main/java/org/apache/iotdb/db/client/MLNodeClient.java b/server/src/main/java/org/apache/iotdb/db/client/MLNodeClient.java
index 84278d9ba4..1ff54d43b6 100644
--- a/server/src/main/java/org/apache/iotdb/db/client/MLNodeClient.java
+++ b/server/src/main/java/org/apache/iotdb/db/client/MLNodeClient.java
@@ -29,16 +29,18 @@ import org.apache.iotdb.mlnode.rpc.thrift.TCreateTrainingTaskReq;
import org.apache.iotdb.mlnode.rpc.thrift.TDeleteModelReq;
import org.apache.iotdb.mlnode.rpc.thrift.TForecastReq;
import org.apache.iotdb.mlnode.rpc.thrift.TForecastResp;
-import org.apache.iotdb.rpc.RpcTransportFactory;
+import org.apache.iotdb.rpc.TConfigurationConst;
import org.apache.iotdb.rpc.TSStatusCode;
import org.apache.iotdb.tsfile.read.common.block.TsBlock;
import org.apache.iotdb.tsfile.read.common.block.column.TsBlockSerde;
import org.apache.thrift.TException;
-import org.apache.thrift.protocol.TBinaryProtocol;
+import org.apache.thrift.protocol.TCompactProtocol;
import org.apache.thrift.protocol.TProtocolFactory;
+import org.apache.thrift.transport.TSocket;
import org.apache.thrift.transport.TTransport;
import org.apache.thrift.transport.TTransportException;
+import org.apache.thrift.transport.layered.TFramedTransport;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -63,9 +65,13 @@ public class MLNodeClient implements AutoCloseable {
try {
long connectionTimeout = ClientPoolProperty.DefaultProperty.WAIT_CLIENT_TIMEOUT_MS;
transport =
- RpcTransportFactory.INSTANCE.getTransport(
- // As there is a try-catch already, we do not need to use TSocket.wrap
- endpoint.getIp(), endpoint.getPort(), (int) connectionTimeout);
+ new TFramedTransport.Factory()
+ .getTransport(
+ new TSocket(
+ TConfigurationConst.defaultTConfiguration,
+ endpoint.getIp(),
+ endpoint.getPort(),
+ (int) connectionTimeout));
if (!transport.isOpen()) {
transport.open();
}
@@ -73,7 +79,7 @@ public class MLNodeClient implements AutoCloseable {
throw new TException(MSG_CONNECTION_FAIL);
}
- TProtocolFactory protocolFactory = new TBinaryProtocol.Factory();
+ TProtocolFactory protocolFactory = new TCompactProtocol.Factory();
client = new IMLNodeRPCService.Client(protocolFactory.getProtocol(transport));
}
diff --git a/server/src/main/java/org/apache/iotdb/db/service/thrift/impl/DataNodeInternalRPCServiceImpl.java b/server/src/main/java/org/apache/iotdb/db/service/thrift/impl/DataNodeInternalRPCServiceImpl.java
index 60edda2b56..c2fa6249f8 100644
--- a/server/src/main/java/org/apache/iotdb/db/service/thrift/impl/DataNodeInternalRPCServiceImpl.java
+++ b/server/src/main/java/org/apache/iotdb/db/service/thrift/impl/DataNodeInternalRPCServiceImpl.java
@@ -875,8 +875,7 @@ public class DataNodeInternalRPCServiceImpl implements IDataNodeRPCService.Iface
@Override
public TSStatus deleteModelMetrics(TDeleteModelMetricsReq req) throws TException {
- // TODO
- throw new TException(new UnsupportedOperationException().getCause());
+ return RpcUtils.SUCCESS_STATUS;
}
@Override