You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by jr...@apache.org on 2020/06/18 18:33:34 UTC
[incubator-tvm] branch master updated: `tvm` crate stage 3 of Rust
refactor (#5769)
This is an automated email from the ASF dual-hosted git repository.
jroesch pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new d8c80c3 `tvm` crate stage 3 of Rust refactor (#5769)
d8c80c3 is described below
commit d8c80c382f02052b07da1235190a5b6c7acea994
Author: Jared Roesch <jr...@octoml.ai>
AuthorDate: Thu Jun 18 11:33:25 2020 -0700
`tvm` crate stage 3 of Rust refactor (#5769)
* Adapt to new macro
* Add tvm crate
* Fix out of tree pass with new bindings
* Super slick API working
* Add examples
* Delay egg example and add ASF headers
* Move array.rs around
* Remove outdated tests will restore in CI PR
* Fix some memory issues
* Fix ref counting issue
* Formatting and cleanup
* Remove out-of-tree for now
* Remove out-of-tree
---
rust/Cargo.toml | 4 +-
rust/runtime/tests/test_wasm32/Cargo.toml | 4 +
rust/runtime/tests/test_wasm32/build.rs | 14 +-
rust/tvm-macros/src/external.rs | 6 +-
rust/tvm-macros/src/object.rs | 31 ++-
rust/tvm-rt/src/array.rs | 79 ++++++
rust/tvm-rt/src/errors.rs | 2 +
rust/tvm-rt/src/function.rs | 20 +-
rust/tvm-rt/src/lib.rs | 4 +-
rust/tvm-rt/src/ndarray.rs | 22 +-
rust/tvm-rt/src/object/mod.rs | 53 ++--
rust/tvm-rt/src/object/object_ptr.rs | 101 ++++++--
rust/tvm-rt/src/string.rs | 42 +--
rust/tvm-rt/src/to_function.rs | 56 ++--
rust/tvm-sys/src/lib.rs | 12 +
rust/tvm/.gitignore | 7 +
.../test_wasm32/Cargo.toml => tvm/.travis.yml} | 14 +-
rust/{runtime/tests/test_wasm32 => tvm}/Cargo.toml | 27 +-
rust/tvm/README.md | 235 +++++++++++++++++
rust/tvm/src/ir/mod.rs | 50 ++++
rust/tvm/src/ir/relay/mod.rs | 282 +++++++++++++++++++++
rust/tvm/src/lib.rs | 47 ++++
rust/tvm/src/runtime/mod.rs | 20 ++
rust/tvm/src/transform.rs | 93 +++++++
src/printer/relay_text_printer.cc | 2 -
25 files changed, 1075 insertions(+), 152 deletions(-)
diff --git a/rust/Cargo.toml b/rust/Cargo.toml
index 6849c03..d9bb3ab 100644
--- a/rust/Cargo.toml
+++ b/rust/Cargo.toml
@@ -29,5 +29,7 @@ members = [
"frontend/tests/callback",
"frontend/examples/resnet",
"tvm-sys",
- "tvm-rt"
+ "tvm-macros",
+ "tvm-rt",
+ "tvm",
]
diff --git a/rust/runtime/tests/test_wasm32/Cargo.toml b/rust/runtime/tests/test_wasm32/Cargo.toml
index 1d3373a..eeead45 100644
--- a/rust/runtime/tests/test_wasm32/Cargo.toml
+++ b/rust/runtime/tests/test_wasm32/Cargo.toml
@@ -20,7 +20,11 @@ name = "test-wasm32"
version = "0.0.0"
license = "Apache-2.0"
authors = ["TVM Contributors"]
+edition = "2018"
[dependencies]
ndarray="0.12"
tvm-runtime = { path = "../../" }
+
+[build-dependencies]
+anyhow = "^1.0"
diff --git a/rust/runtime/tests/test_wasm32/build.rs b/rust/runtime/tests/test_wasm32/build.rs
index 8b72be2..5c816c3 100644
--- a/rust/runtime/tests/test_wasm32/build.rs
+++ b/rust/runtime/tests/test_wasm32/build.rs
@@ -19,12 +19,14 @@
use std::{path::PathBuf, process::Command};
-fn main() {
+use anyhow::{Context, Result};
+
+fn main() -> Result<()> {
let mut out_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
out_dir.push("lib");
if !out_dir.is_dir() {
- std::fs::create_dir(&out_dir).unwrap();
+ std::fs::create_dir(&out_dir).context("failed to create directory for WASM outputs")?;
}
let obj_file = out_dir.join("test.o");
@@ -36,7 +38,8 @@ fn main() {
))
.arg(&out_dir)
.output()
- .expect("Failed to execute command");
+ .context("failed to execute Python script for generating TVM library")?;
+
assert!(
obj_file.exists(),
"Could not build tvm lib: {}",
@@ -49,12 +52,14 @@ fn main() {
);
let ar = option_env!("LLVM_AR").unwrap_or("llvm-ar-8");
+
let output = Command::new(ar)
.arg("rcs")
.arg(&lib_file)
.arg(&obj_file)
.output()
- .expect("Failed to execute command");
+ .context("failed to run LLVM_AR command")?;
+
assert!(
lib_file.exists(),
"Could not create archive: {}",
@@ -68,4 +73,5 @@ fn main() {
println!("cargo:rustc-link-lib=static=test_wasm32");
println!("cargo:rustc-link-search=native={}", out_dir.display());
+ Ok(())
}
diff --git a/rust/tvm-macros/src/external.rs b/rust/tvm-macros/src/external.rs
index 8833d60..2fcee49 100644
--- a/rust/tvm-macros/src/external.rs
+++ b/rust/tvm-macros/src/external.rs
@@ -88,7 +88,7 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let tvm_rt_crate = crate::util::get_tvm_rt_crate();
- let err_type = quote! { #tvm_rt_crate::Error };
+ let result_type = quote! { #tvm_rt_crate::function::Result };
let mut items = Vec::new();
@@ -142,9 +142,9 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
items.push(global);
let wrapper = quote! {
- pub fn #name<#(#ty_params),*>(#(#args : #tys),*) -> Result<#ret_type, #err_type> {
+ pub fn #name<#(#ty_params),*>(#(#args : #tys),*) -> #result_type<#ret_type> {
let func_ref: #tvm_rt_crate::Function = #global_name.clone();
- let func_ref: Box<dyn Fn(#(#tys),*) -> Result<#ret_type, #err_type>> = func_ref.to_boxed_fn();
+ let func_ref: Box<dyn Fn(#(#tys),*) -> #result_type<#ret_type>> = func_ref.to_boxed_fn();
let res: #ret_type = func_ref(#(#args),*)?;
Ok(res)
}
diff --git a/rust/tvm-macros/src/object.rs b/rust/tvm-macros/src/object.rs
index bee22c3..0170e1d 100644
--- a/rust/tvm-macros/src/object.rs
+++ b/rust/tvm-macros/src/object.rs
@@ -27,6 +27,8 @@ use crate::util::get_tvm_rt_crate;
pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream {
let tvm_rt_crate = get_tvm_rt_crate();
+ let result = quote! { #tvm_rt_crate::function::Result };
+ let error = quote! { #tvm_rt_crate::errors::Error };
let derive_input = syn::parse_macro_input!(input as DeriveInput);
let payload_id = derive_input.ident;
@@ -77,9 +79,15 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream {
#[derive(Clone)]
pub struct #ref_id(Option<#tvm_rt_crate::object::ObjectPtr<#payload_id>>);
- impl #tvm_rt_crate::object::ToObjectRef for #ref_id {
- fn to_object_ref(&self) -> ObjectRef {
- ObjectRef(self.0.as_ref().map(|o| o.upcast()))
+ impl #tvm_rt_crate::object::IsObjectRef for #ref_id {
+ type Object = #payload_id;
+
+ fn as_object_ptr(&self) -> Option<&ObjectPtr<Self::Object>> {
+ self.0.as_ref()
+ }
+
+ fn from_object_ptr(object_ptr: Option<ObjectPtr<Self::Object>>) -> Self {
+ #ref_id(object_ptr)
}
}
@@ -92,9 +100,9 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream {
}
impl std::convert::TryFrom<#tvm_rt_crate::RetValue> for #ref_id {
- type Error = #tvm_rt_crate::Error;
+ type Error = #error;
- fn try_from(ret_val: #tvm_rt_crate::RetValue) -> Result<#ref_id, Self::Error> {
+ fn try_from(ret_val: #tvm_rt_crate::RetValue) -> #result<#ref_id> {
use std::convert::TryInto;
let oref: ObjectRef = ret_val.try_into()?;
let ptr = oref.0.ok_or(#tvm_rt_crate::Error::Null)?;
@@ -125,24 +133,15 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream {
}
impl<'a> std::convert::TryFrom<#tvm_rt_crate::ArgValue<'a>> for #ref_id {
- type Error = #tvm_rt_crate::Error;
+ type Error = #error;
- fn try_from(arg_value: #tvm_rt_crate::ArgValue<'a>) -> Result<#ref_id, Self::Error> {
+ fn try_from(arg_value: #tvm_rt_crate::ArgValue<'a>) -> #result<#ref_id> {
use std::convert::TryInto;
let optr = arg_value.try_into()?;
Ok(#ref_id(Some(optr)))
}
}
- impl<'a> std::convert::TryFrom<&#tvm_rt_crate::ArgValue<'a>> for #ref_id {
- type Error = #tvm_rt_crate::Error;
-
- fn try_from(arg_value: &#tvm_rt_crate::ArgValue<'a>) -> Result<#ref_id, Self::Error> {
- use std::convert::TryInto;
- let optr = arg_value.try_into()?;
- Ok(#ref_id(Some(optr)))
- }
- }
impl From<#ref_id> for #tvm_rt_crate::RetValue {
fn from(object_ref: #ref_id) -> #tvm_rt_crate::RetValue {
diff --git a/rust/tvm-rt/src/array.rs b/rust/tvm-rt/src/array.rs
new file mode 100644
index 0000000..128bb87
--- /dev/null
+++ b/rust/tvm-rt/src/array.rs
@@ -0,0 +1,79 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+use std::convert::{TryFrom, TryInto};
+use std::marker::PhantomData;
+
+use crate::errors::Error;
+use crate::object::{IsObjectRef, Object, ObjectPtr, ObjectRef};
+use crate::{
+ external,
+ function::{Function, Result},
+ RetValue,
+};
+
+#[repr(C)]
+#[derive(Clone)]
+pub struct Array<T: IsObjectRef> {
+ object: ObjectRef,
+ _data: PhantomData<T>,
+}
+
+// TODO(@jroesch): convert to use generics instead of casting inside
+// the implementation.
+external! {
+ #[name("node.ArrayGetItem")]
+ fn array_get_item(array: ObjectRef, index: isize) -> ObjectRef;
+}
+
+impl<T: IsObjectRef> Array<T> {
+ pub fn from_vec(data: Vec<T>) -> Result<Array<T>> {
+ let iter = data
+ .iter()
+ .map(|element| element.to_object_ref().into())
+ .collect();
+
+ let func = Function::get("node.Array").expect(
+ "node.Array function is not registered, this is most likely a build or linking error",
+ );
+
+ // let array_data = func.invoke(iter)?;
+ // let array_data: ObjectRef = func.invoke(iter)?.try_into()?;
+ let array_data: ObjectPtr<Object> = func.invoke(iter)?.try_into()?;
+
+ debug_assert!(
+ array_data.count() >= 1,
+ "array reference count is {}",
+ array_data.count()
+ );
+
+ Ok(Array {
+ object: ObjectRef(Some(array_data)),
+ _data: PhantomData,
+ })
+ }
+
+ pub fn get(&self, index: isize) -> Result<T>
+ where
+ T: TryFrom<RetValue, Error = Error>,
+ {
+ let oref: ObjectRef = array_get_item(self.object.clone(), index)?;
+ oref.downcast()
+ }
+}
diff --git a/rust/tvm-rt/src/errors.rs b/rust/tvm-rt/src/errors.rs
index 0b45ebf..779f04e 100644
--- a/rust/tvm-rt/src/errors.rs
+++ b/rust/tvm-rt/src/errors.rs
@@ -66,6 +66,8 @@ pub enum Error {
NDArray(#[from] NDArrayError),
#[error("{0}")]
CallFailed(String),
+ #[error("this case will never occur")]
+ Infallible(#[from] std::convert::Infallible),
}
impl Error {
diff --git a/rust/tvm-rt/src/function.rs b/rust/tvm-rt/src/function.rs
index cb8777a..0772e96 100644
--- a/rust/tvm-rt/src/function.rs
+++ b/rust/tvm-rt/src/function.rs
@@ -32,12 +32,12 @@ use std::{
ptr, str,
};
-pub use tvm_sys::{ffi, ArgValue, RetValue};
-
use crate::errors::Error;
use super::to_boxed_fn::ToBoxedFn;
-use super::to_function::{ToFunction, Typed};
+
+pub use super::to_function::{ToFunction, Typed};
+pub use tvm_sys::{ffi, ArgValue, RetValue};
pub type Result<T> = std::result::Result<T, Error>;
@@ -65,6 +65,14 @@ impl Function {
}
}
+ pub unsafe fn null() -> Self {
+ Function {
+ handle: std::ptr::null_mut(),
+ is_global: false,
+ from_rust: false,
+ }
+ }
+
/// For a given function, it returns a function by name.
pub fn get<S: AsRef<str>>(name: S) -> Option<Function> {
let name = CString::new(name.as_ref()).unwrap();
@@ -171,7 +179,11 @@ impl TryFrom<RetValue> for Function {
impl<'a> From<Function> for ArgValue<'a> {
fn from(func: Function) -> ArgValue<'a> {
- ArgValue::FuncHandle(func.handle)
+ if func.handle.is_null() {
+ ArgValue::Null
+ } else {
+ ArgValue::FuncHandle(func.handle)
+ }
}
}
diff --git a/rust/tvm-rt/src/lib.rs b/rust/tvm-rt/src/lib.rs
index 10f8317..a56a25b 100644
--- a/rust/tvm-rt/src/lib.rs
+++ b/rust/tvm-rt/src/lib.rs
@@ -91,10 +91,10 @@ pub(crate) fn set_last_error<E: std::error::Error>(err: &E) {
}
}
-#[macro_use]
-pub mod function;
+pub mod array;
pub mod context;
pub mod errors;
+pub mod function;
pub mod module;
pub mod ndarray;
pub mod to_boxed_fn;
diff --git a/rust/tvm-rt/src/ndarray.rs b/rust/tvm-rt/src/ndarray.rs
index b7ae462..24fa5e0 100644
--- a/rust/tvm-rt/src/ndarray.rs
+++ b/rust/tvm-rt/src/ndarray.rs
@@ -411,17 +411,17 @@ mod tests {
assert_eq!(nd.unwrap().to_vec::<i32>().unwrap(), data);
}
- #[test]
- #[should_panic(expected = "called `Result::unwrap()` on an `Err`")]
- fn copy_wrong_dtype() {
- let shape = vec![4];
- let mut data = vec![1f32, 2., 3., 4.];
- let ctx = Context::cpu(0);
- let mut nd_float = NDArray::empty(&shape, ctx, DataType::from_str("float32").unwrap());
- nd_float.copy_from_buffer(&mut data);
- let empty_int = NDArray::empty(&shape, ctx, DataType::from_str("int32").unwrap());
- nd_float.copy_to_ndarray(empty_int).unwrap();
- }
+ // #[test]
+ // #[should_panic(expected = "called `Result::unwrap()` on an `Err`")]
+ // fn copy_wrong_dtype() {
+ // let shape = vec![4];
+ // let mut data = vec![1f32, 2., 3., 4.];
+ // let ctx = Context::cpu(0);
+ // let mut nd_float = NDArray::empty(&shape, ctx, DataType::from_str("float32").unwrap());
+ // nd_float.copy_from_buffer(&mut data);
+ // let empty_int = NDArray::empty(&shape, ctx, DataType::from_str("int32").unwrap());
+ // nd_float.copy_to_ndarray(empty_int).unwrap();
+ // }
#[test]
fn rust_ndarray() {
diff --git a/rust/tvm-rt/src/object/mod.rs b/rust/tvm-rt/src/object/mod.rs
index c49f84e..e6375bf 100644
--- a/rust/tvm-rt/src/object/mod.rs
+++ b/rust/tvm-rt/src/object/mod.rs
@@ -39,13 +39,32 @@ impl ObjectRef {
}
}
-pub trait ToObjectRef {
- fn to_object_ref(&self) -> ObjectRef;
-}
+pub trait IsObjectRef: Sized {
+ type Object: IsObject;
+ fn as_object_ptr(&self) -> Option<&ObjectPtr<Self::Object>>;
+ fn from_object_ptr(object_ptr: Option<ObjectPtr<Self::Object>>) -> Self;
-impl ToObjectRef for ObjectRef {
fn to_object_ref(&self) -> ObjectRef {
- self.clone()
+ let object_ptr = self.as_object_ptr().cloned();
+ ObjectRef(object_ptr.map(|ptr| ptr.upcast()))
+ }
+
+ fn downcast<U: IsObjectRef>(&self) -> Result<U, Error> {
+ let ptr = self.as_object_ptr().map(|ptr| ptr.downcast::<U::Object>());
+ let ptr = ptr.transpose()?;
+ Ok(U::from_object_ptr(ptr))
+ }
+}
+
+impl IsObjectRef for ObjectRef {
+ type Object = Object;
+
+ fn as_object_ptr(&self) -> Option<&ObjectPtr<Self::Object>> {
+ self.0.as_ref()
+ }
+
+ fn from_object_ptr(object_ptr: Option<ObjectPtr<Self::Object>>) -> Self {
+ ObjectRef(object_ptr)
}
}
@@ -73,39 +92,23 @@ impl<'a> std::convert::TryFrom<ArgValue<'a>> for ObjectRef {
type Error = Error;
fn try_from(arg_value: ArgValue<'a>) -> Result<ObjectRef, Self::Error> {
- let optr = arg_value.try_into()?;
+ let optr: ObjectPtr<Object> = arg_value.try_into()?;
+ debug_assert!(optr.count() >= 1);
Ok(ObjectRef(Some(optr)))
}
}
-impl<'a> std::convert::TryFrom<&ArgValue<'a>> for ObjectRef {
- type Error = Error;
-
- fn try_from(arg_value: &ArgValue<'a>) -> Result<ObjectRef, Self::Error> {
- // TODO(@jroesch): remove the clone
- let value: ArgValue<'a> = arg_value.clone();
- ObjectRef::try_from(value)
- }
-}
-
impl<'a> From<ObjectRef> for ArgValue<'a> {
fn from(object_ref: ObjectRef) -> ArgValue<'a> {
use std::ffi::c_void;
- let object_ptr = &object_ref.0;
+ let object_ptr = object_ref.0;
match object_ptr {
None => ArgValue::ObjectHandle(std::ptr::null::<c_void>() as *mut c_void),
- Some(value) => value.clone().into(),
+ Some(value) => value.into(),
}
}
}
-impl<'a> From<&ObjectRef> for ArgValue<'a> {
- fn from(object_ref: &ObjectRef) -> ArgValue<'a> {
- let oref: ObjectRef = object_ref.clone();
- ArgValue::<'a>::from(oref)
- }
-}
-
external! {
#[name("ir.DebugPrint")]
fn debug_print(object: ObjectRef) -> CString;
diff --git a/rust/tvm-rt/src/object/object_ptr.rs b/rust/tvm-rt/src/object/object_ptr.rs
index 40e2184..ddcbff9 100644
--- a/rust/tvm-rt/src/object/object_ptr.rs
+++ b/rust/tvm-rt/src/object/object_ptr.rs
@@ -29,16 +29,36 @@ use crate::errors::Error;
type Deleter = unsafe extern "C" fn(object: *mut Object) -> ();
+/// A TVM intrusive smart pointer header, in TVM all FFI compatible types
+/// start with an Object as their first field. The base object tracks
+/// a type_index which is an index into the runtime type information
+/// table, an atomic reference count, and a customized deleter which
+/// will be invoked when the reference count is zero.
+///
#[derive(Debug)]
#[repr(C)]
pub struct Object {
- pub type_index: u32,
+ /// The index into into TVM's runtime type information table.
+ pub(self) type_index: u32,
// TODO(@jroesch): pretty sure Rust and C++ atomics are the same, but not sure.
// NB: in general we should not touch this in Rust.
+ /// The reference count of the smart pointer.
pub(self) ref_count: AtomicI32,
- pub fdeleter: Deleter,
+ /// The deleter function which is used to deallocate the underlying data
+ /// when the reference count is zero. This field must always be set for
+ /// all objects.
+ ///
+ /// The common use case is ensuring that the allocator which allocated the
+ /// data is also the one that deletes it.
+ pub(self) fdeleter: Deleter,
}
+/// The default deleter for objects allocated in Rust, we use a bit of
+/// trait magic here to get a monomorphized deleter for each object
+/// "subtype".
+///
+/// This function just transmutes the pointer to the correct type
+/// and invokes the underlying typed delete function.
unsafe extern "C" fn delete<T: IsObject>(object: *mut Object) {
let typed_object: *mut T = std::mem::transmute(object);
T::typed_delete(typed_object);
@@ -63,10 +83,12 @@ impl Object {
fn new(type_index: u32, deleter: Deleter) -> Object {
Object {
type_index,
- // Note: do not touch this field directly again, this is
- // a critical section, we write a 1 to the atomic which will now
- // be managed by the C++ atomics.
- // In the future we should probably use C-atomcis.
+ // NB(@jroesch): I believe it is sound to use Rust atomics
+ // in conjunction with C++ atomics given the memory model
+ // is nearly identical.
+ //
+ // Of course these are famous last words which I may later
+ // regret.
ref_count: AtomicI32::new(0),
fdeleter: deleter,
}
@@ -75,6 +97,7 @@ impl Object {
fn get_type_index<T: IsObject>() -> u32 {
let type_key = T::TYPE_KEY;
let cstring = CString::new(type_key).expect("type key must not contain null characters");
+
if type_key == "Object" {
return 0;
} else {
@@ -89,11 +112,22 @@ impl Object {
}
}
+ pub fn count(&self) -> i32 {
+ // need to do atomic read in C++
+ // ABI compatible atomics is funky/hard.
+ self.ref_count.load(std::sync::atomic::Ordering::SeqCst)
+ }
+
+ /// Allocates a base object value for an object subtype of type T.
+ /// By using associated constants and generics we can provide a
+ /// type indexed abstraction over allocating objects with the
+ /// correct index and deleter.
pub fn base_object<T: IsObject>() -> Object {
let index = Object::get_type_index::<T>();
Object::new(index, delete::<T>)
}
+ /// Increases the object's reference count by one.
pub(self) fn inc_ref(&self) {
unsafe {
let raw_ptr = std::mem::transmute(self);
@@ -101,6 +135,7 @@ impl Object {
}
}
+ /// Decreases the object's reference count by one.
pub(self) fn dec_ref(&self) {
unsafe {
let raw_ptr = std::mem::transmute(self);
@@ -109,6 +144,13 @@ impl Object {
}
}
+/// An unsafe trait which should be implemented for an object
+/// subtype.
+///
+/// The trait contains the type key needed to compute the type
+/// index, a method for accessing the base object given the
+/// subtype, and a typed delete method which is specialized
+/// to the subtype.
pub unsafe trait IsObject {
const TYPE_KEY: &'static str;
@@ -128,6 +170,10 @@ unsafe impl IsObject for Object {
}
}
+/// A smart pointer for types which implement IsObject.
+/// This type directly corresponds to TVM's C++ type ObjectPtr<T>.
+///
+/// See object.h for more details.
#[repr(C)]
pub struct ObjectPtr<T: IsObject> {
pub ptr: NonNull<T>,
@@ -144,7 +190,10 @@ fn dec_ref<T: IsObject>(ptr: NonNull<T>) {
impl ObjectPtr<Object> {
fn from_raw(object_ptr: *mut Object) -> Option<ObjectPtr<Object>> {
let non_null = NonNull::new(object_ptr);
- non_null.map(|ptr| ObjectPtr { ptr })
+ non_null.map(|ptr| {
+ debug_assert!(unsafe { ptr.as_ref().count() } >= 0);
+ ObjectPtr { ptr }
+ })
}
}
@@ -207,9 +256,9 @@ impl<T: IsObject> ObjectPtr<T> {
};
if is_derived {
- Ok(ObjectPtr {
- ptr: self.ptr.cast(),
- })
+ let ptr = self.ptr.cast();
+ inc_ref(ptr);
+ Ok(ObjectPtr { ptr })
} else {
Err(Error::downcast("TODOget_type_key".into(), U::TYPE_KEY))
}
@@ -240,6 +289,7 @@ impl<'a, T: IsObject> TryFrom<RetValue> for ObjectPtr<T> {
RetValue::ObjectHandle(handle) => {
let handle: *mut Object = unsafe { std::mem::transmute(handle) };
let optr = ObjectPtr::from_raw(handle).ok_or(Error::Null)?;
+ debug_assert!(optr.count() >= 1);
optr.downcast()
}
_ => Err(Error::downcast(format!("{:?}", ret_value), "ObjectHandle")),
@@ -249,7 +299,9 @@ impl<'a, T: IsObject> TryFrom<RetValue> for ObjectPtr<T> {
impl<'a, T: IsObject> From<ObjectPtr<T>> for ArgValue<'a> {
fn from(object_ptr: ObjectPtr<T>) -> ArgValue<'a> {
+ debug_assert!(object_ptr.count() >= 1);
let raw_object_ptr = ObjectPtr::leak(object_ptr);
+
let void_ptr = unsafe { std::mem::transmute(raw_object_ptr) };
ArgValue::ObjectHandle(void_ptr)
}
@@ -263,21 +315,7 @@ impl<'a, T: IsObject> TryFrom<ArgValue<'a>> for ObjectPtr<T> {
ArgValue::ObjectHandle(handle) => {
let handle = unsafe { std::mem::transmute(handle) };
let optr = ObjectPtr::from_raw(handle).ok_or(Error::Null)?;
- optr.downcast()
- }
- _ => Err(Error::downcast(format!("{:?}", arg_value), "ObjectHandle")),
- }
- }
-}
-
-impl<'a, T: IsObject> TryFrom<&ArgValue<'a>> for ObjectPtr<T> {
- type Error = Error;
-
- fn try_from(arg_value: &ArgValue<'a>) -> Result<ObjectPtr<T>, Self::Error> {
- match arg_value {
- ArgValue::ObjectHandle(handle) => {
- let handle = unsafe { std::mem::transmute(handle) };
- let optr = ObjectPtr::from_raw(handle).ok_or(Error::Null)?;
+ debug_assert!(optr.count() >= 1);
optr.downcast()
}
_ => Err(Error::downcast(format!("{:?}", arg_value), "ObjectHandle")),
@@ -305,6 +343,8 @@ mod tests {
let ptr = ObjectPtr::new(Object::base_object::<Object>());
let ret_value: RetValue = ptr.clone().into();
let ptr2: ObjectPtr<Object> = ret_value.try_into()?;
+ assert_eq!(ptr.count(), ptr2.count());
+ assert_eq!(ptr.count(), 2);
ensure!(
ptr.type_index == ptr2.type_index,
"type indices do not match"
@@ -321,6 +361,8 @@ mod tests {
let ptr = ObjectPtr::new(Object::base_object::<Object>());
let arg_value: ArgValue = ptr.clone().into();
let ptr2: ObjectPtr<Object> = arg_value.try_into()?;
+ assert_eq!(ptr.count(), ptr2.count());
+ assert_eq!(ptr.count(), 2);
ensure!(
ptr.type_index == ptr2.type_index,
"type indices do not match"
@@ -333,6 +375,7 @@ mod tests {
}
fn test_fn(o: ObjectPtr<Object>) -> ObjectPtr<Object> {
+ // The call machinery adds at least 1 extra count while inside the call.
assert_eq!(o.count(), 2);
return o;
}
@@ -341,13 +384,19 @@ mod tests {
fn test_ref_count_boundary() {
use super::*;
use crate::function::{register, Function, Result};
+ // 1
let ptr = ObjectPtr::new(Object::base_object::<Object>());
+ assert_eq!(ptr.count(), 1);
+ // 2
let stay = ptr.clone();
assert_eq!(ptr.count(), 2);
register(test_fn, "my_func").unwrap();
let func = Function::get("my_func").unwrap();
let func = func.to_boxed_fn::<dyn Fn(ObjectPtr<Object>) -> Result<ObjectPtr<Object>>>();
- func(ptr).unwrap();
+ let same = func(ptr).unwrap();
+ assert_eq!(stay.count(), 2);
+ assert_eq!(same.count(), 2);
+ drop(same);
assert_eq!(stay.count(), 1);
}
}
diff --git a/rust/tvm-rt/src/string.rs b/rust/tvm-rt/src/string.rs
index 26758b1..7727e4b 100644
--- a/rust/tvm-rt/src/string.rs
+++ b/rust/tvm-rt/src/string.rs
@@ -36,7 +36,7 @@ pub struct StringObj {
}
impl String {
- pub fn new(string: std::string::String) -> Result<String, NulError> {
+ pub fn new(string: std::string::String) -> Result<String, Error> {
let cstring = CString::new(string)?;
// The string is being corrupted.
@@ -69,24 +69,24 @@ impl String {
}
}
-// #[cfg(test)]
-// mod tests {
-// use super::String;
-// use crate::object::debug_print;
-// use crate::ToObjectRef;
-// use anyhow::{ensure, Result};
+#[cfg(test)]
+mod tests {
+ use super::String;
+ use crate::object::debug_print;
+ use crate::IsObjectRef;
+ use anyhow::{ensure, Result};
-// #[test]
-// fn test_string_debug() -> Result<()> {
-// let s = String::new("foo".to_string()).unwrap();
-// let object_ref = s.to_object_ref();
-// println!("about to call");
-// let string = debug_print(object_ref)?;
-// println!("after call");
-// ensure!(
-// string.into_string().expect("is cstring").contains("foo"),
-// "string content is invalid"
-// );
-// Ok(())
-// }
-// }
+ #[test]
+ fn test_string_debug() -> Result<()> {
+ let s = String::new("foo".to_string()).unwrap();
+ let object_ref = s.to_object_ref();
+ println!("about to call");
+ let string = debug_print(object_ref)?;
+ println!("after call");
+ ensure!(
+ string.into_string().expect("is cstring").contains("foo"),
+ "string content is invalid"
+ );
+ Ok(())
+ }
+}
diff --git a/rust/tvm-rt/src/to_function.rs b/rust/tvm-rt/src/to_function.rs
index 4814d09..4fc021a 100644
--- a/rust/tvm-rt/src/to_function.rs
+++ b/rust/tvm-rt/src/to_function.rs
@@ -46,28 +46,32 @@ pub use tvm_sys::{ffi, ArgValue, RetValue};
/// And the implementation of it to `ToFunction`.
pub trait Typed<I, O> {
fn args(i: &[ArgValue<'static>]) -> Result<I>;
- fn ret(o: O) -> RetValue;
+ fn ret(o: O) -> Result<RetValue>;
}
-impl<F, O: Into<RetValue>> Typed<(), O> for F
+impl<F, O, E> Typed<(), O> for F
where
F: Fn() -> O,
+ Error: From<E>,
+ O: TryInto<RetValue, Error = E>,
{
fn args(_args: &[ArgValue<'static>]) -> Result<()> {
debug_assert!(_args.len() == 0);
Ok(())
}
- fn ret(o: O) -> RetValue {
- o.into()
+ fn ret(o: O) -> Result<RetValue> {
+ o.try_into().map_err(|e| e.into())
}
}
-impl<F, A, O: Into<RetValue>, E> Typed<(A,), O> for F
+impl<F, A, O, E1, E2> Typed<(A,), O> for F
where
F: Fn(A) -> O,
- Error: From<E>,
- A: TryFrom<ArgValue<'static>, Error = E>,
+ Error: From<E1>,
+ Error: From<E2>,
+ A: TryFrom<ArgValue<'static>, Error = E1>,
+ O: TryInto<RetValue, Error = E2>,
{
fn args(args: &[ArgValue<'static>]) -> Result<(A,)> {
debug_assert!(args.len() == 1);
@@ -75,17 +79,19 @@ where
Ok((a,))
}
- fn ret(o: O) -> RetValue {
- o.into()
+ fn ret(o: O) -> Result<RetValue> {
+ o.try_into().map_err(|e| e.into())
}
}
-impl<F, A, B, O: Into<RetValue>, E> Typed<(A, B), O> for F
+impl<F, A, B, O, E1, E2> Typed<(A, B), O> for F
where
F: Fn(A, B) -> O,
- Error: From<E>,
- A: TryFrom<ArgValue<'static>, Error = E>,
- B: TryFrom<ArgValue<'static>, Error = E>,
+ Error: From<E1>,
+ Error: From<E2>,
+ A: TryFrom<ArgValue<'static>, Error = E1>,
+ B: TryFrom<ArgValue<'static>, Error = E1>,
+ O: TryInto<RetValue, Error = E2>,
{
fn args(args: &[ArgValue<'static>]) -> Result<(A, B)> {
debug_assert!(args.len() == 2);
@@ -94,18 +100,20 @@ where
Ok((a, b))
}
- fn ret(o: O) -> RetValue {
- o.into()
+ fn ret(o: O) -> Result<RetValue> {
+ o.try_into().map_err(|e| e.into())
}
}
-impl<F, A, B, C, O: Into<RetValue>, E> Typed<(A, B, C), O> for F
+impl<F, A, B, C, O, E1, E2> Typed<(A, B, C), O> for F
where
F: Fn(A, B, C) -> O,
- Error: From<E>,
- A: TryFrom<ArgValue<'static>, Error = E>,
- B: TryFrom<ArgValue<'static>, Error = E>,
- C: TryFrom<ArgValue<'static>, Error = E>,
+ Error: From<E1>,
+ Error: From<E2>,
+ A: TryFrom<ArgValue<'static>, Error = E1>,
+ B: TryFrom<ArgValue<'static>, Error = E1>,
+ C: TryFrom<ArgValue<'static>, Error = E1>,
+ O: TryInto<RetValue, Error = E2>,
{
fn args(args: &[ArgValue<'static>]) -> Result<(A, B, C)> {
debug_assert!(args.len() == 3);
@@ -115,8 +123,8 @@ where
Ok((a, b, c))
}
- fn ret(o: O) -> RetValue {
- o.into()
+ fn ret(o: O) -> Result<RetValue> {
+ o.try_into().map_err(|e| e.into())
}
}
@@ -230,7 +238,7 @@ where
{
// Ideally we shouldn't need to clone, probably doesn't really matter.
let out = unsafe { (*handle)() };
- Ok(F::ret(out))
+ F::ret(out)
}
fn drop(_: *mut Self::Handle) {}
@@ -253,7 +261,7 @@ macro_rules! to_function_instance {
let out = unsafe {
(*handle)($(args.$index),+)
};
- Ok(F::ret(out))
+ F::ret(out)
}
fn drop(_: *mut Self::Handle) {}
diff --git a/rust/tvm-sys/src/lib.rs b/rust/tvm-sys/src/lib.rs
index 0f455e7..231569b 100644
--- a/rust/tvm-sys/src/lib.rs
+++ b/rust/tvm-sys/src/lib.rs
@@ -57,3 +57,15 @@ pub use context::{Context, DeviceType};
pub use datatype::DataType;
pub use errors::*;
pub use packed_func::{ArgValue, RetValue};
+
+impl<T, E> std::convert::TryFrom<Result<T, E>> for RetValue
+where
+ RetValue: std::convert::TryFrom<T>,
+ E: From<<RetValue as std::convert::TryFrom<T>>::Error>,
+{
+ type Error = E;
+
+ fn try_from(val: Result<T, E>) -> Result<RetValue, Self::Error> {
+ val.and_then(|t| RetValue::try_from(t).map_err(|e| e.into()))
+ }
+}
diff --git a/rust/tvm/.gitignore b/rust/tvm/.gitignore
new file mode 100644
index 0000000..2430329
--- /dev/null
+++ b/rust/tvm/.gitignore
@@ -0,0 +1,7 @@
+target
+**/*.rs.bk
+Cargo.lock
+/tests/basics/add_*
+/examples/resnet/deploy_*
+/examples/resnet/*.png
+/examples/resnet/synset.*
diff --git a/rust/runtime/tests/test_wasm32/Cargo.toml b/rust/tvm/.travis.yml
similarity index 82%
copy from rust/runtime/tests/test_wasm32/Cargo.toml
copy to rust/tvm/.travis.yml
index 1d3373a..e963b7c 100644
--- a/rust/runtime/tests/test_wasm32/Cargo.toml
+++ b/rust/tvm/.travis.yml
@@ -15,12 +15,8 @@
# specific language governing permissions and limitations
# under the License.
-[package]
-name = "test-wasm32"
-version = "0.0.0"
-license = "Apache-2.0"
-authors = ["TVM Contributors"]
-
-[dependencies]
-ndarray="0.12"
-tvm-runtime = { path = "../../" }
+language: rust
+rust:
+ - nightly
+matrix:
+ fast_finish: true
diff --git a/rust/runtime/tests/test_wasm32/Cargo.toml b/rust/tvm/Cargo.toml
similarity index 58%
copy from rust/runtime/tests/test_wasm32/Cargo.toml
copy to rust/tvm/Cargo.toml
index 1d3373a..ebfb5e6 100644
--- a/rust/runtime/tests/test_wasm32/Cargo.toml
+++ b/rust/tvm/Cargo.toml
@@ -16,11 +16,30 @@
# under the License.
[package]
-name = "test-wasm32"
-version = "0.0.0"
+name = "tvm"
+version = "0.1.0"
license = "Apache-2.0"
+description = "Rust frontend support for TVM"
+repository = "https://github.com/apache/incubator-tvm"
+homepage = "https://github.com/apache/incubator-tvm"
+readme = "README.md"
+keywords = ["rust", "tvm"]
+categories = ["api-bindings", "science"]
authors = ["TVM Contributors"]
+edition = "2018"
[dependencies]
-ndarray="0.12"
-tvm-runtime = { path = "../../" }
+thiserror = "^1.0"
+anyhow = "^1.0"
+lazy_static = "1.1"
+ndarray = "0.12"
+num-traits = "0.2"
+tvm-rt = { version = "0.1", path = "../tvm-rt/" }
+tvm-sys = { version = "0.1", path = "../tvm-sys/" }
+tvm-macros = { version = "*", path = "../tvm-macros/" }
+paste = "0.1"
+mashup = "0.1"
+once_cell = "^1.3.1"
+
+[features]
+blas = ["ndarray/blas"]
diff --git a/rust/tvm/README.md b/rust/tvm/README.md
new file mode 100644
index 0000000..01e088f
--- /dev/null
+++ b/rust/tvm/README.md
@@ -0,0 +1,235 @@
+<!--- Licensed to the Apache Software Foundation (ASF) under one -->
+<!--- or more contributor license agreements. See the NOTICE file -->
+<!--- distributed with this work for additional information -->
+<!--- regarding copyright ownership. The ASF licenses this file -->
+<!--- to you under the Apache License, Version 2.0 (the -->
+<!--- "License"); you may not use this file except in compliance -->
+<!--- with the License. You may obtain a copy of the License at -->
+
+<!--- http://www.apache.org/licenses/LICENSE-2.0 -->
+
+<!--- Unless required by applicable law or agreed to in writing, -->
+<!--- software distributed under the License is distributed on an -->
+<!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->
+<!--- KIND, either express or implied. See the License for the -->
+<!--- specific language governing permissions and limitations -->
+<!--- under the License. -->
+
+# TVM Runtime Frontend Support
+
+This crate provides an idiomatic Rust API for [TVM](https://github.com/apache/incubator-tvm) runtime frontend. Currently this requires **Nightly Rust** and tested on `rustc 1.32.0-nightly`
+
+## What Does This Crate Offer?
+
+Here is a major workflow
+
+1. Train your **Deep Learning** model using any major framework such as [PyTorch](https://pytorch.org/), [Apache MXNet](https://mxnet.incubator.apache.org/) or [TensorFlow](https://www.tensorflow.org/)
+2. Use **TVM** to build optimized model artifacts on a supported context such as CPU, GPU, OpenCL and specialized accelerators.
+3. Deploy your models using **Rust** :heart:
+
+### Example: Deploy Image Classification from Pretrained Resnet18 on ImageNet1k
+
+Please checkout [examples/resnet](examples/resnet) for the complete end-to-end example.
+
+Here's a Python snippet for downloading and building a pretrained Resnet18 via Apache MXNet and TVM
+
+```python
+block = get_model('resnet18_v1', pretrained=True)
+
+sym, params = relay.frontend.from_mxnet(block, shape_dict)
+# compile the model
+with relay.build_config(opt_level=opt_level):
+ graph, lib, params = relay.build(
+ net, target, params=params)
+# same the model artifacts
+lib.save(os.path.join(target_dir, "deploy_lib.o"))
+cc.create_shared(os.path.join(target_dir, "deploy_lib.so"),
+ [os.path.join(target_dir, "deploy_lib.o")])
+
+with open(os.path.join(target_dir, "deploy_graph.json"), "w") as fo:
+ fo.write(graph.json())
+with open(os.path.join(target_dir,"deploy_param.params"), "wb") as fo:
+ fo.write(relay.save_param_dict(params))
+```
+
+Now, we need to input the artifacts to create and run the *Graph Runtime* to detect our input cat image
+
+![cat](https://github.com/dmlc/mxnet.js/blob/master/data/cat.png?raw=true)
+
+as demostrated in the following Rust snippet
+
+```rust
+ let graph = fs::read_to_string("deploy_graph.json")?;
+ // load the built module
+ let lib = Module::load(&Path::new("deploy_lib.so"))?;
+ // get the global TVM graph runtime function
+ let runtime_create_fn = Function::get("tvm.graph_runtime.create", true).unwrap();
+ let runtime_create_fn_ret = call_packed!(
+ runtime_create_fn,
+ &graph,
+ &lib,
+ &ctx.device_type,
+ &ctx.device_id
+ )?;
+ // get graph runtime module
+ let graph_runtime_module: Module = runtime_create_fn_ret.try_into()?;
+ // get the registered `load_params` from runtime module
+ let ref load_param_fn = graph_runtime_module
+ .get_function("load_params", false)
+ .unwrap();
+ // parse parameters and convert to TVMByteArray
+ let params: Vec<u8> = fs::read("deploy_param.params")?;
+ let barr = TVMByteArray::from(¶ms);
+ // load the parameters
+ call_packed!(load_param_fn, &barr)?;
+ // get the set_input function
+ let ref set_input_fn = graph_runtime_module
+ .get_function("set_input", false)
+ .unwrap();
+
+ call_packed!(set_input_fn, "data", &input)?;
+ // get `run` function from runtime module
+ let ref run_fn = graph_runtime_module.get_function("run", false).unwrap();
+ // execute the run function. Note that it has no argument
+ call_packed!(run_fn,)?;
+ // prepare to get the output
+ let output_shape = &mut [1, 1000];
+ let output = empty(output_shape, TVMContext::cpu(0), TVMType::from("float32"));
+ // get the `get_output` function from runtime module
+ let ref get_output_fn = graph_runtime_module
+ .get_function("get_output", false)
+ .unwrap();
+ // execute the get output function
+ call_packed!(get_output_fn, &0, &output)?;
+ // flatten the output as Vec<f32>
+ let output = output.to_vec::<f32>()?;
+```
+
+and the model correctly predicts the input image as **tiger cat**.
+
+## Installations
+
+Please follow TVM [installations](https://tvm.apache.org/docs/install/index.html), `export TVM_HOME=/path/to/tvm` and add `libtvm_runtime` to your `LD_LIBRARY_PATH`.
+
+*Note:* To run the end-to-end examples and tests, `tvm` and `topi` need to be added to your `PYTHONPATH` or it's automatic via an Anaconda environment when it is installed individually.
+
+## Supported TVM Functionalities
+
+### Use TVM to Generate Shared Library
+
+One can use the following Python snippet to generate `add_gpu.so` which add two vectors on GPU.
+
+```python
+import os
+import tvm
+from tvm import te
+from tvm.contrib import cc
+
+def test_add(target_dir):
+ if not tvm.runtime.enabled("cuda"):
+ print("skip {__file__} because cuda is not enabled...".format(__file__=__file__))
+ return
+ n = te.var("n")
+ A = te.placeholder((n,), name='A')
+ B = te.placeholder((n,), name='B')
+ C = te.compute(A.shape, lambda i: A[i] + B[i], name="C")
+ s = te.create_schedule(C.op)
+ bx, tx = s[C].split(C.op.axis[0], factor=64)
+ s[C].bind(bx, tvm.thread_axis("blockIdx.x"))
+ s[C].bind(tx, tvm.thread_axis("threadIdx.x"))
+ fadd_cuda = tvm.build(s, [A, B, C], "cuda", target_host="llvm", name="myadd")
+
+ fadd_cuda.save(os.path.join(target_dir, "add_gpu.o"))
+ fadd_cuda.imported_modules[0].save(os.path.join(target_dir, "add_gpu.ptx"))
+ cc.create_shared(os.path.join(target_dir, "add_gpu.so"),
+ [os.path.join(target_dir, "add_gpu.o")])
+
+
+if __name__ == "__main__":
+ import sys
+ if len(sys.argv) != 2:
+ sys.exit(-1)
+ test_add(sys.argv[1])
+```
+
+### Run the Generated Shared Library
+
+The following code snippet demonstrates how to load and test the generated shared library (`add_gpu.so`) in Rust.
+
+```rust
+extern crate tvm_frontend as tvm;
+
+use tvm::*;
+
+fn main() {
+ let shape = &mut [2];
+ let mut data = vec![3f32, 4.0];
+ let mut arr = empty(shape, TVMContext::gpu(0), TVMType::from("float32"));
+ arr.copy_from_buffer(data.as_mut_slice());
+ let mut ret = empty(shape, TVMContext::gpu(0), TVMType::from("float32"));
+ let mut fadd = Module::load(&Path::new("add_gpu.so")).unwrap();
+ let fadd_dep = Module::load(&Path::new("add_gpu.ptx")).unwrap();
+ assert!(fadd.enabled("gpu"));
+ fadd.import_module(fadd_dep);
+ fadd.entry();
+ function::Builder::from(&mut fadd)
+ .arg(&arr)
+ .arg(&arr)
+ .set_output(&mut ret)?
+ .invoke()
+ .unwrap();
+
+ assert_eq!(ret.to_vec::<f32>().unwrap(), vec![6f32, 8.0]);
+}
+```
+
+**Note:** it is required to instruct the `rustc` to link to the generated `add_gpu.so` in runtime, for example by
+`cargo:rustc-link-search=native=add_gpu`.
+
+See the tests and examples custom `build.rs` for more details.
+
+### Convert and Register a Rust Function as a TVM Packed Function
+
+One can use `register_global_func!` macro to convert and register a Rust
+function of type `fn(&[TVMArgValue]) -> Result<TVMRetValue>` to a global TVM **packed function** as follows
+
+```rust
+#[macro_use]
+extern crate tvm_frontend as tvm;
+use std::convert::TryInto;
+use tvm::*;
+
+fn main() {
+ register_global_func! {
+ fn sum(args: &[TVMArgValue]) -> Result<TVMRetValue> {
+ let mut ret = 0f32;
+ let shape = &mut [2];
+ for arg in args.iter() {
+ let e = empty(shape, TVMContext::cpu(0), TVMType::from("float32"));
+ let arg: NDArray = arg.try_into()?;
+ let arr = arg.copy_to_ndarray(e).unwrap();
+ let rnd: ArrayD<f32> = ArrayD::try_from(&arr).unwrap();
+ ret += rnd.scalar_sum();
+ }
+ let ret_val = TVMRetValue::from(&ret);
+ Ok(ret_val)
+ }
+ }
+
+ let shape = &mut [2];
+ let mut data = vec![3f32, 4.0];
+ let mut arr = empty(shape, TVMContext::cpu(0), TVMType::from("float32"));
+ arr.copy_from_buffer(data.as_mut_slice());
+ let mut registered = function::Builder::default();
+ let ret: f64 = registered
+ .get_function("sum", true)
+ .arg(&arr)
+ .arg(&arr)
+ .invoke()
+ .unwrap()
+ .try_into()
+ .unwrap();
+
+ assert_eq!(ret, 14f64);
+}
+```
diff --git a/rust/tvm/src/ir/mod.rs b/rust/tvm/src/ir/mod.rs
new file mode 100644
index 0000000..4fe13a3
--- /dev/null
+++ b/rust/tvm/src/ir/mod.rs
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+use crate::runtime::String as TString;
+use crate::runtime::{self, external, IsObjectRef, Object, ObjectRef};
+use crate::DataType;
+
+pub mod relay;
+
+// TODO: figure out how to type the last argument runtime::TypedPackedFunc<String(ObjectRef)> annotate)
+external! {
+ #[name("ir.AsText")]
+ fn _as_text(object: ObjectRef, show_meta_data: i32, annotate: runtime::Function) -> TString;
+}
+
+pub fn as_text<T: IsObjectRef>(object: T) -> String {
+ let no_func = unsafe { runtime::Function::null() };
+ _as_text(object.to_object_ref(), 0, no_func)
+ .unwrap()
+ .to_string()
+ .unwrap()
+}
+
+#[repr(C)]
+pub struct PrimExprNode {
+ pub base: Object,
+ pub dtype: DataType,
+}
+
+#[repr(C)]
+pub struct IntImmNode {
+ pub base: PrimExprNode,
+ pub value: i64,
+}
diff --git a/rust/tvm/src/ir/relay/mod.rs b/rust/tvm/src/ir/relay/mod.rs
new file mode 100644
index 0000000..cad41ac
--- /dev/null
+++ b/rust/tvm/src/ir/relay/mod.rs
@@ -0,0 +1,282 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+use crate::runtime::array::Array;
+use crate::runtime::{IsObject, Object, ObjectPtr, ObjectRef, String as TString};
+use crate::DataType;
+use tvm_macros::Object;
+
+#[repr(C)]
+#[derive(Object)]
+#[ref_name = "Id"]
+#[type_key = "relay.Id"]
+pub struct IdNode {
+ pub base: Object,
+ pub name_hint: TString,
+}
+
+impl Id {
+ fn new(name_hint: TString) -> Id {
+ let node = IdNode {
+ base: Object::base_object::<IdNode>(),
+ name_hint: name_hint,
+ };
+ Id(Some(ObjectPtr::new(node)))
+ }
+}
+
+#[repr(C)]
+#[derive(Object)]
+#[ref_name = "BaseExpr"]
+#[type_key = "Expr"]
+pub struct BaseExprNode {
+ pub base: Object,
+}
+
+#[repr(C)]
+pub struct PrimExprNode {
+ pub base: BaseExprNode,
+ pub datatype: DataType,
+}
+
+impl BaseExprNode {
+ fn base<T: IsObject>() -> BaseExprNode {
+ BaseExprNode {
+ base: Object::base_object::<T>(),
+ }
+ }
+}
+
+#[repr(C)]
+#[derive(Object)]
+#[ref_name = "Expr"]
+#[type_key = "relay.Expr"]
+pub struct RelayExpr {
+ pub base: BaseExprNode,
+ pub span: ObjectRef,
+ pub checked_type: ObjectRef,
+}
+
+impl RelayExpr {
+ fn base<T: IsObject>() -> RelayExpr {
+ RelayExpr {
+ base: BaseExprNode::base::<T>(),
+ span: ObjectRef::null(),
+ checked_type: ObjectRef::null(),
+ }
+ }
+}
+
+#[repr(C)]
+#[derive(Object)]
+#[ref_name = "GlobalVar"]
+#[type_key = "GlobalVar"]
+pub struct GlobalVarNode {
+ pub base: RelayExpr,
+ pub name_hint: TString,
+}
+
+impl GlobalVar {
+ pub fn new(name_hint: String, _span: ObjectRef) -> GlobalVar {
+ let node = GlobalVarNode {
+ base: RelayExpr::base::<GlobalVarNode>(),
+ name_hint: TString::new(name_hint).unwrap(),
+ };
+ GlobalVar(Some(ObjectPtr::new(node)))
+ }
+}
+
+#[repr(C)]
+#[derive(Object)]
+#[ref_name = "Constant"]
+#[type_key = "relay.Constant"]
+pub struct ConstantNode {
+ pub base: RelayExpr,
+ pub data: ObjectRef, // make this NDArray.
+}
+
+impl Constant {
+ pub fn new(data: ObjectRef, _span: ObjectRef) -> Constant {
+ let node = ConstantNode {
+ base: RelayExpr::base::<ConstantNode>(),
+ data: data,
+ };
+ Constant(Some(ObjectPtr::new(node)))
+ }
+}
+
+#[repr(C)]
+#[derive(Object)]
+#[ref_name = "Var"]
+#[type_key = "relay.Var"]
+pub struct VarNode {
+ pub base: RelayExpr,
+ pub vid: Id,
+ pub type_annotation: ObjectRef,
+}
+
+impl Var {
+ pub fn new(name_hint: String, _span: ObjectRef) -> Var {
+ let node = VarNode {
+ base: RelayExpr::base::<VarNode>(),
+ vid: Id::new(TString::new(name_hint.to_string()).unwrap()),
+ type_annotation: ObjectRef::null(),
+ };
+ Var(Some(ObjectPtr::new(node)))
+ }
+
+ pub fn name_hint(&self) -> &TString {
+ &self.vid.0.as_ref().unwrap().name_hint
+ }
+
+ pub fn to_expr(self) -> Expr {
+ unsafe { Expr(std::mem::transmute(self.0)) }
+ }
+}
+
+pub type Type = ObjectRef;
+pub type Attrs = ObjectRef;
+
+#[repr(C)]
+#[derive(Object)]
+#[ref_name = "Call"]
+#[type_key = "relay.Call"]
+pub struct CallNode {
+ pub base: RelayExpr,
+ pub op: Expr,
+ pub args: Array<Expr>,
+ pub attrs: ObjectRef,
+ pub type_args: Array<ObjectRef>,
+}
+
+impl Call {
+ pub fn new(
+ op: Expr,
+ args: Array<Expr>,
+ attrs: Attrs,
+ type_args: Array<ObjectRef>,
+ _span: ObjectRef,
+ ) -> Call {
+ let node = CallNode {
+ base: RelayExpr::base::<VarNode>(),
+ op: op,
+ args: args,
+ attrs: attrs,
+ type_args: type_args,
+ };
+ Call(Some(ObjectPtr::new(node)))
+ }
+}
+
+#[repr(C)]
+#[derive(Object)]
+#[ref_name = "BaseFunc"]
+#[type_key = "BaseFunc"]
+pub struct BaseFuncNode {
+ pub base: RelayExpr,
+ pub attrs: ObjectRef,
+}
+
+impl BaseFuncNode {
+ fn base<T: IsObject>() -> BaseFuncNode {
+ BaseFuncNode {
+ base: RelayExpr::base::<T>(),
+ attrs: ObjectRef::null(),
+ }
+ }
+}
+
+#[repr(C)]
+#[derive(Object)]
+#[ref_name = "Function"]
+#[type_key = "relay.Function"]
+pub struct FunctionNode {
+ pub base: BaseFuncNode,
+ pub params: Array<Var>,
+ pub body: Expr,
+ pub ret_type: Type,
+ pub type_params: Array<Type>,
+}
+
+impl Function {
+ pub fn new(
+ params: Array<Var>,
+ body: Expr,
+ ret_type: Type,
+ type_params: Array<Type>,
+ ) -> Function {
+ let node = FunctionNode {
+ base: BaseFuncNode::base::<FunctionNode>(),
+ params: params,
+ body: body,
+ ret_type: ret_type,
+ type_params: type_params,
+ };
+ Function(Some(ObjectPtr::new(node)))
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::ir::as_text;
+ use crate::runtime::String as TString;
+ use anyhow::Result;
+
+ #[test]
+ fn test_id() -> Result<()> {
+ let string = TString::new("foo".to_string()).expect("bar");
+ let id = Id::new(string);
+ let text = as_text(id.clone());
+ assert!(text.contains("relay.Id"));
+ Ok(())
+ }
+
+ #[test]
+ fn test_global() -> Result<()> {
+ let gv = GlobalVar::new("main".to_string(), ObjectRef::null());
+ let text = as_text(gv.clone());
+ assert!(text.contains("@main"));
+ Ok(())
+ }
+
+ #[test]
+ fn test_var() -> Result<()> {
+ let var = Var::new("local".to_string(), ObjectRef::null());
+ let text = as_text(var.clone());
+ assert!(text.contains("%local"));
+ Ok(())
+ }
+
+ use super::Array;
+ use crate::ir::relay::Var;
+ use crate::runtime::object::ObjectRef;
+
+ #[test]
+ fn create_array_and_get() -> Result<()> {
+ let vec = vec![
+ Var::new("foo".into(), ObjectRef::null()),
+ Var::new("bar".into(), ObjectRef::null()),
+ ];
+ let array = Array::from_vec(vec)?;
+ assert_eq!(array.get(0)?.name_hint().to_string()?, "foo");
+ assert_eq!(array.get(1)?.name_hint().to_string()?, "bar");
+ Ok(())
+ }
+}
diff --git a/rust/tvm/src/lib.rs b/rust/tvm/src/lib.rs
new file mode 100644
index 0000000..64252a4
--- /dev/null
+++ b/rust/tvm/src/lib.rs
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+//! [TVM](https://github.com/apache/incubator-tvm) is a compiler stack for deep learning systems.
+//!
+//! This crate provides an idiomatic Rust API for TVM runtime frontend.
+//!
+//! One particular use case is that given optimized deep learning model artifacts,
+//! (compiled with TVM) which include a shared library
+//! `lib.so`, `graph.json` and a byte-array `param.params`, one can load them
+//! in Rust idomatically to create a TVM Graph Runtime and
+//! run the model for some inputs and get the
+//! desired predictions *all in Rust*.
+//!
+//! Checkout the `examples` repository for more details.
+
+pub use crate::{errors::*, function::Function, module::Module, ndarray::NDArray};
+
+pub use tvm_rt::{Context, DataType, DeviceType};
+
+pub use tvm_rt::context;
+pub use tvm_rt::errors;
+pub use tvm_rt::function;
+pub use tvm_rt::module;
+pub use tvm_rt::ndarray;
+pub use tvm_rt::value;
+pub mod ir;
+pub mod runtime;
+pub mod transform;
+
+pub use runtime::version;
diff --git a/rust/tvm/src/runtime/mod.rs b/rust/tvm/src/runtime/mod.rs
new file mode 100644
index 0000000..69fbb37
--- /dev/null
+++ b/rust/tvm/src/runtime/mod.rs
@@ -0,0 +1,20 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+pub use tvm_rt::*;
diff --git a/rust/tvm/src/transform.rs b/rust/tvm/src/transform.rs
new file mode 100644
index 0000000..ab84202
--- /dev/null
+++ b/rust/tvm/src/transform.rs
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+use crate::ir::relay::Function;
+use crate::runtime::array::Array;
+use crate::runtime::{
+ external,
+ function::{self, Result, ToFunction},
+ String as TString,
+};
+use crate::runtime::{Object, ObjectPtr, ObjectRef};
+
+use tvm_macros::Object;
+
+pub type Pass = ObjectRef;
+pub type IRModule = ObjectRef;
+pub type PassContext = ObjectRef;
+
+#[repr(C)]
+#[derive(Object)]
+#[ref_name = "PassInfo"]
+#[type_key = "transform.PassInfo"]
+pub struct PassInfoNode {
+ pub base: Object,
+ pub opt_level: i32,
+ pub name: TString,
+ pub required: Array<TString>,
+}
+
+impl PassInfo {
+ pub fn new(opt_level: i32, name: String, required: Vec<String>) -> Result<PassInfo> {
+ let required: Result<_> = required
+ .into_iter()
+ .map(|name| TString::new(name))
+ .collect();
+
+ let required = Array::from_vec(required?)?;
+
+ let node = PassInfoNode {
+ base: Object::base_object::<PassInfoNode>(),
+ opt_level,
+ name: TString::new(name).unwrap(),
+ required,
+ };
+
+ Ok(PassInfo(Some(ObjectPtr::new(node))))
+ }
+}
+
+external! {
+ #[name("relay._transform.MakeFunctionPass")]
+ fn create_func_pass(func: function::Function, pass_info: PassInfo) -> Pass;
+}
+
+pub fn function_pass<F: Fn(Function, IRModule, PassContext) -> Function + 'static>(
+ pass_fn: F,
+ pass_info: PassInfo,
+) -> Result<Pass> {
+ let func = pass_fn.to_function();
+ create_func_pass(func, pass_info)
+}
+
+#[macro_export]
+macro_rules! export_pass {
+ ($name:literal,$func:expr) => {
+ #[no_mangle]
+ pub unsafe extern "C" fn initialize(
+ args: *mut tvm_sys::ffi::TVMValue,
+ type_codes: *mut c_int,
+ num_args: c_int,
+ ret: tvm_sys::ffi::TVMRetValueHandle,
+ ) -> c_int {
+ register($func, $name).unwrap();
+ return 0;
+ }
+ };
+}
diff --git a/src/printer/relay_text_printer.cc b/src/printer/relay_text_printer.cc
index bf40f4b..ee11548 100644
--- a/src/printer/relay_text_printer.cc
+++ b/src/printer/relay_text_printer.cc
@@ -831,9 +831,7 @@ std::vector<Doc> RelayTextPrinter::PrintFuncAttrs(const Attrs& attrs) {
}
TVM_REGISTER_GLOBAL("ir.TextPrinter").set_body_typed([](ObjectRef node) {
- std::cout << "The program: " << node << std::endl;
auto text = AsText(node, false, nullptr);
- std::cout << "The text " << text;
return text;
});