You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by jr...@apache.org on 2020/11/05 21:41:08 UTC
[incubator-tvm] 12/23: Rust Diagnostics work
This is an automated email from the ASF dual-hosted git repository.
jroesch pushed a commit to branch cargo-build
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
commit 6e1346748e08255f220c3e6cf72c59a8a3f6ef29
Author: Jared Roesch <jr...@octoml.ai>
AuthorDate: Fri Oct 16 14:17:01 2020 -0700
Rust Diagnostics work
---
rust/tvm-rt/src/errors.rs | 15 ++++
rust/tvm-rt/src/function.rs | 7 +-
rust/tvm/src/bin/tyck.rs | 13 ++--
rust/tvm/src/ir/diagnostics/codespan.rs | 126 ++++++++++++++++++++++----------
4 files changed, 117 insertions(+), 44 deletions(-)
diff --git a/rust/tvm-rt/src/errors.rs b/rust/tvm-rt/src/errors.rs
index c884c56..3de9f3c 100644
--- a/rust/tvm-rt/src/errors.rs
+++ b/rust/tvm-rt/src/errors.rs
@@ -68,6 +68,21 @@ pub enum Error {
Infallible(#[from] std::convert::Infallible),
#[error("a panic occurred while executing a Rust packed function")]
Panic,
+ #[error("one or more error diagnostics were emitted, please check diagnostic render for output.")]
+ DiagnosticError(String),
+ #[error("{0}")]
+ Raw(String),
+}
+
+impl Error {
+ pub fn from_raw_tvm(raw: &str) -> Error {
+ let err_header = raw.find(":").unwrap_or(0);
+ let (err_ty, err_content) = raw.split_at(err_header);
+ match err_ty {
+ "DiagnosticError" => Error::DiagnosticError((&err_content[1..]).into()),
+ _ => Error::Raw(raw.into()),
+ }
+ }
}
impl Error {
diff --git a/rust/tvm-rt/src/function.rs b/rust/tvm-rt/src/function.rs
index c7aebdd..173b60a 100644
--- a/rust/tvm-rt/src/function.rs
+++ b/rust/tvm-rt/src/function.rs
@@ -133,7 +133,12 @@ impl Function {
};
if ret_code != 0 {
- return Err(Error::CallFailed(crate::get_last_error().into()));
+ let raw_error = crate::get_last_error();
+ let error = match Error::from_raw_tvm(raw_error) {
+ Error::Raw(string) => Error::CallFailed(string),
+ e => e,
+ };
+ return Err(error);
}
let rv = RetValue::from_tvm_value(ret_val, ret_type_code as u32);
diff --git a/rust/tvm/src/bin/tyck.rs b/rust/tvm/src/bin/tyck.rs
index fbab027..13470e7 100644
--- a/rust/tvm/src/bin/tyck.rs
+++ b/rust/tvm/src/bin/tyck.rs
@@ -4,7 +4,8 @@ use anyhow::Result;
use structopt::StructOpt;
use tvm::ir::diagnostics::codespan;
-use tvm::ir::IRModule;
+use tvm::ir::{self, IRModule};
+use tvm::runtime::Error;
#[derive(Debug, StructOpt)]
#[structopt(name = "tyck", about = "Parse and type check a Relay program.")]
@@ -18,11 +19,11 @@ fn main() -> Result<()> {
codespan::init().expect("Rust based diagnostics");
let opt = Opt::from_args();
println!("{:?}", &opt);
- let module = IRModule::parse_file(opt.input);
-
- // for (k, v) in module.functions {
- // println!("Function name: {:?}", v);
- // }
+ let _module = match IRModule::parse_file(opt.input) {
+ Err(ir::module::Error::TVM(Error::DiagnosticError(_))) => { return Ok(()) },
+ Err(e) => { return Err(e.into()); },
+ Ok(module) => module
+ };
Ok(())
}
diff --git a/rust/tvm/src/ir/diagnostics/codespan.rs b/rust/tvm/src/ir/diagnostics/codespan.rs
index 80a8784..9fc1ee0 100644
--- a/rust/tvm/src/ir/diagnostics/codespan.rs
+++ b/rust/tvm/src/ir/diagnostics/codespan.rs
@@ -4,6 +4,7 @@ use std::sync::{Arc, Mutex};
use codespan_reporting::diagnostic::{Diagnostic as CDiagnostic, Label, Severity};
use codespan_reporting::files::SimpleFiles;
use codespan_reporting::term::termcolor::{ColorChoice, StandardStream};
+use codespan_reporting::term::{self, ColorArg};
use crate::ir::source_map::*;
use super::*;
@@ -13,8 +14,14 @@ enum StartOrEnd {
End,
}
+struct ByteRange<FileId> {
+ file_id: FileId,
+ start_pos: usize,
+ end_pos: usize,
+}
+
enum FileSpanToByteRange {
- AsciiSource,
+ AsciiSource(Vec<usize>),
Utf8 {
/// Map character regions which are larger then 1-byte to length.
lengths: HashMap<isize, isize>,
@@ -27,7 +34,12 @@ impl FileSpanToByteRange {
let mut last_index = 0;
let mut is_ascii = true;
if source.is_ascii() {
- FileSpanToByteRange::AsciiSource
+ let line_lengths =
+ source
+ .lines()
+ .map(|line| line.len())
+ .collect();
+ FileSpanToByteRange::AsciiSource(line_lengths)
} else {
panic!()
}
@@ -41,6 +53,21 @@ impl FileSpanToByteRange {
// last_index = index;
// }
}
+
+ fn lookup(&self, span: &Span) -> ByteRange<String> {
+ use FileSpanToByteRange::*;
+
+ let source_name: String = span.source_name.name.as_str().unwrap().into();
+
+ match self {
+ AsciiSource(ref line_lengths) => {
+ let start_pos = (&line_lengths[0..(span.line - 1) as usize]).into_iter().sum::<usize>() + (span.column) as usize;
+ let end_pos = (&line_lengths[0..(span.end_line - 1) as usize]).into_iter().sum::<usize>() + (span.end_column) as usize;
+ ByteRange { file_id: source_name, start_pos, end_pos }
+ },
+ _ => panic!()
+ }
+ }
}
struct SpanToByteRange {
@@ -62,41 +89,22 @@ impl SpanToByteRange {
self.map.insert(source_name, FileSpanToByteRange::new(source));
}
}
-}
-
-struct ByteRange<FileId> {
- file_id: FileId,
- start_pos: usize,
- end_pos: usize,
-}
-
-
-pub fn to_diagnostic(diag: super::Diagnostic) -> CDiagnostic<String> {
- let severity = match diag.level {
- DiagnosticLevel::Error => Severity::Error,
- DiagnosticLevel::Warning => Severity::Warning,
- DiagnosticLevel::Note => Severity::Note,
- DiagnosticLevel::Help => Severity::Help,
- DiagnosticLevel::Bug => Severity::Bug,
- };
-
- let file_id = "foo".into(); // diag.span.source_name;
- let message: String = diag.message.as_str().unwrap().into();
- let inner_message: String = "expected `String`, found `Nat`".into();
- let diagnostic = CDiagnostic::new(severity)
- .with_message(message)
- .with_code("EXXX")
- .with_labels(vec![
- Label::primary(file_id, 328..331).with_message(inner_message)
- ]);
+ pub fn lookup(&self, span: &Span) -> ByteRange<String> {
+ let source_name: String = span.source_name.name.as_str().expect("foo").into();
- diagnostic
+ match self.map.get(&source_name) {
+ Some(file_span_to_bytes) => file_span_to_bytes.lookup(span),
+ None => panic!(),
+ }
+ }
}
struct DiagnosticState {
files: SimpleFiles<String, String>,
span_map: SpanToByteRange,
+ // todo unify wih source name
+ source_to_id: HashMap<String, usize>,
}
impl DiagnosticState {
@@ -104,26 +112,70 @@ impl DiagnosticState {
DiagnosticState {
files: SimpleFiles::new(),
span_map: SpanToByteRange::new(),
+ source_to_id: HashMap::new(),
}
}
+
+ fn add_source(&mut self, source: Source) {
+ let source_str: String = source.source.as_str().unwrap().into();
+ let source_name: String = source.source_name.name.as_str().unwrap().into();
+ self.span_map.add_source(source);
+ let file_id = self.files.add(source_name.clone(), source_str);
+ self.source_to_id.insert(source_name, file_id);
+ }
+
+ fn to_diagnostic(&self, diag: super::Diagnostic) -> CDiagnostic<usize> {
+ let severity = match diag.level {
+ DiagnosticLevel::Error => Severity::Error,
+ DiagnosticLevel::Warning => Severity::Warning,
+ DiagnosticLevel::Note => Severity::Note,
+ DiagnosticLevel::Help => Severity::Help,
+ DiagnosticLevel::Bug => Severity::Bug,
+ };
+
+ let source_name: String = diag.span.source_name.name.as_str().unwrap().into();
+ let file_id = *self.source_to_id.get(&source_name).unwrap();
+
+ let message: String = diag.message.as_str().unwrap().into();
+
+ let byte_range = self.span_map.lookup(&diag.span);
+
+ let diagnostic = CDiagnostic::new(severity)
+ .with_message(message)
+ .with_code("EXXX")
+ .with_labels(vec![
+ Label::primary(file_id, byte_range.start_pos..byte_range.end_pos)
+ ]);
+
+ diagnostic
+ }
}
fn renderer(state: &mut DiagnosticState, diag_ctx: DiagnosticContext) {
let source_map = diag_ctx.module.source_map.clone();
- for diagnostic in diag_ctx.diagnostics.clone() {
- match source_map.source_map.get(&diagnostic.span.source_name) {
- Err(err) => panic!(),
- Ok(source) => state.span_map.add_source(source),
+ let writer = StandardStream::stderr(ColorChoice::Always);
+ let config = codespan_reporting::term::Config::default();
+ for diagnostic in diag_ctx.diagnostics.clone() {
+ match source_map.source_map.get(&diagnostic.span.source_name) {
+ Err(err) => panic!(err),
+ Ok(source) => {
+ state.add_source(source);
+ let diagnostic = state.to_diagnostic(diagnostic);
+ term::emit(
+ &mut writer.lock(),
+ &config,
+ &state.files,
+ &diagnostic).unwrap();
}
- println!("Diagnostic: {}", diagnostic.message);
}
+ }
}
pub fn init() -> Result<()> {
let diag_state = Arc::new(Mutex::new(DiagnosticState::new()));
let render_fn = move |diag_ctx: DiagnosticContext| {
- // let mut guard = diag_state.lock().unwrap();
- // renderer(&mut *guard, diag_ctx);
+ let mut guard = diag_state.lock().unwrap();
+ renderer(&mut *guard, diag_ctx);
};
override_renderer(Some(render_fn))?;