You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by jr...@apache.org on 2020/03/22 05:30:47 UTC
[incubator-tvm] 01/01: Fix up the final pieces
This is an automated email from the ASF dual-hosted git repository.
jroesch pushed a commit to branch rust-stablize
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
commit 5939ce700ca24f10dab71fa0d0c5fab19fd2be4f
Author: Jared Roesch <jr...@octoml.ai>
AuthorDate: Sat Mar 21 22:30:29 2020 -0700
Fix up the final pieces
---
rust/Cargo.toml | 1 -
rust/frontend/tests/callback/Cargo.toml | 1 +
rust/frontend/tests/callback/src/bin/error.rs | 5 +-
rust/macros/Cargo.toml | 9 +-
rust/macros/src/lib.rs | 123 ++++++++++++++++++++--
rust/macros_raw/Cargo.toml | 36 -------
rust/macros_raw/src/lib.rs | 141 --------------------------
rust/runtime/tests/test_nn/build.rs | 3 +-
rust/runtime/tests/test_tvm_basic/build.rs | 16 ++-
rust/runtime/tests/test_tvm_basic/src/main.rs | 2 +-
10 files changed, 141 insertions(+), 196 deletions(-)
diff --git a/rust/Cargo.toml b/rust/Cargo.toml
index 190a6eb..8467f6a 100644
--- a/rust/Cargo.toml
+++ b/rust/Cargo.toml
@@ -19,7 +19,6 @@
members = [
"common",
"macros",
- "macros_raw",
"runtime",
"runtime/tests/test_tvm_basic",
"runtime/tests/test_tvm_dso",
diff --git a/rust/frontend/tests/callback/Cargo.toml b/rust/frontend/tests/callback/Cargo.toml
index a452572..dfe80cc 100644
--- a/rust/frontend/tests/callback/Cargo.toml
+++ b/rust/frontend/tests/callback/Cargo.toml
@@ -19,6 +19,7 @@
name = "callback"
version = "0.0.0"
authors = ["TVM Contributors"]
+edition = "2018"
[dependencies]
ndarray = "0.12"
diff --git a/rust/frontend/tests/callback/src/bin/error.rs b/rust/frontend/tests/callback/src/bin/error.rs
index 29bfd9a..c9f9a6f 100644
--- a/rust/frontend/tests/callback/src/bin/error.rs
+++ b/rust/frontend/tests/callback/src/bin/error.rs
@@ -19,10 +19,7 @@
use std::panic;
-#[macro_use]
-extern crate tvm_frontend as tvm;
-
-use tvm::{errors::Error, *};
+use tvm_frontend::{errors::Error, *};
fn main() {
register_global_func! {
diff --git a/rust/macros/Cargo.toml b/rust/macros/Cargo.toml
index ff4f7d8..784b35e 100644
--- a/rust/macros/Cargo.toml
+++ b/rust/macros/Cargo.toml
@@ -19,13 +19,18 @@
name = "tvm-macros"
version = "0.1.1"
license = "Apache-2.0"
-description = "Proc macros used by the TVM crates."
+description = "Procedural macros of the TVM crate."
repository = "https://github.com/apache/incubator-tvm"
readme = "README.md"
keywords = ["tvm"]
authors = ["TVM Contributors"]
edition = "2018"
+[lib]
+proc-macro = true
[dependencies]
-tvm-macros-raw = { path = "../macros_raw" }
+goblin = "0.0.24"
+proc-macro2 = "^1.0"
+quote = "1.0"
+syn = "1.0"
diff --git a/rust/macros/src/lib.rs b/rust/macros/src/lib.rs
index efd85d0..d1d86b6 100644
--- a/rust/macros/src/lib.rs
+++ b/rust/macros/src/lib.rs
@@ -17,12 +17,123 @@
* under the License.
*/
-#[macro_use]
-extern crate tvm_macros_raw;
+extern crate proc_macro;
-#[macro_export]
-macro_rules! import_module {
- ($module_path:literal) => {
- $crate::import_module_raw!(file!(), $module_path);
+use std::{fs::File, io::Read};
+use syn::parse::{Parse, ParseStream, Result};
+use syn::{LitStr};
+use quote::quote;
+
+use std::path::PathBuf;
+
+struct ImportModule {
+ importing_file: LitStr,
+}
+
+impl Parse for ImportModule {
+ fn parse(input: ParseStream) -> Result<Self> {
+ let importing_file: LitStr = input.parse()?;
+ Ok(ImportModule {
+ importing_file,
+ })
+ }
+}
+
+#[proc_macro]
+pub fn import_module(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
+ let import_module_args = syn::parse_macro_input!(input as ImportModule);
+
+ let manifest = std::env::var("CARGO_MANIFEST_DIR")
+ .expect("variable should always be set by Cargo.");
+
+ let mut path = PathBuf::new();
+ path.push(manifest);
+ path = path.join(import_module_args.importing_file.value());
+
+ let mut fd = File::open(&path)
+ .unwrap_or_else(|_| panic!("Unable to find TVM object file at `{}`", path.display()));
+ let mut buffer = Vec::new();
+ fd.read_to_end(&mut buffer).unwrap();
+
+ let fn_names = match goblin::Object::parse(&buffer).unwrap() {
+ goblin::Object::Elf(elf) => elf
+ .syms
+ .iter()
+ .filter_map(|s| {
+ if s.st_type() == 0 || goblin::elf::sym::type_to_str(s.st_type()) == "FILE" {
+ return None;
+ }
+ match elf.strtab.get(s.st_name) {
+ Some(Ok(name)) if name != "" => {
+ Some(syn::Ident::new(name, proc_macro2::Span::call_site()))
+ }
+ _ => None,
+ }
+ })
+ .collect::<Vec<_>>(),
+ goblin::Object::Mach(goblin::mach::Mach::Binary(obj)) => {
+ obj.symbols()
+ .filter_map(|s| match s {
+ Ok((name, ref nlist))
+ if nlist.is_global()
+ && nlist.n_sect != 0
+ && !name.ends_with("tvm_module_ctx") =>
+ {
+ Some(syn::Ident::new(
+ if name.starts_with('_') {
+ // Mach objects prepend a _ to globals.
+ &name[1..]
+ } else {
+ &name
+ },
+ proc_macro2::Span::call_site(),
+ ))
+ }
+ _ => None,
+ })
+ .collect::<Vec<_>>()
+ }
+ _ => panic!("Unsupported object format."),
+ };
+
+ let extern_fns = quote! {
+ mod ext {
+ extern "C" {
+ #(
+ pub(super) fn #fn_names(
+ args: *const tvm_runtime::ffi::TVMValue,
+ type_codes: *const std::os::raw::c_int,
+ num_args: std::os::raw::c_int
+ ) -> std::os::raw::c_int;
+ )*
+ }
+ }
};
+
+ let fns = quote! {
+ use tvm_runtime::{ffi::TVMValue, TVMArgValue, TVMRetValue, FuncCallError};
+ #extern_fns
+
+ #(
+ pub fn #fn_names(args: &[TVMArgValue]) -> Result<TVMRetValue, FuncCallError> {
+ let (values, type_codes): (Vec<TVMValue>, Vec<i32>) = args
+ .into_iter()
+ .map(|arg| {
+ let (val, code) = arg.to_tvm_value();
+ (val, code as i32)
+ })
+ .unzip();
+ let exit_code = unsafe {
+ ext::#fn_names(values.as_ptr(), type_codes.as_ptr(), values.len() as i32)
+ };
+ if exit_code == 0 {
+ Ok(TVMRetValue::default())
+ } else {
+ Err(FuncCallError::get_with_context(stringify!(#fn_names).to_string()))
+ }
+ }
+ )*
+ };
+
+ proc_macro::TokenStream::from(fns)
}
diff --git a/rust/macros_raw/Cargo.toml b/rust/macros_raw/Cargo.toml
deleted file mode 100644
index 9b3d3e9..0000000
--- a/rust/macros_raw/Cargo.toml
+++ /dev/null
@@ -1,36 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-[package]
-name = "tvm-macros-raw"
-version = "0.1.1"
-license = "Apache-2.0"
-description = "Proc macros used by the TVM crates."
-repository = "https://github.com/apache/incubator-tvm"
-readme = "README.md"
-keywords = ["tvm"]
-authors = ["TVM Contributors"]
-edition = "2018"
-
-[lib]
-proc-macro = true
-
-[dependencies]
-goblin = "0.0.24"
-proc-macro2 = "^1.0"
-quote = "1.0"
-syn = "1.0"
diff --git a/rust/macros_raw/src/lib.rs b/rust/macros_raw/src/lib.rs
deleted file mode 100644
index f518f88..0000000
--- a/rust/macros_raw/src/lib.rs
+++ /dev/null
@@ -1,141 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-extern crate proc_macro;
-
-use std::{fs::File, io::Read};
-use syn::parse::{Parse, ParseStream, Result};
-use syn::{Token, LitStr};
-use quote::quote;
-
-use std::path::PathBuf;
-
-struct ImportModule {
- importing_file: LitStr,
- module_path: LitStr,
-}
-
-impl Parse for ImportModule {
- fn parse(input: ParseStream) -> Result<Self> {
- let importing_file: LitStr = input.parse()?;
- input.parse::<Token![,]>()?;
- let module_path: LitStr = input.parse()?;
- Ok(ImportModule {
- importing_file,
- module_path,
- })
- }
-}
-
-#[proc_macro]
-pub fn import_module_raw(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
- let import_module_args = syn::parse_macro_input!(input as ImportModule);
-
- let mut path = PathBuf::new();
- path = path.join(import_module_args.importing_file.value());
- path.pop(); // remove the filename
- path.push(import_module_args.module_path.value());
-
- let mut fd = File::open(&path)
- .unwrap_or_else(|_| panic!("Unable to find TVM object file at `{}`", path.display()));
- let mut buffer = Vec::new();
- fd.read_to_end(&mut buffer).unwrap();
-
- let fn_names = match goblin::Object::parse(&buffer).unwrap() {
- goblin::Object::Elf(elf) => elf
- .syms
- .iter()
- .filter_map(|s| {
- if s.st_type() == 0 || goblin::elf::sym::type_to_str(s.st_type()) == "FILE" {
- return None;
- }
- match elf.strtab.get(s.st_name) {
- Some(Ok(name)) if name != "" => {
- Some(syn::Ident::new(name, proc_macro2::Span::call_site()))
- }
- _ => None,
- }
- })
- .collect::<Vec<_>>(),
- goblin::Object::Mach(goblin::mach::Mach::Binary(obj)) => {
- obj.symbols()
- .filter_map(|s| match s {
- Ok((name, ref nlist))
- if nlist.is_global()
- && nlist.n_sect != 0
- && !name.ends_with("tvm_module_ctx") =>
- {
- Some(syn::Ident::new(
- if name.starts_with('_') {
- // Mach objects prepend a _ to globals.
- &name[1..]
- } else {
- &name
- },
- proc_macro2::Span::call_site(),
- ))
- }
- _ => None,
- })
- .collect::<Vec<_>>()
- }
- _ => panic!("Unsupported object format."),
- };
-
- let extern_fns = quote! {
- mod ext {
- extern "C" {
- #(
- pub(super) fn #fn_names(
- args: *const tvm_runtime::ffi::TVMValue,
- type_codes: *const std::os::raw::c_int,
- num_args: std::os::raw::c_int
- ) -> std::os::raw::c_int;
- )*
- }
- }
- };
-
- let fns = quote! {
- use tvm_runtime::{ffi::TVMValue, TVMArgValue, TVMRetValue, FuncCallError};
- #extern_fns
-
- #(
- pub fn #fn_names(args: &[TVMArgValue]) -> Result<TVMRetValue, FuncCallError> {
- let (values, type_codes): (Vec<TVMValue>, Vec<i32>) = args
- .into_iter()
- .map(|arg| {
- let (val, code) = arg.to_tvm_value();
- (val, code as i32)
- })
- .unzip();
- let exit_code = unsafe {
- ext::#fn_names(values.as_ptr(), type_codes.as_ptr(), values.len() as i32)
- };
- if exit_code == 0 {
- Ok(TVMRetValue::default())
- } else {
- Err(FuncCallError::get_with_context(stringify!(#fn_names).to_string()))
- }
- }
- )*
- };
-
- proc_macro::TokenStream::from(fns)
-}
diff --git a/rust/runtime/tests/test_nn/build.rs b/rust/runtime/tests/test_nn/build.rs
index 2d0b066..f072a90 100644
--- a/rust/runtime/tests/test_nn/build.rs
+++ b/rust/runtime/tests/test_nn/build.rs
@@ -44,7 +44,8 @@ fn main() {
.unwrap_or("")
);
- let mut builder = Builder::new(File::create(format!("{}/libgraph.a", out_dir)).unwrap());
+ let file = File::create(format!("{}/libtestnn.a", out_dir)).unwrap();
+ let mut builder = Builder::new(file);
builder.append_path(format!("{}/graph.o", out_dir)).unwrap();
println!("cargo:rustc-link-lib=static=graph");
diff --git a/rust/runtime/tests/test_tvm_basic/build.rs b/rust/runtime/tests/test_tvm_basic/build.rs
index 3439f9c..ade9e02 100644
--- a/rust/runtime/tests/test_tvm_basic/build.rs
+++ b/rust/runtime/tests/test_tvm_basic/build.rs
@@ -33,7 +33,7 @@ fn main() {
}
let obj_file = out_dir.join("test.o");
- let lib_file = out_dir.join("libtest.a");
+ let lib_file = out_dir.join("libtest_basic.a");
let output = Command::new(concat!(
env!("CARGO_MANIFEST_DIR"),
@@ -53,9 +53,17 @@ fn main() {
.unwrap_or("")
);
- let mut builder = Builder::new(File::create(lib_file).unwrap());
- builder.append_path(obj_file).unwrap();
+ let mut builder = Builder::new(File::create(&lib_file).unwrap());
+ builder.append_path(&obj_file).unwrap();
+ drop(builder);
- println!("cargo:rustc-link-lib=static=test");
+ let status = Command::new("ranlib")
+ .arg(&lib_file)
+ .status()
+ .expect("fdjlksafjdsa");
+
+ assert!(status.success());
+
+ println!("cargo:rustc-link-lib=static=test_basic");
println!("cargo:rustc-link-search=native={}", out_dir.display());
}
diff --git a/rust/runtime/tests/test_tvm_basic/src/main.rs b/rust/runtime/tests/test_tvm_basic/src/main.rs
index a83078e..653cb43 100644
--- a/rust/runtime/tests/test_tvm_basic/src/main.rs
+++ b/rust/runtime/tests/test_tvm_basic/src/main.rs
@@ -25,7 +25,7 @@ use ndarray::Array;
use tvm_runtime::{DLTensor, Module as _, SystemLibModule};
mod tvm_mod {
- import_module!("../lib/test.o");
+ import_module!("lib/test.o");
}
fn main() {