You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by jr...@apache.org on 2020/11/05 21:41:02 UTC
[incubator-tvm] 06/23: Update CMake and delete old API
This is an automated email from the ASF dual-hosted git repository.
jroesch pushed a commit to branch cargo-build
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
commit db245535a4ea5b8bed3cdddd26a94825d2fcc9b5
Author: Jared Roesch <ro...@gmail.com>
AuthorDate: Thu Oct 15 01:03:03 2020 -0700
Update CMake and delete old API
---
CMakeLists.txt | 1 +
cmake/modules/RustExt.cmake | 25 ++-
include/tvm/parser/source_map.h | 2 -
rust/compiler-ext/Cargo.toml | 5 +-
rust/compiler-ext/src/lib.rs | 334 ++--------------------------------------
rust/tvm-rt/Cargo.toml | 15 +-
rust/tvm-sys/Cargo.toml | 1 +
rust/tvm-sys/build.rs | 1 +
rust/tvm/Cargo.toml | 22 ++-
rust/tvm/src/bin/tyck.rs | 1 -
rust/tvm/src/ir/diagnostics.rs | 42 +++--
rust/tvm/src/ir/mod.rs | 2 +-
rust/tvm/src/ir/relay/mod.rs | 3 +-
rust/tvm/src/ir/source_map.rs | 61 ++++++++
rust/tvm/src/ir/span.rs | 95 +++++++++---
src/ir/expr.cc | 11 ++
src/parser/source_map.cc | 11 --
17 files changed, 237 insertions(+), 395 deletions(-)
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 9f82754..58bb2f7 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -353,6 +353,7 @@ include(cmake/modules/contrib/ArmComputeLib.cmake)
include(cmake/modules/contrib/TensorRT.cmake)
include(cmake/modules/Git.cmake)
include(cmake/modules/LibInfo.cmake)
+include(cmake/modules/RustExt.cmake)
include(CheckCXXCompilerFlag)
if(NOT MSVC)
diff --git a/cmake/modules/RustExt.cmake b/cmake/modules/RustExt.cmake
index 45e46bd..2ad726e9 100644
--- a/cmake/modules/RustExt.cmake
+++ b/cmake/modules/RustExt.cmake
@@ -1,7 +1,14 @@
-if(USE_RUST_EXT)
- set(RUST_SRC_DIR "rust")
- set(CARGO_OUT_DIR "rust/target"
- set(COMPILER_EXT_PATH "${CARGO_OUT_DIR}/target/release/libcompiler_ext.dylib")
+if(USE_RUST_EXT AND NOT USE_RUST_EXT EQUAL OFF)
+ set(RUST_SRC_DIR "${CMAKE_SOURCE_DIR}/rust")
+ set(CARGO_OUT_DIR "${CMAKE_SOURCE_DIR}/rust/target")
+
+ if(USE_RUST_EXT STREQUAL "STATIC")
+ set(COMPILER_EXT_PATH "${CARGO_OUT_DIR}/release/libcompiler_ext.a")
+ elseif(USE_RUST_EXT STREQUAL "DYNAMIC")
+ set(COMPILER_EXT_PATH "${CARGO_OUT_DIR}/release/libcompiler_ext.so")
+ else()
+ message(FATAL_ERROR "invalid setting for RUST_EXT")
+ endif()
add_custom_command(
OUTPUT "${COMPILER_EXT_PATH}"
@@ -9,5 +16,11 @@ if(USE_RUST_EXT)
MAIN_DEPENDENCY "${RUST_SRC_DIR}"
WORKING_DIRECTORY "${RUST_SRC_DIR}/compiler-ext")
- target_link_libraries(tvm "${COMPILER_EXT_PATH}" PRIVATE)
-endif(USE_RUST_EXT)
+ add_custom_target(rust_ext ALL DEPENDS "${COMPILER_EXT_PATH}")
+
+ # TODO(@jroesch, @tkonolige): move this to CMake target
+ # target_link_libraries(tvm "${COMPILER_EXT_PATH}" PRIVATE)
+ list(APPEND TVM_LINKER_LIBS ${COMPILER_EXT_PATH})
+
+ add_definitions(-DRUST_COMPILER_EXT=1)
+endif()
diff --git a/include/tvm/parser/source_map.h b/include/tvm/parser/source_map.h
index 424af5c..a160c22 100644
--- a/include/tvm/parser/source_map.h
+++ b/include/tvm/parser/source_map.h
@@ -103,8 +103,6 @@ class SourceMap : public ObjectRef {
TVM_DLL SourceMap() : SourceMap(Map<SourceName, Source>()) {}
- TVM_DLL static SourceMap Global();
-
void Add(const Source& source);
SourceMapNode* operator->() {
diff --git a/rust/compiler-ext/Cargo.toml b/rust/compiler-ext/Cargo.toml
index 76d10eb..3b13bc5 100644
--- a/rust/compiler-ext/Cargo.toml
+++ b/rust/compiler-ext/Cargo.toml
@@ -6,8 +6,11 @@ edition = "2018"
# TODO(@jroesch): would be cool to figure out how to statically link instead.
[lib]
-crate-type = ["cdylib"]
+crate-type = ["staticlib", "cdylib"]
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
+tvm = { path = "../tvm", default-features = false, features = ["static-linking"] }
+log = "*"
+env_logger = "*"
diff --git a/rust/compiler-ext/src/lib.rs b/rust/compiler-ext/src/lib.rs
index 58bdd0c..3e37d21 100644
--- a/rust/compiler-ext/src/lib.rs
+++ b/rust/compiler-ext/src/lib.rs
@@ -17,321 +17,19 @@
* under the License.
*/
- use std::os::raw::c_int;
- use tvm::initialize;
- use tvm::ir::{tir, PrimExpr};
- use tvm::runtime::function::register_override;
- use tvm::runtime::map::Map;
- use tvm::runtime::object::{IsObject, IsObjectRef};
-
- use ordered_float::NotNan;
-
- mod interval;
- mod math;
-
- use math::{BoundsMap, Expr, RecExpr};
- use tvm::ir::arith::ConstIntBound;
- use tvm_rt::{ObjectRef, array::Array};
-
- macro_rules! downcast_match {
- ($id:ident; { $($t:ty => $arm:expr $(,)? )+ , else => $default:expr }) => {
- $( if let Ok($id) = $id.downcast_clone::<$t>() { $arm } else )+
- { $default }
- }
- }
-
- #[derive(Default)]
- struct VarMap {
- vars: Vec<(tvm::ir::tir::Var, egg::Symbol)>,
- objs: Vec<ObjectRef>,
- }
-
- impl VarMap {
- // FIXME this should eventually do the right thing for TVM variables
- // right now it depends on them having unique names
- fn make_symbol(&mut self, var: tvm::ir::tir::Var) -> egg::Symbol {
- let sym = egg::Symbol::from(var.name_hint.as_str().unwrap());
- for (_, sym2) in &self.vars {
- if sym == *sym2 {
- return sym;
- }
- }
-
- self.vars.push((var, sym));
- sym
- }
-
- fn get_symbol(&self, sym: egg::Symbol) -> tvm::ir::tir::Var {
- for (v, sym2) in &self.vars {
- if sym == *sym2 {
- return v.clone();
- }
- }
- panic!("Should have found a var")
- }
-
- fn push_obj(&mut self, obj: impl IsObjectRef) -> usize {
- let i = self.objs.len();
- self.objs.push(obj.upcast());
- i
- }
-
- fn get_obj<T: IsObjectRef>(&self, i: usize) -> T {
- self.objs[i].clone().downcast().expect("bad downcast")
- }
- }
-
- fn to_egg(vars: &mut VarMap, prim: &PrimExpr) -> RecExpr {
- fn build(vars: &mut VarMap, p: &PrimExpr, recexpr: &mut RecExpr) -> egg::Id {
- macro_rules! r {
- ($e:expr) => {
- build(vars, &$e, recexpr)
- };
- }
-
- let dt = recexpr.add(Expr::DataType(p.datatype));
- let e = downcast_match!(p; {
- tir::Add => Expr::Add([dt, r!(p.a), r!(p.b)]),
- tir::Sub => Expr::Sub([dt, r!(p.a), r!(p.b)]),
- tir::Mul => Expr::Mul([dt, r!(p.a), r!(p.b)]),
-
- tir::Div => Expr::Div([dt, r!(p.a), r!(p.b)]),
- tir::Mod => Expr::Mod([dt, r!(p.a), r!(p.b)]),
- tir::FloorDiv => Expr::FloorDiv([dt, r!(p.a), r!(p.b)]),
- tir::FloorMod => Expr::FloorMod([dt, r!(p.a), r!(p.b)]),
-
- tir::Min => Expr::Min([dt, r!(p.a), r!(p.b)]),
- tir::Max => Expr::Max([dt, r!(p.a), r!(p.b)]),
-
- tir::Ramp => Expr::Ramp([dt, r!(p.start), r!(p.stride), recexpr.add(Expr::Int(p.lanes.into()))]),
- tir::Select => Expr::Select([dt, r!(p.condition), r!(p.true_value), r!(p.false_value)]),
-
- tir::Eq => Expr::Equal([dt, r!(p.a), r!(p.b)]),
- tir::Ne => Expr::NotEqual([dt, r!(p.a), r!(p.b)]),
- tir::Lt => Expr::Less([dt, r!(p.a), r!(p.b)]),
- tir::Le => Expr::LessEqual([dt, r!(p.a), r!(p.b)]),
- tir::Gt => Expr::Greater([dt, r!(p.a), r!(p.b)]),
- tir::Ge => Expr::GreaterEqual([dt, r!(p.a), r!(p.b)]),
-
- tir::And => Expr::And([dt, r!(p.a), r!(p.b)]),
- tir::Or => Expr::Or([dt, r!(p.a), r!(p.b)]),
- tir::Not => Expr::Not([dt, r!(p.value)]),
-
- tir::Broadcast => Expr::Broadcast([dt, r!(p.value), recexpr.add(Expr::Int(p.lanes.into()))]),
-
- tir::Let => {
- let sym = recexpr.add(Expr::Symbol(vars.make_symbol(p.var.clone())));
- Expr::Let([dt, sym, r!(p.value), r!(p.body)])
- }
- tir::Var => {
- let sym = recexpr.add(Expr::Symbol(vars.make_symbol(p)));
- Expr::Var([dt, sym])
- }
- tir::IntImm => {
- let int = recexpr.add(Expr::Int(p.value));
- Expr::IntImm([dt, int])
- }
- tir::FloatImm => {
- let float = recexpr.add(Expr::Float(NotNan::new(p.value).unwrap()));
- Expr::FloatImm([dt, float])
- }
- tir::Cast => Expr::Cast([dt, r!(p.value)]),
-
- tir::Call => {
- let op = vars.push_obj(p.op.clone());
- let mut arg_ids = vec![dt];
- for i in 0..p.args.len() {
- let arg: PrimExpr = p.args.get(i as isize).expect("array get fail");
- arg_ids.push(r!(arg));
- }
- Expr::Call(op, arg_ids)
- },
- tir::Load => {
- let sym = recexpr.add(Expr::Symbol(vars.make_symbol(p.var.clone())));
- Expr::Load([dt, sym, r!(p.index), r!(p.predicate)])
- },
- else => {
- println!("Failed to downcast type '{}': {}", p.type_key(), tvm::runtime::debug_print(p.clone().upcast()).unwrap().to_str().unwrap());
- Expr::Object(vars.push_obj(p.clone()))
- }
- });
-
- recexpr.add(e)
- }
-
- let mut recexpr = Default::default();
- build(vars, prim, &mut recexpr);
- recexpr
- }
-
- fn from_egg(vars: &VarMap, recexpr: &RecExpr) -> PrimExpr {
- fn build(vars: &VarMap, nodes: &[Expr]) -> PrimExpr {
- let go = |i: &egg::Id| build(vars, &nodes[..usize::from(*i) + 1]);
- let get_dt = |i: &egg::Id| nodes[usize::from(*i)].to_dtype().unwrap();
- let prim: PrimExpr = match nodes.last().expect("cannot be empty") {
- Expr::Var([_dt, s]) => match &nodes[usize::from(*s)] {
- Expr::Symbol(sym) => vars.get_symbol(*sym).upcast(),
- n => panic!("Expected a symbol, got {:?}", n),
- },
- Expr::IntImm([dt, v]) => {
- let value = nodes[usize::from(*v)].to_int().unwrap();
- tir::IntImm::new(get_dt(dt), value).upcast()
- }
- Expr::FloatImm([dt, v]) => {
- let value = nodes[usize::from(*v)].to_float().unwrap();
- tir::FloatImm::new(get_dt(dt), value).upcast()
- }
- Expr::Let([dt, s, value, body]) => {
- let var = match &nodes[usize::from(*s)] {
- Expr::Symbol(sym) => vars.get_symbol(*sym).upcast(),
- n => panic!("Expected a symbol, got {:?}", n),
- };
- tir::Let::new(get_dt(dt), var, go(value), go(body)).upcast()
- }
- Expr::Load([dt, s, value, body]) => {
- let var = match &nodes[usize::from(*s)] {
- Expr::Symbol(sym) => vars.get_symbol(*sym).upcast(),
- n => panic!("Expected a symbol, got {:?}", n),
- };
- tir::Load::new(get_dt(dt), var, go(value), go(body)).upcast()
- }
-
- Expr::Add([dt, a, b]) => tir::Add::new(get_dt(dt), go(a), go(b)).upcast(),
- Expr::Sub([dt, a, b]) => tir::Sub::new(get_dt(dt), go(a), go(b)).upcast(),
- Expr::Mul([dt, a, b]) => tir::Mul::new(get_dt(dt), go(a), go(b)).upcast(),
-
- Expr::Div([dt, a, b]) => tir::Div::new(get_dt(dt), go(a), go(b)).upcast(),
- Expr::Mod([dt, a, b]) => tir::Mod::new(get_dt(dt), go(a), go(b)).upcast(),
- Expr::FloorDiv([dt, a, b]) => tir::FloorDiv::new(get_dt(dt), go(a), go(b)).upcast(),
- Expr::FloorMod([dt, a, b]) => tir::FloorMod::new(get_dt(dt), go(a), go(b)).upcast(),
-
- Expr::Min([dt, a, b]) => tir::Min::new(get_dt(dt), go(a), go(b)).upcast(),
- Expr::Max([dt, a, b]) => tir::Max::new(get_dt(dt), go(a), go(b)).upcast(),
-
- Expr::Equal([dt, a, b]) => tir::Eq::new(get_dt(dt), go(a), go(b)).upcast(),
- Expr::NotEqual([dt, a, b]) => tir::Ne::new(get_dt(dt), go(a), go(b)).upcast(),
- Expr::Less([dt, a, b]) => tir::Lt::new(get_dt(dt), go(a), go(b)).upcast(),
- Expr::LessEqual([dt, a, b]) => tir::Le::new(get_dt(dt), go(a), go(b)).upcast(),
- Expr::Greater([dt, a, b]) => tir::Gt::new(get_dt(dt), go(a), go(b)).upcast(),
- Expr::GreaterEqual([dt, a, b]) => tir::Ge::new(get_dt(dt), go(a), go(b)).upcast(),
-
- Expr::And([dt, a, b]) => tir::And::new(get_dt(dt), go(a), go(b)).upcast(),
- Expr::Or([dt, a, b]) => tir::Or::new(get_dt(dt), go(a), go(b)).upcast(),
- Expr::Not([dt, a]) => tir::Not::new(get_dt(dt), go(a)).upcast(),
-
- Expr::Ramp([dt, a, b, c]) => {
- let len = &nodes[usize::from(*c)];
- let i = len
- .to_int()
- .unwrap_or_else(|| panic!("Ramp lanes must be an int, got {:?}", len));
- tir::Ramp::new(get_dt(dt), go(a), go(b), i as i32).upcast()
- }
- Expr::Broadcast([dt, val, lanes]) => {
- let lanes = &nodes[usize::from(*lanes)];
- let lanes = lanes
- .to_int()
- .unwrap_or_else(|| panic!("Ramp lanes must be an int, got {:?}", lanes));
- println!("dt: {}", get_dt(dt));
- tir::Broadcast::new(get_dt(dt), go(val), lanes as i32).upcast()
- }
-
- Expr::Select([dt, a, b, c]) => tir::Select::new(get_dt(dt), go(a), go(b), go(c)).upcast(),
- Expr::Cast([dt, a]) => tir::Cast::new(get_dt(dt), go(a)).upcast(),
- Expr::Call(expr, args) => {
- let arg_exprs: Vec<PrimExpr> = args[1..].iter().map(go).collect();
- let arg_exprs = Array::from_vec(arg_exprs).expect("failed to convert args");
- tir::Call::new(get_dt(&args[0]), vars.get_obj(*expr), arg_exprs).upcast()
- }
-
- Expr::Object(i) => vars.get_obj(*i),
- node => panic!("I don't know how to extract {:?}", node),
- };
- assert_ne!(prim.datatype.bits(), 0);
- assert_ne!(prim.datatype.lanes(), 0);
- prim
- }
- build(vars, recexpr.as_ref())
- }
-
- fn run(
- input: PrimExpr,
- expected: Option<PrimExpr>,
- map: Map<PrimExpr, ConstIntBound>,
- ) -> Result<PrimExpr, String> {
- use egg::{CostFunction, Extractor};
-
- let mut bounds = BoundsMap::default();
- for (k, v) in map {
- if let Ok(var) = k.downcast_clone::<tir::Var>() {
- let sym: egg::Symbol = var.name_hint.as_str().unwrap().into();
- bounds.insert(sym, (v.min_value, v.max_value));
- } else {
- println!("Non var in bounds map: {}", tvm::ir::as_text(k));
- }
- }
-
- let mut vars = VarMap::default();
- let expr = to_egg(&mut vars, &input);
- let mut runner = math::default_runner();
- runner.egraph.analysis.bounds = bounds;
-
- let mut runner = runner.with_expr(&expr).run(&math::rules());
- // runner.print_report();
- let mut extractor = Extractor::new(&runner.egraph, math::CostFn);
- let root = runner.egraph.find(runner.roots[0]);
- let (cost, best) = extractor.find_best(root);
- if let Some(expected) = expected {
- let mut expected_vars = VarMap::default();
- let expected_expr = to_egg(&mut expected_vars, &expected);
- let expected_root = runner.egraph.add_expr(&expected_expr);
- if expected_root != root {
- return Err(format!(
- "\n\nFailed to prove them equal!\nExpected:\n{}\nFound:\n{}\n",
- expected_expr.pretty(40),
- best.pretty(40)
- ));
- }
- let expected_cost = math::CostFn.cost_rec(&expected_expr);
- if expected_cost != cost {
- let msg = format!(
- "\n\nCosts not equal: Expected {}:\n{}\nFound {}:\n{}\n",
- expected_cost,
- expected_expr.pretty(40),
- cost,
- best.pretty(40)
- );
- if cost < expected_cost {
- println!("egg wins: {}", msg)
- } else {
- return Err(msg);
- }
- }
- }
- log::info!(" returning... {}", best.pretty(60));
- Ok(from_egg(&vars, &best))
- }
-
- fn simplify(prim: PrimExpr, map: Map<PrimExpr, ConstIntBound>) -> Result<PrimExpr, tvm::Error> {
- log::debug!("map: {:?}", map);
- run(prim, None, map).map_err(tvm::Error::CallFailed)
- }
-
- fn simplify_and_check(
- prim: PrimExpr,
- check: PrimExpr,
- map: Map<PrimExpr, ConstIntBound>,
- ) -> Result<PrimExpr, tvm::Error> {
- log::debug!("check map: {:?}", map);
- run(prim, Some(check), map).map_err(tvm::Error::CallFailed)
- }
-
- initialize!({
- let _ = env_logger::try_init();
- // NOTE this print prevents a segfault (on Linux) for now...
- println!("Initializing simplifier... ");
- register_override(simplify, "egg.simplify", true).expect("failed to initialize simplifier");
- register_override(simplify_and_check, "egg.simplify_and_check", true)
- .expect("failed to initialize simplifier");
- log::debug!("done!");
- });
-
\ No newline at end of file
+use env_logger;
+use tvm;
+use tvm::runtime::function::register_override;
+
+fn test_fn() -> Result<(), tvm::Error> {
+ println!("Hello from Rust!");
+ Ok(())
+}
+
+#[no_mangle]
+fn compiler_ext_initialize() -> i32 {
+ let _ = env_logger::try_init();
+ register_override(test_fn, "rust_ext.test_fn", true).expect("failed to initialize simplifier");
+ log::debug!("done!");
+ return 0;
+}
diff --git a/rust/tvm-rt/Cargo.toml b/rust/tvm-rt/Cargo.toml
index acece5a..9660943 100644
--- a/rust/tvm-rt/Cargo.toml
+++ b/rust/tvm-rt/Cargo.toml
@@ -28,19 +28,26 @@ categories = ["api-bindings", "science"]
authors = ["TVM Contributors"]
edition = "2018"
+[features]
+default = ["dynamic-linking"]
+dynamic-linking = ["tvm-sys/bindings"]
+static-linking = []
+blas = ["ndarray/blas"]
+
[dependencies]
thiserror = "^1.0"
ndarray = "0.12"
num-traits = "0.2"
-tvm-sys = { version = "0.1", path = "../tvm-sys/", features = ["bindings"] }
tvm-macros = { version = "0.1", path = "../tvm-macros" }
paste = "0.1"
mashup = "0.1"
once_cell = "^1.3.1"
memoffset = "0.5.6"
+[dependencies.tvm-sys]
+version = "0.1"
+default-features = false
+path = "../tvm-sys/"
+
[dev-dependencies]
anyhow = "^1.0"
-
-[features]
-blas = ["ndarray/blas"]
diff --git a/rust/tvm-sys/Cargo.toml b/rust/tvm-sys/Cargo.toml
index 4e3fc98..c25a5bf 100644
--- a/rust/tvm-sys/Cargo.toml
+++ b/rust/tvm-sys/Cargo.toml
@@ -23,6 +23,7 @@ license = "Apache-2.0"
edition = "2018"
[features]
+default = ["bindings"]
bindings = []
[dependencies]
diff --git a/rust/tvm-sys/build.rs b/rust/tvm-sys/build.rs
index 05806c0..2d86c4b 100644
--- a/rust/tvm-sys/build.rs
+++ b/rust/tvm-sys/build.rs
@@ -60,6 +60,7 @@ fn main() -> Result<()> {
if cfg!(feature = "bindings") {
println!("cargo:rerun-if-env-changed=TVM_HOME");
println!("cargo:rustc-link-lib=dylib=tvm");
+ println!("cargo:rustc-link-lib=dylib=llvm-10");
println!("cargo:rustc-link-search={}/build", tvm_home);
}
diff --git a/rust/tvm/Cargo.toml b/rust/tvm/Cargo.toml
index 71a4b93..153a195 100644
--- a/rust/tvm/Cargo.toml
+++ b/rust/tvm/Cargo.toml
@@ -28,14 +28,24 @@ categories = ["api-bindings", "science"]
authors = ["TVM Contributors"]
edition = "2018"
+[features]
+default = ["python", "dynamic-linking"]
+dynamic-linking = ["tvm-rt/dynamic-linking"]
+static-linking = ["tvm-rt/static-linking"]
+blas = ["ndarray/blas"]
+python = ["pyo3"]
+
+[dependencies.tvm-rt]
+version = "0.1"
+default-features = false
+path = "../tvm-rt/"
+
[dependencies]
thiserror = "^1.0"
anyhow = "^1.0"
lazy_static = "1.1"
ndarray = "0.12"
num-traits = "0.2"
-tvm-rt = { version = "0.1", path = "../tvm-rt/" }
-tvm-sys = { version = "0.1", path = "../tvm-sys/" }
tvm-macros = { version = "*", path = "../tvm-macros/" }
paste = "0.1"
mashup = "0.1"
@@ -44,8 +54,6 @@ pyo3 = { version = "0.11.1", optional = true }
codespan-reporting = "0.9.5"
structopt = { version = "0.3" }
-[features]
-default = ["python"]
-
-blas = ["ndarray/blas"]
-python = ["pyo3"]
+[[bin]]
+name = "tyck"
+required-features = ["dynamic-linking"]
diff --git a/rust/tvm/src/bin/tyck.rs b/rust/tvm/src/bin/tyck.rs
index 9300412..b869012 100644
--- a/rust/tvm/src/bin/tyck.rs
+++ b/rust/tvm/src/bin/tyck.rs
@@ -6,7 +6,6 @@ use structopt::StructOpt;
use tvm::ir::diagnostics::codespan;
use tvm::ir::IRModule;
-
#[derive(Debug, StructOpt)]
#[structopt(name = "tyck", about = "Parse and type check a Relay program.")]
struct Opt {
diff --git a/rust/tvm/src/ir/diagnostics.rs b/rust/tvm/src/ir/diagnostics.rs
index d306185..b76e43f 100644
--- a/rust/tvm/src/ir/diagnostics.rs
+++ b/rust/tvm/src/ir/diagnostics.rs
@@ -17,17 +17,20 @@
* under the License.
*/
+use super::module::IRModule;
+use super::span::Span;
+use crate::runtime::function::Result;
+use crate::runtime::object::{Object, ObjectPtr, ObjectRef};
+use crate::runtime::{
+ array::Array,
+ function::{self, Function, ToFunction},
+ string::String as TString,
+};
/// The diagnostic interface to TVM, used for reporting and rendering
/// diagnostic information by the compiler. This module exposes
/// three key abstractions: a Diagnostic, the DiagnosticContext,
/// and the DiagnosticRenderer.
-
-use tvm_macros::{Object, external};
-use super::module::IRModule;
-use crate::runtime::{function::{self, Function, ToFunction}, array::Array, string::String as TString};
-use crate::runtime::object::{Object, ObjectPtr, ObjectRef};
-use crate::runtime::function::Result;
-use super::span::Span;
+use tvm_macros::{external, Object};
type SourceName = ObjectRef;
@@ -134,7 +137,6 @@ pub struct DiagnosticRendererNode {
// memory layout
}
-
// def render(self, ctx):
// """
// Render the provided context.
@@ -169,7 +171,8 @@ pub struct DiagnosticContextNode {
/// and contains a renderer.
impl DiagnosticContext {
pub fn new<F>(module: IRModule, render_func: F) -> DiagnosticContext
- where F: Fn(DiagnosticContext) -> () + 'static
+ where
+ F: Fn(DiagnosticContext) -> () + 'static,
{
let renderer = diagnostic_renderer(render_func.to_function()).unwrap();
let node = DiagnosticContextNode {
@@ -210,21 +213,16 @@ impl DiagnosticContext {
// If the render_func is None it will remove the current custom renderer
// and return to default behavior.
fn override_renderer<F>(opt_func: Option<F>) -> Result<()>
-where F: Fn(DiagnosticContext) -> () + 'static
+where
+ F: Fn(DiagnosticContext) -> () + 'static,
{
-
match opt_func {
None => clear_renderer(),
Some(func) => {
let func = func.to_function();
- let render_factory = move || {
- diagnostic_renderer(func.clone()).unwrap()
- };
+ let render_factory = move || diagnostic_renderer(func.clone()).unwrap();
- function::register_override(
- render_factory,
- "diagnostics.OverrideRenderer",
- true)?;
+ function::register_override(render_factory, "diagnostics.OverrideRenderer", true)?;
Ok(())
}
@@ -243,9 +241,9 @@ pub mod codespan {
End,
}
- struct SpanToBytes {
- inner: HashMap<std::String, HashMap<usize, (StartOrEnd,
- }
+ // struct SpanToBytes {
+ // inner: HashMap<std::String, HashMap<usize, (StartOrEnd,
+ // }
struct ByteRange<FileId> {
file_id: FileId,
@@ -276,7 +274,7 @@ pub mod codespan {
.with_message(message)
.with_code("EXXX")
.with_labels(vec![
- Label::primary(file_id, 328..331).with_message(inner_message),
+ Label::primary(file_id, 328..331).with_message(inner_message)
]);
diagnostic
diff --git a/rust/tvm/src/ir/mod.rs b/rust/tvm/src/ir/mod.rs
index 8450bd7..401b6c2 100644
--- a/rust/tvm/src/ir/mod.rs
+++ b/rust/tvm/src/ir/mod.rs
@@ -19,8 +19,8 @@
pub mod arith;
pub mod attrs;
-pub mod expr;
pub mod diagnostics;
+pub mod expr;
pub mod function;
pub mod module;
pub mod op;
diff --git a/rust/tvm/src/ir/relay/mod.rs b/rust/tvm/src/ir/relay/mod.rs
index e539221..4b09128 100644
--- a/rust/tvm/src/ir/relay/mod.rs
+++ b/rust/tvm/src/ir/relay/mod.rs
@@ -28,6 +28,7 @@ use super::attrs::Attrs;
use super::expr::BaseExprNode;
use super::function::BaseFuncNode;
use super::ty::{Type, TypeNode};
+use super::span::Span;
use tvm_macros::Object;
use tvm_rt::NDArray;
@@ -51,7 +52,7 @@ impl ExprNode {
span: ObjectRef::null(),
checked_type: Type::from(TypeNode {
base: Object::base_object::<TypeNode>(),
- span: ObjectRef::null(),
+ span: Span::empty(),
}),
}
}
diff --git a/rust/tvm/src/ir/source_map.rs b/rust/tvm/src/ir/source_map.rs
index e69de29..e6c0371 100644
--- a/rust/tvm/src/ir/source_map.rs
+++ b/rust/tvm/src/ir/source_map.rs
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+use crate::runtime::map::Map;
+use crate::runtime::object::Object;
+
+/// A program source in any language.
+///
+/// Could represent the source from an ML framework or a source of an IRModule.
+#[repr(C)]
+#[derive(Object)]
+#[type_key = "Source"]
+#[ref_key = "Source"]
+struct SourceNode {
+ pub base: Object,
+ /*! \brief The source name. */
+ SourceName source_name;
+
+ /*! \brief The raw source. */
+ String source;
+
+ /*! \brief A mapping of line breaks into the raw source. */
+ std::vector<std::pair<int, int>> line_map;
+}
+
+
+// class Source : public ObjectRef {
+// public:
+// TVM_DLL Source(SourceName src_name, std::string source);
+// TVM_DLL tvm::String GetLine(int line);
+
+// TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Source, ObjectRef, SourceNode);
+// };
+
+
+/// A mapping from a unique source name to source fragments.
+#[repr(C)]
+#[derive(Object)]
+#[type_key = "SourceMap"]
+#[ref_key = "SourceMap"]
+struct SourceMapNode {
+ pub base: Object,
+ /// The source mapping.
+ pub source_map: Map<SourceName, Source>,
+}
diff --git a/rust/tvm/src/ir/span.rs b/rust/tvm/src/ir/span.rs
index d2e19a2..c54fd51 100644
--- a/rust/tvm/src/ir/span.rs
+++ b/rust/tvm/src/ir/span.rs
@@ -1,22 +1,75 @@
/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-use crate::runtime::ObjectRef;
-
-pub type Span = ObjectRef;
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements. See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership. The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an
+* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+* KIND, either express or implied. See the License for the
+
+* specific language governing permissions and limitations
+* under the License.
+*/
+
+use crate::runtime::{ObjectRef, Object, String as TString};
+use tvm_macros::Object;
+
+/// A source file name, contained in a Span.
+
+#[repr(C)]
+#[derive(Object)]
+#[type_key = "SourceName"]
+#[ref_name = "SourceName"]
+pub struct SourceNameNode {
+ pub base: Object,
+ pub name: TString,
+}
+
+// /*!
+// * \brief The source name of a file span.
+// * \sa SourceNameNode, Span
+// */
+// class SourceName : public ObjectRef {
+// public:
+// /*!
+// * \brief Get an SourceName for a given operator name.
+// * Will raise an error if the source name has not been registered.
+// * \param name Name of the operator.
+// * \return SourceName valid throughout program lifetime.
+// */
+// TVM_DLL static SourceName Get(const String& name);
+
+// TVM_DEFINE_OBJECT_REF_METHODS(SourceName, ObjectRef, SourceNameNode);
+// };
+
+/// Span information for diagnostic purposes.
+#[repr(C)]
+#[derive(Object)]
+#[type_key = "Span"]
+#[ref_name = "Span"]
+pub struct SpanNode {
+ pub base: Object,
+ /// The source name.
+ pub source_name: SourceName,
+ /// The line number.
+ pub line: i32,
+ /// The column offset.
+ pub column: i32,
+ /// The end line number.
+ pub end_line: i32,
+ /// The end column number.
+ pub end_column: i32,
+}
+
+impl Span {
+ pub fn empty() -> Span {
+ todo!()
+ }
+}
diff --git a/src/ir/expr.cc b/src/ir/expr.cc
index 67e5cea..5110eef 100644
--- a/src/ir/expr.cc
+++ b/src/ir/expr.cc
@@ -192,4 +192,15 @@ TVM_REGISTER_GLOBAL("ir.DebugPrint").set_body_typed([](ObjectRef ref) {
return ss.str();
});
+
+
} // namespace tvm
+
+#ifdef RUST_COMPILER_EXT
+
+extern "C" {
+ int compiler_ext_initialize();
+ static int test = compiler_ext_initialize();
+}
+
+#endif
diff --git a/src/parser/source_map.cc b/src/parser/source_map.cc
index 7ac978c..7340f69 100644
--- a/src/parser/source_map.cc
+++ b/src/parser/source_map.cc
@@ -77,12 +77,6 @@ tvm::String Source::GetLine(int line) {
return line_text;
}
-// TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-// .set_dispatch<SourceNameNode>([](const ObjectRef& ref, ReprPrinter* p) {
-// auto* node = static_cast<const SourceNameNode*>(ref.get());
-// p->stream << "SourceName(" << node->name << ", " << node << ")";
-// });
-
TVM_REGISTER_NODE_TYPE(SourceMapNode);
SourceMap::SourceMap(Map<SourceName, Source> source_map) {
@@ -91,11 +85,6 @@ SourceMap::SourceMap(Map<SourceName, Source> source_map) {
data_ = std::move(n);
}
-// TODO(@jroesch): fix this
-static SourceMap global_source_map = SourceMap(Map<SourceName, Source>());
-
-SourceMap SourceMap::Global() { return global_source_map; }
-
void SourceMap::Add(const Source& source) { (*this)->source_map.Set(source->source_name, source); }
TVM_REGISTER_GLOBAL("SourceMapAdd").set_body_typed([](SourceMap map, String name, String content) {