You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by jr...@apache.org on 2020/11/05 21:41:05 UTC

[incubator-tvm] 09/23: Improve Rust bindings

This is an automated email from the ASF dual-hosted git repository.

jroesch pushed a commit to branch cargo-build
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git

commit 4cd1bbc79a71873676927c22528f5911e3f4072d
Author: Jared Roesch <ro...@gmail.com>
AuthorDate: Thu Oct 15 20:24:43 2020 -0700

    Improve Rust bindings
---
 rust/tvm/src/ir/diagnostics/codespan.rs            | 131 +++++++++++++++++++++
 .../src/ir/{diagnostics.rs => diagnostics/mod.rs}  |  69 +----------
 rust/tvm/src/ir/source_map.rs                      |   3 +-
 rust/tvm/test.rly                                  |   3 +-
 src/ir/diagnostic.cc                               |   1 +
 5 files changed, 138 insertions(+), 69 deletions(-)

diff --git a/rust/tvm/src/ir/diagnostics/codespan.rs b/rust/tvm/src/ir/diagnostics/codespan.rs
new file mode 100644
index 0000000..80a8784
--- /dev/null
+++ b/rust/tvm/src/ir/diagnostics/codespan.rs
@@ -0,0 +1,131 @@
+use std::collections::HashMap;
+use std::sync::{Arc, Mutex};
+
+use codespan_reporting::diagnostic::{Diagnostic as CDiagnostic, Label, Severity};
+use codespan_reporting::files::SimpleFiles;
+use codespan_reporting::term::termcolor::{ColorChoice, StandardStream};
+
+use crate::ir::source_map::*;
+use super::*;
+
+enum StartOrEnd {
+    Start,
+    End,
+}
+
+enum FileSpanToByteRange {
+    AsciiSource,
+    Utf8 {
+        /// Map character regions which are larger then 1-byte to length.
+        lengths: HashMap<isize, isize>,
+        source: String,
+    }
+}
+
+impl FileSpanToByteRange {
+    fn new(source: String) -> FileSpanToByteRange {
+        let mut last_index = 0;
+        let mut is_ascii = true;
+        if source.is_ascii() {
+            FileSpanToByteRange::AsciiSource
+        } else {
+            panic!()
+        }
+
+        // for (index, _) in source.char_indices() {
+        //     if last_index - 1 != last_index {
+        //         is_ascii = false;
+        //     } else {
+        //         panic!();
+        //     }
+        //     last_index = index;
+        // }
+    }
+}
+
+struct SpanToByteRange {
+    map: HashMap<String, FileSpanToByteRange>
+}
+
+impl SpanToByteRange {
+    fn new() -> SpanToByteRange {
+        SpanToByteRange { map: HashMap::new() }
+    }
+
+    pub fn add_source(&mut self, source: Source) {
+        let source_name: String = source.source_name.name.as_str().expect("foo").into();
+
+        if self.map.contains_key(&source_name) {
+            panic!()
+        } else {
+            let source = source.source.as_str().expect("fpp").into();
+            self.map.insert(source_name, FileSpanToByteRange::new(source));
+        }
+    }
+}
+
+struct ByteRange<FileId> {
+    file_id: FileId,
+    start_pos: usize,
+    end_pos: usize,
+}
+
+
+pub fn to_diagnostic(diag: super::Diagnostic) -> CDiagnostic<String> {
+    let severity = match diag.level {
+        DiagnosticLevel::Error => Severity::Error,
+        DiagnosticLevel::Warning => Severity::Warning,
+        DiagnosticLevel::Note => Severity::Note,
+        DiagnosticLevel::Help => Severity::Help,
+        DiagnosticLevel::Bug => Severity::Bug,
+    };
+
+    let file_id = "foo".into(); // diag.span.source_name;
+
+    let message: String = diag.message.as_str().unwrap().into();
+    let inner_message: String = "expected `String`, found `Nat`".into();
+    let diagnostic = CDiagnostic::new(severity)
+        .with_message(message)
+        .with_code("EXXX")
+        .with_labels(vec![
+            Label::primary(file_id, 328..331).with_message(inner_message)
+        ]);
+
+    diagnostic
+}
+
+struct DiagnosticState {
+    files: SimpleFiles<String, String>,
+    span_map: SpanToByteRange,
+}
+
+impl DiagnosticState {
+    fn new() -> DiagnosticState {
+        DiagnosticState {
+            files: SimpleFiles::new(),
+            span_map: SpanToByteRange::new(),
+        }
+    }
+}
+
+fn renderer(state: &mut DiagnosticState, diag_ctx: DiagnosticContext) {
+    let source_map = diag_ctx.module.source_map.clone();
+        for diagnostic in diag_ctx.diagnostics.clone() {
+            match source_map.source_map.get(&diagnostic.span.source_name) {
+                Err(err) => panic!(),
+                Ok(source) => state.span_map.add_source(source),
+            }
+            println!("Diagnostic: {}", diagnostic.message);
+        }
+}
+
+pub fn init() -> Result<()> {
+    let diag_state = Arc::new(Mutex::new(DiagnosticState::new()));
+    let render_fn = move |diag_ctx: DiagnosticContext| {
+        // let mut guard = diag_state.lock().unwrap();
+        // renderer(&mut *guard, diag_ctx);
+    };
+
+    override_renderer(Some(render_fn))?;
+    Ok(())
+}
diff --git a/rust/tvm/src/ir/diagnostics.rs b/rust/tvm/src/ir/diagnostics/mod.rs
similarity index 76%
rename from rust/tvm/src/ir/diagnostics.rs
rename to rust/tvm/src/ir/diagnostics/mod.rs
index 4975a45..fce214a 100644
--- a/rust/tvm/src/ir/diagnostics.rs
+++ b/rust/tvm/src/ir/diagnostics/mod.rs
@@ -18,7 +18,7 @@
  */
 
 use super::module::IRModule;
-use super::span::Span;
+use super::span::*;
 use crate::runtime::function::Result;
 use crate::runtime::object::{Object, ObjectPtr, ObjectRef};
 use crate::runtime::{
@@ -32,7 +32,7 @@ use crate::runtime::{
 /// and the DiagnosticRenderer.
 use tvm_macros::{external, Object};
 
-type SourceName = ObjectRef;
+pub mod codespan;
 
 // Get the the diagnostic renderer.
 external! {
@@ -229,68 +229,3 @@ where
         }
     }
 }
-
-pub mod codespan {
-    use super::*;
-
-    use codespan_reporting::diagnostic::{Diagnostic as CDiagnostic, Label, Severity};
-    use codespan_reporting::files::SimpleFiles;
-    use codespan_reporting::term::termcolor::{ColorChoice, StandardStream};
-
-    enum StartOrEnd {
-        Start,
-        End,
-    }
-
-    // struct SpanToBytes {
-    //     inner: HashMap<std::String, HashMap<usize, (StartOrEnd,
-    // }
-
-    struct ByteRange<FileId> {
-        file_id: FileId,
-        start_pos: usize,
-        end_pos: usize,
-    }
-
-    // impl SpanToBytes {
-    //     fn to_byte_pos(&self, span: tvm::ir::Span) -> ByteRange<FileId> {
-
-    //     }
-    // }
-
-    pub fn to_diagnostic(diag: super::Diagnostic) -> CDiagnostic<String> {
-        let severity = match diag.level {
-            DiagnosticLevel::Error => Severity::Error,
-            DiagnosticLevel::Warning => Severity::Warning,
-            DiagnosticLevel::Note => Severity::Note,
-            DiagnosticLevel::Help => Severity::Help,
-            DiagnosticLevel::Bug => Severity::Bug,
-        };
-
-        let file_id = "foo".into(); // diag.span.source_name;
-
-        let message: String = diag.message.as_str().unwrap().into();
-        let inner_message: String = "expected `String`, found `Nat`".into();
-        let diagnostic = CDiagnostic::new(severity)
-            .with_message(message)
-            .with_code("EXXX")
-            .with_labels(vec![
-                Label::primary(file_id, 328..331).with_message(inner_message)
-            ]);
-
-        diagnostic
-    }
-
-    pub fn init() -> Result<()> {
-        let mut files: SimpleFiles<String, String> = SimpleFiles::new();
-        let render_fn = move |diag_ctx: DiagnosticContext| {
-            let source_map = diag_ctx.module.source_map.clone();
-            for diagnostic in diag_ctx.diagnostics.clone() {
-                println!("Diagnostic: {}", diagnostic.message);
-            }
-        };
-
-        override_renderer(Some(render_fn))?;
-        Ok(())
-    }
-}
diff --git a/rust/tvm/src/ir/source_map.rs b/rust/tvm/src/ir/source_map.rs
index 56c0830..ebe7e46 100644
--- a/rust/tvm/src/ir/source_map.rs
+++ b/rust/tvm/src/ir/source_map.rs
@@ -19,6 +19,7 @@
 
 use crate::runtime::map::Map;
 use crate::runtime::object::Object;
+use crate::runtime::string::{String as TString};
 
 use super::span::{SourceName, Span};
 
@@ -37,7 +38,7 @@ pub struct SourceNode {
     pub source_name: SourceName,
 
     /// The raw source. */
-    source: String,
+    pub source: TString,
 
    // A mapping of line breaks into the raw source.
    // std::vector<std::pair<int, int>> line_map;
diff --git a/rust/tvm/test.rly b/rust/tvm/test.rly
index d8b7c69..e9407b0 100644
--- a/rust/tvm/test.rly
+++ b/rust/tvm/test.rly
@@ -1,2 +1,3 @@
 #[version = "0.0.5"]
-fn @main(%x: int32) -> float32 { %x }
+
+def @main(%x: int32) -> float32 { %x }
diff --git a/src/ir/diagnostic.cc b/src/ir/diagnostic.cc
index f9299e3..a4fee1e 100644
--- a/src/ir/diagnostic.cc
+++ b/src/ir/diagnostic.cc
@@ -113,6 +113,7 @@ TVM_REGISTER_GLOBAL("diagnostics.DiagnosticRendererRender")
     });
 
 DiagnosticContext::DiagnosticContext(const IRModule& module, const DiagnosticRenderer& renderer) {
+  CHECK(renderer.defined()) << "can not initialize a diagnostic renderer with a null function";
   auto n = make_object<DiagnosticContextNode>();
   n->module = module;
   n->renderer = renderer;