You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@iotdb.apache.org by hu...@apache.org on 2023/03/14 01:52:04 UTC

[iotdb] 09/11: finish & fix bugs

This is an automated email from the ASF dual-hosted git repository.

hui pushed a commit to branch lmh/MLSQL
in repository https://gitbox.apache.org/repos/asf/iotdb.git

commit c2374bcc3ef97b1945b19510a92061a89bc49bc9
Author: Minghui Liu <li...@foxmail.com>
AuthorDate: Mon Mar 13 15:03:54 2023 +0800

    finish & fix bugs
---
 .../org/apache/iotdb/db/qp/sql/IoTDBSqlParser.g4   |  3 +-
 .../antlr4/org/apache/iotdb/db/qp/sql/SqlLexer.g4  |  4 --
 .../iotdb/confignode/persistence/ModelInfo.java    | 12 ++++-
 .../iotdb/commons/model/ModelInformation.java      |  3 +-
 .../apache/iotdb/db/client/ConfigNodeClient.java   | 58 ++++++++++++++++++----
 .../db/mpp/common/header/ColumnHeaderConstant.java | 25 ++++++++++
 .../db/mpp/common/header/DatasetHeaderFactory.java |  8 +++
 .../config/executor/ClusterConfigTaskExecutor.java |  7 ++-
 .../config/metadata/model/ShowModelsTask.java      | 45 ++++++++++++++++-
 .../config/metadata/model/ShowTrailsTask.java      | 33 +++++++++++-
 .../iotdb/db/mpp/plan/parser/ASTVisitor.java       |  3 +-
 .../metadata/model/CreateModelStatement.java       | 19 ++++++-
 12 files changed, 196 insertions(+), 24 deletions(-)

diff --git a/antlr/src/main/antlr4/org/apache/iotdb/db/qp/sql/IoTDBSqlParser.g4 b/antlr/src/main/antlr4/org/apache/iotdb/db/qp/sql/IoTDBSqlParser.g4
index 109a5b3bbd..e727c4566b 100644
--- a/antlr/src/main/antlr4/org/apache/iotdb/db/qp/sql/IoTDBSqlParser.g4
+++ b/antlr/src/main/antlr4/org/apache/iotdb/db/qp/sql/IoTDBSqlParser.g4
@@ -452,8 +452,7 @@ migrateRegion
 // ---- Create Model
 createModel
     : CREATE AUTO? MODEL modelId=identifier
-        WITH MODEL_TASK operator_eq modelTask=attributeValue
-            (COMMA attributePair)*
+        WITH attributePair (COMMA attributePair)*
         BEGIN
             selectStatement
         END
diff --git a/antlr/src/main/antlr4/org/apache/iotdb/db/qp/sql/SqlLexer.g4 b/antlr/src/main/antlr4/org/apache/iotdb/db/qp/sql/SqlLexer.g4
index 2ddab5fd35..f06d70e318 100644
--- a/antlr/src/main/antlr4/org/apache/iotdb/db/qp/sql/SqlLexer.g4
+++ b/antlr/src/main/antlr4/org/apache/iotdb/db/qp/sql/SqlLexer.g4
@@ -394,10 +394,6 @@ MODELS
     : M O D E L S
     ;
 
-MODEL_TASK
-    : M O D E L '_' T A S K
-    ;
-
 NODEID
     : N O D E I D
     ;
diff --git a/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java b/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java
index 58e3ec4e6b..13e4dabe73 100644
--- a/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java
+++ b/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java
@@ -122,7 +122,17 @@ public class ModelInfo implements SnapshotProcessor {
   public TrailTableResp showTrail(ShowTrailPlan plan) {
     acquireModelTableLock();
     try {
-      ModelInformation modelInformation = modelTable.getModelInformationById(plan.getModelId());
+      String modelId = plan.getModelId();
+      ModelInformation modelInformation = modelTable.getModelInformationById(modelId);
+      if (modelInformation == null) {
+        return new TrailTableResp(
+            new TSStatus(TSStatusCode.EXECUTE_STATEMENT_ERROR.getStatusCode())
+                .setMessage(
+                    String.format(
+                        "Failed to show trails of model [%s], this model has not been created.",
+                        modelId)));
+      }
+
       TrailTableResp trailTableResp =
           new TrailTableResp(new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode()));
       if (plan.isSetTrailId()) {
diff --git a/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelInformation.java b/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelInformation.java
index ecccda085e..e6fbf13c95 100644
--- a/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelInformation.java
+++ b/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelInformation.java
@@ -50,7 +50,7 @@ public class ModelInformation {
   private TrainingState trainingState;
 
   private String bestTrailId;
-  private Map<String, TrailInformation> trailMap;
+  private final Map<String, TrailInformation> trailMap;
 
   public ModelInformation(
       String modelId,
@@ -65,6 +65,7 @@ public class ModelInformation {
     this.isAuto = isAuto;
     this.queryExpressions = queryExpressions;
     this.queryFilter = queryFilter;
+    this.trailMap = new HashMap<>();
   }
 
   public ModelInformation(ByteBuffer buffer) {
diff --git a/server/src/main/java/org/apache/iotdb/db/client/ConfigNodeClient.java b/server/src/main/java/org/apache/iotdb/db/client/ConfigNodeClient.java
index c525d5cd75..ad9ee97782 100644
--- a/server/src/main/java/org/apache/iotdb/db/client/ConfigNodeClient.java
+++ b/server/src/main/java/org/apache/iotdb/db/client/ConfigNodeClient.java
@@ -1919,37 +1919,75 @@ public class ConfigNodeClient implements IConfigNodeRPCService.Iface, ThriftClie
 
   @Override
   public TSStatus createModel(TCreateModelReq req) throws TException {
-    // TODO
-    throw new TException(new UnsupportedOperationException().getCause());
+    for (int i = 0; i < RETRY_NUM; i++) {
+      try {
+        TSStatus status = client.createModel(req);
+        if (!updateConfigNodeLeader(status)) {
+          return status;
+        }
+      } catch (TException e) {
+        configLeader = null;
+      }
+      waitAndReconnect();
+    }
+    throw new TException(MSG_RECONNECTION_FAIL);
   }
 
   @Override
   public TSStatus dropModel(TDropModelReq req) throws TException {
-    // TODO
-    throw new TException(new UnsupportedOperationException().getCause());
+    for (int i = 0; i < RETRY_NUM; i++) {
+      try {
+        TSStatus status = client.dropModel(req);
+        if (!updateConfigNodeLeader(status)) {
+          return status;
+        }
+      } catch (TException e) {
+        configLeader = null;
+      }
+      waitAndReconnect();
+    }
+    throw new TException(MSG_RECONNECTION_FAIL);
   }
 
   @Override
   public TShowModelResp showModel(TShowModelReq req) throws TException {
-    // TODO
-    throw new TException(new UnsupportedOperationException().getCause());
+    for (int i = 0; i < RETRY_NUM; i++) {
+      try {
+        TShowModelResp resp = client.showModel(req);
+        if (!updateConfigNodeLeader(resp.getStatus())) {
+          return resp;
+        }
+      } catch (TException e) {
+        configLeader = null;
+      }
+      waitAndReconnect();
+    }
+    throw new TException(MSG_RECONNECTION_FAIL);
   }
 
   @Override
   public TShowTrailResp showTrail(TShowTrailReq req) throws TException {
-    // TODO
-    throw new TException(new UnsupportedOperationException().getCause());
+    for (int i = 0; i < RETRY_NUM; i++) {
+      try {
+        TShowTrailResp resp = client.showTrail(req);
+        if (!updateConfigNodeLeader(resp.getStatus())) {
+          return resp;
+        }
+      } catch (TException e) {
+        configLeader = null;
+      }
+      waitAndReconnect();
+    }
+    throw new TException(MSG_RECONNECTION_FAIL);
   }
 
   @Override
   public TSStatus updateModelInfo(TUpdateModelInfoReq req) throws TException {
-    // TODO
     throw new TException(new UnsupportedOperationException().getCause());
   }
 
   @Override
   public TSStatus updateModelState(TUpdateModelStateReq req) throws TException {
-    // TODO
     throw new TException(new UnsupportedOperationException().getCause());
   }
 
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/common/header/ColumnHeaderConstant.java b/server/src/main/java/org/apache/iotdb/db/mpp/common/header/ColumnHeaderConstant.java
index baaf9a773c..88a22dd1ea 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/common/header/ColumnHeaderConstant.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/common/header/ColumnHeaderConstant.java
@@ -153,6 +153,15 @@ public class ColumnHeaderConstant {
   public static final String ELAPSED_TIME = "ElapsedTime";
   public static final String STATEMENT = "Statement";
 
+  // column names for show models/trails
+  public static final String MODEL_ID = "ModelId";
+  public static final String TRAIL_ID = "TrailId";
+  public static final String MODEL_TASK = "ModelTask";
+  public static final String MODEL_TYPE = "ModelType";
+  public static final String QUERY_BODY = "QueryBody";
+  public static final String HYPERPARAMETER = "Hyperparameter";
+  public static final String MODEL_PATH = "ModelPath";
+
   public static final List<ColumnHeader> lastQueryColumnHeaders =
       ImmutableList.of(
           new ColumnHeader(TIMESERIES, TSDataType.TEXT),
@@ -381,4 +390,20 @@ public class ColumnHeaderConstant {
           new ColumnHeader(DATA_NODE_ID, TSDataType.INT32),
           new ColumnHeader(ELAPSED_TIME, TSDataType.FLOAT),
           new ColumnHeader(STATEMENT, TSDataType.TEXT));
+
+  public static final List<ColumnHeader> showModelsColumnHeaders =
+      ImmutableList.of(
+          new ColumnHeader(MODEL_ID, TSDataType.TEXT),
+          new ColumnHeader(MODEL_TASK, TSDataType.TEXT),
+          new ColumnHeader(MODEL_TYPE, TSDataType.TEXT),
+          new ColumnHeader(QUERY_BODY, TSDataType.TEXT),
+          new ColumnHeader(STATE, TSDataType.TEXT),
+          new ColumnHeader(HYPERPARAMETER, TSDataType.TEXT),
+          new ColumnHeader(MODEL_PATH, TSDataType.TEXT));
+
+  public static final List<ColumnHeader> showTrailsColumnHeaders =
+      ImmutableList.of(
+          new ColumnHeader(TRAIL_ID, TSDataType.TEXT),
+          new ColumnHeader(HYPERPARAMETER, TSDataType.TEXT),
+          new ColumnHeader(MODEL_PATH, TSDataType.TEXT));
 }
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/common/header/DatasetHeaderFactory.java b/server/src/main/java/org/apache/iotdb/db/mpp/common/header/DatasetHeaderFactory.java
index 7e8eb19b05..ff0bb03b9b 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/common/header/DatasetHeaderFactory.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/common/header/DatasetHeaderFactory.java
@@ -164,4 +164,12 @@ public class DatasetHeaderFactory {
   public static DatasetHeader getShowQueriesHeader() {
     return new DatasetHeader(ColumnHeaderConstant.showQueriesColumnHeaders, false);
   }
+
+  public static DatasetHeader getShowModelsHeader() {
+    return new DatasetHeader(ColumnHeaderConstant.showModelsColumnHeaders, false);
+  }
+
+  public static DatasetHeader getShowTrailsHeader() {
+    return new DatasetHeader(ColumnHeaderConstant.showTrailsColumnHeaders, false);
+  }
 }
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/execution/config/executor/ClusterConfigTaskExecutor.java b/server/src/main/java/org/apache/iotdb/db/mpp/plan/execution/config/executor/ClusterConfigTaskExecutor.java
index 8730967fd0..c5c8ea7e45 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/plan/execution/config/executor/ClusterConfigTaskExecutor.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/plan/execution/config/executor/ClusterConfigTaskExecutor.java
@@ -1658,16 +1658,19 @@ public class ClusterConfigTaskExecutor implements IConfigTaskExecutor {
     for (Expression expression : analysis.getSelectExpressions()) {
       queryExpressions.add(expression.toString());
     }
-    String queryFilter = analysis.getWhereExpression().toString();
+    Expression whereExpression = analysis.getWhereExpression();
+    String queryFilter = whereExpression == null ? null : whereExpression.toString();
 
     SettableFuture<ConfigTaskResult> future = SettableFuture.create();
     try (ConfigNodeClient client =
         CONFIG_NODE_CLIENT_MANAGER.borrowClient(ConfigNodeInfo.configNodeRegionId)) {
       TCreateModelReq createModelReq = new TCreateModelReq();
       createModelReq.setModelId(createModelStatement.getModelId());
+      createModelReq.setModelTask(createModelStatement.getModelTask());
+      createModelReq.setModelType(createModelStatement.getModelType());
+      createModelReq.setIsAuto(createModelStatement.isAuto());
       createModelReq.setQueryExpressions(queryExpressions);
       createModelReq.setQueryFilter(queryFilter);
-      createModelReq.setIsAuto(createModelStatement.isAuto());
       createModelReq.setModelConfigs(createModelStatement.getAttributes());
       final TSStatus executionStatus = client.createModel(createModelReq);
       if (TSStatusCode.SUCCESS_STATUS.getStatusCode() != executionStatus.getCode()) {
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/execution/config/metadata/model/ShowModelsTask.java b/server/src/main/java/org/apache/iotdb/db/mpp/plan/execution/config/metadata/model/ShowModelsTask.java
index 7f7a3d1db1..7fd719e0fd 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/plan/execution/config/metadata/model/ShowModelsTask.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/plan/execution/config/metadata/model/ShowModelsTask.java
@@ -19,15 +19,25 @@
 
 package org.apache.iotdb.db.mpp.plan.execution.config.metadata.model;
 
+import org.apache.iotdb.db.mpp.common.header.ColumnHeader;
+import org.apache.iotdb.db.mpp.common.header.ColumnHeaderConstant;
+import org.apache.iotdb.db.mpp.common.header.DatasetHeader;
+import org.apache.iotdb.db.mpp.common.header.DatasetHeaderFactory;
 import org.apache.iotdb.db.mpp.plan.execution.config.ConfigTaskResult;
 import org.apache.iotdb.db.mpp.plan.execution.config.IConfigTask;
 import org.apache.iotdb.db.mpp.plan.execution.config.executor.IConfigTaskExecutor;
+import org.apache.iotdb.rpc.TSStatusCode;
+import org.apache.iotdb.tsfile.file.metadata.enums.TSDataType;
+import org.apache.iotdb.tsfile.read.common.block.TsBlockBuilder;
+import org.apache.iotdb.tsfile.utils.Binary;
+import org.apache.iotdb.tsfile.utils.ReadWriteIOUtils;
 
 import com.google.common.util.concurrent.ListenableFuture;
 import com.google.common.util.concurrent.SettableFuture;
 
 import java.nio.ByteBuffer;
 import java.util.List;
+import java.util.stream.Collectors;
 
 public class ShowModelsTask implements IConfigTask {
 
@@ -40,5 +50,38 @@ public class ShowModelsTask implements IConfigTask {
   }
 
   public static void buildTsBlock(
-      List<ByteBuffer> modelInfoList, SettableFuture<ConfigTaskResult> future) {}
+      List<ByteBuffer> modelInfoList, SettableFuture<ConfigTaskResult> future) {
+    List<TSDataType> outputDataTypes =
+        ColumnHeaderConstant.showModelsColumnHeaders.stream()
+            .map(ColumnHeader::getColumnType)
+            .collect(Collectors.toList());
+    TsBlockBuilder builder = new TsBlockBuilder(outputDataTypes);
+    for (ByteBuffer modelInfo : modelInfoList) {
+      builder.getTimeColumnBuilder().writeLong(0L);
+      builder
+          .getColumnBuilder(0)
+          .writeBinary(Binary.valueOf(ReadWriteIOUtils.readString(modelInfo)));
+      builder
+          .getColumnBuilder(1)
+          .writeBinary(Binary.valueOf(ReadWriteIOUtils.readString(modelInfo)));
+      builder
+          .getColumnBuilder(2)
+          .writeBinary(Binary.valueOf(ReadWriteIOUtils.readString(modelInfo)));
+      builder
+          .getColumnBuilder(3)
+          .writeBinary(Binary.valueOf(ReadWriteIOUtils.readString(modelInfo)));
+      builder
+          .getColumnBuilder(4)
+          .writeBinary(Binary.valueOf(ReadWriteIOUtils.readString(modelInfo)));
+      builder
+          .getColumnBuilder(5)
+          .writeBinary(Binary.valueOf(ReadWriteIOUtils.readString(modelInfo)));
+      builder
+          .getColumnBuilder(6)
+          .writeBinary(Binary.valueOf(ReadWriteIOUtils.readString(modelInfo)));
+      builder.declarePosition();
+    }
+    DatasetHeader datasetHeader = DatasetHeaderFactory.getShowModelsHeader();
+    future.set(new ConfigTaskResult(TSStatusCode.SUCCESS_STATUS, builder.build(), datasetHeader));
+  }
 }
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/execution/config/metadata/model/ShowTrailsTask.java b/server/src/main/java/org/apache/iotdb/db/mpp/plan/execution/config/metadata/model/ShowTrailsTask.java
index 948a8e91fb..a428c27794 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/plan/execution/config/metadata/model/ShowTrailsTask.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/plan/execution/config/metadata/model/ShowTrailsTask.java
@@ -19,15 +19,25 @@
 
 package org.apache.iotdb.db.mpp.plan.execution.config.metadata.model;
 
+import org.apache.iotdb.db.mpp.common.header.ColumnHeader;
+import org.apache.iotdb.db.mpp.common.header.ColumnHeaderConstant;
+import org.apache.iotdb.db.mpp.common.header.DatasetHeader;
+import org.apache.iotdb.db.mpp.common.header.DatasetHeaderFactory;
 import org.apache.iotdb.db.mpp.plan.execution.config.ConfigTaskResult;
 import org.apache.iotdb.db.mpp.plan.execution.config.IConfigTask;
 import org.apache.iotdb.db.mpp.plan.execution.config.executor.IConfigTaskExecutor;
+import org.apache.iotdb.rpc.TSStatusCode;
+import org.apache.iotdb.tsfile.file.metadata.enums.TSDataType;
+import org.apache.iotdb.tsfile.read.common.block.TsBlockBuilder;
+import org.apache.iotdb.tsfile.utils.Binary;
+import org.apache.iotdb.tsfile.utils.ReadWriteIOUtils;
 
 import com.google.common.util.concurrent.ListenableFuture;
 import com.google.common.util.concurrent.SettableFuture;
 
 import java.nio.ByteBuffer;
 import java.util.List;
+import java.util.stream.Collectors;
 
 public class ShowTrailsTask implements IConfigTask {
 
@@ -44,5 +54,26 @@ public class ShowTrailsTask implements IConfigTask {
   }
 
   public static void buildTsBlock(
-      List<ByteBuffer> trailInfoList, SettableFuture<ConfigTaskResult> future) {}
+      List<ByteBuffer> trailInfoList, SettableFuture<ConfigTaskResult> future) {
+    List<TSDataType> outputDataTypes =
+        ColumnHeaderConstant.showTrailsColumnHeaders.stream()
+            .map(ColumnHeader::getColumnType)
+            .collect(Collectors.toList());
+    TsBlockBuilder builder = new TsBlockBuilder(outputDataTypes);
+    for (ByteBuffer trailInfo : trailInfoList) {
+      builder.getTimeColumnBuilder().writeLong(0L);
+      builder
+          .getColumnBuilder(0)
+          .writeBinary(Binary.valueOf(ReadWriteIOUtils.readString(trailInfo)));
+      builder
+          .getColumnBuilder(1)
+          .writeBinary(Binary.valueOf(ReadWriteIOUtils.readString(trailInfo)));
+      builder
+          .getColumnBuilder(2)
+          .writeBinary(Binary.valueOf(ReadWriteIOUtils.readString(trailInfo)));
+      builder.declarePosition();
+    }
+    DatasetHeader datasetHeader = DatasetHeaderFactory.getShowTrailsHeader();
+    future.set(new ConfigTaskResult(TSStatusCode.SUCCESS_STATUS, builder.build(), datasetHeader));
+  }
 }
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/parser/ASTVisitor.java b/server/src/main/java/org/apache/iotdb/db/mpp/plan/parser/ASTVisitor.java
index 510b6a808d..9d5ff90e58 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/plan/parser/ASTVisitor.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/plan/parser/ASTVisitor.java
@@ -923,7 +923,8 @@ public class ASTVisitor extends IoTDBSqlParserBaseVisitor<Statement> {
 
     Map<String, String> attributes = new HashMap<>();
     for (IoTDBSqlParser.AttributePairContext attribute : ctx.attributePair()) {
-      attributes.put(parseAttributeKey(attribute.key), parseAttributeValue(attribute.value));
+      attributes.put(
+          parseAttributeKey(attribute.key).toLowerCase(), parseAttributeValue(attribute.value));
     }
     createModelStatement.setAttributes(attributes);
 
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/statement/metadata/model/CreateModelStatement.java b/server/src/main/java/org/apache/iotdb/db/mpp/plan/statement/metadata/model/CreateModelStatement.java
index 11b552c903..0eecf66558 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/plan/statement/metadata/model/CreateModelStatement.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/plan/statement/metadata/model/CreateModelStatement.java
@@ -19,7 +19,9 @@
 
 package org.apache.iotdb.db.mpp.plan.statement.metadata.model;
 
+import org.apache.iotdb.common.rpc.thrift.ModelTask;
 import org.apache.iotdb.commons.path.PartialPath;
+import org.apache.iotdb.db.exception.sql.SemanticException;
 import org.apache.iotdb.db.mpp.plan.analyze.QueryType;
 import org.apache.iotdb.db.mpp.plan.statement.IConfigStatement;
 import org.apache.iotdb.db.mpp.plan.statement.Statement;
@@ -71,7 +73,22 @@ public class CreateModelStatement extends Statement implements IConfigStatement
     this.queryStatement = queryStatement;
   }
 
-  public void semanticCheck() {}
+  public ModelTask getModelTask() {
+    return ModelTask.valueOf(attributes.get("model_task"));
+  }
+
+  public String getModelType() {
+    return attributes.get("model_type");
+  }
+
+  public void semanticCheck() {
+    if (!attributes.containsKey("model_task")) {
+      throw new SemanticException("The attribute `model_task` must be specified.");
+    }
+    if (!attributes.containsKey("model_type")) {
+      throw new SemanticException("The attribute `model_type` must be specified.");
+    }
+  }
 
   @Override
   public List<? extends PartialPath> getPaths() {