You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2019/10/07 22:26:59 UTC

[GitHub] [incubator-mxnet] junrushao1994 commented on a change in pull request #15921: dynamic custom operator support

junrushao1994 commented on a change in pull request #15921: dynamic custom operator support
URL: https://github.com/apache/incubator-mxnet/pull/15921#discussion_r332258311
 
 

 ##########
 File path: include/mxnet/lib_api.h
 ##########
 @@ -18,33 +18,754 @@
  */
 
 /*!
- * Copyright (c) 2015 by Contributors
+ * Copyright (c) 2019 by Contributors
  * \file lib_api.h
  * \brief APIs to interact with libraries
+ * This API specifies function prototypes to
+ * register custom ops for library authors
  */
+
 #ifndef MXNET_LIB_API_H_
 #define MXNET_LIB_API_H_
 
+#include <stdint.h>
+#include <stdlib.h>
+#include <vector>
+#include <map>
+#include <string>
+#include <iostream>
+#include <utility>
+
+#define MX_LIBRARY_VERSION 1
+
+/*!
+ * \brief Tensor data type, consistent with mshadow data type
+ */
+enum MXDType {
+  kFloat32 = 0,
+  kFloat64 = 1,
+  kFloat16 = 2,
+  kUint8 = 3,
+  kInt32 = 4,
+  kInt8  = 5,
+  kInt64 = 6,
+};
+
+enum MXReturnValue {
+  MX_FAIL = 0,
+  MX_SUCCESS = 1,
+};
+
+/*!
+ * \brief Tensor data structure used by custom operator
+ */
+struct MXTensor {
+  MXTensor() : data_ptr(NULL) {}
+
+  MXTensor(void *data_ptr, const std::vector<int64_t> &shape, MXDType dtype)
+  : data_ptr(data_ptr), shape(shape), dtype(dtype) {}
+
+  /*! \brief helper function to cast data pointer */
+  template<typename data_type>
+  inline data_type* data() {
+    return reinterpret_cast<data_type*>(data_ptr);
+  }
+
+  /*! \brief helper function to get data size */
+  inline int64_t size() {
+    int64_t size = 1;
+    for (unsigned int i = 0; i < shape.size(); i++) {
+      size *= shape[i];
+    }
+    return size;
+  }
+
+  // data is flatten 1D repr of tensor, elements are in continuous memory
+  // user can access each element using the shape of tensor
+  // it may also point to data allocated on gpu
+  void *data_ptr;
+
+  // shape is in [2,3,4] format to represent high-dim tensor
+  std::vector<int64_t> shape;
 
 Review comment:
   Do we assume that we share a same c++ runtime so that the allocators behind std::vector are consistent?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services