You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/09/22 03:33:35 UTC

[GitHub] tqchen closed pull request #12047: [MXNET-779]Add DLPack Transformation API

tqchen closed pull request #12047: [MXNET-779]Add DLPack Transformation API
URL: https://github.com/apache/incubator-mxnet/pull/12047
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h
index 00439962a94..a01cc6a7794 100644
--- a/include/mxnet/c_api.h
+++ b/include/mxnet/c_api.h
@@ -93,6 +93,8 @@ typedef void *CudaModuleHandle;
 typedef void *CudaKernelHandle;
 /*! \brief handle to a Profile object (domain, duration, counter, etc.) */
 typedef void *ProfileHandle;
+/*! \brief handle to DLManagedTensor*/
+typedef void *DLManagedTensorHandle;
 
 typedef void (*ExecutorMonitorCallback)(const char*,
                                         NDArrayHandle,
@@ -746,6 +748,40 @@ MXNET_DLL int MXNDArrayGetShape(NDArrayHandle handle,
  */
 MXNET_DLL int MXNDArrayGetData(NDArrayHandle handle,
                                void **out_pdata);
+/*!
+* \brief Create a reference view of NDArray that
+*  represents as DLManagedTensor
+*  Notice: MXNet uses asynchronous execution. Please call MXNDArrayWaitToRead or
+*          MXNDArrayWaitToWrite before calling MXNDArrayToDLPack.
+* \param handle the handle to the ndarray
+* \param out_dlpack pointer holder to get pointer of DLManagedTensor
+* \return 0 when success, -1 when failure happens
+*/
+MXNET_DLL int MXNDArrayToDLPack(NDArrayHandle handle,
+                                       DLManagedTensorHandle *out_dlpack);
+
+/*!
+* \brief Create a NDArray backed by a dlpack tensor.
+*
+* This allows us to create a NDArray using the memory
+* allocated by an external deep learning framework
+* that is DLPack compatible.
+*
+* The memory is retained until the NDArray went out of scope.
+*
+* \param dlpack the pointer of the input DLManagedTensor
+* \param out_handle pointer holder to get pointer of NDArray
+* \return 0 when success, -1 when failure happens
+*/
+MXNET_DLL int MXNDArrayFromDLPack(DLManagedTensorHandle dlpack,
+                                  NDArrayHandle *out_handle);
+/*!
+ * \brief Delete a dlpack tensor
+ * \param dlpack the pointer of the input DLManagedTensor
+ * \return 0 when success, -1 when failure happens
+ */
+MXNET_DLL int MXNDArrayCallDLPackDeleter(DLManagedTensorHandle dlpack);
+
 /*!
  * \brief get the type of the data in NDArray
  * \param handle the handle to the narray
diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h
index 6141a4da78e..afae5dcfcff 100644
--- a/include/mxnet/ndarray.h
+++ b/include/mxnet/ndarray.h
@@ -116,6 +116,26 @@ class NDArray {
         dtype_(data.type_flag_), storage_type_(kDefaultStorage),
         entry_({nullptr, 0, 0}) {
   }
+
+  /*!
+   * \brief constructing a static NDArray that shares data with TBlob which is with deleter
+   *  Use with caution: allocate ONLY ONE NDArray for each TBlob,
+   *  make sure the memory region is available through out the life of NDArray
+   * \param data the memory content of static data
+   * \param dev_id the device id this tensor sits at
+   * \param deleter the function pointer of custom deleter
+   */
+  NDArray(const TBlob &data, int dev_id, const std::function<void()>& deleter)
+      : ptr_(new Chunk(data, dev_id),
+        [deleter](Chunk *p) {
+          deleter();    // call custom deleter
+          delete p;     // delete Chunk object
+        }),
+        shape_(data.shape_),
+        dtype_(data.type_flag_), storage_type_(kDefaultStorage),
+        entry_({nullptr, 0, 0}) {
+  }
+
   /*! \brief create ndarray from shared memory */
   NDArray(int shared_pid, int shared_id, const TShape& shape, int dtype)
       : ptr_(std::make_shared<Chunk>(shared_pid, shared_id, shape, dtype)), shape_(shape),
@@ -523,6 +543,26 @@ class NDArray {
     return ret;
   }
 
+  /*!
+   * \brief Create a reference view of NDArray that
+   *  represents as DLManagedTensor.
+   * \return A DLManagedTensor
+   */
+  DLManagedTensor* ToDLPack() const;
+
+  /*!
+   * \brief Create a NDArray backed by a dlpack tensor.
+   *
+   * This allows us to create a NDArray using the memory
+   * allocated by an external deep learning framework
+   * that is DLPack compatible.
+   *
+   * The memory is retained until the NDArray went out of scope.
+   *
+   * \return The created NDArray view.
+   */
+  static NDArray FromDLPack(const DLManagedTensor* tensor);
+
   /*!
    * \brief Update ndarray chunk storage handles using existing ndarray storage handles
    * Also update the aux_handle, aux_shapes and aux_types.
diff --git a/include/mxnet/tensor_blob.h b/include/mxnet/tensor_blob.h
index 6f604a5bb8d..496e8c7cfce 100755
--- a/include/mxnet/tensor_blob.h
+++ b/include/mxnet/tensor_blob.h
@@ -104,6 +104,39 @@ class TBlob {
       : dptr_(dptr), shape_(shape), type_flag_(type_flag) {
     SetDLTensor(dev_mask, dev_id);
   }
+  /*!
+   * \brief constructor that construct TBlob from DLTensor
+   * \param DLTensor Object
+   */
+  explicit TBlob(const DLTensor &dltensor)
+      : dptr_(dltensor.data),
+        shape_(TShape(dltensor.shape, dltensor.shape + dltensor.ndim)),
+        type_flag_(DLDataTypeTransform(dltensor.dtype)),
+        dltensor_(dltensor) {
+    // compactness check for DLTensor
+    if (dltensor.strides != nullptr) {
+      // check strides
+      const int &ndim = dltensor.ndim;
+      const int64_t *shape = dltensor.shape;
+      const int64_t *strides = dltensor.strides;
+      if (ndim >= 1) {
+        bool err = false;
+        if (strides[ndim - 1] != 1) {
+          err = true;
+        } else {
+          for (int i = ndim - 2; i >= 0; --i) {
+            if (strides[i] != shape[i + 1] * strides[i + 1]) {
+              err = true;
+              break;
+            }
+          }
+        }
+        if (err) {
+          LOG(FATAL) << "Unsupported DLPack because MXNet only support compact tensor now";
+        }
+      }
+    }
+  }
   /*!
    * \brief constructor from tensor
    * \param src source tensor
@@ -336,6 +369,36 @@ class TBlob {
       }
     }
   }
+  static int DLDataTypeTransform(DLDataType dldata_type) {
+    if (dldata_type.lanes != 1) {
+      LOG(FATAL) << "Unsupported DLDataType whose lanes != 1";
+    }
+    switch (dldata_type.code) {
+      case kDLFloat:
+        switch (dldata_type.bits) {
+          case 16: return mshadow::kFloat16;
+          case 32: return mshadow::kFloat32;
+          case 64: return mshadow::kFloat64;
+        }
+        break;
+      case kDLUInt:
+        switch (dldata_type.bits) {
+          case 8: return mshadow::kUint8;
+        }
+        break;
+      case kDLInt:
+        switch (dldata_type.bits) {
+          case 8: return mshadow::kInt8;
+          case 32: return mshadow::kInt32;
+          case 64: return mshadow::kInt64;
+        }
+        break;
+    }
+    LOG(FATAL) << "Unknown DLDataType{" << dldata_type.code
+               << ", " << dldata_type.bits
+               << ", " << dldata_type.lanes << "}";
+    return mshadow::kFloat32;
+  }
 
   inline void SetDLTensor(int dev_mask, int dev_id) {
     dltensor_.data = dptr_;
@@ -343,7 +406,7 @@ class TBlob {
     dltensor_.ndim = shape_.ndim();
     dltensor_.dtype = DTypeTransform(type_flag_);
     dltensor_.shape = shape_.data();
-    dltensor_.strides = NULL;
+    dltensor_.strides = nullptr;
     dltensor_.byte_offset = 0;
   }
 
diff --git a/python/mxnet/base.py b/python/mxnet/base.py
index 89e1c9e087b..84b9e5831c6 100644
--- a/python/mxnet/base.py
+++ b/python/mxnet/base.py
@@ -232,6 +232,7 @@ def _load_lib():
 CudaModuleHandle = ctypes.c_void_p
 CudaKernelHandle = ctypes.c_void_p
 ProfileHandle = ctypes.c_void_p
+DLPackHandle = ctypes.c_void_p
 
 
 #----------------------------
@@ -726,3 +727,6 @@ def write_all_str(module_file, module_all_list):
     module_op_file.close()
     write_all_str(module_internal_file, module_internal_all)
     module_internal_file.close()
+
+ctypes.pythonapi.PyCapsule_New.restype = ctypes.py_object
+ctypes.pythonapi.PyCapsule_GetPointer.restype = ctypes.c_void_p
diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py
index d6d619f30ca..fabf42e1dc6 100644
--- a/python/mxnet/ndarray/ndarray.py
+++ b/python/mxnet/ndarray/ndarray.py
@@ -34,8 +34,8 @@
 from functools import reduce # pylint: disable=redefined-builtin
 import numpy as np
 from ..base import _LIB, numeric_types, integer_types
-from ..base import c_array, c_array_buf, c_handle_array, mx_real_t
-from ..base import mx_uint, NDArrayHandle, check_call
+from ..base import c_str, c_array, c_array_buf, c_handle_array, mx_real_t
+from ..base import mx_uint, NDArrayHandle, check_call, DLPackHandle
 from ..base import ctypes2buffer
 from ..context import Context, current_context
 from . import _internal
@@ -46,7 +46,8 @@
            "ones", "add", "arange", "eye", "divide", "equal", "full", "greater", "greater_equal",
            "imdecode", "lesser", "lesser_equal", "logical_and", "logical_or", "logical_xor",
            "maximum", "minimum", "moveaxis", "modulo", "multiply", "not_equal", "onehot_encode",
-           "power", "subtract", "true_divide", "waitall", "_new_empty_handle", "histogram"]
+           "power", "subtract", "true_divide", "waitall", "_new_empty_handle", "histogram",
+           "to_dlpack_for_read", "to_dlpack_for_write", "from_dlpack"]
 
 _STORAGE_TYPE_UNDEFINED = -1
 _STORAGE_TYPE_DEFAULT = 0
@@ -178,7 +179,6 @@ class NDArray(NDArrayBase):
     # See C++ side of definition(kTVMNDArrayTypeCode) at include/mxmet/tensor_blob.h
     _tvm_tcode = 19
     # pylint: disable= no-member, undefined-variable
-
     @property
     def _tvm_handle(self):
         return self.handle.value
@@ -2205,6 +2205,52 @@ def tostype(self, stype):
         """
         return op.cast_storage(self, stype=stype)
 
+    def to_dlpack_for_read(self):
+        """Returns a reference view of NDArray that represents as DLManagedTensor until
+        all previous write operations on the current array are finished.
+
+        Returns
+        -------
+        PyCapsule (the pointer of DLManagedTensor)
+            a reference view of NDArray that represents as DLManagedTensor.
+
+        Examples
+        --------
+        >>> x = mx.nd.ones((2,3))
+        >>> y = mx.nd.to_dlpack_for_read(x)
+        >>> type(y)
+        <class 'PyCapsule'>
+        >>> z = mx.nd.from_dlpack(y)
+        >>> z
+        [[1. 1. 1.]
+         [1. 1. 1.]]
+        <NDArray 2x3 @cpu(0)>
+        """
+        return to_dlpack_for_read(self)
+
+    def to_dlpack_for_write(self):
+        """Returns a reference view of NDArray that represents as DLManagedTensor until
+        all previous read/write operations on the current array are finished.
+
+        Returns
+        -------
+        PyCapsule (the pointer of DLManagedTensor)
+            a reference view of NDArray that represents as DLManagedTensor.
+
+        Examples
+        --------
+        >>> x = mx.nd.ones((2,3))
+        >>> w = mx.nd.to_dlpack_for_write(x)
+        >>> type(w)
+        <class 'PyCapsule'>
+        >>> u = mx.nd.from_dlpack(w)
+        >>> u += 1
+        >>> x
+        [[2. 2. 2.]
+         [2. 2. 2.]]
+        <NDArray 2x3 @cpu(0)>
+        """
+        return to_dlpack_for_write(self)
 
 def _get_indexing_dispatch_code(key):
     """Returns a dispatch code for calling basic or advanced indexing functions."""
@@ -3851,3 +3897,128 @@ def histogram(a, bins=10, range=None):
         return _internal._histogram(data=a, bin_cnt=bins, range=range)
     raise ValueError("bins argument should be either an integer or an NDArray")
     # pylint: enable= no-member, protected-access, redefined-builtin
+
+PyCapsuleDestructor = ctypes.CFUNCTYPE(None, ctypes.c_void_p)
+_c_str_dltensor = c_str('dltensor')
+_c_str_used_dltensor = c_str('used_dltensor')
+
+def _dlpack_deleter(pycapsule):
+    pycapsule = ctypes.c_void_p(pycapsule)
+    if ctypes.pythonapi.PyCapsule_IsValid(pycapsule, _c_str_dltensor):
+        ptr = ctypes.c_void_p(
+            ctypes.pythonapi.PyCapsule_GetPointer(pycapsule, _c_str_dltensor))
+        check_call(_LIB.MXNDArrayCallDLPackDeleter(ptr))
+
+_c_dlpack_deleter = PyCapsuleDestructor(_dlpack_deleter)
+
+def to_dlpack_for_read(data):
+    """Returns a reference view of NDArray that represents as DLManagedTensor until
+       all previous write operations on the current array are finished.
+
+    Parameters
+    ----------
+    data: NDArray
+        input data.
+
+    Returns
+    -------
+    PyCapsule (the pointer of DLManagedTensor)
+        a reference view of NDArray that represents as DLManagedTensor.
+
+    Examples
+    --------
+    >>> x = mx.nd.ones((2,3))
+    >>> y = mx.nd.to_dlpack_for_read(x)
+    >>> type(y)
+    <class 'PyCapsule'>
+    >>> z = mx.nd.from_dlpack(y)
+    >>> z
+    [[1. 1. 1.]
+     [1. 1. 1.]]
+    <NDArray 2x3 @cpu(0)>
+    """
+    data.wait_to_read()
+    dlpack = DLPackHandle()
+    check_call(_LIB.MXNDArrayToDLPack(data.handle, ctypes.byref(dlpack)))
+    return ctypes.pythonapi.PyCapsule_New(dlpack, _c_str_dltensor, _c_dlpack_deleter)
+
+def to_dlpack_for_write(data):
+    """Returns a reference view of NDArray that represents as DLManagedTensor until
+       all previous read/write operations on the current array are finished.
+
+    Parameters
+    ----------
+    data: NDArray
+        input data.
+
+    Returns
+    -------
+    PyCapsule (the pointer of DLManagedTensor)
+        a reference view of NDArray that represents as DLManagedTensor.
+
+    Examples
+    --------
+    >>> x = mx.nd.ones((2,3))
+    >>> w = mx.nd.to_dlpack_for_write(x)
+    >>> type(w)
+    <class 'PyCapsule'>
+    >>> u = mx.nd.from_dlpack(w)
+    >>> u += 1
+    >>> x
+    [[2. 2. 2.]
+     [2. 2. 2.]]
+    <NDArray 2x3 @cpu(0)>
+    """
+    check_call(_LIB.MXNDArrayWaitToWrite(data.handle))
+    dlpack = DLPackHandle()
+    check_call(_LIB.MXNDArrayToDLPack(data.handle, ctypes.byref(dlpack)))
+    return ctypes.pythonapi.PyCapsule_New(dlpack, _c_str_dltensor, _c_dlpack_deleter)
+
+def from_dlpack(dlpack):
+    """Returns a NDArray backed by a dlpack tensor.
+
+    Parameters
+    ----------
+    dlpack: PyCapsule (the pointer of DLManagedTensor)
+        input data
+
+    Returns
+    -------
+    NDArray
+        a NDArray backed by a dlpack tensor
+
+    Examples
+    --------
+    >>> x = mx.nd.ones((2,3))
+    >>> y = mx.nd.to_dlpack_for_read(x)
+    >>> type(y)
+    <class 'PyCapsule'>
+    >>> z = mx.nd.from_dlpack(y)
+    >>> type(z)
+    <class 'mxnet.ndarray.ndarray.NDArray'>
+    >>> z
+    [[ 1.  1.  1.]
+     [ 1.  1.  1.]]
+    <NDArray 2x3 @cpu(0)>
+
+    >>> w = mx.nd.to_dlpack_for_write(x)
+    >>> type(w)
+    <class 'PyCapsule'>
+    >>> u = mx.nd.from_dlpack(w)
+    >>> u += 1
+    >>> x
+    [[2. 2. 2.]
+     [2. 2. 2.]]
+    <NDArray 2x3 @cpu(0)>
+    """
+    handle = NDArrayHandle()
+    dlpack = ctypes.py_object(dlpack)
+    assert ctypes.pythonapi.PyCapsule_IsValid(dlpack, _c_str_dltensor), ValueError(
+        'Invalid DLPack Tensor. DLTensor capsules can be consumed only once.')
+    dlpack_handle = ctypes.c_void_p(ctypes.pythonapi.PyCapsule_GetPointer(dlpack, _c_str_dltensor))
+    check_call(_LIB.MXNDArrayFromDLPack(dlpack_handle, ctypes.byref(handle)))
+    # Rename PyCapsule (DLPack)
+    ctypes.pythonapi.PyCapsule_SetName(dlpack, _c_str_used_dltensor)
+    # delete the deleter of the old dlpack
+    ctypes.pythonapi.PyCapsule_SetDestructor(dlpack, None)
+    return NDArray(handle=handle)
diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc
index 1ef3f0fca9f..56e318097a3 100644
--- a/src/c_api/c_api.cc
+++ b/src/c_api/c_api.cc
@@ -500,6 +500,31 @@ int MXNDArrayGetData(NDArrayHandle handle,
   API_END();
 }
 
+int MXNDArrayToDLPack(NDArrayHandle handle,
+                      DLManagedTensorHandle *out_dlpack) {
+  API_BEGIN();
+  NDArray *arr = static_cast<NDArray*>(handle);
+  *out_dlpack = arr->ToDLPack();
+  API_END();
+}
+
+int MXNDArrayFromDLPack(DLManagedTensorHandle dlpack,
+                        NDArrayHandle *out_handle) {
+  API_BEGIN();
+  *out_handle = new NDArray(NDArray::FromDLPack(
+              static_cast<DLManagedTensor*>(dlpack)));
+  API_END();
+}
+
+int MXNDArrayCallDLPackDeleter(DLManagedTensorHandle dlpack) {
+  API_BEGIN();
+  if (dlpack != nullptr) {
+    DLManagedTensor *p_dlpack = static_cast<DLManagedTensor*>(dlpack);
+    p_dlpack->deleter(p_dlpack);
+  }
+  API_END();
+}
+
 int MXNDArrayGetDType(NDArrayHandle handle,
                      int *out_dtype) {
   API_BEGIN();
diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc
index 853838a87f4..5bcb1c2bf48 100644
--- a/src/ndarray/ndarray.cc
+++ b/src/ndarray/ndarray.cc
@@ -312,6 +312,34 @@ NDArray NDArray::data_ndarray() const {
   return ret;
 }
 
+struct NDArrayDLManager {
+    NDArray handle;  // ref NDArray
+    DLManagedTensor tensor;
+};
+
+DLManagedTensor* NDArray::ToDLPack() const {
+  NDArrayDLManager* dlmanager(new NDArrayDLManager);
+  dlmanager->handle = *this;
+  if (!is_none()) {
+    dlmanager->tensor.dl_tensor = data().dltensor();
+  }
+  dlmanager->tensor.manager_ctx = dlmanager;
+  dlmanager->tensor.deleter = [](DLManagedTensor* dlmanager){
+    delete static_cast<NDArrayDLManager*>(dlmanager->manager_ctx);
+  };
+  return &(dlmanager->tensor);
+}
+
+NDArray NDArray::FromDLPack(const DLManagedTensor* tensor) {
+  const DLTensor &dl_tensor = tensor->dl_tensor;
+  auto deleter = [tensor](){
+    if (tensor->deleter != nullptr) {
+      tensor->deleter(const_cast<DLManagedTensor*>(tensor));
+    }
+  };
+  return NDArray(TBlob(dl_tensor), dl_tensor.ctx.device_id, deleter);
+}
+
 bool NDArray::fresh_out_grad() const {
   if (Imperative::AGInfo::IsNone(*this)) return false;
   Imperative::AGInfo& info = Imperative::AGInfo::Get(entry_.node);
diff --git a/tests/python/unittest/test_ndarray.py b/tests/python/unittest/test_ndarray.py
index c48801ec1ce..e5fc19b190a 100644
--- a/tests/python/unittest/test_ndarray.py
+++ b/tests/python/unittest/test_ndarray.py
@@ -1439,6 +1439,37 @@ def test_ndarray_cpu_shared_ctx():
     res = mx.nd.zeros((1, 2, 3), ctx=ctx)
     assert(res.context == ctx)
 
+@with_seed()
+def test_dlpack():
+    for dtype in [np.float32, np.int32]:
+        for shape in [(3, 4, 5, 6), (2, 10), (15,)]:
+            a = mx.nd.random.uniform(shape = shape)
+            a_np = a.asnumpy()
+
+            pack = a.to_dlpack_for_read()
+            b = mx.nd.from_dlpack(pack)
+
+            a_copy = a.copy()
+            pack2 = a_copy.to_dlpack_for_write()
+            c = mx.nd.from_dlpack(pack2)
+
+            pack3 = mx.nd.to_dlpack_for_read(a)
+            d = mx.nd.from_dlpack(pack3)
+
+            a_copy = a.copy()
+            pack4 = mx.nd.to_dlpack_for_write(a_copy)
+            e = mx.nd.from_dlpack(pack4)
+
+            del a, pack, pack2, pack3, pack4
+
+            b_np = b.asnumpy()
+            c_np = c.asnumpy()
+            d_np = d.asnumpy()
+            e_np = e.asnumpy()
+            mx.test_utils.assert_almost_equal(a_np, b_np)
+            mx.test_utils.assert_almost_equal(a_np, c_np)
+            mx.test_utils.assert_almost_equal(a_np, d_np)
+            mx.test_utils.assert_almost_equal(a_np, e_np)
 
 if __name__ == '__main__':
     import nose


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services