You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@iotdb.apache.org by hu...@apache.org on 2023/03/23 07:18:52 UTC

[iotdb] 02/02: fix bug & finish

This is an automated email from the ASF dual-hosted git repository.

hui pushed a commit to branch lmh/MLSQL
in repository https://gitbox.apache.org/repos/asf/iotdb.git

commit 71fb614bb14cb9fdfbe87a2053628ba5e676fb96
Author: Minghui Liu <li...@foxmail.com>
AuthorDate: Thu Mar 23 15:17:35 2023 +0800

    fix bug & finish
---
 .../iotdb/confignode/persistence/ModelInfo.java    |  2 +-
 .../procedure/impl/model/CreateModelProcedure.java |  2 +-
 .../procedure/impl/model/DropModelProcedure.java   | 27 +-------
 .../procedure/state/model/DropModelState.java      |  1 -
 .../procedure/store/ProcedureFactory.java          |  4 ++
 mlnode/iotdb/mlnode/service.py                     |  2 +-
 .../iotdb/commons/model/ModelInformation.java      | 79 ++++++++++++++++++----
 .../org/apache/iotdb/db/client/MLNodeClient.java   | 18 +++--
 .../impl/DataNodeInternalRPCServiceImpl.java       |  3 +-
 9 files changed, 88 insertions(+), 50 deletions(-)

diff --git a/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java b/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java
index 13e4dabe73..3c72e09570 100644
--- a/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java
+++ b/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java
@@ -84,7 +84,7 @@ public class ModelInfo implements SnapshotProcessor {
       return new TSStatus(TSStatusCode.EXECUTE_STATEMENT_ERROR.getStatusCode())
           .setMessage(errorMessage);
     }
-    return null;
+    return new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode());
   }
 
   public TSStatus dropModel(DropModelPlan plan) {
diff --git a/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/CreateModelProcedure.java b/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/CreateModelProcedure.java
index 0a4d306fde..7dff5fe06e 100644
--- a/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/CreateModelProcedure.java
+++ b/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/CreateModelProcedure.java
@@ -49,7 +49,7 @@ import java.util.Objects;
 public class CreateModelProcedure extends AbstractNodeProcedure<CreateModelState> {
 
   private static final Logger LOGGER = LoggerFactory.getLogger(CreateModelProcedure.class);
-  private static final int RETRY_THRESHOLD = 5;
+  private static final int RETRY_THRESHOLD = 1;
 
   private ModelInformation modelInformation;
   private Map<String, String> modelConfigs;
diff --git a/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/DropModelProcedure.java b/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/DropModelProcedure.java
index 1268368bd6..bfa461a8b2 100644
--- a/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/DropModelProcedure.java
+++ b/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/DropModelProcedure.java
@@ -21,19 +21,16 @@ package org.apache.iotdb.confignode.procedure.impl.model;
 
 import org.apache.iotdb.common.rpc.thrift.TDataNodeLocation;
 import org.apache.iotdb.common.rpc.thrift.TSStatus;
-import org.apache.iotdb.common.rpc.thrift.TrainingState;
 import org.apache.iotdb.commons.model.exception.ModelManagementException;
 import org.apache.iotdb.confignode.client.DataNodeRequestType;
 import org.apache.iotdb.confignode.client.sync.SyncDataNodeClientPool;
 import org.apache.iotdb.confignode.consensus.request.write.model.DropModelPlan;
-import org.apache.iotdb.confignode.consensus.request.write.model.UpdateModelStatePlan;
 import org.apache.iotdb.confignode.persistence.ModelInfo;
 import org.apache.iotdb.confignode.procedure.env.ConfigNodeProcedureEnv;
 import org.apache.iotdb.confignode.procedure.exception.ProcedureException;
 import org.apache.iotdb.confignode.procedure.impl.node.AbstractNodeProcedure;
 import org.apache.iotdb.confignode.procedure.state.model.DropModelState;
 import org.apache.iotdb.confignode.procedure.store.ProcedureType;
-import org.apache.iotdb.confignode.rpc.thrift.TUpdateModelStateReq;
 import org.apache.iotdb.consensus.common.response.ConsensusWriteResponse;
 import org.apache.iotdb.db.client.MLNodeClient;
 import org.apache.iotdb.mpp.rpc.thrift.TDeleteModelMetricsReq;
@@ -53,7 +50,7 @@ import java.util.Optional;
 public class DropModelProcedure extends AbstractNodeProcedure<DropModelState> {
 
   private static final Logger LOGGER = LoggerFactory.getLogger(DropModelProcedure.class);
-  private static final int RETRY_THRESHOLD = 5;
+  private static final int RETRY_THRESHOLD = 1;
 
   private String modelId;
 
@@ -87,25 +84,6 @@ public class DropModelProcedure extends AbstractNodeProcedure<DropModelState> {
           break;
 
         case VALIDATED:
-          LOGGER.info("Change state of model [{}] to DROPPING", modelId);
-
-          ConsensusWriteResponse response =
-              env.getConfigManager()
-                  .getConsensusManager()
-                  .write(
-                      new UpdateModelStatePlan(
-                          new TUpdateModelStateReq(modelId, TrainingState.DROPPING)));
-          if (!response.isSuccessful()) {
-            throw new ModelManagementException(
-                String.format(
-                    "Failed to drop model [%s], fail to modify model state: %s",
-                    modelId, response.getErrorMessage()));
-          }
-
-          setNextState(DropModelState.CONFIG_NODE_DROPPING);
-          break;
-
-        case CONFIG_NODE_DROPPING:
           LOGGER.info("Start to drop model metrics [{}] on Data Nodes", modelId);
 
           Optional<TDataNodeLocation> targetDataNode =
@@ -153,7 +131,8 @@ public class DropModelProcedure extends AbstractNodeProcedure<DropModelState> {
         case ML_NODE_DROPPED:
           LOGGER.info("Start to drop model [{}] on Config Nodes", modelId);
 
-          response = env.getConfigManager().getConsensusManager().write(new DropModelPlan(modelId));
+          ConsensusWriteResponse response =
+              env.getConfigManager().getConsensusManager().write(new DropModelPlan(modelId));
           if (!response.isSuccessful()) {
             throw new ModelManagementException(
                 String.format(
diff --git a/confignode/src/main/java/org/apache/iotdb/confignode/procedure/state/model/DropModelState.java b/confignode/src/main/java/org/apache/iotdb/confignode/procedure/state/model/DropModelState.java
index 5f8c5a6f6e..54e32e86da 100644
--- a/confignode/src/main/java/org/apache/iotdb/confignode/procedure/state/model/DropModelState.java
+++ b/confignode/src/main/java/org/apache/iotdb/confignode/procedure/state/model/DropModelState.java
@@ -22,7 +22,6 @@ package org.apache.iotdb.confignode.procedure.state.model;
 public enum DropModelState {
   INIT,
   VALIDATED,
-  CONFIG_NODE_DROPPING,
   DATA_NODE_DROPPED,
   ML_NODE_DROPPED,
   CONFIG_NODE_DROPPED
diff --git a/confignode/src/main/java/org/apache/iotdb/confignode/procedure/store/ProcedureFactory.java b/confignode/src/main/java/org/apache/iotdb/confignode/procedure/store/ProcedureFactory.java
index 48a4cfd997..4c026cb4f2 100644
--- a/confignode/src/main/java/org/apache/iotdb/confignode/procedure/store/ProcedureFactory.java
+++ b/confignode/src/main/java/org/apache/iotdb/confignode/procedure/store/ProcedureFactory.java
@@ -158,6 +158,10 @@ public class ProcedureFactory implements IProcedureFactory {
       return ProcedureType.DEACTIVATE_TEMPLATE_PROCEDURE;
     } else if (procedure instanceof UnsetTemplateProcedure) {
       return ProcedureType.UNSET_TEMPLATE_PROCEDURE;
+    } else if (procedure instanceof CreateModelProcedure) {
+      return ProcedureType.CREATE_MODEL_PROCEDURE;
+    } else if (procedure instanceof DropModelProcedure) {
+      return ProcedureType.DROP_MODEL_PROCEDURE;
     }
     return null;
   }
diff --git a/mlnode/iotdb/mlnode/service.py b/mlnode/iotdb/mlnode/service.py
index 8314dc363e..a2c05ea5c3 100644
--- a/mlnode/iotdb/mlnode/service.py
+++ b/mlnode/iotdb/mlnode/service.py
@@ -33,7 +33,7 @@ class RPCService(threading.Thread):
         super().__init__()
         processor = IMLNodeRPCService.Processor(handler=MLNodeRPCServiceHandler())
         transport = TSocket.TServerSocket(host=config.get_mn_rpc_address(), port=config.get_mn_rpc_port())
-        transport_factory = TTransport.TBufferedTransportFactory()
+        transport_factory = TTransport.TFramedTransportFactory()
         protocol_factory = TCompactProtocol.TCompactProtocolFactory()
 
         self.__pool_server = TServer.TThreadPoolServer(processor, transport, transport_factory, protocol_factory)
diff --git a/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelInformation.java b/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelInformation.java
index e6fbf13c95..a8cff6968d 100644
--- a/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelInformation.java
+++ b/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelInformation.java
@@ -24,6 +24,8 @@ import org.apache.iotdb.common.rpc.thrift.TrainingState;
 import org.apache.iotdb.tsfile.utils.PublicBAOS;
 import org.apache.iotdb.tsfile.utils.ReadWriteIOUtils;
 
+import javax.annotation.Nullable;
+
 import java.io.DataOutputStream;
 import java.io.FileOutputStream;
 import java.io.IOException;
@@ -44,12 +46,12 @@ public class ModelInformation {
   private final String modelType;
 
   private final List<String> queryExpressions;
-  private final String queryFilter;
+  @Nullable private String queryFilter;
 
   private final boolean isAuto;
   private TrainingState trainingState;
 
-  private String bestTrailId;
+  @Nullable private String bestTrailId;
   private final Map<String, TrailInformation> trailMap;
 
   public ModelInformation(
@@ -58,11 +60,12 @@ public class ModelInformation {
       String modelType,
       boolean isAuto,
       List<String> queryExpressions,
-      String queryFilter) {
+      @Nullable String queryFilter) {
     this.modelId = modelId;
     this.modelTask = modelTask;
     this.modelType = modelType;
     this.isAuto = isAuto;
+    this.trainingState = TrainingState.PENDING;
     this.queryExpressions = queryExpressions;
     this.queryFilter = queryFilter;
     this.trailMap = new HashMap<>();
@@ -79,10 +82,18 @@ public class ModelInformation {
       this.queryExpressions.add(ReadWriteIOUtils.readString(buffer));
     }
 
-    this.queryFilter = ReadWriteIOUtils.readString(buffer);
+    byte isNull = ReadWriteIOUtils.readByte(buffer);
+    if (isNull == 1) {
+      this.queryFilter = ReadWriteIOUtils.readString(buffer);
+    }
+
     this.isAuto = ReadWriteIOUtils.readBool(buffer);
     this.trainingState = TrainingState.findByValue(ReadWriteIOUtils.readInt(buffer));
-    this.bestTrailId = ReadWriteIOUtils.readString(buffer);
+
+    isNull = ReadWriteIOUtils.readByte(buffer);
+    if (isNull == 1) {
+      this.bestTrailId = ReadWriteIOUtils.readString(buffer);
+    }
 
     int mapSize = ReadWriteIOUtils.readInt(buffer);
     this.trailMap = new HashMap<>();
@@ -103,10 +114,18 @@ public class ModelInformation {
       this.queryExpressions.add(ReadWriteIOUtils.readString(stream));
     }
 
-    this.queryFilter = ReadWriteIOUtils.readString(stream);
+    byte isNull = ReadWriteIOUtils.readByte(stream);
+    if (isNull == 1) {
+      this.queryFilter = ReadWriteIOUtils.readString(stream);
+    }
+
     this.isAuto = ReadWriteIOUtils.readBool(stream);
     this.trainingState = TrainingState.findByValue(ReadWriteIOUtils.readInt(stream));
-    this.bestTrailId = ReadWriteIOUtils.readString(stream);
+
+    isNull = ReadWriteIOUtils.readByte(stream);
+    if (isNull == 1) {
+      this.bestTrailId = ReadWriteIOUtils.readString(stream);
+    }
 
     int mapSize = ReadWriteIOUtils.readInt(stream);
     this.trailMap = new HashMap<>();
@@ -128,6 +147,7 @@ public class ModelInformation {
     return queryExpressions;
   }
 
+  @Nullable
   public String getQueryFilter() {
     return queryFilter;
   }
@@ -174,10 +194,24 @@ public class ModelInformation {
     for (String queryExpression : queryExpressions) {
       ReadWriteIOUtils.write(queryExpression, stream);
     }
-    ReadWriteIOUtils.write(queryFilter, stream);
+
+    if (queryFilter == null) {
+      ReadWriteIOUtils.write((byte) 0, stream);
+    } else {
+      ReadWriteIOUtils.write((byte) 1, stream);
+      ReadWriteIOUtils.write(queryFilter, stream);
+    }
+
     ReadWriteIOUtils.write(isAuto, stream);
     ReadWriteIOUtils.write(trainingState.ordinal(), stream);
-    ReadWriteIOUtils.write(bestTrailId, stream);
+
+    if (bestTrailId == null) {
+      ReadWriteIOUtils.write((byte) 0, stream);
+    } else {
+      ReadWriteIOUtils.write((byte) 1, stream);
+      ReadWriteIOUtils.write(bestTrailId, stream);
+    }
+
     ReadWriteIOUtils.write(trailMap.size(), stream);
     for (TrailInformation trailInformation : trailMap.values()) {
       trailInformation.serialize(stream);
@@ -194,10 +228,22 @@ public class ModelInformation {
       ReadWriteIOUtils.write(queryExpression, stream);
     }
 
-    ReadWriteIOUtils.write(queryFilter, stream);
+    if (queryFilter == null) {
+      ReadWriteIOUtils.write((byte) 0, stream);
+    } else {
+      ReadWriteIOUtils.write((byte) 1, stream);
+      ReadWriteIOUtils.write(queryFilter, stream);
+    }
+
     ReadWriteIOUtils.write(isAuto, stream);
     ReadWriteIOUtils.write(trainingState.ordinal(), stream);
-    ReadWriteIOUtils.write(bestTrailId, stream);
+
+    if (bestTrailId == null) {
+      ReadWriteIOUtils.write((byte) 0, stream);
+    } else {
+      ReadWriteIOUtils.write((byte) 1, stream);
+      ReadWriteIOUtils.write(bestTrailId, stream);
+    }
 
     ReadWriteIOUtils.write(trailMap.size(), stream);
     for (TrailInformation trailInformation : trailMap.values()) {
@@ -222,9 +268,14 @@ public class ModelInformation {
     ReadWriteIOUtils.write(Arrays.toString(queryExpressions.toArray(new String[0])), stream);
     ReadWriteIOUtils.write(trainingState.toString(), stream);
 
-    TrailInformation bestTrail = trailMap.get(bestTrailId);
-    ReadWriteIOUtils.write(bestTrail.getModelHyperparameter().toString(), stream);
-    ReadWriteIOUtils.write(bestTrail.getModelPath(), stream);
+    if (bestTrailId != null) {
+      TrailInformation bestTrail = trailMap.get(bestTrailId);
+      ReadWriteIOUtils.write(bestTrail.getModelHyperparameter().toString(), stream);
+      ReadWriteIOUtils.write(bestTrail.getModelPath(), stream);
+    } else {
+      ReadWriteIOUtils.write("UNKNOWN", stream);
+      ReadWriteIOUtils.write("UNKNOWN", stream);
+    }
     return ByteBuffer.wrap(buffer.getBuf(), 0, buffer.size());
   }
 }
diff --git a/server/src/main/java/org/apache/iotdb/db/client/MLNodeClient.java b/server/src/main/java/org/apache/iotdb/db/client/MLNodeClient.java
index 84278d9ba4..1ff54d43b6 100644
--- a/server/src/main/java/org/apache/iotdb/db/client/MLNodeClient.java
+++ b/server/src/main/java/org/apache/iotdb/db/client/MLNodeClient.java
@@ -29,16 +29,18 @@ import org.apache.iotdb.mlnode.rpc.thrift.TCreateTrainingTaskReq;
 import org.apache.iotdb.mlnode.rpc.thrift.TDeleteModelReq;
 import org.apache.iotdb.mlnode.rpc.thrift.TForecastReq;
 import org.apache.iotdb.mlnode.rpc.thrift.TForecastResp;
-import org.apache.iotdb.rpc.RpcTransportFactory;
+import org.apache.iotdb.rpc.TConfigurationConst;
 import org.apache.iotdb.rpc.TSStatusCode;
 import org.apache.iotdb.tsfile.read.common.block.TsBlock;
 import org.apache.iotdb.tsfile.read.common.block.column.TsBlockSerde;
 
 import org.apache.thrift.TException;
-import org.apache.thrift.protocol.TBinaryProtocol;
+import org.apache.thrift.protocol.TCompactProtocol;
 import org.apache.thrift.protocol.TProtocolFactory;
+import org.apache.thrift.transport.TSocket;
 import org.apache.thrift.transport.TTransport;
 import org.apache.thrift.transport.TTransportException;
+import org.apache.thrift.transport.layered.TFramedTransport;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -63,9 +65,13 @@ public class MLNodeClient implements AutoCloseable {
     try {
       long connectionTimeout = ClientPoolProperty.DefaultProperty.WAIT_CLIENT_TIMEOUT_MS;
       transport =
-          RpcTransportFactory.INSTANCE.getTransport(
-              // As there is a try-catch already, we do not need to use TSocket.wrap
-              endpoint.getIp(), endpoint.getPort(), (int) connectionTimeout);
+          new TFramedTransport.Factory()
+              .getTransport(
+                  new TSocket(
+                      TConfigurationConst.defaultTConfiguration,
+                      endpoint.getIp(),
+                      endpoint.getPort(),
+                      (int) connectionTimeout));
       if (!transport.isOpen()) {
         transport.open();
       }
@@ -73,7 +79,7 @@ public class MLNodeClient implements AutoCloseable {
       throw new TException(MSG_CONNECTION_FAIL);
     }
 
-    TProtocolFactory protocolFactory = new TBinaryProtocol.Factory();
+    TProtocolFactory protocolFactory = new TCompactProtocol.Factory();
     client = new IMLNodeRPCService.Client(protocolFactory.getProtocol(transport));
   }
 
diff --git a/server/src/main/java/org/apache/iotdb/db/service/thrift/impl/DataNodeInternalRPCServiceImpl.java b/server/src/main/java/org/apache/iotdb/db/service/thrift/impl/DataNodeInternalRPCServiceImpl.java
index 60edda2b56..c2fa6249f8 100644
--- a/server/src/main/java/org/apache/iotdb/db/service/thrift/impl/DataNodeInternalRPCServiceImpl.java
+++ b/server/src/main/java/org/apache/iotdb/db/service/thrift/impl/DataNodeInternalRPCServiceImpl.java
@@ -875,8 +875,7 @@ public class DataNodeInternalRPCServiceImpl implements IDataNodeRPCService.Iface
 
   @Override
   public TSStatus deleteModelMetrics(TDeleteModelMetricsReq req) throws TException {
-    // TODO
-    throw new TException(new UnsupportedOperationException().getCause());
+    return RpcUtils.SUCCESS_STATUS;
   }
 
   @Override