You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by jr...@apache.org on 2020/11/05 21:41:08 UTC

[incubator-tvm] 12/23: Rust Diagnostics work

This is an automated email from the ASF dual-hosted git repository.

jroesch pushed a commit to branch cargo-build
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git

commit 6e1346748e08255f220c3e6cf72c59a8a3f6ef29
Author: Jared Roesch <jr...@octoml.ai>
AuthorDate: Fri Oct 16 14:17:01 2020 -0700

    Rust Diagnostics work
---
 rust/tvm-rt/src/errors.rs               |  15 ++++
 rust/tvm-rt/src/function.rs             |   7 +-
 rust/tvm/src/bin/tyck.rs                |  13 ++--
 rust/tvm/src/ir/diagnostics/codespan.rs | 126 ++++++++++++++++++++++----------
 4 files changed, 117 insertions(+), 44 deletions(-)

diff --git a/rust/tvm-rt/src/errors.rs b/rust/tvm-rt/src/errors.rs
index c884c56..3de9f3c 100644
--- a/rust/tvm-rt/src/errors.rs
+++ b/rust/tvm-rt/src/errors.rs
@@ -68,6 +68,21 @@ pub enum Error {
     Infallible(#[from] std::convert::Infallible),
     #[error("a panic occurred while executing a Rust packed function")]
     Panic,
+    #[error("one or more error diagnostics were emitted, please check diagnostic render for output.")]
+    DiagnosticError(String),
+    #[error("{0}")]
+    Raw(String),
+}
+
+impl Error {
+    pub fn from_raw_tvm(raw: &str) -> Error {
+        let err_header = raw.find(":").unwrap_or(0);
+        let (err_ty, err_content) = raw.split_at(err_header);
+        match err_ty {
+            "DiagnosticError" => Error::DiagnosticError((&err_content[1..]).into()),
+            _ => Error::Raw(raw.into()),
+        }
+    }
 }
 
 impl Error {
diff --git a/rust/tvm-rt/src/function.rs b/rust/tvm-rt/src/function.rs
index c7aebdd..173b60a 100644
--- a/rust/tvm-rt/src/function.rs
+++ b/rust/tvm-rt/src/function.rs
@@ -133,7 +133,12 @@ impl Function {
         };
 
         if ret_code != 0 {
-            return Err(Error::CallFailed(crate::get_last_error().into()));
+            let raw_error = crate::get_last_error();
+            let error = match Error::from_raw_tvm(raw_error) {
+                Error::Raw(string) => Error::CallFailed(string),
+                e => e,
+            };
+            return Err(error);
         }
 
         let rv = RetValue::from_tvm_value(ret_val, ret_type_code as u32);
diff --git a/rust/tvm/src/bin/tyck.rs b/rust/tvm/src/bin/tyck.rs
index fbab027..13470e7 100644
--- a/rust/tvm/src/bin/tyck.rs
+++ b/rust/tvm/src/bin/tyck.rs
@@ -4,7 +4,8 @@ use anyhow::Result;
 use structopt::StructOpt;
 
 use tvm::ir::diagnostics::codespan;
-use tvm::ir::IRModule;
+use tvm::ir::{self, IRModule};
+use tvm::runtime::Error;
 
 #[derive(Debug, StructOpt)]
 #[structopt(name = "tyck", about = "Parse and type check a Relay program.")]
@@ -18,11 +19,11 @@ fn main() -> Result<()> {
     codespan::init().expect("Rust based diagnostics");
     let opt = Opt::from_args();
     println!("{:?}", &opt);
-    let module = IRModule::parse_file(opt.input);
-
-    // for (k, v) in module.functions {
-    //     println!("Function name: {:?}", v);
-    // }
+    let _module = match IRModule::parse_file(opt.input) {
+        Err(ir::module::Error::TVM(Error::DiagnosticError(_))) => { return Ok(()) },
+        Err(e) => { return Err(e.into()); },
+        Ok(module) => module
+    };
 
     Ok(())
 }
diff --git a/rust/tvm/src/ir/diagnostics/codespan.rs b/rust/tvm/src/ir/diagnostics/codespan.rs
index 80a8784..9fc1ee0 100644
--- a/rust/tvm/src/ir/diagnostics/codespan.rs
+++ b/rust/tvm/src/ir/diagnostics/codespan.rs
@@ -4,6 +4,7 @@ use std::sync::{Arc, Mutex};
 use codespan_reporting::diagnostic::{Diagnostic as CDiagnostic, Label, Severity};
 use codespan_reporting::files::SimpleFiles;
 use codespan_reporting::term::termcolor::{ColorChoice, StandardStream};
+use codespan_reporting::term::{self, ColorArg};
 
 use crate::ir::source_map::*;
 use super::*;
@@ -13,8 +14,14 @@ enum StartOrEnd {
     End,
 }
 
+struct ByteRange<FileId> {
+    file_id: FileId,
+    start_pos: usize,
+    end_pos: usize,
+}
+
 enum FileSpanToByteRange {
-    AsciiSource,
+    AsciiSource(Vec<usize>),
     Utf8 {
         /// Map character regions which are larger then 1-byte to length.
         lengths: HashMap<isize, isize>,
@@ -27,7 +34,12 @@ impl FileSpanToByteRange {
         let mut last_index = 0;
         let mut is_ascii = true;
         if source.is_ascii() {
-            FileSpanToByteRange::AsciiSource
+            let line_lengths =
+                source
+                    .lines()
+                    .map(|line| line.len())
+                    .collect();
+            FileSpanToByteRange::AsciiSource(line_lengths)
         } else {
             panic!()
         }
@@ -41,6 +53,21 @@ impl FileSpanToByteRange {
         //     last_index = index;
         // }
     }
+
+    fn lookup(&self, span: &Span) -> ByteRange<String> {
+        use FileSpanToByteRange::*;
+
+        let source_name: String = span.source_name.name.as_str().unwrap().into();
+
+        match self {
+            AsciiSource(ref line_lengths) => {
+                let start_pos = (&line_lengths[0..(span.line - 1) as usize]).into_iter().sum::<usize>() + (span.column) as usize;
+                let end_pos = (&line_lengths[0..(span.end_line - 1) as usize]).into_iter().sum::<usize>() + (span.end_column) as usize;
+                ByteRange { file_id: source_name, start_pos, end_pos }
+            },
+            _ => panic!()
+        }
+    }
 }
 
 struct SpanToByteRange {
@@ -62,41 +89,22 @@ impl SpanToByteRange {
             self.map.insert(source_name, FileSpanToByteRange::new(source));
         }
     }
-}
-
-struct ByteRange<FileId> {
-    file_id: FileId,
-    start_pos: usize,
-    end_pos: usize,
-}
-
-
-pub fn to_diagnostic(diag: super::Diagnostic) -> CDiagnostic<String> {
-    let severity = match diag.level {
-        DiagnosticLevel::Error => Severity::Error,
-        DiagnosticLevel::Warning => Severity::Warning,
-        DiagnosticLevel::Note => Severity::Note,
-        DiagnosticLevel::Help => Severity::Help,
-        DiagnosticLevel::Bug => Severity::Bug,
-    };
-
-    let file_id = "foo".into(); // diag.span.source_name;
 
-    let message: String = diag.message.as_str().unwrap().into();
-    let inner_message: String = "expected `String`, found `Nat`".into();
-    let diagnostic = CDiagnostic::new(severity)
-        .with_message(message)
-        .with_code("EXXX")
-        .with_labels(vec![
-            Label::primary(file_id, 328..331).with_message(inner_message)
-        ]);
+    pub fn lookup(&self, span: &Span) -> ByteRange<String> {
+        let source_name: String = span.source_name.name.as_str().expect("foo").into();
 
-    diagnostic
+        match self.map.get(&source_name) {
+            Some(file_span_to_bytes) => file_span_to_bytes.lookup(span),
+            None => panic!(),
+        }
+    }
 }
 
 struct DiagnosticState {
     files: SimpleFiles<String, String>,
     span_map: SpanToByteRange,
+    // todo unify wih source name
+    source_to_id: HashMap<String, usize>,
 }
 
 impl DiagnosticState {
@@ -104,26 +112,70 @@ impl DiagnosticState {
         DiagnosticState {
             files: SimpleFiles::new(),
             span_map: SpanToByteRange::new(),
+            source_to_id: HashMap::new(),
         }
     }
+
+    fn add_source(&mut self, source: Source) {
+        let source_str: String = source.source.as_str().unwrap().into();
+        let source_name: String = source.source_name.name.as_str().unwrap().into();
+        self.span_map.add_source(source);
+        let file_id = self.files.add(source_name.clone(), source_str);
+        self.source_to_id.insert(source_name, file_id);
+    }
+
+    fn to_diagnostic(&self, diag: super::Diagnostic) -> CDiagnostic<usize> {
+        let severity = match diag.level {
+            DiagnosticLevel::Error => Severity::Error,
+            DiagnosticLevel::Warning => Severity::Warning,
+            DiagnosticLevel::Note => Severity::Note,
+            DiagnosticLevel::Help => Severity::Help,
+            DiagnosticLevel::Bug => Severity::Bug,
+        };
+
+        let source_name: String = diag.span.source_name.name.as_str().unwrap().into();
+        let file_id = *self.source_to_id.get(&source_name).unwrap();
+
+        let message: String = diag.message.as_str().unwrap().into();
+
+        let byte_range = self.span_map.lookup(&diag.span);
+
+        let diagnostic = CDiagnostic::new(severity)
+            .with_message(message)
+            .with_code("EXXX")
+            .with_labels(vec![
+                Label::primary(file_id, byte_range.start_pos..byte_range.end_pos)
+            ]);
+
+        diagnostic
+    }
 }
 
 fn renderer(state: &mut DiagnosticState, diag_ctx: DiagnosticContext) {
     let source_map = diag_ctx.module.source_map.clone();
-        for diagnostic in diag_ctx.diagnostics.clone() {
-            match source_map.source_map.get(&diagnostic.span.source_name) {
-                Err(err) => panic!(),
-                Ok(source) => state.span_map.add_source(source),
+    let writer = StandardStream::stderr(ColorChoice::Always);
+    let config = codespan_reporting::term::Config::default();
+    for diagnostic in diag_ctx.diagnostics.clone() {
+        match source_map.source_map.get(&diagnostic.span.source_name) {
+            Err(err) => panic!(err),
+            Ok(source) => {
+                state.add_source(source);
+                let diagnostic = state.to_diagnostic(diagnostic);
+                term::emit(
+                    &mut writer.lock(),
+                    &config,
+                    &state.files,
+                    &diagnostic).unwrap();
             }
-            println!("Diagnostic: {}", diagnostic.message);
         }
+    }
 }
 
 pub fn init() -> Result<()> {
     let diag_state = Arc::new(Mutex::new(DiagnosticState::new()));
     let render_fn = move |diag_ctx: DiagnosticContext| {
-        // let mut guard = diag_state.lock().unwrap();
-        // renderer(&mut *guard, diag_ctx);
+        let mut guard = diag_state.lock().unwrap();
+        renderer(&mut *guard, diag_ctx);
     };
 
     override_renderer(Some(render_fn))?;