You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2021/09/24 13:23:17 UTC
[tvm] branch main updated: [PYTHON][FFI] Skip
numpy.ascontiguousarray if C_CONTIGUOUS == True (#9073)
This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new d3d7e8e [PYTHON][FFI] Skip numpy.ascontiguousarray if C_CONTIGUOUS == True (#9073)
d3d7e8e is described below
commit d3d7e8eb6c201506dc706a055e16eed189dcdb0b
Author: wangxiang2713 <49...@users.noreply.github.com>
AuthorDate: Fri Sep 24 21:22:55 2021 +0800
[PYTHON][FFI] Skip numpy.ascontiguousarray if C_CONTIGUOUS == True (#9073)
---
python/tvm/runtime/ndarray.py | 13 +++++++++++--
1 file changed, 11 insertions(+), 2 deletions(-)
diff --git a/python/tvm/runtime/ndarray.py b/python/tvm/runtime/ndarray.py
index 27811a9..2b9f7f9 100644
--- a/python/tvm/runtime/ndarray.py
+++ b/python/tvm/runtime/ndarray.py
@@ -165,9 +165,18 @@ class NDArray(NDArrayBase):
source_array.shape, shape
)
)
- source_array = np.ascontiguousarray(
- source_array, dtype="uint16" if dtype == "bfloat16" else dtype
+ numpy_str_map = DataType.NUMPY2STR
+ np_dtype_str = (
+ numpy_str_map[source_array.dtype]
+ if source_array.dtype in numpy_str_map
+ else str(source_array.dtype)
)
+ if (not source_array.flags["C_CONTIGUOUS"]) or (
+ dtype == "bfloat16" or dtype != np_dtype_str
+ ):
+ source_array = np.ascontiguousarray(
+ source_array, dtype="uint16" if dtype == "bfloat16" else dtype
+ )
assert source_array.flags["C_CONTIGUOUS"]
data = source_array.ctypes.data_as(ctypes.c_void_p)
nbytes = ctypes.c_size_t(source_array.size * source_array.dtype.itemsize)