You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2018/03/13 17:22:00 UTC

[incubator-mxnet] branch master updated: [TENSOR] Fix DLTensor conversion for int64 (#10083)

This is an automated email from the ASF dual-hosted git repository.

zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new a9c1717  [TENSOR] Fix DLTensor conversion for int64 (#10083)
a9c1717 is described below

commit a9c1717f6673a2194a1f82ba7d74fb276f0ef24e
Author: Tianqi Chen <tq...@users.noreply.github.com>
AuthorDate: Tue Mar 13 10:21:53 2018 -0700

    [TENSOR] Fix DLTensor conversion for int64 (#10083)
    
    * [TENSOR] Fix DLTensor conversion for int64
    
    * trigger build
---
 include/mxnet/tensor_blob.h         | 23 +++++++++++++----------
 tests/python/gpu/test_tvm_bridge.py | 21 +++++++++++----------
 2 files changed, 24 insertions(+), 20 deletions(-)

diff --git a/include/mxnet/tensor_blob.h b/include/mxnet/tensor_blob.h
index 59c1eac..6f604a5 100755
--- a/include/mxnet/tensor_blob.h
+++ b/include/mxnet/tensor_blob.h
@@ -322,16 +322,19 @@ class TBlob {
 
  private:
   static DLDataType DTypeTransform(int type_flag) {
-    static std::unordered_map<int, DLDataType>
-      MSHADOW_DTYPE_TO_DLPACK_DTYPE = {
-        {0, {2, 32, 1}},  // Float32
-        {1, {2, 64, 1}},  // Float64
-        {2, {2, 16, 1}},  // Float16
-        {3, {1,  8, 1}},  // UInt8
-        {4, {0, 32, 1}},  // Int32
-        {5, {0,  8, 1}}   // Int8
-      };
-    return MSHADOW_DTYPE_TO_DLPACK_DTYPE[type_flag];
+    switch (type_flag) {
+      case mshadow::kFloat32: return DLDataType{kDLFloat, 32, 1};
+      case mshadow::kFloat64: return DLDataType{kDLFloat, 64, 1};
+      case mshadow::kFloat16: return DLDataType{kDLFloat, 16, 1};
+      case mshadow::kUint8: return DLDataType{kDLUInt, 8, 1};
+      case mshadow::kInt32: return DLDataType{kDLInt, 32, 1};
+      case mshadow::kInt8: return DLDataType{kDLInt, 8, 1};
+      case mshadow::kInt64: return DLDataType{kDLInt, 64, 1};
+      default: {
+        LOG(FATAL) << "Unknown type_flag=" << type_flag;
+        return DLDataType();
+      }
+    }
   }
 
   inline void SetDLTensor(int dev_mask, int dev_id) {
diff --git a/tests/python/gpu/test_tvm_bridge.py b/tests/python/gpu/test_tvm_bridge.py
index 292b9d9..69a713d 100644
--- a/tests/python/gpu/test_tvm_bridge.py
+++ b/tests/python/gpu/test_tvm_bridge.py
@@ -30,13 +30,13 @@ def test_tvm_bridge():
         logging.warn("TVM bridge test skipped because TVM is missing...")
         return
 
-    def check(target):
+    def check(target, dtype):
         shape = (20,)
         scale = tvm.var("scale", dtype="float32")
-        x = tvm.placeholder(shape)
-        y = tvm.placeholder(shape)
+        x = tvm.placeholder(shape, dtype=dtype)
+        y = tvm.placeholder(shape, dtype=dtype)
         z = tvm.compute(shape, lambda i: x[i] + y[i])
-        zz = tvm.compute(shape, lambda *i: z(*i) * scale)
+        zz = tvm.compute(shape, lambda *i: z(*i) * scale.astype(dtype))
         ctx = mx.gpu(0) if target == "cuda" else mx.cpu(0)
         target = tvm.target.create(target)
 
@@ -47,17 +47,18 @@ def test_tvm_bridge():
 
         # get a mxnet version
         mxf = tvm.contrib.mxnet.to_mxnet_func(f, const_loc=[0, 1])
-        xx = mx.nd.uniform(shape=shape, ctx=ctx)
-        yy = mx.nd.uniform(shape=shape, ctx=ctx)
-        zz = mx.nd.empty(shape=shape, ctx=ctx)
+        xx = mx.nd.uniform(shape=shape, ctx=ctx).astype(dtype)
+        yy = mx.nd.uniform(shape=shape, ctx=ctx).astype(dtype)
+        zz = mx.nd.empty(shape=shape, ctx=ctx).astype(dtype)
         # invoke myf: this runs in mxnet engine
         mxf(xx, yy, zz, 10.0)
         np.testing.assert_allclose(
             zz.asnumpy(), (xx.asnumpy() + yy.asnumpy()) * 10)
 
-    check("llvm")
-    check("cuda")
-
+    for tgt in ["llvm", "cuda"]:
+        for dtype in ["int8", "uint8", "int64",
+                      "float32", "float64"]:
+            check(tgt, dtype)
 
 
 if __name__ == "__main__":

-- 
To stop receiving notification emails like this one, please contact
zhasheng@apache.org.