You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by jr...@apache.org on 2020/06/08 21:01:25 UTC
[incubator-tvm] 03/04: Finish removing anyhow and work with new
external! macro
This is an automated email from the ASF dual-hosted git repository.
jroesch pushed a commit to branch rust-tvm
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
commit 0c55c39477979c75f2bf3e2e9974d90fde74fa26
Author: Jared Roesch <jr...@octoml.ai>
AuthorDate: Mon Jun 8 13:56:28 2020 -0700
Finish removing anyhow and work with new external! macro
---
rust/tvm-rt/src/context.rs | 12 ++++++++----
rust/tvm-rt/src/errors.rs | 14 ++++++++------
rust/tvm-rt/src/function.rs | 12 ++++++------
rust/tvm-rt/src/module.rs | 10 +++++-----
rust/tvm-rt/src/ndarray.rs | 2 +-
rust/tvm-rt/src/to_boxed_fn.rs | 29 ++++++++++++++++-------------
rust/tvm-rt/src/to_function.rs | 30 +++++++++++++++---------------
7 files changed, 59 insertions(+), 50 deletions(-)
diff --git a/rust/tvm-rt/src/context.rs b/rust/tvm-rt/src/context.rs
index 0c01d91..b1bdab5 100644
--- a/rust/tvm-rt/src/context.rs
+++ b/rust/tvm-rt/src/context.rs
@@ -1,13 +1,17 @@
-pub use tvm_sys::context::*;
-use tvm_sys::ffi;
use std::os::raw::c_void;
use std::ptr;
+use crate::errors::Error;
+
+use tvm_sys::ffi;
+
+pub use tvm_sys::context::*;
+
trait ContextExt {
/// Checks whether the context exists or not.
fn exist(&self) -> bool;
- fn sync(&self) -> anyhow::Result<()>;
+ fn sync(&self) -> Result<(), Error>;
fn max_threads_per_block(&self) -> isize;
fn warp_size(&self) -> isize;
fn max_shared_memory_per_block(&self) -> isize;
@@ -44,7 +48,7 @@ impl ContextExt for Context {
}
/// Synchronize the context stream.
- fn sync(&self) -> anyhow::Result<()> {
+ fn sync(&self) -> Result<(), Error> {
check_call!(ffi::TVMSynchronize(
self.device_type as i32,
self.device_id as i32,
diff --git a/rust/tvm-rt/src/errors.rs b/rust/tvm-rt/src/errors.rs
index 414484d..197c875 100644
--- a/rust/tvm-rt/src/errors.rs
+++ b/rust/tvm-rt/src/errors.rs
@@ -21,12 +21,6 @@ use crate::DataType;
use thiserror::Error;
#[derive(Debug, Error)]
-#[error("Handle `{name}` is null.")]
-pub struct NullHandleError {
- pub name: String,
-}
-
-#[derive(Debug, Error)]
#[error("Function was not set in `function::Builder`")]
pub struct FunctionNotFoundError;
@@ -62,6 +56,14 @@ pub enum Error {
Null,
#[error("failed to load module due to invalid path {0}")]
ModuleLoadPath(String),
+ #[error("failed to convert String into CString due to embedded nul character")]
+ ToCString(#[from] std::ffi::NulError),
+ #[error("failed to convert CString into String")]
+ FromCString(#[from] std::ffi::IntoStringError),
+ #[error("Handle `{0}` is null.")]
+ NullHandle(String),
+ #[error("{0}")]
+ NDArray(#[from] NDArrayError),
}
impl Error {
diff --git a/rust/tvm-rt/src/function.rs b/rust/tvm-rt/src/function.rs
index 4b34bc1..cca918a 100644
--- a/rust/tvm-rt/src/function.rs
+++ b/rust/tvm-rt/src/function.rs
@@ -138,7 +138,7 @@ impl Function {
}
/// Calls the function that created from `Builder`.
- pub fn invoke<'a>(&self, arg_buf: Vec<ArgValue<'a>>) -> Result<RetValue, Error> {
+ pub fn invoke<'a>(&self, arg_buf: Vec<ArgValue<'a>>) -> Result<RetValue> {
let num_args = arg_buf.len();
let (mut values, mut type_codes): (Vec<ffi::TVMValue>, Vec<ffi::TVMTypeCode>) =
arg_buf.iter().map(|arg| arg.to_tvm_value()).unzip();
@@ -192,7 +192,7 @@ impl From<Function> for RetValue {
impl TryFrom<RetValue> for Function {
type Error = Error;
- fn try_from(ret_value: RetValue) -> Result<Function, Self::Error> {
+ fn try_from(ret_value: RetValue) -> Result<Function> {
match ret_value {
RetValue::FuncHandle(handle) => Ok(Function::new(handle)),
_ => Err(Error::downcast(
@@ -212,7 +212,7 @@ impl<'a> From<Function> for ArgValue<'a> {
impl<'a> TryFrom<ArgValue<'a>> for Function {
type Error = Error;
- fn try_from(arg_value: ArgValue<'a>) -> Result<Function, Self::Error> {
+ fn try_from(arg_value: ArgValue<'a>) -> Result<Function> {
match arg_value {
ArgValue::FuncHandle(handle) => Ok(Function::new(handle)),
_ => Err(Error::downcast(
@@ -226,7 +226,7 @@ impl<'a> TryFrom<ArgValue<'a>> for Function {
impl<'a> TryFrom<&ArgValue<'a>> for Function {
type Error = Error;
- fn try_from(arg_value: &ArgValue<'a>) -> Result<Function, Self::Error> {
+ fn try_from(arg_value: &ArgValue<'a>) -> Result<Function> {
match arg_value {
ArgValue::FuncHandle(handle) => Ok(Function::new(*handle)),
_ => Err(Error::downcast(
@@ -264,7 +264,7 @@ impl<'a> TryFrom<&ArgValue<'a>> for Function {
/// let ret = boxed_fn(10, 20, 30).unwrap();
/// assert_eq!(ret, 60);
/// ```
-pub fn register<F, I, O, S: Into<String>>(f: F, name: S) -> Result<(), Error>
+pub fn register<F, I, O, S: Into<String>>(f: F, name: S) -> Result<()>
where
F: ToFunction<I, O>,
F: Typed<I, O>,
@@ -275,7 +275,7 @@ where
/// Register a function with explicit control over whether to override an existing registration or not.
///
/// See `register` for more details on how to use the registration API.
-pub fn register_override<F, I, O, S: Into<String>>(f: F, name: S, override_: bool) -> Result<(), Error>
+pub fn register_override<F, I, O, S: Into<String>>(f: F, name: S, override_: bool) -> Result<()>
where
F: ToFunction<I, O>,
F: Typed<I, O>,
diff --git a/rust/tvm-rt/src/module.rs b/rust/tvm-rt/src/module.rs
index b8b56f4..c161af5 100644
--- a/rust/tvm-rt/src/module.rs
+++ b/rust/tvm-rt/src/module.rs
@@ -78,9 +78,9 @@ impl Module {
));
if !fhandle.is_null() {
- return Err(errors::NullHandleError {
- name: name.into_string()?.to_string()
- })
+ return Err(errors::Error::NullHandle(
+ name.into_string()?.to_string()
+ ));
}
Ok(Function::new(fhandle))
@@ -98,13 +98,13 @@ impl Module {
.extension()
.unwrap_or_else(|| std::ffi::OsStr::new(""))
.to_str()
- .ok_or_else(|| Error::ModuleLoadPath(path.as_ref().display()))
+ .ok_or_else(|| Error::ModuleLoadPath(path.as_ref().display().to_string()))?
)?;
let cpath = CString::new(
path.as_ref()
.to_str()
- .ok_or_else(|| Error::ModuleLoadPath(path.as_ref().display()))
+ .ok_or_else(|| Error::ModuleLoadPath(path.as_ref().display().to_string()))?
)?;
let module = load_from_file(cpath, ext)?;
diff --git a/rust/tvm-rt/src/ndarray.rs b/rust/tvm-rt/src/ndarray.rs
index 593154d..9a17502 100644
--- a/rust/tvm-rt/src/ndarray.rs
+++ b/rust/tvm-rt/src/ndarray.rs
@@ -147,7 +147,7 @@ impl NDArray {
}
/// Shows whether the underlying ndarray is contiguous in memory or not.
- pub fn is_contiguous(&self) -> anyhow::Result<bool> {
+ pub fn is_contiguous(&self) -> Result<bool, crate::errors::Error> {
Ok(match self.strides() {
None => true,
Some(strides) => {
diff --git a/rust/tvm-rt/src/to_boxed_fn.rs b/rust/tvm-rt/src/to_boxed_fn.rs
index d2dde67..12e4351 100644
--- a/rust/tvm-rt/src/to_boxed_fn.rs
+++ b/rust/tvm-rt/src/to_boxed_fn.rs
@@ -29,9 +29,7 @@ pub use tvm_sys::{ffi, ArgValue, RetValue};
use crate::{Module, errors};
-use super::function::Function;
-
-type Result<T> = std::result::Result<T, errors::Error>;
+use super::function::{Function, Result};
pub trait ToBoxedFn {
fn to_boxed_fn(func: &'static Function) -> Box<Self>;
@@ -39,9 +37,10 @@ pub trait ToBoxedFn {
use std::convert::{TryFrom, TryInto};
-impl<O> ToBoxedFn for dyn Fn() -> Result<O>
+impl<E, O> ToBoxedFn for dyn Fn() -> Result<O>
where
- O: TryFrom<RetValue, Error = errors::Error>,
+ errors::Error: From<E>,
+ O: TryFrom<RetValue, Error = E>,
{
fn to_boxed_fn(func: &'static Function) -> Box<Self> {
Box::new(move || {
@@ -53,10 +52,11 @@ where
}
}
-impl<A, O> ToBoxedFn for dyn Fn(A) -> Result<O>
+impl<E, A, O> ToBoxedFn for dyn Fn(A) -> Result<O>
where
+ errors::Error: From<E>,
A: Into<ArgValue<'static>>,
- O: TryFrom<RetValue, Error = errors::Error>,
+ O: TryFrom<RetValue, Error = E>,
{
fn to_boxed_fn(func: &'static Function) -> Box<Self> {
Box::new(move |a: A| {
@@ -69,11 +69,12 @@ where
}
}
-impl<A, B, O> ToBoxedFn for dyn Fn(A, B) -> Result<O>
+impl<E, A, B, O> ToBoxedFn for dyn Fn(A, B) -> Result<O>
where
+ errors::Error: From<E>,
A: Into<ArgValue<'static>>,
B: Into<ArgValue<'static>>,
- O: TryFrom<RetValue, Error = errors::Error>,
+ O: TryFrom<RetValue, Error = E>,
{
fn to_boxed_fn(func: &'static Function) -> Box<Self> {
Box::new(move |a: A, b: B| {
@@ -87,12 +88,13 @@ where
}
}
-impl<A, B, C, O> ToBoxedFn for dyn Fn(A, B, C) -> Result<O>
+impl<E, A, B, C, O> ToBoxedFn for dyn Fn(A, B, C) -> Result<O>
where
+ errors::Error: From<E>,
A: Into<ArgValue<'static>>,
B: Into<ArgValue<'static>>,
C: Into<ArgValue<'static>>,
- O: TryFrom<RetValue, Error = errors::Error>,
+ O: TryFrom<RetValue, Error = E>,
{
fn to_boxed_fn(func: &'static Function) -> Box<Self> {
Box::new(move |a: A, b: B, c: C| {
@@ -107,13 +109,14 @@ where
}
}
-impl<A, B, C, D, O> ToBoxedFn for dyn Fn(A, B, C, D) -> Result<O>
+impl<E, A, B, C, D, O> ToBoxedFn for dyn Fn(A, B, C, D) -> Result<O>
where
+ errors::Error: From<E>,
A: Into<ArgValue<'static>>,
B: Into<ArgValue<'static>>,
C: Into<ArgValue<'static>>,
D: Into<ArgValue<'static>>,
- O: TryFrom<RetValue, Error = errors::Error>,
+ O: TryFrom<RetValue, Error = E>,
{
fn to_boxed_fn(func: &'static Function) -> Box<Self> {
Box::new(move |a: A, b: B, c: C, d: D| {
diff --git a/rust/tvm-rt/src/to_function.rs b/rust/tvm-rt/src/to_function.rs
index 0527b0c..9d8065c 100644
--- a/rust/tvm-rt/src/to_function.rs
+++ b/rust/tvm-rt/src/to_function.rs
@@ -32,7 +32,7 @@ use std::{
ptr, slice,
};
-use super::Function;
+use super::{Function, function::Result};
use crate::errors::Error;
pub use tvm_sys::{ffi, ArgValue, RetValue};
@@ -46,20 +46,20 @@ pub use tvm_sys::{ffi, ArgValue, RetValue};
///
/// And the implementation of it to `ToFunction`.
pub trait Typed<I, O> {
- fn args(i: &[ArgValue<'static>]) -> Result<I, Error>;
+ fn args(i: &[ArgValue<'static>]) -> Result<I>;
fn ret(o: O) -> RetValue;
}
-impl<'a, F> Typed<&'a [ArgValue<'static>], anyhow::Result<RetValue>> for F
+impl<'a, F> Typed<&'a [ArgValue<'static>], Result<RetValue>> for F
where
- F: Fn(&'a [ArgValue]) -> anyhow::Result<RetValue>,
+ F: Fn(&'a [ArgValue]) -> Result<RetValue>,
{
- fn args(args: &[ArgValue<'static>]) -> Result<&'a [ArgValue<'static>], Error> {
+ fn args(args: &[ArgValue<'static>]) -> Result<&'a [ArgValue<'static>]> {
// this is BAD but just hacking for time being
Ok(unsafe { std::mem::transmute(args) })
}
- fn ret(ret_value: anyhow::Result<RetValue>) -> RetValue {
+ fn ret(ret_value: Result<RetValue>) -> RetValue {
ret_value.unwrap()
}
}
@@ -68,7 +68,7 @@ impl<F, O: Into<RetValue>> Typed<(), O> for F
where
F: Fn() -> O,
{
- fn args(_args: &[ArgValue<'static>]) -> anyhow::Result<(), Error> {
+ fn args(_args: &[ArgValue<'static>]) -> Result<()> {
debug_assert!(_args.len() == 0);
Ok(())
}
@@ -84,7 +84,7 @@ where
Error: From<E>,
A: TryFrom<ArgValue<'static>, Error = E>,
{
- fn args(args: &[ArgValue<'static>]) -> Result<(A,), Error> {
+ fn args(args: &[ArgValue<'static>]) -> Result<(A,)> {
debug_assert!(args.len() == 1);
let a: A = args[0].clone().try_into()?;
Ok((a,))
@@ -102,7 +102,7 @@ where
A: TryFrom<ArgValue<'static>, Error = E>,
B: TryFrom<ArgValue<'static>, Error = E>,
{
- fn args(args: &[ArgValue<'static>]) -> Result<(A, B), Error> {
+ fn args(args: &[ArgValue<'static>]) -> Result<(A, B)> {
debug_assert!(args.len() == 2);
let a: A = args[0].clone().try_into()?;
let b: B = args[1].clone().try_into()?;
@@ -122,7 +122,7 @@ where
B: TryFrom<ArgValue<'static>, Error = E>,
C: TryFrom<ArgValue<'static>, Error = E>,
{
- fn args(args: &[ArgValue<'static>]) -> Result<(A, B, C), Error> {
+ fn args(args: &[ArgValue<'static>]) -> Result<(A, B, C)> {
debug_assert!(args.len() == 3);
let a: A = args[0].clone().try_into()?;
let b: B = args[1].clone().try_into()?;
@@ -140,7 +140,7 @@ pub trait ToFunction<I, O>: Sized {
fn into_raw(self) -> *mut Self::Handle;
- fn call(handle: *mut Self::Handle, args: &[ArgValue<'static>]) -> Result<RetValue, Error>
+ fn call(handle: *mut Self::Handle, args: &[ArgValue<'static>]) -> Result<RetValue>
where
Self: Typed<I, O>;
@@ -242,7 +242,7 @@ pub trait ToFunction<I, O>: Sized {
// }
// impl Typed<&[ArgValue<'static>], ()> for RawFunction {
-// fn args(i: &[ArgValue<'static>]) -> anyhow::Result<&[ArgValue<'static>]> {
+// fn args(i: &[ArgValue<'static>]) -> Result<&[ArgValue<'static>]> {
// Ok(i)
// }
@@ -279,7 +279,7 @@ where
Box::into_raw(ptr)
}
- fn call(handle: *mut Self::Handle, _: &[ArgValue<'static>]) -> Result<RetValue, Error>
+ fn call(handle: *mut Self::Handle, _: &[ArgValue<'static>]) -> Result<RetValue>
where
F: Typed<(), O>,
{
@@ -302,7 +302,7 @@ macro_rules! to_function_instance {
Box::into_raw(ptr)
}
- fn call(handle: *mut Self::Handle, args: &[ArgValue<'static>]) -> Result<RetValue, Error> where F: Typed<($($param,)+), O> {
+ fn call(handle: *mut Self::Handle, args: &[ArgValue<'static>]) -> Result<RetValue> where F: Typed<($($param,)+), O> {
// Ideally we shouldn't need to clone, probably doesn't really matter.
let args = F::args(args)?;
let out = unsafe {
@@ -338,7 +338,7 @@ mod tests {
f.to_function()
}
- // fn func_args(args: &[ArgValue<'static>]) -> anyhow::Result<RetValue> {
+ // fn func_args(args: &[ArgValue<'static>]) -> Result<RetValue> {
// Ok(10.into())
// }