You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by jr...@apache.org on 2020/11/05 21:41:15 UTC

[incubator-tvm] 19/23: WIP

This is an automated email from the ASF dual-hosted git repository.

jroesch pushed a commit to branch cargo-build
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git

commit a9ee3cb34c020a4debe75fc9a194303f22d00892
Author: Jared Roesch <ro...@gmail.com>
AuthorDate: Thu Oct 22 11:48:34 2020 -0700

    WIP
---
 rust/tvm-macros/Cargo.toml      |  2 +-
 rust/tvm-macros/src/external.rs | 43 +++++++++++++++++++++++++++++++++--------
 rust/tvm-macros/src/lib.rs      |  1 +
 rust/tvm-rt/src/object/mod.rs   |  2 +-
 rust/tvm/src/ir/module.rs       | 16 +++++++++++----
 5 files changed, 50 insertions(+), 14 deletions(-)

diff --git a/rust/tvm-macros/Cargo.toml b/rust/tvm-macros/Cargo.toml
index 63b8472..8e97d3b 100644
--- a/rust/tvm-macros/Cargo.toml
+++ b/rust/tvm-macros/Cargo.toml
@@ -33,5 +33,5 @@ proc-macro = true
 goblin = "^0.2"
 proc-macro2 = "^1.0"
 quote = "^1.0"
-syn = { version = "1.0.17", features = ["full", "extra-traits"] }
+syn = { version = "^1.0", features = ["full", "parsing", "extra-traits"] }
 proc-macro-error = "^1.0"
diff --git a/rust/tvm-macros/src/external.rs b/rust/tvm-macros/src/external.rs
index de8ada3..44a242c 100644
--- a/rust/tvm-macros/src/external.rs
+++ b/rust/tvm-macros/src/external.rs
@@ -21,9 +21,28 @@ use proc_macro_error::abort;
 use quote::quote;
 use syn::parse::{Parse, ParseStream, Result};
 
-use syn::{FnArg, Generics, Ident, Lit, Meta, NestedMeta, Pat, ReturnType, TraitItemMethod, Type};
+use syn::{Token, FnArg, Signature, Attribute, token::Semi, Visibility, Generics, Ident, Lit, Meta, NestedMeta, Pat, ReturnType, TraitItemMethod, Type};
+
+struct ExternalItem {
+    attrs: Vec<Attribute>,
+    visibility: Visibility,
+    sig: Signature,
+}
+
+impl Parse for ExternalItem {
+    fn parse(input: ParseStream) -> Result<Self> {
+        let item = ExternalItem {
+            attrs: input.call(Attribute::parse_outer)?,
+            visibility: input.parse()?,
+            sig: input.parse()?,
+        };
+        let _semi: Semi = input.parse()?;
+        Ok(item)
+    }
+}
 
 struct External {
+    visibility: Visibility,
     tvm_name: String,
     ident: Ident,
     generics: Generics,
@@ -33,7 +52,8 @@ struct External {
 
 impl Parse for External {
     fn parse(input: ParseStream) -> Result<Self> {
-        let method: TraitItemMethod = input.parse()?;
+        let method: ExternalItem = input.parse()?;
+        let visibility = method.visibility;
         assert_eq!(method.attrs.len(), 1);
         let sig = method.sig;
         let tvm_name = method.attrs[0].parse_meta()?;
@@ -48,8 +68,7 @@ impl Parse for External {
             }
             _ => panic!(),
         };
-        assert_eq!(method.default, None);
-        assert!(method.semi_token != None);
+
         let ident = sig.ident;
         let generics = sig.generics;
         let inputs = sig
@@ -61,6 +80,7 @@ impl Parse for External {
         let ret_type = sig.output;
 
         Ok(External {
+            visibility,
             tvm_name,
             ident,
             generics,
@@ -99,6 +119,7 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
     let mut items = Vec::new();
 
     for external in &ext_input.externs {
+        let visibility = &external.visibility;
         let name = &external.ident;
         let global_name = format!("global_{}", external.ident);
         let global_name = Ident::new(&global_name, Span::call_site());
@@ -127,15 +148,21 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
                         let ty: Type = *pat_type.ty.clone();
                         (ident, ty)
                     }
-                    _ => panic!(),
+                    _ =>  abort! { pat_type,
+                        "Only supports type parameters."
+                    }
                 },
-                _ => panic!(),
+                pat => abort! {
+                        pat, "invalid pattern type for function";
+
+                        note = "{:?} is not allowed here", pat;
+                    }
             })
             .unzip();
 
         let ret_type = match &external.ret_type {
             ReturnType::Type(_, rtype) => *rtype.clone(),
-            _ => panic!(),
+            ReturnType::Default => syn::parse_str::<Type>("()").unwrap(),
         };
 
         let global = quote! {
@@ -150,7 +177,7 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
         items.push(global);
 
         let wrapper = quote! {
-            pub fn #name<#(#ty_params),*>(#(#args : #tys),*) -> #result_type<#ret_type> {
+            #visibility fn #name<#(#ty_params),*>(#(#args : #tys),*) -> #result_type<#ret_type> {
                 let func_ref: #tvm_rt_crate::Function = #global_name.clone();
                 let func_ref: Box<dyn Fn(#(#tys),*) -> #result_type<#ret_type>> = func_ref.into();
                 let res: #ret_type = func_ref(#(#args),*)?;
diff --git a/rust/tvm-macros/src/lib.rs b/rust/tvm-macros/src/lib.rs
index ab75c92..32f2839 100644
--- a/rust/tvm-macros/src/lib.rs
+++ b/rust/tvm-macros/src/lib.rs
@@ -18,6 +18,7 @@
  */
 
 use proc_macro::TokenStream;
+use proc_macro_error::proc_macro_error;
 
 mod external;
 mod import_module;
diff --git a/rust/tvm-rt/src/object/mod.rs b/rust/tvm-rt/src/object/mod.rs
index 46e0342..e48c017 100644
--- a/rust/tvm-rt/src/object/mod.rs
+++ b/rust/tvm-rt/src/object/mod.rs
@@ -88,7 +88,7 @@ pub trait IsObjectRef:
 
 external! {
     #[name("ir.DebugPrint")]
-    fn debug_print(object: ObjectRef) -> CString;
+    pub fn debug_print(object: ObjectRef) -> CString;
     #[name("node.StructuralHash")]
     fn structural_hash(object: ObjectRef, map_free_vars: bool) -> ObjectRef;
     #[name("node.StructuralEqual")]
diff --git a/rust/tvm/src/ir/module.rs b/rust/tvm/src/ir/module.rs
index 8918bdc..3b60b0c 100644
--- a/rust/tvm/src/ir/module.rs
+++ b/rust/tvm/src/ir/module.rs
@@ -31,8 +31,10 @@ use crate::runtime::{external, Object, ObjectRef};
 use super::expr::GlobalVar;
 use super::function::BaseFunc;
 use super::source_map::SourceMap;
+use super::{ty::GlobalTypeVar, relay};
 
 // TODO(@jroesch): define type
+
 type TypeData = ObjectRef;
 type GlobalTypeVar = ObjectRef;
 
@@ -64,7 +66,7 @@ external! {
     fn parse_expression(file_name: TVMString, source: TVMString) -> IRModule;
     // Module methods
     #[name("ir.Module_Add")]
-    fn module_add_def(module: IRModule, type_name: GlobalVar, expr: relay::Expr, update: bool) -> ();
+    fn module_add(module: IRModule, type_name: GlobalVar, expr: relay::Expr, update: bool) -> ();
     #[name("ir.Module_AddDef")]
     fn module_add_def(module: IRModule, type_name: GlobalTypeVar, type_data: TypeData, update: bool) -> ();
     #[name("ir.Module_GetGlobalVar")]
@@ -78,15 +80,15 @@ external! {
     #[name("ir.Module_GetGlobalTypeVars")]
     fn module_get_global_type_vars() -> Array<GlobalTypeVar>;
     #[name("ir.Module_ContainGlobalVar")]
-    fn module_get_global_var(name: TVMString) -> bool;
+    fn module_contains_global_var(name: TVMString) -> bool;
     #[name("ir.Module_ContainGlobalTypeVar")]
-    fn module_get_global_type_var(name: TVMString) -> bool;
+    fn module_contains_global_type_var(name: TVMString) -> bool;
     #[name("ir.Module_LookupDef")]
     fn module_lookup_def(module: IRModule, global: GlobalTypeVar) -> TypeDef;
     #[name("ir.Module_LookupDef_str")]
     fn module_lookup_def_str(module: IRModule, global: GlobalTypeVar) -> TypeDef;
     #[name("ir.Module_LookupTag")]
-    fn module_lookup_tag(module: IRModule, tag: i32) -> Constructor;
+    fn module_lookup_tag(module: IRModule, tag: i32) -> relay::Constructor;
     #[name("ir.Module_FromExpr")]
     fn module_from_expr(expr: relay::Expr, funcs: Map<GlobalVar, BaseFunc>, types: Map<GlobalTypeVar, TypeData>) -> IRModule;
     #[name("ir.Module_Import")]
@@ -145,3 +147,9 @@ impl IRModule {
         module_lookup_str(self.clone(), name.into())
     }
 }
+
+#[cfg(test)]
+mod tests {
+    // #[test]
+    // fn
+}