You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by jr...@apache.org on 2020/11/05 21:41:14 UTC

[incubator-tvm] 18/23: WIP

This is an automated email from the ASF dual-hosted git repository.

jroesch pushed a commit to branch cargo-build
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git

commit 49246bff342eb59757cecc34a9a9465a2e3c063d
Author: Jared Roesch <ro...@gmail.com>
AuthorDate: Wed Oct 21 14:09:37 2020 -0700

    WIP
---
 rust/tvm-macros/src/external.rs |  5 ++-
 rust/tvm-macros/src/lib.rs      |  1 +
 rust/tvm/src/ir/module.rs       | 67 +++++++++++++----------------------------
 3 files changed, 26 insertions(+), 47 deletions(-)

diff --git a/rust/tvm-macros/src/external.rs b/rust/tvm-macros/src/external.rs
index 802d7ae..de8ada3 100644
--- a/rust/tvm-macros/src/external.rs
+++ b/rust/tvm-macros/src/external.rs
@@ -17,6 +17,7 @@
  * under the License.
  */
 use proc_macro2::Span;
+use proc_macro_error::abort;
 use quote::quote;
 use syn::parse::{Parse, ParseStream, Result};
 
@@ -109,7 +110,9 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
             .iter()
             .map(|ty_param| match ty_param {
                 syn::GenericParam::Type(param) => param.clone(),
-                _ => panic!(),
+                _ => abort! { ty_param,
+                    "Only supports type parameters."
+                }
             })
             .collect();
 
diff --git a/rust/tvm-macros/src/lib.rs b/rust/tvm-macros/src/lib.rs
index 603e1ce..ab75c92 100644
--- a/rust/tvm-macros/src/lib.rs
+++ b/rust/tvm-macros/src/lib.rs
@@ -35,6 +35,7 @@ pub fn macro_impl(input: TokenStream) -> TokenStream {
     TokenStream::from(object::macro_impl(input))
 }
 
+#[proc_macro_error]
 #[proc_macro]
 pub fn external(input: TokenStream) -> TokenStream {
     external::macro_impl(input)
diff --git a/rust/tvm/src/ir/module.rs b/rust/tvm/src/ir/module.rs
index 443915f..8918bdc 100644
--- a/rust/tvm/src/ir/module.rs
+++ b/rust/tvm/src/ir/module.rs
@@ -63,6 +63,8 @@ external! {
     #[name("parser.ParseExpr")]
     fn parse_expression(file_name: TVMString, source: TVMString) -> IRModule;
     // Module methods
+    #[name("ir.Module_Add")]
+    fn module_add_def(module: IRModule, type_name: GlobalVar, expr: relay::Expr, update: bool) -> ();
     #[name("ir.Module_AddDef")]
     fn module_add_def(module: IRModule, type_name: GlobalTypeVar, type_data: TypeData, update: bool) -> ();
     #[name("ir.Module_GetGlobalVar")]
@@ -73,55 +75,28 @@ external! {
     fn module_lookup(module: IRModule, var: GlobalVar) -> BaseFunc;
     #[name("ir.Module_Lookup_str")]
     fn module_lookup_str(module: IRModule, name: TVMString) -> BaseFunc;
+    #[name("ir.Module_GetGlobalTypeVars")]
+    fn module_get_global_type_vars() -> Array<GlobalTypeVar>;
+    #[name("ir.Module_ContainGlobalVar")]
+    fn module_get_global_var(name: TVMString) -> bool;
+    #[name("ir.Module_ContainGlobalTypeVar")]
+    fn module_get_global_type_var(name: TVMString) -> bool;
+    #[name("ir.Module_LookupDef")]
+    fn module_lookup_def(module: IRModule, global: GlobalTypeVar) -> TypeDef;
+    #[name("ir.Module_LookupDef_str")]
+    fn module_lookup_def_str(module: IRModule, global: GlobalTypeVar) -> TypeDef;
+    #[name("ir.Module_LookupTag")]
+    fn module_lookup_tag(module: IRModule, tag: i32) -> Constructor;
+    #[name("ir.Module_FromExpr")]
+    fn module_from_expr(expr: relay::Expr, funcs: Map<GlobalVar, BaseFunc>, types: Map<GlobalTypeVar, TypeData>) -> IRModule;
+    #[name("ir.Module_Import")]
+    fn module_import(module: IRModule, path: TVMString);
+    #[name("ir.Module_ImportFromStd")]
+    fn module_import_from_std(module: IRModule, path: TVMString);
 }
 
-// TVM_REGISTER_GLOBAL("ir.Module_GetGlobalTypeVars")
-//     .set_body_method<IRModule>(&IRModuleNode::GetGlobalTypeVars);
+// Note: we don't expose update here as update is going to be removed.
 
-// TVM_REGISTER_GLOBAL("ir.Module_ContainGlobalVar")
-//     .set_body_method<IRModule>(&IRModuleNode::ContainGlobalVar);
-
-// TVM_REGISTER_GLOBAL("ir.Module_GetGlobalTypeVar")
-//     .set_body_method<IRModule>(&IRModuleNode::GetGlobalTypeVar);
-
-// TVM_REGISTER_GLOBAL("ir.Module_LookupDef").set_body_typed([](IRModule mod, GlobalTypeVar var) {
-//   return mod->LookupTypeDef(var);
-// });
-
-// TVM_REGISTER_GLOBAL("ir.Module_LookupDef_str").set_body_typed([](IRModule mod, String var) {
-//   return mod->LookupTypeDef(var);
-// });
-
-// TVM_REGISTER_GLOBAL("ir.Module_LookupTag").set_body_typed([](IRModule mod, int32_t tag) {
-//   return mod->LookupTag(tag);
-// });
-
-// TVM_REGISTER_GLOBAL("ir.Module_FromExpr")
-//     .set_body_typed([](RelayExpr e, tvm::Map<GlobalVar, BaseFunc> funcs,
-//                        tvm::Map<GlobalTypeVar, TypeData> type_defs) {
-//       return IRModule::FromExpr(e, funcs, type_defs);
-//     });
-
-// TVM_REGISTER_GLOBAL("ir.Module_Update").set_body_typed([](IRModule mod, IRModule from) {
-//   mod->Update(from);
-// });
-
-// TVM_REGISTER_GLOBAL("ir.Module_UpdateFunction")
-//     .set_body_typed([](IRModule mod, GlobalVar gv, BaseFunc func) { mod->Update(gv, func); });
-
-// TVM_REGISTER_GLOBAL("ir.Module_Import").set_body_typed([](IRModule mod, String path) {
-//   mod->Import(path);
-// });
-
-// TVM_REGISTER_GLOBAL("ir.Module_ImportFromStd").set_body_typed([](IRModule mod, String path) {
-//   mod->ImportFromStd(path);
-// });
-
-// TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-//     .set_dispatch<IRModuleNode>([](const ObjectRef& ref, ReprPrinter* p) {
-//       auto* node = static_cast<const IRModuleNode*>(ref.get());
-//       p->stream << "IRModuleNode( " << node->functions << ")";
-//     });
 
 impl IRModule {
     pub fn parse<N, S>(file_name: N, source: S) -> Result<IRModule>