You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/09/30 00:04:08 UTC

[GitHub] [incubator-tvm] imalsogreg commented on a change in pull request #6563: [Rust] Improve NDArray, GraphRt, and Relay bindings

imalsogreg commented on a change in pull request #6563:
URL: https://github.com/apache/incubator-tvm/pull/6563#discussion_r497162386



##########
File path: rust/tvm-rt/src/ndarray.rs
##########
@@ -47,73 +47,146 @@
 //! [`copy_from_buffer`]:struct.NDArray.html#method.copy_from_buffer
 //! [`copy_to_ctx`]:struct.NDArray.html#method.copy_to_ctx
 
-use std::convert::TryInto;
 use std::ffi::c_void;
+use std::{borrow::Cow, convert::TryInto};
 use std::{convert::TryFrom, mem, os::raw::c_int, ptr, slice, str::FromStr};
 
-use crate::errors::NDArrayError;
-
+use mem::size_of;
+use tvm_macros::Object;
 use tvm_sys::ffi::DLTensor;
 use tvm_sys::{ffi, ByteArray, Context, DataType};
 
 use ndarray::{Array, ArrayD};
 use num_traits::Num;
 
+use crate::errors::NDArrayError;
+
+use crate::object::{Object, ObjectPtr};
+
 /// See the [`module-level documentation`](../ndarray/index.html) for more details.
-///
-/// Wrapper around TVM array handle.
-#[derive(Debug)]
-pub enum NDArray {
-    Borrowed { handle: ffi::TVMArrayHandle },
-    Owned { handle: *mut c_void },
+#[repr(C)]
+#[derive(Object)]
+#[ref_name = "NDArray"]
+#[type_key = "runtime.NDArray"]
+pub struct NDArrayContainer {
+    base: Object,
+    // Container Base
+    dl_tensor: DLTensor,
+    manager_ctx: *mut c_void,
+    // TOOD: shape?
 }
 
-impl NDArray {
-    pub(crate) fn new(handle: ffi::TVMArrayHandle) -> Self {
-        NDArray::Borrowed { handle }
+impl NDArrayContainer {
+    pub(crate) fn from_raw(handle: ffi::TVMArrayHandle) -> Option<ObjectPtr<Self>> {
+        let base_offset = memoffset::offset_of!(NDArrayContainer, dl_tensor) as isize;
+        let base_ptr = unsafe { (handle as *mut i8).offset(-base_offset) };
+        let object_ptr = ObjectPtr::from_raw(base_ptr.cast());
+        object_ptr.map(|ptr| {
+            ptr.downcast::<NDArrayContainer>()
+                .expect("we know this is an NDArray container")
+        })
+    }
+
+    pub fn leak<'a>(object_ptr: ObjectPtr<NDArrayContainer>) -> &'a mut NDArrayContainer
+    where
+        NDArrayContainer: 'a,
+    {
+        let base_offset = memoffset::offset_of!(NDArrayContainer, dl_tensor) as isize;
+        unsafe {
+            &mut *std::mem::ManuallyDrop::new(object_ptr)
+                .ptr
+                .as_ptr()
+                .cast::<u8>()
+                .offset(base_offset)
+                .cast::<NDArrayContainer>()
+        }
     }
+}
 
-    pub(crate) fn from_ndarray_handle(handle: *mut c_void) -> Self {
-        NDArray::Owned { handle }
+fn cow_usize<'a>(slice: &[i64]) -> Cow<'a, [usize]> {
+    if std::mem::size_of::<usize>() == 64 {
+        debug_assert!(slice.iter().all(|&x| x >= 0));
+        let shape: &[usize] = unsafe { std::mem::transmute(slice) };
+        Cow::Borrowed(shape)
+    } else {
+        let shape: Vec<usize> = slice
+            .iter()
+            .map(|&x| usize::try_from(x).unwrap_or_else(|_| panic!("Cannot fit into usize: {}", x)))
+            .collect();
+        Cow::Owned(shape)
     }
+}
 
-    pub fn as_dltensor(&self) -> &DLTensor {
-        let ptr: *mut DLTensor = match self {
-            NDArray::Borrowed { ref handle } => *handle,
-            NDArray::Owned { ref handle } => *handle as *mut DLTensor,
-        };
+impl NDArray {
+    pub(crate) fn _from_raw(handle: ffi::TVMArrayHandle) -> Self {
+        let ptr = NDArrayContainer::from_raw(handle);
+        NDArray(ptr)
+    }
 
-        unsafe { std::mem::transmute(ptr) }
+    // I think these should be marked as unsafe functions? projecting a reference is bad news.

Review comment:
       Just calling this out as a TODO




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org