You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by jr...@apache.org on 2020/06/10 09:07:30 UTC
[incubator-tvm] branch master updated: [Rust] Second stage of Rust
Refactor (#5527)
This is an automated email from the ASF dual-hosted git repository.
jroesch pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new 60cfb79 [Rust] Second stage of Rust Refactor (#5527)
60cfb79 is described below
commit 60cfb79ce79eb47e4cebc04321fee7f67bfd404b
Author: Jared Roesch <jr...@octoml.ai>
AuthorDate: Wed Jun 10 02:07:19 2020 -0700
[Rust] Second stage of Rust Refactor (#5527)
* Add tvm-rt crate
* Backport changes from frontend branch
* Format
* Add ASF headers
* Address self-code review
* Replace with helper
* Fix lint
* Fix
* Clean up repro debugging
* WIP
* Remove global resgistry to fix one memory issue
* Fix
* Format
* Format
* Update rust/tvm-rt/README.md
Co-authored-by: Jason Knight <bi...@gmail.com>
* Format
* Duplicate TVM macros
* Split macros
* Restore old macro for old crates
* Repair macros
* Fix format
* Format
Co-authored-by: Jason Knight <bi...@gmail.com>
---
include/tvm/ir/expr.h | 2 +
include/tvm/runtime/c_runtime_api.h | 19 +
include/tvm/runtime/container.h | 17 +-
python/tvm/runtime/object_generic.py | 2 +-
rust/Cargo.toml | 3 +-
rust/macros/Cargo.toml | 6 +-
rust/runtime/Cargo.toml | 2 +-
rust/runtime/src/lib.rs | 1 +
rust/{macros => tvm-macros}/Cargo.toml | 4 +-
rust/tvm-macros/src/external.rs | 160 ++++++++
rust/tvm-macros/src/import_module.rs | 133 +++++++
.../src/errors.rs => tvm-macros/src/lib.rs} | 37 +-
rust/tvm-macros/src/object.rs | 163 ++++++++
.../src/errors.rs => tvm-macros/src/util.rs} | 32 +-
rust/tvm-rt/.gitignore | 7 +
rust/{runtime => tvm-rt}/Cargo.toml | 34 +-
rust/tvm-rt/README.md | 60 +++
rust/tvm-rt/src/context.rs | 97 +++++
rust/tvm-rt/src/errors.rs | 78 ++++
rust/tvm-rt/src/function.rs | 303 ++++++++++++++
rust/tvm-rt/src/lib.rs | 130 ++++++
rust/tvm-rt/src/module.rs | 129 ++++++
rust/tvm-rt/src/ndarray.rs | 438 +++++++++++++++++++++
rust/tvm-rt/src/object/mod.rs | 117 ++++++
rust/tvm-rt/src/object/object_ptr.rs | 353 +++++++++++++++++
rust/tvm-rt/src/string.rs | 92 +++++
rust/tvm-rt/src/to_boxed_fn.rs | 227 +++++++++++
rust/tvm-rt/src/to_function.rs | 307 +++++++++++++++
rust/tvm-rt/src/value.rs | 161 ++++++++
rust/tvm-sys/src/byte_array.rs | 36 ++
rust/tvm-sys/src/datatype.rs | 10 +
rust/tvm-sys/src/errors.rs | 2 +-
rust/tvm-sys/src/lib.rs | 9 +-
src/ir/expr.cc | 7 +
src/printer/relay_text_printer.cc | 7 +
src/relay/transforms/to_cps.cc | 2 +-
src/runtime/object.cc | 13 +
src/runtime/object_internal.h | 9 +
38 files changed, 3131 insertions(+), 78 deletions(-)
diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h
index 6797f16..b2ce50d 100644
--- a/include/tvm/ir/expr.h
+++ b/include/tvm/ir/expr.h
@@ -37,6 +37,8 @@
namespace tvm {
+using tvm::runtime::String;
+
/*!
* \brief Base type of all the expressions.
* \sa Expr
diff --git a/include/tvm/runtime/c_runtime_api.h b/include/tvm/runtime/c_runtime_api.h
index bf24f99..213c705 100644
--- a/include/tvm/runtime/c_runtime_api.h
+++ b/include/tvm/runtime/c_runtime_api.h
@@ -515,6 +515,15 @@ TVM_DLL int TVMObjectGetTypeIndex(TVMObjectHandle obj, unsigned* out_tindex);
TVM_DLL int TVMObjectTypeKey2Index(const char* type_key, unsigned* out_tindex);
/*!
+ * \brief Increase the reference count of an object.
+ *
+ * \param obj The object handle.
+ * \note Internally we increase the reference counter of the object.
+ * \return 0 when success, -1 when failure happens
+ */
+TVM_DLL int TVMObjectRetain(TVMObjectHandle obj);
+
+/*!
* \brief Free the object.
*
* \param obj The object handle.
@@ -564,6 +573,16 @@ TVM_DLL int TVMDeviceCopyDataFromTo(const void* from, size_t from_offset, void*
TVMContext ctx_to, DLDataType type_hint,
TVMStreamHandle stream);
+/*!
+ * \brief Check that an object is derived from another.
+ * \param child_type_index The type index of the derived type.
+ * \param parent_type_index The type index of the parent type.
+ * \param is_derived A boolean representing whether this predicate holds.
+ * \return 0 when success, -1 when failure happens.
+ */
+TVM_DLL int TVMObjectDerivedFrom(uint32_t child_type_index, uint32_t parent_type_index,
+ int* is_derived);
+
#ifdef __cplusplus
} // TVM_EXTERN_C
#endif
diff --git a/include/tvm/runtime/container.h b/include/tvm/runtime/container.h
index 6bc6fbf..2b3eb92 100644
--- a/include/tvm/runtime/container.h
+++ b/include/tvm/runtime/container.h
@@ -511,11 +511,20 @@ class ArrayNode : public Object, public InplaceArrayBase<ArrayNode, ObjectRef> {
};
/*!
- * \brief Array container of ObjectRef in DSL graph.
- * Array implements copy-on-write semantics, which means array is mutable
- * but copy will happen when array is referenced in more than two places.
+ * \brief Array, container representing a contigious sequence of ObjectRefs.
*
- * operator[] only provide const access, use Set to mutate the content.
+ * Array implements in-place copy-on-write semantics.
+ *
+ * As in typical copy-on-write, a method which would typically mutate the array
+ * instead opaquely copies the underlying container, and then acts on its copy.
+ *
+ * If the array has reference count equal to one, we directly update the
+ * container in place without copying. This is optimization is sound because
+ * when the reference count is equal to one this reference is guranteed to be
+ * the sole pointer to the container.
+ *
+ *
+ * operator[] only provides const access, use Set to mutate the content.
* \tparam T The content ObjectRef type.
*/
template <typename T,
diff --git a/python/tvm/runtime/object_generic.py b/python/tvm/runtime/object_generic.py
index cc21450..8f559ae 100644
--- a/python/tvm/runtime/object_generic.py
+++ b/python/tvm/runtime/object_generic.py
@@ -38,7 +38,7 @@ ObjectTypes = (ObjectBase, NDArrayBase, Module, ObjectRValueRef, PyNativeObject)
def convert_to_object(value):
- """Convert a python value to corresponding object type.
+ """Convert a Python value to corresponding object type.
Parameters
----------
diff --git a/rust/Cargo.toml b/rust/Cargo.toml
index b4a159c..6849c03 100644
--- a/rust/Cargo.toml
+++ b/rust/Cargo.toml
@@ -28,5 +28,6 @@ members = [
"frontend/tests/basics",
"frontend/tests/callback",
"frontend/examples/resnet",
- "tvm-sys"
+ "tvm-sys",
+ "tvm-rt"
]
diff --git a/rust/macros/Cargo.toml b/rust/macros/Cargo.toml
index 784b35e..97ebeca 100644
--- a/rust/macros/Cargo.toml
+++ b/rust/macros/Cargo.toml
@@ -16,7 +16,7 @@
# under the License.
[package]
-name = "tvm-macros"
+name = "old-tvm-macros"
version = "0.1.1"
license = "Apache-2.0"
description = "Procedural macros of the TVM crate."
@@ -32,5 +32,5 @@ proc-macro = true
[dependencies]
goblin = "0.0.24"
proc-macro2 = "^1.0"
-quote = "1.0"
-syn = "1.0"
+quote = "^1.0"
+syn = { version = "1.0.17", features = ["full", "extra-traits"] }
diff --git a/rust/runtime/Cargo.toml b/rust/runtime/Cargo.toml
index eb531f9..cc149d4 100644
--- a/rust/runtime/Cargo.toml
+++ b/rust/runtime/Cargo.toml
@@ -39,7 +39,7 @@ serde = "1.0"
serde_derive = "1.0"
serde_json = "1.0"
tvm-common = { version = "0.1", path = "../common" }
-tvm-macros = { version = "0.1", path = "../macros" }
+old-tvm-macros = { version = "0.1", path = "../macros" }
[target.'cfg(not(any(target_arch = "wasm32", target_env = "sgx")))'.dependencies]
libloading = "0.5"
diff --git a/rust/runtime/src/lib.rs b/rust/runtime/src/lib.rs
index de1b79d..07aaaae 100644
--- a/rust/runtime/src/lib.rs
+++ b/rust/runtime/src/lib.rs
@@ -41,6 +41,7 @@ extern crate num_cpus;
extern crate serde;
#[macro_use]
extern crate serde_derive;
+extern crate old_tvm_macros as tvm_macros;
extern crate serde_json;
extern crate tvm_common;
diff --git a/rust/macros/Cargo.toml b/rust/tvm-macros/Cargo.toml
similarity index 93%
copy from rust/macros/Cargo.toml
copy to rust/tvm-macros/Cargo.toml
index 784b35e..7abc9ae 100644
--- a/rust/macros/Cargo.toml
+++ b/rust/tvm-macros/Cargo.toml
@@ -32,5 +32,5 @@ proc-macro = true
[dependencies]
goblin = "0.0.24"
proc-macro2 = "^1.0"
-quote = "1.0"
-syn = "1.0"
+quote = "^1.0"
+syn = { version = "1.0.17", features = ["full", "extra-traits"] }
diff --git a/rust/tvm-macros/src/external.rs b/rust/tvm-macros/src/external.rs
new file mode 100644
index 0000000..8833d60
--- /dev/null
+++ b/rust/tvm-macros/src/external.rs
@@ -0,0 +1,160 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+use proc_macro2::Span;
+use quote::quote;
+use syn::parse::{Parse, ParseStream, Result};
+
+use syn::{FnArg, Generics, Ident, Lit, Meta, NestedMeta, Pat, ReturnType, TraitItemMethod, Type};
+
+struct External {
+ tvm_name: String,
+ ident: Ident,
+ generics: Generics,
+ inputs: Vec<FnArg>,
+ ret_type: ReturnType,
+}
+
+impl Parse for External {
+ fn parse(input: ParseStream) -> Result<Self> {
+ let method: TraitItemMethod = input.parse()?;
+ assert_eq!(method.attrs.len(), 1);
+ let sig = method.sig;
+ let tvm_name = method.attrs[0].parse_meta()?;
+ let tvm_name = match tvm_name {
+ Meta::List(meta_list) => {
+ let name = meta_list.path.get_ident().expect("name");
+ assert_eq!(name.to_string(), "name".to_string());
+ match meta_list.nested.first() {
+ Some(NestedMeta::Lit(Lit::Str(lit))) => lit.value(),
+ _ => panic!(),
+ }
+ }
+ _ => panic!(),
+ };
+ assert_eq!(method.default, None);
+ assert!(method.semi_token != None);
+ let ident = sig.ident;
+ let generics = sig.generics;
+ let inputs = sig.inputs.iter().map(|param| param.clone()).collect();
+ let ret_type = sig.output;
+
+ Ok(External {
+ tvm_name,
+ ident,
+ generics,
+ inputs,
+ ret_type,
+ })
+ }
+}
+
+struct ExternalInput {
+ externs: Vec<External>,
+}
+
+impl Parse for ExternalInput {
+ fn parse(input: ParseStream) -> Result<Self> {
+ let mut externs: Vec<External> = Vec::new();
+
+ loop {
+ if input.is_empty() {
+ break;
+ }
+ externs.push(input.parse()?);
+ }
+
+ Ok(ExternalInput { externs })
+ }
+}
+
+pub fn macro_impl(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
+ let ext_input = syn::parse_macro_input!(input as ExternalInput);
+
+ let tvm_rt_crate = crate::util::get_tvm_rt_crate();
+
+ let err_type = quote! { #tvm_rt_crate::Error };
+
+ let mut items = Vec::new();
+
+ for external in &ext_input.externs {
+ let name = &external.ident;
+ let global_name = format!("global_{}", external.ident);
+ let global_name = Ident::new(&global_name, Span::call_site());
+ let ext_name = &external.tvm_name;
+
+ let ty_params: Vec<syn::TypeParam> = external
+ .generics
+ .params
+ .iter()
+ .map(|ty_param| match ty_param {
+ syn::GenericParam::Type(param) => param.clone(),
+ _ => panic!(),
+ })
+ .collect();
+
+ let args = &external.inputs;
+
+ let (args, tys): (Vec<Ident>, Vec<Type>) = args
+ .iter()
+ .map(|arg| match arg {
+ FnArg::Typed(pat_type) => match &*pat_type.pat {
+ Pat::Ident(pat_ident) => {
+ let ident: Ident = pat_ident.ident.clone();
+ let ty: Type = *pat_type.ty.clone();
+ (ident, ty)
+ }
+ _ => panic!(),
+ },
+ _ => panic!(),
+ })
+ .unzip();
+
+ let ret_type = match &external.ret_type {
+ ReturnType::Type(_, rtype) => *rtype.clone(),
+ _ => panic!(),
+ };
+
+ let global = quote! {
+ #[allow(non_upper_case_globals)]
+ static #global_name: ::once_cell::sync::Lazy<#tvm_rt_crate::Function> =
+ ::once_cell::sync::Lazy::new(|| {
+ #tvm_rt_crate::Function::get(#ext_name)
+ .expect(concat!("unable to load external function", stringify!(#ext_name), "from TVM registry."))
+ });
+ };
+
+ items.push(global);
+
+ let wrapper = quote! {
+ pub fn #name<#(#ty_params),*>(#(#args : #tys),*) -> Result<#ret_type, #err_type> {
+ let func_ref: #tvm_rt_crate::Function = #global_name.clone();
+ let func_ref: Box<dyn Fn(#(#tys),*) -> Result<#ret_type, #err_type>> = func_ref.to_boxed_fn();
+ let res: #ret_type = func_ref(#(#args),*)?;
+ Ok(res)
+ }
+ };
+
+ items.push(wrapper);
+ }
+
+ proc_macro::TokenStream::from(quote! {
+ #(#items
+ )*
+ })
+}
diff --git a/rust/tvm-macros/src/import_module.rs b/rust/tvm-macros/src/import_module.rs
new file mode 100644
index 0000000..6b059ae
--- /dev/null
+++ b/rust/tvm-macros/src/import_module.rs
@@ -0,0 +1,133 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+use quote::quote;
+use std::{fs::File, io::Read};
+use syn::parse::{Parse, ParseStream, Result};
+use syn::LitStr;
+
+use std::path::PathBuf;
+
+struct ImportModule {
+ importing_file: LitStr,
+}
+
+impl Parse for ImportModule {
+ fn parse(input: ParseStream) -> Result<Self> {
+ let importing_file: LitStr = input.parse()?;
+ Ok(ImportModule { importing_file })
+ }
+}
+
+pub fn macro_impl(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
+ let import_module_args = syn::parse_macro_input!(input as ImportModule);
+
+ let manifest =
+ std::env::var("CARGO_MANIFEST_DIR").expect("variable should always be set by Cargo.");
+
+ let mut path = PathBuf::new();
+ path.push(manifest);
+ path = path.join(import_module_args.importing_file.value());
+
+ let mut fd = File::open(&path)
+ .unwrap_or_else(|_| panic!("Unable to find TVM object file at `{}`", path.display()));
+ let mut buffer = Vec::new();
+ fd.read_to_end(&mut buffer).unwrap();
+
+ let fn_names = match goblin::Object::parse(&buffer).unwrap() {
+ goblin::Object::Elf(elf) => elf
+ .syms
+ .iter()
+ .filter_map(|s| {
+ if s.st_type() == 0 || goblin::elf::sym::type_to_str(s.st_type()) == "FILE" {
+ return None;
+ }
+ match elf.strtab.get(s.st_name) {
+ Some(Ok(name)) if name != "" => {
+ Some(syn::Ident::new(name, proc_macro2::Span::call_site()))
+ }
+ _ => None,
+ }
+ })
+ .collect::<Vec<_>>(),
+ goblin::Object::Mach(goblin::mach::Mach::Binary(obj)) => {
+ obj.symbols()
+ .filter_map(|s| match s {
+ Ok((name, ref nlist))
+ if nlist.is_global()
+ && nlist.n_sect != 0
+ && !name.ends_with("tvm_module_ctx") =>
+ {
+ Some(syn::Ident::new(
+ if name.starts_with('_') {
+ // Mach objects prepend a _ to globals.
+ &name[1..]
+ } else {
+ &name
+ },
+ proc_macro2::Span::call_site(),
+ ))
+ }
+ _ => None,
+ })
+ .collect::<Vec<_>>()
+ }
+ _ => panic!("Unsupported object format."),
+ };
+
+ let extern_fns = quote! {
+ mod ext {
+ extern "C" {
+ #(
+ pub(super) fn #fn_names(
+ args: *const tvm_runtime::ffi::TVMValue,
+ type_codes: *const std::os::raw::c_int,
+ num_args: std::os::raw::c_int
+ ) -> std::os::raw::c_int;
+ )*
+ }
+ }
+ };
+
+ let fns = quote! {
+ use tvm_runtime::{ffi::TVMValue, ArgValue, RetValue, FuncCallError};
+ #extern_fns
+
+ #(
+ pub fn #fn_names(args: &[ArgValue]) -> Result<RetValue, FuncCallError> {
+ let (values, type_codes): (Vec<TVMValue>, Vec<i32>) = args
+ .into_iter()
+ .map(|arg| {
+ let (val, code) = arg.to_tvm_value();
+ (val, code as i32)
+ })
+ .unzip();
+ let exit_code = unsafe {
+ ext::#fn_names(values.as_ptr(), type_codes.as_ptr(), values.len() as i32)
+ };
+ if exit_code == 0 {
+ Ok(RetValue::default())
+ } else {
+ Err(FuncCallError::get_with_context(stringify!(#fn_names).to_string()))
+ }
+ }
+ )*
+ };
+
+ proc_macro::TokenStream::from(fns)
+}
diff --git a/rust/tvm-sys/src/errors.rs b/rust/tvm-macros/src/lib.rs
similarity index 54%
copy from rust/tvm-sys/src/errors.rs
copy to rust/tvm-macros/src/lib.rs
index 8479ec6..603e1ce 100644
--- a/rust/tvm-sys/src/errors.rs
+++ b/rust/tvm-macros/src/lib.rs
@@ -17,30 +17,25 @@
* under the License.
*/
-use thiserror::Error;
+use proc_macro::TokenStream;
-#[derive(Error, Debug)]
-#[error("invalid header (expected {expected_type:?}, found {actual_type:?})")]
-pub struct ValueDowncastError {
- pub actual_type: String,
- pub expected_type: &'static str,
+mod external;
+mod import_module;
+mod object;
+mod util;
+
+#[proc_macro]
+pub fn import_module(input: TokenStream) -> TokenStream {
+ import_module::macro_impl(input)
}
-#[derive(Error, Debug)]
-#[error("Function call `{context:?}` returned error: {message:?}")]
-pub struct FuncCallError {
- context: String,
- message: String,
+#[proc_macro_derive(Object, attributes(base, ref_name, type_key))]
+pub fn macro_impl(input: TokenStream) -> TokenStream {
+ // let input = proc_macro2::TokenStream::from(input);
+ TokenStream::from(object::macro_impl(input))
}
-impl FuncCallError {
- pub fn get_with_context(context: String) -> Self {
- Self {
- context,
- message: unsafe { std::ffi::CStr::from_ptr(crate::ffi::TVMGetLastError()) }
- .to_str()
- .expect("double fault")
- .to_owned(),
- }
- }
+#[proc_macro]
+pub fn external(input: TokenStream) -> TokenStream {
+ external::macro_impl(input)
}
diff --git a/rust/tvm-macros/src/object.rs b/rust/tvm-macros/src/object.rs
new file mode 100644
index 0000000..bee22c3
--- /dev/null
+++ b/rust/tvm-macros/src/object.rs
@@ -0,0 +1,163 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+use proc_macro::TokenStream;
+use proc_macro2::Span;
+use quote::quote;
+use syn::DeriveInput;
+use syn::Ident;
+
+use crate::util::get_tvm_rt_crate;
+
+pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream {
+ let tvm_rt_crate = get_tvm_rt_crate();
+ let derive_input = syn::parse_macro_input!(input as DeriveInput);
+ let payload_id = derive_input.ident;
+
+ let mut type_key = None;
+ let mut ref_name = None;
+ let base = Some(Ident::new("base", Span::call_site()));
+
+ for attr in derive_input.attrs {
+ if attr.path.is_ident("type_key") {
+ type_key = Some(attr.parse_meta().expect("foo"))
+ }
+
+ if attr.path.is_ident("ref_name") {
+ ref_name = Some(attr.parse_meta().expect("foo"))
+ }
+ }
+
+ let type_key = if let Some(syn::Meta::NameValue(name_value)) = type_key {
+ match name_value.lit {
+ syn::Lit::Str(type_key) => type_key,
+ _ => panic!("foo"),
+ }
+ } else {
+ panic!("bar");
+ };
+
+ let ref_name = if let Some(syn::Meta::NameValue(name_value)) = ref_name {
+ match name_value.lit {
+ syn::Lit::Str(ref_name) => ref_name,
+ _ => panic!("foo"),
+ }
+ } else {
+ panic!("bar");
+ };
+
+ let ref_id = Ident::new(&ref_name.value(), Span::call_site());
+ let base = base.expect("should be present");
+
+ let expanded = quote! {
+ unsafe impl #tvm_rt_crate::object::IsObject for #payload_id {
+ const TYPE_KEY: &'static str = #type_key;
+
+ fn as_object<'s>(&'s self) -> &'s Object {
+ &self.#base.as_object()
+ }
+ }
+
+ #[derive(Clone)]
+ pub struct #ref_id(Option<#tvm_rt_crate::object::ObjectPtr<#payload_id>>);
+
+ impl #tvm_rt_crate::object::ToObjectRef for #ref_id {
+ fn to_object_ref(&self) -> ObjectRef {
+ ObjectRef(self.0.as_ref().map(|o| o.upcast()))
+ }
+ }
+
+ impl std::ops::Deref for #ref_id {
+ type Target = #payload_id;
+
+ fn deref(&self) -> &Self::Target {
+ self.0.as_ref().unwrap()
+ }
+ }
+
+ impl std::convert::TryFrom<#tvm_rt_crate::RetValue> for #ref_id {
+ type Error = #tvm_rt_crate::Error;
+
+ fn try_from(ret_val: #tvm_rt_crate::RetValue) -> Result<#ref_id, Self::Error> {
+ use std::convert::TryInto;
+ let oref: ObjectRef = ret_val.try_into()?;
+ let ptr = oref.0.ok_or(#tvm_rt_crate::Error::Null)?;
+ let ptr = ptr.downcast::<#payload_id>()?;
+ Ok(#ref_id(Some(ptr)))
+ }
+ }
+
+ impl<'a> From<#ref_id> for #tvm_rt_crate::ArgValue<'a> {
+ fn from(object_ref: #ref_id) -> #tvm_rt_crate::ArgValue<'a> {
+ use std::ffi::c_void;
+ let object_ptr = &object_ref.0;
+ match object_ptr {
+ None => {
+ #tvm_rt_crate::ArgValue::
+ ObjectHandle(std::ptr::null::<c_void>() as *mut c_void)
+ }
+ Some(value) => value.clone().into()
+ }
+ }
+ }
+
+ impl<'a> From<&#ref_id> for #tvm_rt_crate::ArgValue<'a> {
+ fn from(object_ref: &#ref_id) -> #tvm_rt_crate::ArgValue<'a> {
+ let oref: #ref_id = object_ref.clone();
+ #tvm_rt_crate::ArgValue::<'a>::from(oref)
+ }
+ }
+
+ impl<'a> std::convert::TryFrom<#tvm_rt_crate::ArgValue<'a>> for #ref_id {
+ type Error = #tvm_rt_crate::Error;
+
+ fn try_from(arg_value: #tvm_rt_crate::ArgValue<'a>) -> Result<#ref_id, Self::Error> {
+ use std::convert::TryInto;
+ let optr = arg_value.try_into()?;
+ Ok(#ref_id(Some(optr)))
+ }
+ }
+
+ impl<'a> std::convert::TryFrom<&#tvm_rt_crate::ArgValue<'a>> for #ref_id {
+ type Error = #tvm_rt_crate::Error;
+
+ fn try_from(arg_value: &#tvm_rt_crate::ArgValue<'a>) -> Result<#ref_id, Self::Error> {
+ use std::convert::TryInto;
+ let optr = arg_value.try_into()?;
+ Ok(#ref_id(Some(optr)))
+ }
+ }
+
+ impl From<#ref_id> for #tvm_rt_crate::RetValue {
+ fn from(object_ref: #ref_id) -> #tvm_rt_crate::RetValue {
+ use std::ffi::c_void;
+ let object_ptr = &object_ref.0;
+ match object_ptr {
+ None => {
+ #tvm_rt_crate::RetValue::ObjectHandle(std::ptr::null::<c_void>() as *mut c_void)
+ }
+ Some(value) => value.clone().into()
+ }
+ }
+ }
+
+ };
+
+ TokenStream::from(expanded)
+}
diff --git a/rust/tvm-sys/src/errors.rs b/rust/tvm-macros/src/util.rs
similarity index 54%
copy from rust/tvm-sys/src/errors.rs
copy to rust/tvm-macros/src/util.rs
index 8479ec6..1e720f0 100644
--- a/rust/tvm-sys/src/errors.rs
+++ b/rust/tvm-macros/src/util.rs
@@ -17,30 +17,14 @@
* under the License.
*/
-use thiserror::Error;
+use proc_macro2::TokenStream;
+use quote::quote;
+use std::env;
-#[derive(Error, Debug)]
-#[error("invalid header (expected {expected_type:?}, found {actual_type:?})")]
-pub struct ValueDowncastError {
- pub actual_type: String,
- pub expected_type: &'static str,
-}
-
-#[derive(Error, Debug)]
-#[error("Function call `{context:?}` returned error: {message:?}")]
-pub struct FuncCallError {
- context: String,
- message: String,
-}
-
-impl FuncCallError {
- pub fn get_with_context(context: String) -> Self {
- Self {
- context,
- message: unsafe { std::ffi::CStr::from_ptr(crate::ffi::TVMGetLastError()) }
- .to_str()
- .expect("double fault")
- .to_owned(),
- }
+pub fn get_tvm_rt_crate() -> TokenStream {
+ if env::var("CARGO_PKG_NAME").unwrap() == "tvm-rt" {
+ quote!(crate)
+ } else {
+ quote!(tvm_rt)
}
}
diff --git a/rust/tvm-rt/.gitignore b/rust/tvm-rt/.gitignore
new file mode 100644
index 0000000..2430329
--- /dev/null
+++ b/rust/tvm-rt/.gitignore
@@ -0,0 +1,7 @@
+target
+**/*.rs.bk
+Cargo.lock
+/tests/basics/add_*
+/examples/resnet/deploy_*
+/examples/resnet/*.png
+/examples/resnet/synset.*
diff --git a/rust/runtime/Cargo.toml b/rust/tvm-rt/Cargo.toml
similarity index 68%
copy from rust/runtime/Cargo.toml
copy to rust/tvm-rt/Cargo.toml
index eb531f9..465ae58 100644
--- a/rust/runtime/Cargo.toml
+++ b/rust/tvm-rt/Cargo.toml
@@ -16,30 +16,30 @@
# under the License.
[package]
-name = "tvm-runtime"
+name = "tvm-rt"
version = "0.1.0"
license = "Apache-2.0"
-description = "A static TVM runtime"
+description = "Rust bindings for the TVM runtime API."
repository = "https://github.com/apache/incubator-tvm"
+homepage = "https://github.com/apache/incubator-tvm"
readme = "README.md"
-keywords = ["tvm"]
+keywords = ["rust", "tvm"]
categories = ["api-bindings", "science"]
authors = ["TVM Contributors"]
edition = "2018"
[dependencies]
-crossbeam = "0.7.3"
-failure = "0.1"
-itertools = "0.8"
-lazy_static = "1.4"
-ndarray="0.12"
-nom = "5.0"
-num_cpus = "1.10"
-serde = "1.0"
-serde_derive = "1.0"
-serde_json = "1.0"
-tvm-common = { version = "0.1", path = "../common" }
-tvm-macros = { version = "0.1", path = "../macros" }
+thiserror = "^1.0"
+ndarray = "0.12"
+num-traits = "0.2"
+tvm-sys = { version = "0.1", path = "../tvm-sys/", features = ["bindings"] }
+tvm-macros = { version = "0.1", path = "../tvm-macros" }
+paste = "0.1"
+mashup = "0.1"
+once_cell = "^1.3.1"
-[target.'cfg(not(any(target_arch = "wasm32", target_env = "sgx")))'.dependencies]
-libloading = "0.5"
+[dev-dependencies]
+anyhow = "^1.0"
+
+[features]
+blas = ["ndarray/blas"]
diff --git a/rust/tvm-rt/README.md b/rust/tvm-rt/README.md
new file mode 100644
index 0000000..7c87939
--- /dev/null
+++ b/rust/tvm-rt/README.md
@@ -0,0 +1,60 @@
+<!--- Licensed to the Apache Software Foundation (ASF) under one -->
+<!--- or more contributor license agreements. See the NOTICE file -->
+<!--- distributed with this work for additional information -->
+<!--- regarding copyright ownership. The ASF licenses this file -->
+<!--- to you under the Apache License, Version 2.0 (the -->
+<!--- "License"); you may not use this file except in compliance -->
+<!--- with the License. You may obtain a copy of the License at -->
+
+<!--- http://www.apache.org/licenses/LICENSE-2.0 -->
+
+<!--- Unless required by applicable law or agreed to in writing, -->
+<!--- software distributed under the License is distributed on an -->
+<!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->
+<!--- KIND, either express or implied. See the License for the -->
+<!--- specific language governing permissions and limitations -->
+<!--- under the License. -->
+
+# TVM Runtime Support
+
+This crate provides an idiomatic Rust API for [TVM](https://github.com/apache/incubator-tvm) runtime.
+Currently this is tested on `1.42.0` and above.
+
+## What Does This Crate Offer?
+
+TVM is an end-to-end deep learning compiler which takes high level machine learning
+models or tensor computations and lowers them into executable code for a variety
+of heterogenous devices (e.g., CPU, GPU).
+
+This crate provides access to the APIs for manipulating runtime data structures,
+as well as TVM's cross-language Object system which functions similarly to systems
+such as COM, enabling cross-language interoperability.
+
+## Installations
+
+Please follow TVM [installation](https://tvm.apache.org/docs/install/index.html) instructions,
+`export TVM_HOME=/path/to/tvm` and add `libtvm_runtime` to your `LD_LIBRARY_PATH`.
+
+### Example of registering a cross-language closure.
+
+One can use `register!` macro to expose a Rust closure with arguments which implement `TryFrom<ArgValue>`
+and return types which implement `Into<RetValue>`. Once registered with TVM these functions can be
+accessed via Python or C++, or any other language which implements the TVM packed function convention
+see `docs.tvm.ai` for more information.
+
+```rust
+use tvm_rt::{ArgValue, RetValue};
+use tvm_rt::function::{Function, Result, register};
+
+fn sum(x: i64, y: i64, z: i64) -> i64 {
+ x + y + z
+}
+
+fn main() {
+ register(sum, "mysum".to_owned()).unwrap();
+ let func = Function::get("mysum").unwrap();
+ let boxed_fn = func.to_boxed_fn::<dyn Fn(i64, i64, i64) -> Result<i64>>();
+ let ret = boxed_fn(10, 20, 30).unwrap();
+ assert_eq!(ret, 60);
+}
+```
diff --git a/rust/tvm-rt/src/context.rs b/rust/tvm-rt/src/context.rs
new file mode 100644
index 0000000..b0fea33
--- /dev/null
+++ b/rust/tvm-rt/src/context.rs
@@ -0,0 +1,97 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+use std::os::raw::c_void;
+use std::ptr;
+
+use crate::errors::Error;
+
+use tvm_sys::ffi;
+
+pub use tvm_sys::context::*;
+
+trait ContextExt {
+ /// Checks whether the context exists or not.
+ fn exist(&self) -> bool;
+ fn sync(&self) -> Result<(), Error>;
+ fn max_threads_per_block(&self) -> isize;
+ fn warp_size(&self) -> isize;
+ fn max_shared_memory_per_block(&self) -> isize;
+ fn compute_version(&self) -> isize;
+ fn device_name(&self) -> isize;
+ fn max_clock_rate(&self) -> isize;
+ fn multi_processor_count(&self) -> isize;
+ fn max_thread_dimensions(&self) -> isize;
+}
+
+macro_rules! impl_device_attrs {
+ ($(($attr_name:ident, $attr_kind:expr));+) => {
+ $(
+ fn $attr_name(&self) -> isize {
+ get_device_attr(self.device_type as i32, self.device_id as i32, 0)
+ .expect("should not fail") as isize
+ }
+
+ )+
+ };
+}
+
+crate::external! {
+ #[name("runtime.GetDeviceAttr")]
+ fn get_device_attr(device_type: i32, device_id: i32, device_kind: i32) -> i32;
+}
+
+impl ContextExt for Context {
+ fn exist(&self) -> bool {
+ let exists = get_device_attr(self.device_type as i32, self.device_id as i32, 0)
+ .expect("should not fail");
+
+ exists != 0
+ }
+
+ /// Synchronize the context stream.
+ fn sync(&self) -> Result<(), Error> {
+ check_call!(ffi::TVMSynchronize(
+ self.device_type as i32,
+ self.device_id as i32,
+ ptr::null_mut() as *mut c_void
+ ));
+ Ok(())
+ }
+
+ impl_device_attrs!((max_threads_per_block, 1);
+ (warp_size, 2);
+ (max_shared_memory_per_block, 3);
+ (compute_version, 4);
+ (device_name, 5);
+ (max_clock_rate, 6);
+ (multi_processor_count, 7);
+ (max_thread_dimensions, 8));
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn sync() {
+ let ctx = Context::cpu(0);
+ assert!(ctx.sync().is_ok())
+ }
+}
diff --git a/rust/tvm-rt/src/errors.rs b/rust/tvm-rt/src/errors.rs
new file mode 100644
index 0000000..0b45ebf
--- /dev/null
+++ b/rust/tvm-rt/src/errors.rs
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+use crate::DataType;
+use thiserror::Error;
+
+#[derive(Debug, Error)]
+#[error("Function was not set in `function::Builder`")]
+pub struct FunctionNotFoundError;
+
+#[derive(Debug, Error)]
+#[error("Expected type `{expected}` but found `{actual}`")]
+pub struct TypeMismatchError {
+ pub expected: String,
+ pub actual: String,
+}
+
+#[derive(Debug, Error)]
+pub enum NDArrayError {
+ #[error("Missing NDArray shape.")]
+ MissingShape,
+ #[error("Cannot convert from an empty array.")]
+ EmptyArray,
+ #[error("Invalid datatype when attempting to convert ndarray.")]
+ InvalidDatatype(#[from] tvm_sys::datatype::ParseDataTypeError),
+ #[error("a shape error occurred in the Rust ndarray library")]
+ ShapeError(#[from] ndarray::ShapeError),
+ #[error("Expected type `{expected}` but found `{actual}`")]
+ DataTypeMismatch {
+ expected: DataType,
+ actual: DataType,
+ },
+}
+
+#[derive(Debug, Error)]
+pub enum Error {
+ #[error("{0}")]
+ Downcast(#[from] tvm_sys::errors::ValueDowncastError),
+ #[error("raw pointer passed across boundary was null")]
+ Null,
+ #[error("failed to load module due to invalid path {0}")]
+ ModuleLoadPath(String),
+ #[error("failed to convert String into CString due to embedded nul character")]
+ ToCString(#[from] std::ffi::NulError),
+ #[error("failed to convert CString into String")]
+ FromCString(#[from] std::ffi::IntoStringError),
+ #[error("Handle `{0}` is null.")]
+ NullHandle(String),
+ #[error("{0}")]
+ NDArray(#[from] NDArrayError),
+ #[error("{0}")]
+ CallFailed(String),
+}
+
+impl Error {
+ pub fn downcast(actual_type: String, expected_type: &'static str) -> Error {
+ Self::Downcast(tvm_sys::errors::ValueDowncastError {
+ actual_type,
+ expected_type,
+ })
+ }
+}
diff --git a/rust/tvm-rt/src/function.rs b/rust/tvm-rt/src/function.rs
new file mode 100644
index 0000000..cb8777a
--- /dev/null
+++ b/rust/tvm-rt/src/function.rs
@@ -0,0 +1,303 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+//! This module provides an idiomatic Rust API for creating and working with TVM functions.
+//!
+//! For calling an already registered TVM function use [`function::Builder`]
+//! To register a TVM packed function from Rust side either
+//! use [`function::register`] or the macro [`register_global_func`].
+//!
+//! See the tests and examples repository for more examples.
+
+use std::convert::TryFrom;
+use std::{
+ ffi::CString,
+ os::raw::{c_char, c_int},
+ ptr, str,
+};
+
+pub use tvm_sys::{ffi, ArgValue, RetValue};
+
+use crate::errors::Error;
+
+use super::to_boxed_fn::ToBoxedFn;
+use super::to_function::{ToFunction, Typed};
+
+pub type Result<T> = std::result::Result<T, Error>;
+
+/// Wrapper around TVM function handle which includes `is_global`
+/// indicating whether the function is global or not, and `is_cloned` showing
+/// not to drop a cloned function from Rust side.
+/// The value of these fields can be accessed through their respective methods.
+#[derive(Debug, Hash)]
+pub struct Function {
+ pub(crate) handle: ffi::TVMFunctionHandle,
+ // whether the registered function is global or not.
+ is_global: bool,
+ from_rust: bool,
+}
+
+unsafe impl Send for Function {}
+unsafe impl Sync for Function {}
+
+impl Function {
+ pub(crate) fn new(handle: ffi::TVMFunctionHandle) -> Self {
+ Function {
+ handle,
+ is_global: false,
+ from_rust: false,
+ }
+ }
+
+ /// For a given function, it returns a function by name.
+ pub fn get<S: AsRef<str>>(name: S) -> Option<Function> {
+ let name = CString::new(name.as_ref()).unwrap();
+ let mut handle = ptr::null_mut() as ffi::TVMFunctionHandle;
+
+ check_call!(ffi::TVMFuncGetGlobal(
+ name.as_ptr() as *const c_char,
+ &mut handle as *mut _
+ ));
+
+ if handle.is_null() {
+ None
+ } else {
+ Some(Function {
+ handle,
+ is_global: true,
+ from_rust: false,
+ })
+ }
+ }
+
+ pub fn get_boxed<F: ?Sized, S: AsRef<str>>(name: S) -> Option<Box<F>>
+ where
+ F: ToBoxedFn,
+ {
+ Self::get(name).map(|f| f.to_boxed_fn::<F>())
+ }
+
+ /// Returns the underlying TVM function handle.
+ pub fn handle(&self) -> ffi::TVMFunctionHandle {
+ self.handle
+ }
+
+ /// Returns `true` if the underlying TVM function is global and `false` otherwise.
+ pub fn is_global(&self) -> bool {
+ self.is_global
+ }
+
+ /// Calls the function that created from `Builder`.
+ pub fn invoke<'a>(&self, arg_buf: Vec<ArgValue<'a>>) -> Result<RetValue> {
+ let num_args = arg_buf.len();
+ let (mut values, mut type_codes): (Vec<ffi::TVMValue>, Vec<ffi::TVMArgTypeCode>) =
+ arg_buf.iter().map(|arg| arg.to_tvm_value()).unzip();
+ let mut ret_val = ffi::TVMValue { v_int64: 0 };
+ let mut ret_type_code = 0i32;
+
+ check_call!(ffi::TVMFuncCall(
+ self.handle,
+ values.as_mut_ptr() as *mut ffi::TVMValue,
+ type_codes.as_mut_ptr() as *mut c_int,
+ num_args as c_int,
+ &mut ret_val as *mut _,
+ &mut ret_type_code as *mut _
+ ));
+
+ Ok(RetValue::from_tvm_value(ret_val, ret_type_code as u32))
+ }
+
+ pub fn to_boxed_fn<F: ?Sized>(self) -> Box<F>
+ where
+ F: ToBoxedFn,
+ {
+ F::to_boxed_fn(self)
+ }
+}
+
+impl Clone for Function {
+ fn clone(&self) -> Function {
+ Self {
+ handle: self.handle,
+ is_global: self.is_global,
+ from_rust: true,
+ }
+ }
+}
+
+// impl Drop for Function {
+// fn drop(&mut self) {
+// if !self.is_global && !self.is_cloned {
+// check_call!(ffi::TVMFuncFree(self.handle));
+// }
+// }
+// }
+
+impl From<Function> for RetValue {
+ fn from(func: Function) -> RetValue {
+ RetValue::FuncHandle(func.handle)
+ }
+}
+
+impl TryFrom<RetValue> for Function {
+ type Error = Error;
+
+ fn try_from(ret_value: RetValue) -> Result<Function> {
+ match ret_value {
+ RetValue::FuncHandle(handle) => Ok(Function::new(handle)),
+ _ => Err(Error::downcast(
+ format!("{:?}", ret_value),
+ "FunctionHandle",
+ )),
+ }
+ }
+}
+
+impl<'a> From<Function> for ArgValue<'a> {
+ fn from(func: Function) -> ArgValue<'a> {
+ ArgValue::FuncHandle(func.handle)
+ }
+}
+
+impl<'a> TryFrom<ArgValue<'a>> for Function {
+ type Error = Error;
+
+ fn try_from(arg_value: ArgValue<'a>) -> Result<Function> {
+ match arg_value {
+ ArgValue::FuncHandle(handle) => Ok(Function::new(handle)),
+ _ => Err(Error::downcast(
+ format!("{:?}", arg_value),
+ "FunctionHandle",
+ )),
+ }
+ }
+}
+
+impl<'a> TryFrom<&ArgValue<'a>> for Function {
+ type Error = Error;
+
+ fn try_from(arg_value: &ArgValue<'a>) -> Result<Function> {
+ match arg_value {
+ ArgValue::FuncHandle(handle) => Ok(Function::new(*handle)),
+ _ => Err(Error::downcast(
+ format!("{:?}", arg_value),
+ "FunctionHandle",
+ )),
+ }
+ }
+}
+
+/// Registers a Rust function with an arbitrary type signature in
+/// the TVM registry.
+///
+///
+/// A function is convertible if and only if its arguments and return types are convertible
+/// to and from TVM values respectively.
+///
+/// Use [`register_override`] if control of overriding existing global TVM function
+/// is required, this function will panic if a function is already registered.
+///
+/// ## Example
+///
+/// ```
+/// # use tvm_rt::{ArgValue, RetValue};
+/// # use tvm_rt::function::{Function, Result, register};
+///
+/// fn sum(x: i64, y: i64, z: i64) -> i64 {
+/// x + y + z
+/// }
+///
+/// register(sum, "mysum".to_owned()).unwrap();
+/// let func = Function::get("mysum").unwrap();
+/// let boxed_fn = func.to_boxed_fn::<dyn Fn(i64, i64, i64) -> Result<i64>>();
+/// let ret = boxed_fn(10, 20, 30).unwrap();
+/// assert_eq!(ret, 60);
+/// ```
+pub fn register<F, I, O, S: Into<String>>(f: F, name: S) -> Result<()>
+where
+ F: ToFunction<I, O>,
+ F: Typed<I, O>,
+{
+ register_override(f, name, false)
+}
+
+/// Register a function with explicit control over whether to override an existing registration or not.
+///
+/// See `register` for more details on how to use the registration API.
+pub fn register_override<F, I, O, S: Into<String>>(f: F, name: S, override_: bool) -> Result<()>
+where
+ F: ToFunction<I, O>,
+ F: Typed<I, O>,
+{
+ let func = f.to_function();
+ let name = name.into();
+ // Not sure about this code
+ let handle = func.handle();
+ let name = CString::new(name)?;
+ check_call!(ffi::TVMFuncRegisterGlobal(
+ name.into_raw(),
+ handle,
+ override_ as c_int
+ ));
+
+ Ok(())
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::function::Function;
+
+ static CANARY: &str = "runtime.ModuleLoadFromFile";
+
+ #[test]
+ fn get_fn() {
+ assert!(Function::get(CANARY).is_some());
+ assert!(Function::get("does not exists!").is_none());
+ }
+
+ #[test]
+ fn register_and_call_closure0() {
+ use crate::function;
+ use function::Result;
+
+ fn constfn() -> i64 {
+ return 10;
+ }
+
+ function::register_override(constfn, "constfn".to_owned(), true).unwrap();
+
+ let func = Function::get_boxed::<dyn Fn() -> Result<i32>, _>("constfn").unwrap();
+ let ret = func().unwrap();
+ assert_eq!(ret, 10);
+ }
+
+ #[test]
+ fn register_and_call_closure1() {
+ use crate::function::{self};
+
+ fn ident(x: i64) -> i64 {
+ return x;
+ }
+
+ function::register_override(ident, "ident".to_owned(), true).unwrap();
+ let func = Function::get_boxed::<dyn Fn(i32) -> Result<i32>, _>("ident").unwrap();
+ assert_eq!(func(60).unwrap(), 60);
+ }
+}
diff --git a/rust/tvm-rt/src/lib.rs b/rust/tvm-rt/src/lib.rs
new file mode 100644
index 0000000..10f8317
--- /dev/null
+++ b/rust/tvm-rt/src/lib.rs
@@ -0,0 +1,130 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+//! [TVM](https://github.com/apache/incubator-tvm) is a compiler stack for deep learning systems.
+//!
+//! This crate provides an idiomatic Rust API for TVM runtime.
+//!
+//! The TVM runtime API contains the data structures used by higher-level TVM executors.
+//! Specifically it exposes the basic types such as NDArray, as well as the more general object system.
+//! The TVM object system enables cross-language interoperability including that of closures for all
+//! supported languages including C++, and Python.
+
+pub mod object;
+pub mod string;
+
+pub use object::*;
+pub use string::*;
+
+use std::{
+ ffi::{CStr, CString},
+ str,
+};
+
+pub use crate::{
+ context::{Context, DeviceType},
+ errors::*,
+ function::Function,
+ module::Module,
+ ndarray::NDArray,
+};
+
+pub use function::{ArgValue, RetValue};
+pub use tvm_sys::byte_array::ByteArray;
+pub use tvm_sys::datatype::DataType;
+use tvm_sys::ffi;
+
+pub use tvm_macros::external;
+
+// Macro to check the return call to TVM runtime shared library.
+
+#[macro_export]
+macro_rules! tvm_call {
+ ($e:expr) => {{
+ if unsafe { $e } != 0 {
+ Err($crate::get_last_error().into())
+ } else {
+ Ok(())
+ }
+ }};
+}
+
+#[macro_export]
+macro_rules! check_call {
+ ($e:expr) => {{
+ if unsafe { $e } != 0 {
+ panic!("{}", $crate::get_last_error());
+ }
+ }};
+}
+
+/// Gets the last error message.
+pub fn get_last_error() -> &'static str {
+ unsafe {
+ match CStr::from_ptr(ffi::TVMGetLastError()).to_str() {
+ Ok(s) => s,
+ Err(_) => "Invalid UTF-8 message",
+ }
+ }
+}
+
+pub(crate) fn set_last_error<E: std::error::Error>(err: &E) {
+ let c_string = CString::new(err.to_string()).unwrap();
+ unsafe {
+ ffi::TVMAPISetLastError(c_string.as_ptr());
+ }
+}
+
+#[macro_use]
+pub mod function;
+pub mod context;
+pub mod errors;
+pub mod module;
+pub mod ndarray;
+pub mod to_boxed_fn;
+mod to_function;
+pub mod value;
+
+/// Outputs the current TVM version.
+pub fn version() -> &'static str {
+ match str::from_utf8(ffi::TVM_VERSION) {
+ Ok(s) => s,
+ Err(_) => "Invalid UTF-8 string",
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn print_version() {
+ println!("TVM version: {}", version());
+ }
+
+ #[test]
+ fn set_error() {
+ let err = errors::NDArrayError::EmptyArray;
+ set_last_error(&err);
+ assert_eq!(
+ get_last_error().trim(),
+ errors::NDArrayError::EmptyArray.to_string()
+ );
+ }
+}
diff --git a/rust/tvm-rt/src/module.rs b/rust/tvm-rt/src/module.rs
new file mode 100644
index 0000000..b540c1b
--- /dev/null
+++ b/rust/tvm-rt/src/module.rs
@@ -0,0 +1,129 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+//! Provides the [`Module`] type and methods for working with runtime TVM modules.
+
+use std::{
+ ffi::CString,
+ os::raw::{c_char, c_int},
+ path::Path,
+ ptr,
+};
+
+use tvm_sys::ffi;
+
+use crate::errors::Error;
+use crate::{errors, function::Function};
+
+const ENTRY_FUNC: &str = "__tvm_main__";
+
+/// Wrapper around TVM module handle which contains an entry function.
+/// The entry function can be applied to an imported module through [`entry_func`].
+///
+/// [`entry_func`]:struct.Module.html#method.entry_func
+#[derive(Debug, Clone)]
+pub struct Module {
+ pub(crate) handle: ffi::TVMModuleHandle,
+ entry_func: Option<Function>,
+}
+
+crate::external! {
+ #[name("runtime.RuntimeEnabled")]
+ fn runtime_enabled(target: CString) -> i32;
+
+ #[name("runtime.ModuleLoadFromFile")]
+ fn load_from_file(file_name: CString, format: CString) -> Module;
+}
+
+impl Module {
+ pub(crate) fn new(handle: ffi::TVMModuleHandle) -> Self {
+ Self {
+ handle,
+ entry_func: None,
+ }
+ }
+
+ pub fn entry(&mut self) -> Option<Function> {
+ if self.entry_func.is_none() {
+ self.entry_func = self.get_function(ENTRY_FUNC, false).ok();
+ }
+ self.entry_func.clone()
+ }
+
+ /// Gets a function by name from a registered module.
+ pub fn get_function(&self, name: &str, query_import: bool) -> Result<Function, Error> {
+ let name = CString::new(name)?;
+ let mut fhandle = ptr::null_mut() as ffi::TVMFunctionHandle;
+ check_call!(ffi::TVMModGetFunction(
+ self.handle,
+ name.as_ptr() as *const c_char,
+ query_import as c_int,
+ &mut fhandle as *mut _
+ ));
+
+ if !fhandle.is_null() {
+ return Err(errors::Error::NullHandle(name.into_string()?.to_string()));
+ }
+
+ Ok(Function::new(fhandle))
+ }
+
+ /// Imports a dependent module such as `.ptx` for gpu.
+ pub fn import_module(&self, dependent_module: Module) {
+ check_call!(ffi::TVMModImport(self.handle, dependent_module.handle))
+ }
+
+ /// Loads a module shared library from path.
+ pub fn load<P: AsRef<Path>>(path: &P) -> Result<Module, Error> {
+ let ext = CString::new(
+ path.as_ref()
+ .extension()
+ .unwrap_or_else(|| std::ffi::OsStr::new(""))
+ .to_str()
+ .ok_or_else(|| Error::ModuleLoadPath(path.as_ref().display().to_string()))?,
+ )?;
+
+ let cpath = CString::new(
+ path.as_ref()
+ .to_str()
+ .ok_or_else(|| Error::ModuleLoadPath(path.as_ref().display().to_string()))?,
+ )?;
+
+ let module = load_from_file(cpath, ext)?;
+ Ok(module)
+ }
+
+ /// Checks if a target device is enabled for a module.
+ pub fn enabled(&self, target: &str) -> bool {
+ let target = CString::new(target).unwrap();
+ let enabled = runtime_enabled(target).unwrap();
+ enabled != 0
+ }
+
+ /// Returns the underlying module handle.
+ pub fn handle(&self) -> ffi::TVMModuleHandle {
+ self.handle
+ }
+}
+
+impl Drop for Module {
+ fn drop(&mut self) {
+ check_call!(ffi::TVMModFree(self.handle));
+ }
+}
diff --git a/rust/tvm-rt/src/ndarray.rs b/rust/tvm-rt/src/ndarray.rs
new file mode 100644
index 0000000..b7ae462
--- /dev/null
+++ b/rust/tvm-rt/src/ndarray.rs
@@ -0,0 +1,438 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+//! This module implements the [`NDArray`] type for working with *TVM tensors* or
+//! coverting from a Rust's ndarray to TVM `NDArray`.
+//!
+//! One can create an empty NDArray given the shape, context and dtype using [`empty`].
+//! To create an NDArray from a mutable buffer in cpu use [`copy_from_buffer`].
+//! To copy an NDArray to different context use [`copy_to_ctx`].
+//!
+//! Given a [`Rust's dynamic ndarray`], one can convert it to TVM NDArray as follows:
+//!
+//! # Example
+//!
+//! ```
+//! # use tvm_rt::{NDArray, Context, DataType};
+//! # use ndarray::{Array, ArrayD};
+//! # use std::str::FromStr;
+//! use std::convert::TryFrom;
+//!
+//! let a = Array::from_shape_vec((2, 2), vec![1f32, 2., 3., 4.])
+//! .unwrap()
+//! .into_dyn(); // Rust's ndarray
+//! let nd = NDArray::from_rust_ndarray(&a, Context::cpu(0), DataType::from_str("float32").unwrap()).unwrap();
+//! assert_eq!(nd.shape(), Some(&mut [2, 2][..]));
+//! let rnd: ArrayD<f32> = ArrayD::try_from(&nd).unwrap();
+//! assert!(rnd.all_close(&a, 1e-8f32));
+//! ```
+//!
+//! [`Rust's dynamic ndarray`]:https://docs.rs/ndarray/0.12.1/ndarray/
+//! [`copy_from_buffer`]:struct.NDArray.html#method.copy_from_buffer
+//! [`copy_to_ctx`]:struct.NDArray.html#method.copy_to_ctx
+
+use std::convert::TryInto;
+use std::ffi::c_void;
+use std::{convert::TryFrom, mem, os::raw::c_int, ptr, slice, str::FromStr};
+
+use crate::errors::NDArrayError;
+
+use tvm_sys::ffi::DLTensor;
+use tvm_sys::{ffi, ByteArray, Context, DataType};
+
+use ndarray::{Array, ArrayD};
+use num_traits::Num;
+
+/// See the [`module-level documentation`](../ndarray/index.html) for more details.
+///
+/// Wrapper around TVM array handle.
+#[derive(Debug)]
+pub enum NDArray {
+ Borrowed { handle: ffi::TVMArrayHandle },
+ Owned { handle: *mut c_void },
+}
+
+impl NDArray {
+ pub(crate) fn new(handle: ffi::TVMArrayHandle) -> Self {
+ NDArray::Borrowed { handle }
+ }
+
+ pub(crate) fn from_ndarray_handle(handle: *mut c_void) -> Self {
+ NDArray::Owned { handle }
+ }
+
+ pub fn as_dltensor(&self) -> &DLTensor {
+ let ptr: *mut DLTensor = match self {
+ NDArray::Borrowed { ref handle } => *handle,
+ NDArray::Owned { ref handle } => *handle as *mut DLTensor,
+ };
+
+ unsafe { std::mem::transmute(ptr) }
+ }
+
+ pub(crate) fn as_raw_dltensor(&self) -> *mut DLTensor {
+ match self {
+ NDArray::Borrowed { handle } => *handle,
+ NDArray::Owned { handle } => *handle as *mut DLTensor,
+ }
+ }
+
+ pub fn is_view(&self) -> bool {
+ if let &NDArray::Borrowed { .. } = self {
+ true
+ } else {
+ false
+ }
+ }
+
+ /// Returns the shape of the NDArray.
+ pub fn shape(&self) -> Option<&mut [usize]> {
+ let arr = self.as_dltensor();
+ if arr.shape.is_null() || arr.data.is_null() {
+ return None;
+ };
+ let slc = unsafe { slice::from_raw_parts_mut(arr.shape as *mut usize, arr.ndim as usize) };
+ Some(slc)
+ }
+
+ /// Returns the total number of entries of the NDArray.
+ pub fn size(&self) -> Option<usize> {
+ self.shape().map(|v| v.iter().product())
+ }
+
+ /// Returns the context which the NDArray was defined.
+ pub fn ctx(&self) -> Context {
+ self.as_dltensor().ctx.into()
+ }
+
+ /// Returns the type of the entries of the NDArray.
+ pub fn dtype(&self) -> DataType {
+ self.as_dltensor().dtype.into()
+ }
+
+ /// Returns the number of dimensions of the NDArray.
+ pub fn ndim(&self) -> usize {
+ self.as_dltensor()
+ .ndim
+ .try_into()
+ .expect("number of dimensions must always be positive")
+ }
+
+ /// Returns the strides of the underlying NDArray.
+ pub fn strides(&self) -> Option<&[usize]> {
+ unsafe {
+ let sz = self.ndim() * mem::size_of::<usize>();
+ let strides_ptr = self.as_dltensor().strides as *const usize;
+ let slc = slice::from_raw_parts(strides_ptr, sz);
+ Some(slc)
+ }
+ }
+
+ /// Shows whether the underlying ndarray is contiguous in memory or not.
+ pub fn is_contiguous(&self) -> Result<bool, crate::errors::Error> {
+ Ok(match self.strides() {
+ None => true,
+ Some(strides) => {
+ // NDArrayError::MissingShape in case shape is not determined
+ self.shape()
+ .ok_or(NDArrayError::MissingShape)?
+ .iter()
+ .zip(strides)
+ .rfold(
+ (true, 1),
+ |(is_contig, expected_stride), (shape, stride)| {
+ (
+ is_contig && *stride == expected_stride,
+ expected_stride * (*shape as usize),
+ )
+ },
+ )
+ .0
+ }
+ })
+ }
+
+ pub fn byte_offset(&self) -> isize {
+ self.as_dltensor().byte_offset as isize
+ }
+
+ /// Flattens the NDArray to a `Vec` of the same type in cpu.
+ ///
+ /// ## Example
+ ///
+ /// ```
+ /// # use tvm_rt::{Context, DataType, NDArray};
+ /// # use std::str::FromStr;
+ /// let mut shape = [4];
+ /// let mut data = vec![1i32, 2, 3, 4];
+ /// let ctx = Context::cpu(0);
+ /// let mut ndarray = NDArray::empty(&mut shape, ctx, DataType::from_str("int32").unwrap());
+ /// ndarray.copy_from_buffer(&mut data);
+ /// assert_eq!(ndarray.shape(), Some(&mut shape[..]));
+ /// assert_eq!(ndarray.to_vec::<i32>().unwrap(), data);
+ /// ```
+ pub fn to_vec<T>(&self) -> Result<Vec<T>, NDArrayError> {
+ if !self.shape().is_some() {
+ return Err(NDArrayError::EmptyArray);
+ }
+ let earr = NDArray::empty(
+ self.shape().ok_or(NDArrayError::MissingShape)?,
+ Context::cpu(0),
+ self.dtype(),
+ );
+ let target = self.copy_to_ndarray(earr)?;
+ let arr = target.as_dltensor();
+ let sz = self.size().ok_or(NDArrayError::MissingShape)?;
+ let mut v: Vec<T> = Vec::with_capacity(sz * mem::size_of::<T>());
+ unsafe {
+ v.as_mut_ptr()
+ .copy_from_nonoverlapping(arr.data as *const T, sz);
+ v.set_len(sz);
+ }
+ Ok(v)
+ }
+
+ /// Converts the NDArray to [`ByteArray`].
+ pub fn to_bytearray(&self) -> Result<ByteArray, NDArrayError> {
+ let v = self.to_vec::<u8>()?;
+ Ok(ByteArray::from(v))
+ }
+
+ /// Creates an NDArray from a mutable buffer of types i32, u32 or f32 in cpu.
+ ///
+ /// ## Example
+ ///
+ /// ```
+ /// # use tvm_rt::{Context, DataType, NDArray};
+ /// # use std::str::FromStr;
+ /// let shape = &mut [2];
+ /// let mut data = vec![1f32, 2.0];
+ /// let ctx = Context::cpu(0);
+ /// let mut ndarray = NDArray::empty(shape, ctx, DataType::from_str("int32").unwrap());
+ /// ndarray.copy_from_buffer(&mut data);
+ /// ```
+ ///
+ /// *Note*: if something goes wrong during the copy, it will panic
+ /// from TVM side. See `TVMArrayCopyFromBytes` in `include/tvm/runtime/c_runtime_api.h`.
+ pub fn copy_from_buffer<T: Num32>(&mut self, data: &mut [T]) {
+ check_call!(ffi::TVMArrayCopyFromBytes(
+ self.as_raw_dltensor(),
+ data.as_ptr() as *mut _,
+ data.len() * mem::size_of::<T>()
+ ));
+ }
+
+ /// Copies the NDArray to another target NDArray.
+ pub fn copy_to_ndarray(&self, target: NDArray) -> Result<NDArray, NDArrayError> {
+ if self.dtype() != target.dtype() {
+ return Err(NDArrayError::DataTypeMismatch {
+ expected: self.dtype(),
+ actual: target.dtype(),
+ });
+ }
+
+ check_call!(ffi::TVMArrayCopyFromTo(
+ self.as_raw_dltensor(),
+ target.as_raw_dltensor(),
+ ptr::null_mut() as ffi::TVMStreamHandle
+ ));
+
+ Ok(target)
+ }
+
+ /// Copies the NDArray to a target context.
+ pub fn copy_to_ctx(&self, target: &Context) -> Result<NDArray, NDArrayError> {
+ let tmp = NDArray::empty(
+ self.shape().ok_or(NDArrayError::MissingShape)?,
+ *target,
+ self.dtype(),
+ );
+ let copy = self.copy_to_ndarray(tmp)?;
+ Ok(copy)
+ }
+
+ /// Converts a Rust's ndarray to TVM NDArray.
+ pub fn from_rust_ndarray<T: Num32 + Copy>(
+ rnd: &ArrayD<T>,
+ ctx: Context,
+ dtype: DataType,
+ ) -> Result<Self, NDArrayError> {
+ let shape = rnd.shape().to_vec();
+ let mut nd = NDArray::empty(&shape, ctx, dtype);
+ let mut buf = Array::from_iter(rnd.into_iter().map(|&v| v as T));
+ nd.copy_from_buffer(
+ buf.as_slice_mut()
+ .expect("Array from iter must be contiguous."),
+ );
+ Ok(nd)
+ }
+
+ /// Allocates and creates an empty NDArray given the shape, context and dtype.
+ pub fn empty(shape: &[usize], ctx: Context, dtype: DataType) -> NDArray {
+ let mut handle = ptr::null_mut() as ffi::TVMArrayHandle;
+ let dtype: tvm_sys::ffi::DLDataType = dtype.into();
+ check_call!(ffi::TVMArrayAlloc(
+ shape.as_ptr() as *const i64,
+ shape.len() as c_int,
+ i32::from(dtype.code) as c_int,
+ i32::from(dtype.bits) as c_int,
+ i32::from(dtype.lanes) as c_int,
+ ctx.device_type as c_int,
+ ctx.device_id as c_int,
+ &mut handle as *mut _,
+ ));
+ NDArray::Borrowed { handle: handle }
+ }
+}
+
+macro_rules! impl_from_ndarray_rustndarray {
+ ($type:ty, $type_name:tt) => {
+ impl<'a> TryFrom<&'a NDArray> for ArrayD<$type> {
+ type Error = NDArrayError;
+
+ fn try_from(nd: &NDArray) -> Result<ArrayD<$type>, Self::Error> {
+ if !nd.shape().is_some() {
+ return Err(NDArrayError::MissingShape);
+ }
+ assert_eq!(nd.dtype(), DataType::from_str($type_name)?, "Type mismatch");
+ Ok(Array::from_shape_vec(
+ &*nd.shape().ok_or(NDArrayError::MissingShape)?,
+ nd.to_vec::<$type>()?,
+ )?)
+ }
+ }
+
+ impl<'a> TryFrom<&'a mut NDArray> for ArrayD<$type> {
+ type Error = NDArrayError;
+
+ fn try_from(nd: &mut NDArray) -> Result<ArrayD<$type>, Self::Error> {
+ if !nd.shape().is_some() {
+ return Err(NDArrayError::MissingShape);
+ };
+ assert_eq!(nd.dtype(), DataType::from_str($type_name)?, "Type mismatch");
+ Ok(Array::from_shape_vec(
+ &*nd.shape().ok_or(NDArrayError::MissingShape)?,
+ nd.to_vec::<$type>()?,
+ )?)
+ }
+ }
+ };
+}
+
+impl_from_ndarray_rustndarray!(i32, "int");
+impl_from_ndarray_rustndarray!(u32, "uint");
+impl_from_ndarray_rustndarray!(f32, "float");
+
+impl Drop for NDArray {
+ fn drop(&mut self) {
+ if let &mut NDArray::Owned { .. } = self {
+ check_call!(ffi::TVMArrayFree(self.as_raw_dltensor()));
+ }
+ }
+}
+
+mod sealed {
+ /// Private trait to prevent other traits from being implemeneted in downstream crates.
+ pub trait Sealed {}
+}
+
+/// A trait for the supported 32-bits numerical types in frontend.
+pub trait Num32: Num + sealed::Sealed {
+ const BITS: u8 = 32;
+}
+
+macro_rules! impl_num32 {
+ ($($type:ty),+) => {
+ $(
+ impl sealed::Sealed for $type {}
+ impl Num32 for $type {}
+ )+
+ };
+}
+
+impl_num32!(i32, u32, f32);
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn basics() {
+ let shape = &mut [1, 2, 3];
+ let ctx = Context::cpu(0);
+ let ndarray = NDArray::empty(shape, ctx, DataType::from_str("int32").unwrap());
+ assert_eq!(ndarray.shape().unwrap(), shape);
+ assert_eq!(
+ ndarray.size().unwrap(),
+ shape.to_vec().into_iter().product()
+ );
+ assert_eq!(ndarray.ndim(), 3);
+ assert!(ndarray.strides().is_none());
+ assert_eq!(ndarray.byte_offset(), 0);
+ }
+
+ #[test]
+ fn copy() {
+ let shape = &mut [4];
+ let mut data = vec![1i32, 2, 3, 4];
+ let ctx = Context::cpu(0);
+ let mut ndarray = NDArray::empty(shape, ctx, DataType::from_str("int32").unwrap());
+ assert!(ndarray.to_vec::<i32>().is_ok());
+ ndarray.copy_from_buffer(&mut data);
+ assert_eq!(ndarray.shape().unwrap(), shape);
+ assert_eq!(ndarray.to_vec::<i32>().unwrap(), data);
+ assert_eq!(ndarray.ndim(), 1);
+ assert!(ndarray.is_contiguous().is_ok());
+ assert_eq!(ndarray.byte_offset(), 0);
+ let shape = vec![4];
+ let e = NDArray::empty(
+ &shape,
+ Context::cpu(0),
+ DataType::from_str("int32").unwrap(),
+ );
+ let nd = ndarray.copy_to_ndarray(e);
+ assert!(nd.is_ok());
+ assert_eq!(nd.unwrap().to_vec::<i32>().unwrap(), data);
+ }
+
+ #[test]
+ #[should_panic(expected = "called `Result::unwrap()` on an `Err`")]
+ fn copy_wrong_dtype() {
+ let shape = vec![4];
+ let mut data = vec![1f32, 2., 3., 4.];
+ let ctx = Context::cpu(0);
+ let mut nd_float = NDArray::empty(&shape, ctx, DataType::from_str("float32").unwrap());
+ nd_float.copy_from_buffer(&mut data);
+ let empty_int = NDArray::empty(&shape, ctx, DataType::from_str("int32").unwrap());
+ nd_float.copy_to_ndarray(empty_int).unwrap();
+ }
+
+ #[test]
+ fn rust_ndarray() {
+ let a = Array::from_shape_vec((2, 2), vec![1f32, 2., 3., 4.])
+ .unwrap()
+ .into_dyn();
+ let nd =
+ NDArray::from_rust_ndarray(&a, Context::cpu(0), DataType::from_str("float32").unwrap())
+ .unwrap();
+ assert_eq!(nd.shape().unwrap(), &mut [2, 2]);
+ let rnd: ArrayD<f32> = ArrayD::try_from(&nd).unwrap();
+ assert!(rnd.all_close(&a, 1e-8f32));
+ }
+}
diff --git a/rust/tvm-rt/src/object/mod.rs b/rust/tvm-rt/src/object/mod.rs
new file mode 100644
index 0000000..c49f84e
--- /dev/null
+++ b/rust/tvm-rt/src/object/mod.rs
@@ -0,0 +1,117 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+use std::convert::TryFrom;
+use std::convert::TryInto;
+use std::ffi::CString;
+
+use crate::errors::Error;
+use crate::external;
+
+use tvm_sys::{ArgValue, RetValue};
+
+mod object_ptr;
+
+pub use object_ptr::{IsObject, Object, ObjectPtr};
+
+#[derive(Clone)]
+pub struct ObjectRef(pub Option<ObjectPtr<Object>>);
+
+impl ObjectRef {
+ pub fn null() -> ObjectRef {
+ ObjectRef(None)
+ }
+}
+
+pub trait ToObjectRef {
+ fn to_object_ref(&self) -> ObjectRef;
+}
+
+impl ToObjectRef for ObjectRef {
+ fn to_object_ref(&self) -> ObjectRef {
+ self.clone()
+ }
+}
+
+impl TryFrom<RetValue> for ObjectRef {
+ type Error = Error;
+
+ fn try_from(ret_val: RetValue) -> Result<ObjectRef, Self::Error> {
+ let optr = ret_val.try_into()?;
+ Ok(ObjectRef(Some(optr)))
+ }
+}
+
+impl From<ObjectRef> for RetValue {
+ fn from(object_ref: ObjectRef) -> RetValue {
+ use std::ffi::c_void;
+ let object_ptr = object_ref.0;
+ match object_ptr {
+ None => RetValue::ObjectHandle(std::ptr::null::<c_void>() as *mut c_void),
+ Some(value) => value.clone().into(),
+ }
+ }
+}
+
+impl<'a> std::convert::TryFrom<ArgValue<'a>> for ObjectRef {
+ type Error = Error;
+
+ fn try_from(arg_value: ArgValue<'a>) -> Result<ObjectRef, Self::Error> {
+ let optr = arg_value.try_into()?;
+ Ok(ObjectRef(Some(optr)))
+ }
+}
+
+impl<'a> std::convert::TryFrom<&ArgValue<'a>> for ObjectRef {
+ type Error = Error;
+
+ fn try_from(arg_value: &ArgValue<'a>) -> Result<ObjectRef, Self::Error> {
+ // TODO(@jroesch): remove the clone
+ let value: ArgValue<'a> = arg_value.clone();
+ ObjectRef::try_from(value)
+ }
+}
+
+impl<'a> From<ObjectRef> for ArgValue<'a> {
+ fn from(object_ref: ObjectRef) -> ArgValue<'a> {
+ use std::ffi::c_void;
+ let object_ptr = &object_ref.0;
+ match object_ptr {
+ None => ArgValue::ObjectHandle(std::ptr::null::<c_void>() as *mut c_void),
+ Some(value) => value.clone().into(),
+ }
+ }
+}
+
+impl<'a> From<&ObjectRef> for ArgValue<'a> {
+ fn from(object_ref: &ObjectRef) -> ArgValue<'a> {
+ let oref: ObjectRef = object_ref.clone();
+ ArgValue::<'a>::from(oref)
+ }
+}
+
+external! {
+ #[name("ir.DebugPrint")]
+ fn debug_print(object: ObjectRef) -> CString;
+}
+
+// external! {
+// #[name("ir.TextPrinter")]
+// fn as_text(object: ObjectRef) -> CString;
+// }
diff --git a/rust/tvm-rt/src/object/object_ptr.rs b/rust/tvm-rt/src/object/object_ptr.rs
new file mode 100644
index 0000000..40e2184
--- /dev/null
+++ b/rust/tvm-rt/src/object/object_ptr.rs
@@ -0,0 +1,353 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+use std::convert::TryFrom;
+use std::ffi::CString;
+use std::ptr::NonNull;
+use std::sync::atomic::AtomicI32;
+
+use tvm_sys::ffi::{self, TVMObjectFree, TVMObjectRetain, TVMObjectTypeKey2Index};
+use tvm_sys::{ArgValue, RetValue};
+
+use crate::errors::Error;
+
+type Deleter = unsafe extern "C" fn(object: *mut Object) -> ();
+
+#[derive(Debug)]
+#[repr(C)]
+pub struct Object {
+ pub type_index: u32,
+ // TODO(@jroesch): pretty sure Rust and C++ atomics are the same, but not sure.
+ // NB: in general we should not touch this in Rust.
+ pub(self) ref_count: AtomicI32,
+ pub fdeleter: Deleter,
+}
+
+unsafe extern "C" fn delete<T: IsObject>(object: *mut Object) {
+ let typed_object: *mut T = std::mem::transmute(object);
+ T::typed_delete(typed_object);
+}
+
+fn derived_from(child_type_index: u32, parent_type_index: u32) -> bool {
+ let mut is_derived = 0;
+ crate::check_call!(ffi::TVMObjectDerivedFrom(
+ child_type_index,
+ parent_type_index,
+ &mut is_derived
+ ));
+
+ if is_derived == 0 {
+ false
+ } else {
+ true
+ }
+}
+
+impl Object {
+ fn new(type_index: u32, deleter: Deleter) -> Object {
+ Object {
+ type_index,
+ // Note: do not touch this field directly again, this is
+ // a critical section, we write a 1 to the atomic which will now
+ // be managed by the C++ atomics.
+ // In the future we should probably use C-atomcis.
+ ref_count: AtomicI32::new(0),
+ fdeleter: deleter,
+ }
+ }
+
+ fn get_type_index<T: IsObject>() -> u32 {
+ let type_key = T::TYPE_KEY;
+ let cstring = CString::new(type_key).expect("type key must not contain null characters");
+ if type_key == "Object" {
+ return 0;
+ } else {
+ let mut index = 0;
+ unsafe {
+ let index_ptr = std::mem::transmute(&mut index);
+ if TVMObjectTypeKey2Index(cstring.as_ptr(), index_ptr) != 0 {
+ panic!(crate::get_last_error())
+ }
+ }
+ return index;
+ }
+ }
+
+ pub fn base_object<T: IsObject>() -> Object {
+ let index = Object::get_type_index::<T>();
+ Object::new(index, delete::<T>)
+ }
+
+ pub(self) fn inc_ref(&self) {
+ unsafe {
+ let raw_ptr = std::mem::transmute(self);
+ assert_eq!(TVMObjectRetain(raw_ptr), 0);
+ }
+ }
+
+ pub(self) fn dec_ref(&self) {
+ unsafe {
+ let raw_ptr = std::mem::transmute(self);
+ assert_eq!(TVMObjectFree(raw_ptr), 0);
+ }
+ }
+}
+
+pub unsafe trait IsObject {
+ const TYPE_KEY: &'static str;
+
+ fn as_object<'s>(&'s self) -> &'s Object;
+
+ unsafe extern "C" fn typed_delete(object: *mut Self) {
+ let object = Box::from_raw(object);
+ drop(object)
+ }
+}
+
+unsafe impl IsObject for Object {
+ const TYPE_KEY: &'static str = "Object";
+
+ fn as_object<'s>(&'s self) -> &'s Object {
+ self
+ }
+}
+
+#[repr(C)]
+pub struct ObjectPtr<T: IsObject> {
+ pub ptr: NonNull<T>,
+}
+
+fn inc_ref<T: IsObject>(ptr: NonNull<T>) {
+ unsafe { ptr.as_ref().as_object().inc_ref() }
+}
+
+fn dec_ref<T: IsObject>(ptr: NonNull<T>) {
+ unsafe { ptr.as_ref().as_object().dec_ref() }
+}
+
+impl ObjectPtr<Object> {
+ fn from_raw(object_ptr: *mut Object) -> Option<ObjectPtr<Object>> {
+ let non_null = NonNull::new(object_ptr);
+ non_null.map(|ptr| ObjectPtr { ptr })
+ }
+}
+
+impl<T: IsObject> Clone for ObjectPtr<T> {
+ fn clone(&self) -> Self {
+ inc_ref(self.ptr);
+ ObjectPtr { ptr: self.ptr }
+ }
+}
+
+impl<T: IsObject> Drop for ObjectPtr<T> {
+ fn drop(&mut self) {
+ dec_ref(self.ptr);
+ }
+}
+
+impl<T: IsObject> ObjectPtr<T> {
+ pub fn leak<'a>(object_ptr: ObjectPtr<T>) -> &'a mut T
+ where
+ T: 'a,
+ {
+ unsafe { &mut *std::mem::ManuallyDrop::new(object_ptr).ptr.as_ptr() }
+ }
+
+ pub fn new(object: T) -> ObjectPtr<T> {
+ let object_ptr = Box::new(object);
+ let object_ptr = Box::leak(object_ptr);
+ let ptr = NonNull::from(object_ptr);
+ inc_ref(ptr);
+ ObjectPtr { ptr }
+ }
+
+ pub fn count(&self) -> i32 {
+ // need to do atomic read in C++
+ // ABI compatible atomics is funky/hard.
+ self.as_object()
+ .ref_count
+ .load(std::sync::atomic::Ordering::SeqCst)
+ }
+
+ fn as_object<'s>(&'s self) -> &'s Object {
+ unsafe { self.ptr.as_ref().as_object() }
+ }
+
+ pub fn upcast(&self) -> ObjectPtr<Object> {
+ ObjectPtr {
+ ptr: self.ptr.cast(),
+ }
+ }
+
+ pub fn downcast<U: IsObject>(&self) -> Result<ObjectPtr<U>, Error> {
+ let child_index = Object::get_type_index::<U>();
+ let object_index = self.as_object().type_index;
+
+ let is_derived = if child_index == object_index {
+ true
+ } else {
+ // TODO(@jroesch): write tests
+ derived_from(object_index, child_index)
+ };
+
+ if is_derived {
+ Ok(ObjectPtr {
+ ptr: self.ptr.cast(),
+ })
+ } else {
+ Err(Error::downcast("TODOget_type_key".into(), U::TYPE_KEY))
+ }
+ }
+}
+
+impl<T: IsObject> std::ops::Deref for ObjectPtr<T> {
+ type Target = T;
+
+ fn deref(&self) -> &Self::Target {
+ unsafe { self.ptr.as_ref() }
+ }
+}
+
+impl<'a, T: IsObject> From<ObjectPtr<T>> for RetValue {
+ fn from(object_ptr: ObjectPtr<T>) -> RetValue {
+ let raw_object_ptr = ObjectPtr::leak(object_ptr);
+ let void_ptr = unsafe { std::mem::transmute(raw_object_ptr) };
+ RetValue::ObjectHandle(void_ptr)
+ }
+}
+
+impl<'a, T: IsObject> TryFrom<RetValue> for ObjectPtr<T> {
+ type Error = Error;
+
+ fn try_from(ret_value: RetValue) -> Result<ObjectPtr<T>, Self::Error> {
+ match ret_value {
+ RetValue::ObjectHandle(handle) => {
+ let handle: *mut Object = unsafe { std::mem::transmute(handle) };
+ let optr = ObjectPtr::from_raw(handle).ok_or(Error::Null)?;
+ optr.downcast()
+ }
+ _ => Err(Error::downcast(format!("{:?}", ret_value), "ObjectHandle")),
+ }
+ }
+}
+
+impl<'a, T: IsObject> From<ObjectPtr<T>> for ArgValue<'a> {
+ fn from(object_ptr: ObjectPtr<T>) -> ArgValue<'a> {
+ let raw_object_ptr = ObjectPtr::leak(object_ptr);
+ let void_ptr = unsafe { std::mem::transmute(raw_object_ptr) };
+ ArgValue::ObjectHandle(void_ptr)
+ }
+}
+
+impl<'a, T: IsObject> TryFrom<ArgValue<'a>> for ObjectPtr<T> {
+ type Error = Error;
+
+ fn try_from(arg_value: ArgValue<'a>) -> Result<ObjectPtr<T>, Self::Error> {
+ match arg_value {
+ ArgValue::ObjectHandle(handle) => {
+ let handle = unsafe { std::mem::transmute(handle) };
+ let optr = ObjectPtr::from_raw(handle).ok_or(Error::Null)?;
+ optr.downcast()
+ }
+ _ => Err(Error::downcast(format!("{:?}", arg_value), "ObjectHandle")),
+ }
+ }
+}
+
+impl<'a, T: IsObject> TryFrom<&ArgValue<'a>> for ObjectPtr<T> {
+ type Error = Error;
+
+ fn try_from(arg_value: &ArgValue<'a>) -> Result<ObjectPtr<T>, Self::Error> {
+ match arg_value {
+ ArgValue::ObjectHandle(handle) => {
+ let handle = unsafe { std::mem::transmute(handle) };
+ let optr = ObjectPtr::from_raw(handle).ok_or(Error::Null)?;
+ optr.downcast()
+ }
+ _ => Err(Error::downcast(format!("{:?}", arg_value), "ObjectHandle")),
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::{Object, ObjectPtr};
+ use anyhow::{ensure, Result};
+ use std::convert::TryInto;
+ use tvm_sys::{ArgValue, RetValue};
+
+ #[test]
+ fn test_new_object() -> anyhow::Result<()> {
+ let object = Object::base_object::<Object>();
+ let ptr = ObjectPtr::new(object);
+ assert_eq!(ptr.count(), 1);
+ Ok(())
+ }
+
+ #[test]
+ fn roundtrip_retvalue() -> Result<()> {
+ let ptr = ObjectPtr::new(Object::base_object::<Object>());
+ let ret_value: RetValue = ptr.clone().into();
+ let ptr2: ObjectPtr<Object> = ret_value.try_into()?;
+ ensure!(
+ ptr.type_index == ptr2.type_index,
+ "type indices do not match"
+ );
+ ensure!(
+ ptr.fdeleter == ptr2.fdeleter,
+ "objects have different deleters"
+ );
+ Ok(())
+ }
+
+ #[test]
+ fn roundtrip_argvalue() -> Result<()> {
+ let ptr = ObjectPtr::new(Object::base_object::<Object>());
+ let arg_value: ArgValue = ptr.clone().into();
+ let ptr2: ObjectPtr<Object> = arg_value.try_into()?;
+ ensure!(
+ ptr.type_index == ptr2.type_index,
+ "type indices do not match"
+ );
+ ensure!(
+ ptr.fdeleter == ptr2.fdeleter,
+ "objects have different deleters"
+ );
+ Ok(())
+ }
+
+ fn test_fn(o: ObjectPtr<Object>) -> ObjectPtr<Object> {
+ assert_eq!(o.count(), 2);
+ return o;
+ }
+
+ #[test]
+ fn test_ref_count_boundary() {
+ use super::*;
+ use crate::function::{register, Function, Result};
+ let ptr = ObjectPtr::new(Object::base_object::<Object>());
+ let stay = ptr.clone();
+ assert_eq!(ptr.count(), 2);
+ register(test_fn, "my_func").unwrap();
+ let func = Function::get("my_func").unwrap();
+ let func = func.to_boxed_fn::<dyn Fn(ObjectPtr<Object>) -> Result<ObjectPtr<Object>>>();
+ func(ptr).unwrap();
+ assert_eq!(stay.count(), 1);
+ }
+}
diff --git a/rust/tvm-rt/src/string.rs b/rust/tvm-rt/src/string.rs
new file mode 100644
index 0000000..26758b1
--- /dev/null
+++ b/rust/tvm-rt/src/string.rs
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+use std::ffi::{CString, NulError};
+use std::os::raw::c_char;
+
+use super::errors::Error;
+use super::{Object, ObjectPtr, ObjectRef};
+
+use tvm_macros::Object;
+
+#[repr(C)]
+#[derive(Object)]
+#[ref_name = "String"]
+#[type_key = "runtime.String"]
+pub struct StringObj {
+ base: Object,
+ data: *const c_char,
+ size: u64,
+}
+
+impl String {
+ pub fn new(string: std::string::String) -> Result<String, NulError> {
+ let cstring = CString::new(string)?;
+
+ // The string is being corrupted.
+ // why is this wrong
+ let length = cstring.as_bytes().len();
+
+ let string_obj = StringObj {
+ base: Object::base_object::<StringObj>(),
+ data: cstring.into_raw(),
+ size: length as u64,
+ };
+
+ let object_ptr = ObjectPtr::new(string_obj);
+ Ok(String(Some(object_ptr)))
+ }
+
+ pub fn to_cstring(&self) -> Result<std::ffi::CString, NulError> {
+ use std::slice;
+ let ptr = self.0.as_ref().unwrap().data;
+ let size = self.0.as_ref().unwrap().size;
+ unsafe {
+ let slice: &[u8] = slice::from_raw_parts(ptr as *const u8, size as usize);
+ CString::new(slice)
+ }
+ }
+
+ pub fn to_string(&self) -> Result<std::string::String, Error> {
+ let string = self.to_cstring()?.into_string()?;
+ Ok(string)
+ }
+}
+
+// #[cfg(test)]
+// mod tests {
+// use super::String;
+// use crate::object::debug_print;
+// use crate::ToObjectRef;
+// use anyhow::{ensure, Result};
+
+// #[test]
+// fn test_string_debug() -> Result<()> {
+// let s = String::new("foo".to_string()).unwrap();
+// let object_ref = s.to_object_ref();
+// println!("about to call");
+// let string = debug_print(object_ref)?;
+// println!("after call");
+// ensure!(
+// string.into_string().expect("is cstring").contains("foo"),
+// "string content is invalid"
+// );
+// Ok(())
+// }
+// }
diff --git a/rust/tvm-rt/src/to_boxed_fn.rs b/rust/tvm-rt/src/to_boxed_fn.rs
new file mode 100644
index 0000000..f0e5e80
--- /dev/null
+++ b/rust/tvm-rt/src/to_boxed_fn.rs
@@ -0,0 +1,227 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+//! This module provides a method for converting type erased TVM functions
+//! into a boxed Rust closure.
+//!
+//! To call a registered function check the [`ToBoxedFn::to_boxed_fn`] method.
+//!
+//! See the tests and examples repository for more examples.
+
+pub use tvm_sys::{ffi, ArgValue, RetValue};
+
+use crate::{errors, Module};
+
+use super::function::{Function, Result};
+
+pub trait ToBoxedFn {
+ fn to_boxed_fn(func: Function) -> Box<Self>;
+}
+
+use std::convert::{TryFrom, TryInto};
+
+impl<E, O> ToBoxedFn for dyn Fn() -> Result<O>
+where
+ errors::Error: From<E>,
+ O: TryFrom<RetValue, Error = E>,
+{
+ fn to_boxed_fn(func: Function) -> Box<Self> {
+ Box::new(move || {
+ let mut builder = Builder::default();
+ builder.func = Some(func.clone());
+ let res = builder.invoke()?.try_into()?;
+ Ok(res)
+ })
+ }
+}
+
+impl<E, A, O> ToBoxedFn for dyn Fn(A) -> Result<O>
+where
+ errors::Error: From<E>,
+ A: Into<ArgValue<'static>>,
+ O: TryFrom<RetValue, Error = E>,
+{
+ fn to_boxed_fn(func: Function) -> Box<Self> {
+ Box::new(move |a: A| {
+ let mut builder = Builder::default();
+ builder.func = Some(func.clone());
+ builder.arg(a.into());
+ let res = builder.invoke()?.try_into()?;
+ Ok(res)
+ })
+ }
+}
+
+impl<E, A, B, O> ToBoxedFn for dyn Fn(A, B) -> Result<O>
+where
+ errors::Error: From<E>,
+ A: Into<ArgValue<'static>>,
+ B: Into<ArgValue<'static>>,
+ O: TryFrom<RetValue, Error = E>,
+{
+ fn to_boxed_fn(func: Function) -> Box<Self> {
+ Box::new(move |a: A, b: B| {
+ let mut builder = Builder::default();
+ builder.func = Some(func.clone());
+ builder.arg(a.into());
+ builder.arg(b.into());
+ let res = builder.invoke()?.try_into()?;
+ Ok(res)
+ })
+ }
+}
+
+impl<E, A, B, C, O> ToBoxedFn for dyn Fn(A, B, C) -> Result<O>
+where
+ errors::Error: From<E>,
+ A: Into<ArgValue<'static>>,
+ B: Into<ArgValue<'static>>,
+ C: Into<ArgValue<'static>>,
+ O: TryFrom<RetValue, Error = E>,
+{
+ fn to_boxed_fn(func: Function) -> Box<Self> {
+ Box::new(move |a: A, b: B, c: C| {
+ let mut builder = Builder::default();
+ builder.func = Some(func.clone());
+ builder.arg(a.into());
+ builder.arg(b.into());
+ builder.arg(c.into());
+ let res = builder.invoke()?.try_into()?;
+ Ok(res)
+ })
+ }
+}
+
+impl<E, A, B, C, D, O> ToBoxedFn for dyn Fn(A, B, C, D) -> Result<O>
+where
+ errors::Error: From<E>,
+ A: Into<ArgValue<'static>>,
+ B: Into<ArgValue<'static>>,
+ C: Into<ArgValue<'static>>,
+ D: Into<ArgValue<'static>>,
+ O: TryFrom<RetValue, Error = E>,
+{
+ fn to_boxed_fn(func: Function) -> Box<Self> {
+ Box::new(move |a: A, b: B, c: C, d: D| {
+ let mut builder = Builder::default();
+ builder.func = Some(func.clone());
+ builder.arg(a.into());
+ builder.arg(b.into());
+ builder.arg(c.into());
+ builder.arg(d.into());
+ let res = builder.invoke()?.try_into()?;
+ Ok(res)
+ })
+ }
+}
+
+/// Function builder in order to create and call functions.
+///
+/// *Note:* Currently TVM functions accept *at most* one return value.
+#[derive(Default)]
+pub struct Builder<'a> {
+ pub func: Option<Function>,
+ pub arg_buf: Vec<ArgValue<'a>>,
+ pub ret_buf: Option<RetValue>,
+}
+
+impl<'a, 'm> Builder<'a> {
+ pub fn new(
+ func: Option<Function>,
+ arg_buf: Vec<ArgValue<'a>>,
+ ret_buf: Option<RetValue>,
+ ) -> Self {
+ Self {
+ func,
+ arg_buf,
+ ret_buf,
+ }
+ }
+
+ pub fn get_function(&mut self, name: &'m str) -> &mut Self {
+ self.func = Function::get(name);
+ self
+ }
+
+ /// Pushes a [`ArgValue`] into the function argument buffer.
+ pub fn arg<T: 'a>(&mut self, arg: T) -> &mut Self
+ where
+ ArgValue<'a>: From<T>,
+ {
+ self.arg_buf.push(arg.into());
+ self
+ }
+
+ /// Pushes multiple [`ArgValue`]s into the function argument buffer.
+ pub fn args<T: 'a, I>(&mut self, args: I) -> &mut Self
+ where
+ I: IntoIterator<Item = T>,
+ ArgValue<'a>: From<T>,
+ {
+ args.into_iter().for_each(|arg| {
+ self.arg(arg);
+ });
+ self
+ }
+
+ /// Sets an output for a function that requires a mutable output to be provided.
+ /// See the `basics` in tests for an example.
+ pub fn set_output<T>(&mut self, ret: T) -> &mut Self
+ where
+ RetValue: From<T>,
+ {
+ self.ret_buf = Some(ret.into());
+ self
+ }
+
+ pub fn invoke(self) -> Result<RetValue> {
+ self.func.unwrap().invoke(self.arg_buf)
+ }
+}
+
+/// Converts a [`Function`] to builder. Currently, this is the best way to work with
+/// TVM functions.
+impl<'a, 'm> From<Function> for Builder<'a> {
+ fn from(func: Function) -> Self {
+ Builder::new(Some(func), Vec::new(), None)
+ }
+}
+
+/// Converts a mutable reference of a [`Module`] to [`Builder`].
+impl<'a, 'm> From<&'m mut Module> for Builder<'a> {
+ fn from(module: &'m mut Module) -> Self {
+ Builder::new(module.entry(), Vec::new(), None)
+ }
+}
+#[cfg(test)]
+mod tests {
+ use crate::function::{self, Function, Result};
+
+ #[test]
+ fn to_boxed_fn0() {
+ fn boxed0() -> i64 {
+ return 10;
+ }
+
+ function::register_override(boxed0, "boxed0".to_owned(), true).unwrap();
+ let func = Function::get("boxed0").unwrap();
+ let typed_func: Box<dyn Fn() -> Result<i64>> = func.to_boxed_fn();
+ assert_eq!(typed_func().unwrap(), 10);
+ }
+}
diff --git a/rust/tvm-rt/src/to_function.rs b/rust/tvm-rt/src/to_function.rs
new file mode 100644
index 0000000..4814d09
--- /dev/null
+++ b/rust/tvm-rt/src/to_function.rs
@@ -0,0 +1,307 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+//! This module provides an idiomatic Rust API for creating and working with TVM functions.
+//!
+//! For calling an already registered TVM function use [`function::Builder`]
+//! To register a TVM packed function from Rust side either
+//! use [`function::register`] or the macro [`register_global_func`].
+//!
+//! See the tests and examples repository for more examples.
+
+use std::convert::{TryFrom, TryInto};
+use std::{
+ os::raw::{c_int, c_void},
+ ptr, slice,
+};
+
+use super::{function::Result, Function};
+use crate::errors::Error;
+
+pub use tvm_sys::{ffi, ArgValue, RetValue};
+
+/// A trait representing whether the function arguments
+/// and return type can be assigned to a TVM packed function.
+///
+/// By splitting the conversion to function into two traits
+/// we are able to improve error reporting, by splitting the
+/// conversion of inputs and outputs to this trait.
+///
+/// And the implementation of it to `ToFunction`.
+pub trait Typed<I, O> {
+ fn args(i: &[ArgValue<'static>]) -> Result<I>;
+ fn ret(o: O) -> RetValue;
+}
+
+impl<F, O: Into<RetValue>> Typed<(), O> for F
+where
+ F: Fn() -> O,
+{
+ fn args(_args: &[ArgValue<'static>]) -> Result<()> {
+ debug_assert!(_args.len() == 0);
+ Ok(())
+ }
+
+ fn ret(o: O) -> RetValue {
+ o.into()
+ }
+}
+
+impl<F, A, O: Into<RetValue>, E> Typed<(A,), O> for F
+where
+ F: Fn(A) -> O,
+ Error: From<E>,
+ A: TryFrom<ArgValue<'static>, Error = E>,
+{
+ fn args(args: &[ArgValue<'static>]) -> Result<(A,)> {
+ debug_assert!(args.len() == 1);
+ let a: A = args[0].clone().try_into()?;
+ Ok((a,))
+ }
+
+ fn ret(o: O) -> RetValue {
+ o.into()
+ }
+}
+
+impl<F, A, B, O: Into<RetValue>, E> Typed<(A, B), O> for F
+where
+ F: Fn(A, B) -> O,
+ Error: From<E>,
+ A: TryFrom<ArgValue<'static>, Error = E>,
+ B: TryFrom<ArgValue<'static>, Error = E>,
+{
+ fn args(args: &[ArgValue<'static>]) -> Result<(A, B)> {
+ debug_assert!(args.len() == 2);
+ let a: A = args[0].clone().try_into()?;
+ let b: B = args[1].clone().try_into()?;
+ Ok((a, b))
+ }
+
+ fn ret(o: O) -> RetValue {
+ o.into()
+ }
+}
+
+impl<F, A, B, C, O: Into<RetValue>, E> Typed<(A, B, C), O> for F
+where
+ F: Fn(A, B, C) -> O,
+ Error: From<E>,
+ A: TryFrom<ArgValue<'static>, Error = E>,
+ B: TryFrom<ArgValue<'static>, Error = E>,
+ C: TryFrom<ArgValue<'static>, Error = E>,
+{
+ fn args(args: &[ArgValue<'static>]) -> Result<(A, B, C)> {
+ debug_assert!(args.len() == 3);
+ let a: A = args[0].clone().try_into()?;
+ let b: B = args[1].clone().try_into()?;
+ let c: C = args[2].clone().try_into()?;
+ Ok((a, b, c))
+ }
+
+ fn ret(o: O) -> RetValue {
+ o.into()
+ }
+}
+
+pub trait ToFunction<I, O>: Sized {
+ type Handle;
+
+ fn into_raw(self) -> *mut Self::Handle;
+
+ fn call(handle: *mut Self::Handle, args: &[ArgValue<'static>]) -> Result<RetValue>
+ where
+ Self: Typed<I, O>;
+
+ fn drop(handle: *mut Self::Handle);
+
+ fn to_function(self) -> Function
+ where
+ Self: Typed<I, O>,
+ {
+ let mut fhandle = ptr::null_mut() as ffi::TVMFunctionHandle;
+ let resource_handle = self.into_raw();
+
+ check_call!(ffi::TVMFuncCreateFromCFunc(
+ Some(Self::tvm_callback),
+ resource_handle as *mut _,
+ None, // Some(Self::tvm_finalizer),
+ &mut fhandle as *mut ffi::TVMFunctionHandle,
+ ));
+
+ Function::new(fhandle)
+ }
+
+ /// The callback function which is wrapped converted by TVM
+ /// into a packed function stored in fhandle.
+ unsafe extern "C" fn tvm_callback(
+ args: *mut ffi::TVMValue,
+ type_codes: *mut c_int,
+ num_args: c_int,
+ ret: ffi::TVMRetValueHandle,
+ resource_handle: *mut c_void,
+ ) -> c_int
+ where
+ Self: Typed<I, O>,
+ {
+ #![allow(unused_assignments, unused_unsafe)]
+ // turning off the incorrect linter complaints
+ let len = num_args as usize;
+ let args_list = slice::from_raw_parts_mut(args, len);
+ let type_codes_list = slice::from_raw_parts_mut(type_codes, len);
+ let mut local_args: Vec<ArgValue> = Vec::new();
+ let mut value = ffi::TVMValue { v_int64: 0 };
+ let mut tcode = 0;
+ let resource_handle = resource_handle as *mut Self::Handle;
+ for i in 0..len {
+ value = args_list[i];
+ tcode = type_codes_list[i];
+ if tcode == ffi::TVMArgTypeCode_kTVMObjectHandle as c_int
+ || tcode == ffi::TVMArgTypeCode_kTVMPackedFuncHandle as c_int
+ || tcode == ffi::TVMArgTypeCode_kTVMModuleHandle as c_int
+ {
+ check_call!(ffi::TVMCbArgToReturn(
+ &mut value as *mut _,
+ &mut tcode as *mut _
+ ));
+ }
+ let arg_value = ArgValue::from_tvm_value(value, tcode as u32);
+ local_args.push(arg_value);
+ }
+
+ let rv = match Self::call(resource_handle, local_args.as_slice()) {
+ Ok(v) => v,
+ Err(msg) => {
+ crate::set_last_error(&msg);
+ return -1;
+ }
+ };
+
+ let (mut ret_val, ret_tcode) = rv.to_tvm_value();
+ let mut ret_type_code = ret_tcode as c_int;
+
+ check_call!(ffi::TVMCFuncSetReturn(
+ ret,
+ &mut ret_val as *mut _,
+ &mut ret_type_code as *mut _,
+ 1 as c_int
+ ));
+ 0
+ }
+
+ /// The finalizer which is invoked when the packed function's
+ /// reference count is zero.
+ unsafe extern "C" fn tvm_finalizer(fhandle: *mut c_void) {
+ let handle = std::mem::transmute(fhandle);
+ Self::drop(handle)
+ }
+}
+
+impl<O, F> ToFunction<(), O> for F
+where
+ F: Fn() -> O + 'static,
+{
+ type Handle = Box<dyn Fn() -> O + 'static>;
+
+ fn into_raw(self) -> *mut Self::Handle {
+ let ptr: Box<Self::Handle> = Box::new(Box::new(self));
+ Box::into_raw(ptr)
+ }
+
+ fn call(handle: *mut Self::Handle, _: &[ArgValue<'static>]) -> Result<RetValue>
+ where
+ F: Typed<(), O>,
+ {
+ // Ideally we shouldn't need to clone, probably doesn't really matter.
+ let out = unsafe { (*handle)() };
+ Ok(F::ret(out))
+ }
+
+ fn drop(_: *mut Self::Handle) {}
+}
+
+macro_rules! to_function_instance {
+ ($(($param:ident,$index:tt),)+) => {
+ impl<F, $($param,)+ O> ToFunction<($($param,)+), O> for
+ F where F: Fn($($param,)+) -> O + 'static {
+ type Handle = Box<dyn Fn($($param,)+) -> O + 'static>;
+
+ fn into_raw(self) -> *mut Self::Handle {
+ let ptr: Box<Self::Handle> = Box::new(Box::new(self));
+ Box::into_raw(ptr)
+ }
+
+ fn call(handle: *mut Self::Handle, args: &[ArgValue<'static>]) -> Result<RetValue> where F: Typed<($($param,)+), O> {
+ // Ideally we shouldn't need to clone, probably doesn't really matter.
+ let args = F::args(args)?;
+ let out = unsafe {
+ (*handle)($(args.$index),+)
+ };
+ Ok(F::ret(out))
+ }
+
+ fn drop(_: *mut Self::Handle) {}
+ }
+ }
+}
+
+to_function_instance!((A, 0),);
+to_function_instance!((A, 0), (B, 1),);
+to_function_instance!((A, 0), (B, 1), (C, 2),);
+to_function_instance!((A, 0), (B, 1), (C, 2), (D, 3),);
+
+#[cfg(test)]
+mod tests {
+ use super::{Function, ToFunction, Typed};
+
+ fn zero() -> i32 {
+ 10
+ }
+
+ fn helper<F, I, O>(f: F) -> Function
+ where
+ F: ToFunction<I, O>,
+ F: Typed<I, O>,
+ {
+ f.to_function()
+ }
+
+ #[test]
+ fn test_to_function0() {
+ helper(zero);
+ }
+
+ fn one_arg(i: i32) -> i32 {
+ i
+ }
+
+ #[test]
+ fn test_to_function1() {
+ helper(one_arg);
+ }
+
+ fn two_arg(i: i32, j: i32) -> i32 {
+ i + j
+ }
+
+ #[test]
+ fn test_to_function2() {
+ helper(two_arg);
+ }
+}
diff --git a/rust/tvm-rt/src/value.rs b/rust/tvm-rt/src/value.rs
new file mode 100644
index 0000000..1812c0c
--- /dev/null
+++ b/rust/tvm-rt/src/value.rs
@@ -0,0 +1,161 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+//! This module implements [`ArgValue`] and [`RetValue`] types
+//! and their conversions needed for the types used in frontend crate.
+//! `RetValue` is the owned version of `TVMPODValue`.
+
+use std::convert::TryFrom;
+// use std::ffi::c_void;
+
+use crate::{ArgValue, Module, NDArray, RetValue};
+use tvm_sys::{errors::ValueDowncastError, ffi::TVMModuleHandle, try_downcast};
+
+macro_rules! impl_handle_val {
+ ($type:ty, $variant:ident, $inner_type:ty, $ctor:path) => {
+ impl<'a> From<&'a $type> for ArgValue<'a> {
+ fn from(arg: &'a $type) -> Self {
+ ArgValue::$variant(arg.handle() as $inner_type)
+ }
+ }
+
+ impl<'a> From<&'a mut $type> for ArgValue<'a> {
+ fn from(arg: &'a mut $type) -> Self {
+ ArgValue::$variant(arg.handle() as $inner_type)
+ }
+ }
+
+ impl<'a> TryFrom<ArgValue<'a>> for $type {
+ type Error = ValueDowncastError;
+ fn try_from(val: ArgValue<'a>) -> Result<$type, Self::Error> {
+ try_downcast!(val -> $type, |ArgValue::$variant(val)| { $ctor(val) })
+ }
+ }
+
+ impl<'a, 'v> TryFrom<&'a ArgValue<'v>> for $type {
+ type Error = ValueDowncastError;
+ fn try_from(val: &'a ArgValue<'v>) -> Result<$type, Self::Error> {
+ try_downcast!(val -> $type, |ArgValue::$variant(val)| { $ctor(*val) })
+ }
+ }
+
+ impl From<$type> for RetValue {
+ fn from(val: $type) -> RetValue {
+ RetValue::$variant(val.handle() as $inner_type)
+ }
+ }
+
+ impl TryFrom<RetValue> for $type {
+ type Error = ValueDowncastError;
+ fn try_from(val: RetValue) -> Result<$type, Self::Error> {
+ try_downcast!(val -> $type, |RetValue::$variant(val)| { $ctor(val) })
+ }
+ }
+ };
+}
+
+impl_handle_val!(Module, ModuleHandle, TVMModuleHandle, Module::new);
+
+impl<'a> From<&'a NDArray> for ArgValue<'a> {
+ fn from(arg: &'a NDArray) -> Self {
+ match arg {
+ &NDArray::Borrowed { handle } => ArgValue::ArrayHandle(handle),
+ &NDArray::Owned { handle } => ArgValue::NDArrayHandle(handle),
+ }
+ }
+}
+
+impl<'a> From<&'a mut NDArray> for ArgValue<'a> {
+ fn from(arg: &'a mut NDArray) -> Self {
+ match arg {
+ &mut NDArray::Borrowed { handle } => ArgValue::ArrayHandle(handle),
+ &mut NDArray::Owned { handle } => ArgValue::NDArrayHandle(handle),
+ }
+ }
+}
+
+impl<'a> TryFrom<ArgValue<'a>> for NDArray {
+ type Error = ValueDowncastError;
+ fn try_from(val: ArgValue<'a>) -> Result<NDArray, Self::Error> {
+ try_downcast!(val -> NDArray,
+ |ArgValue::NDArrayHandle(val)| { NDArray::from_ndarray_handle(val) },
+ |ArgValue::ArrayHandle(val)| { NDArray::new(val) })
+ }
+}
+
+impl<'a, 'v> TryFrom<&'a ArgValue<'v>> for NDArray {
+ type Error = ValueDowncastError;
+ fn try_from(val: &'a ArgValue<'v>) -> Result<NDArray, Self::Error> {
+ try_downcast!(val -> NDArray,
+ |ArgValue::NDArrayHandle(val)| { NDArray::from_ndarray_handle(*val) },
+ |ArgValue::ArrayHandle(val)| { NDArray::new(*val) })
+ }
+}
+
+impl From<NDArray> for RetValue {
+ fn from(val: NDArray) -> RetValue {
+ match val {
+ NDArray::Owned { handle } => RetValue::NDArrayHandle(handle),
+ _ => panic!("NYI"),
+ }
+ }
+}
+
+impl TryFrom<RetValue> for NDArray {
+ type Error = ValueDowncastError;
+ fn try_from(val: RetValue) -> Result<NDArray, Self::Error> {
+ try_downcast!(val -> NDArray,
+ |RetValue::NDArrayHandle(val)| { NDArray::from_ndarray_handle(val) },
+ |RetValue::ArrayHandle(val)| { NDArray::new(val) })
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use std::{convert::TryInto, str::FromStr};
+
+ use crate::{ByteArray, Context, DataType};
+
+ use super::*;
+
+ #[test]
+ fn bytearray() {
+ let w = vec![1u8, 2, 3, 4, 5];
+ let v = ByteArray::from(w.as_slice());
+ let tvm: ByteArray = RetValue::from(v).try_into().unwrap();
+ assert_eq!(
+ tvm.data(),
+ w.iter().copied().collect::<Vec<u8>>().as_slice()
+ );
+ }
+
+ #[test]
+ fn ty() {
+ let t = DataType::from_str("int32").unwrap();
+ let tvm: DataType = RetValue::from(t).try_into().unwrap();
+ assert_eq!(tvm, t);
+ }
+
+ #[test]
+ fn ctx() {
+ let c = Context::from_str("gpu").unwrap();
+ let tvm: Context = RetValue::from(c).try_into().unwrap();
+ assert_eq!(tvm, c);
+ }
+}
diff --git a/rust/tvm-sys/src/byte_array.rs b/rust/tvm-sys/src/byte_array.rs
index 40f28f4..9bd9526 100644
--- a/rust/tvm-sys/src/byte_array.rs
+++ b/rust/tvm-sys/src/byte_array.rs
@@ -16,9 +16,12 @@
* specific language governing permissions and limitations
* under the License.
*/
+use std::convert::TryFrom;
use std::os::raw::c_char;
+use crate::errors::ValueDowncastError;
use crate::ffi::TVMByteArray;
+use crate::{ArgValue, RetValue};
/// A newtype wrapping a raw TVM byte-array.
///
@@ -69,6 +72,39 @@ impl<T: AsRef<[u8]>> From<T> for ByteArray {
}
}
+impl TryFrom<ArgValue<'static>> for ByteArray {
+ type Error = ValueDowncastError;
+
+ fn try_from(val: ArgValue<'static>) -> Result<ByteArray, Self::Error> {
+ match val {
+ ArgValue::Bytes(array) => Ok(ByteArray { array: *array }),
+ _ => Err(ValueDowncastError {
+ expected_type: "ByteArray",
+ actual_type: format!("{:?}", val),
+ }),
+ }
+ }
+}
+
+impl From<ByteArray> for RetValue {
+ fn from(val: ByteArray) -> RetValue {
+ RetValue::Bytes(val.array)
+ }
+}
+
+impl TryFrom<RetValue> for ByteArray {
+ type Error = ValueDowncastError;
+ fn try_from(val: RetValue) -> Result<ByteArray, Self::Error> {
+ match val {
+ RetValue::Bytes(array) => Ok(ByteArray { array }),
+ _ => Err(ValueDowncastError {
+ expected_type: "ByteArray",
+ actual_type: format!("{:?}", val),
+ }),
+ }
+ }
+}
+
#[cfg(test)]
mod tests {
use super::*;
diff --git a/rust/tvm-sys/src/datatype.rs b/rust/tvm-sys/src/datatype.rs
index 5dd414c..ccdee3f 100644
--- a/rust/tvm-sys/src/datatype.rs
+++ b/rust/tvm-sys/src/datatype.rs
@@ -95,6 +95,16 @@ impl From<DLDataType> for DataType {
}
}
+impl From<DataType> for DLDataType {
+ fn from(dtype: DataType) -> Self {
+ Self {
+ code: dtype.code,
+ bits: dtype.bits,
+ lanes: dtype.lanes,
+ }
+ }
+}
+
#[derive(Debug, Error)]
pub enum ParseDataTypeError {
#[error("invalid number: {0}")]
diff --git a/rust/tvm-sys/src/errors.rs b/rust/tvm-sys/src/errors.rs
index 8479ec6..54fe261 100644
--- a/rust/tvm-sys/src/errors.rs
+++ b/rust/tvm-sys/src/errors.rs
@@ -39,7 +39,7 @@ impl FuncCallError {
context,
message: unsafe { std::ffi::CStr::from_ptr(crate::ffi::TVMGetLastError()) }
.to_str()
- .expect("double fault")
+ .expect("failed while attempting to retrieve the TVM error message")
.to_owned(),
}
}
diff --git a/rust/tvm-sys/src/lib.rs b/rust/tvm-sys/src/lib.rs
index dd28e36..0f455e7 100644
--- a/rust/tvm-sys/src/lib.rs
+++ b/rust/tvm-sys/src/lib.rs
@@ -34,8 +34,13 @@ pub mod ffi {
include!(concat!(env!("CARGO_MANIFEST_DIR"), "/src/c_runtime_api.rs"));
- pub type BackendPackedCFunc =
- extern "C" fn(args: *const TVMValue, type_codes: *const c_int, num_args: c_int) -> c_int;
+ pub type BackendPackedCFunc = extern "C" fn(
+ args: *const TVMValue,
+ type_codes: *const c_int,
+ num_args: c_int,
+ out_ret_value: *mut TVMValue,
+ out_ret_tcode: *mut u32,
+ ) -> c_int;
}
pub mod array;
diff --git a/src/ir/expr.cc b/src/ir/expr.cc
index 699b4db..97e285c 100644
--- a/src/ir/expr.cc
+++ b/src/ir/expr.cc
@@ -185,4 +185,11 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
}
p->stream << '}';
});
+
+TVM_REGISTER_GLOBAL("ir.DebugPrint").set_body_typed([](ObjectRef ref) {
+ std::stringstream ss;
+ ss << ref;
+ return ss.str();
+});
+
} // namespace tvm
diff --git a/src/printer/relay_text_printer.cc b/src/printer/relay_text_printer.cc
index ad16f86..981d0c3 100644
--- a/src/printer/relay_text_printer.cc
+++ b/src/printer/relay_text_printer.cc
@@ -823,5 +823,12 @@ std::vector<Doc> RelayTextPrinter::PrintFuncAttrs(const Attrs& attrs) {
return docs;
}
+TVM_REGISTER_GLOBAL("ir.TextPrinter").set_body_typed([](ObjectRef node) {
+ std::cout << "The program: " << node << std::endl;
+ auto text = AsText(node, false, nullptr);
+ std::cout << "The text " << text;
+ return text;
+});
+
} // namespace relay
} // namespace tvm
diff --git a/src/relay/transforms/to_cps.cc b/src/relay/transforms/to_cps.cc
index 9585904..6972d5a 100644
--- a/src/relay/transforms/to_cps.cc
+++ b/src/relay/transforms/to_cps.cc
@@ -151,7 +151,7 @@ Function ToCPS(const Function& f, const IRModule& m, CPSMap* cm, VarMap* vm,
// only look unfold non-external calls.
BaseFunc base_func = m->Lookup(gv);
if (auto* n = base_func.as<FunctionNode>()) {
- auto cps_gv = GlobalVar(gv->name_hint + "_cps");
+ auto cps_gv = GlobalVar(std::string(gv->name_hint) + "_cps");
cm->insert({gv, cps_gv});
m->Add(cps_gv, ToCPS(GetRef<Function>(n), m, cm));
} else {
diff --git a/src/runtime/object.cc b/src/runtime/object.cc
index 00be440..dc5f1ce 100644
--- a/src/runtime/object.cc
+++ b/src/runtime/object.cc
@@ -234,12 +234,25 @@ int TVMObjectGetTypeIndex(TVMObjectHandle obj, unsigned* out_tindex) {
API_END();
}
+int TVMObjectRetain(TVMObjectHandle obj) {
+ API_BEGIN();
+ tvm::runtime::ObjectInternal::ObjectRetain(obj);
+ API_END();
+}
+
int TVMObjectFree(TVMObjectHandle obj) {
API_BEGIN();
tvm::runtime::ObjectInternal::ObjectFree(obj);
API_END();
}
+int TVMObjectDerivedFrom(uint32_t child_type_index, uint32_t parent_type_index, int* is_derived) {
+ API_BEGIN();
+ *is_derived =
+ tvm::runtime::TypeContext::Global()->DerivedFrom(child_type_index, parent_type_index);
+ API_END();
+}
+
int TVMObjectTypeKey2Index(const char* type_key, unsigned* out_tindex) {
API_BEGIN();
out_tindex[0] = tvm::runtime::ObjectInternal::ObjectTypeKey2Index(type_key);
diff --git a/src/runtime/object_internal.h b/src/runtime/object_internal.h
index 35642fb..f255b28 100644
--- a/src/runtime/object_internal.h
+++ b/src/runtime/object_internal.h
@@ -39,6 +39,15 @@ namespace runtime {
class ObjectInternal {
public:
/*!
+ * \brief Retain an object handle.
+ */
+ static void ObjectRetain(TVMObjectHandle obj) {
+ if (obj != nullptr) {
+ static_cast<Object*>(obj)->IncRef();
+ }
+ }
+
+ /*!
* \brief Free an object handle.
*/
static void ObjectFree(TVMObjectHandle obj) {