You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@iotdb.apache.org by hu...@apache.org on 2023/04/04 09:09:36 UTC
[iotdb] 03/03: support drop model
This is an automated email from the ASF dual-hosted git repository.
hui pushed a commit to branch mlnode/test
in repository https://gitbox.apache.org/repos/asf/iotdb.git
commit 3dcb848c81879817de772028a4b17880fcf13ffb
Author: Minghui Liu <li...@foxmail.com>
AuthorDate: Tue Apr 4 17:08:59 2023 +0800
support drop model
---
.../procedure/impl/model/DropModelProcedure.java | 4 +++-
mlnode/iotdb/mlnode/handler.py | 11 ++++++---
mlnode/iotdb/mlnode/util.py | 2 +-
.../db/mpp/plan/parser/StatementGenerator.java | 21 ++++++++++++++---
.../impl/DataNodeInternalRPCServiceImpl.java | 27 +++++++++++++++++++++-
.../service/thrift/impl/MLNodeRPCServiceImpl.java | 7 +++---
6 files changed, 59 insertions(+), 13 deletions(-)
diff --git a/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/DropModelProcedure.java b/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/DropModelProcedure.java
index bfa461a8b2..8f41de070d 100644
--- a/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/DropModelProcedure.java
+++ b/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/DropModelProcedure.java
@@ -156,7 +156,9 @@ public class DropModelProcedure extends AbstractNodeProcedure<DropModelState> {
if (getCycles() > RETRY_THRESHOLD) {
setFailure(
new ProcedureException(
- String.format("Fail to drop model [%s] at STATE [%s]", modelId, state)));
+ String.format(
+ "Fail to drop model [%s] at STATE [%s], %s",
+ modelId, state, e.getMessage())));
}
}
}
diff --git a/mlnode/iotdb/mlnode/handler.py b/mlnode/iotdb/mlnode/handler.py
index 1a6e3eb90a..b4c64d94b1 100644
--- a/mlnode/iotdb/mlnode/handler.py
+++ b/mlnode/iotdb/mlnode/handler.py
@@ -21,6 +21,7 @@ from iotdb.mlnode.constant import TSStatusCode
from iotdb.mlnode.data_access.factory import create_forecast_dataset
from iotdb.mlnode.parser import parse_training_request
from iotdb.mlnode.process.manager import TaskManager
+from iotdb.mlnode.storage import model_storage
from iotdb.mlnode.util import get_status
from iotdb.thrift.mlnode import IMLNodeRPCService
from iotdb.thrift.mlnode.ttypes import (TCreateTrainingTaskReq,
@@ -33,7 +34,11 @@ class MLNodeRPCServiceHandler(IMLNodeRPCService.Iface):
self.__task_manager = TaskManager(pool_num=10) # TODO: add pool num to config
def deleteModel(self, req: TDeleteModelReq):
- return get_status(TSStatusCode.SUCCESS_STATUS, "")
+ try:
+ model_storage.delete_model(req.modelId)
+ return get_status(TSStatusCode.SUCCESS_STATUS)
+ except Exception as e:
+ return get_status(TSStatusCode.FAIL_STATUS, str(e))
def createTrainingTask(self, req: TCreateTrainingTaskReq):
task = None
@@ -50,7 +55,7 @@ class MLNodeRPCServiceHandler(IMLNodeRPCService.Iface):
# create task & check task config legitimacy
task = self.__task_manager.create_training_task(dataset, model, model_config, task_config)
- return get_status(TSStatusCode.SUCCESS_STATUS, 'Successfully create training task')
+ return get_status(TSStatusCode.SUCCESS_STATUS)
except Exception as e:
return get_status(TSStatusCode.FAIL_STATUS, str(e))
finally:
@@ -58,6 +63,6 @@ class MLNodeRPCServiceHandler(IMLNodeRPCService.Iface):
self.__task_manager.submit_training_task(task)
def forecast(self, req: TForecastReq):
- status = get_status(TSStatusCode.SUCCESS_STATUS, "")
+ status = get_status(TSStatusCode.SUCCESS_STATUS)
forecast_result = b'forecast result'
return TForecastResp(status, forecast_result)
diff --git a/mlnode/iotdb/mlnode/util.py b/mlnode/iotdb/mlnode/util.py
index 5d3a2d670e..5cdc52f01a 100644
--- a/mlnode/iotdb/mlnode/util.py
+++ b/mlnode/iotdb/mlnode/util.py
@@ -45,7 +45,7 @@ def parse_endpoint_url(endpoint_url: str) -> TEndPoint:
raise BadNodeUrlError(endpoint_url)
-def get_status(status_code: TSStatusCode, message: str) -> TSStatus:
+def get_status(status_code: TSStatusCode, message: str = None) -> TSStatus:
status = TSStatus(status_code.get_status_code())
status.message = message
return status
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/parser/StatementGenerator.java b/server/src/main/java/org/apache/iotdb/db/mpp/plan/parser/StatementGenerator.java
index 422617162d..f0da4a29f1 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/plan/parser/StatementGenerator.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/plan/parser/StatementGenerator.java
@@ -66,6 +66,7 @@ import org.apache.iotdb.db.mpp.plan.statement.metadata.template.UnsetSchemaTempl
import org.apache.iotdb.db.qp.sql.IoTDBSqlParser;
import org.apache.iotdb.db.qp.sql.SqlLexer;
import org.apache.iotdb.db.utils.QueryDataSetUtils;
+import org.apache.iotdb.mpp.rpc.thrift.TDeleteModelMetricsReq;
import org.apache.iotdb.mpp.rpc.thrift.TFetchTimeseriesReq;
import org.apache.iotdb.mpp.rpc.thrift.TRecordModelMetricsReq;
import org.apache.iotdb.service.rpc.thrift.TSAggregationQueryReq;
@@ -110,6 +111,9 @@ import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
+import static org.apache.iotdb.commons.conf.IoTDBConstant.MULTI_LEVEL_PATH_WILDCARD;
+import static org.apache.iotdb.db.service.thrift.impl.MLNodeRPCServiceImpl.ML_METRICS_PATH_PREFIX;
+
/** Convert SQL and RPC requests to {@link Statement}. */
public class StatementGenerator {
private static final PerformanceOverviewMetrics PERFORMANCE_OVERVIEW_METRICS =
@@ -806,10 +810,10 @@ public class StatementGenerator {
return databasePath;
}
- public static InsertRowStatement createStatement(
- TRecordModelMetricsReq recordModelMetricsReq, String prefix) throws IllegalPathException {
+ public static InsertRowStatement createStatement(TRecordModelMetricsReq recordModelMetricsReq)
+ throws IllegalPathException {
String path =
- prefix
+ ML_METRICS_PATH_PREFIX
+ TsFileConstant.PATH_SEPARATOR
+ recordModelMetricsReq.getModelId()
+ TsFileConstant.PATH_SEPARATOR
@@ -870,4 +874,15 @@ public class StatementGenerator {
queryStatement.setSelectComponent(selectComponent);
return queryStatement;
}
+
+ public static DeleteTimeSeriesStatement createStatement(TDeleteModelMetricsReq req)
+ throws IllegalPathException {
+ String path =
+ ML_METRICS_PATH_PREFIX
+ + TsFileConstant.PATH_SEPARATOR
+ + req.getModelId()
+ + TsFileConstant.PATH_SEPARATOR
+ + MULTI_LEVEL_PATH_WILDCARD;
+ return new DeleteTimeSeriesStatement(Collections.singletonList(new PartialPath(path)));
+ }
}
diff --git a/server/src/main/java/org/apache/iotdb/db/service/thrift/impl/DataNodeInternalRPCServiceImpl.java b/server/src/main/java/org/apache/iotdb/db/service/thrift/impl/DataNodeInternalRPCServiceImpl.java
index d00903759a..ea1cc229c6 100644
--- a/server/src/main/java/org/apache/iotdb/db/service/thrift/impl/DataNodeInternalRPCServiceImpl.java
+++ b/server/src/main/java/org/apache/iotdb/db/service/thrift/impl/DataNodeInternalRPCServiceImpl.java
@@ -102,6 +102,7 @@ import org.apache.iotdb.db.mpp.plan.planner.plan.node.write.DeleteDataNode;
import org.apache.iotdb.db.mpp.plan.scheduler.load.LoadTsFileScheduler;
import org.apache.iotdb.db.mpp.plan.statement.component.WhereCondition;
import org.apache.iotdb.db.mpp.plan.statement.crud.QueryStatement;
+import org.apache.iotdb.db.mpp.plan.statement.metadata.DeleteTimeSeriesStatement;
import org.apache.iotdb.db.pipe.agent.PipeAgent;
import org.apache.iotdb.db.query.control.SessionManager;
import org.apache.iotdb.db.query.control.clientsession.IClientSession;
@@ -196,6 +197,7 @@ import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
+import java.util.TimeZone;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
@@ -870,7 +872,30 @@ public class DataNodeInternalRPCServiceImpl implements IDataNodeRPCService.Iface
@Override
public TSStatus deleteModelMetrics(TDeleteModelMetricsReq req) throws TException {
- return RpcUtils.SUCCESS_STATUS;
+ IClientSession session = new InternalClientSession(req.getModelId());
+ SESSION_MANAGER.registerSession(session);
+ SESSION_MANAGER.supplySession(
+ session, "MLNode", TimeZone.getDefault().getID(), ClientVersion.V_1_0);
+
+ try {
+ DeleteTimeSeriesStatement deleteTimeSeriesStatement = StatementGenerator.createStatement(req);
+
+ long queryId = SESSION_MANAGER.requestQueryId();
+ ExecutionResult result =
+ COORDINATOR.execute(
+ deleteTimeSeriesStatement,
+ queryId,
+ SESSION_MANAGER.getSessionInfo(session),
+ "",
+ PARTITION_FETCHER,
+ SCHEMA_FETCHER);
+ return result.status;
+ } catch (Exception e) {
+ return onQueryException(e, OperationType.DELETE_TIMESERIES);
+ } finally {
+ SESSION_MANAGER.closeSession(session, COORDINATOR::cleanupQueryExecution);
+ SESSION_MANAGER.removeCurrSession();
+ }
}
private PathPatternTree filterPathPatternTree(PathPatternTree patternTree, String storageGroup) {
diff --git a/server/src/main/java/org/apache/iotdb/db/service/thrift/impl/MLNodeRPCServiceImpl.java b/server/src/main/java/org/apache/iotdb/db/service/thrift/impl/MLNodeRPCServiceImpl.java
index 544d4cd04c..74e14f9f3f 100644
--- a/server/src/main/java/org/apache/iotdb/db/service/thrift/impl/MLNodeRPCServiceImpl.java
+++ b/server/src/main/java/org/apache/iotdb/db/service/thrift/impl/MLNodeRPCServiceImpl.java
@@ -60,14 +60,14 @@ import static org.apache.iotdb.db.utils.ErrorHandlingUtils.onQueryException;
public class MLNodeRPCServiceImpl implements IMLNodeRPCServiceWithHandler {
+ public static final String ML_METRICS_PATH_PREFIX = "root.__system.ml.exp";
+
private static final Logger LOGGER = LoggerFactory.getLogger(MLNodeRPCServiceImpl.class);
private static final SessionManager SESSION_MANAGER = SessionManager.getInstance();
private static final Coordinator COORDINATOR = Coordinator.getInstance();
- private static final String ML_METRICS_STORAGE_GROUP = "root.__system.ml.exp";
-
private final IPartitionFetcher PARTITION_FETCHER;
private final ISchemaFetcher SCHEMA_FETCHER;
@@ -176,8 +176,7 @@ public class MLNodeRPCServiceImpl implements IMLNodeRPCServiceWithHandler {
@Override
public TSStatus recordModelMetrics(TRecordModelMetricsReq req) throws TException {
try {
- InsertRowStatement insertRowStatement =
- StatementGenerator.createStatement(req, ML_METRICS_STORAGE_GROUP);
+ InsertRowStatement insertRowStatement = StatementGenerator.createStatement(req);
long queryId = SESSION_MANAGER.requestQueryId();
ExecutionResult result =