You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@iotdb.apache.org by ha...@apache.org on 2022/04/08 04:26:16 UTC
[iotdb] 02/03: dev
This is an automated email from the ASF dual-hosted git repository.
haonan pushed a commit to branch numpy_type_check
in repository https://gitbox.apache.org/repos/asf/iotdb.git
commit 508ab06954cd3a2ff14497b8111c7e54574fa95a
Author: HTHou <hh...@outlook.com>
AuthorDate: Fri Apr 8 12:03:44 2022 +0800
dev
---
client-py/iotdb/utils/IoTDBConstants.py | 11 +++++++++++
client-py/iotdb/utils/NumpyTablet.py | 9 ++++-----
client-py/tests/test_numpy_tablet.py | 32 ++++++++++++++++++++++----------
3 files changed, 37 insertions(+), 15 deletions(-)
diff --git a/client-py/iotdb/utils/IoTDBConstants.py b/client-py/iotdb/utils/IoTDBConstants.py
index 7b992e3d7a..ef66741fec 100644
--- a/client-py/iotdb/utils/IoTDBConstants.py
+++ b/client-py/iotdb/utils/IoTDBConstants.py
@@ -17,6 +17,7 @@
#
from enum import Enum, unique
+import numpy as np
@unique
@@ -36,6 +37,16 @@ class TSDataType(Enum):
def __hash__(self):
return self.value
+ def np_dtype(self):
+ return {
+ TSDataType.BOOLEAN: np.dtype(">?"),
+ TSDataType.FLOAT: np.dtype(">f4"),
+ TSDataType.DOUBLE: np.dtype(">f8"),
+ TSDataType.INT32: np.dtype(">i4"),
+ TSDataType.INT64: np.dtype(">i8"),
+ TSDataType.TEXT: np.dtype("str"),
+ }[self]
+
@unique
class TSEncoding(Enum):
diff --git a/client-py/iotdb/utils/NumpyTablet.py b/client-py/iotdb/utils/NumpyTablet.py
index 21e651d2d6..b81a172a40 100644
--- a/client-py/iotdb/utils/NumpyTablet.py
+++ b/client-py/iotdb/utils/NumpyTablet.py
@@ -17,7 +17,6 @@
#
import struct
-import numpy as np
from iotdb.utils.IoTDBConstants import TSDataType
from iotdb.utils.BitMap import BitMap
@@ -54,11 +53,11 @@ class NumpyTablet(object):
for i in range(len(values)):
values[i] = values[i][index]
- if timestamps.dtype != np.dtype(">i8"):
- timestamps = timestamps.astype(np.dtype(">i8"))
+ if timestamps.dtype != TSDataType.INT64.np_dtype():
+ timestamps = timestamps.astype(TSDataType.INT64.np_dtype())
for i in range(len(values)):
-
-
+ if values[i].dtype != data_types[i].np_dtype():
+ values[i] = values[i].astype(data_types[i].np_dtype())
self.__values = values
self.__timestamps = timestamps
diff --git a/client-py/tests/test_numpy_tablet.py b/client-py/tests/test_numpy_tablet.py
index ea7502ce96..b984193975 100644
--- a/client-py/tests/test_numpy_tablet.py
+++ b/client-py/tests/test_numpy_tablet.py
@@ -82,7 +82,7 @@ def test_sort_numpy_tablet():
"root.sg_test_01.d_01", measurements_, data_types_, values_, timestamps_
)
np_values_unsorted = [
- np.array([False, False, False, True, True], np.dtype('>?')),
+ np.array([False, False, False, True, True], np.dtype(">?")),
np.array([0, 10, 100, 1000, 10000], np.dtype(">i4")),
np.array([1, 11, 111, 1111, 11111], np.dtype(">i8")),
np.array([1.1, 1.25, 188.1, 0, 8.999], np.dtype(">f4")),
@@ -91,12 +91,17 @@ def test_sort_numpy_tablet():
]
np_timestamps_unsorted = np.array([9, 8, 7, 6, 5], np.dtype(">i8"))
np_tablet_ = NumpyTablet(
- "root.sg_test_01.d_01", measurements_, data_types_, np_values_unsorted, np_timestamps_unsorted
+ "root.sg_test_01.d_01",
+ measurements_,
+ data_types_,
+ np_values_unsorted,
+ np_timestamps_unsorted,
)
assert tablet_.get_binary_timestamps() == np_tablet_.get_binary_timestamps()
assert tablet_.get_binary_values() == np_tablet_.get_binary_values()
-def test_numpy_tablet_correct_endian():
+
+def test_numpy_tablet_auto_correct_datatype():
measurements_ = ["s_01", "s_02", "s_03", "s_04", "s_05", "s_06"]
data_types_ = [
@@ -119,17 +124,24 @@ def test_numpy_tablet_correct_endian():
"root.sg_test_01.d_01", measurements_, data_types_, values_, timestamps_
)
np_values_unsorted = [
- np.array([False, False, False, True, True], np.dtype('>?')),
- np.array([0, 10, 100, 1000, 10000], np.dtype(">i4")),
- np.array([1, 11, 111, 1111, 11111], np.dtype(">i8")),
- np.array([1.1, 1.25, 188.1, 0, 8.999], np.dtype(">f4")),
- np.array([10011.1, 101.0, 688.25, 6.25, 776], np.dtype(">f8")),
+ np.array([False, False, False, True, True]),
+ np.array([0, 10, 100, 1000, 10000]),
+ np.array([1, 11, 111, 1111, 11111]),
+ np.array([1.1, 1.25, 188.1, 0, 8.999]),
+ np.array([10011.1, 101.0, 688.25, 6.25, 776]),
np.array(["test09", "test08", "test07", "test06", "test05"]),
]
np_timestamps_unsorted = np.array([9, 8, 7, 6, 5])
+ # numpy.dtype of int and float should be little endian by default
assert np_timestamps_unsorted.dtype != np.dtype(">i8")
+ for i in range(1, 4):
+ assert np_values_unsorted[i].dtype != data_types_[i].np_dtype()
np_tablet_ = NumpyTablet(
- "root.sg_test_01.d_01", measurements_, data_types_, np_values_unsorted, np_timestamps_unsorted
+ "root.sg_test_01.d_01",
+ measurements_,
+ data_types_,
+ np_values_unsorted,
+ np_timestamps_unsorted,
)
assert tablet_.get_binary_timestamps() == np_tablet_.get_binary_timestamps()
- assert tablet_.get_binary_values() == np_tablet_.get_binary_values()
\ No newline at end of file
+ assert tablet_.get_binary_values() == np_tablet_.get_binary_values()