You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@iotdb.apache.org by hu...@apache.org on 2023/03/30 00:19:12 UTC
[iotdb] branch master updated: [IOTDB-5680] Implement the basic data loader on MLNode (#9372)
This is an automated email from the ASF dual-hosted git repository.
hui pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iotdb.git
The following commit(s) were added to refs/heads/master by this push:
new f6f4728cfd [IOTDB-5680] Implement the basic data loader on MLNode (#9372)
f6f4728cfd is described below
commit f6f4728cfd33948817081082622b6e128bcea977
Author: Yong Liu <li...@gmail.com>
AuthorDate: Thu Mar 30 08:19:05 2023 +0800
[IOTDB-5680] Implement the basic data loader on MLNode (#9372)
---
mlnode/iotdb/mlnode/datats/offline/data_source.py | 98 ++++++++++++
mlnode/iotdb/mlnode/datats/offline/dataset.py | 116 +++++++++++++++
mlnode/iotdb/mlnode/datats/utils/__init__.py | 17 +++
mlnode/iotdb/mlnode/datats/utils/timefeatures.py | 173 ++++++++++++++++++++++
4 files changed, 404 insertions(+)
diff --git a/mlnode/iotdb/mlnode/datats/offline/data_source.py b/mlnode/iotdb/mlnode/datats/offline/data_source.py
new file mode 100644
index 0000000000..cd8e9a891c
--- /dev/null
+++ b/mlnode/iotdb/mlnode/datats/offline/data_source.py
@@ -0,0 +1,98 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+import pandas as pd
+
+from iotdb.mlnode import serde
+from iotdb.mlnode.client import client_manager
+
+
+class DataSource(object):
+ """
+ Pre-fetched in multi-variate time series in memory
+
+ Methods:
+ get_data: returns self.data, the time series value (Numpy.2DArray)
+ get_timestamp: returns self.timestamp, the aligned timestamp value
+ """
+
+ def __init__(self):
+ self.data = None
+ self.timestamp = None
+
+ def _read_data(self):
+ raise NotImplementedError
+
+ def get_data(self):
+ return self.data
+
+ def get_timestamp(self):
+ return self.timestamp
+
+
+class FileDataSource(DataSource):
+ def __init__(self, filename: str = None):
+ super(FileDataSource, self).__init__()
+ self.filename = filename
+ self._read_data()
+
+ def _read_data(self):
+ try:
+ raw_data = pd.read_csv(self.filename)
+ except Exception:
+ raise RuntimeError(f'Fail to load data with filename: {self.filename}')
+ cols_data = raw_data.columns[1:]
+ self.data = raw_data[cols_data].values
+ self.timestamp = pd.to_datetime(raw_data[raw_data.columns[0]].values)
+
+
+class ThriftDataSource(DataSource):
+ def __init__(self, query_expressions: list = None, query_filter: str = None):
+ super(DataSource, self).__init__()
+ self.query_expressions = query_expressions
+ self.query_filter = query_filter
+ self._read_data()
+
+ def _read_data(self):
+ try:
+ data_client = client_manager.borrow_data_node_client()
+ except Exception: # is this exception catch needed???
+ raise RuntimeError('Fail to establish connection with DataNode')
+
+ try:
+ res = data_client.fetch_timeseries(
+ queryExpressions=self.query_expressions,
+ queryFilter=self.query_filter,
+ )
+ except Exception:
+ raise RuntimeError(f'Fail to fetch data with query expressions: {self.query_expressions}'
+ f' and query filter: {self.query_filter}')
+
+ if len(res.tsDataset) == 0:
+ raise RuntimeError(f'No data fetched with query filter: {self.query_filter}')
+
+ raw_data = serde.convert_to_df(res.columnNameList,
+ res.columnTypeList,
+ res.columnNameIndexMap,
+ res.tsDataset)
+ if raw_data.empty:
+ raise RuntimeError(f'Fetched empty data with query expressions: '
+ f'{self.query_expressions} and query filter: {self.query_filter}')
+ cols_data = raw_data.columns[1:]
+ self.data = raw_data[cols_data].values
+ self.timestamp = pd.to_datetime(raw_data[raw_data.columns[0]].values, unit='ms', utc=True) \
+ .tz_convert('Asia/Shanghai') # for iotdb
diff --git a/mlnode/iotdb/mlnode/datats/offline/dataset.py b/mlnode/iotdb/mlnode/datats/offline/dataset.py
new file mode 100644
index 0000000000..c71aaf87c5
--- /dev/null
+++ b/mlnode/iotdb/mlnode/datats/offline/dataset.py
@@ -0,0 +1,116 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+
+import argparse
+
+from torch.utils.data import Dataset
+
+from iotdb.mlnode.datats.offline.data_source import DataSource
+from iotdb.mlnode.datats.utils.timefeatures import time_features
+
+# currently support for multivariate forecasting only
+
+
+class TimeSeriesDataset(Dataset):
+ """
+ Build Row-by-Row dataset (with each element as multivariable time series at
+ the same time and correponding timestamp embedding)
+
+ Args:
+ data_source: the whole multivariate time series for a while
+ time_embed: embedding frequency, see `utils/timefeatures.py` for more detail
+
+ Returns:
+ Random accessible dataset
+ """
+
+ def __init__(self, data_source: DataSource, time_embed: str = 'h'):
+ self.time_embed = time_embed
+ self.data = data_source.get_data()
+ self.data_stamp = time_features(data_source.get_timestamp(), time_embed=self.time_embed).transpose(1, 0)
+ self.n_vars = self.data.shape[-1]
+
+ def get_variable_num(self):
+ return self.n_vars # number of series in data_source
+
+ def __getitem__(self, index):
+ seq = self.data[index]
+ seq_t = self.data_stamp[index]
+ return seq, seq_t
+
+ def __len__(self):
+ return len(self.data)
+
+
+class WindowDataset(TimeSeriesDataset):
+ """
+ Build Windowed dataset (with each element as multivariable time series
+ with a sliding window and corresponding timestamps embedding),
+ the sliding step is one unit in give data source
+
+ Args:
+ data_source: the whole multivariate time series for a while
+ time_embed: embedding frequency, see `utils/timefeatures.py` for more detail
+ input_len: input window size (unit) [1, 2, ... I]
+ pred_len: output window size (unit) right after the input window [I+1, I+2, ... I+P]
+
+ Returns:
+ Random accessible dataset
+ """
+
+ def __init__(self,
+ data_source: DataSource = None,
+ input_len: int = 96,
+ pred_len: int = 96,
+ time_embed: str = 'h'):
+ self.input_len = input_len
+ self.pred_len = pred_len
+ if input_len <= self.data.shape[0]:
+ raise RuntimeError('input_len should not be larger than the number of time series points')
+ if pred_len <= self.data.shape[0]:
+ raise RuntimeError('pred_len should not be larger than the number of time series points')
+ super(WindowDataset, self).__init__(data_source, time_embed)
+
+ def __getitem__(self, index):
+ s_begin = index
+ s_end = s_begin + self.input_len
+ r_begin = s_end
+ r_end = s_end + self.pred_len
+ seq_x = self.data[s_begin:s_end]
+ seq_y = self.data[r_begin:r_end]
+ seq_x_t = self.data_stamp[s_begin:s_end]
+ seq_y_t = self.data_stamp[r_begin:r_end]
+ return seq_x, seq_y, seq_x_t, seq_y_t
+
+ def __len__(self):
+ return len(self.data) - self.input_len - self.pred_len + 1
+
+
+def get_timeseries_dataset(data_config: argparse.Namespace) -> TimeSeriesDataset:
+ # TODO (@lcy)
+ # init datasource
+ # init dataset
+ pass
+
+
+def get_window_dataset(data_config: argparse.Namespace) -> WindowDataset:
+ # TODO (@lcy)
+ # init datasource
+ # init dataset
+ pass
diff --git a/mlnode/iotdb/mlnode/datats/utils/__init__.py b/mlnode/iotdb/mlnode/datats/utils/__init__.py
new file mode 100644
index 0000000000..2a1e720805
--- /dev/null
+++ b/mlnode/iotdb/mlnode/datats/utils/__init__.py
@@ -0,0 +1,17 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
diff --git a/mlnode/iotdb/mlnode/datats/utils/timefeatures.py b/mlnode/iotdb/mlnode/datats/utils/timefeatures.py
new file mode 100644
index 0000000000..bd1681cfbf
--- /dev/null
+++ b/mlnode/iotdb/mlnode/datats/utils/timefeatures.py
@@ -0,0 +1,173 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+
+from typing import List
+
+import numpy as np
+import pandas as pd
+from pandas.tseries import offsets
+from pandas.tseries.frequencies import to_offset
+
+
+class TimeFeature:
+ def __init__(self):
+ pass
+
+ def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
+ pass
+
+ def __repr__(self):
+ return self.__class__.__name__ + '()'
+
+
+class SecondOfMinute(TimeFeature):
+ """Minute of hour encoded as value between [-0.5, 0.5]"""
+
+ def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
+ return index.second / 59.0 - 0.5
+
+
+class MinuteOfHour(TimeFeature):
+ """Minute of hour encoded as value between [-0.5, 0.5]"""
+
+ def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
+ return index.minute / 59.0 - 0.5
+
+
+class HourOfDay(TimeFeature):
+ """Hour of day encoded as value between [-0.5, 0.5]"""
+
+ def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
+ return index.hour / 23.0 - 0.5
+
+
+class DayOfWeek(TimeFeature):
+ """Hour of day encoded as value between [-0.5, 0.5]"""
+
+ def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
+ return index.dayofweek / 6.0 - 0.5
+
+
+class DayOfMonth(TimeFeature):
+ """Day of month encoded as value between [-0.5, 0.5]"""
+
+ def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
+ return (index.day - 1) / 30.0 - 0.5
+
+
+class DayOfYear(TimeFeature):
+ """Day of year encoded as value between [-0.5, 0.5]"""
+
+ def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
+ return (index.dayofyear - 1) / 365.0 - 0.5
+
+
+class MonthOfYear(TimeFeature):
+ """Month of year encoded as value between [-0.5, 0.5]"""
+
+ def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
+ return (index.month - 1) / 11.0 - 0.5
+
+
+class WeekOfYear(TimeFeature):
+ """Week of year encoded as value between [-0.5, 0.5]"""
+
+ def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
+ return (index.isocalendar().week - 1) / 52.0 - 0.5
+
+
+def time_features_from_frequency_str(freq_str: str) -> List[TimeFeature]:
+ """
+ Embedding timestamp by given frequency string
+ Args:
+ freq_str: frequency string of the form [multiple][granularity] such as '12H', '5min', '1D' etc.
+ Returns:
+ a list of time features that will be appropriate for the given frequency string.
+ """
+
+ features_by_offsets = {
+ offsets.YearEnd: [],
+ offsets.QuarterEnd: [MonthOfYear],
+ offsets.MonthEnd: [MonthOfYear],
+ offsets.Week: [DayOfMonth, WeekOfYear],
+ offsets.Day: [DayOfWeek, DayOfMonth, DayOfYear],
+ offsets.BusinessDay: [DayOfWeek, DayOfMonth, DayOfYear],
+ offsets.Hour: [HourOfDay, DayOfWeek, DayOfMonth, DayOfYear],
+ offsets.Minute: [
+ MinuteOfHour,
+ HourOfDay,
+ DayOfWeek,
+ DayOfMonth,
+ DayOfYear,
+ ],
+ offsets.Second: [
+ SecondOfMinute,
+ MinuteOfHour,
+ HourOfDay,
+ DayOfWeek,
+ DayOfMonth,
+ DayOfYear,
+ ],
+ }
+
+ try:
+ offset = to_offset(freq_str)
+
+ for offset_type, feature_classes in features_by_offsets.items():
+ if isinstance(offset, offset_type):
+ return [cls() for cls in feature_classes]
+ except ValueError:
+ supported_freq_msg = f'''
+ Unsupported time embedding frequency ({freq_str})
+ The following frequencies are supported (case-insensitive):
+ Y - yearly
+ alias: A
+ M - monthly
+ W - weekly
+ D - daily
+ B - business days
+ H - hourly
+ T - minutely
+ alias: min
+ S - secondly
+ '''
+ raise RuntimeError(supported_freq_msg)
+
+
+def time_features(dates, time_embed='h'):
+ return np.vstack([feat(dates) for feat in time_features_from_frequency_str(time_embed)])
+
+
+def data_transform(data_raw: pd.DataFrame, freq='h'):
+ """
+ data: dataframe, column 0 is the time stamp
+ """
+ columns = data_raw.columns
+ data = data_raw[columns[1:]]
+ data_stamp = data_raw[columns[0]]
+ return data.values, data_stamp
+
+
+def timestamp_transform(timestamp_raw: pd.DataFrame, freq='h'):
+ """
+ """
+ timestamp = pd.to_datetime(timestamp_raw.values.squeeze(), unit='ms', utc=True).tz_convert('Asia/Shanghai')
+ timestamp = time_features(timestamp, freq=freq)
+ timestamp = timestamp.transpose(1, 0)
+ return timestamp