You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by jr...@apache.org on 2020/11/05 21:41:15 UTC
[incubator-tvm] 19/23: WIP
This is an automated email from the ASF dual-hosted git repository.
jroesch pushed a commit to branch cargo-build
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
commit a9ee3cb34c020a4debe75fc9a194303f22d00892
Author: Jared Roesch <ro...@gmail.com>
AuthorDate: Thu Oct 22 11:48:34 2020 -0700
WIP
---
rust/tvm-macros/Cargo.toml | 2 +-
rust/tvm-macros/src/external.rs | 43 +++++++++++++++++++++++++++++++++--------
rust/tvm-macros/src/lib.rs | 1 +
rust/tvm-rt/src/object/mod.rs | 2 +-
rust/tvm/src/ir/module.rs | 16 +++++++++++----
5 files changed, 50 insertions(+), 14 deletions(-)
diff --git a/rust/tvm-macros/Cargo.toml b/rust/tvm-macros/Cargo.toml
index 63b8472..8e97d3b 100644
--- a/rust/tvm-macros/Cargo.toml
+++ b/rust/tvm-macros/Cargo.toml
@@ -33,5 +33,5 @@ proc-macro = true
goblin = "^0.2"
proc-macro2 = "^1.0"
quote = "^1.0"
-syn = { version = "1.0.17", features = ["full", "extra-traits"] }
+syn = { version = "^1.0", features = ["full", "parsing", "extra-traits"] }
proc-macro-error = "^1.0"
diff --git a/rust/tvm-macros/src/external.rs b/rust/tvm-macros/src/external.rs
index de8ada3..44a242c 100644
--- a/rust/tvm-macros/src/external.rs
+++ b/rust/tvm-macros/src/external.rs
@@ -21,9 +21,28 @@ use proc_macro_error::abort;
use quote::quote;
use syn::parse::{Parse, ParseStream, Result};
-use syn::{FnArg, Generics, Ident, Lit, Meta, NestedMeta, Pat, ReturnType, TraitItemMethod, Type};
+use syn::{Token, FnArg, Signature, Attribute, token::Semi, Visibility, Generics, Ident, Lit, Meta, NestedMeta, Pat, ReturnType, TraitItemMethod, Type};
+
+struct ExternalItem {
+ attrs: Vec<Attribute>,
+ visibility: Visibility,
+ sig: Signature,
+}
+
+impl Parse for ExternalItem {
+ fn parse(input: ParseStream) -> Result<Self> {
+ let item = ExternalItem {
+ attrs: input.call(Attribute::parse_outer)?,
+ visibility: input.parse()?,
+ sig: input.parse()?,
+ };
+ let _semi: Semi = input.parse()?;
+ Ok(item)
+ }
+}
struct External {
+ visibility: Visibility,
tvm_name: String,
ident: Ident,
generics: Generics,
@@ -33,7 +52,8 @@ struct External {
impl Parse for External {
fn parse(input: ParseStream) -> Result<Self> {
- let method: TraitItemMethod = input.parse()?;
+ let method: ExternalItem = input.parse()?;
+ let visibility = method.visibility;
assert_eq!(method.attrs.len(), 1);
let sig = method.sig;
let tvm_name = method.attrs[0].parse_meta()?;
@@ -48,8 +68,7 @@ impl Parse for External {
}
_ => panic!(),
};
- assert_eq!(method.default, None);
- assert!(method.semi_token != None);
+
let ident = sig.ident;
let generics = sig.generics;
let inputs = sig
@@ -61,6 +80,7 @@ impl Parse for External {
let ret_type = sig.output;
Ok(External {
+ visibility,
tvm_name,
ident,
generics,
@@ -99,6 +119,7 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let mut items = Vec::new();
for external in &ext_input.externs {
+ let visibility = &external.visibility;
let name = &external.ident;
let global_name = format!("global_{}", external.ident);
let global_name = Ident::new(&global_name, Span::call_site());
@@ -127,15 +148,21 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let ty: Type = *pat_type.ty.clone();
(ident, ty)
}
- _ => panic!(),
+ _ => abort! { pat_type,
+ "Only supports type parameters."
+ }
},
- _ => panic!(),
+ pat => abort! {
+ pat, "invalid pattern type for function";
+
+ note = "{:?} is not allowed here", pat;
+ }
})
.unzip();
let ret_type = match &external.ret_type {
ReturnType::Type(_, rtype) => *rtype.clone(),
- _ => panic!(),
+ ReturnType::Default => syn::parse_str::<Type>("()").unwrap(),
};
let global = quote! {
@@ -150,7 +177,7 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
items.push(global);
let wrapper = quote! {
- pub fn #name<#(#ty_params),*>(#(#args : #tys),*) -> #result_type<#ret_type> {
+ #visibility fn #name<#(#ty_params),*>(#(#args : #tys),*) -> #result_type<#ret_type> {
let func_ref: #tvm_rt_crate::Function = #global_name.clone();
let func_ref: Box<dyn Fn(#(#tys),*) -> #result_type<#ret_type>> = func_ref.into();
let res: #ret_type = func_ref(#(#args),*)?;
diff --git a/rust/tvm-macros/src/lib.rs b/rust/tvm-macros/src/lib.rs
index ab75c92..32f2839 100644
--- a/rust/tvm-macros/src/lib.rs
+++ b/rust/tvm-macros/src/lib.rs
@@ -18,6 +18,7 @@
*/
use proc_macro::TokenStream;
+use proc_macro_error::proc_macro_error;
mod external;
mod import_module;
diff --git a/rust/tvm-rt/src/object/mod.rs b/rust/tvm-rt/src/object/mod.rs
index 46e0342..e48c017 100644
--- a/rust/tvm-rt/src/object/mod.rs
+++ b/rust/tvm-rt/src/object/mod.rs
@@ -88,7 +88,7 @@ pub trait IsObjectRef:
external! {
#[name("ir.DebugPrint")]
- fn debug_print(object: ObjectRef) -> CString;
+ pub fn debug_print(object: ObjectRef) -> CString;
#[name("node.StructuralHash")]
fn structural_hash(object: ObjectRef, map_free_vars: bool) -> ObjectRef;
#[name("node.StructuralEqual")]
diff --git a/rust/tvm/src/ir/module.rs b/rust/tvm/src/ir/module.rs
index 8918bdc..3b60b0c 100644
--- a/rust/tvm/src/ir/module.rs
+++ b/rust/tvm/src/ir/module.rs
@@ -31,8 +31,10 @@ use crate::runtime::{external, Object, ObjectRef};
use super::expr::GlobalVar;
use super::function::BaseFunc;
use super::source_map::SourceMap;
+use super::{ty::GlobalTypeVar, relay};
// TODO(@jroesch): define type
+
type TypeData = ObjectRef;
type GlobalTypeVar = ObjectRef;
@@ -64,7 +66,7 @@ external! {
fn parse_expression(file_name: TVMString, source: TVMString) -> IRModule;
// Module methods
#[name("ir.Module_Add")]
- fn module_add_def(module: IRModule, type_name: GlobalVar, expr: relay::Expr, update: bool) -> ();
+ fn module_add(module: IRModule, type_name: GlobalVar, expr: relay::Expr, update: bool) -> ();
#[name("ir.Module_AddDef")]
fn module_add_def(module: IRModule, type_name: GlobalTypeVar, type_data: TypeData, update: bool) -> ();
#[name("ir.Module_GetGlobalVar")]
@@ -78,15 +80,15 @@ external! {
#[name("ir.Module_GetGlobalTypeVars")]
fn module_get_global_type_vars() -> Array<GlobalTypeVar>;
#[name("ir.Module_ContainGlobalVar")]
- fn module_get_global_var(name: TVMString) -> bool;
+ fn module_contains_global_var(name: TVMString) -> bool;
#[name("ir.Module_ContainGlobalTypeVar")]
- fn module_get_global_type_var(name: TVMString) -> bool;
+ fn module_contains_global_type_var(name: TVMString) -> bool;
#[name("ir.Module_LookupDef")]
fn module_lookup_def(module: IRModule, global: GlobalTypeVar) -> TypeDef;
#[name("ir.Module_LookupDef_str")]
fn module_lookup_def_str(module: IRModule, global: GlobalTypeVar) -> TypeDef;
#[name("ir.Module_LookupTag")]
- fn module_lookup_tag(module: IRModule, tag: i32) -> Constructor;
+ fn module_lookup_tag(module: IRModule, tag: i32) -> relay::Constructor;
#[name("ir.Module_FromExpr")]
fn module_from_expr(expr: relay::Expr, funcs: Map<GlobalVar, BaseFunc>, types: Map<GlobalTypeVar, TypeData>) -> IRModule;
#[name("ir.Module_Import")]
@@ -145,3 +147,9 @@ impl IRModule {
module_lookup_str(self.clone(), name.into())
}
}
+
+#[cfg(test)]
+mod tests {
+ // #[test]
+ // fn
+}