You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by jr...@apache.org on 2020/11/05 21:40:58 UTC

[incubator-tvm] 02/23: Codespan example almost working

This is an automated email from the ASF dual-hosted git repository.

jroesch pushed a commit to branch cargo-build
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git

commit 77ba30993a7883c142b05e511e8a5a7a91116b2f
Author: Jared Roesch <jr...@octoml.ai>
AuthorDate: Fri Oct 9 23:59:51 2020 -0700

    Codespan example almost working
---
 rust/tvm-sys/src/packed_func.rs  |   1 +
 rust/tvm/Cargo.toml              |   2 +
 rust/tvm/src/bin/tyck.rs         |  24 ++++++++
 rust/tvm/src/ir/diagnostics.rs   | 121 +++++++++++++++++++++++++++++----------
 rust/tvm/src/ir/relay/visitor.rs |  24 ++++++++
 5 files changed, 143 insertions(+), 29 deletions(-)

diff --git a/rust/tvm-sys/src/packed_func.rs b/rust/tvm-sys/src/packed_func.rs
index f7b289c..7b8d529 100644
--- a/rust/tvm-sys/src/packed_func.rs
+++ b/rust/tvm-sys/src/packed_func.rs
@@ -101,6 +101,7 @@ macro_rules! TVMPODValue {
                         TVMArgTypeCode_kTVMOpaqueHandle => Handle($value.v_handle),
                         TVMArgTypeCode_kTVMDLTensorHandle => ArrayHandle($value.v_handle as TVMArrayHandle),
                         TVMArgTypeCode_kTVMObjectHandle => ObjectHandle($value.v_handle),
+                        TVMArgTypeCode_kTVMObjectRValueRefArg => ObjectHandle(*($value.v_handle as *mut *mut c_void)),
                         TVMArgTypeCode_kTVMModuleHandle => ModuleHandle($value.v_handle),
                         TVMArgTypeCode_kTVMPackedFuncHandle => FuncHandle($value.v_handle),
                         TVMArgTypeCode_kTVMNDArrayHandle => NDArrayHandle($value.v_handle),
diff --git a/rust/tvm/Cargo.toml b/rust/tvm/Cargo.toml
index 55fc179..71a4b93 100644
--- a/rust/tvm/Cargo.toml
+++ b/rust/tvm/Cargo.toml
@@ -41,6 +41,8 @@ paste = "0.1"
 mashup = "0.1"
 once_cell = "^1.3.1"
 pyo3 = { version = "0.11.1", optional = true }
+codespan-reporting = "0.9.5"
+structopt = { version = "0.3" }
 
 [features]
 default = ["python"]
diff --git a/rust/tvm/src/bin/tyck.rs b/rust/tvm/src/bin/tyck.rs
new file mode 100644
index 0000000..9300412
--- /dev/null
+++ b/rust/tvm/src/bin/tyck.rs
@@ -0,0 +1,24 @@
+use std::path::PathBuf;
+
+use anyhow::Result;
+use structopt::StructOpt;
+
+use tvm::ir::diagnostics::codespan;
+use tvm::ir::IRModule;
+
+
+#[derive(Debug, StructOpt)]
+#[structopt(name = "tyck", about = "Parse and type check a Relay program.")]
+struct Opt {
+    /// Input file
+    #[structopt(parse(from_os_str))]
+    input: PathBuf,
+}
+
+fn main() -> Result<()> {
+    codespan::init().expect("Rust based diagnostics");
+    let opt = Opt::from_args();
+    println!("{:?}", &opt);
+    let file = IRModule::parse_file(opt.input)?;
+    Ok(())
+}
diff --git a/rust/tvm/src/ir/diagnostics.rs b/rust/tvm/src/ir/diagnostics.rs
index 799a10c..e434d3f 100644
--- a/rust/tvm/src/ir/diagnostics.rs
+++ b/rust/tvm/src/ir/diagnostics.rs
@@ -24,13 +24,31 @@
 
 use tvm_macros::{Object, external};
 use super::module::IRModule;
-use crate::runtime::{function::{Function, Typed}, array::Array, string::String as TString};
-use crate::runtime::object::{Object, ObjectRef};
+use crate::runtime::{function::{self, Function, ToFunction, Typed}, array::Array, string::String as TString};
+use crate::runtime::object::{Object, ObjectPtr, ObjectRef};
 use crate::runtime::function::Result;
 use super::span::Span;
 
 type SourceName = ObjectRef;
 
+// Get the the diagnostic renderer.
+external! {
+    #[name("node.ArrayGetItem")]
+    fn get_renderer() -> DiagnosticRenderer;
+
+    #[name("diagnostics.DiagnosticRenderer")]
+    fn diagnostic_renderer(func: Function) -> DiagnosticRenderer;
+
+    #[name("diagnostics.Emit")]
+    fn emit(ctx: DiagnosticContext, diagnostic: Diagnostic) -> ();
+
+    #[name("diagnostics.DiagnosticContextRender")]
+    fn diagnostic_context_render(ctx: DiagnosticContext) -> ();
+
+    #[name("diagnostics.ClearRenderer")]
+    fn clear_renderer() -> ();
+}
+
 /// The diagnostic level, controls the printing of the message.
 #[repr(C)]
 pub enum DiagnosticLevel {
@@ -171,26 +189,20 @@ pub struct DiagnosticContextNode {
     pub renderer: DiagnosticRenderer,
 }
 
-// Get the the diagnostic renderer.
-external! {
-    #[name("node.ArrayGetItem")]
-    fn get_renderer() -> DiagnosticRenderer;
-
-    #[name("diagnostics.DiagnosticRenderer")]
-    fn diagnostic_renderer(func: Function) -> DiagnosticRenderer;
-
-    #[name("diagnostics.Emit")]
-    fn emit(ctx: DiagnosticContext, diagnostic: Diagnostic) -> ();
-
-    #[name("diagnostics.DiagnosticContextRender")]
-    fn diagnostic_context_render(ctx: DiagnosticContext) -> ();
-}
-
 /// A diagnostic context which records active errors
 /// and contains a renderer.
 impl DiagnosticContext {
-    pub fn new(module: IRModule, renderer: DiagnosticRenderer) {
-        todo!()
+    pub fn new<F>(module: IRModule, render_func: F) -> DiagnosticContext
+    where F: Fn(DiagnosticContext) -> () + 'static
+    {
+        let renderer = diagnostic_renderer(render_func.to_function()).unwrap();
+        let node = DiagnosticContextNode {
+            base: Object::base_object::<DiagnosticContextNode>(),
+            module,
+            diagnostics: Array::from_vec(vec![]).unwrap(),
+            renderer,
+        };
+        DiagnosticContext(Some(ObjectPtr::new(node)))
     }
 
     pub fn default(module: IRModule) -> DiagnosticContext {
@@ -223,17 +235,68 @@ impl DiagnosticContext {
 //     If the render_func is None it will remove the current custom renderer
 //     and return to default behavior.
 fn override_renderer<F>(opt_func: Option<F>) -> Result<()>
-where F: Fn(DiagnosticContext) -> ()
+where F: Fn(DiagnosticContext) -> () + 'static
 {
-    todo!()
-    // fn ()
-    // diagnostic_renderer(func)
-    // if render_func:
 
-    //     def _render_factory():
-    //         return DiagnosticRenderer(render_func)
+    match opt_func {
+        None => clear_renderer(),
+        Some(func) => {
+            let func = func.to_function();
+            let render_factory = move || {
+                diagnostic_renderer(func.clone()).unwrap()
+            };
+
+            function::register_override(
+                render_factory,
+                "diagnostics.OverrideRenderer",
+                true)?;
+
+            Ok(())
+        }
+    }
+}
+
+pub mod codespan {
+    use super::*;
+
+    use codespan_reporting::diagnostic::{Diagnostic as CDiagnostic, Label, Severity};
+    use codespan_reporting::files::SimpleFiles;
+    use codespan_reporting::term::termcolor::{ColorChoice, StandardStream};
+
+    pub fn to_diagnostic(diag: super::Diagnostic) -> CDiagnostic<String> {
+        let severity = match diag.level {
+            DiagnosticLevel::Error => Severity::Error,
+            DiagnosticLevel::Warning => Severity::Warning,
+            DiagnosticLevel::Note => Severity::Note,
+            DiagnosticLevel::Help => Severity::Help,
+            DiagnosticLevel::Bug => Severity::Bug,
+        };
+
+        let file_id = "foo".into(); // diag.span.source_name;
+
+        let message: String = diag.message.as_str().unwrap().into();
+        let inner_message: String = "expected `String`, found `Nat`".into();
+        let diagnostic = CDiagnostic::new(severity)
+            .with_message(message)
+            .with_code("EXXX")
+            .with_labels(vec![
+                Label::primary(file_id, 328..331).with_message(inner_message),
+            ]);
+
+        diagnostic
+    }
+
+    pub fn init() -> Result<()> {
+        let mut files: SimpleFiles<String, String> = SimpleFiles::new();
+        let render_fn = move |diag_ctx: DiagnosticContext| {
+            // let source_map = diag_ctx.module.source_map;
+            for diagnostic in diag_ctx.diagnostics {
+
+            }
+            panic!("render_fn");
+        };
 
-    //     register_func("diagnostics.OverrideRenderer", _render_factory, override=True)
-    // else:
-    //     _ffi_api.ClearRenderer()
+        override_renderer(Some(render_fn))?;
+        Ok(())
+    }
 }
diff --git a/rust/tvm/src/ir/relay/visitor.rs b/rust/tvm/src/ir/relay/visitor.rs
new file mode 100644
index 0000000..3166174
--- /dev/null
+++ b/rust/tvm/src/ir/relay/visitor.rs
@@ -0,0 +1,24 @@
+use super::Expr;
+
+macro_rules! downcast_match {
+    ($id:ident; { $($t:ty => $arm:expr $(,)? )+ , else => $default:expr }) => {
+        $( if let Ok($id) = $id.downcast_clone::<$t>() { $arm } else )+
+        { $default }
+    }
+}
+
+trait ExprVisitorMut {
+    fn visit(&mut self, expr: Expr) {
+        downcast_match!(expr; {
+            else => {
+                panic!()
+            }
+        });
+    }
+
+    fn visit(&mut self, expr: Expr);
+}
+
+// trait ExprTransformer {
+//     fn
+// }