You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@iotdb.apache.org by hu...@apache.org on 2023/03/30 00:19:12 UTC

[iotdb] branch master updated: [IOTDB-5680] Implement the basic data loader on MLNode (#9372)

This is an automated email from the ASF dual-hosted git repository.

hui pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iotdb.git


The following commit(s) were added to refs/heads/master by this push:
     new f6f4728cfd [IOTDB-5680] Implement the basic data loader on MLNode (#9372)
f6f4728cfd is described below

commit f6f4728cfd33948817081082622b6e128bcea977
Author: Yong Liu <li...@gmail.com>
AuthorDate: Thu Mar 30 08:19:05 2023 +0800

    [IOTDB-5680] Implement the basic data loader on MLNode (#9372)
---
 mlnode/iotdb/mlnode/datats/offline/data_source.py |  98 ++++++++++++
 mlnode/iotdb/mlnode/datats/offline/dataset.py     | 116 +++++++++++++++
 mlnode/iotdb/mlnode/datats/utils/__init__.py      |  17 +++
 mlnode/iotdb/mlnode/datats/utils/timefeatures.py  | 173 ++++++++++++++++++++++
 4 files changed, 404 insertions(+)

diff --git a/mlnode/iotdb/mlnode/datats/offline/data_source.py b/mlnode/iotdb/mlnode/datats/offline/data_source.py
new file mode 100644
index 0000000000..cd8e9a891c
--- /dev/null
+++ b/mlnode/iotdb/mlnode/datats/offline/data_source.py
@@ -0,0 +1,98 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+import pandas as pd
+
+from iotdb.mlnode import serde
+from iotdb.mlnode.client import client_manager
+
+
+class DataSource(object):
+    """
+    Pre-fetched in multi-variate time series in memory
+
+    Methods:
+        get_data: returns self.data, the time series value (Numpy.2DArray)
+        get_timestamp: returns self.timestamp, the aligned timestamp value
+    """
+
+    def __init__(self):
+        self.data = None
+        self.timestamp = None
+
+    def _read_data(self):
+        raise NotImplementedError
+
+    def get_data(self):
+        return self.data
+
+    def get_timestamp(self):
+        return self.timestamp
+
+
+class FileDataSource(DataSource):
+    def __init__(self, filename: str = None):
+        super(FileDataSource, self).__init__()
+        self.filename = filename
+        self._read_data()
+
+    def _read_data(self):
+        try:
+            raw_data = pd.read_csv(self.filename)
+        except Exception:
+            raise RuntimeError(f'Fail to load data with filename: {self.filename}')
+        cols_data = raw_data.columns[1:]
+        self.data = raw_data[cols_data].values
+        self.timestamp = pd.to_datetime(raw_data[raw_data.columns[0]].values)
+
+
+class ThriftDataSource(DataSource):
+    def __init__(self, query_expressions: list = None, query_filter: str = None):
+        super(DataSource, self).__init__()
+        self.query_expressions = query_expressions
+        self.query_filter = query_filter
+        self._read_data()
+
+    def _read_data(self):
+        try:
+            data_client = client_manager.borrow_data_node_client()
+        except Exception:  # is this exception catch needed???
+            raise RuntimeError('Fail to establish connection with DataNode')
+
+        try:
+            res = data_client.fetch_timeseries(
+                queryExpressions=self.query_expressions,
+                queryFilter=self.query_filter,
+            )
+        except Exception:
+            raise RuntimeError(f'Fail to fetch data with query expressions: {self.query_expressions}'
+                               f' and query filter: {self.query_filter}')
+
+        if len(res.tsDataset) == 0:
+            raise RuntimeError(f'No data fetched with query filter: {self.query_filter}')
+
+        raw_data = serde.convert_to_df(res.columnNameList,
+                                       res.columnTypeList,
+                                       res.columnNameIndexMap,
+                                       res.tsDataset)
+        if raw_data.empty:
+            raise RuntimeError(f'Fetched empty data with query expressions: '
+                               f'{self.query_expressions} and query filter: {self.query_filter}')
+        cols_data = raw_data.columns[1:]
+        self.data = raw_data[cols_data].values
+        self.timestamp = pd.to_datetime(raw_data[raw_data.columns[0]].values, unit='ms', utc=True) \
+            .tz_convert('Asia/Shanghai')  # for iotdb
diff --git a/mlnode/iotdb/mlnode/datats/offline/dataset.py b/mlnode/iotdb/mlnode/datats/offline/dataset.py
new file mode 100644
index 0000000000..c71aaf87c5
--- /dev/null
+++ b/mlnode/iotdb/mlnode/datats/offline/dataset.py
@@ -0,0 +1,116 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+
+import argparse
+
+from torch.utils.data import Dataset
+
+from iotdb.mlnode.datats.offline.data_source import DataSource
+from iotdb.mlnode.datats.utils.timefeatures import time_features
+
+# currently support for multivariate forecasting only
+
+
+class TimeSeriesDataset(Dataset):
+    """
+    Build Row-by-Row dataset (with each element as multivariable time series at
+    the same time and correponding timestamp embedding)
+
+    Args:
+        data_source: the whole multivariate time series for a while
+        time_embed: embedding frequency, see `utils/timefeatures.py` for more detail
+
+    Returns:
+        Random accessible dataset
+    """
+
+    def __init__(self, data_source: DataSource, time_embed: str = 'h'):
+        self.time_embed = time_embed
+        self.data = data_source.get_data()
+        self.data_stamp = time_features(data_source.get_timestamp(), time_embed=self.time_embed).transpose(1, 0)
+        self.n_vars = self.data.shape[-1]
+
+    def get_variable_num(self):
+        return self.n_vars  # number of series in data_source
+
+    def __getitem__(self, index):
+        seq = self.data[index]
+        seq_t = self.data_stamp[index]
+        return seq, seq_t
+
+    def __len__(self):
+        return len(self.data)
+
+
+class WindowDataset(TimeSeriesDataset):
+    """
+    Build Windowed dataset (with each element as multivariable time series
+    with a sliding window and corresponding timestamps embedding),
+    the sliding step is one unit in give data source
+
+    Args:
+        data_source: the whole multivariate time series for a while
+        time_embed: embedding frequency, see `utils/timefeatures.py` for more detail
+        input_len: input window size (unit) [1, 2, ... I]
+        pred_len: output window size (unit) right after the input window [I+1, I+2, ... I+P]
+
+    Returns:
+        Random accessible dataset
+    """
+
+    def __init__(self,
+                 data_source: DataSource = None,
+                 input_len: int = 96,
+                 pred_len: int = 96,
+                 time_embed: str = 'h'):
+        self.input_len = input_len
+        self.pred_len = pred_len
+        if input_len <= self.data.shape[0]:
+            raise RuntimeError('input_len should not be larger than the number of time series points')
+        if pred_len <= self.data.shape[0]:
+            raise RuntimeError('pred_len should not be larger than the number of time series points')
+        super(WindowDataset, self).__init__(data_source, time_embed)
+
+    def __getitem__(self, index):
+        s_begin = index
+        s_end = s_begin + self.input_len
+        r_begin = s_end
+        r_end = s_end + self.pred_len
+        seq_x = self.data[s_begin:s_end]
+        seq_y = self.data[r_begin:r_end]
+        seq_x_t = self.data_stamp[s_begin:s_end]
+        seq_y_t = self.data_stamp[r_begin:r_end]
+        return seq_x, seq_y, seq_x_t, seq_y_t
+
+    def __len__(self):
+        return len(self.data) - self.input_len - self.pred_len + 1
+
+
+def get_timeseries_dataset(data_config: argparse.Namespace) -> TimeSeriesDataset:
+    # TODO (@lcy)
+    # init datasource
+    # init dataset
+    pass
+
+
+def get_window_dataset(data_config: argparse.Namespace) -> WindowDataset:
+    # TODO (@lcy)
+    # init datasource
+    # init dataset
+    pass
diff --git a/mlnode/iotdb/mlnode/datats/utils/__init__.py b/mlnode/iotdb/mlnode/datats/utils/__init__.py
new file mode 100644
index 0000000000..2a1e720805
--- /dev/null
+++ b/mlnode/iotdb/mlnode/datats/utils/__init__.py
@@ -0,0 +1,17 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
diff --git a/mlnode/iotdb/mlnode/datats/utils/timefeatures.py b/mlnode/iotdb/mlnode/datats/utils/timefeatures.py
new file mode 100644
index 0000000000..bd1681cfbf
--- /dev/null
+++ b/mlnode/iotdb/mlnode/datats/utils/timefeatures.py
@@ -0,0 +1,173 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+
+from typing import List
+
+import numpy as np
+import pandas as pd
+from pandas.tseries import offsets
+from pandas.tseries.frequencies import to_offset
+
+
+class TimeFeature:
+    def __init__(self):
+        pass
+
+    def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
+        pass
+
+    def __repr__(self):
+        return self.__class__.__name__ + '()'
+
+
+class SecondOfMinute(TimeFeature):
+    """Minute of hour encoded as value between [-0.5, 0.5]"""
+
+    def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
+        return index.second / 59.0 - 0.5
+
+
+class MinuteOfHour(TimeFeature):
+    """Minute of hour encoded as value between [-0.5, 0.5]"""
+
+    def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
+        return index.minute / 59.0 - 0.5
+
+
+class HourOfDay(TimeFeature):
+    """Hour of day encoded as value between [-0.5, 0.5]"""
+
+    def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
+        return index.hour / 23.0 - 0.5
+
+
+class DayOfWeek(TimeFeature):
+    """Hour of day encoded as value between [-0.5, 0.5]"""
+
+    def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
+        return index.dayofweek / 6.0 - 0.5
+
+
+class DayOfMonth(TimeFeature):
+    """Day of month encoded as value between [-0.5, 0.5]"""
+
+    def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
+        return (index.day - 1) / 30.0 - 0.5
+
+
+class DayOfYear(TimeFeature):
+    """Day of year encoded as value between [-0.5, 0.5]"""
+
+    def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
+        return (index.dayofyear - 1) / 365.0 - 0.5
+
+
+class MonthOfYear(TimeFeature):
+    """Month of year encoded as value between [-0.5, 0.5]"""
+
+    def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
+        return (index.month - 1) / 11.0 - 0.5
+
+
+class WeekOfYear(TimeFeature):
+    """Week of year encoded as value between [-0.5, 0.5]"""
+
+    def __call__(self, index: pd.DatetimeIndex) -> np.ndarray:
+        return (index.isocalendar().week - 1) / 52.0 - 0.5
+
+
+def time_features_from_frequency_str(freq_str: str) -> List[TimeFeature]:
+    """
+    Embedding timestamp by given frequency string
+    Args:
+        freq_str: frequency string of the form [multiple][granularity] such as '12H', '5min', '1D' etc.
+    Returns:
+        a list of time features that will be appropriate for the given frequency string.
+    """
+
+    features_by_offsets = {
+        offsets.YearEnd: [],
+        offsets.QuarterEnd: [MonthOfYear],
+        offsets.MonthEnd: [MonthOfYear],
+        offsets.Week: [DayOfMonth, WeekOfYear],
+        offsets.Day: [DayOfWeek, DayOfMonth, DayOfYear],
+        offsets.BusinessDay: [DayOfWeek, DayOfMonth, DayOfYear],
+        offsets.Hour: [HourOfDay, DayOfWeek, DayOfMonth, DayOfYear],
+        offsets.Minute: [
+            MinuteOfHour,
+            HourOfDay,
+            DayOfWeek,
+            DayOfMonth,
+            DayOfYear,
+        ],
+        offsets.Second: [
+            SecondOfMinute,
+            MinuteOfHour,
+            HourOfDay,
+            DayOfWeek,
+            DayOfMonth,
+            DayOfYear,
+        ],
+    }
+
+    try:
+        offset = to_offset(freq_str)
+
+        for offset_type, feature_classes in features_by_offsets.items():
+            if isinstance(offset, offset_type):
+                return [cls() for cls in feature_classes]
+    except ValueError:
+        supported_freq_msg = f'''
+        Unsupported time embedding frequency ({freq_str})
+        The following frequencies are supported (case-insensitive):
+            Y   - yearly
+                alias: A
+            M   - monthly
+            W   - weekly
+            D   - daily
+            B   - business days
+            H   - hourly
+            T   - minutely
+                alias: min
+            S   - secondly
+        '''
+        raise RuntimeError(supported_freq_msg)
+
+
+def time_features(dates, time_embed='h'):
+    return np.vstack([feat(dates) for feat in time_features_from_frequency_str(time_embed)])
+
+
+def data_transform(data_raw: pd.DataFrame, freq='h'):
+    """
+    data: dataframe, column 0 is the time stamp
+    """
+    columns = data_raw.columns
+    data = data_raw[columns[1:]]
+    data_stamp = data_raw[columns[0]]
+    return data.values, data_stamp
+
+
+def timestamp_transform(timestamp_raw: pd.DataFrame, freq='h'):
+    """
+    """
+    timestamp = pd.to_datetime(timestamp_raw.values.squeeze(), unit='ms', utc=True).tz_convert('Asia/Shanghai')
+    timestamp = time_features(timestamp, freq=freq)
+    timestamp = timestamp.transpose(1, 0)
+    return timestamp