You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@iotdb.apache.org by hu...@apache.org on 2023/04/04 09:09:36 UTC

[iotdb] 03/03: support drop model

This is an automated email from the ASF dual-hosted git repository.

hui pushed a commit to branch mlnode/test
in repository https://gitbox.apache.org/repos/asf/iotdb.git

commit 3dcb848c81879817de772028a4b17880fcf13ffb
Author: Minghui Liu <li...@foxmail.com>
AuthorDate: Tue Apr 4 17:08:59 2023 +0800

    support drop model
---
 .../procedure/impl/model/DropModelProcedure.java   |  4 +++-
 mlnode/iotdb/mlnode/handler.py                     | 11 ++++++---
 mlnode/iotdb/mlnode/util.py                        |  2 +-
 .../db/mpp/plan/parser/StatementGenerator.java     | 21 ++++++++++++++---
 .../impl/DataNodeInternalRPCServiceImpl.java       | 27 +++++++++++++++++++++-
 .../service/thrift/impl/MLNodeRPCServiceImpl.java  |  7 +++---
 6 files changed, 59 insertions(+), 13 deletions(-)

diff --git a/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/DropModelProcedure.java b/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/DropModelProcedure.java
index bfa461a8b2..8f41de070d 100644
--- a/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/DropModelProcedure.java
+++ b/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/DropModelProcedure.java
@@ -156,7 +156,9 @@ public class DropModelProcedure extends AbstractNodeProcedure<DropModelState> {
         if (getCycles() > RETRY_THRESHOLD) {
           setFailure(
               new ProcedureException(
-                  String.format("Fail to drop model [%s] at STATE [%s]", modelId, state)));
+                  String.format(
+                      "Fail to drop model [%s] at STATE [%s], %s",
+                      modelId, state, e.getMessage())));
         }
       }
     }
diff --git a/mlnode/iotdb/mlnode/handler.py b/mlnode/iotdb/mlnode/handler.py
index 1a6e3eb90a..b4c64d94b1 100644
--- a/mlnode/iotdb/mlnode/handler.py
+++ b/mlnode/iotdb/mlnode/handler.py
@@ -21,6 +21,7 @@ from iotdb.mlnode.constant import TSStatusCode
 from iotdb.mlnode.data_access.factory import create_forecast_dataset
 from iotdb.mlnode.parser import parse_training_request
 from iotdb.mlnode.process.manager import TaskManager
+from iotdb.mlnode.storage import model_storage
 from iotdb.mlnode.util import get_status
 from iotdb.thrift.mlnode import IMLNodeRPCService
 from iotdb.thrift.mlnode.ttypes import (TCreateTrainingTaskReq,
@@ -33,7 +34,11 @@ class MLNodeRPCServiceHandler(IMLNodeRPCService.Iface):
         self.__task_manager = TaskManager(pool_num=10)  # TODO: add pool num to config
 
     def deleteModel(self, req: TDeleteModelReq):
-        return get_status(TSStatusCode.SUCCESS_STATUS, "")
+        try:
+            model_storage.delete_model(req.modelId)
+            return get_status(TSStatusCode.SUCCESS_STATUS)
+        except Exception as e:
+            return get_status(TSStatusCode.FAIL_STATUS, str(e))
 
     def createTrainingTask(self, req: TCreateTrainingTaskReq):
         task = None
@@ -50,7 +55,7 @@ class MLNodeRPCServiceHandler(IMLNodeRPCService.Iface):
             # create task & check task config legitimacy
             task = self.__task_manager.create_training_task(dataset, model, model_config, task_config)
 
-            return get_status(TSStatusCode.SUCCESS_STATUS, 'Successfully create training task')
+            return get_status(TSStatusCode.SUCCESS_STATUS)
         except Exception as e:
             return get_status(TSStatusCode.FAIL_STATUS, str(e))
         finally:
@@ -58,6 +63,6 @@ class MLNodeRPCServiceHandler(IMLNodeRPCService.Iface):
             self.__task_manager.submit_training_task(task)
 
     def forecast(self, req: TForecastReq):
-        status = get_status(TSStatusCode.SUCCESS_STATUS, "")
+        status = get_status(TSStatusCode.SUCCESS_STATUS)
         forecast_result = b'forecast result'
         return TForecastResp(status, forecast_result)
diff --git a/mlnode/iotdb/mlnode/util.py b/mlnode/iotdb/mlnode/util.py
index 5d3a2d670e..5cdc52f01a 100644
--- a/mlnode/iotdb/mlnode/util.py
+++ b/mlnode/iotdb/mlnode/util.py
@@ -45,7 +45,7 @@ def parse_endpoint_url(endpoint_url: str) -> TEndPoint:
         raise BadNodeUrlError(endpoint_url)
 
 
-def get_status(status_code: TSStatusCode, message: str) -> TSStatus:
+def get_status(status_code: TSStatusCode, message: str = None) -> TSStatus:
     status = TSStatus(status_code.get_status_code())
     status.message = message
     return status
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/parser/StatementGenerator.java b/server/src/main/java/org/apache/iotdb/db/mpp/plan/parser/StatementGenerator.java
index 422617162d..f0da4a29f1 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/plan/parser/StatementGenerator.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/plan/parser/StatementGenerator.java
@@ -66,6 +66,7 @@ import org.apache.iotdb.db.mpp.plan.statement.metadata.template.UnsetSchemaTempl
 import org.apache.iotdb.db.qp.sql.IoTDBSqlParser;
 import org.apache.iotdb.db.qp.sql.SqlLexer;
 import org.apache.iotdb.db.utils.QueryDataSetUtils;
+import org.apache.iotdb.mpp.rpc.thrift.TDeleteModelMetricsReq;
 import org.apache.iotdb.mpp.rpc.thrift.TFetchTimeseriesReq;
 import org.apache.iotdb.mpp.rpc.thrift.TRecordModelMetricsReq;
 import org.apache.iotdb.service.rpc.thrift.TSAggregationQueryReq;
@@ -110,6 +111,9 @@ import java.util.LinkedHashMap;
 import java.util.List;
 import java.util.Map;
 
+import static org.apache.iotdb.commons.conf.IoTDBConstant.MULTI_LEVEL_PATH_WILDCARD;
+import static org.apache.iotdb.db.service.thrift.impl.MLNodeRPCServiceImpl.ML_METRICS_PATH_PREFIX;
+
 /** Convert SQL and RPC requests to {@link Statement}. */
 public class StatementGenerator {
   private static final PerformanceOverviewMetrics PERFORMANCE_OVERVIEW_METRICS =
@@ -806,10 +810,10 @@ public class StatementGenerator {
     return databasePath;
   }
 
-  public static InsertRowStatement createStatement(
-      TRecordModelMetricsReq recordModelMetricsReq, String prefix) throws IllegalPathException {
+  public static InsertRowStatement createStatement(TRecordModelMetricsReq recordModelMetricsReq)
+      throws IllegalPathException {
     String path =
-        prefix
+        ML_METRICS_PATH_PREFIX
             + TsFileConstant.PATH_SEPARATOR
             + recordModelMetricsReq.getModelId()
             + TsFileConstant.PATH_SEPARATOR
@@ -870,4 +874,15 @@ public class StatementGenerator {
     queryStatement.setSelectComponent(selectComponent);
     return queryStatement;
   }
+
+  public static DeleteTimeSeriesStatement createStatement(TDeleteModelMetricsReq req)
+      throws IllegalPathException {
+    String path =
+        ML_METRICS_PATH_PREFIX
+            + TsFileConstant.PATH_SEPARATOR
+            + req.getModelId()
+            + TsFileConstant.PATH_SEPARATOR
+            + MULTI_LEVEL_PATH_WILDCARD;
+    return new DeleteTimeSeriesStatement(Collections.singletonList(new PartialPath(path)));
+  }
 }
diff --git a/server/src/main/java/org/apache/iotdb/db/service/thrift/impl/DataNodeInternalRPCServiceImpl.java b/server/src/main/java/org/apache/iotdb/db/service/thrift/impl/DataNodeInternalRPCServiceImpl.java
index d00903759a..ea1cc229c6 100644
--- a/server/src/main/java/org/apache/iotdb/db/service/thrift/impl/DataNodeInternalRPCServiceImpl.java
+++ b/server/src/main/java/org/apache/iotdb/db/service/thrift/impl/DataNodeInternalRPCServiceImpl.java
@@ -102,6 +102,7 @@ import org.apache.iotdb.db.mpp.plan.planner.plan.node.write.DeleteDataNode;
 import org.apache.iotdb.db.mpp.plan.scheduler.load.LoadTsFileScheduler;
 import org.apache.iotdb.db.mpp.plan.statement.component.WhereCondition;
 import org.apache.iotdb.db.mpp.plan.statement.crud.QueryStatement;
+import org.apache.iotdb.db.mpp.plan.statement.metadata.DeleteTimeSeriesStatement;
 import org.apache.iotdb.db.pipe.agent.PipeAgent;
 import org.apache.iotdb.db.query.control.SessionManager;
 import org.apache.iotdb.db.query.control.clientsession.IClientSession;
@@ -196,6 +197,7 @@ import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.Optional;
+import java.util.TimeZone;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicInteger;
 import java.util.concurrent.atomic.AtomicLong;
@@ -870,7 +872,30 @@ public class DataNodeInternalRPCServiceImpl implements IDataNodeRPCService.Iface
 
   @Override
   public TSStatus deleteModelMetrics(TDeleteModelMetricsReq req) throws TException {
-    return RpcUtils.SUCCESS_STATUS;
+    IClientSession session = new InternalClientSession(req.getModelId());
+    SESSION_MANAGER.registerSession(session);
+    SESSION_MANAGER.supplySession(
+        session, "MLNode", TimeZone.getDefault().getID(), ClientVersion.V_1_0);
+
+    try {
+      DeleteTimeSeriesStatement deleteTimeSeriesStatement = StatementGenerator.createStatement(req);
+
+      long queryId = SESSION_MANAGER.requestQueryId();
+      ExecutionResult result =
+          COORDINATOR.execute(
+              deleteTimeSeriesStatement,
+              queryId,
+              SESSION_MANAGER.getSessionInfo(session),
+              "",
+              PARTITION_FETCHER,
+              SCHEMA_FETCHER);
+      return result.status;
+    } catch (Exception e) {
+      return onQueryException(e, OperationType.DELETE_TIMESERIES);
+    } finally {
+      SESSION_MANAGER.closeSession(session, COORDINATOR::cleanupQueryExecution);
+      SESSION_MANAGER.removeCurrSession();
+    }
   }
 
   private PathPatternTree filterPathPatternTree(PathPatternTree patternTree, String storageGroup) {
diff --git a/server/src/main/java/org/apache/iotdb/db/service/thrift/impl/MLNodeRPCServiceImpl.java b/server/src/main/java/org/apache/iotdb/db/service/thrift/impl/MLNodeRPCServiceImpl.java
index 544d4cd04c..74e14f9f3f 100644
--- a/server/src/main/java/org/apache/iotdb/db/service/thrift/impl/MLNodeRPCServiceImpl.java
+++ b/server/src/main/java/org/apache/iotdb/db/service/thrift/impl/MLNodeRPCServiceImpl.java
@@ -60,14 +60,14 @@ import static org.apache.iotdb.db.utils.ErrorHandlingUtils.onQueryException;
 
 public class MLNodeRPCServiceImpl implements IMLNodeRPCServiceWithHandler {
 
+  public static final String ML_METRICS_PATH_PREFIX = "root.__system.ml.exp";
+
   private static final Logger LOGGER = LoggerFactory.getLogger(MLNodeRPCServiceImpl.class);
 
   private static final SessionManager SESSION_MANAGER = SessionManager.getInstance();
 
   private static final Coordinator COORDINATOR = Coordinator.getInstance();
 
-  private static final String ML_METRICS_STORAGE_GROUP = "root.__system.ml.exp";
-
   private final IPartitionFetcher PARTITION_FETCHER;
 
   private final ISchemaFetcher SCHEMA_FETCHER;
@@ -176,8 +176,7 @@ public class MLNodeRPCServiceImpl implements IMLNodeRPCServiceWithHandler {
   @Override
   public TSStatus recordModelMetrics(TRecordModelMetricsReq req) throws TException {
     try {
-      InsertRowStatement insertRowStatement =
-          StatementGenerator.createStatement(req, ML_METRICS_STORAGE_GROUP);
+      InsertRowStatement insertRowStatement = StatementGenerator.createStatement(req);
 
       long queryId = SESSION_MANAGER.requestQueryId();
       ExecutionResult result =