You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by jr...@apache.org on 2020/03/22 05:30:47 UTC

[incubator-tvm] 01/01: Fix up the final pieces

This is an automated email from the ASF dual-hosted git repository.

jroesch pushed a commit to branch rust-stablize
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git

commit 5939ce700ca24f10dab71fa0d0c5fab19fd2be4f
Author: Jared Roesch <jr...@octoml.ai>
AuthorDate: Sat Mar 21 22:30:29 2020 -0700

    Fix up the final pieces
---
 rust/Cargo.toml                               |   1 -
 rust/frontend/tests/callback/Cargo.toml       |   1 +
 rust/frontend/tests/callback/src/bin/error.rs |   5 +-
 rust/macros/Cargo.toml                        |   9 +-
 rust/macros/src/lib.rs                        | 123 ++++++++++++++++++++--
 rust/macros_raw/Cargo.toml                    |  36 -------
 rust/macros_raw/src/lib.rs                    | 141 --------------------------
 rust/runtime/tests/test_nn/build.rs           |   3 +-
 rust/runtime/tests/test_tvm_basic/build.rs    |  16 ++-
 rust/runtime/tests/test_tvm_basic/src/main.rs |   2 +-
 10 files changed, 141 insertions(+), 196 deletions(-)

diff --git a/rust/Cargo.toml b/rust/Cargo.toml
index 190a6eb..8467f6a 100644
--- a/rust/Cargo.toml
+++ b/rust/Cargo.toml
@@ -19,7 +19,6 @@
 members = [
 	"common",
 	"macros",
-	"macros_raw",
 	"runtime",
 	"runtime/tests/test_tvm_basic",
 	"runtime/tests/test_tvm_dso",
diff --git a/rust/frontend/tests/callback/Cargo.toml b/rust/frontend/tests/callback/Cargo.toml
index a452572..dfe80cc 100644
--- a/rust/frontend/tests/callback/Cargo.toml
+++ b/rust/frontend/tests/callback/Cargo.toml
@@ -19,6 +19,7 @@
 name = "callback"
 version = "0.0.0"
 authors = ["TVM Contributors"]
+edition = "2018"
 
 [dependencies]
 ndarray = "0.12"
diff --git a/rust/frontend/tests/callback/src/bin/error.rs b/rust/frontend/tests/callback/src/bin/error.rs
index 29bfd9a..c9f9a6f 100644
--- a/rust/frontend/tests/callback/src/bin/error.rs
+++ b/rust/frontend/tests/callback/src/bin/error.rs
@@ -19,10 +19,7 @@
 
 use std::panic;
 
-#[macro_use]
-extern crate tvm_frontend as tvm;
-
-use tvm::{errors::Error, *};
+use tvm_frontend::{errors::Error, *};
 
 fn main() {
     register_global_func! {
diff --git a/rust/macros/Cargo.toml b/rust/macros/Cargo.toml
index ff4f7d8..784b35e 100644
--- a/rust/macros/Cargo.toml
+++ b/rust/macros/Cargo.toml
@@ -19,13 +19,18 @@
 name = "tvm-macros"
 version = "0.1.1"
 license = "Apache-2.0"
-description = "Proc macros used by the TVM crates."
+description = "Procedural macros of the TVM crate."
 repository = "https://github.com/apache/incubator-tvm"
 readme = "README.md"
 keywords = ["tvm"]
 authors = ["TVM Contributors"]
 edition = "2018"
 
+[lib]
+proc-macro = true
 
 [dependencies]
-tvm-macros-raw = { path = "../macros_raw" }
+goblin = "0.0.24"
+proc-macro2 = "^1.0"
+quote = "1.0"
+syn = "1.0"
diff --git a/rust/macros/src/lib.rs b/rust/macros/src/lib.rs
index efd85d0..d1d86b6 100644
--- a/rust/macros/src/lib.rs
+++ b/rust/macros/src/lib.rs
@@ -17,12 +17,123 @@
  * under the License.
  */
 
-#[macro_use]
-extern crate tvm_macros_raw;
+extern crate proc_macro;
 
-#[macro_export]
-macro_rules! import_module {
-    ($module_path:literal) => {
-        $crate::import_module_raw!(file!(), $module_path);
+use std::{fs::File, io::Read};
+use syn::parse::{Parse, ParseStream, Result};
+use syn::{LitStr};
+use quote::quote;
+
+use std::path::PathBuf;
+
+struct ImportModule {
+    importing_file: LitStr,
+}
+
+impl Parse for ImportModule {
+    fn parse(input: ParseStream) -> Result<Self> {
+        let importing_file: LitStr = input.parse()?;
+        Ok(ImportModule {
+            importing_file,
+        })
+    }
+}
+
+#[proc_macro]
+pub fn import_module(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
+    let import_module_args = syn::parse_macro_input!(input as ImportModule);
+
+    let manifest = std::env::var("CARGO_MANIFEST_DIR")
+        .expect("variable should always be set by Cargo.");
+
+    let mut path = PathBuf::new();
+    path.push(manifest);
+    path = path.join(import_module_args.importing_file.value());
+
+    let mut fd = File::open(&path)
+        .unwrap_or_else(|_| panic!("Unable to find TVM object file at `{}`", path.display()));
+    let mut buffer = Vec::new();
+    fd.read_to_end(&mut buffer).unwrap();
+
+    let fn_names = match goblin::Object::parse(&buffer).unwrap() {
+        goblin::Object::Elf(elf) => elf
+            .syms
+            .iter()
+            .filter_map(|s| {
+                if s.st_type() == 0 || goblin::elf::sym::type_to_str(s.st_type()) == "FILE" {
+                    return None;
+                }
+                match elf.strtab.get(s.st_name) {
+                    Some(Ok(name)) if name != "" => {
+                        Some(syn::Ident::new(name, proc_macro2::Span::call_site()))
+                    }
+                    _ => None,
+                }
+            })
+            .collect::<Vec<_>>(),
+        goblin::Object::Mach(goblin::mach::Mach::Binary(obj)) => {
+            obj.symbols()
+                .filter_map(|s| match s {
+                    Ok((name, ref nlist))
+                        if nlist.is_global()
+                            && nlist.n_sect != 0
+                            && !name.ends_with("tvm_module_ctx") =>
+                    {
+                        Some(syn::Ident::new(
+                            if name.starts_with('_') {
+                                // Mach objects prepend a _ to globals.
+                                &name[1..]
+                            } else {
+                                &name
+                            },
+                            proc_macro2::Span::call_site(),
+                        ))
+                    }
+                    _ => None,
+                })
+                .collect::<Vec<_>>()
+        }
+        _ => panic!("Unsupported object format."),
+    };
+
+    let extern_fns = quote! {
+        mod ext {
+            extern "C" {
+                #(
+                    pub(super) fn #fn_names(
+                        args: *const tvm_runtime::ffi::TVMValue,
+                        type_codes: *const std::os::raw::c_int,
+                        num_args: std::os::raw::c_int
+                    ) -> std::os::raw::c_int;
+                )*
+            }
+        }
     };
+
+    let fns = quote! {
+        use tvm_runtime::{ffi::TVMValue, TVMArgValue, TVMRetValue, FuncCallError};
+        #extern_fns
+
+        #(
+            pub fn #fn_names(args: &[TVMArgValue]) -> Result<TVMRetValue, FuncCallError> {
+                let (values, type_codes): (Vec<TVMValue>, Vec<i32>) = args
+                   .into_iter()
+                   .map(|arg| {
+                       let (val, code) = arg.to_tvm_value();
+                       (val, code as i32)
+                   })
+                   .unzip();
+                let exit_code = unsafe {
+                    ext::#fn_names(values.as_ptr(), type_codes.as_ptr(), values.len() as i32)
+                };
+                if exit_code == 0 {
+                    Ok(TVMRetValue::default())
+                } else {
+                    Err(FuncCallError::get_with_context(stringify!(#fn_names).to_string()))
+                }
+            }
+        )*
+    };
+
+    proc_macro::TokenStream::from(fns)
 }
diff --git a/rust/macros_raw/Cargo.toml b/rust/macros_raw/Cargo.toml
deleted file mode 100644
index 9b3d3e9..0000000
--- a/rust/macros_raw/Cargo.toml
+++ /dev/null
@@ -1,36 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-[package]
-name = "tvm-macros-raw"
-version = "0.1.1"
-license = "Apache-2.0"
-description = "Proc macros used by the TVM crates."
-repository = "https://github.com/apache/incubator-tvm"
-readme = "README.md"
-keywords = ["tvm"]
-authors = ["TVM Contributors"]
-edition = "2018"
-
-[lib]
-proc-macro = true
-
-[dependencies]
-goblin = "0.0.24"
-proc-macro2 = "^1.0"
-quote = "1.0"
-syn = "1.0"
diff --git a/rust/macros_raw/src/lib.rs b/rust/macros_raw/src/lib.rs
deleted file mode 100644
index f518f88..0000000
--- a/rust/macros_raw/src/lib.rs
+++ /dev/null
@@ -1,141 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-extern crate proc_macro;
-
-use std::{fs::File, io::Read};
-use syn::parse::{Parse, ParseStream, Result};
-use syn::{Token, LitStr};
-use quote::quote;
-
-use std::path::PathBuf;
-
-struct ImportModule {
-    importing_file: LitStr,
-    module_path: LitStr,
-}
-
-impl Parse for ImportModule {
-    fn parse(input: ParseStream) -> Result<Self> {
-        let importing_file: LitStr = input.parse()?;
-        input.parse::<Token![,]>()?;
-        let module_path: LitStr = input.parse()?;
-        Ok(ImportModule {
-            importing_file,
-            module_path,
-        })
-    }
-}
-
-#[proc_macro]
-pub fn import_module_raw(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
-    let import_module_args = syn::parse_macro_input!(input as ImportModule);
-
-    let mut path = PathBuf::new();
-    path = path.join(import_module_args.importing_file.value());
-    path.pop(); // remove the filename
-    path.push(import_module_args.module_path.value());
-
-    let mut fd = File::open(&path)
-        .unwrap_or_else(|_| panic!("Unable to find TVM object file at `{}`", path.display()));
-    let mut buffer = Vec::new();
-    fd.read_to_end(&mut buffer).unwrap();
-
-    let fn_names = match goblin::Object::parse(&buffer).unwrap() {
-        goblin::Object::Elf(elf) => elf
-            .syms
-            .iter()
-            .filter_map(|s| {
-                if s.st_type() == 0 || goblin::elf::sym::type_to_str(s.st_type()) == "FILE" {
-                    return None;
-                }
-                match elf.strtab.get(s.st_name) {
-                    Some(Ok(name)) if name != "" => {
-                        Some(syn::Ident::new(name, proc_macro2::Span::call_site()))
-                    }
-                    _ => None,
-                }
-            })
-            .collect::<Vec<_>>(),
-        goblin::Object::Mach(goblin::mach::Mach::Binary(obj)) => {
-            obj.symbols()
-                .filter_map(|s| match s {
-                    Ok((name, ref nlist))
-                        if nlist.is_global()
-                            && nlist.n_sect != 0
-                            && !name.ends_with("tvm_module_ctx") =>
-                    {
-                        Some(syn::Ident::new(
-                            if name.starts_with('_') {
-                                // Mach objects prepend a _ to globals.
-                                &name[1..]
-                            } else {
-                                &name
-                            },
-                            proc_macro2::Span::call_site(),
-                        ))
-                    }
-                    _ => None,
-                })
-                .collect::<Vec<_>>()
-        }
-        _ => panic!("Unsupported object format."),
-    };
-
-    let extern_fns = quote! {
-        mod ext {
-            extern "C" {
-                #(
-                    pub(super) fn #fn_names(
-                        args: *const tvm_runtime::ffi::TVMValue,
-                        type_codes: *const std::os::raw::c_int,
-                        num_args: std::os::raw::c_int
-                    ) -> std::os::raw::c_int;
-                )*
-            }
-        }
-    };
-
-    let fns = quote! {
-        use tvm_runtime::{ffi::TVMValue, TVMArgValue, TVMRetValue, FuncCallError};
-        #extern_fns
-
-        #(
-            pub fn #fn_names(args: &[TVMArgValue]) -> Result<TVMRetValue, FuncCallError> {
-                let (values, type_codes): (Vec<TVMValue>, Vec<i32>) = args
-                   .into_iter()
-                   .map(|arg| {
-                       let (val, code) = arg.to_tvm_value();
-                       (val, code as i32)
-                   })
-                   .unzip();
-                let exit_code = unsafe {
-                    ext::#fn_names(values.as_ptr(), type_codes.as_ptr(), values.len() as i32)
-                };
-                if exit_code == 0 {
-                    Ok(TVMRetValue::default())
-                } else {
-                    Err(FuncCallError::get_with_context(stringify!(#fn_names).to_string()))
-                }
-            }
-        )*
-    };
-
-    proc_macro::TokenStream::from(fns)
-}
diff --git a/rust/runtime/tests/test_nn/build.rs b/rust/runtime/tests/test_nn/build.rs
index 2d0b066..f072a90 100644
--- a/rust/runtime/tests/test_nn/build.rs
+++ b/rust/runtime/tests/test_nn/build.rs
@@ -44,7 +44,8 @@ fn main() {
             .unwrap_or("")
     );
 
-    let mut builder = Builder::new(File::create(format!("{}/libgraph.a", out_dir)).unwrap());
+    let file = File::create(format!("{}/libtestnn.a", out_dir)).unwrap();
+    let mut builder = Builder::new(file);
     builder.append_path(format!("{}/graph.o", out_dir)).unwrap();
 
     println!("cargo:rustc-link-lib=static=graph");
diff --git a/rust/runtime/tests/test_tvm_basic/build.rs b/rust/runtime/tests/test_tvm_basic/build.rs
index 3439f9c..ade9e02 100644
--- a/rust/runtime/tests/test_tvm_basic/build.rs
+++ b/rust/runtime/tests/test_tvm_basic/build.rs
@@ -33,7 +33,7 @@ fn main() {
     }
 
     let obj_file = out_dir.join("test.o");
-    let lib_file = out_dir.join("libtest.a");
+    let lib_file = out_dir.join("libtest_basic.a");
 
     let output = Command::new(concat!(
         env!("CARGO_MANIFEST_DIR"),
@@ -53,9 +53,17 @@ fn main() {
             .unwrap_or("")
     );
 
-    let mut builder = Builder::new(File::create(lib_file).unwrap());
-    builder.append_path(obj_file).unwrap();
+    let mut builder = Builder::new(File::create(&lib_file).unwrap());
+    builder.append_path(&obj_file).unwrap();
+    drop(builder);
 
-    println!("cargo:rustc-link-lib=static=test");
+    let status = Command::new("ranlib")
+        .arg(&lib_file)
+        .status()
+        .expect("fdjlksafjdsa");
+
+    assert!(status.success());
+
+    println!("cargo:rustc-link-lib=static=test_basic");
     println!("cargo:rustc-link-search=native={}", out_dir.display());
 }
diff --git a/rust/runtime/tests/test_tvm_basic/src/main.rs b/rust/runtime/tests/test_tvm_basic/src/main.rs
index a83078e..653cb43 100644
--- a/rust/runtime/tests/test_tvm_basic/src/main.rs
+++ b/rust/runtime/tests/test_tvm_basic/src/main.rs
@@ -25,7 +25,7 @@ use ndarray::Array;
 use tvm_runtime::{DLTensor, Module as _, SystemLibModule};
 
 mod tvm_mod {
-    import_module!("../lib/test.o");
+    import_module!("lib/test.o");
 }
 
 fn main() {