You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@iotdb.apache.org by ro...@apache.org on 2023/03/31 18:42:47 UTC
[iotdb] branch master updated: [IOTDB-5373] Implement PipeSubTask and PipeExecutor (#9480)
This is an automated email from the ASF dual-hosted git repository.
rong pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iotdb.git
The following commit(s) were added to refs/heads/master by this push:
new c49e307114 [IOTDB-5373] Implement PipeSubTask and PipeExecutor (#9480)
c49e307114 is described below
commit c49e3071140ff8088e834956050b00091bcec486
Author: Itami Sho <42...@users.noreply.github.com>
AuthorDate: Sat Apr 1 02:42:38 2023 +0800
[IOTDB-5373] Implement PipeSubTask and PipeExecutor (#9480)
Co-authored-by: Steve Yurong Su <ro...@apache.org>
---
mlnode/iotdb/mlnode/algorithm/__init__.py | 17 --
mlnode/iotdb/mlnode/algorithm/enums.py | 29 ---
mlnode/iotdb/mlnode/algorithm/factory.py | 128 --------------
mlnode/iotdb/mlnode/algorithm/metric.py | 69 --------
mlnode/iotdb/mlnode/algorithm/models/__init__.py | 17 --
.../mlnode/algorithm/models/forecast/__init__.py | 20 ---
.../mlnode/algorithm/models/forecast/dlinear.py | 161 -----------------
.../mlnode/algorithm/models/forecast/nbeats.py | 169 ------------------
mlnode/iotdb/mlnode/client.py | 23 ++-
mlnode/iotdb/mlnode/constant.py | 10 --
mlnode/iotdb/mlnode/data_access/__init__.py | 17 --
mlnode/iotdb/mlnode/data_access/enums.py | 29 ---
mlnode/iotdb/mlnode/data_access/factory.py | 105 -----------
.../iotdb/mlnode/data_access/offline/__init__.py | 17 --
mlnode/iotdb/mlnode/data_access/offline/dataset.py | 96 ----------
mlnode/iotdb/mlnode/data_access/offline/source.py | 97 -----------
mlnode/iotdb/mlnode/data_access/utils/__init__.py | 17 --
.../iotdb/mlnode/data_access/utils/timefeatures.py | 171 ------------------
mlnode/iotdb/mlnode/exception.py | 15 --
mlnode/iotdb/mlnode/handler.py | 44 ++---
.../iotdb/mlnode/{storage.py => model_storage.py} | 0
mlnode/iotdb/mlnode/parser.py | 194 ---------------------
mlnode/iotdb/mlnode/serde.py | 30 +---
mlnode/iotdb/mlnode/util.py | 19 +-
mlnode/test/test_model_storage.py | 2 +-
mlnode/test/test_parse_training_request.py | 136 ---------------
.../resources/conf/iotdb-common.properties | 21 ++-
.../iotdb/commons/concurrent/ThreadName.java | 7 +-
.../java/org/apache/iotdb/db/conf/IoTDBConfig.java | 11 ++
.../org/apache/iotdb/db/conf/IoTDBDescriptor.java | 9 +
.../db/metadata/mtree/MTreeBelowSGCachedImpl.java | 1 -
.../db/metadata/mtree/MTreeBelowSGMemoryImpl.java | 1 -
.../db/pipe/agent/runtime/PipeRuntimeAgent.java | 15 ++
.../PipeConnectorPluginRuntimeWrapper.java | 44 ++++-
.../PipeProcessorPluginRuntimeWrapper.java | 48 ++++-
.../executor/PipeAssignerSubtaskExecutor.java | 12 +-
.../executor/PipeConnectorSubtaskExecutor.java | 12 +-
.../executor/PipeProcessorSubtaskExecutor.java | 12 +-
.../execution/executor/PipeSubtaskExecutor.java | 122 ++++++++++++-
...kExecutor.java => PipeTaskExecutorManager.java} | 40 +++--
.../scheduler/PipeAssignerSubtaskScheduler.java | 36 ----
.../scheduler/PipeConnectorSubtaskScheduler.java | 36 ----
.../scheduler/PipeProcessorSubtaskScheduler.java | 36 ----
.../execution/scheduler/PipeSubtaskScheduler.java | 33 ----
.../execution/scheduler/PipeTaskScheduler.java | 44 +++--
.../org/apache/iotdb/db/pipe/task/PipeTask.java | 31 +++-
.../DecoratingLock.java} | 26 ++-
.../PipeAssignerSubtask.java | 6 +-
.../PipeConnectorSubtask.java | 13 +-
.../PipeProcessorSubtask.java | 13 +-
.../iotdb/db/pipe/task/callable/PipeSubtask.java | 135 ++++++++++++++
.../pipe/task/metrics/PipeTaskRuntimeRecorder.java | 22 ---
.../db/pipe/task/stage/PipeTaskCollectorStage.java | 20 +--
.../db/pipe/task/stage/PipeTaskConnectorStage.java | 20 +--
.../db/pipe/task/stage/PipeTaskProcessorStage.java | 20 +--
.../iotdb/db/pipe/task/stage/PipeTaskStage.java | 37 +++-
.../executor/PipeAssignerSubtaskExecutorTest.java} | 20 ++-
.../PipeConnectorSubtaskExecutorTest.java} | 24 ++-
.../PipeProcessorSubtaskExecutorTest.java} | 24 ++-
.../executor/PipeSubtaskExecutorTest.java | 158 +++++++++++++++++
60 files changed, 871 insertions(+), 1870 deletions(-)
diff --git a/mlnode/iotdb/mlnode/algorithm/__init__.py b/mlnode/iotdb/mlnode/algorithm/__init__.py
deleted file mode 100644
index 2a1e720805..0000000000
--- a/mlnode/iotdb/mlnode/algorithm/__init__.py
+++ /dev/null
@@ -1,17 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
diff --git a/mlnode/iotdb/mlnode/algorithm/enums.py b/mlnode/iotdb/mlnode/algorithm/enums.py
deleted file mode 100644
index 4b05aa4bf8..0000000000
--- a/mlnode/iotdb/mlnode/algorithm/enums.py
+++ /dev/null
@@ -1,29 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-from enum import Enum
-
-
-class ForecastTaskType(Enum):
- ENDOGENOUS = "endogenous"
- EXOGENOUS = "exogenous"
-
- def __str__(self):
- return self.value
-
- def __eq__(self, other: str) -> bool:
- return self.value == other
diff --git a/mlnode/iotdb/mlnode/algorithm/factory.py b/mlnode/iotdb/mlnode/algorithm/factory.py
deleted file mode 100644
index 92cb01a883..0000000000
--- a/mlnode/iotdb/mlnode/algorithm/factory.py
+++ /dev/null
@@ -1,128 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-import torch.nn as nn
-
-from iotdb.mlnode.algorithm.enums import ForecastTaskType
-from iotdb.mlnode.algorithm.models.forecast import support_forecasting_models
-from iotdb.mlnode.exception import BadConfigValueError
-
-
-# Common configs for all forecasting model with default values
-def _common_config(**kwargs):
- return {
- 'input_len': 96,
- 'pred_len': 96,
- 'input_vars': 1,
- 'output_vars': 1,
- **kwargs
- }
-
-
-# Common forecasting task configs
-support_common_configs = {
- # multivariate forecasting, current support this only
- ForecastTaskType.ENDOGENOUS: _common_config(
- input_vars=1,
- output_vars=1),
-
- # univariate forecasting with observable exogenous variables
- ForecastTaskType.EXOGENOUS: _common_config(
- output_vars=1),
-}
-
-
-def is_model(model_name: str) -> bool:
- """
- Check if a model name exists
- """
- return model_name in support_forecasting_models
-
-
-def list_model() -> list[str]:
- """
- List support forecasting model
- """
- return support_forecasting_models
-
-
-def create_forecast_model(
- model_name,
- forecast_task_type=ForecastTaskType.ENDOGENOUS,
- input_len=96,
- pred_len=96,
- input_vars=1,
- output_vars=1,
- **kwargs,
-) -> [nn.Module, dict]:
- """
- Factory method for all support forecasting models
- the given arguments is common configs shared by all forecasting models
- for specific model configs, see __model_config in `algorithm/models/MODELNAME.py`
-
- Args:
- model_name: see available models by `list_model`
- forecast_task_type: 'm' for multivariate forecasting, 'ms' for covariate forecasting,
- 's' for univariate forecasting
- input_len: time length of model input
- pred_len: time length of model output
- input_vars: number of input series
- output_vars: number of output series
- kwargs: for specific model configs, see returned `model_config` with kwargs=None
-
- Returns:
- model: torch.nn.Module
- model_config: dict of model configurations
- """
- if not is_model(model_name):
- raise BadConfigValueError('model_name', model_name, f'It should be one of {list_model()}')
- if forecast_task_type not in support_common_configs.keys():
- raise BadConfigValueError('forecast_task_type', forecast_task_type,
- f'It should be one of {list(support_common_configs.keys())}')
-
- common_config = support_common_configs[forecast_task_type]
- common_config['input_len'] = input_len
- common_config['pred_len'] = pred_len
- common_config['input_vars'] = input_vars
- common_config['output_vars'] = output_vars
- common_config['forecast_task_type'] = str(forecast_task_type)
-
- if not input_len > 0:
- raise BadConfigValueError('input_len', input_len,
- 'Length of input series should be positive')
- if not pred_len > 0:
- raise BadConfigValueError('pred_len', pred_len,
- 'Length of predicted series should be positive')
- if not input_vars > 0:
- raise BadConfigValueError('input_vars', input_vars,
- 'Number of input variates should be positive')
- if not output_vars > 0:
- raise BadConfigValueError('output_vars', output_vars,
- 'Number of output variates should be positive')
- if forecast_task_type == ForecastTaskType.ENDOGENOUS:
- if input_vars != output_vars:
- raise BadConfigValueError('forecast_task_type', forecast_task_type,
- 'Number of input/output variates should be '
- 'the same in multivariate forecast')
- create_fn = eval(model_name)
- model, model_config = create_fn(
- common_config=common_config,
- **kwargs
- )
- model_config['model_name'] = model_name
-
- return model, model_config
diff --git a/mlnode/iotdb/mlnode/algorithm/metric.py b/mlnode/iotdb/mlnode/algorithm/metric.py
deleted file mode 100644
index e623642191..0000000000
--- a/mlnode/iotdb/mlnode/algorithm/metric.py
+++ /dev/null
@@ -1,69 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-from abc import abstractmethod
-
-import numpy as np
-
-all_metrics = ['RSE', 'CORR', 'MAE', 'MSE', 'RMSE', 'MAPE', 'MSPE']
-
-
-class Metric(object):
- def __call__(self, pred, ground_truth):
- return self.calculate(pred, ground_truth)
-
- @abstractmethod
- def calculate(self, pred, ground_truth):
- pass
-
-
-class RSE(Metric):
- def calculate(self, pred, ground_truth):
- return np.sqrt(np.sum((ground_truth - pred) ** 2)) / np.sqrt(np.sum((ground_truth - ground_truth.mean()) ** 2))
-
-
-class CORR(Metric):
- def calculate(self, pred, ground_truth):
- u = ((ground_truth - ground_truth.mean(0)) * (pred - pred.mean(0))).sum(0)
- d = np.sqrt(((ground_truth - ground_truth.mean(0)) ** 2 * (pred - pred.mean(0)) ** 2).sum(0))
- return (u / d).mean(-1)
-
-
-class MAE(Metric):
- def calculate(self, pred, ground_truth):
- return np.mean(np.abs(pred - ground_truth))
-
-
-class MSE(Metric):
- def calculate(self, pred, ground_truth):
- return np.mean((pred - ground_truth) ** 2)
-
-
-class RMSE(Metric):
- def calculate(self, pred, ground_truth):
- mse = MSE()
- return np.sqrt(mse(pred, ground_truth))
-
-
-class MAPE(Metric):
- def calculate(self, pred, ground_truth):
- return np.mean(np.abs((pred - ground_truth) / ground_truth))
-
-
-class MSPE(Metric):
- def calculate(self, pred, ground_truth):
- return np.mean(np.square((pred - ground_truth) / ground_truth))
diff --git a/mlnode/iotdb/mlnode/algorithm/models/__init__.py b/mlnode/iotdb/mlnode/algorithm/models/__init__.py
deleted file mode 100644
index 2a1e720805..0000000000
--- a/mlnode/iotdb/mlnode/algorithm/models/__init__.py
+++ /dev/null
@@ -1,17 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
diff --git a/mlnode/iotdb/mlnode/algorithm/models/forecast/__init__.py b/mlnode/iotdb/mlnode/algorithm/models/forecast/__init__.py
deleted file mode 100644
index 2abb5faf37..0000000000
--- a/mlnode/iotdb/mlnode/algorithm/models/forecast/__init__.py
+++ /dev/null
@@ -1,20 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-
-
-support_forecasting_models = ['dlinear', 'dlinear_individual', 'nbeats']
diff --git a/mlnode/iotdb/mlnode/algorithm/models/forecast/dlinear.py b/mlnode/iotdb/mlnode/algorithm/models/forecast/dlinear.py
deleted file mode 100644
index fa9ee04e56..0000000000
--- a/mlnode/iotdb/mlnode/algorithm/models/forecast/dlinear.py
+++ /dev/null
@@ -1,161 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-
-import math
-
-import torch
-import torch.nn as nn
-
-from iotdb.mlnode.exception import BadConfigValueError
-
-
-class MovingAverageBlock(nn.Module):
- """ Moving average block to highlight the trend of time series """
-
- def __init__(self, kernel_size, stride):
- super(MovingAverageBlock, self).__init__()
- self.kernel_size = kernel_size
- self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0)
-
- def forward(self, x):
- # padding on the both ends of time series
- front = x[:, 0:1, :].repeat(1, self.kernel_size - 1 - math.floor((self.kernel_size - 1) // 2), 1)
- end = x[:, -1:, :].repeat(1, math.floor((self.kernel_size - 1) // 2), 1)
- x = torch.cat([front, x, end], dim=1)
- x = self.avg(x.permute(0, 2, 1))
- x = x.permute(0, 2, 1)
- return x
-
-
-class SeriesDecompositionBlock(nn.Module):
- """ Series decomposition block """
-
- def __init__(self, kernel_size):
- super(SeriesDecompositionBlock, self).__init__()
- self.moving_avg = MovingAverageBlock(kernel_size, stride=1)
-
- def forward(self, x):
- moving_mean = self.moving_avg(x)
- res = x - moving_mean
- return res, moving_mean
-
-
-class DLinear(nn.Module):
- """ Decomposition Linear Model """
-
- def __init__(
- self,
- kernel_size=25,
- input_len=96,
- pred_len=96,
- input_vars=1,
- output_vars=1,
- forecast_type='m', # TODO, support others
- ):
- super(DLinear, self).__init__()
- self.input_len = input_len
- self.pred_len = pred_len
- self.kernel_size = kernel_size
- self.channels = input_vars
-
- # decomposition Kernel Size
- self.decomposition = SeriesDecompositionBlock(kernel_size)
- self.linear_seasonal = nn.Linear(self.input_len, self.pred_len)
- self.linear_trend = nn.Linear(self.input_len, self.pred_len)
-
- def forward(self, x, *args):
- # x: [Batch, Input length, Channel]
- seasonal_init, trend_init = self.decomposition(x)
- seasonal_init, trend_init = seasonal_init.permute(0, 2, 1), trend_init.permute(0, 2, 1)
-
- seasonal_output = self.linear_seasonal(seasonal_init)
- trend_output = self.linear_trend(trend_init)
-
- x = seasonal_output + trend_output
- return x.permute(0, 2, 1) # to [Batch, Output length, Channel]
-
-
-class DLinearIndividual(nn.Module):
- """ Decomposition Linear Model (individual) """
-
- def __init__(
- self,
- kernel_size=25,
- input_len=96,
- pred_len=96,
- input_vars=1,
- output_vars=1,
- forecast_type='m', # TODO, support others
- ):
- super(DLinearIndividual, self).__init__()
- self.input_len = input_len
- self.pred_len = pred_len
- self.kernel_size = kernel_size
- self.channels = input_vars
-
- self.decomposition = SeriesDecompositionBlock(kernel_size)
- self.Linear_Seasonal = nn.ModuleList(
- [nn.Linear(self.input_len, self.pred_len) for _ in range(self.channels)]
- )
- self.Linear_Trend = nn.ModuleList(
- [nn.Linear(self.input_len, self.pred_len) for _ in range(self.channels)]
- )
-
- def forward(self, x, *args):
- # x: [Batch, Input length, Channel]
- seasonal_init, trend_init = self.decomposition(x)
- seasonal_init, trend_init = seasonal_init.permute(0, 2, 1), trend_init.permute(0, 2, 1)
-
- seasonal_output = torch.zeros([seasonal_init.size(0), seasonal_init.size(1), self.pred_len],
- dtype=seasonal_init.dtype).to(seasonal_init.device)
- trend_output = torch.zeros([trend_init.size(0), trend_init.size(1), self.pred_len],
- dtype=trend_init.dtype).to(trend_init.device)
- for i, linear_season_layer in enumerate(self.Linear_Seasonal):
- seasonal_output[:, i, :] = linear_season_layer(seasonal_init[:, i, :])
- for i, linear_trend_layer in enumerate(self.Linear_Trend):
- trend_output[:, i, :] = linear_trend_layer(trend_init[:, i, :])
-
- x = seasonal_output + trend_output
- return x.permute(0, 2, 1) # to [Batch, Output length, Channel]
-
-
-def _model_config(**kwargs):
- return {
- 'kernel_size': 25,
- **kwargs
- }
-
-
-def dlinear(common_config: dict, kernel_size=25, **kwargs) -> [DLinear, dict]:
- config = _model_config()
- config.update(**common_config)
- if not kernel_size > 0:
- raise BadConfigValueError('kernel_size', kernel_size,
- 'Kernel size of dlinear should larger than 0')
- config['kernel_size'] = kernel_size
- return DLinear(**config), config
-
-
-def dlinear_individual(common_config: dict, kernel_size=25, **kwargs) -> [DLinearIndividual, dict]:
- config = _model_config()
- config.update(**common_config)
- if not kernel_size > 0:
- raise BadConfigValueError('kernel_size', kernel_size,
- 'Kernel size of dlinear_individual should larger than 0')
- config['kernel_size'] = kernel_size
- return DLinearIndividual(**config), config
diff --git a/mlnode/iotdb/mlnode/algorithm/models/forecast/nbeats.py b/mlnode/iotdb/mlnode/algorithm/models/forecast/nbeats.py
deleted file mode 100644
index e3c3ca6a0a..0000000000
--- a/mlnode/iotdb/mlnode/algorithm/models/forecast/nbeats.py
+++ /dev/null
@@ -1,169 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-
-from typing import Tuple
-
-import torch
-import torch.nn as nn
-
-from iotdb.mlnode.exception import BadConfigValueError
-
-
-class GenericBasis(nn.Module):
- """ Generic basis function """
-
- def __init__(self, backcast_size: int, forecast_size: int):
- super().__init__()
- self.backcast_size = backcast_size
- self.forecast_size = forecast_size
-
- def forward(self, theta: torch.Tensor):
- return theta[:, :self.backcast_size], theta[:, -self.forecast_size:]
-
-
-block_dict = {
- 'generic': GenericBasis,
-}
-
-
-class NBeatsBlock(nn.Module):
- """ N-BEATS block which takes a basis function as an argument """
-
- def __init__(self,
- input_size,
- theta_size: int,
- basis_function: nn.Module,
- layers: int,
- layer_size: int):
- """
- N-BEATS block
-
- Args:
- input_size: input sample size
- theta_size: number of parameters for the basis function
- basis_function: basis function which takes the parameters and produces backcast and forecast
- layers: number of layers
- layer_size: layer size
- """
- super().__init__()
- self.layers = nn.ModuleList([nn.Linear(in_features=input_size, out_features=layer_size)] + [
- nn.Linear(in_features=layer_size, out_features=layer_size) for _ in range(layers - 1)])
- self.basis_parameters = nn.Linear(in_features=layer_size, out_features=theta_size)
- self.basis_function = basis_function
-
- def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
- block_input = x
- for layer in self.layers:
- block_input = torch.relu(layer(block_input))
- basis_parameters = self.basis_parameters(block_input)
- return self.basis_function(basis_parameters)
-
-
-class NBeatsUnivariate(nn.Module):
- """ N-Beats Model (univariate) """
-
- def __init__(self, blocks: nn.ModuleList):
- super().__init__()
- self.blocks = blocks
-
- def forward(self, x):
- residuals = x
- forecast = None
- for _, block in enumerate(self.blocks):
- backcast, block_forecast = block(residuals)
- residuals = (residuals - backcast)
- if forecast is None:
- forecast = block_forecast
- else:
- forecast += block_forecast
- return forecast
-
-
-class NBeats(nn.Module):
- """ Neural Basis Expansion Analysis Time Series """
-
- def __init__(
- self,
- block_type='generic',
- d_model=128,
- inner_layers=4,
- outer_layers=4,
- input_len=96,
- pred_len=96,
- input_vars=1,
- output_vars=1,
- forecast_type='m', # TODO, support others
- ):
- super(NBeats, self).__init__()
- self.enc_in = input_vars
- self.block = block_dict[block_type]
- self.model = NBeatsUnivariate(
- torch.nn.ModuleList(
- [NBeatsBlock(input_size=input_len,
- theta_size=input_len + pred_len,
- basis_function=self.block(backcast_size=input_len, forecast_size=pred_len),
- layers=inner_layers,
- layer_size=d_model)
- for _ in range(outer_layers)]
- )
- )
-
- def forward(self, x, *args):
- # x: [Batch, Input length, Channel]
- res = []
- for i in range(self.enc_in):
- dec_out = self.model(x[:, :, i])
- res.append(dec_out)
- return torch.stack(res, dim=-1) # to [Batch, Output length, Channel]
-
-
-def _model_config(**kwargs):
- return {
- 'block_type': 'generic',
- 'd_model': 128,
- 'inner_layers': 4,
- 'outer_layers': 4,
- **kwargs
- }
-
-
-"""
-Specific configs for NBeats variants
-"""
-support_model_configs = {
- 'nbeats': _model_config(
- block_type='generic'),
-}
-
-
-def nbeats(common_config: dict, d_model=128, inner_layers=4, outer_layers=4, **kwargs) -> [NBeats, dict]:
- config = _model_config()
- config.update(**common_config)
- if not d_model > 0:
- raise BadConfigValueError('d_model', d_model,
- 'Model dimension (d_model) of nbeats should larger than 0')
- if not inner_layers > 0:
- raise BadConfigValueError('inner_layers', inner_layers,
- 'Number of inner layers of nbeats should larger than 0')
- if not outer_layers > 0:
- raise BadConfigValueError('outer_layers', outer_layers,
- 'Number of outer layers of nbeats should larger than 0')
- config['d_model'] = d_model
- config['inner_layers'] = inner_layers
- config['outer_layers'] = outer_layers
- return NBeats(**config), config
diff --git a/mlnode/iotdb/mlnode/client.py b/mlnode/iotdb/mlnode/client.py
index 76eb754596..244b6975c9 100644
--- a/mlnode/iotdb/mlnode/client.py
+++ b/mlnode/iotdb/mlnode/client.py
@@ -22,9 +22,7 @@ from thrift.Thrift import TException
from thrift.transport import TSocket, TTransport
from iotdb.mlnode.config import config
-from iotdb.mlnode.constant import TSStatusCode
from iotdb.mlnode.log import logger
-from iotdb.mlnode.util import verify_success
from iotdb.thrift.common.ttypes import TEndPoint, TSStatus
from iotdb.thrift.confignode import IConfigNodeRPCService
from iotdb.thrift.confignode.ttypes import TUpdateModelInfoReq
@@ -35,6 +33,16 @@ from iotdb.thrift.datanode.ttypes import (TFetchTimeseriesReq,
from iotdb.thrift.mlnode import IMLNodeRPCService
from iotdb.thrift.mlnode.ttypes import TCreateTrainingTaskReq, TDeleteModelReq
+# status code
+SUCCESS_STATUS = 200
+REDIRECTION_RECOMMEND = 400
+
+
+def verify_success(status: TSStatus, err_msg: str) -> None:
+ if status.code != SUCCESS_STATUS:
+ logger.warn(err_msg + ", error status is ", status)
+ raise RuntimeError(str(status.code) + ": " + status.message)
+
class ClientManager(object):
def __init__(self):
@@ -70,7 +78,7 @@ class MLNodeClient(object):
model_id: str,
is_auto: bool,
model_configs: dict,
- query_expressions: list = [],
+ query_expressions: list[str],
query_filter: str = None) -> None:
req = TCreateTrainingTaskReq(
modelId=model_id,
@@ -116,7 +124,6 @@ class DataNodeClient(object):
transport.open()
except TTransport.TTransportException as e:
logger.exception("TTransportException!", exc_info=e)
- raise e
protocol = TBinaryProtocol.TBinaryProtocol(transport)
self.__client = IDataNodeRPCService.Client(protocol)
@@ -124,7 +131,7 @@ class DataNodeClient(object):
def fetch_timeseries(self,
session_id: int,
statement_id: int,
- query_expressions: list = [],
+ query_expressions: list[str],
query_filter: str = None,
fetch_size: int = DEFAULT_FETCH_SIZE,
timeout: int = DEFAULT_TIMEOUT) -> TFetchTimeseriesResp:
@@ -146,8 +153,8 @@ class DataNodeClient(object):
def record_model_metrics(self,
model_id: str,
trial_id: str,
- metrics: list = [],
- values: list = []) -> None:
+ metrics: list[str],
+ values: list[float]) -> None:
req = TRecordModelMetricsReq(
modelId=model_id,
trialId=trial_id,
@@ -235,7 +242,7 @@ class ConfigNodeClient(object):
pass
def __update_config_node_leader(self, status: TSStatus) -> bool:
- if status.code == TSStatusCode.REDIRECTION_RECOMMEND:
+ if status.code == REDIRECTION_RECOMMEND:
if status.redirectNode is not None:
self.__config_leader = status.redirectNode
else:
diff --git a/mlnode/iotdb/mlnode/constant.py b/mlnode/iotdb/mlnode/constant.py
index 3bffa06526..8a38aa95d8 100644
--- a/mlnode/iotdb/mlnode/constant.py
+++ b/mlnode/iotdb/mlnode/constant.py
@@ -15,19 +15,9 @@
# specific language governing permissions and limitations
# under the License.
#
-from enum import Enum
MLNODE_CONF_DIRECTORY_NAME = "conf"
MLNODE_CONF_FILE_NAME = "iotdb-mlnode.toml"
MLNODE_LOG_CONF_FILE_NAME = "logging_config.ini"
MLNODE_MODEL_STORAGE_DIRECTORY_NAME = "models"
-
-
-class TSStatusCode(Enum):
- SUCCESS_STATUS = 200
- REDIRECTION_RECOMMEND = 400
- FAIL_STATUS = 404
-
- def get_status_code(self) -> int:
- return self.value
diff --git a/mlnode/iotdb/mlnode/data_access/__init__.py b/mlnode/iotdb/mlnode/data_access/__init__.py
deleted file mode 100644
index 2a1e720805..0000000000
--- a/mlnode/iotdb/mlnode/data_access/__init__.py
+++ /dev/null
@@ -1,17 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
diff --git a/mlnode/iotdb/mlnode/data_access/enums.py b/mlnode/iotdb/mlnode/data_access/enums.py
deleted file mode 100644
index d21a9f69c4..0000000000
--- a/mlnode/iotdb/mlnode/data_access/enums.py
+++ /dev/null
@@ -1,29 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-from enum import Enum
-
-
-class DatasetType(Enum):
- TIMESERIES = "timeseries"
- WINDOW = "window"
-
- def __str__(self):
- return self.value
-
- def __eq__(self, other: str) -> bool:
- return self.value == other
diff --git a/mlnode/iotdb/mlnode/data_access/factory.py b/mlnode/iotdb/mlnode/data_access/factory.py
deleted file mode 100644
index d0041388a6..0000000000
--- a/mlnode/iotdb/mlnode/data_access/factory.py
+++ /dev/null
@@ -1,105 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-from torch.utils.data import Dataset
-
-from iotdb.mlnode.data_access.enums import DatasetType
-from iotdb.mlnode.data_access.offline.dataset import (TimeSeriesDataset,
- WindowDataset)
-from iotdb.mlnode.data_access.offline.source import (FileDataSource,
- ThriftDataSource)
-from iotdb.mlnode.exception import BadConfigValueError, MissingConfigError
-
-support_forecasting_dataset = {
- DatasetType.TIMESERIES: TimeSeriesDataset,
- DatasetType.WINDOW: WindowDataset
-}
-
-
-def _dataset_config(**kwargs):
- return {
- 'time_embed': 'h',
- **kwargs
- }
-
-
-support_dataset_configs = {
- DatasetType.TIMESERIES: _dataset_config(),
- DatasetType.WINDOW: _dataset_config(
- input_len=96,
- pred_len=96,
- )
-}
-
-
-def create_forecast_dataset(
- source_type,
- dataset_type,
- **kwargs,
-) -> [Dataset, dict]:
- """
- Factory method for all support dataset
- currently implement WindowDataset, TimeSeriesDataset
- for specific dataset configs, see _dataset_config in `algorithm/models/MODELNAME.py`
-
- Args:
- dataset_type: available choice in support_forecasting_dataset
- source_type: available choice in ['file', 'thrift']
- kwargs: for specific dataset configs, see returned `dataset_config` with kwargs=None
-
- Returns:
- dataset: torch.nn.Module
- dataset_config: dict of dataset configurations
- """
- if dataset_type not in support_forecasting_dataset.keys():
- raise BadConfigValueError('dataset_type', dataset_type,
- f'It should be one of {list(support_forecasting_dataset.keys())}')
-
- if source_type == 'file':
- if 'filename' not in kwargs.keys():
- raise MissingConfigError('filename')
- datasource = FileDataSource(kwargs['filename'])
- elif source_type == 'thrift':
- if 'query_expressions' not in kwargs.keys():
- raise MissingConfigError('query_expressions')
- if 'query_filter' not in kwargs.keys():
- raise MissingConfigError('query_filter')
- datasource = ThriftDataSource(kwargs['query_expressions'], kwargs['query_filter'])
- else:
- raise BadConfigValueError('source_type', source_type, "It should be one of ['file', 'thrift]")
-
- dataset_fn = support_forecasting_dataset[dataset_type]
- dataset_config = support_dataset_configs[dataset_type]
-
- for k, v in kwargs.items():
- if k in dataset_config.keys():
- dataset_config[k] = v
-
- dataset = dataset_fn(datasource, **dataset_config)
-
- if 'input_vars' in kwargs.keys() and dataset.get_variable_num() != kwargs['input_vars']:
- raise BadConfigValueError('input_vars', kwargs['input_vars'],
- f'Variable number of fetched data: ({dataset.get_variable_num()})'
- f' should be consistent with input_vars')
-
- data_config = dataset_config.copy()
- data_config['input_vars'] = dataset.get_variable_num()
- data_config['output_vars'] = dataset.get_variable_num()
- data_config['source_type'] = source_type
- data_config['dataset_type'] = dataset_type
-
- return dataset, data_config
diff --git a/mlnode/iotdb/mlnode/data_access/offline/__init__.py b/mlnode/iotdb/mlnode/data_access/offline/__init__.py
deleted file mode 100644
index 2a1e720805..0000000000
--- a/mlnode/iotdb/mlnode/data_access/offline/__init__.py
+++ /dev/null
@@ -1,17 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
diff --git a/mlnode/iotdb/mlnode/data_access/offline/dataset.py b/mlnode/iotdb/mlnode/data_access/offline/dataset.py
deleted file mode 100644
index 1a96e81a4a..0000000000
--- a/mlnode/iotdb/mlnode/data_access/offline/dataset.py
+++ /dev/null
@@ -1,96 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-from torch.utils.data import Dataset
-
-from iotdb.mlnode.data_access.offline.source import DataSource
-from iotdb.mlnode.data_access.utils.timefeatures import time_features
-
-
-class TimeSeriesDataset(Dataset):
- """
- Build Row-by-Row dataset (with each element as multivariable time series at
- the same time and correponding timestamp embedding)
-
- Args:
- data_source: the whole multivariate time series for a while
- time_embed: embedding frequency, see `utils/timefeatures.py` for more detail
-
- Returns:
- Random accessible dataset
- """
-
- def __init__(self, data_source: DataSource, time_embed: str = 'h'):
- self.time_embed = time_embed
- self.data = data_source.get_data()
- self.data_stamp = time_features(data_source.get_timestamp(), time_embed=self.time_embed).transpose(1, 0)
- self.n_vars = self.data.shape[-1]
-
- def get_variable_num(self):
- return self.n_vars # number of series in data_source
-
- def __getitem__(self, index):
- seq = self.data[index]
- seq_t = self.data_stamp[index]
- return seq, seq_t
-
- def __len__(self):
- return len(self.data)
-
-
-class WindowDataset(TimeSeriesDataset):
- """
- Build Windowed dataset (with each element as multivariable time series
- with a sliding window and corresponding timestamps embedding),
- the sliding step is one unit in give data source
-
- Args:
- data_source: the whole multivariate time series for a while
- time_embed: embedding frequency, see `utils/timefeatures.py` for more detail
- input_len: input window size (unit) [1, 2, ... I]
- pred_len: output window size (unit) right after the input window [I+1, I+2, ... I+P]
-
- Returns:
- Random accessible dataset
- """
-
- def __init__(self,
- data_source: DataSource = None,
- input_len: int = 96,
- pred_len: int = 96,
- time_embed: str = 'h'):
- self.input_len = input_len
- self.pred_len = pred_len
- super(WindowDataset, self).__init__(data_source, time_embed)
- if input_len > self.data.shape[0]:
- raise RuntimeError('input_len should not be larger than the number of time series points')
- if pred_len > self.data.shape[0]:
- raise RuntimeError('pred_len should not be larger than the number of time series points')
-
- def __getitem__(self, index):
- s_begin = index
- s_end = s_begin + self.input_len
- r_begin = s_end
- r_end = s_end + self.pred_len
- seq_x = self.data[s_begin:s_end]
- seq_y = self.data[r_begin:r_end]
- seq_x_t = self.data_stamp[s_begin:s_end]
- seq_y_t = self.data_stamp[r_begin:r_end]
- return seq_x, seq_y, seq_x_t, seq_y_t
-
- def __len__(self):
- return len(self.data) - self.input_len - self.pred_len + 1
diff --git a/mlnode/iotdb/mlnode/data_access/offline/source.py b/mlnode/iotdb/mlnode/data_access/offline/source.py
deleted file mode 100644
index a63371ec7a..0000000000
--- a/mlnode/iotdb/mlnode/data_access/offline/source.py
+++ /dev/null
@@ -1,97 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-import pandas as pd
-
-from iotdb.mlnode import serde
-from iotdb.mlnode.client import client_manager
-
-
-class DataSource(object):
- """
- Pre-fetched in multi-variate time series in memory
-
- Methods:
- get_data: returns self.data, the time series value (Numpy.2DArray)
- get_timestamp: returns self.timestamp, the aligned timestamp value
- """
-
- def __init__(self):
- self.data = None
- self.timestamp = None
- self._read_data()
-
- def _read_data(self):
- raise NotImplementedError
-
- def get_data(self):
- return self.data
-
- def get_timestamp(self):
- return self.timestamp
-
-
-class FileDataSource(DataSource):
- def __init__(self, filename: str = None):
- self.filename = filename
- super(FileDataSource, self).__init__()
-
- def _read_data(self):
- try:
- raw_data = pd.read_csv(self.filename)
- except Exception:
- raise RuntimeError(f'Fail to load data with filename: {self.filename}')
- cols_data = raw_data.columns[1:]
- self.data = raw_data[cols_data].values
- self.timestamp = pd.to_datetime(raw_data[raw_data.columns[0]].values)
-
-
-class ThriftDataSource(DataSource):
- def __init__(self, query_expressions: list = None, query_filter: str = None):
- self.query_expressions = query_expressions
- self.query_filter = query_filter
- super(ThriftDataSource, self).__init__()
-
- def _read_data(self):
- try:
- data_client = client_manager.borrow_data_node_client()
- except Exception:
- raise RuntimeError('Fail to establish connection with DataNode')
-
- try:
- res = data_client.fetch_timeseries(
- queryExpressions=self.query_expressions,
- queryFilter=self.query_filter,
- )
- except Exception:
- raise RuntimeError(f'Fail to fetch data with query expressions: {self.query_expressions}'
- f' and query filter: {self.query_filter}')
-
- if len(res.tsDataset) == 0:
- raise RuntimeError(f'No data fetched with query filter: {self.query_filter}')
-
- raw_data = serde.convert_to_df(res.columnNameList,
- res.columnTypeList,
- res.columnNameIndexMap,
- res.tsDataset)
- if raw_data.empty:
- raise RuntimeError(f'Fetched empty data with query expressions: '
- f'{self.query_expressions} and query filter: {self.query_filter}')
- cols_data = raw_data.columns[1:]
- self.data = raw_data[cols_data].values
- self.timestamp = pd.to_datetime(raw_data[raw_data.columns[0]].values, unit='ms', utc=True) \
- .tz_convert('Asia/Shanghai') # for iotdb
diff --git a/mlnode/iotdb/mlnode/data_access/utils/__init__.py b/mlnode/iotdb/mlnode/data_access/utils/__init__.py
deleted file mode 100644
index 2a1e720805..0000000000
--- a/mlnode/iotdb/mlnode/data_access/utils/__init__.py
+++ /dev/null
@@ -1,17 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
diff --git a/mlnode/iotdb/mlnode/data_access/utils/timefeatures.py b/mlnode/iotdb/mlnode/data_access/utils/timefeatures.py
deleted file mode 100644
index ecd6784ca4..0000000000
--- a/mlnode/iotdb/mlnode/data_access/utils/timefeatures.py
+++ /dev/null
@@ -1,171 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-from typing import List
-
-import numpy as np
-import pandas as pd
-from pandas.tseries import offsets
-from pandas.tseries.frequencies import to_offset
-
-
-class TimeFeature:
- def __init__(self):
- pass
-
- def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
- pass
-
- def __repr__(self):
- return self.__class__.__name__ + '()'
-
-
-class SecondOfMinute(TimeFeature):
- """Minute of hour encoded as value between [-0.5, 0.5]"""
-
- def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
- return index.second / 59.0 - 0.5
-
-
-class MinuteOfHour(TimeFeature):
- """Minute of hour encoded as value between [-0.5, 0.5]"""
-
- def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
- return index.minute / 59.0 - 0.5
-
-
-class HourOfDay(TimeFeature):
- """Hour of day encoded as value between [-0.5, 0.5]"""
-
- def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
- return index.hour / 23.0 - 0.5
-
-
-class DayOfWeek(TimeFeature):
- """Hour of day encoded as value between [-0.5, 0.5]"""
-
- def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
- return index.dayofweek / 6.0 - 0.5
-
-
-class DayOfMonth(TimeFeature):
- """Day of month encoded as value between [-0.5, 0.5]"""
-
- def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
- return (index.day - 1) / 30.0 - 0.5
-
-
-class DayOfYear(TimeFeature):
- """Day of year encoded as value between [-0.5, 0.5]"""
-
- def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
- return (index.dayofyear - 1) / 365.0 - 0.5
-
-
-class MonthOfYear(TimeFeature):
- """Month of year encoded as value between [-0.5, 0.5]"""
-
- def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
- return (index.month - 1) / 11.0 - 0.5
-
-
-class WeekOfYear(TimeFeature):
- """Week of year encoded as value between [-0.5, 0.5]"""
-
- def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
- return (index.isocalendar().week - 1) / 52.0 - 0.5
-
-
-def time_features_from_frequency_str(freq_str: str) -> List[TimeFeature]:
- """
- Embedding timestamp by given frequency string
- Args:
- freq_str: frequency string of the form [multiple][granularity] such as '12H', '5min', '1D' etc.
- Returns:
- a list of time features that will be appropriate for the given frequency string.
- """
-
- features_by_offsets = {
- offsets.YearEnd: [],
- offsets.QuarterEnd: [MonthOfYear],
- offsets.MonthEnd: [MonthOfYear],
- offsets.Week: [DayOfMonth, WeekOfYear],
- offsets.Day: [DayOfWeek, DayOfMonth, DayOfYear],
- offsets.BusinessDay: [DayOfWeek, DayOfMonth, DayOfYear],
- offsets.Hour: [HourOfDay, DayOfWeek, DayOfMonth, DayOfYear],
- offsets.Minute: [
- MinuteOfHour,
- HourOfDay,
- DayOfWeek,
- DayOfMonth,
- DayOfYear,
- ],
- offsets.Second: [
- SecondOfMinute,
- MinuteOfHour,
- HourOfDay,
- DayOfWeek,
- DayOfMonth,
- DayOfYear,
- ],
- }
-
- try:
- offset = to_offset(freq_str)
-
- for offset_type, feature_classes in features_by_offsets.items():
- if isinstance(offset, offset_type):
- return [cls() for cls in feature_classes]
- except ValueError:
- supported_freq_msg = f'''
- Unsupported time embedding frequency ({freq_str})
- The following frequencies are supported (case-insensitive):
- Y - yearly
- alias: A
- M - monthly
- W - weekly
- D - daily
- B - business days
- H - hourly
- T - minutely
- alias: min
- S - secondly
- '''
- raise RuntimeError(supported_freq_msg)
-
-
-def time_features(dates, time_embed='h'):
- return np.vstack([feat(dates) for feat in time_features_from_frequency_str(time_embed)])
-
-
-def data_transform(data_raw: pd.DataFrame, freq='h'):
- """
- data: dataframe, column 0 is the time stamp
- """
- columns = data_raw.columns
- data = data_raw[columns[1:]]
- data_stamp = data_raw[columns[0]]
- return data.values, data_stamp
-
-
-def timestamp_transform(timestamp_raw: pd.DataFrame, freq='h'):
- """
- """
- timestamp = pd.to_datetime(timestamp_raw.values.squeeze(), unit='ms', utc=True).tz_convert('Asia/Shanghai')
- timestamp = time_features(timestamp, freq=freq)
- timestamp = timestamp.transpose(1, 0)
- return timestamp
diff --git a/mlnode/iotdb/mlnode/exception.py b/mlnode/iotdb/mlnode/exception.py
index a7b211dbc2..6307909a9a 100644
--- a/mlnode/iotdb/mlnode/exception.py
+++ b/mlnode/iotdb/mlnode/exception.py
@@ -29,18 +29,3 @@ class BadNodeUrlError(_BaseError):
class ModelNotExistError(_BaseError):
def __init__(self, file_path: str):
self.message = "Model path: ({}) not exists".format(file_path)
-
-
-class BadConfigValueError(_BaseError):
- def __init__(self, config_name: str, config_value, hint: str = ''):
- self.message = "Bad value ({0}) for config: ({1}). {2}".format(config_value, config_name, hint)
-
-
-class MissingConfigError(_BaseError):
- def __init__(self, config_name: str):
- self.message = "Missing config: ({})".format(config_name)
-
-
-class WrongTypeConfigError(_BaseError):
- def __init__(self, config_name: str, expected_type: str):
- self.message = "Wrong type for config: ({0}), expected: ({1})".format(config_name, expected_type)
diff --git a/mlnode/iotdb/mlnode/handler.py b/mlnode/iotdb/mlnode/handler.py
index e7ff76cbe0..8a36353d47 100644
--- a/mlnode/iotdb/mlnode/handler.py
+++ b/mlnode/iotdb/mlnode/handler.py
@@ -15,19 +15,28 @@
# specific language governing permissions and limitations
# under the License.
#
+from enum import Enum
-from iotdb.mlnode.algorithm.factory import create_forecast_model
-from iotdb.mlnode.constant import TSStatusCode
-from iotdb.mlnode.data_access.factory import create_forecast_dataset
-from iotdb.mlnode.log import logger
-from iotdb.mlnode.parser import parse_training_request
-from iotdb.mlnode.util import get_status
+from iotdb.thrift.common.ttypes import TSStatus
from iotdb.thrift.mlnode import IMLNodeRPCService
from iotdb.thrift.mlnode.ttypes import (TCreateTrainingTaskReq,
TDeleteModelReq, TForecastReq,
TForecastResp)
+class TSStatusCode(Enum):
+ SUCCESS_STATUS = 200
+
+ def get_status_code(self) -> int:
+ return self.value
+
+
+def get_status(status_code: TSStatusCode, message: str) -> TSStatus:
+ status = TSStatus(status_code.get_status_code())
+ status.message = message
+ return status
+
+
class MLNodeRPCServiceHandler(IMLNodeRPCService.Iface):
def __init__(self):
pass
@@ -36,28 +45,7 @@ class MLNodeRPCServiceHandler(IMLNodeRPCService.Iface):
return get_status(TSStatusCode.SUCCESS_STATUS, "")
def createTrainingTask(self, req: TCreateTrainingTaskReq):
- # parse request stage (check required config and config type)
- data_config, model_config, task_config = parse_training_request(req)
-
- # create model stage (check model config legitimacy)
- try:
- model, model_config = create_forecast_model(**model_config)
- except Exception as e: # Create model failed
- return get_status(TSStatusCode.FAIL_STATUS, str(e))
- logger.info('model config: ' + str(model_config))
-
- # create data stage (check data config legitimacy)
- try:
- dataset, data_config = create_forecast_dataset(**data_config)
- except Exception as e: # Create data failed
- return get_status(TSStatusCode.FAIL_STATUS, str(e))
- logger.info('data config: ' + str(data_config))
-
- # create task stage (check task config legitimacy)
-
- # submit task stage (check resource and decide pending/start)
-
- return get_status(TSStatusCode.SUCCESS_STATUS, 'Successfully create training task')
+ return get_status(TSStatusCode.SUCCESS_STATUS, "")
def forecast(self, req: TForecastReq):
status = get_status(TSStatusCode.SUCCESS_STATUS, "")
diff --git a/mlnode/iotdb/mlnode/storage.py b/mlnode/iotdb/mlnode/model_storage.py
similarity index 100%
rename from mlnode/iotdb/mlnode/storage.py
rename to mlnode/iotdb/mlnode/model_storage.py
diff --git a/mlnode/iotdb/mlnode/parser.py b/mlnode/iotdb/mlnode/parser.py
deleted file mode 100644
index 236032b9a0..0000000000
--- a/mlnode/iotdb/mlnode/parser.py
+++ /dev/null
@@ -1,194 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-
-
-import argparse
-import re
-
-from iotdb.mlnode.algorithm.enums import ForecastTaskType
-from iotdb.mlnode.data_access.enums import DatasetType
-from iotdb.mlnode.exception import MissingConfigError, WrongTypeConfigError
-from iotdb.thrift.mlnode.ttypes import TCreateTrainingTaskReq
-
-
-class _ConfigParser(argparse.ArgumentParser):
- """
- A parser for parsing configs from configs: dict
- """
-
- def __init__(self):
- super().__init__()
-
- def parse_configs(self, configs):
- """
- Parse configs from a dict
- Args:configs: a dict of all configs which contains all required arguments
- Returns: a dict of parsed configs
- """
- args = self.parse_dict(configs)
- return vars(self.parse_known_args(args)[0])
-
- @staticmethod
- def parse_dict(config_dict):
- """
- Parse a dict of configs to a list of arguments
- Args:config_dict: a dict of configs
- Returns: a list of arguments which can be parsed by argparse
- """
- args = []
- for k, v in config_dict.items():
- args.append("--{}".format(k))
- if isinstance(v, str) and re.match(r'^\[(.*)]$', v):
- v = eval(v)
- v = [str(i) for i in v]
- args.extend(v)
- elif isinstance(v, list):
- args.extend([str(i) for i in v])
- else:
- args.append(v)
- return args
-
- def error(self, message: str):
- """
- Override the error method to raise exceptions instead of exiting
- """
- if message.startswith('the following arguments are required:'):
- missing_arg = re.findall(r': --(\w+)', message)[0]
- raise MissingConfigError(missing_arg)
- elif re.match(r'argument --\w+: invalid \w+ value:', message):
- argument = re.findall(r'argument --(\w+):', message)[0]
- expected_type = re.findall(r'invalid (\w+) value:', message)[0]
- raise WrongTypeConfigError(argument, expected_type)
- else:
- raise Exception(message)
-
-
-""" Argument description:
- - query_expressions: query expressions
- - query_filter: query filter
- - source_type: source type
- - filename: filename
- - dataset_type: dataset type
- - time_embed: freq for time features encoding
- - input_len: input sequence length
- - pred_len: prediction sequence length
- - input_vars: number of input variables
- - output_vars: number of output variables
-"""
-_data_config_parser = _ConfigParser()
-_data_config_parser.add_argument('--source_type', type=str, required=True)
-_data_config_parser.add_argument('--dataset_type', type=DatasetType, required=True)
-_data_config_parser.add_argument('--filename', type=str, default='')
-_data_config_parser.add_argument('--query_expressions', type=str, nargs='*', default=[])
-_data_config_parser.add_argument('--query_filter', type=str, default='')
-_data_config_parser.add_argument('--time_embed', type=str, default='h')
-_data_config_parser.add_argument('--input_len', type=int, default=96)
-_data_config_parser.add_argument('--pred_len', type=int, default=96)
-_data_config_parser.add_argument('--input_vars', type=int, default=1)
-_data_config_parser.add_argument('--output_vars', type=int, default=1)
-
-""" Argument description:
- - model_name: model name
- - input_len: input sequence length
- - pred_len: prediction sequence length
- - input_vars: number of input variables
- - output_vars: number of output variables
- - task_type: task type, options:[M, S, MS];
- M:multivariate predict multivariate,
- S:univariate predict univariate,
- MS:multivariate predict univariate'
- - kernel_size: kernel size
- - block_type: block type
- - d_model: dimension of feature in model
- - inner_layers: number of inner layers
- - outer_layers: number of outer layers
-"""
-_model_config_parser = _ConfigParser()
-_model_config_parser.add_argument('--model_name', type=str, required=True)
-_model_config_parser.add_argument('--input_len', type=int, default=96)
-_model_config_parser.add_argument('--pred_len', type=int, default=96)
-_model_config_parser.add_argument('--input_vars', type=int, default=1)
-_model_config_parser.add_argument('--output_vars', type=int, default=1)
-_model_config_parser.add_argument('--forecast_task_type', type=ForecastTaskType, default=ForecastTaskType.ENDOGENOUS,
- choices=list(ForecastTaskType))
-_model_config_parser.add_argument('--kernel_size', type=int, default=25)
-_model_config_parser.add_argument('--block_type', type=str, default='generic')
-_model_config_parser.add_argument('--d_model', type=int, default=128)
-_model_config_parser.add_argument('--inner_layers', type=int, default=4)
-_model_config_parser.add_argument('--outer_layers', type=int, default=4)
-
-""" Argument description:
- - model_id: model id
- - tuning: whether to tune hyperparameters
- - task_type: task type, options:[M, S, MS]; M:multivariate predict multivariate, S:univariate predict univariate,
- MS:multivariate predict univariate'
- - task_class: task class
- - input_len: input sequence length
- - pred_len: prediction sequence length
- - input_vars: number of input variables
- - output_vars: number of output variables
- - learning_rate: learning rate
- - batch_size: batch size
- - num_workers: number of workers
- - epochs: number of epochs
- - use_gpu: whether to use gpu
- - use_multi_gpu: whether to use multi-gpu
- - devices: devices to use
- - metric_names: metric to use
-"""
-_task_config_parser = _ConfigParser()
-_task_config_parser.add_argument('--task_class', type=str, required=True)
-_task_config_parser.add_argument('--model_id', type=str, required=True)
-_task_config_parser.add_argument('--tuning', type=bool, default=False)
-_task_config_parser.add_argument('--forecast_task_type', type=ForecastTaskType, default=ForecastTaskType.ENDOGENOUS,
- choices=list(ForecastTaskType))
-_task_config_parser.add_argument('--input_len', type=int, default=96)
-_task_config_parser.add_argument('--pred_len', type=int, default=96)
-_task_config_parser.add_argument('--input_vars', type=int, default=1)
-_task_config_parser.add_argument('--output_vars', type=int, default=1)
-_task_config_parser.add_argument('--learning_rate', type=float, default=0.0001)
-_task_config_parser.add_argument('--batch_size', type=int, default=32)
-_task_config_parser.add_argument('--num_workers', type=int, default=0)
-_task_config_parser.add_argument('--epochs', type=int, default=10)
-_task_config_parser.add_argument('--use_gpu', type=bool, default=False)
-_task_config_parser.add_argument('--gpu', type=int, default=0)
-_task_config_parser.add_argument('--use_multi_gpu', type=bool, default=False)
-_task_config_parser.add_argument('--devices', type=int, nargs='+', default=[0])
-_task_config_parser.add_argument('--metric_names', type=str, nargs='+', default=['MSE', 'MAE'])
-
-
-def parse_training_request(req: TCreateTrainingTaskReq):
- """
- Parse TCreateTrainingTaskReq with given yaml template
- Args:
- req: TCreateTrainingTaskReq
- Returns:
- data_config: configurations related to data
- model_config: configurations related to model
- task_config: configurations related to task
- """
- config = req.modelConfigs
- config.update(model_id=req.modelId)
- config.update(tuning=req.isAuto)
- config.update(query_expressions=req.queryExpressions)
- config.update(query_filter=req.queryFilter)
-
- data_config = _data_config_parser.parse_configs(config)
- model_config = _model_config_parser.parse_configs(config)
- task_config = _task_config_parser.parse_configs(config)
- return data_config, model_config, task_config
diff --git a/mlnode/iotdb/mlnode/serde.py b/mlnode/iotdb/mlnode/serde.py
index 5e98636e2e..26860faf38 100644
--- a/mlnode/iotdb/mlnode/serde.py
+++ b/mlnode/iotdb/mlnode/serde.py
@@ -15,38 +15,10 @@
# specific language governing permissions and limitations
# under the License.
#
-from enum import Enum
-
import numpy as np
import pandas as pd
-
-class TSDataType(Enum):
- BOOLEAN = 0
- INT32 = 1
- INT64 = 2
- FLOAT = 3
- DOUBLE = 4
- TEXT = 5
-
- # this method is implemented to avoid the issue reported by:
- # https://bugs.python.org/issue30545
- def __eq__(self, other) -> bool:
- return self.value == other.value
-
- def __hash__(self):
- return self.value
-
- def np_dtype(self):
- return {
- TSDataType.BOOLEAN: np.dtype(">?"),
- TSDataType.FLOAT: np.dtype(">f4"),
- TSDataType.DOUBLE: np.dtype(">f8"),
- TSDataType.INT32: np.dtype(">i4"),
- TSDataType.INT64: np.dtype(">i8"),
- TSDataType.TEXT: np.dtype("str"),
- }[self]
-
+from iotdb.utils.IoTDBConstants import TSDataType
TIMESTAMP_STR = "Time"
START_INDEX = 2
diff --git a/mlnode/iotdb/mlnode/util.py b/mlnode/iotdb/mlnode/util.py
index d67ba1290d..8932479c4a 100644
--- a/mlnode/iotdb/mlnode/util.py
+++ b/mlnode/iotdb/mlnode/util.py
@@ -15,19 +15,20 @@
# specific language governing permissions and limitations
# under the License.
#
-
-from iotdb.mlnode.constant import TSStatusCode
from iotdb.mlnode.exception import BadNodeUrlError
from iotdb.mlnode.log import logger
-from iotdb.thrift.common.ttypes import TEndPoint, TSStatus
+from iotdb.thrift.common.ttypes import TEndPoint
def parse_endpoint_url(endpoint_url: str) -> TEndPoint:
""" Parse TEndPoint from a given endpoint url.
+
Args:
endpoint_url: an endpoint url, format: ip:port
+
Returns:
TEndPoint
+
Raises:
BadNodeUrlError
"""
@@ -44,15 +45,3 @@ def parse_endpoint_url(endpoint_url: str) -> TEndPoint:
except ValueError as e:
logger.warning("Illegal endpoint url format: {} ({})".format(endpoint_url, e))
raise BadNodeUrlError(endpoint_url)
-
-
-def get_status(status_code: TSStatusCode, message: str) -> TSStatus:
- status = TSStatus(status_code.get_status_code())
- status.message = message
- return status
-
-
-def verify_success(status: TSStatus, err_msg: str) -> None:
- if status.code != TSStatusCode.SUCCESS_STATUS:
- logger.warn(err_msg + ", error status is ", status)
- raise RuntimeError(str(status.code) + ": " + status.message)
diff --git a/mlnode/test/test_model_storage.py b/mlnode/test/test_model_storage.py
index 3750c49c2c..99857db37e 100644
--- a/mlnode/test/test_model_storage.py
+++ b/mlnode/test/test_model_storage.py
@@ -23,7 +23,7 @@ import time
import torch.nn as nn
from iotdb.mlnode.config import config
-from iotdb.mlnode.storage import model_storage
+from iotdb.mlnode.model_storage import model_storage
class TestModel(nn.Module):
diff --git a/mlnode/test/test_parse_training_request.py b/mlnode/test/test_parse_training_request.py
deleted file mode 100644
index ec318ae60d..0000000000
--- a/mlnode/test/test_parse_training_request.py
+++ /dev/null
@@ -1,136 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-from iotdb.mlnode.parser import parse_training_request
-from iotdb.thrift.mlnode.ttypes import TCreateTrainingTaskReq
-
-
-def test_parse_training_request():
- model_id = 'mid_etth1_dlinear_default'
- is_auto = False
- model_configs = {
- 'task_class': 'forecast_training_task',
- 'source_type': 'thrift',
- 'dataset_type': 'window',
- 'filename': 'ETTh1.csv',
- 'time_embed': 'h',
- 'input_len': 96,
- 'pred_len': 96,
- 'model_name': 'dlinear',
- 'input_vars': 7,
- 'output_vars': 7,
- 'forecast_type': 'm',
- 'kernel_size': 25,
- 'learning_rate': 1e-3,
- 'batch_size': 32,
- 'num_workers': 0,
- 'epochs': 10,
- 'metric_names': ['MSE', 'MAE']
- }
- query_expressions = ['root.eg.etth1.**', 'root.eg.etth1.**', 'root.eg.etth1.**']
- query_filter = '0,1501516800000'
- req = TCreateTrainingTaskReq(
- modelId=str(model_id),
- isAuto=is_auto,
- modelConfigs={k: str(v) for k, v in model_configs.items()},
- queryExpressions=[str(query) for query in query_expressions],
- queryFilter=str(query_filter),
- )
- data_config, model_config, task_config = parse_training_request(req)
- for config in model_configs:
- if config in data_config:
- assert data_config[config] == model_configs[config]
- if config in model_config:
- assert model_config[config] == model_configs[config]
- if config in task_config:
- assert task_config[config] == model_configs[config]
-
-
-def test_missing_argument():
- # missing model_name
- model_id = 'mid_etth1_dlinear_default'
- is_auto = False
- model_configs = {
- 'task_class': 'forecast_training_task',
- 'source_type': 'thrift',
- 'dataset_type': 'window',
- 'filename': 'ETTh1.csv',
- 'time_embed': 'h',
- 'input_len': 96,
- 'pred_len': 96,
- 'input_vars': 7,
- 'output_vars': 7,
- 'forecast_type': 'm',
- 'kernel_size': 25,
- 'learning_rate': 1e-3,
- 'batch_size': 32,
- 'num_workers': 0,
- 'epochs': 10,
- 'metric_names': ['MSE', 'MAE']
- }
- query_expressions = ['root.eg.etth1.**', 'root.eg.etth1.**', 'root.eg.etth1.**']
- query_filter = '0,1501516800000'
- req = TCreateTrainingTaskReq(
- modelId=str(model_id),
- isAuto=is_auto,
- modelConfigs={k: str(v) for k, v in model_configs.items()},
- queryExpressions=[str(query) for query in query_expressions],
- queryFilter=str(query_filter),
- )
- try:
- data_config, model_config, task_config = parse_training_request(req)
- except Exception as e:
- assert e.message == 'Missing config: (model_name)'
-
-
-def test_wrong_argument_type():
- model_id = 'mid_etth1_dlinear_default'
- is_auto = False
- model_configs = {
- 'task_class': 'forecast_training_task',
- 'source_type': 'thrift',
- 'dataset_type': 'window',
- 'filename': 'ETTh1.csv',
- 'time_embed': 'h',
- 'input_len': 96.7,
- 'pred_len': 96,
- 'model_name': 'dlinear',
- 'input_vars': 7,
- 'output_vars': 7,
- 'forecast_type': 'm',
- 'kernel_size': 25,
- 'learning_rate': 1e-3,
- 'batch_size': 32,
- 'num_workers': 0,
- 'epochs': 10,
- 'metric_names': ['MSE', 'MAE']
- }
- query_expressions = ['root.eg.etth1.**', 'root.eg.etth1.**', 'root.eg.etth1.**']
- query_filter = '0,1501516800000'
- req = TCreateTrainingTaskReq(
- modelId=str(model_id),
- isAuto=is_auto,
- modelConfigs={k: str(v) for k, v in model_configs.items()},
- queryExpressions=[str(query) for query in query_expressions],
- queryFilter=str(query_filter),
- )
- try:
- data_config, model_config, task_config = parse_training_request(req)
- except Exception as e:
- message = "Wrong type for config: ({})".format('input_len')
- message += ", expected: ({})".format('int')
- assert e.message == message
diff --git a/node-commons/src/assembly/resources/conf/iotdb-common.properties b/node-commons/src/assembly/resources/conf/iotdb-common.properties
index 02bae820b8..4b67e5308d 100644
--- a/node-commons/src/assembly/resources/conf/iotdb-common.properties
+++ b/node-commons/src/assembly/resources/conf/iotdb-common.properties
@@ -892,15 +892,6 @@ cluster_name=defaultCluster
### Continuous Query Configuration
####################
-# Uncomment the following field to configure the pipe lib directory.
-# For Window platform
-# If its prefix is a drive specifier followed by "\\", or if its prefix is "\\\\", then the path is
-# absolute. Otherwise, it is relative.
-# pipe_lib_dir=ext\\pipe
-# For Linux platform
-# If its prefix is "/", then the path is absolute. Otherwise, it is relative.
-# pipe_lib_dir=ext/pipe
-
# The number of threads in the scheduled thread pool that submit continuous query tasks periodically
# Datatype: int
# continuous_query_submit_thread_count=2
@@ -913,6 +904,18 @@ cluster_name=defaultCluster
### PIPE Configuration
####################
+# Uncomment the following field to configure the pipe lib directory.
+# For Window platform
+# If its prefix is a drive specifier followed by "\\", or if its prefix is "\\\\", then the path is
+# absolute. Otherwise, it is relative.
+# pipe_lib_dir=ext\\pipe
+# For Linux platform
+# If its prefix is "/", then the path is absolute. Otherwise, it is relative.
+# pipe_lib_dir=ext/pipe
+
+# The maximum number of threads that can be used to execute the pipe subtasks in PipeSubtaskExecutor.
+# pipe_max_thread_num = 5
+
# White IP list of Sync client.
# Please use the form of IPv4 network segment to present the range of IP, for example: 192.168.0.0/16
# If there are more than one IP segment, please separate them by commas
diff --git a/node-commons/src/main/java/org/apache/iotdb/commons/concurrent/ThreadName.java b/node-commons/src/main/java/org/apache/iotdb/commons/concurrent/ThreadName.java
index 9b33ae30ee..f8da5262ab 100644
--- a/node-commons/src/main/java/org/apache/iotdb/commons/concurrent/ThreadName.java
+++ b/node-commons/src/main/java/org/apache/iotdb/commons/concurrent/ThreadName.java
@@ -65,7 +65,12 @@ public enum ThreadName {
SCHEMA_REGION_RELEASE_PROCESSOR("SchemaRegion-Release-Task-Processor"),
SCHEMA_RELEASE_MONITOR("Schema-Release-Task-Monitor"),
SCHEMA_REGION_FLUSH_PROCESSOR("SchemaRegion-Flush-Task-Processor"),
- SCHEMA_FLUSH_MONITOR("Schema-Flush-Task-Monitor");
+ SCHEMA_FLUSH_MONITOR("Schema-Flush-Task-Monitor"),
+ PIPE_ASSIGNER_EXECUTOR_POOL("Pipe-Assigner-Executor-Pool"),
+ PIPE_PROCESSOR_EXECUTOR_POOL("Pipe-Processor-Executor-Pool"),
+ PIPE_CONNECTOR_EXECUTOR_POOL("Pipe-Connector-Executor-Pool"),
+ PIPE_SUBTASK_CALLBACK_EXECUTOR_POOL("Pipe-SubTask-Callback-Executor-Pool"),
+ ;
private final String name;
diff --git a/server/src/main/java/org/apache/iotdb/db/conf/IoTDBConfig.java b/server/src/main/java/org/apache/iotdb/db/conf/IoTDBConfig.java
index b0f97c2288..dd4ef4a016 100644
--- a/server/src/main/java/org/apache/iotdb/db/conf/IoTDBConfig.java
+++ b/server/src/main/java/org/apache/iotdb/db/conf/IoTDBConfig.java
@@ -1057,6 +1057,9 @@ public class IoTDBConfig {
// customizedProperties, this should be empty by default.
private Properties customizedProperties = new Properties();
+ /** The maximum number of threads that can be used to execute subtasks in PipeSubtaskExecutor */
+ private int pipeMaxThreadNum = 5;
+
IoTDBConfig() {}
public float getUdfMemoryBudgetInMB() {
@@ -3663,4 +3666,12 @@ public class IoTDBConfig {
public int getModeMapSizeThreshold() {
return modeMapSizeThreshold;
}
+
+ public void setPipeSubtaskExecutorMaxThreadNum(int pipeMaxThreadNum) {
+ this.pipeMaxThreadNum = pipeMaxThreadNum;
+ }
+
+ public int getPipeSubtaskExecutorMaxThreadNum() {
+ return pipeMaxThreadNum;
+ }
}
diff --git a/server/src/main/java/org/apache/iotdb/db/conf/IoTDBDescriptor.java b/server/src/main/java/org/apache/iotdb/db/conf/IoTDBDescriptor.java
index b6961fab44..d43d6b3e03 100644
--- a/server/src/main/java/org/apache/iotdb/db/conf/IoTDBDescriptor.java
+++ b/server/src/main/java/org/apache/iotdb/db/conf/IoTDBDescriptor.java
@@ -1788,6 +1788,15 @@ public class IoTDBDescriptor {
private void loadPipeProps(Properties properties) {
conf.setPipeDir(properties.getProperty("pipe_lib_dir", conf.getPipeDir()));
+
+ conf.setPipeSubtaskExecutorMaxThreadNum(
+ Integer.parseInt(
+ properties.getProperty(
+ "pipe_max_thread_num",
+ Integer.toString(conf.getPipeSubtaskExecutorMaxThreadNum()))));
+ if (conf.getPipeSubtaskExecutorMaxThreadNum() <= 0) {
+ conf.setPipeSubtaskExecutorMaxThreadNum(5);
+ }
}
private void loadCQProps(Properties properties) {
diff --git a/server/src/main/java/org/apache/iotdb/db/metadata/mtree/MTreeBelowSGCachedImpl.java b/server/src/main/java/org/apache/iotdb/db/metadata/mtree/MTreeBelowSGCachedImpl.java
index 276161bce5..9425a102a9 100644
--- a/server/src/main/java/org/apache/iotdb/db/metadata/mtree/MTreeBelowSGCachedImpl.java
+++ b/server/src/main/java/org/apache/iotdb/db/metadata/mtree/MTreeBelowSGCachedImpl.java
@@ -890,7 +890,6 @@ public class MTreeBelowSGCachedImpl {
entityMNode.setSchemaTemplateId(templateId);
store.updateMNode(entityMNode.getAsMNode());
- regionStatistics.activateTemplate(templateId);
} finally {
unPinPath(cur);
}
diff --git a/server/src/main/java/org/apache/iotdb/db/metadata/mtree/MTreeBelowSGMemoryImpl.java b/server/src/main/java/org/apache/iotdb/db/metadata/mtree/MTreeBelowSGMemoryImpl.java
index 58838178cc..9d03f3075e 100644
--- a/server/src/main/java/org/apache/iotdb/db/metadata/mtree/MTreeBelowSGMemoryImpl.java
+++ b/server/src/main/java/org/apache/iotdb/db/metadata/mtree/MTreeBelowSGMemoryImpl.java
@@ -810,7 +810,6 @@ public class MTreeBelowSGMemoryImpl {
}
entityMNode.setUseTemplate(true);
entityMNode.setSchemaTemplateId(templateId);
- regionStatistics.activateTemplate(templateId);
}
public long countPathsUsingTemplate(PartialPath pathPattern, int templateId)
diff --git a/server/src/main/java/org/apache/iotdb/db/pipe/agent/runtime/PipeRuntimeAgent.java b/server/src/main/java/org/apache/iotdb/db/pipe/agent/runtime/PipeRuntimeAgent.java
index cbfe53be8b..55e3b0c1ec 100644
--- a/server/src/main/java/org/apache/iotdb/db/pipe/agent/runtime/PipeRuntimeAgent.java
+++ b/server/src/main/java/org/apache/iotdb/db/pipe/agent/runtime/PipeRuntimeAgent.java
@@ -19,8 +19,23 @@
package org.apache.iotdb.db.pipe.agent.runtime;
+import org.apache.iotdb.db.pipe.task.callable.PipeSubtask;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
public class PipeRuntimeAgent {
+ private static final Logger LOGGER = LoggerFactory.getLogger(PipeRuntimeAgent.class);
+
+ public void report(PipeSubtask subtask) {
+ // TODO: terminate the task by the given taskID
+ LOGGER.warn(
+ "Failed to execute task {} after many retries, last failed cause by {}",
+ subtask.getTaskID(),
+ subtask.getLastFailedCause());
+ }
+
///////////////////////// Singleton Instance Holder /////////////////////////
private PipeRuntimeAgent() {}
diff --git a/server/src/main/java/org/apache/iotdb/db/pipe/core/connector/PipeConnectorPluginRuntimeWrapper.java b/server/src/main/java/org/apache/iotdb/db/pipe/core/connector/PipeConnectorPluginRuntimeWrapper.java
index c1c5be1166..1981f6b87f 100644
--- a/server/src/main/java/org/apache/iotdb/db/pipe/core/connector/PipeConnectorPluginRuntimeWrapper.java
+++ b/server/src/main/java/org/apache/iotdb/db/pipe/core/connector/PipeConnectorPluginRuntimeWrapper.java
@@ -20,12 +20,54 @@
package org.apache.iotdb.db.pipe.core.connector;
import org.apache.iotdb.pipe.api.PipeConnector;
+import org.apache.iotdb.pipe.api.event.Event;
+import org.apache.iotdb.pipe.api.event.deletion.DeletionEvent;
+import org.apache.iotdb.pipe.api.event.insertion.TabletInsertionEvent;
+import org.apache.iotdb.pipe.api.event.insertion.TsFileInsertionEvent;
+import org.apache.iotdb.pipe.api.exception.PipeException;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.Queue;
public class PipeConnectorPluginRuntimeWrapper {
+ private static final Logger LOGGER =
+ LoggerFactory.getLogger(PipeConnectorPluginRuntimeWrapper.class);
+
+ private final Queue<Event> inputEventQueue;
private final PipeConnector pipeConnector;
- public PipeConnectorPluginRuntimeWrapper(PipeConnector pipeConnector) {
+ public PipeConnectorPluginRuntimeWrapper(
+ Queue<Event> inputEventQueue, PipeConnector pipeConnector) {
+ this.inputEventQueue = inputEventQueue;
this.pipeConnector = pipeConnector;
}
+
+ // TODO: for a while
+ public void executeForAWhile() {
+ if (inputEventQueue.isEmpty()) {
+ return;
+ }
+
+ final Event event = inputEventQueue.poll();
+
+ try {
+ if (event instanceof TabletInsertionEvent) {
+ pipeConnector.transfer((TabletInsertionEvent) event);
+ } else if (event instanceof TsFileInsertionEvent) {
+ pipeConnector.transfer((TsFileInsertionEvent) event);
+ } else if (event instanceof DeletionEvent) {
+ pipeConnector.transfer((DeletionEvent) event);
+ } else {
+ throw new RuntimeException("Unsupported event type: " + event.getClass().getName());
+ }
+ } catch (Exception e) {
+ e.printStackTrace();
+ throw new PipeException(
+ "Error occurred during executing PipeConnector#transfer, perhaps need to check whether the implementation of PipeConnector is correct according to the pipe-api description.",
+ e);
+ }
+ }
}
diff --git a/server/src/main/java/org/apache/iotdb/db/pipe/core/processor/PipeProcessorPluginRuntimeWrapper.java b/server/src/main/java/org/apache/iotdb/db/pipe/core/processor/PipeProcessorPluginRuntimeWrapper.java
index 3153d2ae0e..a5b56bc68e 100644
--- a/server/src/main/java/org/apache/iotdb/db/pipe/core/processor/PipeProcessorPluginRuntimeWrapper.java
+++ b/server/src/main/java/org/apache/iotdb/db/pipe/core/processor/PipeProcessorPluginRuntimeWrapper.java
@@ -20,12 +20,58 @@
package org.apache.iotdb.db.pipe.core.processor;
import org.apache.iotdb.pipe.api.PipeProcessor;
+import org.apache.iotdb.pipe.api.collector.EventCollector;
+import org.apache.iotdb.pipe.api.event.Event;
+import org.apache.iotdb.pipe.api.event.deletion.DeletionEvent;
+import org.apache.iotdb.pipe.api.event.insertion.TabletInsertionEvent;
+import org.apache.iotdb.pipe.api.event.insertion.TsFileInsertionEvent;
+import org.apache.iotdb.pipe.api.exception.PipeException;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.Queue;
public class PipeProcessorPluginRuntimeWrapper {
+ private static final Logger LOGGER =
+ LoggerFactory.getLogger(PipeProcessorPluginRuntimeWrapper.class);
+
+ private final Queue<Event> inputEventQueue;
private final PipeProcessor pipeProcessor;
+ private final EventCollector outputEventCollector;
- public PipeProcessorPluginRuntimeWrapper(PipeProcessor pipeProcessor) {
+ public PipeProcessorPluginRuntimeWrapper(
+ Queue<Event> inputEventQueue,
+ PipeProcessor pipeProcessor,
+ EventCollector outputEventCollector) {
+ this.inputEventQueue = inputEventQueue;
this.pipeProcessor = pipeProcessor;
+ this.outputEventCollector = outputEventCollector;
+ }
+
+ public void executeForAWhile() {
+ if (inputEventQueue.isEmpty()) {
+ return;
+ }
+
+ final Event event = inputEventQueue.poll();
+
+ try {
+ if (event instanceof TabletInsertionEvent) {
+ pipeProcessor.process((TabletInsertionEvent) event, outputEventCollector);
+ } else if (event instanceof TsFileInsertionEvent) {
+ pipeProcessor.process((TsFileInsertionEvent) event, outputEventCollector);
+ } else if (event instanceof DeletionEvent) {
+ pipeProcessor.process((DeletionEvent) event, outputEventCollector);
+ } else {
+ throw new RuntimeException("Unsupported event type: " + event.getClass().getName());
+ }
+ } catch (Exception e) {
+ e.printStackTrace();
+ throw new PipeException(
+ "Error occurred during executing PipeProcessor#process, perhaps need to check whether the implementation of PipeProcessor is correct according to the pipe-api description.",
+ e);
+ }
}
}
diff --git a/server/src/main/java/org/apache/iotdb/db/pipe/execution/executor/PipeAssignerSubtaskExecutor.java b/server/src/main/java/org/apache/iotdb/db/pipe/execution/executor/PipeAssignerSubtaskExecutor.java
index cc3dd43987..4bfe8f0bb9 100644
--- a/server/src/main/java/org/apache/iotdb/db/pipe/execution/executor/PipeAssignerSubtaskExecutor.java
+++ b/server/src/main/java/org/apache/iotdb/db/pipe/execution/executor/PipeAssignerSubtaskExecutor.java
@@ -19,4 +19,14 @@
package org.apache.iotdb.db.pipe.execution.executor;
-public class PipeAssignerSubtaskExecutor implements PipeSubtaskExecutor {}
+import org.apache.iotdb.commons.concurrent.ThreadName;
+import org.apache.iotdb.db.conf.IoTDBDescriptor;
+
+public class PipeAssignerSubtaskExecutor extends PipeSubtaskExecutor {
+
+ PipeAssignerSubtaskExecutor() {
+ super(
+ IoTDBDescriptor.getInstance().getConfig().getPipeSubtaskExecutorMaxThreadNum(),
+ ThreadName.PIPE_ASSIGNER_EXECUTOR_POOL);
+ }
+}
diff --git a/server/src/main/java/org/apache/iotdb/db/pipe/execution/executor/PipeConnectorSubtaskExecutor.java b/server/src/main/java/org/apache/iotdb/db/pipe/execution/executor/PipeConnectorSubtaskExecutor.java
index 98eaf31d1b..33ba3a4210 100644
--- a/server/src/main/java/org/apache/iotdb/db/pipe/execution/executor/PipeConnectorSubtaskExecutor.java
+++ b/server/src/main/java/org/apache/iotdb/db/pipe/execution/executor/PipeConnectorSubtaskExecutor.java
@@ -19,4 +19,14 @@
package org.apache.iotdb.db.pipe.execution.executor;
-public class PipeConnectorSubtaskExecutor implements PipeSubtaskExecutor {}
+import org.apache.iotdb.commons.concurrent.ThreadName;
+import org.apache.iotdb.db.conf.IoTDBDescriptor;
+
+public class PipeConnectorSubtaskExecutor extends PipeSubtaskExecutor {
+
+ PipeConnectorSubtaskExecutor() {
+ super(
+ IoTDBDescriptor.getInstance().getConfig().getPipeSubtaskExecutorMaxThreadNum(),
+ ThreadName.PIPE_CONNECTOR_EXECUTOR_POOL);
+ }
+}
diff --git a/server/src/main/java/org/apache/iotdb/db/pipe/execution/executor/PipeProcessorSubtaskExecutor.java b/server/src/main/java/org/apache/iotdb/db/pipe/execution/executor/PipeProcessorSubtaskExecutor.java
index c61871fafe..e3a1e1a7ec 100644
--- a/server/src/main/java/org/apache/iotdb/db/pipe/execution/executor/PipeProcessorSubtaskExecutor.java
+++ b/server/src/main/java/org/apache/iotdb/db/pipe/execution/executor/PipeProcessorSubtaskExecutor.java
@@ -19,4 +19,14 @@
package org.apache.iotdb.db.pipe.execution.executor;
-public class PipeProcessorSubtaskExecutor implements PipeSubtaskExecutor {}
+import org.apache.iotdb.commons.concurrent.ThreadName;
+import org.apache.iotdb.db.conf.IoTDBDescriptor;
+
+public class PipeProcessorSubtaskExecutor extends PipeSubtaskExecutor {
+
+ PipeProcessorSubtaskExecutor() {
+ super(
+ IoTDBDescriptor.getInstance().getConfig().getPipeSubtaskExecutorMaxThreadNum(),
+ ThreadName.PIPE_PROCESSOR_EXECUTOR_POOL);
+ }
+}
diff --git a/server/src/main/java/org/apache/iotdb/db/pipe/execution/executor/PipeSubtaskExecutor.java b/server/src/main/java/org/apache/iotdb/db/pipe/execution/executor/PipeSubtaskExecutor.java
index 7d97605dff..05956f0a90 100644
--- a/server/src/main/java/org/apache/iotdb/db/pipe/execution/executor/PipeSubtaskExecutor.java
+++ b/server/src/main/java/org/apache/iotdb/db/pipe/execution/executor/PipeSubtaskExecutor.java
@@ -19,4 +19,124 @@
package org.apache.iotdb.db.pipe.execution.executor;
-public interface PipeSubtaskExecutor {}
+import org.apache.iotdb.commons.concurrent.IoTDBThreadPoolFactory;
+import org.apache.iotdb.commons.concurrent.ThreadName;
+import org.apache.iotdb.commons.utils.TestOnly;
+import org.apache.iotdb.db.pipe.task.callable.PipeSubtask;
+
+import com.google.common.util.concurrent.ListeningExecutorService;
+import com.google.common.util.concurrent.MoreExecutors;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import javax.annotation.concurrent.NotThreadSafe;
+
+import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ExecutorService;
+
+@NotThreadSafe
+public abstract class PipeSubtaskExecutor {
+
+ private static final Logger LOGGER = LoggerFactory.getLogger(PipeSubtaskExecutor.class);
+
+ private static final ExecutorService subtaskCallbackListeningExecutor =
+ IoTDBThreadPoolFactory.newSingleThreadExecutor(
+ ThreadName.PIPE_SUBTASK_CALLBACK_EXECUTOR_POOL.getName());
+ private final ListeningExecutorService subtaskWorkerThreadPoolExecutor;
+
+ private final Map<String, PipeSubtask> registeredIdSubtaskMapper;
+
+ private int corePoolSize;
+
+ protected PipeSubtaskExecutor(int corePoolSize, ThreadName threadName) {
+ subtaskWorkerThreadPoolExecutor =
+ MoreExecutors.listeningDecorator(
+ IoTDBThreadPoolFactory.newFixedThreadPool(corePoolSize, threadName.getName()));
+
+ registeredIdSubtaskMapper = new ConcurrentHashMap<>();
+
+ this.corePoolSize = corePoolSize;
+ }
+
+ /////////////////////// subtask management ///////////////////////
+
+ public final void register(PipeSubtask subtask) {
+ if (registeredIdSubtaskMapper.containsKey(subtask.getTaskID())) {
+ LOGGER.warn("The subtask {} is already registered.", subtask.getTaskID());
+ return;
+ }
+
+ registeredIdSubtaskMapper.put(subtask.getTaskID(), subtask);
+ subtask.bindExecutors(subtaskWorkerThreadPoolExecutor, subtaskCallbackListeningExecutor);
+ }
+
+ public final void start(String subTaskID) {
+ if (!registeredIdSubtaskMapper.containsKey(subTaskID)) {
+ LOGGER.warn("The subtask {} is not registered.", subTaskID);
+ return;
+ }
+
+ final PipeSubtask subtask = registeredIdSubtaskMapper.get(subTaskID);
+ if (subtask.isSubmittingSelf()) {
+ LOGGER.info("The subtask {} is already running.", subTaskID);
+ } else {
+ subtask.allowSubmittingSelf();
+ subtask.submitSelf();
+ LOGGER.info("The subtask {} is started to submit self.", subTaskID);
+ }
+ }
+
+ public final void stop(String subTaskID) {
+ if (!registeredIdSubtaskMapper.containsKey(subTaskID)) {
+ LOGGER.warn("The subtask {} is not registered.", subTaskID);
+ return;
+ }
+
+ registeredIdSubtaskMapper.get(subTaskID).disallowSubmittingSelf();
+ }
+
+ public final void deregister(String subTaskID) {
+ stop(subTaskID);
+
+ registeredIdSubtaskMapper.remove(subTaskID);
+ }
+
+ @TestOnly
+ public final boolean isRegistered(String subTaskID) {
+ return registeredIdSubtaskMapper.containsKey(subTaskID);
+ }
+
+ @TestOnly
+ public final int getRegisteredSubtaskNumber() {
+ return registeredIdSubtaskMapper.size();
+ }
+
+ /////////////////////// executor management ///////////////////////
+
+ public final void shutdown() {
+ if (isShutdown()) {
+ return;
+ }
+
+ // stop all subtasks before shutting down the executor
+ for (PipeSubtask subtask : registeredIdSubtaskMapper.values()) {
+ subtask.disallowSubmittingSelf();
+ }
+
+ subtaskWorkerThreadPoolExecutor.shutdown();
+ }
+
+ public final boolean isShutdown() {
+ return subtaskWorkerThreadPoolExecutor.isShutdown();
+ }
+
+ public final void adjustExecutorThreadNumber(int threadNum) {
+ corePoolSize = threadNum;
+ throw new UnsupportedOperationException("Not implemented yet.");
+ }
+
+ public final int getExecutorThreadNumber() {
+ return corePoolSize;
+ }
+}
diff --git a/server/src/main/java/org/apache/iotdb/db/pipe/execution/executor/PipeTaskExecutor.java b/server/src/main/java/org/apache/iotdb/db/pipe/execution/executor/PipeTaskExecutorManager.java
similarity index 59%
rename from server/src/main/java/org/apache/iotdb/db/pipe/execution/executor/PipeTaskExecutor.java
rename to server/src/main/java/org/apache/iotdb/db/pipe/execution/executor/PipeTaskExecutorManager.java
index 4437fb119d..8698a23d86 100644
--- a/server/src/main/java/org/apache/iotdb/db/pipe/execution/executor/PipeTaskExecutor.java
+++ b/server/src/main/java/org/apache/iotdb/db/pipe/execution/executor/PipeTaskExecutorManager.java
@@ -19,30 +19,48 @@
package org.apache.iotdb.db.pipe.execution.executor;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
/**
* PipeTaskExecutor is responsible for executing the pipe tasks, and it is scheduled by the
* PipeTaskScheduler. It is a singleton class.
*/
-public class PipeTaskExecutor {
+public class PipeTaskExecutorManager {
+
+ private static final Logger LOGGER = LoggerFactory.getLogger(PipeTaskExecutorManager.class);
- private final PipeAssignerSubtaskExecutor assignerSubtaskExecutor =
- new PipeAssignerSubtaskExecutor();
- private final PipeProcessorSubtaskExecutor processorSubtaskExecutor =
- new PipeProcessorSubtaskExecutor();
- private final PipeConnectorSubtaskExecutor connectorSubtaskExecutor =
- new PipeConnectorSubtaskExecutor();
+ private final PipeAssignerSubtaskExecutor assignerSubtaskExecutor;
+ private final PipeProcessorSubtaskExecutor processorSubtaskExecutor;
+ private final PipeConnectorSubtaskExecutor connectorSubtaskExecutor;
+
+ public PipeAssignerSubtaskExecutor getAssignerSubtaskExecutor() {
+ return assignerSubtaskExecutor;
+ }
+
+ public PipeProcessorSubtaskExecutor getProcessorSubtaskExecutor() {
+ return processorSubtaskExecutor;
+ }
+
+ public PipeConnectorSubtaskExecutor getConnectorSubtaskExecutor() {
+ return connectorSubtaskExecutor;
+ }
///////////////////////// Singleton Instance Holder /////////////////////////
- private PipeTaskExecutor() {}
+ private PipeTaskExecutorManager() {
+ assignerSubtaskExecutor = new PipeAssignerSubtaskExecutor();
+ processorSubtaskExecutor = new PipeProcessorSubtaskExecutor();
+ connectorSubtaskExecutor = new PipeConnectorSubtaskExecutor();
+ }
private static class PipeTaskExecutorHolder {
- private static PipeTaskExecutor instance = null;
+ private static PipeTaskExecutorManager instance = null;
}
- public static PipeTaskExecutor setupAndGetInstance() {
+ public static PipeTaskExecutorManager setupAndGetInstance() {
if (PipeTaskExecutorHolder.instance == null) {
- PipeTaskExecutorHolder.instance = new PipeTaskExecutor();
+ PipeTaskExecutorHolder.instance = new PipeTaskExecutorManager();
}
return PipeTaskExecutorHolder.instance;
}
diff --git a/server/src/main/java/org/apache/iotdb/db/pipe/execution/scheduler/PipeAssignerSubtaskScheduler.java b/server/src/main/java/org/apache/iotdb/db/pipe/execution/scheduler/PipeAssignerSubtaskScheduler.java
deleted file mode 100644
index 2cab31d737..0000000000
--- a/server/src/main/java/org/apache/iotdb/db/pipe/execution/scheduler/PipeAssignerSubtaskScheduler.java
+++ /dev/null
@@ -1,36 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.iotdb.db.pipe.execution.scheduler;
-
-import org.apache.iotdb.db.pipe.task.runnable.PipeSubtask;
-
-public class PipeAssignerSubtaskScheduler implements PipeSubtaskScheduler {
- @Override
- public void createSubtask(String subtaskId, PipeSubtask subtask) {}
-
- @Override
- public void dropSubtask(String subtaskId) {}
-
- @Override
- public void startSubtask(String subtaskId) {}
-
- @Override
- public void stopSubtask(String subtaskId) {}
-}
diff --git a/server/src/main/java/org/apache/iotdb/db/pipe/execution/scheduler/PipeConnectorSubtaskScheduler.java b/server/src/main/java/org/apache/iotdb/db/pipe/execution/scheduler/PipeConnectorSubtaskScheduler.java
deleted file mode 100644
index c53c6b040d..0000000000
--- a/server/src/main/java/org/apache/iotdb/db/pipe/execution/scheduler/PipeConnectorSubtaskScheduler.java
+++ /dev/null
@@ -1,36 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.iotdb.db.pipe.execution.scheduler;
-
-import org.apache.iotdb.db.pipe.task.runnable.PipeSubtask;
-
-public class PipeConnectorSubtaskScheduler implements PipeSubtaskScheduler {
- @Override
- public void createSubtask(String subtaskId, PipeSubtask subtask) {}
-
- @Override
- public void dropSubtask(String subtaskId) {}
-
- @Override
- public void startSubtask(String subtaskId) {}
-
- @Override
- public void stopSubtask(String subtaskId) {}
-}
diff --git a/server/src/main/java/org/apache/iotdb/db/pipe/execution/scheduler/PipeProcessorSubtaskScheduler.java b/server/src/main/java/org/apache/iotdb/db/pipe/execution/scheduler/PipeProcessorSubtaskScheduler.java
deleted file mode 100644
index 9f5df481b8..0000000000
--- a/server/src/main/java/org/apache/iotdb/db/pipe/execution/scheduler/PipeProcessorSubtaskScheduler.java
+++ /dev/null
@@ -1,36 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.iotdb.db.pipe.execution.scheduler;
-
-import org.apache.iotdb.db.pipe.task.runnable.PipeSubtask;
-
-public class PipeProcessorSubtaskScheduler implements PipeSubtaskScheduler {
- @Override
- public void createSubtask(String subtaskId, PipeSubtask subtask) {}
-
- @Override
- public void dropSubtask(String subtaskId) {}
-
- @Override
- public void startSubtask(String subtaskId) {}
-
- @Override
- public void stopSubtask(String subtaskId) {}
-}
diff --git a/server/src/main/java/org/apache/iotdb/db/pipe/execution/scheduler/PipeSubtaskScheduler.java b/server/src/main/java/org/apache/iotdb/db/pipe/execution/scheduler/PipeSubtaskScheduler.java
deleted file mode 100644
index c87f949103..0000000000
--- a/server/src/main/java/org/apache/iotdb/db/pipe/execution/scheduler/PipeSubtaskScheduler.java
+++ /dev/null
@@ -1,33 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.iotdb.db.pipe.execution.scheduler;
-
-import org.apache.iotdb.db.pipe.task.runnable.PipeSubtask;
-
-public interface PipeSubtaskScheduler {
-
- void createSubtask(String subtaskId, PipeSubtask subtask);
-
- void dropSubtask(String subtaskId);
-
- void startSubtask(String subtaskId);
-
- void stopSubtask(String subtaskId);
-}
diff --git a/server/src/main/java/org/apache/iotdb/db/pipe/execution/scheduler/PipeTaskScheduler.java b/server/src/main/java/org/apache/iotdb/db/pipe/execution/scheduler/PipeTaskScheduler.java
index cda2cc9466..188bbac0e4 100644
--- a/server/src/main/java/org/apache/iotdb/db/pipe/execution/scheduler/PipeTaskScheduler.java
+++ b/server/src/main/java/org/apache/iotdb/db/pipe/execution/scheduler/PipeTaskScheduler.java
@@ -19,34 +19,48 @@
package org.apache.iotdb.db.pipe.execution.scheduler;
-import org.apache.iotdb.db.pipe.task.PipeTask;
+import org.apache.iotdb.db.pipe.execution.executor.PipeTaskExecutorManager;
/**
- * PipeTaskScheduler is responsible for scheduling the pipe tasks. It takes the pipe tasks and
- * executes them in the PipeTaskExecutor. It is a singleton class.
+ * PipeTaskScheduler is a singleton class that manages the numbers of threads used by
+ * PipeTaskExecutors dynamically.
*/
public class PipeTaskScheduler {
- private final PipeSubtaskScheduler assignerSubtaskScheduler;
- private final PipeSubtaskScheduler processorSubtaskScheduler;
- private final PipeSubtaskScheduler connectorSubtaskScheduler;
+ private final PipeTaskExecutorManager pipeTaskExecutorManager =
+ PipeTaskExecutorManager.setupAndGetInstance();
- public void createPipeTask(PipeTask pipeTask) {}
+ public void adjustAssignerSubtaskExecutorThreadNum(int threadNum) {
+ // TODO: make it configurable by setting different parameters
+ pipeTaskExecutorManager.getAssignerSubtaskExecutor().adjustExecutorThreadNumber(threadNum);
+ }
- public void dropPipeTask(String pipeName) {}
+ public int getAssignerSubtaskExecutorThreadNum() {
+ return pipeTaskExecutorManager.getAssignerSubtaskExecutor().getExecutorThreadNumber();
+ }
- public void startPipeTask(String pipeName) {}
+ public void adjustConnectorSubtaskExecutorThreadNum(int threadNum) {
+ // TODO: make it configurable by setting different parameters
+ pipeTaskExecutorManager.getConnectorSubtaskExecutor().adjustExecutorThreadNumber(threadNum);
+ }
- public void stopPipeTask(String pipeName) {}
+ public int getConnectorSubtaskExecutorThreadNum() {
+ return pipeTaskExecutorManager.getConnectorSubtaskExecutor().getExecutorThreadNumber();
+ }
- ///////////////////////// Singleton Instance Holder /////////////////////////
+ public void adjustProcessorSubtaskExecutorThreadNum(int threadNum) {
+ // TODO: make it configurable by setting different parameters
+ pipeTaskExecutorManager.getProcessorSubtaskExecutor().adjustExecutorThreadNumber(threadNum);
+ }
- private PipeTaskScheduler() {
- assignerSubtaskScheduler = new PipeAssignerSubtaskScheduler();
- processorSubtaskScheduler = new PipeProcessorSubtaskScheduler();
- connectorSubtaskScheduler = new PipeConnectorSubtaskScheduler();
+ public int getProcessorSubtaskExecutorThreadNum() {
+ return pipeTaskExecutorManager.getProcessorSubtaskExecutor().getExecutorThreadNumber();
}
+ ///////////////////////// Singleton Instance Holder /////////////////////////
+
+ private PipeTaskScheduler() {}
+
private static class PipeTaskSchedulerHolder {
private static PipeTaskScheduler instance = null;
}
diff --git a/server/src/main/java/org/apache/iotdb/db/pipe/task/PipeTask.java b/server/src/main/java/org/apache/iotdb/db/pipe/task/PipeTask.java
index e3a2819deb..416c2c17fa 100644
--- a/server/src/main/java/org/apache/iotdb/db/pipe/task/PipeTask.java
+++ b/server/src/main/java/org/apache/iotdb/db/pipe/task/PipeTask.java
@@ -19,7 +19,6 @@
package org.apache.iotdb.db.pipe.task;
-import org.apache.iotdb.db.pipe.task.metrics.PipeTaskRuntimeRecorder;
import org.apache.iotdb.db.pipe.task.stage.PipeTaskStage;
public class PipeTask {
@@ -30,8 +29,6 @@ public class PipeTask {
private final PipeTaskStage processorStage;
private final PipeTaskStage connectorStage;
- private final PipeTaskRuntimeRecorder runtimeRecorder;
-
public PipeTask(
String pipeName,
PipeTaskStage collectorStage,
@@ -42,7 +39,33 @@ public class PipeTask {
this.collectorStage = collectorStage;
this.processorStage = processorStage;
this.connectorStage = connectorStage;
+ }
+
+ public void create() {
+ collectorStage.create();
+ processorStage.create();
+ connectorStage.create();
+ }
+
+ public void drop() {
+ collectorStage.drop();
+ processorStage.drop();
+ connectorStage.drop();
+ }
+
+ public void start() {
+ collectorStage.start();
+ processorStage.start();
+ connectorStage.start();
+ }
+
+ public void stop() {
+ collectorStage.stop();
+ processorStage.stop();
+ connectorStage.stop();
+ }
- runtimeRecorder = new PipeTaskRuntimeRecorder();
+ public String getPipeName() {
+ return pipeName;
}
}
diff --git a/server/src/main/java/org/apache/iotdb/db/pipe/task/runnable/PipeSubtask.java b/server/src/main/java/org/apache/iotdb/db/pipe/task/callable/DecoratingLock.java
similarity index 60%
rename from server/src/main/java/org/apache/iotdb/db/pipe/task/runnable/PipeSubtask.java
rename to server/src/main/java/org/apache/iotdb/db/pipe/task/callable/DecoratingLock.java
index daebd15e47..54f562416b 100644
--- a/server/src/main/java/org/apache/iotdb/db/pipe/task/runnable/PipeSubtask.java
+++ b/server/src/main/java/org/apache/iotdb/db/pipe/task/callable/DecoratingLock.java
@@ -17,20 +17,28 @@
* under the License.
*/
-package org.apache.iotdb.db.pipe.task.runnable;
+package org.apache.iotdb.db.pipe.task.callable;
-import org.apache.iotdb.commons.concurrent.WrappedRunnable;
+import java.util.concurrent.atomic.AtomicBoolean;
-public abstract class PipeSubtask extends WrappedRunnable {
+public class DecoratingLock {
+ private final AtomicBoolean isDecorating = new AtomicBoolean(false);
- private final String taskID;
+ public void waitForDecorated() {
+ while (isDecorating.get()) {
+ try {
+ Thread.sleep(10);
+ } catch (InterruptedException e) {
+ Thread.currentThread().interrupt();
+ }
+ }
+ }
- public PipeSubtask(String taskID) {
- super();
- this.taskID = taskID;
+ public void markAsDecorating() {
+ isDecorating.set(true);
}
- public String getTaskID() {
- return taskID;
+ public void markAsDecorated() {
+ isDecorating.set(false);
}
}
diff --git a/server/src/main/java/org/apache/iotdb/db/pipe/task/runnable/PipeAssignerSubtask.java b/server/src/main/java/org/apache/iotdb/db/pipe/task/callable/PipeAssignerSubtask.java
similarity index 89%
rename from server/src/main/java/org/apache/iotdb/db/pipe/task/runnable/PipeAssignerSubtask.java
rename to server/src/main/java/org/apache/iotdb/db/pipe/task/callable/PipeAssignerSubtask.java
index 5890801ea6..64d421cf8c 100644
--- a/server/src/main/java/org/apache/iotdb/db/pipe/task/runnable/PipeAssignerSubtask.java
+++ b/server/src/main/java/org/apache/iotdb/db/pipe/task/callable/PipeAssignerSubtask.java
@@ -17,7 +17,7 @@
* under the License.
*/
-package org.apache.iotdb.db.pipe.task.runnable;
+package org.apache.iotdb.db.pipe.task.callable;
public class PipeAssignerSubtask extends PipeSubtask {
@@ -26,5 +26,7 @@ public class PipeAssignerSubtask extends PipeSubtask {
}
@Override
- public void runMayThrow() throws Throwable {}
+ protected void executeForAWhile() {
+ // do nothing
+ }
}
diff --git a/server/src/main/java/org/apache/iotdb/db/pipe/task/runnable/PipeConnectorSubtask.java b/server/src/main/java/org/apache/iotdb/db/pipe/task/callable/PipeConnectorSubtask.java
similarity index 68%
rename from server/src/main/java/org/apache/iotdb/db/pipe/task/runnable/PipeConnectorSubtask.java
rename to server/src/main/java/org/apache/iotdb/db/pipe/task/callable/PipeConnectorSubtask.java
index b199607241..dd1ec6db5d 100644
--- a/server/src/main/java/org/apache/iotdb/db/pipe/task/runnable/PipeConnectorSubtask.java
+++ b/server/src/main/java/org/apache/iotdb/db/pipe/task/callable/PipeConnectorSubtask.java
@@ -17,14 +17,21 @@
* under the License.
*/
-package org.apache.iotdb.db.pipe.task.runnable;
+package org.apache.iotdb.db.pipe.task.callable;
+
+import org.apache.iotdb.db.pipe.core.connector.PipeConnectorPluginRuntimeWrapper;
public class PipeConnectorSubtask extends PipeSubtask {
- public PipeConnectorSubtask(String taskID) {
+ private final PipeConnectorPluginRuntimeWrapper pipeConnector;
+
+ public PipeConnectorSubtask(String taskID, PipeConnectorPluginRuntimeWrapper pipeConnector) {
super(taskID);
+ this.pipeConnector = pipeConnector;
}
@Override
- public void runMayThrow() throws Throwable {}
+ protected void executeForAWhile() {
+ pipeConnector.executeForAWhile();
+ }
}
diff --git a/server/src/main/java/org/apache/iotdb/db/pipe/task/runnable/PipeProcessorSubtask.java b/server/src/main/java/org/apache/iotdb/db/pipe/task/callable/PipeProcessorSubtask.java
similarity index 68%
rename from server/src/main/java/org/apache/iotdb/db/pipe/task/runnable/PipeProcessorSubtask.java
rename to server/src/main/java/org/apache/iotdb/db/pipe/task/callable/PipeProcessorSubtask.java
index cfdf0123e1..b13dbc4c18 100644
--- a/server/src/main/java/org/apache/iotdb/db/pipe/task/runnable/PipeProcessorSubtask.java
+++ b/server/src/main/java/org/apache/iotdb/db/pipe/task/callable/PipeProcessorSubtask.java
@@ -17,14 +17,21 @@
* under the License.
*/
-package org.apache.iotdb.db.pipe.task.runnable;
+package org.apache.iotdb.db.pipe.task.callable;
+
+import org.apache.iotdb.db.pipe.core.processor.PipeProcessorPluginRuntimeWrapper;
public class PipeProcessorSubtask extends PipeSubtask {
- public PipeProcessorSubtask(String taskID) {
+ private final PipeProcessorPluginRuntimeWrapper pipeProcessor;
+
+ public PipeProcessorSubtask(String taskID, PipeProcessorPluginRuntimeWrapper pipeProcessor) {
super(taskID);
+ this.pipeProcessor = pipeProcessor;
}
@Override
- public void runMayThrow() throws Throwable {}
+ protected void executeForAWhile() {
+ pipeProcessor.executeForAWhile();
+ }
}
diff --git a/server/src/main/java/org/apache/iotdb/db/pipe/task/callable/PipeSubtask.java b/server/src/main/java/org/apache/iotdb/db/pipe/task/callable/PipeSubtask.java
new file mode 100644
index 0000000000..983d70b32b
--- /dev/null
+++ b/server/src/main/java/org/apache/iotdb/db/pipe/task/callable/PipeSubtask.java
@@ -0,0 +1,135 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iotdb.db.pipe.task.callable;
+
+import org.apache.iotdb.db.pipe.agent.runtime.PipeRuntimeAgent;
+
+import com.google.common.util.concurrent.FutureCallback;
+import com.google.common.util.concurrent.Futures;
+import com.google.common.util.concurrent.ListenableFuture;
+import com.google.common.util.concurrent.ListeningExecutorService;
+import org.jetbrains.annotations.NotNull;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.concurrent.Callable;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicInteger;
+
+public abstract class PipeSubtask implements FutureCallback<Void>, Callable<Void> {
+
+ private static final Logger LOGGER = LoggerFactory.getLogger(PipeSubtask.class);
+
+ private final String taskID;
+
+ private ListeningExecutorService subtaskWorkerThreadPoolExecutor;
+ private ExecutorService subtaskCallbackListeningExecutor;
+
+ private final DecoratingLock callbackDecoratingLock = new DecoratingLock();
+
+ private static final int MAX_RETRY_TIMES = 5;
+ private final AtomicInteger retryCount = new AtomicInteger(0);
+
+ private Throwable lastFailedCause;
+
+ private final AtomicBoolean shouldStopSubmittingSelf = new AtomicBoolean(true);
+
+ public PipeSubtask(String taskID) {
+ super();
+ this.taskID = taskID;
+ }
+
+ public void bindExecutors(
+ ListeningExecutorService subtaskWorkerThreadPoolExecutor,
+ ExecutorService subtaskCallbackListeningExecutor) {
+ this.subtaskWorkerThreadPoolExecutor = subtaskWorkerThreadPoolExecutor;
+ this.subtaskCallbackListeningExecutor = subtaskCallbackListeningExecutor;
+ }
+
+ @Override
+ public Void call() throws Exception {
+ executeForAWhile();
+
+ // wait for the callable to be decorated by Futures.addCallback in the executorService
+ // to make sure that the callback can be submitted again on success or failure.
+ callbackDecoratingLock.waitForDecorated();
+
+ return null;
+ }
+
+ protected abstract void executeForAWhile() throws Exception;
+
+ @Override
+ public void onSuccess(Void result) {
+ retryCount.set(0);
+ submitSelf();
+ }
+
+ @Override
+ public void onFailure(@NotNull Throwable throwable) {
+ if (retryCount.get() < MAX_RETRY_TIMES) {
+ retryCount.incrementAndGet();
+ submitSelf();
+ } else {
+ LOGGER.warn(
+ "Subtask {} failed, has been retried for {} times, last failed because of {}",
+ taskID,
+ retryCount,
+ throwable);
+ lastFailedCause = throwable;
+ PipeRuntimeAgent.setupAndGetInstance().report(this);
+ }
+ }
+
+ public void submitSelf() {
+ if (shouldStopSubmittingSelf.get()) {
+ return;
+ }
+
+ callbackDecoratingLock.markAsDecorating();
+ try {
+ final ListenableFuture<Void> nextFuture = subtaskWorkerThreadPoolExecutor.submit(this);
+ Futures.addCallback(nextFuture, this, subtaskCallbackListeningExecutor);
+ } finally {
+ callbackDecoratingLock.markAsDecorated();
+ }
+ }
+
+ public void allowSubmittingSelf() {
+ shouldStopSubmittingSelf.set(false);
+ }
+
+ public void disallowSubmittingSelf() {
+ shouldStopSubmittingSelf.set(true);
+ }
+
+ public boolean isSubmittingSelf() {
+ return !shouldStopSubmittingSelf.get();
+ }
+
+ public String getTaskID() {
+ return taskID;
+ }
+
+ public Throwable getLastFailedCause() {
+ return lastFailedCause;
+ }
+}
diff --git a/server/src/main/java/org/apache/iotdb/db/pipe/task/metrics/PipeTaskRuntimeRecorder.java b/server/src/main/java/org/apache/iotdb/db/pipe/task/metrics/PipeTaskRuntimeRecorder.java
deleted file mode 100644
index 6f09614958..0000000000
--- a/server/src/main/java/org/apache/iotdb/db/pipe/task/metrics/PipeTaskRuntimeRecorder.java
+++ /dev/null
@@ -1,22 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.iotdb.db.pipe.task.metrics;
-
-public class PipeTaskRuntimeRecorder {}
diff --git a/server/src/main/java/org/apache/iotdb/db/pipe/task/stage/PipeTaskCollectorStage.java b/server/src/main/java/org/apache/iotdb/db/pipe/task/stage/PipeTaskCollectorStage.java
index 930a06cc94..1962555dae 100644
--- a/server/src/main/java/org/apache/iotdb/db/pipe/task/stage/PipeTaskCollectorStage.java
+++ b/server/src/main/java/org/apache/iotdb/db/pipe/task/stage/PipeTaskCollectorStage.java
@@ -19,19 +19,13 @@
package org.apache.iotdb.db.pipe.task.stage;
-import org.apache.iotdb.pipe.api.exception.PipeException;
+import org.apache.iotdb.db.pipe.execution.executor.PipeAssignerSubtaskExecutor;
+import org.apache.iotdb.db.pipe.task.callable.PipeAssignerSubtask;
-public class PipeTaskCollectorStage implements PipeTaskStage {
+public class PipeTaskCollectorStage extends PipeTaskStage {
- @Override
- public void create() throws PipeException {}
-
- @Override
- public void start() throws PipeException {}
-
- @Override
- public void stop() throws PipeException {}
-
- @Override
- public void drop() throws PipeException {}
+ protected PipeTaskCollectorStage(
+ PipeAssignerSubtaskExecutor executor, PipeAssignerSubtask subtask) {
+ super(executor, subtask);
+ }
}
diff --git a/server/src/main/java/org/apache/iotdb/db/pipe/task/stage/PipeTaskConnectorStage.java b/server/src/main/java/org/apache/iotdb/db/pipe/task/stage/PipeTaskConnectorStage.java
index fddb99fb5e..5b7f48e16b 100644
--- a/server/src/main/java/org/apache/iotdb/db/pipe/task/stage/PipeTaskConnectorStage.java
+++ b/server/src/main/java/org/apache/iotdb/db/pipe/task/stage/PipeTaskConnectorStage.java
@@ -19,19 +19,13 @@
package org.apache.iotdb.db.pipe.task.stage;
-import org.apache.iotdb.pipe.api.exception.PipeException;
+import org.apache.iotdb.db.pipe.execution.executor.PipeConnectorSubtaskExecutor;
+import org.apache.iotdb.db.pipe.task.callable.PipeConnectorSubtask;
-public class PipeTaskConnectorStage implements PipeTaskStage {
+public class PipeTaskConnectorStage extends PipeTaskStage {
- @Override
- public void create() throws PipeException {}
-
- @Override
- public void start() throws PipeException {}
-
- @Override
- public void stop() throws PipeException {}
-
- @Override
- public void drop() throws PipeException {}
+ protected PipeTaskConnectorStage(
+ PipeConnectorSubtaskExecutor executor, PipeConnectorSubtask subtask) {
+ super(executor, subtask);
+ }
}
diff --git a/server/src/main/java/org/apache/iotdb/db/pipe/task/stage/PipeTaskProcessorStage.java b/server/src/main/java/org/apache/iotdb/db/pipe/task/stage/PipeTaskProcessorStage.java
index 5a22721a8c..6f2e058193 100644
--- a/server/src/main/java/org/apache/iotdb/db/pipe/task/stage/PipeTaskProcessorStage.java
+++ b/server/src/main/java/org/apache/iotdb/db/pipe/task/stage/PipeTaskProcessorStage.java
@@ -19,19 +19,13 @@
package org.apache.iotdb.db.pipe.task.stage;
-import org.apache.iotdb.pipe.api.exception.PipeException;
+import org.apache.iotdb.db.pipe.execution.executor.PipeProcessorSubtaskExecutor;
+import org.apache.iotdb.db.pipe.task.callable.PipeProcessorSubtask;
-public class PipeTaskProcessorStage implements PipeTaskStage {
+public class PipeTaskProcessorStage extends PipeTaskStage {
- @Override
- public void create() throws PipeException {}
-
- @Override
- public void start() throws PipeException {}
-
- @Override
- public void stop() throws PipeException {}
-
- @Override
- public void drop() throws PipeException {}
+ protected PipeTaskProcessorStage(
+ PipeProcessorSubtaskExecutor executor, PipeProcessorSubtask subtask) {
+ super(executor, subtask);
+ }
}
diff --git a/server/src/main/java/org/apache/iotdb/db/pipe/task/stage/PipeTaskStage.java b/server/src/main/java/org/apache/iotdb/db/pipe/task/stage/PipeTaskStage.java
index 09ae67ef76..3cbe0b569b 100644
--- a/server/src/main/java/org/apache/iotdb/db/pipe/task/stage/PipeTaskStage.java
+++ b/server/src/main/java/org/apache/iotdb/db/pipe/task/stage/PipeTaskStage.java
@@ -19,35 +19,62 @@
package org.apache.iotdb.db.pipe.task.stage;
+import org.apache.iotdb.db.pipe.execution.executor.PipeSubtaskExecutor;
+import org.apache.iotdb.db.pipe.task.callable.PipeSubtask;
import org.apache.iotdb.pipe.api.exception.PipeException;
-public interface PipeTaskStage {
+public abstract class PipeTaskStage {
+
+ protected final PipeSubtaskExecutor executor;
+ protected final PipeSubtask subtask;
+
+ protected PipeTaskStage(PipeSubtaskExecutor executor, PipeSubtask subtask) {
+ this.executor = executor;
+ this.subtask = subtask;
+ }
/**
* Create a pipe task stage.
*
* @throws PipeException if failed to create a pipe task stage.
*/
- void create() throws PipeException;
+ public final void create() throws PipeException {
+ executor.register(subtask);
+ }
/**
* Start a pipe task stage.
*
* @throws PipeException if failed to start a pipe task stage.
*/
- void start() throws PipeException;
+ public final void start() throws PipeException {
+ executor.start(subtask.getTaskID());
+ }
/**
* Stop a pipe task stage.
*
* @throws PipeException if failed to stop a pipe task stage.
*/
- void stop() throws PipeException;
+ public final void stop() throws PipeException {
+ executor.stop(subtask.getTaskID());
+ }
/**
* Drop a pipe task stage.
*
* @throws PipeException if failed to drop a pipe task stage.
*/
- void drop() throws PipeException;
+ public final void drop() throws PipeException {
+ executor.deregister(subtask.getTaskID());
+ }
+
+ /**
+ * Get the pipe subtask.
+ *
+ * @return the pipe subtask.
+ */
+ public final PipeSubtask getSubtask() {
+ return subtask;
+ }
}
diff --git a/server/src/main/java/org/apache/iotdb/db/pipe/execution/executor/PipeSubtaskExecutor.java b/server/src/test/java/org/apache/iotdb/db/pipe/execution/executor/PipeAssignerSubtaskExecutorTest.java
similarity index 63%
copy from server/src/main/java/org/apache/iotdb/db/pipe/execution/executor/PipeSubtaskExecutor.java
copy to server/src/test/java/org/apache/iotdb/db/pipe/execution/executor/PipeAssignerSubtaskExecutorTest.java
index 7d97605dff..43bb0673f7 100644
--- a/server/src/main/java/org/apache/iotdb/db/pipe/execution/executor/PipeSubtaskExecutor.java
+++ b/server/src/test/java/org/apache/iotdb/db/pipe/execution/executor/PipeAssignerSubtaskExecutorTest.java
@@ -19,4 +19,22 @@
package org.apache.iotdb.db.pipe.execution.executor;
-public interface PipeSubtaskExecutor {}
+import org.apache.iotdb.db.pipe.task.callable.PipeAssignerSubtask;
+
+import org.junit.Before;
+import org.mockito.Mockito;
+
+public class PipeAssignerSubtaskExecutorTest extends PipeSubtaskExecutorTest {
+
+ @Before
+ public void setUp() throws Exception {
+ executor = new PipeAssignerSubtaskExecutor();
+
+ subtask =
+ Mockito.spy(
+ new PipeAssignerSubtask("PipeAssignerSubtaskExecutorTest") {
+ @Override
+ public void executeForAWhile() {}
+ });
+ }
+}
diff --git a/server/src/main/java/org/apache/iotdb/db/pipe/execution/executor/PipeAssignerSubtaskExecutor.java b/server/src/test/java/org/apache/iotdb/db/pipe/execution/executor/PipeConnectorSubtaskExecutorTest.java
similarity index 55%
copy from server/src/main/java/org/apache/iotdb/db/pipe/execution/executor/PipeAssignerSubtaskExecutor.java
copy to server/src/test/java/org/apache/iotdb/db/pipe/execution/executor/PipeConnectorSubtaskExecutorTest.java
index cc3dd43987..694ee93b92 100644
--- a/server/src/main/java/org/apache/iotdb/db/pipe/execution/executor/PipeAssignerSubtaskExecutor.java
+++ b/server/src/test/java/org/apache/iotdb/db/pipe/execution/executor/PipeConnectorSubtaskExecutorTest.java
@@ -19,4 +19,26 @@
package org.apache.iotdb.db.pipe.execution.executor;
-public class PipeAssignerSubtaskExecutor implements PipeSubtaskExecutor {}
+import org.apache.iotdb.db.pipe.core.connector.PipeConnectorPluginRuntimeWrapper;
+import org.apache.iotdb.db.pipe.task.callable.PipeConnectorSubtask;
+
+import org.junit.Before;
+import org.mockito.Mockito;
+
+import static org.mockito.Mockito.mock;
+
+public class PipeConnectorSubtaskExecutorTest extends PipeSubtaskExecutorTest {
+
+ @Before
+ public void setUp() throws Exception {
+ executor = new PipeConnectorSubtaskExecutor();
+
+ subtask =
+ Mockito.spy(
+ new PipeConnectorSubtask(
+ "PipeConnectorSubtaskExecutorTest", mock(PipeConnectorPluginRuntimeWrapper.class)) {
+ @Override
+ public void executeForAWhile() {}
+ });
+ }
+}
diff --git a/server/src/main/java/org/apache/iotdb/db/pipe/execution/executor/PipeAssignerSubtaskExecutor.java b/server/src/test/java/org/apache/iotdb/db/pipe/execution/executor/PipeProcessorSubtaskExecutorTest.java
similarity index 55%
copy from server/src/main/java/org/apache/iotdb/db/pipe/execution/executor/PipeAssignerSubtaskExecutor.java
copy to server/src/test/java/org/apache/iotdb/db/pipe/execution/executor/PipeProcessorSubtaskExecutorTest.java
index cc3dd43987..c4446ae5b4 100644
--- a/server/src/main/java/org/apache/iotdb/db/pipe/execution/executor/PipeAssignerSubtaskExecutor.java
+++ b/server/src/test/java/org/apache/iotdb/db/pipe/execution/executor/PipeProcessorSubtaskExecutorTest.java
@@ -19,4 +19,26 @@
package org.apache.iotdb.db.pipe.execution.executor;
-public class PipeAssignerSubtaskExecutor implements PipeSubtaskExecutor {}
+import org.apache.iotdb.db.pipe.core.processor.PipeProcessorPluginRuntimeWrapper;
+import org.apache.iotdb.db.pipe.task.callable.PipeProcessorSubtask;
+
+import org.junit.Before;
+import org.mockito.Mockito;
+
+import static org.mockito.Mockito.mock;
+
+public class PipeProcessorSubtaskExecutorTest extends PipeSubtaskExecutorTest {
+
+ @Before
+ public void setUp() throws Exception {
+ executor = new PipeProcessorSubtaskExecutor();
+
+ subtask =
+ Mockito.spy(
+ new PipeProcessorSubtask(
+ "PipeProcessorSubtaskExecutorTest", mock(PipeProcessorPluginRuntimeWrapper.class)) {
+ @Override
+ public void executeForAWhile() {}
+ });
+ }
+}
diff --git a/server/src/test/java/org/apache/iotdb/db/pipe/execution/executor/PipeSubtaskExecutorTest.java b/server/src/test/java/org/apache/iotdb/db/pipe/execution/executor/PipeSubtaskExecutorTest.java
new file mode 100644
index 0000000000..8b4238d4c2
--- /dev/null
+++ b/server/src/test/java/org/apache/iotdb/db/pipe/execution/executor/PipeSubtaskExecutorTest.java
@@ -0,0 +1,158 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iotdb.db.pipe.execution.executor;
+
+import org.apache.iotdb.db.pipe.task.callable.PipeSubtask;
+
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Test;
+
+import static org.mockito.Mockito.atLeast;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+
+public abstract class PipeSubtaskExecutorTest {
+
+ protected PipeSubtaskExecutor executor;
+ protected PipeSubtask subtask;
+
+ @After
+ public void tearDown() throws Exception {
+ executor.shutdown();
+ Assert.assertTrue(executor.isShutdown());
+ }
+
+ @Test
+ public void testRegister() {
+ Assert.assertFalse(executor.isRegistered(subtask.getTaskID()));
+ Assert.assertEquals(0, executor.getRegisteredSubtaskNumber());
+
+ // test register a subtask which is not in the map
+ executor.register(subtask);
+ Assert.assertTrue(executor.isRegistered(subtask.getTaskID()));
+ Assert.assertEquals(1, executor.getRegisteredSubtaskNumber());
+
+ // test register a subtask which is in the map
+ executor.register(subtask);
+ Assert.assertTrue(executor.isRegistered(subtask.getTaskID()));
+ Assert.assertEquals(1, executor.getRegisteredSubtaskNumber());
+ }
+
+ @Test
+ public void testStart() throws Exception {
+ // test start a subtask which is not in the map
+ executor.start(subtask.getTaskID());
+ try {
+ Thread.sleep(20);
+ } catch (InterruptedException e) {
+ e.printStackTrace();
+ }
+ verify(subtask, times(0)).call();
+
+ // test start a subtask which is in the map
+ executor.register(subtask);
+ executor.start(subtask.getTaskID());
+ try {
+ Thread.sleep(100);
+ } catch (InterruptedException e) {
+ e.printStackTrace();
+ }
+ verify(subtask, atLeast(10)).call();
+ Assert.assertTrue(subtask.isSubmittingSelf());
+
+ // test start a subtask which is in the map and is already running
+ executor.start(subtask.getTaskID());
+ Assert.assertTrue(subtask.isSubmittingSelf());
+ }
+
+ @Test
+ public void testStop() {
+ // test stop a subtask which is not in the map
+ executor.stop(subtask.getTaskID());
+ Assert.assertFalse(subtask.isSubmittingSelf());
+
+ // test stop a subtask which is in the map
+ executor.register(subtask);
+ try {
+ Thread.sleep(20);
+ } catch (InterruptedException e) {
+ e.printStackTrace();
+ }
+ executor.stop(subtask.getTaskID());
+ Assert.assertFalse(subtask.isSubmittingSelf());
+
+ // test stop a running subtask
+ executor.start(subtask.getTaskID());
+ try {
+ Thread.sleep(20);
+ } catch (InterruptedException e) {
+ e.printStackTrace();
+ }
+ executor.stop(subtask.getTaskID());
+ Assert.assertFalse(subtask.isSubmittingSelf());
+
+ // test stop a stopped subtask
+ executor.stop(subtask.getTaskID());
+ Assert.assertFalse(subtask.isSubmittingSelf());
+ }
+
+ @Test
+ public void testDeregister() {
+ // test unregister a subtask which is not in the map
+ executor.deregister(subtask.getTaskID());
+ Assert.assertEquals(0, executor.getRegisteredSubtaskNumber());
+
+ // test unregister a subtask which is in the map
+ executor.register(subtask);
+ Assert.assertEquals(1, executor.getRegisteredSubtaskNumber());
+
+ // test unregister a running subtask
+ executor.start(subtask.getTaskID());
+ try {
+ Thread.sleep(20);
+ } catch (InterruptedException e) {
+ e.printStackTrace();
+ }
+ executor.deregister(subtask.getTaskID());
+ Assert.assertEquals(0, executor.getRegisteredSubtaskNumber());
+ Assert.assertFalse(subtask.isSubmittingSelf());
+
+ // test unregister an unregistered subtask
+ executor.deregister(subtask.getTaskID());
+ Assert.assertEquals(0, executor.getRegisteredSubtaskNumber());
+ Assert.assertFalse(subtask.isSubmittingSelf());
+ }
+
+ @Test
+ public void testShutdown() {
+ // test shutdown a running executor
+ executor.start(subtask.getTaskID());
+ executor.shutdown();
+
+ Assert.assertTrue(executor.isShutdown());
+ Assert.assertFalse(subtask.isSubmittingSelf());
+
+ // test shutdown a stopped executor
+ executor.shutdown();
+ Assert.assertTrue(executor.isShutdown());
+ Assert.assertFalse(subtask.isSubmittingSelf());
+ }
+}