You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2021/09/24 13:23:17 UTC

[tvm] branch main updated: [PYTHON][FFI] Skip numpy.ascontiguousarray if C_CONTIGUOUS == True (#9073)

This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new d3d7e8e  [PYTHON][FFI] Skip numpy.ascontiguousarray if C_CONTIGUOUS == True (#9073)
d3d7e8e is described below

commit d3d7e8eb6c201506dc706a055e16eed189dcdb0b
Author: wangxiang2713 <49...@users.noreply.github.com>
AuthorDate: Fri Sep 24 21:22:55 2021 +0800

    [PYTHON][FFI] Skip numpy.ascontiguousarray if C_CONTIGUOUS == True (#9073)
---
 python/tvm/runtime/ndarray.py | 13 +++++++++++--
 1 file changed, 11 insertions(+), 2 deletions(-)

diff --git a/python/tvm/runtime/ndarray.py b/python/tvm/runtime/ndarray.py
index 27811a9..2b9f7f9 100644
--- a/python/tvm/runtime/ndarray.py
+++ b/python/tvm/runtime/ndarray.py
@@ -165,9 +165,18 @@ class NDArray(NDArrayBase):
                     source_array.shape, shape
                 )
             )
-        source_array = np.ascontiguousarray(
-            source_array, dtype="uint16" if dtype == "bfloat16" else dtype
+        numpy_str_map = DataType.NUMPY2STR
+        np_dtype_str = (
+            numpy_str_map[source_array.dtype]
+            if source_array.dtype in numpy_str_map
+            else str(source_array.dtype)
         )
+        if (not source_array.flags["C_CONTIGUOUS"]) or (
+            dtype == "bfloat16" or dtype != np_dtype_str
+        ):
+            source_array = np.ascontiguousarray(
+                source_array, dtype="uint16" if dtype == "bfloat16" else dtype
+            )
         assert source_array.flags["C_CONTIGUOUS"]
         data = source_array.ctypes.data_as(ctypes.c_void_p)
         nbytes = ctypes.c_size_t(source_array.size * source_array.dtype.itemsize)