You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@iotdb.apache.org by ha...@apache.org on 2022/04/08 04:26:16 UTC

[iotdb] 02/03: dev

This is an automated email from the ASF dual-hosted git repository.

haonan pushed a commit to branch numpy_type_check
in repository https://gitbox.apache.org/repos/asf/iotdb.git

commit 508ab06954cd3a2ff14497b8111c7e54574fa95a
Author: HTHou <hh...@outlook.com>
AuthorDate: Fri Apr 8 12:03:44 2022 +0800

    dev
---
 client-py/iotdb/utils/IoTDBConstants.py | 11 +++++++++++
 client-py/iotdb/utils/NumpyTablet.py    |  9 ++++-----
 client-py/tests/test_numpy_tablet.py    | 32 ++++++++++++++++++++++----------
 3 files changed, 37 insertions(+), 15 deletions(-)

diff --git a/client-py/iotdb/utils/IoTDBConstants.py b/client-py/iotdb/utils/IoTDBConstants.py
index 7b992e3d7a..ef66741fec 100644
--- a/client-py/iotdb/utils/IoTDBConstants.py
+++ b/client-py/iotdb/utils/IoTDBConstants.py
@@ -17,6 +17,7 @@
 #
 
 from enum import Enum, unique
+import numpy as np
 
 
 @unique
@@ -36,6 +37,16 @@ class TSDataType(Enum):
     def __hash__(self):
         return self.value
 
+    def np_dtype(self):
+        return {
+            TSDataType.BOOLEAN: np.dtype(">?"),
+            TSDataType.FLOAT: np.dtype(">f4"),
+            TSDataType.DOUBLE: np.dtype(">f8"),
+            TSDataType.INT32: np.dtype(">i4"),
+            TSDataType.INT64: np.dtype(">i8"),
+            TSDataType.TEXT: np.dtype("str"),
+        }[self]
+
 
 @unique
 class TSEncoding(Enum):
diff --git a/client-py/iotdb/utils/NumpyTablet.py b/client-py/iotdb/utils/NumpyTablet.py
index 21e651d2d6..b81a172a40 100644
--- a/client-py/iotdb/utils/NumpyTablet.py
+++ b/client-py/iotdb/utils/NumpyTablet.py
@@ -17,7 +17,6 @@
 #
 
 import struct
-import numpy as np
 from iotdb.utils.IoTDBConstants import TSDataType
 from iotdb.utils.BitMap import BitMap
 
@@ -54,11 +53,11 @@ class NumpyTablet(object):
             for i in range(len(values)):
                 values[i] = values[i][index]
 
-        if timestamps.dtype != np.dtype(">i8"):
-            timestamps = timestamps.astype(np.dtype(">i8"))
+        if timestamps.dtype != TSDataType.INT64.np_dtype():
+            timestamps = timestamps.astype(TSDataType.INT64.np_dtype())
         for i in range(len(values)):
-            
-
+            if values[i].dtype != data_types[i].np_dtype():
+                values[i] = values[i].astype(data_types[i].np_dtype())
 
         self.__values = values
         self.__timestamps = timestamps
diff --git a/client-py/tests/test_numpy_tablet.py b/client-py/tests/test_numpy_tablet.py
index ea7502ce96..b984193975 100644
--- a/client-py/tests/test_numpy_tablet.py
+++ b/client-py/tests/test_numpy_tablet.py
@@ -82,7 +82,7 @@ def test_sort_numpy_tablet():
         "root.sg_test_01.d_01", measurements_, data_types_, values_, timestamps_
     )
     np_values_unsorted = [
-        np.array([False, False, False, True, True], np.dtype('>?')),
+        np.array([False, False, False, True, True], np.dtype(">?")),
         np.array([0, 10, 100, 1000, 10000], np.dtype(">i4")),
         np.array([1, 11, 111, 1111, 11111], np.dtype(">i8")),
         np.array([1.1, 1.25, 188.1, 0, 8.999], np.dtype(">f4")),
@@ -91,12 +91,17 @@ def test_sort_numpy_tablet():
     ]
     np_timestamps_unsorted = np.array([9, 8, 7, 6, 5], np.dtype(">i8"))
     np_tablet_ = NumpyTablet(
-        "root.sg_test_01.d_01", measurements_, data_types_, np_values_unsorted, np_timestamps_unsorted
+        "root.sg_test_01.d_01",
+        measurements_,
+        data_types_,
+        np_values_unsorted,
+        np_timestamps_unsorted,
     )
     assert tablet_.get_binary_timestamps() == np_tablet_.get_binary_timestamps()
     assert tablet_.get_binary_values() == np_tablet_.get_binary_values()
 
-def test_numpy_tablet_correct_endian():
+
+def test_numpy_tablet_auto_correct_datatype():
 
     measurements_ = ["s_01", "s_02", "s_03", "s_04", "s_05", "s_06"]
     data_types_ = [
@@ -119,17 +124,24 @@ def test_numpy_tablet_correct_endian():
         "root.sg_test_01.d_01", measurements_, data_types_, values_, timestamps_
     )
     np_values_unsorted = [
-        np.array([False, False, False, True, True], np.dtype('>?')),
-        np.array([0, 10, 100, 1000, 10000], np.dtype(">i4")),
-        np.array([1, 11, 111, 1111, 11111], np.dtype(">i8")),
-        np.array([1.1, 1.25, 188.1, 0, 8.999], np.dtype(">f4")),
-        np.array([10011.1, 101.0, 688.25, 6.25, 776], np.dtype(">f8")),
+        np.array([False, False, False, True, True]),
+        np.array([0, 10, 100, 1000, 10000]),
+        np.array([1, 11, 111, 1111, 11111]),
+        np.array([1.1, 1.25, 188.1, 0, 8.999]),
+        np.array([10011.1, 101.0, 688.25, 6.25, 776]),
         np.array(["test09", "test08", "test07", "test06", "test05"]),
     ]
     np_timestamps_unsorted = np.array([9, 8, 7, 6, 5])
+    # numpy.dtype of int and float should be little endian by default
     assert np_timestamps_unsorted.dtype != np.dtype(">i8")
+    for i in range(1, 4):
+        assert np_values_unsorted[i].dtype != data_types_[i].np_dtype()
     np_tablet_ = NumpyTablet(
-        "root.sg_test_01.d_01", measurements_, data_types_, np_values_unsorted, np_timestamps_unsorted
+        "root.sg_test_01.d_01",
+        measurements_,
+        data_types_,
+        np_values_unsorted,
+        np_timestamps_unsorted,
     )
     assert tablet_.get_binary_timestamps() == np_tablet_.get_binary_timestamps()
-    assert tablet_.get_binary_values() == np_tablet_.get_binary_values()
\ No newline at end of file
+    assert tablet_.get_binary_values() == np_tablet_.get_binary_values()