You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@iotdb.apache.org by hu...@apache.org on 2023/03/14 01:52:04 UTC
[iotdb] 09/11: finish & fix bugs
This is an automated email from the ASF dual-hosted git repository.
hui pushed a commit to branch lmh/MLSQL
in repository https://gitbox.apache.org/repos/asf/iotdb.git
commit c2374bcc3ef97b1945b19510a92061a89bc49bc9
Author: Minghui Liu <li...@foxmail.com>
AuthorDate: Mon Mar 13 15:03:54 2023 +0800
finish & fix bugs
---
.../org/apache/iotdb/db/qp/sql/IoTDBSqlParser.g4 | 3 +-
.../antlr4/org/apache/iotdb/db/qp/sql/SqlLexer.g4 | 4 --
.../iotdb/confignode/persistence/ModelInfo.java | 12 ++++-
.../iotdb/commons/model/ModelInformation.java | 3 +-
.../apache/iotdb/db/client/ConfigNodeClient.java | 58 ++++++++++++++++++----
.../db/mpp/common/header/ColumnHeaderConstant.java | 25 ++++++++++
.../db/mpp/common/header/DatasetHeaderFactory.java | 8 +++
.../config/executor/ClusterConfigTaskExecutor.java | 7 ++-
.../config/metadata/model/ShowModelsTask.java | 45 ++++++++++++++++-
.../config/metadata/model/ShowTrailsTask.java | 33 +++++++++++-
.../iotdb/db/mpp/plan/parser/ASTVisitor.java | 3 +-
.../metadata/model/CreateModelStatement.java | 19 ++++++-
12 files changed, 196 insertions(+), 24 deletions(-)
diff --git a/antlr/src/main/antlr4/org/apache/iotdb/db/qp/sql/IoTDBSqlParser.g4 b/antlr/src/main/antlr4/org/apache/iotdb/db/qp/sql/IoTDBSqlParser.g4
index 109a5b3bbd..e727c4566b 100644
--- a/antlr/src/main/antlr4/org/apache/iotdb/db/qp/sql/IoTDBSqlParser.g4
+++ b/antlr/src/main/antlr4/org/apache/iotdb/db/qp/sql/IoTDBSqlParser.g4
@@ -452,8 +452,7 @@ migrateRegion
// ---- Create Model
createModel
: CREATE AUTO? MODEL modelId=identifier
- WITH MODEL_TASK operator_eq modelTask=attributeValue
- (COMMA attributePair)*
+ WITH attributePair (COMMA attributePair)*
BEGIN
selectStatement
END
diff --git a/antlr/src/main/antlr4/org/apache/iotdb/db/qp/sql/SqlLexer.g4 b/antlr/src/main/antlr4/org/apache/iotdb/db/qp/sql/SqlLexer.g4
index 2ddab5fd35..f06d70e318 100644
--- a/antlr/src/main/antlr4/org/apache/iotdb/db/qp/sql/SqlLexer.g4
+++ b/antlr/src/main/antlr4/org/apache/iotdb/db/qp/sql/SqlLexer.g4
@@ -394,10 +394,6 @@ MODELS
: M O D E L S
;
-MODEL_TASK
- : M O D E L '_' T A S K
- ;
-
NODEID
: N O D E I D
;
diff --git a/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java b/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java
index 58e3ec4e6b..13e4dabe73 100644
--- a/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java
+++ b/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java
@@ -122,7 +122,17 @@ public class ModelInfo implements SnapshotProcessor {
public TrailTableResp showTrail(ShowTrailPlan plan) {
acquireModelTableLock();
try {
- ModelInformation modelInformation = modelTable.getModelInformationById(plan.getModelId());
+ String modelId = plan.getModelId();
+ ModelInformation modelInformation = modelTable.getModelInformationById(modelId);
+ if (modelInformation == null) {
+ return new TrailTableResp(
+ new TSStatus(TSStatusCode.EXECUTE_STATEMENT_ERROR.getStatusCode())
+ .setMessage(
+ String.format(
+ "Failed to show trails of model [%s], this model has not been created.",
+ modelId)));
+ }
+
TrailTableResp trailTableResp =
new TrailTableResp(new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode()));
if (plan.isSetTrailId()) {
diff --git a/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelInformation.java b/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelInformation.java
index ecccda085e..e6fbf13c95 100644
--- a/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelInformation.java
+++ b/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelInformation.java
@@ -50,7 +50,7 @@ public class ModelInformation {
private TrainingState trainingState;
private String bestTrailId;
- private Map<String, TrailInformation> trailMap;
+ private final Map<String, TrailInformation> trailMap;
public ModelInformation(
String modelId,
@@ -65,6 +65,7 @@ public class ModelInformation {
this.isAuto = isAuto;
this.queryExpressions = queryExpressions;
this.queryFilter = queryFilter;
+ this.trailMap = new HashMap<>();
}
public ModelInformation(ByteBuffer buffer) {
diff --git a/server/src/main/java/org/apache/iotdb/db/client/ConfigNodeClient.java b/server/src/main/java/org/apache/iotdb/db/client/ConfigNodeClient.java
index c525d5cd75..ad9ee97782 100644
--- a/server/src/main/java/org/apache/iotdb/db/client/ConfigNodeClient.java
+++ b/server/src/main/java/org/apache/iotdb/db/client/ConfigNodeClient.java
@@ -1919,37 +1919,75 @@ public class ConfigNodeClient implements IConfigNodeRPCService.Iface, ThriftClie
@Override
public TSStatus createModel(TCreateModelReq req) throws TException {
- // TODO
- throw new TException(new UnsupportedOperationException().getCause());
+ for (int i = 0; i < RETRY_NUM; i++) {
+ try {
+ TSStatus status = client.createModel(req);
+ if (!updateConfigNodeLeader(status)) {
+ return status;
+ }
+ } catch (TException e) {
+ configLeader = null;
+ }
+ waitAndReconnect();
+ }
+ throw new TException(MSG_RECONNECTION_FAIL);
}
@Override
public TSStatus dropModel(TDropModelReq req) throws TException {
- // TODO
- throw new TException(new UnsupportedOperationException().getCause());
+ for (int i = 0; i < RETRY_NUM; i++) {
+ try {
+ TSStatus status = client.dropModel(req);
+ if (!updateConfigNodeLeader(status)) {
+ return status;
+ }
+ } catch (TException e) {
+ configLeader = null;
+ }
+ waitAndReconnect();
+ }
+ throw new TException(MSG_RECONNECTION_FAIL);
}
@Override
public TShowModelResp showModel(TShowModelReq req) throws TException {
- // TODO
- throw new TException(new UnsupportedOperationException().getCause());
+ for (int i = 0; i < RETRY_NUM; i++) {
+ try {
+ TShowModelResp resp = client.showModel(req);
+ if (!updateConfigNodeLeader(resp.getStatus())) {
+ return resp;
+ }
+ } catch (TException e) {
+ configLeader = null;
+ }
+ waitAndReconnect();
+ }
+ throw new TException(MSG_RECONNECTION_FAIL);
}
@Override
public TShowTrailResp showTrail(TShowTrailReq req) throws TException {
- // TODO
- throw new TException(new UnsupportedOperationException().getCause());
+ for (int i = 0; i < RETRY_NUM; i++) {
+ try {
+ TShowTrailResp resp = client.showTrail(req);
+ if (!updateConfigNodeLeader(resp.getStatus())) {
+ return resp;
+ }
+ } catch (TException e) {
+ configLeader = null;
+ }
+ waitAndReconnect();
+ }
+ throw new TException(MSG_RECONNECTION_FAIL);
}
@Override
public TSStatus updateModelInfo(TUpdateModelInfoReq req) throws TException {
- // TODO
throw new TException(new UnsupportedOperationException().getCause());
}
@Override
public TSStatus updateModelState(TUpdateModelStateReq req) throws TException {
- // TODO
throw new TException(new UnsupportedOperationException().getCause());
}
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/common/header/ColumnHeaderConstant.java b/server/src/main/java/org/apache/iotdb/db/mpp/common/header/ColumnHeaderConstant.java
index baaf9a773c..88a22dd1ea 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/common/header/ColumnHeaderConstant.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/common/header/ColumnHeaderConstant.java
@@ -153,6 +153,15 @@ public class ColumnHeaderConstant {
public static final String ELAPSED_TIME = "ElapsedTime";
public static final String STATEMENT = "Statement";
+ // column names for show models/trails
+ public static final String MODEL_ID = "ModelId";
+ public static final String TRAIL_ID = "TrailId";
+ public static final String MODEL_TASK = "ModelTask";
+ public static final String MODEL_TYPE = "ModelType";
+ public static final String QUERY_BODY = "QueryBody";
+ public static final String HYPERPARAMETER = "Hyperparameter";
+ public static final String MODEL_PATH = "ModelPath";
+
public static final List<ColumnHeader> lastQueryColumnHeaders =
ImmutableList.of(
new ColumnHeader(TIMESERIES, TSDataType.TEXT),
@@ -381,4 +390,20 @@ public class ColumnHeaderConstant {
new ColumnHeader(DATA_NODE_ID, TSDataType.INT32),
new ColumnHeader(ELAPSED_TIME, TSDataType.FLOAT),
new ColumnHeader(STATEMENT, TSDataType.TEXT));
+
+ public static final List<ColumnHeader> showModelsColumnHeaders =
+ ImmutableList.of(
+ new ColumnHeader(MODEL_ID, TSDataType.TEXT),
+ new ColumnHeader(MODEL_TASK, TSDataType.TEXT),
+ new ColumnHeader(MODEL_TYPE, TSDataType.TEXT),
+ new ColumnHeader(QUERY_BODY, TSDataType.TEXT),
+ new ColumnHeader(STATE, TSDataType.TEXT),
+ new ColumnHeader(HYPERPARAMETER, TSDataType.TEXT),
+ new ColumnHeader(MODEL_PATH, TSDataType.TEXT));
+
+ public static final List<ColumnHeader> showTrailsColumnHeaders =
+ ImmutableList.of(
+ new ColumnHeader(TRAIL_ID, TSDataType.TEXT),
+ new ColumnHeader(HYPERPARAMETER, TSDataType.TEXT),
+ new ColumnHeader(MODEL_PATH, TSDataType.TEXT));
}
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/common/header/DatasetHeaderFactory.java b/server/src/main/java/org/apache/iotdb/db/mpp/common/header/DatasetHeaderFactory.java
index 7e8eb19b05..ff0bb03b9b 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/common/header/DatasetHeaderFactory.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/common/header/DatasetHeaderFactory.java
@@ -164,4 +164,12 @@ public class DatasetHeaderFactory {
public static DatasetHeader getShowQueriesHeader() {
return new DatasetHeader(ColumnHeaderConstant.showQueriesColumnHeaders, false);
}
+
+ public static DatasetHeader getShowModelsHeader() {
+ return new DatasetHeader(ColumnHeaderConstant.showModelsColumnHeaders, false);
+ }
+
+ public static DatasetHeader getShowTrailsHeader() {
+ return new DatasetHeader(ColumnHeaderConstant.showTrailsColumnHeaders, false);
+ }
}
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/execution/config/executor/ClusterConfigTaskExecutor.java b/server/src/main/java/org/apache/iotdb/db/mpp/plan/execution/config/executor/ClusterConfigTaskExecutor.java
index 8730967fd0..c5c8ea7e45 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/plan/execution/config/executor/ClusterConfigTaskExecutor.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/plan/execution/config/executor/ClusterConfigTaskExecutor.java
@@ -1658,16 +1658,19 @@ public class ClusterConfigTaskExecutor implements IConfigTaskExecutor {
for (Expression expression : analysis.getSelectExpressions()) {
queryExpressions.add(expression.toString());
}
- String queryFilter = analysis.getWhereExpression().toString();
+ Expression whereExpression = analysis.getWhereExpression();
+ String queryFilter = whereExpression == null ? null : whereExpression.toString();
SettableFuture<ConfigTaskResult> future = SettableFuture.create();
try (ConfigNodeClient client =
CONFIG_NODE_CLIENT_MANAGER.borrowClient(ConfigNodeInfo.configNodeRegionId)) {
TCreateModelReq createModelReq = new TCreateModelReq();
createModelReq.setModelId(createModelStatement.getModelId());
+ createModelReq.setModelTask(createModelStatement.getModelTask());
+ createModelReq.setModelType(createModelStatement.getModelType());
+ createModelReq.setIsAuto(createModelStatement.isAuto());
createModelReq.setQueryExpressions(queryExpressions);
createModelReq.setQueryFilter(queryFilter);
- createModelReq.setIsAuto(createModelStatement.isAuto());
createModelReq.setModelConfigs(createModelStatement.getAttributes());
final TSStatus executionStatus = client.createModel(createModelReq);
if (TSStatusCode.SUCCESS_STATUS.getStatusCode() != executionStatus.getCode()) {
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/execution/config/metadata/model/ShowModelsTask.java b/server/src/main/java/org/apache/iotdb/db/mpp/plan/execution/config/metadata/model/ShowModelsTask.java
index 7f7a3d1db1..7fd719e0fd 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/plan/execution/config/metadata/model/ShowModelsTask.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/plan/execution/config/metadata/model/ShowModelsTask.java
@@ -19,15 +19,25 @@
package org.apache.iotdb.db.mpp.plan.execution.config.metadata.model;
+import org.apache.iotdb.db.mpp.common.header.ColumnHeader;
+import org.apache.iotdb.db.mpp.common.header.ColumnHeaderConstant;
+import org.apache.iotdb.db.mpp.common.header.DatasetHeader;
+import org.apache.iotdb.db.mpp.common.header.DatasetHeaderFactory;
import org.apache.iotdb.db.mpp.plan.execution.config.ConfigTaskResult;
import org.apache.iotdb.db.mpp.plan.execution.config.IConfigTask;
import org.apache.iotdb.db.mpp.plan.execution.config.executor.IConfigTaskExecutor;
+import org.apache.iotdb.rpc.TSStatusCode;
+import org.apache.iotdb.tsfile.file.metadata.enums.TSDataType;
+import org.apache.iotdb.tsfile.read.common.block.TsBlockBuilder;
+import org.apache.iotdb.tsfile.utils.Binary;
+import org.apache.iotdb.tsfile.utils.ReadWriteIOUtils;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.SettableFuture;
import java.nio.ByteBuffer;
import java.util.List;
+import java.util.stream.Collectors;
public class ShowModelsTask implements IConfigTask {
@@ -40,5 +50,38 @@ public class ShowModelsTask implements IConfigTask {
}
public static void buildTsBlock(
- List<ByteBuffer> modelInfoList, SettableFuture<ConfigTaskResult> future) {}
+ List<ByteBuffer> modelInfoList, SettableFuture<ConfigTaskResult> future) {
+ List<TSDataType> outputDataTypes =
+ ColumnHeaderConstant.showModelsColumnHeaders.stream()
+ .map(ColumnHeader::getColumnType)
+ .collect(Collectors.toList());
+ TsBlockBuilder builder = new TsBlockBuilder(outputDataTypes);
+ for (ByteBuffer modelInfo : modelInfoList) {
+ builder.getTimeColumnBuilder().writeLong(0L);
+ builder
+ .getColumnBuilder(0)
+ .writeBinary(Binary.valueOf(ReadWriteIOUtils.readString(modelInfo)));
+ builder
+ .getColumnBuilder(1)
+ .writeBinary(Binary.valueOf(ReadWriteIOUtils.readString(modelInfo)));
+ builder
+ .getColumnBuilder(2)
+ .writeBinary(Binary.valueOf(ReadWriteIOUtils.readString(modelInfo)));
+ builder
+ .getColumnBuilder(3)
+ .writeBinary(Binary.valueOf(ReadWriteIOUtils.readString(modelInfo)));
+ builder
+ .getColumnBuilder(4)
+ .writeBinary(Binary.valueOf(ReadWriteIOUtils.readString(modelInfo)));
+ builder
+ .getColumnBuilder(5)
+ .writeBinary(Binary.valueOf(ReadWriteIOUtils.readString(modelInfo)));
+ builder
+ .getColumnBuilder(6)
+ .writeBinary(Binary.valueOf(ReadWriteIOUtils.readString(modelInfo)));
+ builder.declarePosition();
+ }
+ DatasetHeader datasetHeader = DatasetHeaderFactory.getShowModelsHeader();
+ future.set(new ConfigTaskResult(TSStatusCode.SUCCESS_STATUS, builder.build(), datasetHeader));
+ }
}
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/execution/config/metadata/model/ShowTrailsTask.java b/server/src/main/java/org/apache/iotdb/db/mpp/plan/execution/config/metadata/model/ShowTrailsTask.java
index 948a8e91fb..a428c27794 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/plan/execution/config/metadata/model/ShowTrailsTask.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/plan/execution/config/metadata/model/ShowTrailsTask.java
@@ -19,15 +19,25 @@
package org.apache.iotdb.db.mpp.plan.execution.config.metadata.model;
+import org.apache.iotdb.db.mpp.common.header.ColumnHeader;
+import org.apache.iotdb.db.mpp.common.header.ColumnHeaderConstant;
+import org.apache.iotdb.db.mpp.common.header.DatasetHeader;
+import org.apache.iotdb.db.mpp.common.header.DatasetHeaderFactory;
import org.apache.iotdb.db.mpp.plan.execution.config.ConfigTaskResult;
import org.apache.iotdb.db.mpp.plan.execution.config.IConfigTask;
import org.apache.iotdb.db.mpp.plan.execution.config.executor.IConfigTaskExecutor;
+import org.apache.iotdb.rpc.TSStatusCode;
+import org.apache.iotdb.tsfile.file.metadata.enums.TSDataType;
+import org.apache.iotdb.tsfile.read.common.block.TsBlockBuilder;
+import org.apache.iotdb.tsfile.utils.Binary;
+import org.apache.iotdb.tsfile.utils.ReadWriteIOUtils;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.SettableFuture;
import java.nio.ByteBuffer;
import java.util.List;
+import java.util.stream.Collectors;
public class ShowTrailsTask implements IConfigTask {
@@ -44,5 +54,26 @@ public class ShowTrailsTask implements IConfigTask {
}
public static void buildTsBlock(
- List<ByteBuffer> trailInfoList, SettableFuture<ConfigTaskResult> future) {}
+ List<ByteBuffer> trailInfoList, SettableFuture<ConfigTaskResult> future) {
+ List<TSDataType> outputDataTypes =
+ ColumnHeaderConstant.showTrailsColumnHeaders.stream()
+ .map(ColumnHeader::getColumnType)
+ .collect(Collectors.toList());
+ TsBlockBuilder builder = new TsBlockBuilder(outputDataTypes);
+ for (ByteBuffer trailInfo : trailInfoList) {
+ builder.getTimeColumnBuilder().writeLong(0L);
+ builder
+ .getColumnBuilder(0)
+ .writeBinary(Binary.valueOf(ReadWriteIOUtils.readString(trailInfo)));
+ builder
+ .getColumnBuilder(1)
+ .writeBinary(Binary.valueOf(ReadWriteIOUtils.readString(trailInfo)));
+ builder
+ .getColumnBuilder(2)
+ .writeBinary(Binary.valueOf(ReadWriteIOUtils.readString(trailInfo)));
+ builder.declarePosition();
+ }
+ DatasetHeader datasetHeader = DatasetHeaderFactory.getShowTrailsHeader();
+ future.set(new ConfigTaskResult(TSStatusCode.SUCCESS_STATUS, builder.build(), datasetHeader));
+ }
}
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/parser/ASTVisitor.java b/server/src/main/java/org/apache/iotdb/db/mpp/plan/parser/ASTVisitor.java
index 510b6a808d..9d5ff90e58 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/plan/parser/ASTVisitor.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/plan/parser/ASTVisitor.java
@@ -923,7 +923,8 @@ public class ASTVisitor extends IoTDBSqlParserBaseVisitor<Statement> {
Map<String, String> attributes = new HashMap<>();
for (IoTDBSqlParser.AttributePairContext attribute : ctx.attributePair()) {
- attributes.put(parseAttributeKey(attribute.key), parseAttributeValue(attribute.value));
+ attributes.put(
+ parseAttributeKey(attribute.key).toLowerCase(), parseAttributeValue(attribute.value));
}
createModelStatement.setAttributes(attributes);
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/plan/statement/metadata/model/CreateModelStatement.java b/server/src/main/java/org/apache/iotdb/db/mpp/plan/statement/metadata/model/CreateModelStatement.java
index 11b552c903..0eecf66558 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/plan/statement/metadata/model/CreateModelStatement.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/plan/statement/metadata/model/CreateModelStatement.java
@@ -19,7 +19,9 @@
package org.apache.iotdb.db.mpp.plan.statement.metadata.model;
+import org.apache.iotdb.common.rpc.thrift.ModelTask;
import org.apache.iotdb.commons.path.PartialPath;
+import org.apache.iotdb.db.exception.sql.SemanticException;
import org.apache.iotdb.db.mpp.plan.analyze.QueryType;
import org.apache.iotdb.db.mpp.plan.statement.IConfigStatement;
import org.apache.iotdb.db.mpp.plan.statement.Statement;
@@ -71,7 +73,22 @@ public class CreateModelStatement extends Statement implements IConfigStatement
this.queryStatement = queryStatement;
}
- public void semanticCheck() {}
+ public ModelTask getModelTask() {
+ return ModelTask.valueOf(attributes.get("model_task"));
+ }
+
+ public String getModelType() {
+ return attributes.get("model_type");
+ }
+
+ public void semanticCheck() {
+ if (!attributes.containsKey("model_task")) {
+ throw new SemanticException("The attribute `model_task` must be specified.");
+ }
+ if (!attributes.containsKey("model_type")) {
+ throw new SemanticException("The attribute `model_type` must be specified.");
+ }
+ }
@Override
public List<? extends PartialPath> getPaths() {