You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by jr...@apache.org on 2020/06/08 21:01:23 UTC
[incubator-tvm] 01/04: Reworking errors and proc macros
This is an automated email from the ASF dual-hosted git repository.
jroesch pushed a commit to branch rust-tvm
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
commit fc6fac254d02f91dae146c81c618cd17f8bf9d3c
Author: Jared Roesch <jr...@octoml.ai>
AuthorDate: Sat Jun 6 21:54:58 2020 -0700
Reworking errors and proc macros
---
rust/macros/src/lib.rs | 7 +++++
rust/tvm-rt/src/errors.rs | 5 +++-
rust/tvm-rt/src/function.rs | 53 ++++++++++--------------------------
rust/tvm-rt/src/ndarray.rs | 23 ++++++++++------
rust/tvm-rt/src/object/mod.rs | 2 +-
rust/tvm-rt/src/object/object_ptr.rs | 2 +-
rust/tvm-rt/src/to_function.rs | 2 +-
rust/tvm-rt/src/value.rs | 6 +---
rust/tvm/src/ir/array.rs | 5 ++--
rust/tvm/src/lib.rs | 9 ++----
10 files changed, 48 insertions(+), 66 deletions(-)
diff --git a/rust/macros/src/lib.rs b/rust/macros/src/lib.rs
index e9ddc25..d0ac1ca 100644
--- a/rust/macros/src/lib.rs
+++ b/rust/macros/src/lib.rs
@@ -18,6 +18,8 @@
*/
use proc_macro::TokenStream;
+
+mod external;
mod import_module;
mod object;
@@ -31,3 +33,8 @@ pub fn macro_impl(input: TokenStream) -> TokenStream {
// let input = proc_macro2::TokenStream::from(input);
TokenStream::from(object::macro_impl(input))
}
+
+#[proc_macro]
+pub fn external(input: TokenStream) -> TokenStream {
+ external::macro_impl(input)
+}
diff --git a/rust/tvm-rt/src/errors.rs b/rust/tvm-rt/src/errors.rs
index 41e873f..f081258 100644
--- a/rust/tvm-rt/src/errors.rs
+++ b/rust/tvm-rt/src/errors.rs
@@ -48,7 +48,10 @@ pub enum NDArrayError {
#[error("a shape error occurred in the Rust ndarray library")]
ShapeError(#[from] ndarray::ShapeError),
#[error("Expected type `{expected}` but found `{actual}`")]
- DataTypeMismatch { expected: DataType, actual: DataType }
+ DataTypeMismatch {
+ expected: DataType,
+ actual: DataType,
+ },
}
#[derive(Debug, Error)]
diff --git a/rust/tvm-rt/src/function.rs b/rust/tvm-rt/src/function.rs
index 17f5f6e..b0122ff 100644
--- a/rust/tvm-rt/src/function.rs
+++ b/rust/tvm-rt/src/function.rs
@@ -25,6 +25,9 @@
//!
//! See the tests and examples repository for more examples.
+use anyhow::Result;
+use lazy_static::lazy_static;
+use std::convert::TryFrom;
use std::{
collections::BTreeMap,
ffi::{CStr, CString},
@@ -33,9 +36,6 @@ use std::{
ptr, slice, str,
sync::Mutex,
};
-use std::convert::{TryFrom};
-use anyhow::Result;
-use lazy_static::lazy_static;
pub use tvm_sys::{ffi, ArgValue, RetValue};
@@ -194,7 +194,10 @@ impl TryFrom<RetValue> for Function {
fn try_from(ret_value: RetValue) -> Result<Function, Self::Error> {
match ret_value {
RetValue::FuncHandle(handle) => Ok(Function::new(handle)),
- _ => Err(Error::downcast(format!("{:?}", ret_value), "FunctionHandle"))
+ _ => Err(Error::downcast(
+ format!("{:?}", ret_value),
+ "FunctionHandle",
+ )),
}
}
}
@@ -211,7 +214,10 @@ impl<'a> TryFrom<ArgValue<'a>> for Function {
fn try_from(arg_value: ArgValue<'a>) -> Result<Function, Self::Error> {
match arg_value {
ArgValue::FuncHandle(handle) => Ok(Function::new(handle)),
- _ => Err(Error::downcast(format!("{:?}", arg_value), "FunctionHandle")),
+ _ => Err(Error::downcast(
+ format!("{:?}", arg_value),
+ "FunctionHandle",
+ )),
}
}
}
@@ -222,7 +228,10 @@ impl<'a> TryFrom<&ArgValue<'a>> for Function {
fn try_from(arg_value: &ArgValue<'a>) -> Result<Function, Self::Error> {
match arg_value {
ArgValue::FuncHandle(handle) => Ok(Function::new(*handle)),
- _ => Err(Error::downcast(format!("{:?}", arg_value), "FunctionHandle")),
+ _ => Err(Error::downcast(
+ format!("{:?}", arg_value),
+ "FunctionHandle",
+ )),
}
}
}
@@ -286,38 +295,6 @@ where
Ok(())
}
-#[macro_export]
-macro_rules! external_func_impl {
- ($name:ident , $($ty_param:tt)* , ( $($arg:ident : $ty:ty),* ), $ret_type:ty, $ext_name:literal) => {
- ::paste::item! {
- #[allow(non_upper_case_globals)]
- static [<global_ $name>]: ::once_cell::sync::Lazy<&'static $crate::Function> =
- ::once_cell::sync::Lazy::new(|| {
- $crate::Function::get($ext_name)
- .expect(concat!("unable to load external function", stringify!($ext_name), "from TVM registry."))
- });
- }
-
- pub fn $name<$($ty_param),*>($($arg : $ty),*) -> anyhow::Result<$ret_type> w,* {
- let func_ref: &$crate::Function = ::paste::expr! { &*[<global_ $name>] };
- let func_ref: Box<dyn Fn($($ty),*) -> anyhow::Result<$ret_type>> = func_ref.to_boxed_fn();
- let res: $ret_type = func_ref($($arg),*)?;
- Ok(res)
- }
- }
-}
-
-
-#[macro_export]
-macro_rules! external_func {
- (fn $name:ident ( $($arg:ident : $ty:ty),* ) -> $ret_type:ty as $ext_name:literal;) => {
- $crate::external_func_impl!($name, , ( $($arg : $ty),* ) , $ret_type, $ext_name);
- };
- (fn $name:ident < $($ty_param:ident),* > ( $($arg:ident : $ty:ty),* ) -> $ret_type:ty as $ext_name:literal;) => {
- $crate::external_func_impl!($name, $($ty_param:ident),* , ( $($arg : $ty),* ) , $ret_type, $ext_name);
- }
-}
-
#[cfg(test)]
mod tests {
use super::*;
diff --git a/rust/tvm-rt/src/ndarray.rs b/rust/tvm-rt/src/ndarray.rs
index f97b3a4..593154d 100644
--- a/rust/tvm-rt/src/ndarray.rs
+++ b/rust/tvm-rt/src/ndarray.rs
@@ -47,9 +47,9 @@
//! [`copy_from_buffer`]:struct.NDArray.html#method.copy_from_buffer
//! [`copy_to_ctx`]:struct.NDArray.html#method.copy_to_ctx
-use std::{convert::TryFrom, mem, os::raw::c_int, ptr, slice, str::FromStr};
use std::convert::TryInto;
use std::ffi::c_void;
+use std::{convert::TryFrom, mem, os::raw::c_int, ptr, slice, str::FromStr};
use crate::errors::NDArrayError;
@@ -190,7 +190,9 @@ impl NDArray {
/// assert_eq!(ndarray.to_vec::<i32>().unwrap(), data);
/// ```
pub fn to_vec<T>(&self) -> Result<Vec<T>, NDArrayError> {
- if self.shape().is_some() { return Err(NDArrayError::EmptyArray); }
+ if self.shape().is_some() {
+ return Err(NDArrayError::EmptyArray);
+ }
let earr = NDArray::empty(
self.shape().ok_or(NDArrayError::MissingShape)?,
Context::cpu(0),
@@ -241,11 +243,10 @@ impl NDArray {
/// Copies the NDArray to another target NDArray.
pub fn copy_to_ndarray(&self, target: NDArray) -> Result<NDArray, NDArrayError> {
if self.dtype() != target.dtype() {
- return Err(
- NDArrayError::DataTypeMismatch {
- expected: self.dtype(),
- actual: target.dtype()
- });
+ return Err(NDArrayError::DataTypeMismatch {
+ expected: self.dtype(),
+ actual: target.dtype(),
+ });
}
check_call!(ffi::TVMArrayCopyFromTo(
@@ -307,7 +308,9 @@ macro_rules! impl_from_ndarray_rustndarray {
type Error = NDArrayError;
fn try_from(nd: &NDArray) -> Result<ArrayD<$type>, Self::Error> {
- if nd.shape().is_some() { return Err(NDArrayError::MissingShape); }
+ if nd.shape().is_some() {
+ return Err(NDArrayError::MissingShape);
+ }
assert_eq!(nd.dtype(), DataType::from_str($type_name)?, "Type mismatch");
Ok(Array::from_shape_vec(
&*nd.shape().ok_or(NDArrayError::MissingShape)?,
@@ -320,7 +323,9 @@ macro_rules! impl_from_ndarray_rustndarray {
type Error = NDArrayError;
fn try_from(nd: &mut NDArray) -> Result<ArrayD<$type>, Self::Error> {
- if nd.shape().is_some() { return Err(NDArrayError::MissingShape) };
+ if nd.shape().is_some() {
+ return Err(NDArrayError::MissingShape);
+ };
assert_eq!(nd.dtype(), DataType::from_str($type_name)?, "Type mismatch");
Ok(Array::from_shape_vec(
&*nd.shape().ok_or(NDArrayError::MissingShape)?,
diff --git a/rust/tvm-rt/src/object/mod.rs b/rust/tvm-rt/src/object/mod.rs
index 32da18e..9dcf836 100644
--- a/rust/tvm-rt/src/object/mod.rs
+++ b/rust/tvm-rt/src/object/mod.rs
@@ -2,8 +2,8 @@ use std::convert::TryFrom;
use std::convert::TryInto;
use std::ffi::CString;
-use crate::external_func;
use crate::errors::Error;
+use crate::external_func;
use tvm_sys::{ArgValue, RetValue};
diff --git a/rust/tvm-rt/src/object/object_ptr.rs b/rust/tvm-rt/src/object/object_ptr.rs
index 8e91878..ead37e3 100644
--- a/rust/tvm-rt/src/object/object_ptr.rs
+++ b/rust/tvm-rt/src/object/object_ptr.rs
@@ -194,7 +194,7 @@ impl<'a, T: IsObject> TryFrom<RetValue> for ObjectPtr<T> {
let optr = ObjectPtr::from_raw(handle).ok_or(Error::Null)?;
optr.downcast()
}
- _ => Err(Error::downcast(format!("{:?}", ret_value), "ObjectHandle"))
+ _ => Err(Error::downcast(format!("{:?}", ret_value), "ObjectHandle")),
}
}
}
diff --git a/rust/tvm-rt/src/to_function.rs b/rust/tvm-rt/src/to_function.rs
index dac37c8..0527b0c 100644
--- a/rust/tvm-rt/src/to_function.rs
+++ b/rust/tvm-rt/src/to_function.rs
@@ -32,8 +32,8 @@ use std::{
ptr, slice,
};
-use crate::errors::Error;
use super::Function;
+use crate::errors::Error;
pub use tvm_sys::{ffi, ArgValue, RetValue};
diff --git a/rust/tvm-rt/src/value.rs b/rust/tvm-rt/src/value.rs
index d9436b1..1812c0c 100644
--- a/rust/tvm-rt/src/value.rs
+++ b/rust/tvm-rt/src/value.rs
@@ -25,11 +25,7 @@ use std::convert::TryFrom;
// use std::ffi::c_void;
use crate::{ArgValue, Module, NDArray, RetValue};
-use tvm_sys::{
- errors::ValueDowncastError,
- ffi::{TVMModuleHandle},
- try_downcast,
-};
+use tvm_sys::{errors::ValueDowncastError, ffi::TVMModuleHandle, try_downcast};
macro_rules! impl_handle_val {
($type:ty, $variant:ident, $inner_type:ty, $ctor:path) => {
diff --git a/rust/tvm/src/ir/array.rs b/rust/tvm/src/ir/array.rs
index a426474..bd12252 100644
--- a/rust/tvm/src/ir/array.rs
+++ b/rust/tvm/src/ir/array.rs
@@ -1,14 +1,13 @@
-use std::convert::{TryFrom};
+use std::convert::TryFrom;
use std::marker::PhantomData;
use crate::runtime::object::{ObjectRef, ToObjectRef};
-use tvm_rt::RetValue;
use tvm_rt::external_func;
+use tvm_rt::RetValue;
use anyhow::Result;
-
#[derive(Clone)]
pub struct Array<T: ToObjectRef> {
object: ObjectRef,
diff --git a/rust/tvm/src/lib.rs b/rust/tvm/src/lib.rs
index 9315f7c..64252a4 100644
--- a/rust/tvm/src/lib.rs
+++ b/rust/tvm/src/lib.rs
@@ -30,14 +30,9 @@
//!
//! Checkout the `examples` repository for more details.
-pub use crate::{
- errors::*,
- function::Function,
- module::Module,
- ndarray::NDArray,
-};
+pub use crate::{errors::*, function::Function, module::Module, ndarray::NDArray};
-pub use tvm_rt::{Context, DeviceType, DataType};
+pub use tvm_rt::{Context, DataType, DeviceType};
pub use tvm_rt::context;
pub use tvm_rt::errors;