You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@iotdb.apache.org by hu...@apache.org on 2023/02/17 07:00:52 UTC
[iotdb] 16/16: implement ModelInfo
This is an automated email from the ASF dual-hosted git repository.
hui pushed a commit to branch lmh/modelManager
in repository https://gitbox.apache.org/repos/asf/iotdb.git
commit 7202cba5b6e22c3dc914edc1367c721c00eeaa02
Author: Minghui Liu <li...@foxmail.com>
AuthorDate: Fri Feb 17 15:00:10 2023 +0800
implement ModelInfo
---
.../request/write/model/UpdateModelInfoPlan.java | 14 +-
.../consensus/response/ModelTableResp.java | 11 +-
.../consensus/response/TrailTableResp.java | 11 +-
.../iotdb/confignode/manager/ModelManager.java | 11 +-
.../iotdb/confignode/manager/ProcedureManager.java | 5 +-
.../iotdb/confignode/persistence/ModelInfo.java | 43 +++---
.../procedure/impl/model/CreateModelProcedure.java | 28 ++--
.../procedure/state/model/CreateModelState.java | 5 +-
.../commons/model/ForecastTrailInformation.java | 32 -----
.../iotdb/commons/model/ModelHyperparameter.java | 71 +++++++++
.../iotdb/commons/model/ModelInformation.java | 158 ++++++++++++++++++---
.../org/apache/iotdb/commons/model/ModelTable.java | 4 +-
.../iotdb/commons/model/TrailInformation.java | 84 +++++++++--
thrift-commons/src/main/thrift/common.thrift | 12 +-
.../src/main/thrift/confignode.thrift | 26 +++-
15 files changed, 390 insertions(+), 125 deletions(-)
diff --git a/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/model/UpdateModelInfoPlan.java b/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/model/UpdateModelInfoPlan.java
index d1eacadae2..52c1100a00 100644
--- a/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/model/UpdateModelInfoPlan.java
+++ b/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/model/UpdateModelInfoPlan.java
@@ -33,6 +33,7 @@ import java.util.Objects;
public class UpdateModelInfoPlan extends ConfigPhysicalPlan {
private String modelId;
+ private String trailId;
private Map<String, String> modelInfo;
public UpdateModelInfoPlan() {
@@ -42,6 +43,7 @@ public class UpdateModelInfoPlan extends ConfigPhysicalPlan {
public UpdateModelInfoPlan(TUpdateModelInfoReq updateModelInfoReq) {
super(ConfigPhysicalPlanType.UpdateModelInfo);
this.modelId = updateModelInfoReq.getModelId();
+ this.trailId = updateModelInfoReq.getTrailId();
this.modelInfo = updateModelInfoReq.getModelInfo();
}
@@ -49,6 +51,10 @@ public class UpdateModelInfoPlan extends ConfigPhysicalPlan {
return modelId;
}
+ public String getTrailId() {
+ return trailId;
+ }
+
public Map<String, String> getModelInfo() {
return modelInfo;
}
@@ -57,12 +63,14 @@ public class UpdateModelInfoPlan extends ConfigPhysicalPlan {
protected void serializeImpl(DataOutputStream stream) throws IOException {
stream.writeShort(getType().getPlanType());
ReadWriteIOUtils.write(modelId, stream);
+ ReadWriteIOUtils.write(trailId, stream);
ReadWriteIOUtils.write(modelInfo, stream);
}
@Override
protected void deserializeImpl(ByteBuffer buffer) throws IOException {
this.modelId = ReadWriteIOUtils.readString(buffer);
+ this.trailId = ReadWriteIOUtils.readString(buffer);
this.modelInfo = ReadWriteIOUtils.readMap(buffer);
}
@@ -78,11 +86,13 @@ public class UpdateModelInfoPlan extends ConfigPhysicalPlan {
return false;
}
UpdateModelInfoPlan that = (UpdateModelInfoPlan) o;
- return modelId.equals(that.modelId) && modelInfo.equals(that.modelInfo);
+ return modelId.equals(that.modelId)
+ && trailId.equals(that.trailId)
+ && modelInfo.equals(that.modelInfo);
}
@Override
public int hashCode() {
- return Objects.hash(super.hashCode(), modelId, modelInfo);
+ return Objects.hash(super.hashCode(), modelId, trailId, modelInfo);
}
}
diff --git a/confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/ModelTableResp.java b/confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/ModelTableResp.java
index 82787ca653..6642f76be9 100644
--- a/confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/ModelTableResp.java
+++ b/confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/ModelTableResp.java
@@ -34,14 +34,21 @@ public class ModelTableResp implements DataSet {
private final TSStatus status;
private final List<ByteBuffer> serializedAllModelInformation;
- public ModelTableResp(TSStatus status, List<ModelInformation> allModelInformation) {
+ public ModelTableResp(TSStatus status) {
this.status = status;
this.serializedAllModelInformation = new ArrayList<>();
- for (ModelInformation modelInformation : allModelInformation) {
+ }
+
+ public void addModelInformation(List<ModelInformation> modelInformationList) throws IOException {
+ for (ModelInformation modelInformation : modelInformationList) {
this.serializedAllModelInformation.add(modelInformation.serializeShowModelResult());
}
}
+ public void addModelInformation(ModelInformation modelInformation) throws IOException {
+ this.serializedAllModelInformation.add(modelInformation.serializeShowModelResult());
+ }
+
public TShowModelResp convertToThriftResponse() throws IOException {
return new TShowModelResp(status, serializedAllModelInformation);
}
diff --git a/confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/TrailTableResp.java b/confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/TrailTableResp.java
index 21e3fcd230..1f9a6b5acb 100644
--- a/confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/TrailTableResp.java
+++ b/confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/TrailTableResp.java
@@ -34,10 +34,17 @@ public class TrailTableResp implements DataSet {
private final TSStatus status;
private final List<ByteBuffer> serializedAllTrailInformation;
- public TrailTableResp(TSStatus status, List<TrailInformation> allTrailInformation) {
+ public TrailTableResp(TSStatus status) {
this.status = status;
this.serializedAllTrailInformation = new ArrayList<>();
- for (TrailInformation trailInformation : allTrailInformation) {
+ }
+
+ public void addTrailInformation(TrailInformation trailInformation) throws IOException {
+ this.serializedAllTrailInformation.add(trailInformation.serializeShowTrailResult());
+ }
+
+ public void addTrailInformation(List<TrailInformation> trailInformationList) throws IOException {
+ for (TrailInformation trailInformation : trailInformationList) {
this.serializedAllTrailInformation.add(trailInformation.serializeShowTrailResult());
}
}
diff --git a/confignode/src/main/java/org/apache/iotdb/confignode/manager/ModelManager.java b/confignode/src/main/java/org/apache/iotdb/confignode/manager/ModelManager.java
index 347c5c0b94..1d268df5b5 100644
--- a/confignode/src/main/java/org/apache/iotdb/confignode/manager/ModelManager.java
+++ b/confignode/src/main/java/org/apache/iotdb/confignode/manager/ModelManager.java
@@ -61,8 +61,15 @@ public class ModelManager {
}
public TSStatus createModel(TCreateModelReq req) {
- ModelInformation modelInformation = new ModelInformation();
- return configManager.getProcedureManager().createModel(modelInformation);
+ ModelInformation modelInformation =
+ new ModelInformation(
+ req.getModelId(),
+ req.getModelTask(),
+ req.getModelType(),
+ req.isIsAuto(),
+ req.getQueryExpressions(),
+ req.getQueryFilter());
+ return configManager.getProcedureManager().createModel(modelInformation, req.getModelConfigs());
}
public TSStatus dropModel(TDropModelReq req) {
diff --git a/confignode/src/main/java/org/apache/iotdb/confignode/manager/ProcedureManager.java b/confignode/src/main/java/org/apache/iotdb/confignode/manager/ProcedureManager.java
index 7cd0da45cd..7b318b4e87 100644
--- a/confignode/src/main/java/org/apache/iotdb/confignode/manager/ProcedureManager.java
+++ b/confignode/src/main/java/org/apache/iotdb/confignode/manager/ProcedureManager.java
@@ -535,8 +535,9 @@ public class ProcedureManager {
return statusList.get(0);
}
- public TSStatus createModel(ModelInformation modelInformation) {
- long procedureId = executor.submitProcedure(new CreateModelProcedure(modelInformation));
+ public TSStatus createModel(ModelInformation modelInformation, Map<String, String> modelConfigs) {
+ long procedureId =
+ executor.submitProcedure(new CreateModelProcedure(modelInformation, modelConfigs));
List<TSStatus> statusList = new ArrayList<>();
boolean isSucceed =
waitingProcedureFinished(Collections.singletonList(procedureId), statusList);
diff --git a/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java b/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java
index d2c5144d10..bb422b3b1e 100644
--- a/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java
+++ b/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java
@@ -43,7 +43,6 @@ import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
-import java.util.Collections;
import java.util.concurrent.locks.ReentrantLock;
@ThreadSafe
@@ -98,18 +97,22 @@ public class ModelInfo implements SnapshotProcessor {
public ModelTableResp showModel(ShowModelPlan plan) {
acquireModelTableLock();
try {
+ ModelTableResp modelTableResp =
+ new ModelTableResp(new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode()));
if (plan.isSetModelId()) {
ModelInformation modelInformation = modelTable.getModelInformationById(plan.getModelId());
- return new ModelTableResp(
- new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode()),
- modelInformation != null
- ? Collections.singletonList(modelInformation)
- : Collections.emptyList());
+ if (modelInformation != null) {
+ modelTableResp.addModelInformation(modelInformation);
+ }
} else {
- return new ModelTableResp(
- new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode()),
- modelTable.getAllModelInformation());
+ modelTableResp.addModelInformation(modelTable.getAllModelInformation());
}
+ return modelTableResp;
+ } catch (IOException e) {
+ LOGGER.warn("Fail to get ModelTable", e);
+ return new ModelTableResp(
+ new TSStatus(TSStatusCode.EXECUTE_STATEMENT_ERROR.getStatusCode())
+ .setMessage(e.getMessage()));
} finally {
releaseModelTableLock();
}
@@ -119,19 +122,23 @@ public class ModelInfo implements SnapshotProcessor {
acquireModelTableLock();
try {
ModelInformation modelInformation = modelTable.getModelInformationById(plan.getModelId());
+ TrailTableResp trailTableResp =
+ new TrailTableResp(new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode()));
if (plan.isSetTrailId()) {
TrailInformation trailInformation =
modelInformation.getTrailInformationById(plan.getTrailId());
- return new TrailTableResp(
- new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode()),
- trailInformation != null
- ? Collections.singletonList(trailInformation)
- : Collections.emptyList());
+ if (trailInformation != null) {
+ trailTableResp.addTrailInformation(trailInformation);
+ }
} else {
- return new TrailTableResp(
- new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode()),
- modelInformation.getAllTrailInformation());
+ trailTableResp.addTrailInformation(modelInformation.getAllTrailInformation());
}
+ return trailTableResp;
+ } catch (IOException e) {
+ LOGGER.warn("Fail to get TrailTable", e);
+ return new TrailTableResp(
+ new TSStatus(TSStatusCode.EXECUTE_STATEMENT_ERROR.getStatusCode())
+ .setMessage(e.getMessage()));
} finally {
releaseModelTableLock();
}
@@ -142,7 +149,7 @@ public class ModelInfo implements SnapshotProcessor {
try {
String modelId = plan.getModelId();
if (modelTable.containsModel(modelId)) {
- modelTable.updateModel(modelId, plan.getModelInfo());
+ modelTable.updateModel(modelId, plan.getTrailId(), plan.getModelInfo());
}
return new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode());
} finally {
diff --git a/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/CreateModelProcedure.java b/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/CreateModelProcedure.java
index 0eb4b3800b..b003210f94 100644
--- a/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/CreateModelProcedure.java
+++ b/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/CreateModelProcedure.java
@@ -23,7 +23,6 @@ import org.apache.iotdb.commons.model.ModelInformation;
import org.apache.iotdb.commons.model.exception.ModelManagementException;
import org.apache.iotdb.confignode.consensus.request.write.model.CreateModelPlan;
import org.apache.iotdb.confignode.consensus.request.write.model.DropModelPlan;
-import org.apache.iotdb.confignode.consensus.request.write.model.UpdateModelInfoPlan;
import org.apache.iotdb.confignode.manager.ConfigManager;
import org.apache.iotdb.confignode.persistence.ModelInfo;
import org.apache.iotdb.confignode.procedure.env.ConfigNodeProcedureEnv;
@@ -32,6 +31,7 @@ import org.apache.iotdb.confignode.procedure.impl.node.AbstractNodeProcedure;
import org.apache.iotdb.confignode.procedure.state.model.CreateModelState;
import org.apache.iotdb.confignode.procedure.store.ProcedureType;
import org.apache.iotdb.consensus.common.response.ConsensusWriteResponse;
+import org.apache.iotdb.tsfile.utils.ReadWriteIOUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -39,6 +39,7 @@ import org.slf4j.LoggerFactory;
import java.io.DataOutputStream;
import java.io.IOException;
import java.nio.ByteBuffer;
+import java.util.Map;
public class CreateModelProcedure extends AbstractNodeProcedure<CreateModelState> {
@@ -46,14 +47,16 @@ public class CreateModelProcedure extends AbstractNodeProcedure<CreateModelState
private static final int RETRY_THRESHOLD = 5;
private ModelInformation modelInformation;
+ private Map<String, String> modelConfigs;
public CreateModelProcedure() {
super();
}
- public CreateModelProcedure(ModelInformation modelInformation) {
+ public CreateModelProcedure(ModelInformation modelInformation, Map<String, String> modelConfigs) {
super();
this.modelInformation = modelInformation;
+ this.modelConfigs = modelConfigs;
}
@Override
@@ -90,10 +93,10 @@ public class CreateModelProcedure extends AbstractNodeProcedure<CreateModelState
throw new ModelManagementException(response.getErrorMessage());
}
- setNextState(CreateModelState.CONFIG_NODE_INACTIVE);
+ setNextState(CreateModelState.CONFIG_NODE_ACTIVE);
break;
- case CONFIG_NODE_INACTIVE:
+ case CONFIG_NODE_ACTIVE:
LOGGER.info("Start to train model [{}] on ML Node", modelInformation.getModelId());
if (true) {
@@ -107,16 +110,6 @@ public class CreateModelProcedure extends AbstractNodeProcedure<CreateModelState
break;
case ML_NODE_ACTIVE:
- LOGGER.info("Start to active model [{}] on Config Nodes", modelInformation.getModelId());
- env.getConfigManager()
- .getConsensusManager()
- .write(
- // TODO
- new UpdateModelInfoPlan());
- setNextState(CreateModelState.CONFIG_NODE_ACTIVE);
- break;
-
- case CONFIG_NODE_ACTIVE:
env.getConfigManager().getTriggerManager().getTriggerInfo().releaseTriggerTableLock();
return Flow.NO_MORE_STATE;
}
@@ -160,7 +153,7 @@ public class CreateModelProcedure extends AbstractNodeProcedure<CreateModelState
.write(new DropModelPlan(modelInformation.getModelId()));
break;
- case CONFIG_NODE_INACTIVE:
+ case ML_NODE_ACTIVE:
LOGGER.info(
"Start to [CONFIG_NODE_INACTIVE] rollback of model [{}]",
modelInformation.getModelId());
@@ -197,12 +190,14 @@ public class CreateModelProcedure extends AbstractNodeProcedure<CreateModelState
stream.writeShort(ProcedureType.CREATE_MODEL_PROCEDURE.getTypeCode());
super.serialize(stream);
modelInformation.serialize(stream);
+ ReadWriteIOUtils.write(modelConfigs, stream);
}
@Override
public void deserialize(ByteBuffer byteBuffer) {
super.deserialize(byteBuffer);
modelInformation = ModelInformation.deserialize(byteBuffer);
+ modelConfigs = ReadWriteIOUtils.readMap(byteBuffer);
}
@Override
@@ -211,7 +206,8 @@ public class CreateModelProcedure extends AbstractNodeProcedure<CreateModelState
CreateModelProcedure thatProc = (CreateModelProcedure) that;
return thatProc.getProcId() == this.getProcId()
&& thatProc.getState() == this.getState()
- && thatProc.modelInformation.equals(this.modelInformation);
+ && thatProc.modelInformation.equals(this.modelInformation)
+ && thatProc.modelConfigs.equals(this.modelConfigs);
}
return false;
}
diff --git a/confignode/src/main/java/org/apache/iotdb/confignode/procedure/state/model/CreateModelState.java b/confignode/src/main/java/org/apache/iotdb/confignode/procedure/state/model/CreateModelState.java
index 2f9380ce80..304d209817 100644
--- a/confignode/src/main/java/org/apache/iotdb/confignode/procedure/state/model/CreateModelState.java
+++ b/confignode/src/main/java/org/apache/iotdb/confignode/procedure/state/model/CreateModelState.java
@@ -22,7 +22,6 @@ package org.apache.iotdb.confignode.procedure.state.model;
public enum CreateModelState {
INIT,
VALIDATED,
- CONFIG_NODE_INACTIVE,
- ML_NODE_ACTIVE,
- CONFIG_NODE_ACTIVE
+ CONFIG_NODE_ACTIVE,
+ ML_NODE_ACTIVE
}
diff --git a/node-commons/src/main/java/org/apache/iotdb/commons/model/ForecastTrailInformation.java b/node-commons/src/main/java/org/apache/iotdb/commons/model/ForecastTrailInformation.java
deleted file mode 100644
index f98d7339ea..0000000000
--- a/node-commons/src/main/java/org/apache/iotdb/commons/model/ForecastTrailInformation.java
+++ /dev/null
@@ -1,32 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.iotdb.commons.model;
-
-import java.nio.ByteBuffer;
-
-public class ForecastTrailInformation extends TrailInformation {
-
- private long outputLen;
-
- @Override
- public ByteBuffer serializeShowTrailResult() {
- return null;
- }
-}
diff --git a/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelHyperparameter.java b/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelHyperparameter.java
new file mode 100644
index 0000000000..151a6b7c59
--- /dev/null
+++ b/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelHyperparameter.java
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iotdb.commons.model;
+
+import org.apache.iotdb.tsfile.utils.ReadWriteIOUtils;
+
+import java.io.DataOutputStream;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.nio.ByteBuffer;
+import java.util.Map;
+
+public class ModelHyperparameter {
+
+ private final Map<String, String> keyValueMap;
+
+ public ModelHyperparameter(Map<String, String> keyValueMap) {
+ this.keyValueMap = keyValueMap;
+ }
+
+ public void update(Map<String, String> modelInfo) {
+ this.keyValueMap.putAll(modelInfo);
+ }
+
+ @Override
+ public String toString() {
+ StringBuilder stringBuilder = new StringBuilder();
+ for (Map.Entry<String, String> keyValuePair : keyValueMap.entrySet()) {
+ stringBuilder
+ .append(keyValuePair.getKey())
+ .append('=')
+ .append(keyValuePair.getValue())
+ .append('\n');
+ }
+ return stringBuilder.toString();
+ }
+
+ public void serialize(DataOutputStream stream) throws IOException {
+ ReadWriteIOUtils.write(keyValueMap, stream);
+ }
+
+ public void serialize(FileOutputStream stream) throws IOException {
+ ReadWriteIOUtils.write(keyValueMap, stream);
+ }
+
+ public static ModelHyperparameter deserialize(ByteBuffer buffer) {
+ return new ModelHyperparameter(ReadWriteIOUtils.readMap(buffer));
+ }
+
+ public static ModelHyperparameter deserialize(InputStream stream) throws IOException {
+ return new ModelHyperparameter(ReadWriteIOUtils.readMap(stream));
+ }
+}
diff --git a/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelInformation.java b/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelInformation.java
index 308c1d7f36..bdd9b13fc9 100644
--- a/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelInformation.java
+++ b/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelInformation.java
@@ -22,32 +22,98 @@ package org.apache.iotdb.commons.model;
import org.apache.iotdb.common.rpc.thrift.ModelTask;
import org.apache.iotdb.common.rpc.thrift.TrainingState;
import org.apache.iotdb.tsfile.utils.PublicBAOS;
+import org.apache.iotdb.tsfile.utils.ReadWriteIOUtils;
import java.io.DataOutputStream;
import java.io.FileOutputStream;
+import java.io.IOException;
import java.io.InputStream;
import java.nio.ByteBuffer;
import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
import java.util.List;
import java.util.Map;
-public class ModelInformation {
+import static org.apache.iotdb.commons.model.TrailInformation.MODEL_PATH;
- private String modelId;
- private ModelTask modelTask;
- private String modelType;
+public class ModelInformation {
- private TrainingState modelState;
+ private final String modelId;
+ private final ModelTask modelTask;
+ private final String modelType;
- private List<String> queryExpressions;
- private String queryFilter;
+ private final List<String> queryExpressions;
+ private final String queryFilter;
- private boolean isAuto;
+ private final boolean isAuto;
+ private TrainingState trainingState;
+ private String bestTrailId;
private Map<String, TrailInformation> trailMap;
- private String bestTrailId;
- private String modelPath;
+ public ModelInformation(
+ String modelId,
+ ModelTask modelTask,
+ String modelType,
+ boolean isAuto,
+ List<String> queryExpressions,
+ String queryFilter) {
+ this.modelId = modelId;
+ this.modelTask = modelTask;
+ this.modelType = modelType;
+ this.isAuto = isAuto;
+ this.queryExpressions = queryExpressions;
+ this.queryFilter = queryFilter;
+ }
+
+ public ModelInformation(ByteBuffer buffer) {
+ this.modelId = ReadWriteIOUtils.readString(buffer);
+ this.modelTask = ModelTask.findByValue(ReadWriteIOUtils.readInt(buffer));
+ this.modelType = ReadWriteIOUtils.readString(buffer);
+
+ int listSize = ReadWriteIOUtils.readInt(buffer);
+ this.queryExpressions = new ArrayList<>(listSize);
+ for (int i = 0; i < listSize; i++) {
+ this.queryExpressions.add(ReadWriteIOUtils.readString(buffer));
+ }
+
+ this.queryFilter = ReadWriteIOUtils.readString(buffer);
+ this.isAuto = ReadWriteIOUtils.readBool(buffer);
+ this.trainingState = TrainingState.findByValue(ReadWriteIOUtils.readInt(buffer));
+ this.bestTrailId = ReadWriteIOUtils.readString(buffer);
+
+ int mapSize = ReadWriteIOUtils.readInt(buffer);
+ this.trailMap = new HashMap<>();
+ for (int i = 0; i < mapSize; i++) {
+ TrailInformation trailInformation = TrailInformation.deserialize(buffer);
+ this.trailMap.put(trailInformation.getTrailId(), trailInformation);
+ }
+ }
+
+ public ModelInformation(InputStream stream) throws IOException {
+ this.modelId = ReadWriteIOUtils.readString(stream);
+ this.modelTask = ModelTask.findByValue(ReadWriteIOUtils.readInt(stream));
+ this.modelType = ReadWriteIOUtils.readString(stream);
+
+ int listSize = ReadWriteIOUtils.readInt(stream);
+ this.queryExpressions = new ArrayList<>(listSize);
+ for (int i = 0; i < listSize; i++) {
+ this.queryExpressions.add(ReadWriteIOUtils.readString(stream));
+ }
+
+ this.queryFilter = ReadWriteIOUtils.readString(stream);
+ this.isAuto = ReadWriteIOUtils.readBool(stream);
+ this.trainingState = TrainingState.findByValue(ReadWriteIOUtils.readInt(stream));
+ this.bestTrailId = ReadWriteIOUtils.readString(stream);
+
+ int mapSize = ReadWriteIOUtils.readInt(stream);
+ this.trailMap = new HashMap<>();
+ for (int i = 0; i < mapSize; i++) {
+ TrailInformation trailInformation = TrailInformation.deserialize(stream);
+ this.trailMap.put(trailInformation.getTrailId(), trailInformation);
+ }
+ }
public String getModelId() {
return modelId;
@@ -68,24 +134,80 @@ public class ModelInformation {
return new ArrayList<>(trailMap.values());
}
- public void update(Map<String, String> modelInfo) {}
+ public void update(String trailId, Map<String, String> modelInfo) {
+ if (!trailMap.containsKey(trailId)) {
+ String modelPath = null;
+ if (modelInfo.containsKey(MODEL_PATH)) {
+ modelPath = modelInfo.get(MODEL_PATH);
+ modelInfo.remove(MODEL_PATH);
+ }
+ TrailInformation trailInformation =
+ new TrailInformation(trailId, new ModelHyperparameter(modelInfo), modelPath);
+ trailMap.put(trailId, trailInformation);
+ } else {
+ trailMap.get(trailId).update(modelInfo);
+ }
+ }
- public void serialize(DataOutputStream stream) {}
+ public void serialize(DataOutputStream stream) throws IOException {
+ ReadWriteIOUtils.write(modelId, stream);
+ ReadWriteIOUtils.write(modelTask.ordinal(), stream);
+ ReadWriteIOUtils.write(modelType, stream);
+ ReadWriteIOUtils.write(queryExpressions.size(), stream);
+ for (String queryExpression : queryExpressions) {
+ ReadWriteIOUtils.write(queryExpression, stream);
+ }
+ ReadWriteIOUtils.write(queryFilter, stream);
+ ReadWriteIOUtils.write(isAuto, stream);
+ ReadWriteIOUtils.write(trainingState.ordinal(), stream);
+ ReadWriteIOUtils.write(bestTrailId, stream);
+ ReadWriteIOUtils.write(trailMap.size(), stream);
+ for (TrailInformation trailInformation : trailMap.values()) {
+ trailInformation.serialize(stream);
+ }
+ }
- public void serialize(FileOutputStream stream) {}
+ public void serialize(FileOutputStream stream) throws IOException {
+ ReadWriteIOUtils.write(modelId, stream);
+ ReadWriteIOUtils.write(modelTask.ordinal(), stream);
+ ReadWriteIOUtils.write(modelType, stream);
- public static ModelInformation deserialize(InputStream stream) {
- return null;
+ ReadWriteIOUtils.write(queryExpressions.size(), stream);
+ for (String queryExpression : queryExpressions) {
+ ReadWriteIOUtils.write(queryExpression, stream);
+ }
+
+ ReadWriteIOUtils.write(queryFilter, stream);
+ ReadWriteIOUtils.write(isAuto, stream);
+ ReadWriteIOUtils.write(trainingState.ordinal(), stream);
+ ReadWriteIOUtils.write(bestTrailId, stream);
+
+ ReadWriteIOUtils.write(trailMap.size(), stream);
+ for (TrailInformation trailInformation : trailMap.values()) {
+ trailInformation.serialize(stream);
+ }
+ }
+
+ public static ModelInformation deserialize(InputStream stream) throws IOException {
+ return new ModelInformation(stream);
}
public static ModelInformation deserialize(ByteBuffer buffer) {
- return null;
+ return new ModelInformation(buffer);
}
- public ByteBuffer serializeShowModelResult() {
+ public ByteBuffer serializeShowModelResult() throws IOException {
PublicBAOS buffer = new PublicBAOS();
DataOutputStream stream = new DataOutputStream(buffer);
-
+ ReadWriteIOUtils.write(modelId, stream);
+ ReadWriteIOUtils.write(modelTask.toString(), stream);
+ ReadWriteIOUtils.write(modelType, stream);
+ ReadWriteIOUtils.write(Arrays.toString(queryExpressions.toArray(new String[0])), stream);
+ ReadWriteIOUtils.write(trainingState.toString(), stream);
+
+ TrailInformation bestTrail = trailMap.get(bestTrailId);
+ ReadWriteIOUtils.write(bestTrail.getModelHyperparameter().toString(), stream);
+ ReadWriteIOUtils.write(bestTrail.getModelPath(), stream);
return ByteBuffer.wrap(buffer.getBuf(), 0, buffer.size());
}
}
diff --git a/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelTable.java b/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelTable.java
index 598c26017e..756a57060f 100644
--- a/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelTable.java
+++ b/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelTable.java
@@ -60,8 +60,8 @@ public class ModelTable {
return null;
}
- public void updateModel(String modelId, Map<String, String> modelInfo) {
- modelInfoMap.get(modelId).update(modelInfo);
+ public void updateModel(String modelId, String trailId, Map<String, String> modelInfo) {
+ modelInfoMap.get(modelId).update(trailId, modelInfo);
}
public void clear() {
diff --git a/node-commons/src/main/java/org/apache/iotdb/commons/model/TrailInformation.java b/node-commons/src/main/java/org/apache/iotdb/commons/model/TrailInformation.java
index 0b35c3a98e..8551534d41 100644
--- a/node-commons/src/main/java/org/apache/iotdb/commons/model/TrailInformation.java
+++ b/node-commons/src/main/java/org/apache/iotdb/commons/model/TrailInformation.java
@@ -19,23 +19,83 @@
package org.apache.iotdb.commons.model;
-import org.apache.iotdb.common.rpc.thrift.Activation;
-import org.apache.iotdb.common.rpc.thrift.TrainingState;
+import org.apache.iotdb.tsfile.utils.PublicBAOS;
+import org.apache.iotdb.tsfile.utils.ReadWriteIOUtils;
+import java.io.DataOutputStream;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.io.InputStream;
import java.nio.ByteBuffer;
+import java.util.Map;
-public abstract class TrailInformation {
+public class TrailInformation {
- protected String trailId;
- protected TrainingState trailState;
+ public static final String MODEL_PATH = "model_path";
- protected long batchSize;
- protected double learningRate;
- protected long epochs;
+ private final String trailId;
+ private final ModelHyperparameter modelHyperparameter;
+ private String modelPath;
- protected long dModel;
- protected long dFF;
- protected Activation activation;
+ public TrailInformation(
+ String trailId, ModelHyperparameter modelHyperparameter, String modelPath) {
+ this.trailId = trailId;
+ this.modelHyperparameter = modelHyperparameter;
+ this.modelPath = modelPath;
+ }
- public abstract ByteBuffer serializeShowTrailResult();
+ public void update(Map<String, String> modelInfo) {
+ if (modelInfo.containsKey(MODEL_PATH)) {
+ modelPath = modelInfo.get(MODEL_PATH);
+ modelInfo.remove(MODEL_PATH);
+ }
+ modelHyperparameter.update(modelInfo);
+ }
+
+ public String getTrailId() {
+ return trailId;
+ }
+
+ public ModelHyperparameter getModelHyperparameter() {
+ return modelHyperparameter;
+ }
+
+ public String getModelPath() {
+ return modelPath;
+ }
+
+ public ByteBuffer serializeShowTrailResult() throws IOException {
+ PublicBAOS buffer = new PublicBAOS();
+ DataOutputStream stream = new DataOutputStream(buffer);
+ ReadWriteIOUtils.write(trailId, stream);
+ ReadWriteIOUtils.write(modelHyperparameter.toString(), stream);
+ ReadWriteIOUtils.write(modelPath, stream);
+ return ByteBuffer.wrap(buffer.getBuf(), 0, buffer.size());
+ }
+
+ public void serialize(DataOutputStream stream) throws IOException {
+ ReadWriteIOUtils.write(trailId, stream);
+ modelHyperparameter.serialize(stream);
+ ReadWriteIOUtils.write(modelPath, stream);
+ }
+
+ public void serialize(FileOutputStream stream) throws IOException {
+ ReadWriteIOUtils.write(trailId, stream);
+ modelHyperparameter.serialize(stream);
+ ReadWriteIOUtils.write(modelPath, stream);
+ }
+
+ public static TrailInformation deserialize(ByteBuffer buffer) {
+ return new TrailInformation(
+ ReadWriteIOUtils.readString(buffer),
+ ModelHyperparameter.deserialize(buffer),
+ ReadWriteIOUtils.readString(buffer));
+ }
+
+ public static TrailInformation deserialize(InputStream stream) throws IOException {
+ return new TrailInformation(
+ ReadWriteIOUtils.readString(stream),
+ ModelHyperparameter.deserialize(stream),
+ ReadWriteIOUtils.readString(stream));
+ }
}
diff --git a/thrift-commons/src/main/thrift/common.thrift b/thrift-commons/src/main/thrift/common.thrift
index 108010de8b..916f810624 100644
--- a/thrift-commons/src/main/thrift/common.thrift
+++ b/thrift-commons/src/main/thrift/common.thrift
@@ -126,7 +126,10 @@ struct TFilesResp {
// for MLNode
enum TrainingState {
-
+ PENDING,
+ RUNNING,
+ FINISHED,
+ FAILED
}
enum ModelTask {
@@ -137,11 +140,4 @@ enum EvaluateMetric {
MSE,
MAE,
RMSE
-}
-
-enum Activation {
- RELU,
- GELU,
- SIGMOID,
- TANH
}
\ No newline at end of file
diff --git a/thrift-confignode/src/main/thrift/confignode.thrift b/thrift-confignode/src/main/thrift/confignode.thrift
index 0544374411..51dbd39d9c 100644
--- a/thrift-confignode/src/main/thrift/confignode.thrift
+++ b/thrift-confignode/src/main/thrift/confignode.thrift
@@ -668,11 +668,12 @@ struct TUnsetSchemaTemplateReq{
struct TCreateModelReq {
1: required string modelId
- 2: required byte modelTask
- 3: required bool isAuto
- 4: required map<string, string> modelConfigs
- 5: required list<string> queryExpressions
- 6: optional string queryFilter
+ 2: required common.ModelTask modelTask
+ 3: required string modelType
+ 4: required list<string> queryExpressions
+ 5: optional string queryFilter
+ 6: required bool isAuto
+ 7: required map<string, string> modelConfigs
}
struct TDropModelReq {
@@ -700,7 +701,13 @@ struct TShowTrailResp {
struct TUpdateModelInfoReq {
1: required string modelId
- 2: required map<string, string> modelInfo
+ 2: required string trailId
+ 3: required map<string, string> modelInfo
+}
+
+struct TUpdateModelStateReq {
+ 1: required string modelId
+ 2: required common.TrainingState state
}
service IConfigNodeRPCService {
@@ -1250,5 +1257,12 @@ service IConfigNodeRPCService {
* @return SUCCESS_STATUS if the model was removed successfully
*/
common.TSStatus updateModelInfo(TUpdateModelInfoReq req)
+
+ /**
+ * Update the model state
+ *
+ * @return SUCCESS_STATUS if the model was removed successfully
+ */
+ common.TSStatus updateModelState(TUpdateModelStateReq req)
}