You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by jr...@apache.org on 2020/11/05 21:40:56 UTC

[incubator-tvm] branch cargo-build created (now c933926)

This is an automated email from the ASF dual-hosted git repository.

jroesch pushed a change to branch cargo-build
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git.


      at c933926  Post-rebase

This branch includes the following new commits:

     new 1097cbf  Add initial boilerplate for Rust diagnostic interface.
     new 77ba309  Codespan example almost working
     new cb37856  WIP
     new 131e40a  Hacking on Rust inside of TVM
     new b2b59c2  Borrow code from Egg
     new db24553  Update CMake and delete old API
     new e0f9801  Fix Linux build
     new 20c6a28  Clean up exporting to show off new diagnostics
     new 4cd1bbc  Improve Rust bindings
     new eeb86c6  Fix calling
     new 4261461  Fix
     new 6e13467  Rust Diagnostics work
     new 0cabfdc  Remove type checker
     new 04a9779  Format and cleanup
     new 6828374  Fix the extension code
     new 1874350  More cleanup
     new 8e295b7  Fix some CR
     new 49246bf  WIP
     new a9ee3cb  WIP
     new b8dcc35  WIP
     new 478326c  Debug segfault
     new e8bb83d  WIP
     new c933926  Post-rebase

The 23 revisions listed above as "new" are entirely new to this
repository and will be described in separate emails.  The revisions
listed as "add" were already present in the repository and have only
been added to this reference.



[incubator-tvm] 11/23: Fix

Posted by jr...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

jroesch pushed a commit to branch cargo-build
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git

commit 4261461974f2fb548d6b48e19fab07d34dd03a51
Author: Jared Roesch <ro...@gmail.com>
AuthorDate: Fri Oct 16 02:11:53 2020 -0700

    Fix
---
 rust/tvm/src/ir/module.rs | 25 +++++++++++++++++--------
 1 file changed, 17 insertions(+), 8 deletions(-)

diff --git a/rust/tvm/src/ir/module.rs b/rust/tvm/src/ir/module.rs
index 5156e74..11d6c49 100644
--- a/rust/tvm/src/ir/module.rs
+++ b/rust/tvm/src/ir/module.rs
@@ -16,6 +16,11 @@
  * specific language governing permissions and limitations
  * under the License.
  */
+use std::io::Result as IOResult;
+use std::path::Path;
+
+use thiserror::Error;
+use tvm_macros::Object;
 
 use crate::runtime::array::Array;
 use crate::runtime::function::Result;
@@ -27,15 +32,19 @@ use super::expr::GlobalVar;
 use super::function::BaseFunc;
 use super::source_map::SourceMap;
 
-use std::io::Result as IOResult;
-use std::path::Path;
-
-use tvm_macros::Object;
 
 // TODO(@jroesch): define type
 type TypeData = ObjectRef;
 type GlobalTypeVar = ObjectRef;
 
+#[derive(Error, Debug)]
+pub enum Error {
+    #[error("{0}")]
+    IO(#[from] std::io::Error),
+    #[error("{0}")]
+    TVM(#[from] crate::runtime::Error),
+}
+
 #[repr(C)]
 #[derive(Object)]
 #[ref_name = "IRModule"]
@@ -116,19 +125,19 @@ external! {
 //     });
 
 impl IRModule {
-    pub fn parse<N, S>(file_name: N, source: S) -> IRModule
+    pub fn parse<N, S>(file_name: N, source: S) -> Result<IRModule>
     where
         N: Into<TVMString>,
         S: Into<TVMString>,
     {
-        parse_module(file_name.into(), source.into()).expect("failed to call parser")
+        parse_module(file_name.into(), source.into())
     }
 
-    pub fn parse_file<P: 'static + AsRef<Path>>(file_path: P) -> IOResult<IRModule> {
+    pub fn parse_file<P: 'static + AsRef<Path>>(file_path: P) -> std::result::Result<IRModule, Error> {
         let file_path = file_path.as_ref();
         let file_path_as_str = file_path.to_str().unwrap().to_string();
         let source = std::fs::read_to_string(file_path)?;
-        let module = IRModule::parse(file_path_as_str, source);
+        let module = IRModule::parse(file_path_as_str, source)?;
         Ok(module)
     }
 


[incubator-tvm] 23/23: Post-rebase

Posted by jr...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

jroesch pushed a commit to branch cargo-build
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git

commit c9339267bf7ae641a7f9c3678c84cf1ff600959e
Author: Jared Roesch <jr...@octoml.ai>
AuthorDate: Sat Oct 31 15:39:05 2020 -0700

    Post-rebase
---
 rust/tvm/src/ir/module.rs | 3 ---
 1 file changed, 3 deletions(-)

diff --git a/rust/tvm/src/ir/module.rs b/rust/tvm/src/ir/module.rs
index db32ce2..869c5e6 100644
--- a/rust/tvm/src/ir/module.rs
+++ b/rust/tvm/src/ir/module.rs
@@ -34,11 +34,8 @@ use super::function::BaseFunc;
 use super::source_map::SourceMap;
 use super::{ty::GlobalTypeVar, relay};
 
-use tvm_macros::Object;
-
 // TODO(@jroesch): define type
 type TypeData = ObjectRef;
-type GlobalTypeVar = ObjectRef;
 
 #[derive(Error, Debug)]
 pub enum Error {


[incubator-tvm] 15/23: Fix the extension code

Posted by jr...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

jroesch pushed a commit to branch cargo-build
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git

commit 68283745b9c283db611579205a3b925eb09e9faa
Author: Jared Roesch <jr...@octoml.ai>
AuthorDate: Fri Oct 16 16:32:13 2020 -0700

    Fix the extension code
---
 src/contrib/rust_extension.cc | 31 +++++++++++++++++++++++++++++++
 1 file changed, 31 insertions(+)

diff --git a/src/contrib/rust_extension.cc b/src/contrib/rust_extension.cc
new file mode 100644
index 0000000..075cbc6
--- /dev/null
+++ b/src/contrib/rust_extension.cc
@@ -0,0 +1,31 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/contrib/rust_extension.cc
+ * \brief Expose Rust extensions initialization.
+ */
+#ifdef RUST_COMPILER_EXT
+
+extern "C" {
+  int compiler_ext_initialize();
+  static int test = compiler_ext_initialize();
+}
+
+#endif


[incubator-tvm] 13/23: Remove type checker

Posted by jr...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

jroesch pushed a commit to branch cargo-build
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git

commit 0cabfdcee309b12d8907fc3abe2ba8e8718ecac6
Author: Jared Roesch <jr...@octoml.ai>
AuthorDate: Fri Oct 16 15:56:10 2020 -0700

    Remove type checker
---
 tests/python/relay/test_type_infer2.py | 419 ---------------------------------
 1 file changed, 419 deletions(-)

diff --git a/tests/python/relay/test_type_infer2.py b/tests/python/relay/test_type_infer2.py
deleted file mode 100644
index 6758d96..0000000
--- a/tests/python/relay/test_type_infer2.py
+++ /dev/null
@@ -1,419 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""Test that type checker correcly computes types
-   for expressions.
-"""
-import pytest
-import tvm
-
-from tvm import IRModule, te, relay, parser
-from tvm.relay import op, transform, analysis
-from tvm.relay import Any
-
-
-def infer_mod(mod, annotate_spans=True):
-    if annotate_spans:
-        mod = relay.transform.AnnotateSpans()(mod)
-
-    mod = transform.InferType()(mod)
-    return mod
-
-
-def infer_expr(expr, annotate_spans=True):
-    mod = IRModule.from_expr(expr)
-    mod = infer_mod(mod, annotate_spans)
-    mod = transform.InferType()(mod)
-    entry = mod["main"]
-    return entry if isinstance(expr, relay.Function) else entry.body
-
-
-def assert_has_type(expr, typ, mod=None):
-    if not mod:
-        mod = tvm.IRModule({})
-
-    mod["main"] = expr
-    mod = infer_mod(mod)
-    checked_expr = mod["main"]
-    checked_type = checked_expr.checked_type
-    if checked_type != typ:
-        raise RuntimeError("Type mismatch %s vs %s" % (checked_type, typ))
-
-
-def initialize_box_adt(mod):
-    # initializes simple ADT for tests
-    box = relay.GlobalTypeVar("box")
-    tv = relay.TypeVar("tv")
-    constructor = relay.Constructor("constructor", [tv], box)
-    data = relay.TypeData(box, [tv], [constructor])
-    mod[box] = data
-    return box, constructor
-
-
-def test_monomorphic_let():
-    "Program: let %x = 1; %x"
-    # TODO(@jroesch): this seems whack.
-    sb = relay.ScopeBuilder()
-    x = relay.var("x", dtype="float64", shape=())
-    x = sb.let("x", relay.const(1.0, "float64"))
-    sb.ret(x)
-    xchecked = infer_expr(sb.get())
-    assert xchecked.checked_type == relay.scalar_type("float64")
-
-
-def test_single_op():
-    "Program: fn (%x : float32) { let %t1 = f(%x); %t1 }"
-    x = relay.var("x", shape=[])
-    func = relay.Function([x], op.log(x))
-    ttype = relay.TensorType([], dtype="float32")
-    assert_has_type(func, relay.FuncType([ttype], ttype))
-
-
-def test_add_broadcast_op():
-    """
-    Program:
-        fn (%x: Tensor[(10, 4), float32], %y: Tensor[(5, 10, 1), float32])
-            -> Tensor[(5, 10, 4), float32] {
-            %x + %y
-        }
-    """
-    x = relay.var("x", shape=(10, 4))
-    y = relay.var("y", shape=(5, 10, 1))
-    z = x + y
-    func = relay.Function([x, y], z)
-    t1 = relay.TensorType((10, 4), "float32")
-    t2 = relay.TensorType((5, 10, 1), "float32")
-    t3 = relay.TensorType((5, 10, 4), "float32")
-    expected_ty = relay.FuncType([t1, t2], t3)
-    assert_has_type(func, expected_ty)
-
-
-def test_dual_op():
-    """Program:
-    fn (%x : Tensor[(10, 10), float32]) {
-      let %t1 = log(x);
-      let %t2 = add(%t1, %x);
-      %t1
-    }
-    """
-    tp = relay.TensorType((10, 10), "float32")
-    x = relay.var("x", tp)
-    sb = relay.ScopeBuilder()
-    t1 = sb.let("t1", relay.log(x))
-    t2 = sb.let("t2", relay.add(t1, x))
-    sb.ret(t2)
-    f = relay.Function([x], sb.get())
-    fchecked = infer_expr(f)
-    assert fchecked.checked_type == relay.FuncType([tp], tp)
-
-
-def test_decl():
-    """Program:
-    def @f(%x : Tensor[(10, 10), float32]) {
-        log(%x)
-    }
-    """
-    tp = relay.TensorType((10, 10))
-    x = relay.var("x", tp)
-    f = relay.Function([x], relay.log(x))
-    fchecked = infer_expr(f)
-    assert fchecked.checked_type == relay.FuncType([tp], tp)
-
-
-def test_recursion():
-    """
-    Program:
-       def @f(%n: int32, %data: float32) -> float32 {
-          if (%n == 0) {
-              %data
-          } else {
-              @f(%n - 1, log(%data))
-          }
-       }
-    """
-    sb = relay.ScopeBuilder()
-    f = relay.GlobalVar("f")
-    ti32 = relay.scalar_type("int32")
-    tf32 = relay.scalar_type("float32")
-    n = relay.var("n", ti32)
-    data = relay.var("data", tf32)
-
-    with sb.if_scope(relay.equal(n, relay.const(0, ti32))):
-        sb.ret(data)
-    with sb.else_scope():
-        sb.ret(f(relay.subtract(n, relay.const(1, ti32)), relay.log(data)))
-    mod = tvm.IRModule()
-    mod[f] = relay.Function([n, data], sb.get())
-    mod = infer_mod(mod)
-    assert "@f(%1, %2)" in mod.astext()
-    assert mod["f"].checked_type == relay.FuncType([ti32, tf32], tf32)
-
-
-def test_incomplete_call():
-    tt = relay.scalar_type("int32")
-    x = relay.var("x", tt)
-    f = relay.var("f")
-    func = relay.Function([x, f], relay.Call(f, [x]), tt)
-
-    ft = infer_expr(func)
-    f_type = relay.FuncType([tt], tt)
-    assert ft.checked_type == relay.FuncType([tt, f_type], tt)
-
-
-def test_higher_order_argument():
-    a = relay.TypeVar("a")
-    x = relay.Var("x", a)
-    id_func = relay.Function([x], x, a, [a])
-
-    b = relay.TypeVar("b")
-    f = relay.Var("f", relay.FuncType([b], b))
-    y = relay.Var("y", b)
-    ho_func = relay.Function([f, y], f(y), b, [b])
-
-    # id func should be an acceptable argument to the higher-order
-    # function even though id_func takes a type parameter
-    ho_call = ho_func(id_func, relay.const(0, "int32"))
-
-    hc = infer_expr(ho_call)
-    expected = relay.scalar_type("int32")
-    assert hc.checked_type == expected
-
-
-def test_higher_order_return():
-    a = relay.TypeVar("a")
-    x = relay.Var("x", a)
-    id_func = relay.Function([x], x, a, [a])
-
-    b = relay.TypeVar("b")
-    nested_id = relay.Function([], id_func, relay.FuncType([b], b), [b])
-
-    ft = infer_expr(nested_id)
-    assert ft.checked_type == relay.FuncType([], relay.FuncType([b], b), [b])
-
-
-def test_higher_order_nested():
-    a = relay.TypeVar("a")
-    x = relay.Var("x", a)
-    id_func = relay.Function([x], x, a, [a])
-
-    choice_t = relay.FuncType([], relay.scalar_type("bool"))
-    f = relay.Var("f", choice_t)
-
-    b = relay.TypeVar("b")
-    z = relay.Var("z")
-    top = relay.Function(
-        [f], relay.If(f(), id_func, relay.Function([z], z)), relay.FuncType([b], b), [b]
-    )
-
-    expected = relay.FuncType([choice_t], relay.FuncType([b], b), [b])
-    ft = infer_expr(top)
-    assert ft.checked_type == expected
-
-
-def test_tuple():
-    tp = relay.TensorType((10,))
-    x = relay.var("x", tp)
-    res = relay.Tuple([x, x])
-    assert infer_expr(res).checked_type == relay.TupleType([tp, tp])
-
-
-def test_ref():
-    x = relay.var("x", "float32")
-    y = relay.var("y", "float32")
-    r = relay.RefCreate(x)
-    st = relay.scalar_type("float32")
-    assert infer_expr(r).checked_type == relay.RefType(st)
-    g = relay.RefRead(r)
-    assert infer_expr(g).checked_type == st
-    w = relay.RefWrite(r, y)
-    assert infer_expr(w).checked_type == relay.TupleType([])
-
-
-def test_free_expr():
-    x = relay.var("x", "float32")
-    y = relay.add(x, x)
-    yy = infer_expr(y, annotate_spans=False)
-    assert tvm.ir.structural_equal(yy.args[0], x, map_free_vars=True)
-    assert yy.checked_type == relay.scalar_type("float32")
-    assert x.vid.same_as(yy.args[0].vid)
-
-
-def test_type_args():
-    x = relay.var("x", shape=(10, 10))
-    y = relay.var("y", shape=(1, 10))
-    z = relay.add(x, y)
-    ty_z = infer_expr(z)
-    ty_args = ty_z.type_args
-    assert len(ty_args) == 2
-    assert ty_args[0].dtype == "float32"
-    assert ty_args[1].dtype == "float32"
-    sh1 = ty_args[0].shape
-    sh2 = ty_args[1].shape
-    assert sh1[0].value == 10
-    assert sh1[1].value == 10
-    assert sh2[0].value == 1
-    assert sh2[1].value == 10
-
-
-def test_global_var_recursion():
-    mod = tvm.IRModule({})
-    gv = relay.GlobalVar("main")
-    x = relay.var("x", shape=[])
-    tt = relay.scalar_type("float32")
-
-    func = relay.Function([x], relay.Call(gv, [x]), tt)
-    mod[gv] = func
-    mod = infer_mod(mod)
-    func_ty = mod["main"].checked_type
-
-    assert func_ty == relay.FuncType([tt], tt)
-
-
-def test_equal():
-    i = relay.var("i", shape=[], dtype="int32")
-    eq = op.equal(i, relay.const(0, dtype="int32"))
-    func = relay.Function([i], eq)
-    ft = infer_expr(func)
-    expected = relay.FuncType([relay.scalar_type("int32")], relay.scalar_type("bool"))
-    assert ft.checked_type == expected
-
-    assert ft.checked_type == relay.FuncType(
-        [relay.scalar_type("int32")], relay.scalar_type("bool")
-    )
-
-
-def test_constructor_type():
-    mod = tvm.IRModule()
-    box, constructor = initialize_box_adt(mod)
-
-    a = relay.TypeVar("a")
-    x = relay.Var("x", a)
-    func = relay.Function([x], constructor(x), box(a), [a])
-    mod["main"] = func
-    mod = infer_mod(mod)
-    func_ty = mod["main"].checked_type
-    box = mod.get_global_type_var("box")
-    expected = relay.FuncType([a], box(a), [a])
-    assert func_ty == expected
-
-
-def test_constructor_call():
-    mod = tvm.IRModule()
-    box, constructor = initialize_box_adt(mod)
-
-    box_unit = constructor(relay.Tuple([]))
-    box_constant = constructor(relay.const(0, "float32"))
-
-    func = relay.Function([], relay.Tuple([box_unit, box_constant]))
-    mod["main"] = func
-    mod = infer_mod(mod)
-    ret_type = mod["main"].checked_type.ret_type.fields
-    # NB(@jroesch): when we annotate spans the ast fragments before
-    # annotation the previous fragments will no longer be directly equal.
-    box = mod.get_global_type_var("box")
-    expected1 = box(relay.TupleType([]))
-    expected2 = box(relay.TensorType((), "float32"))
-    assert ret_type[0] == expected1
-    assert ret_type[1] == expected2
-
-
-def test_adt_match():
-    mod = tvm.IRModule()
-    box, constructor = initialize_box_adt(mod)
-
-    v = relay.Var("v", relay.TensorType((), "float32"))
-    match = relay.Match(
-        constructor(relay.const(0, "float32")),
-        [
-            relay.Clause(
-                relay.PatternConstructor(constructor, [relay.PatternVar(v)]), relay.Tuple([])
-            ),
-            # redundant but shouldn't matter to typechecking
-            relay.Clause(relay.PatternWildcard(), relay.Tuple([])),
-        ],
-    )
-
-    func = relay.Function([], match)
-    mod["main"] = func
-    mod = infer_mod(mod)
-    actual = mod["main"].checked_type.ret_type
-    assert actual == relay.TupleType([])
-
-
-def test_adt_match_type_annotations():
-    mod = tvm.IRModule()
-    box, constructor = initialize_box_adt(mod)
-
-    # the only type annotation is inside the match pattern var
-    # but that should be enough info
-    tt = relay.TensorType((2, 2), "float32")
-    x = relay.Var("x")
-    mv = relay.Var("mv", tt)
-    match = relay.Match(
-        constructor(x),
-        [
-            relay.Clause(
-                relay.PatternConstructor(constructor, [relay.PatternVar(mv)]), relay.Tuple([])
-            )
-        ],
-    )
-
-    mod["main"] = relay.Function([x], match)
-    mod = infer_mod(mod)
-    ft = mod["main"].checked_type
-    assert ft == relay.FuncType([tt], relay.TupleType([]))
-
-
-def test_let_polymorphism():
-    id = relay.Var("id")
-    xt = relay.TypeVar("xt")
-    x = relay.Var("x", xt)
-    body = relay.Tuple([id(relay.const(1)), id(relay.Tuple([]))])
-    body = relay.Let(id, relay.Function([x], x, xt, [xt]), body)
-    body = infer_expr(body)
-    int32 = relay.TensorType((), "int32")
-    tvm.ir.assert_structural_equal(body.checked_type, relay.TupleType([int32, relay.TupleType([])]))
-
-
-def test_if():
-    choice_t = relay.FuncType([], relay.scalar_type("bool"))
-    f = relay.Var("f", choice_t)
-    true_branch = relay.Var("True", relay.TensorType([Any(), 1], dtype="float32"))
-    false_branch = relay.Var("False", relay.TensorType([Any(), Any()], dtype="float32"))
-    top = relay.Function([f, true_branch, false_branch], relay.If(f(), true_branch, false_branch))
-    ft = infer_expr(top)
-    tvm.ir.assert_structural_equal(ft.ret_type, relay.TensorType([Any(), 1], dtype="float32"))
-
-
-def test_type_arg_infer():
-    code = """
-#[version = "0.0.5"]
-def @id[A](%x: A) -> A {
-  %x
-}
-def @main(%f: float32) -> float32 {
-  @id(%f)
-}
-"""
-    mod = tvm.parser.fromtext(code)
-    mod = transform.InferType()(mod)
-    tvm.ir.assert_structural_equal(mod["main"].body.type_args, [relay.TensorType((), "float32")])
-
-
-if __name__ == "__main__":
-    import sys
-
-    pytest.main(sys.argv)


[incubator-tvm] 21/23: Debug segfault

Posted by jr...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

jroesch pushed a commit to branch cargo-build
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git

commit 478326c8d7c2ccce10b77aa28e4b22bafbfdf877
Author: Jared Roesch <jr...@octoml.ai>
AuthorDate: Sun Oct 25 17:26:47 2020 -0700

    Debug segfault
---
 python/tvm/__init__.py                         |  2 ++
 python/tvm/relay/__init__.py                   |  3 +-
 python/tvm/relay/analysis/__init__.py          |  2 +-
 python/tvm/relay/analysis/analysis.py          |  6 ++--
 python/tvm/relay/analysis/annotated_regions.py |  2 +-
 python/tvm/relay/analysis/call_graph.py        |  4 +--
 python/tvm/relay/analysis/sparse_dense.py      | 15 ++++-----
 python/tvm/relay/build_module.py               |  6 +++-
 python/tvm/relay/op/op.py                      | 40 ++++++++++++------------
 python/tvm/relay/transform/__init__.py         |  2 +-
 python/tvm/relay/transform/memory_alloc.py     |  7 ++---
 python/tvm/relay/transform/transform.py        |  5 +--
 python/tvm/topi/cuda/__init__.py               |  2 --
 python/tvm/topi/cuda/sparse.py                 |  3 +-
 rust/tvm/Cargo.toml                            |  2 +-
 rust/tvm/src/ir/relay/mod.rs                   |  1 -
 rust/tvm/src/python.rs                         | 43 +++++++++++++++++++++++---
 17 files changed, 91 insertions(+), 54 deletions(-)

diff --git a/python/tvm/__init__.py b/python/tvm/__init__.py
index 569e8f0..60f81f4 100644
--- a/python/tvm/__init__.py
+++ b/python/tvm/__init__.py
@@ -67,6 +67,8 @@ from . import support
 # Contrib initializers
 from .contrib import rocm as _rocm, nvcc as _nvcc, sdaccel as _sdaccel
 
+def cleanup():
+    _ffi.base._LIB = None
 
 def tvm_wrap_excepthook(exception_hook):
     """Wrap given excepthook with TVM additional work."""
diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py
index cd96ecc..7e6ed4f 100644
--- a/python/tvm/relay/__init__.py
+++ b/python/tvm/relay/__init__.py
@@ -60,8 +60,7 @@ from . import qnn
 from .scope_builder import ScopeBuilder
 
 # Load Memory Passes
-from .transform import memory_alloc
-from .transform import memory_plan
+from .transform import memory_alloc, memory_plan
 
 # Required to traverse large programs
 setrecursionlimit(10000)
diff --git a/python/tvm/relay/analysis/__init__.py b/python/tvm/relay/analysis/__init__.py
index b4ea7f3..4ea4de7 100644
--- a/python/tvm/relay/analysis/__init__.py
+++ b/python/tvm/relay/analysis/__init__.py
@@ -26,7 +26,7 @@ from .annotated_regions import AnnotatedRegionSet
 from . import call_graph
 from .call_graph import CallGraph
 
-# Feature
+# # Feature
 from . import feature
 from . import sparse_dense
 
diff --git a/python/tvm/relay/analysis/analysis.py b/python/tvm/relay/analysis/analysis.py
index 7e49461..48e9ce0 100644
--- a/python/tvm/relay/analysis/analysis.py
+++ b/python/tvm/relay/analysis/analysis.py
@@ -20,9 +20,9 @@
 This file contains the set of passes for Relay, which exposes an interface for
 configuring the passes and scripting them in Python.
 """
-from tvm.ir import IRModule
-from tvm.relay import transform, build_module
-from tvm.runtime.ndarray import cpu
+from ...ir import IRModule
+from ...relay import transform, build_module
+from ...runtime.ndarray import cpu
 
 from . import _ffi_api
 from .feature import Feature
diff --git a/python/tvm/relay/analysis/annotated_regions.py b/python/tvm/relay/analysis/annotated_regions.py
index 437b97b..a18ccb9 100644
--- a/python/tvm/relay/analysis/annotated_regions.py
+++ b/python/tvm/relay/analysis/annotated_regions.py
@@ -17,7 +17,7 @@
 # pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name, unused-import
 """Regions used in Relay."""
 
-from tvm.runtime import Object
+from ...runtime import Object
 from . import _ffi_api
 
 
diff --git a/python/tvm/relay/analysis/call_graph.py b/python/tvm/relay/analysis/call_graph.py
index 966659a..fd9704d 100644
--- a/python/tvm/relay/analysis/call_graph.py
+++ b/python/tvm/relay/analysis/call_graph.py
@@ -17,8 +17,8 @@
 # pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name, unused-import
 """Call graph used in Relay."""
 
-from tvm.ir import IRModule
-from tvm.runtime import Object
+from ...ir import IRModule
+from ...runtime import Object
 from ..expr import GlobalVar
 from . import _ffi_api
 
diff --git a/python/tvm/relay/analysis/sparse_dense.py b/python/tvm/relay/analysis/sparse_dense.py
index d521748..51fab34 100644
--- a/python/tvm/relay/analysis/sparse_dense.py
+++ b/python/tvm/relay/analysis/sparse_dense.py
@@ -22,8 +22,8 @@ to block sparse model
 """
 from collections import namedtuple
 import numpy as np
-import scipy.sparse as sp
-import tvm
+
+from ... import nd, runtime
 from . import _ffi_api
 
 
@@ -73,6 +73,7 @@ def process_params(expr, params, block_size, sparsity_threshold):
     ret : Namedtuple[weight_name: Array[String], weight_shape: Array[Array[IntImm]]]
         return names of qualified dense weight and the shape in BSR format
     """
+    import scipy.sparse as sp
     memo = SparseAnalysisResult(weight_name=[], weight_shape=[])
     weight_names = _search_dense_op_weight(expr)
     for name in weight_names:
@@ -89,11 +90,11 @@ def process_params(expr, params, block_size, sparsity_threshold):
                 + list(sparse_weight.indices.shape)
                 + list(sparse_weight.indptr.shape)
             )
-            params[name + ".data"] = tvm.nd.array(sparse_weight.data)
-            params[name + ".indices"] = tvm.nd.array(sparse_weight.indices)
-            params[name + ".indptr"] = tvm.nd.array(sparse_weight.indptr)
+            params[name + ".data"] = nd.array(sparse_weight.data)
+            params[name + ".indices"] = nd.array(sparse_weight.indices)
+            params[name + ".indptr"] = nd.array(sparse_weight.indptr)
     ret = SparseAnalysisResult(
-        weight_name=tvm.runtime.convert(memo.weight_name),
-        weight_shape=tvm.runtime.convert(memo.weight_shape),
+        weight_name=runtime.convert(memo.weight_name),
+        weight_shape=runtime.convert(memo.weight_shape),
     )
     return ret
diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py
index 35bd8e6..e93d654 100644
--- a/python/tvm/relay/build_module.py
+++ b/python/tvm/relay/build_module.py
@@ -24,7 +24,7 @@ import numpy as np
 from tvm.ir import IRModule
 
 from tvm.tir import expr as tvm_expr
-from .. import nd as _nd, autotvm
+from .. import nd as _nd, autotvm, register_func
 from ..target import Target
 from ..contrib import graph_runtime as _graph_rt
 from . import _build_module
@@ -186,6 +186,10 @@ class BuildModule(object):
             ret[key] = value.data
         return ret
 
+@register_func("tvm.relay.build")
+def build1(mod, target=None, target_host=None, params=None, mod_name="default"):
+    import pdb; pdb.set_trace()
+    return build(mod, target, target_host, params, mod_name)
 
 def build(mod, target=None, target_host=None, params=None, mod_name="default"):
     """Helper function that builds a Relay function to run on TVM graph
diff --git a/python/tvm/relay/op/op.py b/python/tvm/relay/op/op.py
index 755659a..f780dc9 100644
--- a/python/tvm/relay/op/op.py
+++ b/python/tvm/relay/op/op.py
@@ -16,10 +16,8 @@
 # under the License.
 # pylint: disable=unused-argument,invalid-name
 """The base node types for the Relay language."""
-import tvm._ffi
-import tvm.ir
-from tvm.driver import lower, build
-
+from ... import _ffi, ir
+from ...driver import lower, build
 from ...target import get_native_generic_func, GenericFunc
 from ...runtime import Object
 from . import _make
@@ -38,7 +36,7 @@ def get(op_name):
     op : Op
         The op of the corresponding name
     """
-    return tvm.ir.Op.get(op_name)
+    return ir.Op.get(op_name)
 
 
 class OpPattern(object):
@@ -65,7 +63,7 @@ class OpPattern(object):
     OPAQUE = 8
 
 
-@tvm._ffi.register_object("relay.OpImplementation")
+@_ffi.register_object("relay.OpImplementation")
 class OpImplementation(Object):
     """Operator implementation"""
 
@@ -112,12 +110,12 @@ class OpImplementation(Object):
         return _OpImplementationSchedule(self, attrs, outs, target)
 
 
-@tvm._ffi.register_object("relay.OpSpecialization")
+@_ffi.register_object("relay.OpSpecialization")
 class OpSpecialization(Object):
     """Operator specialization"""
 
 
-@tvm._ffi.register_object("relay.OpStrategy")
+@_ffi.register_object("relay.OpStrategy")
 class OpStrategy(Object):
     """Operator strategy"""
 
@@ -184,7 +182,7 @@ def register_compute(op_name, compute=None, level=10):
     level : int
         The priority level
     """
-    return tvm.ir.register_op_attr(op_name, "FTVMCompute", compute, level)
+    return ir.register_op_attr(op_name, "FTVMCompute", compute, level)
 
 
 def register_strategy(op_name, fstrategy=None, level=10):
@@ -205,7 +203,7 @@ def register_strategy(op_name, fstrategy=None, level=10):
     if not isinstance(fstrategy, GenericFunc):
         assert hasattr(fstrategy, "generic_func_node")
         fstrategy = fstrategy.generic_func_node
-    return tvm.ir.register_op_attr(op_name, "FTVMStrategy", fstrategy, level)
+    return ir.register_op_attr(op_name, "FTVMStrategy", fstrategy, level)
 
 
 def register_schedule(op_name, schedule, level=10):
@@ -286,7 +284,7 @@ def register_alter_op_layout(op_name, alter_layout=None, level=10):
     level : int
         The priority level
     """
-    return tvm.ir.register_op_attr(op_name, "FTVMAlterOpLayout", alter_layout, level)
+    return ir.register_op_attr(op_name, "FTVMAlterOpLayout", alter_layout, level)
 
 
 def register_convert_op_layout(op_name, convert_layout=None, level=10):
@@ -303,7 +301,7 @@ def register_convert_op_layout(op_name, convert_layout=None, level=10):
     level : int
         The priority level
     """
-    return tvm.ir.register_op_attr(op_name, "FTVMConvertOpLayout", convert_layout, level)
+    return ir.register_op_attr(op_name, "FTVMConvertOpLayout", convert_layout, level)
 
 
 def register_legalize(op_name, legal_op=None, level=10):
@@ -320,7 +318,7 @@ def register_legalize(op_name, legal_op=None, level=10):
     level : int
         The priority level
     """
-    return tvm.ir.register_op_attr(op_name, "FTVMLegalize", legal_op, level)
+    return ir.register_op_attr(op_name, "FTVMLegalize", legal_op, level)
 
 
 def register_pattern(op_name, pattern, level=10):
@@ -337,7 +335,7 @@ def register_pattern(op_name, pattern, level=10):
     level : int
         The priority level
     """
-    return tvm.ir.register_op_attr(op_name, "TOpPattern", pattern, level)
+    return ir.register_op_attr(op_name, "TOpPattern", pattern, level)
 
 
 def register_gradient(op_name, fgradient=None, level=10):
@@ -354,7 +352,7 @@ def register_gradient(op_name, fgradient=None, level=10):
     level : int
         The priority level
     """
-    return tvm.ir.register_op_attr(op_name, "FPrimalGradient", fgradient, level)
+    return ir.register_op_attr(op_name, "FPrimalGradient", fgradient, level)
 
 
 def register_shape_func(op_name, data_dependant, shape_func=None, level=10):
@@ -376,7 +374,7 @@ def register_shape_func(op_name, data_dependant, shape_func=None, level=10):
         The priority level
     """
     get(op_name).set_attr("TShapeDataDependant", data_dependant, level)
-    return tvm.ir.register_op_attr(op_name, "FShapeFunc", shape_func, level)
+    return ir.register_op_attr(op_name, "FShapeFunc", shape_func, level)
 
 
 def register_external_compiler(op_name, fexternal=None, level=10):
@@ -395,15 +393,15 @@ def register_external_compiler(op_name, fexternal=None, level=10):
     level : int
         The priority level
     """
-    return tvm.ir.register_op_attr(op_name, "FTVMExternalCompiler", fexternal, level)
+    return ir.register_op_attr(op_name, "FTVMExternalCompiler", fexternal, level)
 
 
-@tvm._ffi.register_func("relay.op.compiler._lower")
+_ffi.register_func("relay.op.compiler._lower")
 def _lower(name, schedule, inputs, outputs):
     return lower(schedule, list(inputs) + list(outputs), name=name)
 
 
-@tvm._ffi.register_func("relay.op.compiler._build")
+_ffi.register_func("relay.op.compiler._build")
 def _build(lowered_funcs):
     return build(lowered_funcs, target="llvm")
 
@@ -420,7 +418,7 @@ def debug(expr, debug_func=None):
 
     if debug_func:
         name = "debugger_func{}".format(__DEBUG_COUNTER__)
-        tvm._ffi.register_func(name, debug_func)
+        _ffi.register_func(name, debug_func)
         __DEBUG_COUNTER__ += 1
     else:
         name = ""
@@ -428,4 +426,4 @@ def debug(expr, debug_func=None):
     return _make.debug(expr, name)
 
 
-tvm._ffi._init_api("relay.op", __name__)
+_ffi._init_api("relay.op", __name__)
diff --git a/python/tvm/relay/transform/__init__.py b/python/tvm/relay/transform/__init__.py
index 1d0ea17..9684e42 100644
--- a/python/tvm/relay/transform/__init__.py
+++ b/python/tvm/relay/transform/__init__.py
@@ -19,4 +19,4 @@
 # transformation passes
 from .transform import *
 from .recast import recast
-from . import memory_alloc
+# from . import memory_alloc
diff --git a/python/tvm/relay/transform/memory_alloc.py b/python/tvm/relay/transform/memory_alloc.py
index 66528c8..593a411 100644
--- a/python/tvm/relay/transform/memory_alloc.py
+++ b/python/tvm/relay/transform/memory_alloc.py
@@ -20,14 +20,13 @@ A pass for manifesting explicit memory allocations.
 """
 import numpy as np
 
-from tvm.ir.transform import PassContext, module_pass
-from tvm.relay.transform import InferType
-from tvm import nd, container
+from ... import DataType, register_func, nd, container, cpu
+from ...ir.transform import PassContext, module_pass
+from . import InferType
 from ..function import Function
 from ..expr_functor import ExprVisitor, ExprMutator
 from ..scope_builder import ScopeBuilder
 from .. import op
-from ... import DataType, register_func
 from .. import ty, expr
 from ..backend import compile_engine
 from ..op.memory import flatten_tuple_type, from_tuple_type, to_tuple_type
diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py
index f0f55f6..af1e718 100644
--- a/python/tvm/relay/transform/transform.py
+++ b/python/tvm/relay/transform/transform.py
@@ -23,11 +23,12 @@ import inspect
 import functools
 import warnings
 
+from ...ir import transform as tvm_transform
 import tvm.ir
 from tvm import te
 from tvm.runtime import ndarray as _nd
 
-from tvm import relay
+# from tvm import relay
 from . import _ffi_api
 
 
@@ -82,7 +83,7 @@ def build_config(opt_level=2, required_pass=None, disabled_pass=None, trace=None
 
 
 @tvm._ffi.register_object("relay.FunctionPass")
-class FunctionPass(tvm.ir.transform.Pass):
+class FunctionPass():
     """A pass that works on each tvm.relay.Function in a module. A function
     pass class should be created through `function_pass`.
     """
diff --git a/python/tvm/topi/cuda/__init__.py b/python/tvm/topi/cuda/__init__.py
index 3ff544f..47badb5 100644
--- a/python/tvm/topi/cuda/__init__.py
+++ b/python/tvm/topi/cuda/__init__.py
@@ -17,8 +17,6 @@
 
 # pylint: disable=redefined-builtin, wildcard-import
 """CUDA specific declaration and schedules."""
-from __future__ import absolute_import as _abs
-
 from .conv1d import *
 from .conv1d_transpose_ncw import *
 from .conv2d import *
diff --git a/python/tvm/topi/cuda/sparse.py b/python/tvm/topi/cuda/sparse.py
index d125423..c2b99af 100644
--- a/python/tvm/topi/cuda/sparse.py
+++ b/python/tvm/topi/cuda/sparse.py
@@ -17,7 +17,6 @@
 
 """Sparse operators"""
 import numpy as np
-import scipy.sparse as sp
 
 import tvm
 from tvm import relay, te
@@ -326,6 +325,7 @@ def schedule_sparse_dense_padded(outs):
 
 def pad_sparse_matrix(matrix, blocksize):
     """Pad rows of sparse matrix matrix so that they are a multiple of blocksize."""
+    import scipy.sparse as sp
     assert isinstance(matrix, sp.bsr_matrix)
     new_entries = np.zeros(matrix.shape[0], dtype=matrix.indptr.dtype)
     bsr = matrix.blocksize[0]
@@ -362,6 +362,7 @@ def _alter_sparse_dense_layout(_attrs, inputs, _tinfos, _out_type):
     sparse_dense implementation for one that operates on a padded matrix. We
     also padd the matrix.
     """
+    import scipy.sparse as sp
     if (
         isinstance(inputs[1], relay.Constant)
         and isinstance(inputs[2], relay.Constant)
diff --git a/rust/tvm/Cargo.toml b/rust/tvm/Cargo.toml
index 153a195..c1d8aa8 100644
--- a/rust/tvm/Cargo.toml
+++ b/rust/tvm/Cargo.toml
@@ -50,7 +50,7 @@ tvm-macros = { version = "*", path = "../tvm-macros/" }
 paste = "0.1"
 mashup = "0.1"
 once_cell = "^1.3.1"
-pyo3 = { version = "0.11.1", optional = true }
+pyo3 = { version = "^0.12", optional = true }
 codespan-reporting = "0.9.5"
 structopt = { version = "0.3" }
 
diff --git a/rust/tvm/src/ir/relay/mod.rs b/rust/tvm/src/ir/relay/mod.rs
index 90b7a6a..a6ea684 100644
--- a/rust/tvm/src/ir/relay/mod.rs
+++ b/rust/tvm/src/ir/relay/mod.rs
@@ -24,7 +24,6 @@ use super::expr::BaseExprNode;
 use super::function::BaseFuncNode;
 use super::span::Span;
 use super::ty::{Type, TypeNode};
-use super::span::Span;
 
 use tvm_macros::Object;
 use tvm_rt::NDArray;
diff --git a/rust/tvm/src/python.rs b/rust/tvm/src/python.rs
index 89558af..2b2d374 100644
--- a/rust/tvm/src/python.rs
+++ b/rust/tvm/src/python.rs
@@ -18,6 +18,21 @@
  */
 
 use pyo3::prelude::*;
+use once_cell::sync::OnceCell;
+
+// static TVM_PYTHON: OnceCell<Py<PyModule>> = OnceCell::new();
+
+// fn initialize() -> Py<PyModule> {
+//     TVM_PYTHON.get_or_init(|| {
+//         let gil = Python::acquire_gil();
+//         let py = gil.python();
+//         PyModule::new(py, "__tvm__rust__module__").map_err(|e| {
+//             // We can't display Python exceptions via std::fmt::Display,
+//             // so print the error here manually.
+//             e.print_and_set_sys_last_vars(py);
+//         }).expect("failed to initialize the Python interface").into()
+//     }).clone()
+// }
 
 /// Load the Python interpreter into the address space.
 ///
@@ -29,6 +44,8 @@ use pyo3::prelude::*;
 pub fn load() -> Result<String, ()> {
     let gil = Python::acquire_gil();
     let py = gil.python();
+    // let main_mod = initialize();
+    //let main_mod = main_mod.as_ref(py);
     load_python_tvm_(py).map_err(|e| {
         // We can't display Python exceptions via std::fmt::Display,
         // so print the error here manually.
@@ -36,12 +53,30 @@ pub fn load() -> Result<String, ()> {
     })
 }
 
-// const TVMC_CODE: &'static str = include_str!("tvmc.py");
+fn import_python<'p, 'b: 'p>(py: Python<'p>, to_import: &'b str) -> PyResult<&'p PyModule> {
+    let imported_mod = py.import(to_import)?;
+    PyModule::from_code(py,
+        r#"
+import tvm
+from tvm import relay
+tvm.cleanup()
+"#, "blah", "my_mod")?;
+    // py_mod.add(to_import, imported_mod)?;
+    Ok(imported_mod)
+}
+
+pub fn import(mod_to_import: &str) -> PyResult<()> {
+    let gil = Python::acquire_gil();
+    let py = gil.python();
+    // let main_mod = initialize();
+    // let main_mod = main_mod.as_ref(py);
+    import_python(py, mod_to_import)?;
+    Ok(())
+}
 
 fn load_python_tvm_(py: Python) -> PyResult<String> {
-    let sys = py.import("tvm")?;
-    let version: String = sys.get("__version__")?.extract()?;
-    // py.run(TVMC_CODE, None, None)?;
+    let imported_mod = import_python(py, "tvm")?;
+    let version: String = imported_mod.get("__version__")?.extract()?;
     Ok(version)
 }
 


[incubator-tvm] 04/23: Hacking on Rust inside of TVM

Posted by jr...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

jroesch pushed a commit to branch cargo-build
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git

commit 131e40afeb4bd89f0b378e0e7d8ad440d017d33a
Author: Jared Roesch <jr...@octoml.ai>
AuthorDate: Tue Oct 13 15:25:56 2020 -0700

    Hacking on Rust inside of TVM
---
 CMakeLists.txt                         |   1 +
 cmake/modules/RustExt.cmake            |  13 +
 rust/Cargo.toml                        |   1 +
 rust/compiler-ext/Cargo.toml           |  13 +
 rust/compiler-ext/src/lib.rs           |   7 +
 rust/tvm/src/ir/source_map.rs          |   0
 rust/tvm/test.rly                      |   2 +
 tests/python/relay/test_type_infer2.py | 419 +++++++++++++++++++++++++++++++++
 8 files changed, 456 insertions(+)

diff --git a/CMakeLists.txt b/CMakeLists.txt
index abf2b56..9f82754 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -79,6 +79,7 @@ tvm_option(USE_ARM_COMPUTE_LIB "Build with Arm Compute Library" OFF)
 tvm_option(USE_ARM_COMPUTE_LIB_GRAPH_RUNTIME "Build with Arm Compute Library graph runtime" OFF)
 tvm_option(USE_TENSORRT_CODEGEN "Build with TensorRT Codegen support" OFF)
 tvm_option(USE_TENSORRT_RUNTIME "Build with TensorRT runtime" OFF)
+tvm_option(USE_RUST_EXT "Build with Rust based compiler extensions" OFF)
 
 # include directories
 include_directories(${CMAKE_INCLUDE_PATH})
diff --git a/cmake/modules/RustExt.cmake b/cmake/modules/RustExt.cmake
new file mode 100644
index 0000000..45e46bd
--- /dev/null
+++ b/cmake/modules/RustExt.cmake
@@ -0,0 +1,13 @@
+if(USE_RUST_EXT)
+    set(RUST_SRC_DIR "rust")
+    set(CARGO_OUT_DIR "rust/target"
+    set(COMPILER_EXT_PATH "${CARGO_OUT_DIR}/target/release/libcompiler_ext.dylib")
+
+    add_custom_command(
+        OUTPUT "${COMPILER_EXT_PATH}"
+        COMMAND cargo build --release
+        MAIN_DEPENDENCY "${RUST_SRC_DIR}"
+        WORKING_DIRECTORY "${RUST_SRC_DIR}/compiler-ext")
+
+    target_link_libraries(tvm "${COMPILER_EXT_PATH}" PRIVATE)
+endif(USE_RUST_EXT)
diff --git a/rust/Cargo.toml b/rust/Cargo.toml
index 28312a5..7c092d8 100644
--- a/rust/Cargo.toml
+++ b/rust/Cargo.toml
@@ -29,4 +29,5 @@ members = [
 	"tvm-graph-rt/tests/test_tvm_dso",
 	"tvm-graph-rt/tests/test_wasm32",
 	"tvm-graph-rt/tests/test_nn",
+	"compiler-ext",
 ]
diff --git a/rust/compiler-ext/Cargo.toml b/rust/compiler-ext/Cargo.toml
new file mode 100644
index 0000000..76d10eb
--- /dev/null
+++ b/rust/compiler-ext/Cargo.toml
@@ -0,0 +1,13 @@
+[package]
+name = "compiler-ext"
+version = "0.1.0"
+authors = ["Jared Roesch <jr...@octoml.ai>"]
+edition = "2018"
+# TODO(@jroesch): would be cool to figure out how to statically link instead.
+
+[lib]
+crate-type = ["cdylib"]
+
+# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
+
+[dependencies]
diff --git a/rust/compiler-ext/src/lib.rs b/rust/compiler-ext/src/lib.rs
new file mode 100644
index 0000000..31e1bb2
--- /dev/null
+++ b/rust/compiler-ext/src/lib.rs
@@ -0,0 +1,7 @@
+#[cfg(test)]
+mod tests {
+    #[test]
+    fn it_works() {
+        assert_eq!(2 + 2, 4);
+    }
+}
diff --git a/rust/tvm/src/ir/source_map.rs b/rust/tvm/src/ir/source_map.rs
new file mode 100644
index 0000000..e69de29
diff --git a/rust/tvm/test.rly b/rust/tvm/test.rly
new file mode 100644
index 0000000..d8b7c69
--- /dev/null
+++ b/rust/tvm/test.rly
@@ -0,0 +1,2 @@
+#[version = "0.0.5"]
+fn @main(%x: int32) -> float32 { %x }
diff --git a/tests/python/relay/test_type_infer2.py b/tests/python/relay/test_type_infer2.py
new file mode 100644
index 0000000..6758d96
--- /dev/null
+++ b/tests/python/relay/test_type_infer2.py
@@ -0,0 +1,419 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Test that type checker correcly computes types
+   for expressions.
+"""
+import pytest
+import tvm
+
+from tvm import IRModule, te, relay, parser
+from tvm.relay import op, transform, analysis
+from tvm.relay import Any
+
+
+def infer_mod(mod, annotate_spans=True):
+    if annotate_spans:
+        mod = relay.transform.AnnotateSpans()(mod)
+
+    mod = transform.InferType()(mod)
+    return mod
+
+
+def infer_expr(expr, annotate_spans=True):
+    mod = IRModule.from_expr(expr)
+    mod = infer_mod(mod, annotate_spans)
+    mod = transform.InferType()(mod)
+    entry = mod["main"]
+    return entry if isinstance(expr, relay.Function) else entry.body
+
+
+def assert_has_type(expr, typ, mod=None):
+    if not mod:
+        mod = tvm.IRModule({})
+
+    mod["main"] = expr
+    mod = infer_mod(mod)
+    checked_expr = mod["main"]
+    checked_type = checked_expr.checked_type
+    if checked_type != typ:
+        raise RuntimeError("Type mismatch %s vs %s" % (checked_type, typ))
+
+
+def initialize_box_adt(mod):
+    # initializes simple ADT for tests
+    box = relay.GlobalTypeVar("box")
+    tv = relay.TypeVar("tv")
+    constructor = relay.Constructor("constructor", [tv], box)
+    data = relay.TypeData(box, [tv], [constructor])
+    mod[box] = data
+    return box, constructor
+
+
+def test_monomorphic_let():
+    "Program: let %x = 1; %x"
+    # TODO(@jroesch): this seems whack.
+    sb = relay.ScopeBuilder()
+    x = relay.var("x", dtype="float64", shape=())
+    x = sb.let("x", relay.const(1.0, "float64"))
+    sb.ret(x)
+    xchecked = infer_expr(sb.get())
+    assert xchecked.checked_type == relay.scalar_type("float64")
+
+
+def test_single_op():
+    "Program: fn (%x : float32) { let %t1 = f(%x); %t1 }"
+    x = relay.var("x", shape=[])
+    func = relay.Function([x], op.log(x))
+    ttype = relay.TensorType([], dtype="float32")
+    assert_has_type(func, relay.FuncType([ttype], ttype))
+
+
+def test_add_broadcast_op():
+    """
+    Program:
+        fn (%x: Tensor[(10, 4), float32], %y: Tensor[(5, 10, 1), float32])
+            -> Tensor[(5, 10, 4), float32] {
+            %x + %y
+        }
+    """
+    x = relay.var("x", shape=(10, 4))
+    y = relay.var("y", shape=(5, 10, 1))
+    z = x + y
+    func = relay.Function([x, y], z)
+    t1 = relay.TensorType((10, 4), "float32")
+    t2 = relay.TensorType((5, 10, 1), "float32")
+    t3 = relay.TensorType((5, 10, 4), "float32")
+    expected_ty = relay.FuncType([t1, t2], t3)
+    assert_has_type(func, expected_ty)
+
+
+def test_dual_op():
+    """Program:
+    fn (%x : Tensor[(10, 10), float32]) {
+      let %t1 = log(x);
+      let %t2 = add(%t1, %x);
+      %t1
+    }
+    """
+    tp = relay.TensorType((10, 10), "float32")
+    x = relay.var("x", tp)
+    sb = relay.ScopeBuilder()
+    t1 = sb.let("t1", relay.log(x))
+    t2 = sb.let("t2", relay.add(t1, x))
+    sb.ret(t2)
+    f = relay.Function([x], sb.get())
+    fchecked = infer_expr(f)
+    assert fchecked.checked_type == relay.FuncType([tp], tp)
+
+
+def test_decl():
+    """Program:
+    def @f(%x : Tensor[(10, 10), float32]) {
+        log(%x)
+    }
+    """
+    tp = relay.TensorType((10, 10))
+    x = relay.var("x", tp)
+    f = relay.Function([x], relay.log(x))
+    fchecked = infer_expr(f)
+    assert fchecked.checked_type == relay.FuncType([tp], tp)
+
+
+def test_recursion():
+    """
+    Program:
+       def @f(%n: int32, %data: float32) -> float32 {
+          if (%n == 0) {
+              %data
+          } else {
+              @f(%n - 1, log(%data))
+          }
+       }
+    """
+    sb = relay.ScopeBuilder()
+    f = relay.GlobalVar("f")
+    ti32 = relay.scalar_type("int32")
+    tf32 = relay.scalar_type("float32")
+    n = relay.var("n", ti32)
+    data = relay.var("data", tf32)
+
+    with sb.if_scope(relay.equal(n, relay.const(0, ti32))):
+        sb.ret(data)
+    with sb.else_scope():
+        sb.ret(f(relay.subtract(n, relay.const(1, ti32)), relay.log(data)))
+    mod = tvm.IRModule()
+    mod[f] = relay.Function([n, data], sb.get())
+    mod = infer_mod(mod)
+    assert "@f(%1, %2)" in mod.astext()
+    assert mod["f"].checked_type == relay.FuncType([ti32, tf32], tf32)
+
+
+def test_incomplete_call():
+    tt = relay.scalar_type("int32")
+    x = relay.var("x", tt)
+    f = relay.var("f")
+    func = relay.Function([x, f], relay.Call(f, [x]), tt)
+
+    ft = infer_expr(func)
+    f_type = relay.FuncType([tt], tt)
+    assert ft.checked_type == relay.FuncType([tt, f_type], tt)
+
+
+def test_higher_order_argument():
+    a = relay.TypeVar("a")
+    x = relay.Var("x", a)
+    id_func = relay.Function([x], x, a, [a])
+
+    b = relay.TypeVar("b")
+    f = relay.Var("f", relay.FuncType([b], b))
+    y = relay.Var("y", b)
+    ho_func = relay.Function([f, y], f(y), b, [b])
+
+    # id func should be an acceptable argument to the higher-order
+    # function even though id_func takes a type parameter
+    ho_call = ho_func(id_func, relay.const(0, "int32"))
+
+    hc = infer_expr(ho_call)
+    expected = relay.scalar_type("int32")
+    assert hc.checked_type == expected
+
+
+def test_higher_order_return():
+    a = relay.TypeVar("a")
+    x = relay.Var("x", a)
+    id_func = relay.Function([x], x, a, [a])
+
+    b = relay.TypeVar("b")
+    nested_id = relay.Function([], id_func, relay.FuncType([b], b), [b])
+
+    ft = infer_expr(nested_id)
+    assert ft.checked_type == relay.FuncType([], relay.FuncType([b], b), [b])
+
+
+def test_higher_order_nested():
+    a = relay.TypeVar("a")
+    x = relay.Var("x", a)
+    id_func = relay.Function([x], x, a, [a])
+
+    choice_t = relay.FuncType([], relay.scalar_type("bool"))
+    f = relay.Var("f", choice_t)
+
+    b = relay.TypeVar("b")
+    z = relay.Var("z")
+    top = relay.Function(
+        [f], relay.If(f(), id_func, relay.Function([z], z)), relay.FuncType([b], b), [b]
+    )
+
+    expected = relay.FuncType([choice_t], relay.FuncType([b], b), [b])
+    ft = infer_expr(top)
+    assert ft.checked_type == expected
+
+
+def test_tuple():
+    tp = relay.TensorType((10,))
+    x = relay.var("x", tp)
+    res = relay.Tuple([x, x])
+    assert infer_expr(res).checked_type == relay.TupleType([tp, tp])
+
+
+def test_ref():
+    x = relay.var("x", "float32")
+    y = relay.var("y", "float32")
+    r = relay.RefCreate(x)
+    st = relay.scalar_type("float32")
+    assert infer_expr(r).checked_type == relay.RefType(st)
+    g = relay.RefRead(r)
+    assert infer_expr(g).checked_type == st
+    w = relay.RefWrite(r, y)
+    assert infer_expr(w).checked_type == relay.TupleType([])
+
+
+def test_free_expr():
+    x = relay.var("x", "float32")
+    y = relay.add(x, x)
+    yy = infer_expr(y, annotate_spans=False)
+    assert tvm.ir.structural_equal(yy.args[0], x, map_free_vars=True)
+    assert yy.checked_type == relay.scalar_type("float32")
+    assert x.vid.same_as(yy.args[0].vid)
+
+
+def test_type_args():
+    x = relay.var("x", shape=(10, 10))
+    y = relay.var("y", shape=(1, 10))
+    z = relay.add(x, y)
+    ty_z = infer_expr(z)
+    ty_args = ty_z.type_args
+    assert len(ty_args) == 2
+    assert ty_args[0].dtype == "float32"
+    assert ty_args[1].dtype == "float32"
+    sh1 = ty_args[0].shape
+    sh2 = ty_args[1].shape
+    assert sh1[0].value == 10
+    assert sh1[1].value == 10
+    assert sh2[0].value == 1
+    assert sh2[1].value == 10
+
+
+def test_global_var_recursion():
+    mod = tvm.IRModule({})
+    gv = relay.GlobalVar("main")
+    x = relay.var("x", shape=[])
+    tt = relay.scalar_type("float32")
+
+    func = relay.Function([x], relay.Call(gv, [x]), tt)
+    mod[gv] = func
+    mod = infer_mod(mod)
+    func_ty = mod["main"].checked_type
+
+    assert func_ty == relay.FuncType([tt], tt)
+
+
+def test_equal():
+    i = relay.var("i", shape=[], dtype="int32")
+    eq = op.equal(i, relay.const(0, dtype="int32"))
+    func = relay.Function([i], eq)
+    ft = infer_expr(func)
+    expected = relay.FuncType([relay.scalar_type("int32")], relay.scalar_type("bool"))
+    assert ft.checked_type == expected
+
+    assert ft.checked_type == relay.FuncType(
+        [relay.scalar_type("int32")], relay.scalar_type("bool")
+    )
+
+
+def test_constructor_type():
+    mod = tvm.IRModule()
+    box, constructor = initialize_box_adt(mod)
+
+    a = relay.TypeVar("a")
+    x = relay.Var("x", a)
+    func = relay.Function([x], constructor(x), box(a), [a])
+    mod["main"] = func
+    mod = infer_mod(mod)
+    func_ty = mod["main"].checked_type
+    box = mod.get_global_type_var("box")
+    expected = relay.FuncType([a], box(a), [a])
+    assert func_ty == expected
+
+
+def test_constructor_call():
+    mod = tvm.IRModule()
+    box, constructor = initialize_box_adt(mod)
+
+    box_unit = constructor(relay.Tuple([]))
+    box_constant = constructor(relay.const(0, "float32"))
+
+    func = relay.Function([], relay.Tuple([box_unit, box_constant]))
+    mod["main"] = func
+    mod = infer_mod(mod)
+    ret_type = mod["main"].checked_type.ret_type.fields
+    # NB(@jroesch): when we annotate spans the ast fragments before
+    # annotation the previous fragments will no longer be directly equal.
+    box = mod.get_global_type_var("box")
+    expected1 = box(relay.TupleType([]))
+    expected2 = box(relay.TensorType((), "float32"))
+    assert ret_type[0] == expected1
+    assert ret_type[1] == expected2
+
+
+def test_adt_match():
+    mod = tvm.IRModule()
+    box, constructor = initialize_box_adt(mod)
+
+    v = relay.Var("v", relay.TensorType((), "float32"))
+    match = relay.Match(
+        constructor(relay.const(0, "float32")),
+        [
+            relay.Clause(
+                relay.PatternConstructor(constructor, [relay.PatternVar(v)]), relay.Tuple([])
+            ),
+            # redundant but shouldn't matter to typechecking
+            relay.Clause(relay.PatternWildcard(), relay.Tuple([])),
+        ],
+    )
+
+    func = relay.Function([], match)
+    mod["main"] = func
+    mod = infer_mod(mod)
+    actual = mod["main"].checked_type.ret_type
+    assert actual == relay.TupleType([])
+
+
+def test_adt_match_type_annotations():
+    mod = tvm.IRModule()
+    box, constructor = initialize_box_adt(mod)
+
+    # the only type annotation is inside the match pattern var
+    # but that should be enough info
+    tt = relay.TensorType((2, 2), "float32")
+    x = relay.Var("x")
+    mv = relay.Var("mv", tt)
+    match = relay.Match(
+        constructor(x),
+        [
+            relay.Clause(
+                relay.PatternConstructor(constructor, [relay.PatternVar(mv)]), relay.Tuple([])
+            )
+        ],
+    )
+
+    mod["main"] = relay.Function([x], match)
+    mod = infer_mod(mod)
+    ft = mod["main"].checked_type
+    assert ft == relay.FuncType([tt], relay.TupleType([]))
+
+
+def test_let_polymorphism():
+    id = relay.Var("id")
+    xt = relay.TypeVar("xt")
+    x = relay.Var("x", xt)
+    body = relay.Tuple([id(relay.const(1)), id(relay.Tuple([]))])
+    body = relay.Let(id, relay.Function([x], x, xt, [xt]), body)
+    body = infer_expr(body)
+    int32 = relay.TensorType((), "int32")
+    tvm.ir.assert_structural_equal(body.checked_type, relay.TupleType([int32, relay.TupleType([])]))
+
+
+def test_if():
+    choice_t = relay.FuncType([], relay.scalar_type("bool"))
+    f = relay.Var("f", choice_t)
+    true_branch = relay.Var("True", relay.TensorType([Any(), 1], dtype="float32"))
+    false_branch = relay.Var("False", relay.TensorType([Any(), Any()], dtype="float32"))
+    top = relay.Function([f, true_branch, false_branch], relay.If(f(), true_branch, false_branch))
+    ft = infer_expr(top)
+    tvm.ir.assert_structural_equal(ft.ret_type, relay.TensorType([Any(), 1], dtype="float32"))
+
+
+def test_type_arg_infer():
+    code = """
+#[version = "0.0.5"]
+def @id[A](%x: A) -> A {
+  %x
+}
+def @main(%f: float32) -> float32 {
+  @id(%f)
+}
+"""
+    mod = tvm.parser.fromtext(code)
+    mod = transform.InferType()(mod)
+    tvm.ir.assert_structural_equal(mod["main"].body.type_args, [relay.TensorType((), "float32")])
+
+
+if __name__ == "__main__":
+    import sys
+
+    pytest.main(sys.argv)


[incubator-tvm] 02/23: Codespan example almost working

Posted by jr...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

jroesch pushed a commit to branch cargo-build
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git

commit 77ba30993a7883c142b05e511e8a5a7a91116b2f
Author: Jared Roesch <jr...@octoml.ai>
AuthorDate: Fri Oct 9 23:59:51 2020 -0700

    Codespan example almost working
---
 rust/tvm-sys/src/packed_func.rs  |   1 +
 rust/tvm/Cargo.toml              |   2 +
 rust/tvm/src/bin/tyck.rs         |  24 ++++++++
 rust/tvm/src/ir/diagnostics.rs   | 121 +++++++++++++++++++++++++++++----------
 rust/tvm/src/ir/relay/visitor.rs |  24 ++++++++
 5 files changed, 143 insertions(+), 29 deletions(-)

diff --git a/rust/tvm-sys/src/packed_func.rs b/rust/tvm-sys/src/packed_func.rs
index f7b289c..7b8d529 100644
--- a/rust/tvm-sys/src/packed_func.rs
+++ b/rust/tvm-sys/src/packed_func.rs
@@ -101,6 +101,7 @@ macro_rules! TVMPODValue {
                         TVMArgTypeCode_kTVMOpaqueHandle => Handle($value.v_handle),
                         TVMArgTypeCode_kTVMDLTensorHandle => ArrayHandle($value.v_handle as TVMArrayHandle),
                         TVMArgTypeCode_kTVMObjectHandle => ObjectHandle($value.v_handle),
+                        TVMArgTypeCode_kTVMObjectRValueRefArg => ObjectHandle(*($value.v_handle as *mut *mut c_void)),
                         TVMArgTypeCode_kTVMModuleHandle => ModuleHandle($value.v_handle),
                         TVMArgTypeCode_kTVMPackedFuncHandle => FuncHandle($value.v_handle),
                         TVMArgTypeCode_kTVMNDArrayHandle => NDArrayHandle($value.v_handle),
diff --git a/rust/tvm/Cargo.toml b/rust/tvm/Cargo.toml
index 55fc179..71a4b93 100644
--- a/rust/tvm/Cargo.toml
+++ b/rust/tvm/Cargo.toml
@@ -41,6 +41,8 @@ paste = "0.1"
 mashup = "0.1"
 once_cell = "^1.3.1"
 pyo3 = { version = "0.11.1", optional = true }
+codespan-reporting = "0.9.5"
+structopt = { version = "0.3" }
 
 [features]
 default = ["python"]
diff --git a/rust/tvm/src/bin/tyck.rs b/rust/tvm/src/bin/tyck.rs
new file mode 100644
index 0000000..9300412
--- /dev/null
+++ b/rust/tvm/src/bin/tyck.rs
@@ -0,0 +1,24 @@
+use std::path::PathBuf;
+
+use anyhow::Result;
+use structopt::StructOpt;
+
+use tvm::ir::diagnostics::codespan;
+use tvm::ir::IRModule;
+
+
+#[derive(Debug, StructOpt)]
+#[structopt(name = "tyck", about = "Parse and type check a Relay program.")]
+struct Opt {
+    /// Input file
+    #[structopt(parse(from_os_str))]
+    input: PathBuf,
+}
+
+fn main() -> Result<()> {
+    codespan::init().expect("Rust based diagnostics");
+    let opt = Opt::from_args();
+    println!("{:?}", &opt);
+    let file = IRModule::parse_file(opt.input)?;
+    Ok(())
+}
diff --git a/rust/tvm/src/ir/diagnostics.rs b/rust/tvm/src/ir/diagnostics.rs
index 799a10c..e434d3f 100644
--- a/rust/tvm/src/ir/diagnostics.rs
+++ b/rust/tvm/src/ir/diagnostics.rs
@@ -24,13 +24,31 @@
 
 use tvm_macros::{Object, external};
 use super::module::IRModule;
-use crate::runtime::{function::{Function, Typed}, array::Array, string::String as TString};
-use crate::runtime::object::{Object, ObjectRef};
+use crate::runtime::{function::{self, Function, ToFunction, Typed}, array::Array, string::String as TString};
+use crate::runtime::object::{Object, ObjectPtr, ObjectRef};
 use crate::runtime::function::Result;
 use super::span::Span;
 
 type SourceName = ObjectRef;
 
+// Get the the diagnostic renderer.
+external! {
+    #[name("node.ArrayGetItem")]
+    fn get_renderer() -> DiagnosticRenderer;
+
+    #[name("diagnostics.DiagnosticRenderer")]
+    fn diagnostic_renderer(func: Function) -> DiagnosticRenderer;
+
+    #[name("diagnostics.Emit")]
+    fn emit(ctx: DiagnosticContext, diagnostic: Diagnostic) -> ();
+
+    #[name("diagnostics.DiagnosticContextRender")]
+    fn diagnostic_context_render(ctx: DiagnosticContext) -> ();
+
+    #[name("diagnostics.ClearRenderer")]
+    fn clear_renderer() -> ();
+}
+
 /// The diagnostic level, controls the printing of the message.
 #[repr(C)]
 pub enum DiagnosticLevel {
@@ -171,26 +189,20 @@ pub struct DiagnosticContextNode {
     pub renderer: DiagnosticRenderer,
 }
 
-// Get the the diagnostic renderer.
-external! {
-    #[name("node.ArrayGetItem")]
-    fn get_renderer() -> DiagnosticRenderer;
-
-    #[name("diagnostics.DiagnosticRenderer")]
-    fn diagnostic_renderer(func: Function) -> DiagnosticRenderer;
-
-    #[name("diagnostics.Emit")]
-    fn emit(ctx: DiagnosticContext, diagnostic: Diagnostic) -> ();
-
-    #[name("diagnostics.DiagnosticContextRender")]
-    fn diagnostic_context_render(ctx: DiagnosticContext) -> ();
-}
-
 /// A diagnostic context which records active errors
 /// and contains a renderer.
 impl DiagnosticContext {
-    pub fn new(module: IRModule, renderer: DiagnosticRenderer) {
-        todo!()
+    pub fn new<F>(module: IRModule, render_func: F) -> DiagnosticContext
+    where F: Fn(DiagnosticContext) -> () + 'static
+    {
+        let renderer = diagnostic_renderer(render_func.to_function()).unwrap();
+        let node = DiagnosticContextNode {
+            base: Object::base_object::<DiagnosticContextNode>(),
+            module,
+            diagnostics: Array::from_vec(vec![]).unwrap(),
+            renderer,
+        };
+        DiagnosticContext(Some(ObjectPtr::new(node)))
     }
 
     pub fn default(module: IRModule) -> DiagnosticContext {
@@ -223,17 +235,68 @@ impl DiagnosticContext {
 //     If the render_func is None it will remove the current custom renderer
 //     and return to default behavior.
 fn override_renderer<F>(opt_func: Option<F>) -> Result<()>
-where F: Fn(DiagnosticContext) -> ()
+where F: Fn(DiagnosticContext) -> () + 'static
 {
-    todo!()
-    // fn ()
-    // diagnostic_renderer(func)
-    // if render_func:
 
-    //     def _render_factory():
-    //         return DiagnosticRenderer(render_func)
+    match opt_func {
+        None => clear_renderer(),
+        Some(func) => {
+            let func = func.to_function();
+            let render_factory = move || {
+                diagnostic_renderer(func.clone()).unwrap()
+            };
+
+            function::register_override(
+                render_factory,
+                "diagnostics.OverrideRenderer",
+                true)?;
+
+            Ok(())
+        }
+    }
+}
+
+pub mod codespan {
+    use super::*;
+
+    use codespan_reporting::diagnostic::{Diagnostic as CDiagnostic, Label, Severity};
+    use codespan_reporting::files::SimpleFiles;
+    use codespan_reporting::term::termcolor::{ColorChoice, StandardStream};
+
+    pub fn to_diagnostic(diag: super::Diagnostic) -> CDiagnostic<String> {
+        let severity = match diag.level {
+            DiagnosticLevel::Error => Severity::Error,
+            DiagnosticLevel::Warning => Severity::Warning,
+            DiagnosticLevel::Note => Severity::Note,
+            DiagnosticLevel::Help => Severity::Help,
+            DiagnosticLevel::Bug => Severity::Bug,
+        };
+
+        let file_id = "foo".into(); // diag.span.source_name;
+
+        let message: String = diag.message.as_str().unwrap().into();
+        let inner_message: String = "expected `String`, found `Nat`".into();
+        let diagnostic = CDiagnostic::new(severity)
+            .with_message(message)
+            .with_code("EXXX")
+            .with_labels(vec![
+                Label::primary(file_id, 328..331).with_message(inner_message),
+            ]);
+
+        diagnostic
+    }
+
+    pub fn init() -> Result<()> {
+        let mut files: SimpleFiles<String, String> = SimpleFiles::new();
+        let render_fn = move |diag_ctx: DiagnosticContext| {
+            // let source_map = diag_ctx.module.source_map;
+            for diagnostic in diag_ctx.diagnostics {
+
+            }
+            panic!("render_fn");
+        };
 
-    //     register_func("diagnostics.OverrideRenderer", _render_factory, override=True)
-    // else:
-    //     _ffi_api.ClearRenderer()
+        override_renderer(Some(render_fn))?;
+        Ok(())
+    }
 }
diff --git a/rust/tvm/src/ir/relay/visitor.rs b/rust/tvm/src/ir/relay/visitor.rs
new file mode 100644
index 0000000..3166174
--- /dev/null
+++ b/rust/tvm/src/ir/relay/visitor.rs
@@ -0,0 +1,24 @@
+use super::Expr;
+
+macro_rules! downcast_match {
+    ($id:ident; { $($t:ty => $arm:expr $(,)? )+ , else => $default:expr }) => {
+        $( if let Ok($id) = $id.downcast_clone::<$t>() { $arm } else )+
+        { $default }
+    }
+}
+
+trait ExprVisitorMut {
+    fn visit(&mut self, expr: Expr) {
+        downcast_match!(expr; {
+            else => {
+                panic!()
+            }
+        });
+    }
+
+    fn visit(&mut self, expr: Expr);
+}
+
+// trait ExprTransformer {
+//     fn
+// }


[incubator-tvm] 14/23: Format and cleanup

Posted by jr...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

jroesch pushed a commit to branch cargo-build
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git

commit 04a9779359380ae383405f5f72a66e79631ee2d5
Author: Jared Roesch <jr...@octoml.ai>
AuthorDate: Fri Oct 16 16:31:37 2020 -0700

    Format and cleanup
---
 python/tvm/ir/diagnostics/__init__.py   |  1 +
 rust/compiler-ext/src/lib.rs            |  3 +-
 rust/tvm-rt/src/array.rs                | 11 +++--
 rust/tvm-rt/src/errors.rs               |  4 +-
 rust/tvm-rt/src/function.rs             |  2 +-
 rust/tvm/src/bin/tyck.rs                |  8 +--
 rust/tvm/src/ir/diagnostics/codespan.rs | 87 +++++++++++++++++++--------------
 rust/tvm/src/ir/mod.rs                  |  2 +-
 rust/tvm/src/ir/module.rs               |  5 +-
 rust/tvm/src/ir/relay/mod.rs            |  2 +-
 rust/tvm/src/ir/relay/visitor.rs        | 24 ---------
 rust/tvm/src/ir/source_map.rs           | 13 ++---
 rust/tvm/src/ir/span.rs                 |  2 +-
 src/ir/expr.cc                          | 11 -----
 14 files changed, 79 insertions(+), 96 deletions(-)

diff --git a/python/tvm/ir/diagnostics/__init__.py b/python/tvm/ir/diagnostics/__init__.py
index 0ad2a7a..3a6402c 100644
--- a/python/tvm/ir/diagnostics/__init__.py
+++ b/python/tvm/ir/diagnostics/__init__.py
@@ -37,6 +37,7 @@ def get_renderer():
     """
     return _ffi_api.GetRenderer()
 
+
 @tvm.register_func("diagnostics.override_renderer")
 def override_renderer(render_func):
     """
diff --git a/rust/compiler-ext/src/lib.rs b/rust/compiler-ext/src/lib.rs
index c136d06..346f40f 100644
--- a/rust/compiler-ext/src/lib.rs
+++ b/rust/compiler-ext/src/lib.rs
@@ -36,8 +36,7 @@ tvm::export!(test_fn, test_fn2);
 #[no_mangle]
 fn compiler_ext_initialize() -> i32 {
     let _ = env_logger::try_init();
-    tvm_export("rust_ext")
-        .expect("failed to initialize Rust compiler_ext");
+    tvm_export("rust_ext").expect("failed to initialize Rust compiler_ext");
     log::debug!("done!");
     return 0;
 }
diff --git a/rust/tvm-rt/src/array.rs b/rust/tvm-rt/src/array.rs
index 032ca79..66e32a7 100644
--- a/rust/tvm-rt/src/array.rs
+++ b/rust/tvm-rt/src/array.rs
@@ -18,8 +18,8 @@
  */
 
 use std::convert::{TryFrom, TryInto};
-use std::marker::PhantomData;
 use std::iter::{IntoIterator, Iterator};
+use std::marker::PhantomData;
 
 use crate::errors::Error;
 use crate::object::{IsObjectRef, Object, ObjectPtr, ObjectRef};
@@ -93,8 +93,7 @@ impl<T: IsObjectRef> Iterator for IntoIter<T> {
 
     fn next(&mut self) -> Option<Self::Item> {
         if self.pos < self.size {
-            let item = self.array.get(self.pos)
-                .expect("should not fail");
+            let item = self.array.get(self.pos).expect("should not fail");
             self.pos += 1;
             Some(item)
         } else {
@@ -109,7 +108,11 @@ impl<T: IsObjectRef> IntoIterator for Array<T> {
 
     fn into_iter(self) -> Self::IntoIter {
         let size = self.len() as isize;
-        IntoIter { array: self, pos: 0, size: size }
+        IntoIter {
+            array: self,
+            pos: 0,
+            size: size,
+        }
     }
 }
 
diff --git a/rust/tvm-rt/src/errors.rs b/rust/tvm-rt/src/errors.rs
index 3de9f3c..31ce385 100644
--- a/rust/tvm-rt/src/errors.rs
+++ b/rust/tvm-rt/src/errors.rs
@@ -68,7 +68,9 @@ pub enum Error {
     Infallible(#[from] std::convert::Infallible),
     #[error("a panic occurred while executing a Rust packed function")]
     Panic,
-    #[error("one or more error diagnostics were emitted, please check diagnostic render for output.")]
+    #[error(
+        "one or more error diagnostics were emitted, please check diagnostic render for output."
+    )]
     DiagnosticError(String),
     #[error("{0}")]
     Raw(String),
diff --git a/rust/tvm-rt/src/function.rs b/rust/tvm-rt/src/function.rs
index 173b60a..4c6f56e 100644
--- a/rust/tvm-rt/src/function.rs
+++ b/rust/tvm-rt/src/function.rs
@@ -128,7 +128,7 @@ impl Function {
                 type_codes.as_mut_ptr() as *mut c_int,
                 num_args as c_int,
                 &mut ret_val as *mut _,
-                &mut ret_type_code as *mut _
+                &mut ret_type_code as *mut _,
             )
         };
 
diff --git a/rust/tvm/src/bin/tyck.rs b/rust/tvm/src/bin/tyck.rs
index 13470e7..e9c2663 100644
--- a/rust/tvm/src/bin/tyck.rs
+++ b/rust/tvm/src/bin/tyck.rs
@@ -20,9 +20,11 @@ fn main() -> Result<()> {
     let opt = Opt::from_args();
     println!("{:?}", &opt);
     let _module = match IRModule::parse_file(opt.input) {
-        Err(ir::module::Error::TVM(Error::DiagnosticError(_))) => { return Ok(()) },
-        Err(e) => { return Err(e.into()); },
-        Ok(module) => module
+        Err(ir::module::Error::TVM(Error::DiagnosticError(_))) => return Ok(()),
+        Err(e) => {
+            return Err(e.into());
+        }
+        Ok(module) => module,
     };
 
     Ok(())
diff --git a/rust/tvm/src/ir/diagnostics/codespan.rs b/rust/tvm/src/ir/diagnostics/codespan.rs
index 9fc1ee0..9a31691 100644
--- a/rust/tvm/src/ir/diagnostics/codespan.rs
+++ b/rust/tvm/src/ir/diagnostics/codespan.rs
@@ -1,3 +1,24 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/// A TVM diagnostics renderer which uses the Rust `codespan`
+/// library to produce error messages.
 use std::collections::HashMap;
 use std::sync::{Arc, Mutex};
 
@@ -6,13 +27,8 @@ use codespan_reporting::files::SimpleFiles;
 use codespan_reporting::term::termcolor::{ColorChoice, StandardStream};
 use codespan_reporting::term::{self, ColorArg};
 
-use crate::ir::source_map::*;
 use super::*;
-
-enum StartOrEnd {
-    Start,
-    End,
-}
+use crate::ir::source_map::*;
 
 struct ByteRange<FileId> {
     file_id: FileId,
@@ -26,7 +42,7 @@ enum FileSpanToByteRange {
         /// Map character regions which are larger then 1-byte to length.
         lengths: HashMap<isize, isize>,
         source: String,
-    }
+    },
 }
 
 impl FileSpanToByteRange {
@@ -34,24 +50,11 @@ impl FileSpanToByteRange {
         let mut last_index = 0;
         let mut is_ascii = true;
         if source.is_ascii() {
-            let line_lengths =
-                source
-                    .lines()
-                    .map(|line| line.len())
-                    .collect();
+            let line_lengths = source.lines().map(|line| line.len()).collect();
             FileSpanToByteRange::AsciiSource(line_lengths)
         } else {
             panic!()
         }
-
-        // for (index, _) in source.char_indices() {
-        //     if last_index - 1 != last_index {
-        //         is_ascii = false;
-        //     } else {
-        //         panic!();
-        //     }
-        //     last_index = index;
-        // }
     }
 
     fn lookup(&self, span: &Span) -> ByteRange<String> {
@@ -61,22 +64,34 @@ impl FileSpanToByteRange {
 
         match self {
             AsciiSource(ref line_lengths) => {
-                let start_pos = (&line_lengths[0..(span.line - 1) as usize]).into_iter().sum::<usize>() + (span.column) as usize;
-                let end_pos = (&line_lengths[0..(span.end_line - 1) as usize]).into_iter().sum::<usize>() + (span.end_column) as usize;
-                ByteRange { file_id: source_name, start_pos, end_pos }
-            },
-            _ => panic!()
+                let start_pos = (&line_lengths[0..(span.line - 1) as usize])
+                    .into_iter()
+                    .sum::<usize>()
+                    + (span.column) as usize;
+                let end_pos = (&line_lengths[0..(span.end_line - 1) as usize])
+                    .into_iter()
+                    .sum::<usize>()
+                    + (span.end_column) as usize;
+                ByteRange {
+                    file_id: source_name,
+                    start_pos,
+                    end_pos,
+                }
+            }
+            _ => panic!(),
         }
     }
 }
 
 struct SpanToByteRange {
-    map: HashMap<String, FileSpanToByteRange>
+    map: HashMap<String, FileSpanToByteRange>,
 }
 
 impl SpanToByteRange {
     fn new() -> SpanToByteRange {
-        SpanToByteRange { map: HashMap::new() }
+        SpanToByteRange {
+            map: HashMap::new(),
+        }
     }
 
     pub fn add_source(&mut self, source: Source) {
@@ -86,7 +101,8 @@ impl SpanToByteRange {
             panic!()
         } else {
             let source = source.source.as_str().expect("fpp").into();
-            self.map.insert(source_name, FileSpanToByteRange::new(source));
+            self.map
+                .insert(source_name, FileSpanToByteRange::new(source));
         }
     }
 
@@ -143,9 +159,10 @@ impl DiagnosticState {
         let diagnostic = CDiagnostic::new(severity)
             .with_message(message)
             .with_code("EXXX")
-            .with_labels(vec![
-                Label::primary(file_id, byte_range.start_pos..byte_range.end_pos)
-            ]);
+            .with_labels(vec![Label::primary(
+                file_id,
+                byte_range.start_pos..byte_range.end_pos,
+            )]);
 
         diagnostic
     }
@@ -161,11 +178,7 @@ fn renderer(state: &mut DiagnosticState, diag_ctx: DiagnosticContext) {
             Ok(source) => {
                 state.add_source(source);
                 let diagnostic = state.to_diagnostic(diagnostic);
-                term::emit(
-                    &mut writer.lock(),
-                    &config,
-                    &state.files,
-                    &diagnostic).unwrap();
+                term::emit(&mut writer.lock(), &config, &state.files, &diagnostic).unwrap();
             }
         }
     }
diff --git a/rust/tvm/src/ir/mod.rs b/rust/tvm/src/ir/mod.rs
index df9bc68..6d51580 100644
--- a/rust/tvm/src/ir/mod.rs
+++ b/rust/tvm/src/ir/mod.rs
@@ -25,8 +25,8 @@ pub mod function;
 pub mod module;
 pub mod op;
 pub mod relay;
-pub mod span;
 pub mod source_map;
+pub mod span;
 pub mod tir;
 pub mod ty;
 
diff --git a/rust/tvm/src/ir/module.rs b/rust/tvm/src/ir/module.rs
index 11d6c49..443915f 100644
--- a/rust/tvm/src/ir/module.rs
+++ b/rust/tvm/src/ir/module.rs
@@ -32,7 +32,6 @@ use super::expr::GlobalVar;
 use super::function::BaseFunc;
 use super::source_map::SourceMap;
 
-
 // TODO(@jroesch): define type
 type TypeData = ObjectRef;
 type GlobalTypeVar = ObjectRef;
@@ -133,7 +132,9 @@ impl IRModule {
         parse_module(file_name.into(), source.into())
     }
 
-    pub fn parse_file<P: 'static + AsRef<Path>>(file_path: P) -> std::result::Result<IRModule, Error> {
+    pub fn parse_file<P: 'static + AsRef<Path>>(
+        file_path: P,
+    ) -> std::result::Result<IRModule, Error> {
         let file_path = file_path.as_ref();
         let file_path_as_str = file_path.to_str().unwrap().to_string();
         let source = std::fs::read_to_string(file_path)?;
diff --git a/rust/tvm/src/ir/relay/mod.rs b/rust/tvm/src/ir/relay/mod.rs
index 4b09128..530b120 100644
--- a/rust/tvm/src/ir/relay/mod.rs
+++ b/rust/tvm/src/ir/relay/mod.rs
@@ -27,8 +27,8 @@ use crate::runtime::{object::*, String as TString};
 use super::attrs::Attrs;
 use super::expr::BaseExprNode;
 use super::function::BaseFuncNode;
-use super::ty::{Type, TypeNode};
 use super::span::Span;
+use super::ty::{Type, TypeNode};
 
 use tvm_macros::Object;
 use tvm_rt::NDArray;
diff --git a/rust/tvm/src/ir/relay/visitor.rs b/rust/tvm/src/ir/relay/visitor.rs
deleted file mode 100644
index 3166174..0000000
--- a/rust/tvm/src/ir/relay/visitor.rs
+++ /dev/null
@@ -1,24 +0,0 @@
-use super::Expr;
-
-macro_rules! downcast_match {
-    ($id:ident; { $($t:ty => $arm:expr $(,)? )+ , else => $default:expr }) => {
-        $( if let Ok($id) = $id.downcast_clone::<$t>() { $arm } else )+
-        { $default }
-    }
-}
-
-trait ExprVisitorMut {
-    fn visit(&mut self, expr: Expr) {
-        downcast_match!(expr; {
-            else => {
-                panic!()
-            }
-        });
-    }
-
-    fn visit(&mut self, expr: Expr);
-}
-
-// trait ExprTransformer {
-//     fn
-// }
diff --git a/rust/tvm/src/ir/source_map.rs b/rust/tvm/src/ir/source_map.rs
index ebe7e46..b0cc0d8 100644
--- a/rust/tvm/src/ir/source_map.rs
+++ b/rust/tvm/src/ir/source_map.rs
@@ -19,7 +19,7 @@
 
 use crate::runtime::map::Map;
 use crate::runtime::object::Object;
-use crate::runtime::string::{String as TString};
+use crate::runtime::string::String as TString;
 
 use super::span::{SourceName, Span};
 
@@ -39,12 +39,10 @@ pub struct SourceNode {
 
     /// The raw source. */
     pub source: TString,
-
-   // A mapping of line breaks into the raw source.
-   // std::vector<std::pair<int, int>> line_map;
+    // A mapping of line breaks into the raw source.
+    // std::vector<std::pair<int, int>> line_map;
 }
 
-
 //  class Source : public ObjectRef {
 //   public:
 //    TVM_DLL Source(SourceName src_name, std::string source);
@@ -53,7 +51,6 @@ pub struct SourceNode {
 //    TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Source, ObjectRef, SourceNode);
 //  };
 
-
 /// A mapping from a unique source name to source fragments.
 #[repr(C)]
 #[derive(Object)]
@@ -61,6 +58,6 @@ pub struct SourceNode {
 #[ref_name = "SourceMap"]
 pub struct SourceMapNode {
     pub base: Object,
-   /// The source mapping.
-   pub source_map: Map<SourceName, Source>,
+    /// The source mapping.
+    pub source_map: Map<SourceName, Source>,
 }
diff --git a/rust/tvm/src/ir/span.rs b/rust/tvm/src/ir/span.rs
index c54fd51..afcbe9c 100644
--- a/rust/tvm/src/ir/span.rs
+++ b/rust/tvm/src/ir/span.rs
@@ -18,7 +18,7 @@
 * under the License.
 */
 
-use crate::runtime::{ObjectRef, Object, String as TString};
+use crate::runtime::{Object, ObjectRef, String as TString};
 use tvm_macros::Object;
 
 /// A source file name, contained in a Span.
diff --git a/src/ir/expr.cc b/src/ir/expr.cc
index 5110eef..67e5cea 100644
--- a/src/ir/expr.cc
+++ b/src/ir/expr.cc
@@ -192,15 +192,4 @@ TVM_REGISTER_GLOBAL("ir.DebugPrint").set_body_typed([](ObjectRef ref) {
   return ss.str();
 });
 
-
-
 }  // namespace tvm
-
-#ifdef RUST_COMPILER_EXT
-
-extern "C" {
-  int compiler_ext_initialize();
-  static int test = compiler_ext_initialize();
-}
-
-#endif


[incubator-tvm] 10/23: Fix calling

Posted by jr...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

jroesch pushed a commit to branch cargo-build
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git

commit eeb86c63d693288f8d406aed6b1b0df6d28e4b07
Author: Jared Roesch <ro...@gmail.com>
AuthorDate: Thu Oct 15 21:37:34 2020 -0700

    Fix calling
---
 rust/tvm-rt/src/function.rs        | 28 +++++++++++++++++-----------
 rust/tvm/src/bin/tyck.rs           |  2 +-
 rust/tvm/src/ir/diagnostics/mod.rs |  4 +---
 3 files changed, 19 insertions(+), 15 deletions(-)

diff --git a/rust/tvm-rt/src/function.rs b/rust/tvm-rt/src/function.rs
index bae06e9..c7aebdd 100644
--- a/rust/tvm-rt/src/function.rs
+++ b/rust/tvm-rt/src/function.rs
@@ -33,6 +33,7 @@ use std::{
 };
 
 use crate::errors::Error;
+use crate::object::ObjectPtr;
 
 pub use super::to_function::{ToFunction, Typed};
 pub use tvm_sys::{ffi, ArgValue, RetValue};
@@ -120,21 +121,26 @@ impl Function {
         let mut ret_val = ffi::TVMValue { v_int64: 0 };
         let mut ret_type_code = 0i32;
 
-        check_call!(ffi::TVMFuncCall(
-            self.handle,
-            values.as_mut_ptr() as *mut ffi::TVMValue,
-            type_codes.as_mut_ptr() as *mut c_int,
-            num_args as c_int,
-            &mut ret_val as *mut _,
-            &mut ret_type_code as *mut _
-        ));
+        let ret_code = unsafe {
+            ffi::TVMFuncCall(
+                self.handle,
+                values.as_mut_ptr() as *mut ffi::TVMValue,
+                type_codes.as_mut_ptr() as *mut c_int,
+                num_args as c_int,
+                &mut ret_val as *mut _,
+                &mut ret_type_code as *mut _
+            )
+        };
+
+        if ret_code != 0 {
+            return Err(Error::CallFailed(crate::get_last_error().into()));
+        }
 
         let rv = RetValue::from_tvm_value(ret_val, ret_type_code as u32);
         match rv {
             RetValue::ObjectHandle(object) => {
-                let optr = crate::object::ObjectPtr::from_raw(object as _).unwrap();
-                // println!("after wrapped call: {}", optr.count());
-                crate::object::ObjectPtr::leak(optr);
+                let optr = ObjectPtr::from_raw(object as _).unwrap();
+                ObjectPtr::leak(optr);
             }
             _ => {}
         };
diff --git a/rust/tvm/src/bin/tyck.rs b/rust/tvm/src/bin/tyck.rs
index e0c7136..fbab027 100644
--- a/rust/tvm/src/bin/tyck.rs
+++ b/rust/tvm/src/bin/tyck.rs
@@ -18,7 +18,7 @@ fn main() -> Result<()> {
     codespan::init().expect("Rust based diagnostics");
     let opt = Opt::from_args();
     println!("{:?}", &opt);
-    let module = IRModule::parse_file(opt.input)?;
+    let module = IRModule::parse_file(opt.input);
 
     // for (k, v) in module.functions {
     //     println!("Function name: {:?}", v);
diff --git a/rust/tvm/src/ir/diagnostics/mod.rs b/rust/tvm/src/ir/diagnostics/mod.rs
index fce214a..039d1ed 100644
--- a/rust/tvm/src/ir/diagnostics/mod.rs
+++ b/rust/tvm/src/ir/diagnostics/mod.rs
@@ -207,9 +207,7 @@ impl DiagnosticContext {
     }
 }
 
-// Override the global diagnostics renderer.
-// Params
-// ------
+/// Override the global diagnostics renderer.
 // render_func: Option[Callable[[DiagnosticContext], None]]
 //     If the render_func is None it will remove the current custom renderer
 //     and return to default behavior.


[incubator-tvm] 18/23: WIP

Posted by jr...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

jroesch pushed a commit to branch cargo-build
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git

commit 49246bff342eb59757cecc34a9a9465a2e3c063d
Author: Jared Roesch <ro...@gmail.com>
AuthorDate: Wed Oct 21 14:09:37 2020 -0700

    WIP
---
 rust/tvm-macros/src/external.rs |  5 ++-
 rust/tvm-macros/src/lib.rs      |  1 +
 rust/tvm/src/ir/module.rs       | 67 +++++++++++++----------------------------
 3 files changed, 26 insertions(+), 47 deletions(-)

diff --git a/rust/tvm-macros/src/external.rs b/rust/tvm-macros/src/external.rs
index 802d7ae..de8ada3 100644
--- a/rust/tvm-macros/src/external.rs
+++ b/rust/tvm-macros/src/external.rs
@@ -17,6 +17,7 @@
  * under the License.
  */
 use proc_macro2::Span;
+use proc_macro_error::abort;
 use quote::quote;
 use syn::parse::{Parse, ParseStream, Result};
 
@@ -109,7 +110,9 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
             .iter()
             .map(|ty_param| match ty_param {
                 syn::GenericParam::Type(param) => param.clone(),
-                _ => panic!(),
+                _ => abort! { ty_param,
+                    "Only supports type parameters."
+                }
             })
             .collect();
 
diff --git a/rust/tvm-macros/src/lib.rs b/rust/tvm-macros/src/lib.rs
index 603e1ce..ab75c92 100644
--- a/rust/tvm-macros/src/lib.rs
+++ b/rust/tvm-macros/src/lib.rs
@@ -35,6 +35,7 @@ pub fn macro_impl(input: TokenStream) -> TokenStream {
     TokenStream::from(object::macro_impl(input))
 }
 
+#[proc_macro_error]
 #[proc_macro]
 pub fn external(input: TokenStream) -> TokenStream {
     external::macro_impl(input)
diff --git a/rust/tvm/src/ir/module.rs b/rust/tvm/src/ir/module.rs
index 443915f..8918bdc 100644
--- a/rust/tvm/src/ir/module.rs
+++ b/rust/tvm/src/ir/module.rs
@@ -63,6 +63,8 @@ external! {
     #[name("parser.ParseExpr")]
     fn parse_expression(file_name: TVMString, source: TVMString) -> IRModule;
     // Module methods
+    #[name("ir.Module_Add")]
+    fn module_add_def(module: IRModule, type_name: GlobalVar, expr: relay::Expr, update: bool) -> ();
     #[name("ir.Module_AddDef")]
     fn module_add_def(module: IRModule, type_name: GlobalTypeVar, type_data: TypeData, update: bool) -> ();
     #[name("ir.Module_GetGlobalVar")]
@@ -73,55 +75,28 @@ external! {
     fn module_lookup(module: IRModule, var: GlobalVar) -> BaseFunc;
     #[name("ir.Module_Lookup_str")]
     fn module_lookup_str(module: IRModule, name: TVMString) -> BaseFunc;
+    #[name("ir.Module_GetGlobalTypeVars")]
+    fn module_get_global_type_vars() -> Array<GlobalTypeVar>;
+    #[name("ir.Module_ContainGlobalVar")]
+    fn module_get_global_var(name: TVMString) -> bool;
+    #[name("ir.Module_ContainGlobalTypeVar")]
+    fn module_get_global_type_var(name: TVMString) -> bool;
+    #[name("ir.Module_LookupDef")]
+    fn module_lookup_def(module: IRModule, global: GlobalTypeVar) -> TypeDef;
+    #[name("ir.Module_LookupDef_str")]
+    fn module_lookup_def_str(module: IRModule, global: GlobalTypeVar) -> TypeDef;
+    #[name("ir.Module_LookupTag")]
+    fn module_lookup_tag(module: IRModule, tag: i32) -> Constructor;
+    #[name("ir.Module_FromExpr")]
+    fn module_from_expr(expr: relay::Expr, funcs: Map<GlobalVar, BaseFunc>, types: Map<GlobalTypeVar, TypeData>) -> IRModule;
+    #[name("ir.Module_Import")]
+    fn module_import(module: IRModule, path: TVMString);
+    #[name("ir.Module_ImportFromStd")]
+    fn module_import_from_std(module: IRModule, path: TVMString);
 }
 
-// TVM_REGISTER_GLOBAL("ir.Module_GetGlobalTypeVars")
-//     .set_body_method<IRModule>(&IRModuleNode::GetGlobalTypeVars);
+// Note: we don't expose update here as update is going to be removed.
 
-// TVM_REGISTER_GLOBAL("ir.Module_ContainGlobalVar")
-//     .set_body_method<IRModule>(&IRModuleNode::ContainGlobalVar);
-
-// TVM_REGISTER_GLOBAL("ir.Module_GetGlobalTypeVar")
-//     .set_body_method<IRModule>(&IRModuleNode::GetGlobalTypeVar);
-
-// TVM_REGISTER_GLOBAL("ir.Module_LookupDef").set_body_typed([](IRModule mod, GlobalTypeVar var) {
-//   return mod->LookupTypeDef(var);
-// });
-
-// TVM_REGISTER_GLOBAL("ir.Module_LookupDef_str").set_body_typed([](IRModule mod, String var) {
-//   return mod->LookupTypeDef(var);
-// });
-
-// TVM_REGISTER_GLOBAL("ir.Module_LookupTag").set_body_typed([](IRModule mod, int32_t tag) {
-//   return mod->LookupTag(tag);
-// });
-
-// TVM_REGISTER_GLOBAL("ir.Module_FromExpr")
-//     .set_body_typed([](RelayExpr e, tvm::Map<GlobalVar, BaseFunc> funcs,
-//                        tvm::Map<GlobalTypeVar, TypeData> type_defs) {
-//       return IRModule::FromExpr(e, funcs, type_defs);
-//     });
-
-// TVM_REGISTER_GLOBAL("ir.Module_Update").set_body_typed([](IRModule mod, IRModule from) {
-//   mod->Update(from);
-// });
-
-// TVM_REGISTER_GLOBAL("ir.Module_UpdateFunction")
-//     .set_body_typed([](IRModule mod, GlobalVar gv, BaseFunc func) { mod->Update(gv, func); });
-
-// TVM_REGISTER_GLOBAL("ir.Module_Import").set_body_typed([](IRModule mod, String path) {
-//   mod->Import(path);
-// });
-
-// TVM_REGISTER_GLOBAL("ir.Module_ImportFromStd").set_body_typed([](IRModule mod, String path) {
-//   mod->ImportFromStd(path);
-// });
-
-// TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-//     .set_dispatch<IRModuleNode>([](const ObjectRef& ref, ReprPrinter* p) {
-//       auto* node = static_cast<const IRModuleNode*>(ref.get());
-//       p->stream << "IRModuleNode( " << node->functions << ")";
-//     });
 
 impl IRModule {
     pub fn parse<N, S>(file_name: N, source: S) -> Result<IRModule>


[incubator-tvm] 08/23: Clean up exporting to show off new diagnostics

Posted by jr...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

jroesch pushed a commit to branch cargo-build
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git

commit 20c6a28606c053fbf9adf1c36c85fd608e63e024
Author: Jared Roesch <ro...@gmail.com>
AuthorDate: Thu Oct 15 17:03:00 2020 -0700

    Clean up exporting to show off new diagnostics
---
 rust/compiler-ext/src/lib.rs   | 12 ++++++++++--
 rust/tvm-rt/src/array.rs       | 32 ++++++++++++++++++++++++++++++++
 rust/tvm/src/bin/tyck.rs       |  7 ++++++-
 rust/tvm/src/ir/diagnostics.rs | 10 +++++-----
 rust/tvm/src/ir/mod.rs         |  1 +
 rust/tvm/src/ir/module.rs      |  3 +++
 rust/tvm/src/ir/source_map.rs  | 26 +++++++++++++++-----------
 rust/tvm/src/lib.rs            | 24 ++++++++++++++++++++++++
 8 files changed, 96 insertions(+), 19 deletions(-)

diff --git a/rust/compiler-ext/src/lib.rs b/rust/compiler-ext/src/lib.rs
index 3e37d21..c136d06 100644
--- a/rust/compiler-ext/src/lib.rs
+++ b/rust/compiler-ext/src/lib.rs
@@ -22,14 +22,22 @@ use tvm;
 use tvm::runtime::function::register_override;
 
 fn test_fn() -> Result<(), tvm::Error> {
-    println!("Hello from Rust!");
+    println!("Hello Greg from Rust!");
     Ok(())
 }
 
+fn test_fn2(message: tvm::runtime::string::String) -> Result<(), tvm::Error> {
+    println!("The message: {}", message);
+    Ok(())
+}
+
+tvm::export!(test_fn, test_fn2);
+
 #[no_mangle]
 fn compiler_ext_initialize() -> i32 {
     let _ = env_logger::try_init();
-    register_override(test_fn, "rust_ext.test_fn", true).expect("failed to initialize simplifier");
+    tvm_export("rust_ext")
+        .expect("failed to initialize Rust compiler_ext");
     log::debug!("done!");
     return 0;
 }
diff --git a/rust/tvm-rt/src/array.rs b/rust/tvm-rt/src/array.rs
index 5e19cef..032ca79 100644
--- a/rust/tvm-rt/src/array.rs
+++ b/rust/tvm-rt/src/array.rs
@@ -19,6 +19,7 @@
 
 use std::convert::{TryFrom, TryInto};
 use std::marker::PhantomData;
+use std::iter::{IntoIterator, Iterator};
 
 use crate::errors::Error;
 use crate::object::{IsObjectRef, Object, ObjectPtr, ObjectRef};
@@ -81,6 +82,37 @@ impl<T: IsObjectRef> Array<T> {
     }
 }
 
+pub struct IntoIter<T: IsObjectRef> {
+    array: Array<T>,
+    pos: isize,
+    size: isize,
+}
+
+impl<T: IsObjectRef> Iterator for IntoIter<T> {
+    type Item = T;
+
+    fn next(&mut self) -> Option<Self::Item> {
+        if self.pos < self.size {
+            let item = self.array.get(self.pos)
+                .expect("should not fail");
+            self.pos += 1;
+            Some(item)
+        } else {
+            None
+        }
+    }
+}
+
+impl<T: IsObjectRef> IntoIterator for Array<T> {
+    type Item = T;
+    type IntoIter = IntoIter<T>;
+
+    fn into_iter(self) -> Self::IntoIter {
+        let size = self.len() as isize;
+        IntoIter { array: self, pos: 0, size: size }
+    }
+}
+
 impl<T: IsObjectRef> From<Array<T>> for ArgValue<'static> {
     fn from(array: Array<T>) -> ArgValue<'static> {
         array.object.into()
diff --git a/rust/tvm/src/bin/tyck.rs b/rust/tvm/src/bin/tyck.rs
index b869012..e0c7136 100644
--- a/rust/tvm/src/bin/tyck.rs
+++ b/rust/tvm/src/bin/tyck.rs
@@ -18,6 +18,11 @@ fn main() -> Result<()> {
     codespan::init().expect("Rust based diagnostics");
     let opt = Opt::from_args();
     println!("{:?}", &opt);
-    let file = IRModule::parse_file(opt.input)?;
+    let module = IRModule::parse_file(opt.input)?;
+
+    // for (k, v) in module.functions {
+    //     println!("Function name: {:?}", v);
+    // }
+
     Ok(())
 }
diff --git a/rust/tvm/src/ir/diagnostics.rs b/rust/tvm/src/ir/diagnostics.rs
index b76e43f..4975a45 100644
--- a/rust/tvm/src/ir/diagnostics.rs
+++ b/rust/tvm/src/ir/diagnostics.rs
@@ -135,6 +135,7 @@ pub struct DiagnosticRendererNode {
     pub base: Object,
     // TODO(@jroesch): we can't easily exposed packed functions due to
     // memory layout
+    // missing field here
 }
 
 //     def render(self, ctx):
@@ -283,11 +284,10 @@ pub mod codespan {
     pub fn init() -> Result<()> {
         let mut files: SimpleFiles<String, String> = SimpleFiles::new();
         let render_fn = move |diag_ctx: DiagnosticContext| {
-            // let source_map = diag_ctx.module.source_map;
-            // for diagnostic in diag_ctx.diagnostics {
-
-            // }
-            panic!("render_fn");
+            let source_map = diag_ctx.module.source_map.clone();
+            for diagnostic in diag_ctx.diagnostics.clone() {
+                println!("Diagnostic: {}", diagnostic.message);
+            }
         };
 
         override_renderer(Some(render_fn))?;
diff --git a/rust/tvm/src/ir/mod.rs b/rust/tvm/src/ir/mod.rs
index 401b6c2..df9bc68 100644
--- a/rust/tvm/src/ir/mod.rs
+++ b/rust/tvm/src/ir/mod.rs
@@ -26,6 +26,7 @@ pub mod module;
 pub mod op;
 pub mod relay;
 pub mod span;
+pub mod source_map;
 pub mod tir;
 pub mod ty;
 
diff --git a/rust/tvm/src/ir/module.rs b/rust/tvm/src/ir/module.rs
index e0444b3..5156e74 100644
--- a/rust/tvm/src/ir/module.rs
+++ b/rust/tvm/src/ir/module.rs
@@ -25,6 +25,7 @@ use crate::runtime::{external, Object, ObjectRef};
 
 use super::expr::GlobalVar;
 use super::function::BaseFunc;
+use super::source_map::SourceMap;
 
 use std::io::Result as IOResult;
 use std::path::Path;
@@ -43,6 +44,8 @@ pub struct IRModuleNode {
     pub base: Object,
     pub functions: Map<GlobalVar, BaseFunc>,
     pub type_definitions: Map<GlobalTypeVar, TypeData>,
+    pub source_map: SourceMap,
+    // TODO(@jroesch): this is missing some fields
 }
 
 external! {
diff --git a/rust/tvm/src/ir/source_map.rs b/rust/tvm/src/ir/source_map.rs
index e6c0371..56c0830 100644
--- a/rust/tvm/src/ir/source_map.rs
+++ b/rust/tvm/src/ir/source_map.rs
@@ -12,7 +12,7 @@
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
+ * KIND, either exprss or implied.  See the License for the
  * specific language governing permissions and limitations
  * under the License.
  */
@@ -20,23 +20,27 @@
 use crate::runtime::map::Map;
 use crate::runtime::object::Object;
 
+use super::span::{SourceName, Span};
+
+use tvm_macros::Object;
+
 /// A program source in any language.
 ///
 /// Could represent the source from an ML framework or a source of an IRModule.
 #[repr(C)]
 #[derive(Object)]
 #[type_key = "Source"]
-#[ref_key = "Source"]
-struct SourceNode {
+#[ref_name = "Source"]
+pub struct SourceNode {
     pub base: Object,
-    /*! \brief The source name. */
-   SourceName source_name;
+    /// The source name. */
+    pub source_name: SourceName,
 
-   /*! \brief The raw source. */
-   String source;
+    /// The raw source. */
+    source: String,
 
-   /*! \brief A mapping of line breaks into the raw source. */
-   std::vector<std::pair<int, int>> line_map;
+   // A mapping of line breaks into the raw source.
+   // std::vector<std::pair<int, int>> line_map;
 }
 
 
@@ -53,8 +57,8 @@ struct SourceNode {
 #[repr(C)]
 #[derive(Object)]
 #[type_key = "SourceMap"]
-#[ref_key = "SourceMap"]
-struct SourceMapNode {
+#[ref_name = "SourceMap"]
+pub struct SourceMapNode {
     pub base: Object,
    /// The source mapping.
    pub source_map: Map<SourceName, Source>,
diff --git a/rust/tvm/src/lib.rs b/rust/tvm/src/lib.rs
index 36c7503..d193f09 100644
--- a/rust/tvm/src/lib.rs
+++ b/rust/tvm/src/lib.rs
@@ -47,3 +47,27 @@ pub mod runtime;
 pub mod transform;
 
 pub use runtime::version;
+
+#[macro_export]
+macro_rules! export {
+    ($($fn_names:expr),*) => {
+        pub fn tvm_export(ns: &str) -> Result<(), tvm::Error> {
+            $(
+                register_override($fn_name, concat!($ns, stringfy!($fn_name)), true)?;
+            )*
+            Ok(())
+        }
+    }
+}
+
+#[macro_export]
+macro_rules! export_mod {
+    ($ns:expr, $($mod_name:expr),*) => {
+        pub fn tvm_mod_export() -> Result<(), tvm::Error> {
+            $(
+                $mod_names::tvm_export($ns)?;
+            )*
+            Ok(())
+        }
+    }
+}


[incubator-tvm] 12/23: Rust Diagnostics work

Posted by jr...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

jroesch pushed a commit to branch cargo-build
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git

commit 6e1346748e08255f220c3e6cf72c59a8a3f6ef29
Author: Jared Roesch <jr...@octoml.ai>
AuthorDate: Fri Oct 16 14:17:01 2020 -0700

    Rust Diagnostics work
---
 rust/tvm-rt/src/errors.rs               |  15 ++++
 rust/tvm-rt/src/function.rs             |   7 +-
 rust/tvm/src/bin/tyck.rs                |  13 ++--
 rust/tvm/src/ir/diagnostics/codespan.rs | 126 ++++++++++++++++++++++----------
 4 files changed, 117 insertions(+), 44 deletions(-)

diff --git a/rust/tvm-rt/src/errors.rs b/rust/tvm-rt/src/errors.rs
index c884c56..3de9f3c 100644
--- a/rust/tvm-rt/src/errors.rs
+++ b/rust/tvm-rt/src/errors.rs
@@ -68,6 +68,21 @@ pub enum Error {
     Infallible(#[from] std::convert::Infallible),
     #[error("a panic occurred while executing a Rust packed function")]
     Panic,
+    #[error("one or more error diagnostics were emitted, please check diagnostic render for output.")]
+    DiagnosticError(String),
+    #[error("{0}")]
+    Raw(String),
+}
+
+impl Error {
+    pub fn from_raw_tvm(raw: &str) -> Error {
+        let err_header = raw.find(":").unwrap_or(0);
+        let (err_ty, err_content) = raw.split_at(err_header);
+        match err_ty {
+            "DiagnosticError" => Error::DiagnosticError((&err_content[1..]).into()),
+            _ => Error::Raw(raw.into()),
+        }
+    }
 }
 
 impl Error {
diff --git a/rust/tvm-rt/src/function.rs b/rust/tvm-rt/src/function.rs
index c7aebdd..173b60a 100644
--- a/rust/tvm-rt/src/function.rs
+++ b/rust/tvm-rt/src/function.rs
@@ -133,7 +133,12 @@ impl Function {
         };
 
         if ret_code != 0 {
-            return Err(Error::CallFailed(crate::get_last_error().into()));
+            let raw_error = crate::get_last_error();
+            let error = match Error::from_raw_tvm(raw_error) {
+                Error::Raw(string) => Error::CallFailed(string),
+                e => e,
+            };
+            return Err(error);
         }
 
         let rv = RetValue::from_tvm_value(ret_val, ret_type_code as u32);
diff --git a/rust/tvm/src/bin/tyck.rs b/rust/tvm/src/bin/tyck.rs
index fbab027..13470e7 100644
--- a/rust/tvm/src/bin/tyck.rs
+++ b/rust/tvm/src/bin/tyck.rs
@@ -4,7 +4,8 @@ use anyhow::Result;
 use structopt::StructOpt;
 
 use tvm::ir::diagnostics::codespan;
-use tvm::ir::IRModule;
+use tvm::ir::{self, IRModule};
+use tvm::runtime::Error;
 
 #[derive(Debug, StructOpt)]
 #[structopt(name = "tyck", about = "Parse and type check a Relay program.")]
@@ -18,11 +19,11 @@ fn main() -> Result<()> {
     codespan::init().expect("Rust based diagnostics");
     let opt = Opt::from_args();
     println!("{:?}", &opt);
-    let module = IRModule::parse_file(opt.input);
-
-    // for (k, v) in module.functions {
-    //     println!("Function name: {:?}", v);
-    // }
+    let _module = match IRModule::parse_file(opt.input) {
+        Err(ir::module::Error::TVM(Error::DiagnosticError(_))) => { return Ok(()) },
+        Err(e) => { return Err(e.into()); },
+        Ok(module) => module
+    };
 
     Ok(())
 }
diff --git a/rust/tvm/src/ir/diagnostics/codespan.rs b/rust/tvm/src/ir/diagnostics/codespan.rs
index 80a8784..9fc1ee0 100644
--- a/rust/tvm/src/ir/diagnostics/codespan.rs
+++ b/rust/tvm/src/ir/diagnostics/codespan.rs
@@ -4,6 +4,7 @@ use std::sync::{Arc, Mutex};
 use codespan_reporting::diagnostic::{Diagnostic as CDiagnostic, Label, Severity};
 use codespan_reporting::files::SimpleFiles;
 use codespan_reporting::term::termcolor::{ColorChoice, StandardStream};
+use codespan_reporting::term::{self, ColorArg};
 
 use crate::ir::source_map::*;
 use super::*;
@@ -13,8 +14,14 @@ enum StartOrEnd {
     End,
 }
 
+struct ByteRange<FileId> {
+    file_id: FileId,
+    start_pos: usize,
+    end_pos: usize,
+}
+
 enum FileSpanToByteRange {
-    AsciiSource,
+    AsciiSource(Vec<usize>),
     Utf8 {
         /// Map character regions which are larger then 1-byte to length.
         lengths: HashMap<isize, isize>,
@@ -27,7 +34,12 @@ impl FileSpanToByteRange {
         let mut last_index = 0;
         let mut is_ascii = true;
         if source.is_ascii() {
-            FileSpanToByteRange::AsciiSource
+            let line_lengths =
+                source
+                    .lines()
+                    .map(|line| line.len())
+                    .collect();
+            FileSpanToByteRange::AsciiSource(line_lengths)
         } else {
             panic!()
         }
@@ -41,6 +53,21 @@ impl FileSpanToByteRange {
         //     last_index = index;
         // }
     }
+
+    fn lookup(&self, span: &Span) -> ByteRange<String> {
+        use FileSpanToByteRange::*;
+
+        let source_name: String = span.source_name.name.as_str().unwrap().into();
+
+        match self {
+            AsciiSource(ref line_lengths) => {
+                let start_pos = (&line_lengths[0..(span.line - 1) as usize]).into_iter().sum::<usize>() + (span.column) as usize;
+                let end_pos = (&line_lengths[0..(span.end_line - 1) as usize]).into_iter().sum::<usize>() + (span.end_column) as usize;
+                ByteRange { file_id: source_name, start_pos, end_pos }
+            },
+            _ => panic!()
+        }
+    }
 }
 
 struct SpanToByteRange {
@@ -62,41 +89,22 @@ impl SpanToByteRange {
             self.map.insert(source_name, FileSpanToByteRange::new(source));
         }
     }
-}
-
-struct ByteRange<FileId> {
-    file_id: FileId,
-    start_pos: usize,
-    end_pos: usize,
-}
-
-
-pub fn to_diagnostic(diag: super::Diagnostic) -> CDiagnostic<String> {
-    let severity = match diag.level {
-        DiagnosticLevel::Error => Severity::Error,
-        DiagnosticLevel::Warning => Severity::Warning,
-        DiagnosticLevel::Note => Severity::Note,
-        DiagnosticLevel::Help => Severity::Help,
-        DiagnosticLevel::Bug => Severity::Bug,
-    };
-
-    let file_id = "foo".into(); // diag.span.source_name;
 
-    let message: String = diag.message.as_str().unwrap().into();
-    let inner_message: String = "expected `String`, found `Nat`".into();
-    let diagnostic = CDiagnostic::new(severity)
-        .with_message(message)
-        .with_code("EXXX")
-        .with_labels(vec![
-            Label::primary(file_id, 328..331).with_message(inner_message)
-        ]);
+    pub fn lookup(&self, span: &Span) -> ByteRange<String> {
+        let source_name: String = span.source_name.name.as_str().expect("foo").into();
 
-    diagnostic
+        match self.map.get(&source_name) {
+            Some(file_span_to_bytes) => file_span_to_bytes.lookup(span),
+            None => panic!(),
+        }
+    }
 }
 
 struct DiagnosticState {
     files: SimpleFiles<String, String>,
     span_map: SpanToByteRange,
+    // todo unify wih source name
+    source_to_id: HashMap<String, usize>,
 }
 
 impl DiagnosticState {
@@ -104,26 +112,70 @@ impl DiagnosticState {
         DiagnosticState {
             files: SimpleFiles::new(),
             span_map: SpanToByteRange::new(),
+            source_to_id: HashMap::new(),
         }
     }
+
+    fn add_source(&mut self, source: Source) {
+        let source_str: String = source.source.as_str().unwrap().into();
+        let source_name: String = source.source_name.name.as_str().unwrap().into();
+        self.span_map.add_source(source);
+        let file_id = self.files.add(source_name.clone(), source_str);
+        self.source_to_id.insert(source_name, file_id);
+    }
+
+    fn to_diagnostic(&self, diag: super::Diagnostic) -> CDiagnostic<usize> {
+        let severity = match diag.level {
+            DiagnosticLevel::Error => Severity::Error,
+            DiagnosticLevel::Warning => Severity::Warning,
+            DiagnosticLevel::Note => Severity::Note,
+            DiagnosticLevel::Help => Severity::Help,
+            DiagnosticLevel::Bug => Severity::Bug,
+        };
+
+        let source_name: String = diag.span.source_name.name.as_str().unwrap().into();
+        let file_id = *self.source_to_id.get(&source_name).unwrap();
+
+        let message: String = diag.message.as_str().unwrap().into();
+
+        let byte_range = self.span_map.lookup(&diag.span);
+
+        let diagnostic = CDiagnostic::new(severity)
+            .with_message(message)
+            .with_code("EXXX")
+            .with_labels(vec![
+                Label::primary(file_id, byte_range.start_pos..byte_range.end_pos)
+            ]);
+
+        diagnostic
+    }
 }
 
 fn renderer(state: &mut DiagnosticState, diag_ctx: DiagnosticContext) {
     let source_map = diag_ctx.module.source_map.clone();
-        for diagnostic in diag_ctx.diagnostics.clone() {
-            match source_map.source_map.get(&diagnostic.span.source_name) {
-                Err(err) => panic!(),
-                Ok(source) => state.span_map.add_source(source),
+    let writer = StandardStream::stderr(ColorChoice::Always);
+    let config = codespan_reporting::term::Config::default();
+    for diagnostic in diag_ctx.diagnostics.clone() {
+        match source_map.source_map.get(&diagnostic.span.source_name) {
+            Err(err) => panic!(err),
+            Ok(source) => {
+                state.add_source(source);
+                let diagnostic = state.to_diagnostic(diagnostic);
+                term::emit(
+                    &mut writer.lock(),
+                    &config,
+                    &state.files,
+                    &diagnostic).unwrap();
             }
-            println!("Diagnostic: {}", diagnostic.message);
         }
+    }
 }
 
 pub fn init() -> Result<()> {
     let diag_state = Arc::new(Mutex::new(DiagnosticState::new()));
     let render_fn = move |diag_ctx: DiagnosticContext| {
-        // let mut guard = diag_state.lock().unwrap();
-        // renderer(&mut *guard, diag_ctx);
+        let mut guard = diag_state.lock().unwrap();
+        renderer(&mut *guard, diag_ctx);
     };
 
     override_renderer(Some(render_fn))?;


[incubator-tvm] 22/23: WIP

Posted by jr...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

jroesch pushed a commit to branch cargo-build
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git

commit e8bb83d33e7ad44ff16c72d9de61c2de722d12a8
Author: Jared Roesch <jr...@octoml.ai>
AuthorDate: Thu Oct 29 15:31:05 2020 -0700

    WIP
---
 rust/tvm/src/python.rs | 34 ++++++----------------------------
 1 file changed, 6 insertions(+), 28 deletions(-)

diff --git a/rust/tvm/src/python.rs b/rust/tvm/src/python.rs
index 2b2d374..50ce7b0 100644
--- a/rust/tvm/src/python.rs
+++ b/rust/tvm/src/python.rs
@@ -20,20 +20,6 @@
 use pyo3::prelude::*;
 use once_cell::sync::OnceCell;
 
-// static TVM_PYTHON: OnceCell<Py<PyModule>> = OnceCell::new();
-
-// fn initialize() -> Py<PyModule> {
-//     TVM_PYTHON.get_or_init(|| {
-//         let gil = Python::acquire_gil();
-//         let py = gil.python();
-//         PyModule::new(py, "__tvm__rust__module__").map_err(|e| {
-//             // We can't display Python exceptions via std::fmt::Display,
-//             // so print the error here manually.
-//             e.print_and_set_sys_last_vars(py);
-//         }).expect("failed to initialize the Python interface").into()
-//     }).clone()
-// }
-
 /// Load the Python interpreter into the address space.
 ///
 /// This enables the ability for Rust code to call TVM
@@ -53,27 +39,19 @@ pub fn load() -> Result<String, ()> {
     })
 }
 
-fn import_python<'p, 'b: 'p>(py: Python<'p>, to_import: &'b str) -> PyResult<&'p PyModule> {
-    let imported_mod = py.import(to_import)?;
-    PyModule::from_code(py,
-        r#"
-import tvm
-from tvm import relay
-tvm.cleanup()
-"#, "blah", "my_mod")?;
-    // py_mod.add(to_import, imported_mod)?;
-    Ok(imported_mod)
-}
-
 pub fn import(mod_to_import: &str) -> PyResult<()> {
     let gil = Python::acquire_gil();
     let py = gil.python();
-    // let main_mod = initialize();
-    // let main_mod = main_mod.as_ref(py);
     import_python(py, mod_to_import)?;
     Ok(())
 }
 
+fn import_python<'p, 'b: 'p>(py: Python<'p>, to_import: &'b str) -> PyResult<&'p PyModule> {
+    let imported_mod = py.import(to_import)?;
+    Ok(imported_mod)
+}
+
+
 fn load_python_tvm_(py: Python) -> PyResult<String> {
     let imported_mod = import_python(py, "tvm")?;
     let version: String = imported_mod.get("__version__")?.extract()?;


[incubator-tvm] 07/23: Fix Linux build

Posted by jr...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

jroesch pushed a commit to branch cargo-build
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git

commit e0f980142c3c0ab795de316943079a651743d8d7
Author: Jared Roesch <ro...@gmail.com>
AuthorDate: Thu Oct 15 01:42:14 2020 -0700

    Fix Linux build
---
 cmake/modules/LLVM.cmake | 7 ++++++-
 rust/tvm-sys/Cargo.toml  | 2 +-
 rust/tvm-sys/build.rs    | 3 +--
 3 files changed, 8 insertions(+), 4 deletions(-)

diff --git a/cmake/modules/LLVM.cmake b/cmake/modules/LLVM.cmake
index 5f8ace1..ca4ecd6 100644
--- a/cmake/modules/LLVM.cmake
+++ b/cmake/modules/LLVM.cmake
@@ -16,7 +16,12 @@
 # under the License.
 
 # LLVM rules
-add_definitions(-DDMLC_USE_FOPEN64=0)
+# Due to LLVM debug symbols you can sometimes face linking issues on
+# certain compiler, platform combinations if you don't set NDEBUG.
+#
+# See https://github.com/imageworks/OpenShadingLanguage/issues/1069
+# for more discussion.
+add_definitions(-DDMLC_USE_FOPEN64=0 -DNDEBUG=1)
 
 # Test if ${USE_LLVM} is not an explicit boolean false
 # It may be a boolean or a string
diff --git a/rust/tvm-sys/Cargo.toml b/rust/tvm-sys/Cargo.toml
index c25a5bf..2952aa4 100644
--- a/rust/tvm-sys/Cargo.toml
+++ b/rust/tvm-sys/Cargo.toml
@@ -23,7 +23,7 @@ license = "Apache-2.0"
 edition = "2018"
 
 [features]
-default = ["bindings"]
+default = []
 bindings = []
 
 [dependencies]
diff --git a/rust/tvm-sys/build.rs b/rust/tvm-sys/build.rs
index 2d86c4b..1590234 100644
--- a/rust/tvm-sys/build.rs
+++ b/rust/tvm-sys/build.rs
@@ -60,8 +60,7 @@ fn main() -> Result<()> {
     if cfg!(feature = "bindings") {
         println!("cargo:rerun-if-env-changed=TVM_HOME");
         println!("cargo:rustc-link-lib=dylib=tvm");
-        println!("cargo:rustc-link-lib=dylib=llvm-10");
-        println!("cargo:rustc-link-search={}/build", tvm_home);
+        println!("cargo:rustc-link-search=native={}/build", tvm_home);
     }
 
     // @see rust-bindgen#550 for `blacklist_type`


[incubator-tvm] 05/23: Borrow code from Egg

Posted by jr...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

jroesch pushed a commit to branch cargo-build
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git

commit b2b59c229e9b8c2002d8c8cd520748df6b38e074
Author: Jared Roesch <jr...@octoml.ai>
AuthorDate: Tue Oct 13 15:26:54 2020 -0700

    Borrow code from Egg
---
 rust/compiler-ext/src/lib.rs | 344 ++++++++++++++++++++++++++++++++++++++++++-
 1 file changed, 337 insertions(+), 7 deletions(-)

diff --git a/rust/compiler-ext/src/lib.rs b/rust/compiler-ext/src/lib.rs
index 31e1bb2..58bdd0c 100644
--- a/rust/compiler-ext/src/lib.rs
+++ b/rust/compiler-ext/src/lib.rs
@@ -1,7 +1,337 @@
-#[cfg(test)]
-mod tests {
-    #[test]
-    fn it_works() {
-        assert_eq!(2 + 2, 4);
-    }
-}
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+ use std::os::raw::c_int;
+ use tvm::initialize;
+ use tvm::ir::{tir, PrimExpr};
+ use tvm::runtime::function::register_override;
+ use tvm::runtime::map::Map;
+ use tvm::runtime::object::{IsObject, IsObjectRef};
+ 
+ use ordered_float::NotNan;
+ 
+ mod interval;
+ mod math;
+ 
+ use math::{BoundsMap, Expr, RecExpr};
+ use tvm::ir::arith::ConstIntBound;
+ use tvm_rt::{ObjectRef, array::Array};
+ 
+ macro_rules! downcast_match {
+     ($id:ident; { $($t:ty => $arm:expr $(,)? )+ , else => $default:expr }) => {
+         $( if let Ok($id) = $id.downcast_clone::<$t>() { $arm } else )+
+         { $default }
+     }
+ }
+ 
+ #[derive(Default)]
+ struct VarMap {
+     vars: Vec<(tvm::ir::tir::Var, egg::Symbol)>,
+     objs: Vec<ObjectRef>,
+ }
+ 
+ impl VarMap {
+     // FIXME this should eventually do the right thing for TVM variables
+     // right now it depends on them having unique names
+     fn make_symbol(&mut self, var: tvm::ir::tir::Var) -> egg::Symbol {
+         let sym = egg::Symbol::from(var.name_hint.as_str().unwrap());
+         for (_, sym2) in &self.vars {
+             if sym == *sym2 {
+                 return sym;
+             }
+         }
+ 
+         self.vars.push((var, sym));
+         sym
+     }
+ 
+     fn get_symbol(&self, sym: egg::Symbol) -> tvm::ir::tir::Var {
+         for (v, sym2) in &self.vars {
+             if sym == *sym2 {
+                 return v.clone();
+             }
+         }
+         panic!("Should have found a var")
+     }
+ 
+     fn push_obj(&mut self, obj: impl IsObjectRef) -> usize {
+         let i = self.objs.len();
+         self.objs.push(obj.upcast());
+         i
+     }
+ 
+     fn get_obj<T: IsObjectRef>(&self, i: usize) -> T {
+         self.objs[i].clone().downcast().expect("bad downcast")
+     }
+ }
+ 
+ fn to_egg(vars: &mut VarMap, prim: &PrimExpr) -> RecExpr {
+     fn build(vars: &mut VarMap, p: &PrimExpr, recexpr: &mut RecExpr) -> egg::Id {
+         macro_rules! r {
+             ($e:expr) => {
+                 build(vars, &$e, recexpr)
+             };
+         }
+ 
+         let dt = recexpr.add(Expr::DataType(p.datatype));
+         let e = downcast_match!(p; {
+             tir::Add => Expr::Add([dt, r!(p.a), r!(p.b)]),
+             tir::Sub => Expr::Sub([dt, r!(p.a), r!(p.b)]),
+             tir::Mul => Expr::Mul([dt, r!(p.a), r!(p.b)]),
+ 
+             tir::Div => Expr::Div([dt, r!(p.a), r!(p.b)]),
+             tir::Mod => Expr::Mod([dt, r!(p.a), r!(p.b)]),
+             tir::FloorDiv => Expr::FloorDiv([dt, r!(p.a), r!(p.b)]),
+             tir::FloorMod => Expr::FloorMod([dt, r!(p.a), r!(p.b)]),
+ 
+             tir::Min => Expr::Min([dt, r!(p.a), r!(p.b)]),
+             tir::Max => Expr::Max([dt, r!(p.a), r!(p.b)]),
+ 
+             tir::Ramp => Expr::Ramp([dt, r!(p.start), r!(p.stride), recexpr.add(Expr::Int(p.lanes.into()))]),
+             tir::Select => Expr::Select([dt, r!(p.condition), r!(p.true_value), r!(p.false_value)]),
+ 
+             tir::Eq => Expr::Equal([dt, r!(p.a), r!(p.b)]),
+             tir::Ne => Expr::NotEqual([dt, r!(p.a), r!(p.b)]),
+             tir::Lt => Expr::Less([dt, r!(p.a), r!(p.b)]),
+             tir::Le => Expr::LessEqual([dt, r!(p.a), r!(p.b)]),
+             tir::Gt => Expr::Greater([dt, r!(p.a), r!(p.b)]),
+             tir::Ge => Expr::GreaterEqual([dt, r!(p.a), r!(p.b)]),
+ 
+             tir::And => Expr::And([dt, r!(p.a), r!(p.b)]),
+             tir::Or => Expr::Or([dt, r!(p.a), r!(p.b)]),
+             tir::Not => Expr::Not([dt, r!(p.value)]),
+ 
+             tir::Broadcast => Expr::Broadcast([dt, r!(p.value), recexpr.add(Expr::Int(p.lanes.into()))]),
+ 
+             tir::Let => {
+                 let sym = recexpr.add(Expr::Symbol(vars.make_symbol(p.var.clone())));
+                 Expr::Let([dt, sym, r!(p.value), r!(p.body)])
+             }
+             tir::Var => {
+                 let sym = recexpr.add(Expr::Symbol(vars.make_symbol(p)));
+                 Expr::Var([dt, sym])
+             }
+             tir::IntImm => {
+                 let int = recexpr.add(Expr::Int(p.value));
+                 Expr::IntImm([dt, int])
+             }
+             tir::FloatImm => {
+                 let float = recexpr.add(Expr::Float(NotNan::new(p.value).unwrap()));
+                 Expr::FloatImm([dt, float])
+             }
+             tir::Cast => Expr::Cast([dt, r!(p.value)]),
+ 
+             tir::Call => {
+                 let op = vars.push_obj(p.op.clone());
+                 let mut arg_ids = vec![dt];
+                 for i in 0..p.args.len() {
+                     let arg: PrimExpr = p.args.get(i as isize).expect("array get fail");
+                     arg_ids.push(r!(arg));
+                 }
+                 Expr::Call(op, arg_ids)
+             },
+             tir::Load => {
+                 let sym = recexpr.add(Expr::Symbol(vars.make_symbol(p.var.clone())));
+                 Expr::Load([dt, sym, r!(p.index), r!(p.predicate)])
+             },
+             else => {
+                 println!("Failed to downcast type '{}': {}", p.type_key(), tvm::runtime::debug_print(p.clone().upcast()).unwrap().to_str().unwrap());
+                 Expr::Object(vars.push_obj(p.clone()))
+             }
+         });
+ 
+         recexpr.add(e)
+     }
+ 
+     let mut recexpr = Default::default();
+     build(vars, prim, &mut recexpr);
+     recexpr
+ }
+ 
+ fn from_egg(vars: &VarMap, recexpr: &RecExpr) -> PrimExpr {
+     fn build(vars: &VarMap, nodes: &[Expr]) -> PrimExpr {
+         let go = |i: &egg::Id| build(vars, &nodes[..usize::from(*i) + 1]);
+         let get_dt = |i: &egg::Id| nodes[usize::from(*i)].to_dtype().unwrap();
+         let prim: PrimExpr = match nodes.last().expect("cannot be empty") {
+             Expr::Var([_dt, s]) => match &nodes[usize::from(*s)] {
+                 Expr::Symbol(sym) => vars.get_symbol(*sym).upcast(),
+                 n => panic!("Expected a symbol, got {:?}", n),
+             },
+             Expr::IntImm([dt, v]) => {
+                 let value = nodes[usize::from(*v)].to_int().unwrap();
+                 tir::IntImm::new(get_dt(dt), value).upcast()
+             }
+             Expr::FloatImm([dt, v]) => {
+                 let value = nodes[usize::from(*v)].to_float().unwrap();
+                 tir::FloatImm::new(get_dt(dt), value).upcast()
+             }
+             Expr::Let([dt, s, value, body]) => {
+                 let var = match &nodes[usize::from(*s)] {
+                     Expr::Symbol(sym) => vars.get_symbol(*sym).upcast(),
+                     n => panic!("Expected a symbol, got {:?}", n),
+                 };
+                 tir::Let::new(get_dt(dt), var, go(value), go(body)).upcast()
+             }
+             Expr::Load([dt, s, value, body]) => {
+                 let var = match &nodes[usize::from(*s)] {
+                     Expr::Symbol(sym) => vars.get_symbol(*sym).upcast(),
+                     n => panic!("Expected a symbol, got {:?}", n),
+                 };
+                 tir::Load::new(get_dt(dt), var, go(value), go(body)).upcast()
+             }
+ 
+             Expr::Add([dt, a, b]) => tir::Add::new(get_dt(dt), go(a), go(b)).upcast(),
+             Expr::Sub([dt, a, b]) => tir::Sub::new(get_dt(dt), go(a), go(b)).upcast(),
+             Expr::Mul([dt, a, b]) => tir::Mul::new(get_dt(dt), go(a), go(b)).upcast(),
+ 
+             Expr::Div([dt, a, b]) => tir::Div::new(get_dt(dt), go(a), go(b)).upcast(),
+             Expr::Mod([dt, a, b]) => tir::Mod::new(get_dt(dt), go(a), go(b)).upcast(),
+             Expr::FloorDiv([dt, a, b]) => tir::FloorDiv::new(get_dt(dt), go(a), go(b)).upcast(),
+             Expr::FloorMod([dt, a, b]) => tir::FloorMod::new(get_dt(dt), go(a), go(b)).upcast(),
+ 
+             Expr::Min([dt, a, b]) => tir::Min::new(get_dt(dt), go(a), go(b)).upcast(),
+             Expr::Max([dt, a, b]) => tir::Max::new(get_dt(dt), go(a), go(b)).upcast(),
+ 
+             Expr::Equal([dt, a, b]) => tir::Eq::new(get_dt(dt), go(a), go(b)).upcast(),
+             Expr::NotEqual([dt, a, b]) => tir::Ne::new(get_dt(dt), go(a), go(b)).upcast(),
+             Expr::Less([dt, a, b]) => tir::Lt::new(get_dt(dt), go(a), go(b)).upcast(),
+             Expr::LessEqual([dt, a, b]) => tir::Le::new(get_dt(dt), go(a), go(b)).upcast(),
+             Expr::Greater([dt, a, b]) => tir::Gt::new(get_dt(dt), go(a), go(b)).upcast(),
+             Expr::GreaterEqual([dt, a, b]) => tir::Ge::new(get_dt(dt), go(a), go(b)).upcast(),
+ 
+             Expr::And([dt, a, b]) => tir::And::new(get_dt(dt), go(a), go(b)).upcast(),
+             Expr::Or([dt, a, b]) => tir::Or::new(get_dt(dt), go(a), go(b)).upcast(),
+             Expr::Not([dt, a]) => tir::Not::new(get_dt(dt), go(a)).upcast(),
+ 
+             Expr::Ramp([dt, a, b, c]) => {
+                 let len = &nodes[usize::from(*c)];
+                 let i = len
+                     .to_int()
+                     .unwrap_or_else(|| panic!("Ramp lanes must be an int, got {:?}", len));
+                 tir::Ramp::new(get_dt(dt), go(a), go(b), i as i32).upcast()
+             }
+             Expr::Broadcast([dt, val, lanes]) => {
+                 let lanes = &nodes[usize::from(*lanes)];
+                 let lanes = lanes
+                     .to_int()
+                     .unwrap_or_else(|| panic!("Ramp lanes must be an int, got {:?}", lanes));
+                 println!("dt: {}", get_dt(dt));
+                 tir::Broadcast::new(get_dt(dt), go(val), lanes as i32).upcast()
+             }
+ 
+             Expr::Select([dt, a, b, c]) => tir::Select::new(get_dt(dt), go(a), go(b), go(c)).upcast(),
+             Expr::Cast([dt, a]) => tir::Cast::new(get_dt(dt), go(a)).upcast(),
+             Expr::Call(expr, args) => {
+                 let arg_exprs: Vec<PrimExpr> = args[1..].iter().map(go).collect();
+                 let arg_exprs = Array::from_vec(arg_exprs).expect("failed to convert args");
+                 tir::Call::new(get_dt(&args[0]), vars.get_obj(*expr), arg_exprs).upcast()
+             }
+ 
+             Expr::Object(i) => vars.get_obj(*i),
+             node => panic!("I don't know how to extract {:?}", node),
+         };
+         assert_ne!(prim.datatype.bits(), 0);
+         assert_ne!(prim.datatype.lanes(), 0);
+         prim
+     }
+     build(vars, recexpr.as_ref())
+ }
+ 
+ fn run(
+     input: PrimExpr,
+     expected: Option<PrimExpr>,
+     map: Map<PrimExpr, ConstIntBound>,
+ ) -> Result<PrimExpr, String> {
+     use egg::{CostFunction, Extractor};
+ 
+     let mut bounds = BoundsMap::default();
+     for (k, v) in map {
+         if let Ok(var) = k.downcast_clone::<tir::Var>() {
+             let sym: egg::Symbol = var.name_hint.as_str().unwrap().into();
+             bounds.insert(sym, (v.min_value, v.max_value));
+         } else {
+             println!("Non var in bounds map: {}", tvm::ir::as_text(k));
+         }
+     }
+ 
+     let mut vars = VarMap::default();
+     let expr = to_egg(&mut vars, &input);
+     let mut runner = math::default_runner();
+     runner.egraph.analysis.bounds = bounds;
+ 
+     let mut runner = runner.with_expr(&expr).run(&math::rules());
+     // runner.print_report();
+     let mut extractor = Extractor::new(&runner.egraph, math::CostFn);
+     let root = runner.egraph.find(runner.roots[0]);
+     let (cost, best) = extractor.find_best(root);
+     if let Some(expected) = expected {
+         let mut expected_vars = VarMap::default();
+         let expected_expr = to_egg(&mut expected_vars, &expected);
+         let expected_root = runner.egraph.add_expr(&expected_expr);
+         if expected_root != root {
+             return Err(format!(
+                 "\n\nFailed to prove them equal!\nExpected:\n{}\nFound:\n{}\n",
+                 expected_expr.pretty(40),
+                 best.pretty(40)
+             ));
+         }
+         let expected_cost = math::CostFn.cost_rec(&expected_expr);
+         if expected_cost != cost {
+             let msg = format!(
+                 "\n\nCosts not equal: Expected {}:\n{}\nFound {}:\n{}\n",
+                 expected_cost,
+                 expected_expr.pretty(40),
+                 cost,
+                 best.pretty(40)
+             );
+             if cost < expected_cost {
+                 println!("egg wins: {}", msg)
+             } else {
+                 return Err(msg);
+             }
+         }
+     }
+     log::info!("  returning... {}", best.pretty(60));
+     Ok(from_egg(&vars, &best))
+ }
+ 
+ fn simplify(prim: PrimExpr, map: Map<PrimExpr, ConstIntBound>) -> Result<PrimExpr, tvm::Error> {
+     log::debug!("map: {:?}", map);
+     run(prim, None, map).map_err(tvm::Error::CallFailed)
+ }
+ 
+ fn simplify_and_check(
+     prim: PrimExpr,
+     check: PrimExpr,
+     map: Map<PrimExpr, ConstIntBound>,
+ ) -> Result<PrimExpr, tvm::Error> {
+     log::debug!("check map: {:?}", map);
+     run(prim, Some(check), map).map_err(tvm::Error::CallFailed)
+ }
+ 
+ initialize!({
+     let _ = env_logger::try_init();
+     // NOTE this print prevents a segfault (on Linux) for now...
+     println!("Initializing simplifier... ");
+     register_override(simplify, "egg.simplify", true).expect("failed to initialize simplifier");
+     register_override(simplify_and_check, "egg.simplify_and_check", true)
+         .expect("failed to initialize simplifier");
+     log::debug!("done!");
+ });
+ 
\ No newline at end of file


[incubator-tvm] 20/23: WIP

Posted by jr...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

jroesch pushed a commit to branch cargo-build
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git

commit b8dcc35801e7f0f3c696e0d030b3bd0931127785
Author: Jared Roesch <ro...@gmail.com>
AuthorDate: Thu Oct 22 22:18:58 2020 -0700

    WIP
---
 rust/tvm-macros/src/external.rs      |   2 +-
 rust/tvm-macros/src/lib.rs           |   3 +-
 rust/tvm-macros/src/object.rs        |  23 ++++++
 rust/tvm-rt/src/object/mod.rs        |   9 +--
 rust/tvm-rt/src/object/object_ptr.rs |  16 ++++
 rust/tvm-rt/src/string.rs            |   1 +
 rust/tvm-rt/src/value.rs             |   1 -
 rust/tvm-sys/src/datatype.rs         |   4 +
 rust/tvm/src/ir/module.rs            | 152 ++++++++++++++++++++++++++++++++---
 rust/tvm/src/ir/relay/mod.rs         |  36 +++------
 rust/tvm/src/ir/tir.rs               |  14 ++++
 11 files changed, 218 insertions(+), 43 deletions(-)

diff --git a/rust/tvm-macros/src/external.rs b/rust/tvm-macros/src/external.rs
index 44a242c..51a389b 100644
--- a/rust/tvm-macros/src/external.rs
+++ b/rust/tvm-macros/src/external.rs
@@ -21,7 +21,7 @@ use proc_macro_error::abort;
 use quote::quote;
 use syn::parse::{Parse, ParseStream, Result};
 
-use syn::{Token, FnArg, Signature, Attribute, token::Semi, Visibility, Generics, Ident, Lit, Meta, NestedMeta, Pat, ReturnType, TraitItemMethod, Type};
+use syn::{FnArg, Signature, Attribute, token::Semi, Visibility, Generics, Ident, Lit, Meta, NestedMeta, Pat, ReturnType, Type};
 
 struct ExternalItem {
     attrs: Vec<Attribute>,
diff --git a/rust/tvm-macros/src/lib.rs b/rust/tvm-macros/src/lib.rs
index 32f2839..e563a57 100644
--- a/rust/tvm-macros/src/lib.rs
+++ b/rust/tvm-macros/src/lib.rs
@@ -30,7 +30,8 @@ pub fn import_module(input: TokenStream) -> TokenStream {
     import_module::macro_impl(input)
 }
 
-#[proc_macro_derive(Object, attributes(base, ref_name, type_key))]
+#[proc_macro_error]
+#[proc_macro_derive(Object, attributes(base, ref_name, type_key, no_derive))]
 pub fn macro_impl(input: TokenStream) -> TokenStream {
     // let input = proc_macro2::TokenStream::from(input);
     TokenStream::from(object::macro_impl(input))
diff --git a/rust/tvm-macros/src/object.rs b/rust/tvm-macros/src/object.rs
index ff72d6a..7e6a934 100644
--- a/rust/tvm-macros/src/object.rs
+++ b/rust/tvm-macros/src/object.rs
@@ -36,6 +36,8 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream {
         .map(attr_to_str)
         .expect("Failed to get type_key");
 
+    let derive = get_attr(&derive_input, "no_derive").map(|_| false).unwrap_or(true);
+
     let ref_id = get_attr(&derive_input, "ref_name")
         .map(|a| Ident::new(attr_to_str(a).value().as_str(), Span::call_site()))
         .unwrap_or_else(|| {
@@ -185,5 +187,26 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream {
 
     expanded.extend(base_tokens);
 
+    if derive {
+        let derives = quote! {
+            impl std::hash::Hash for #ref_id {
+                fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
+                    self.0.hash(state)
+                }
+            }
+
+            impl std::cmp::PartialEq for #ref_id {
+                fn eq(&self, other: &Self) -> bool {
+                    self.0 == other.0
+                }
+            }
+
+            impl std::cmp::Eq for #ref_id {}
+        };
+
+
+        expanded.extend(derives);
+    }
+
     TokenStream::from(expanded)
 }
diff --git a/rust/tvm-rt/src/object/mod.rs b/rust/tvm-rt/src/object/mod.rs
index e48c017..7e6107d 100644
--- a/rust/tvm-rt/src/object/mod.rs
+++ b/rust/tvm-rt/src/object/mod.rs
@@ -90,12 +90,7 @@ external! {
     #[name("ir.DebugPrint")]
     pub fn debug_print(object: ObjectRef) -> CString;
     #[name("node.StructuralHash")]
-    fn structural_hash(object: ObjectRef, map_free_vars: bool) -> ObjectRef;
+    fn structural_hash(object: ObjectRef, map_free_vars: bool) -> i64;
     #[name("node.StructuralEqual")]
-    fn structural_equal(lhs: ObjectRef, rhs: ObjectRef, assert_mode: bool, map_free_vars: bool) -> ObjectRef;
+    fn structural_equal(lhs: ObjectRef, rhs: ObjectRef, assert_mode: bool, map_free_vars: bool) -> bool;
 }
-
-// external! {
-//     #[name("ir.TextPrinter")]
-//     fn as_text(object: ObjectRef) -> CString;
-// }
diff --git a/rust/tvm-rt/src/object/object_ptr.rs b/rust/tvm-rt/src/object/object_ptr.rs
index 77254d2..a923506 100644
--- a/rust/tvm-rt/src/object/object_ptr.rs
+++ b/rust/tvm-rt/src/object/object_ptr.rs
@@ -342,6 +342,22 @@ impl<'a, T: IsObject> TryFrom<ArgValue<'a>> for ObjectPtr<T> {
     }
 }
 
+impl<T: IsObject> std::hash::Hash for ObjectPtr<T> {
+    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
+        state.write_i64(super::structural_hash(ObjectRef(Some(self.clone().upcast())), false).unwrap())
+    }
+}
+
+impl<T: IsObject> PartialEq for ObjectPtr<T> {
+    fn eq(&self, other: &Self) -> bool {
+        let lhs = ObjectRef(Some(self.clone().upcast()));
+        let rhs = ObjectRef(Some(other.clone().upcast()));
+        super::structural_equal(lhs, rhs, false, false).unwrap()
+    }
+}
+
+impl<T: IsObject> Eq for ObjectPtr<T> {}
+
 #[cfg(test)]
 mod tests {
     use super::{Object, ObjectPtr};
diff --git a/rust/tvm-rt/src/string.rs b/rust/tvm-rt/src/string.rs
index 6ff24be..e9a76d2 100644
--- a/rust/tvm-rt/src/string.rs
+++ b/rust/tvm-rt/src/string.rs
@@ -28,6 +28,7 @@ use tvm_macros::Object;
 #[derive(Object)]
 #[ref_name = "String"]
 #[type_key = "runtime.String"]
+#[no_derive]
 pub struct StringObj {
     base: Object,
     data: *const u8,
diff --git a/rust/tvm-rt/src/value.rs b/rust/tvm-rt/src/value.rs
index c49944d..b8cd190 100644
--- a/rust/tvm-rt/src/value.rs
+++ b/rust/tvm-rt/src/value.rs
@@ -22,7 +22,6 @@
 //! `RetValue` is the owned version of `TVMPODValue`.
 
 use std::convert::TryFrom;
-// use std::ffi::c_void;
 
 use crate::{ArgValue, Module, RetValue};
 use tvm_sys::{errors::ValueDowncastError, ffi::TVMModuleHandle, try_downcast};
diff --git a/rust/tvm-sys/src/datatype.rs b/rust/tvm-sys/src/datatype.rs
index 8050d93..5f7e0c3 100644
--- a/rust/tvm-sys/src/datatype.rs
+++ b/rust/tvm-sys/src/datatype.rs
@@ -83,6 +83,10 @@ impl DataType {
         DataType::new(DL_FLOAT_CODE, bits, lanes)
     }
 
+    pub const fn float32() -> DataType {
+        Self::float(32, 1)
+    }
+
     pub const fn uint(bits: u8, lanes: u16) -> DataType {
         DataType::new(DL_UINT_CODE, bits, lanes)
     }
diff --git a/rust/tvm/src/ir/module.rs b/rust/tvm/src/ir/module.rs
index 3b60b0c..db32ce2 100644
--- a/rust/tvm/src/ir/module.rs
+++ b/rust/tvm/src/ir/module.rs
@@ -17,6 +17,7 @@
  * under the License.
  */
 use std::io::Result as IOResult;
+use std::iter::FromIterator;
 use std::path::Path;
 
 use thiserror::Error;
@@ -33,8 +34,9 @@ use super::function::BaseFunc;
 use super::source_map::SourceMap;
 use super::{ty::GlobalTypeVar, relay};
 
-// TODO(@jroesch): define type
+use tvm_macros::Object;
 
+// TODO(@jroesch): define type
 type TypeData = ObjectRef;
 type GlobalTypeVar = ObjectRef;
 
@@ -64,9 +66,11 @@ external! {
     fn parse_module(file_name: TVMString, source: TVMString) -> IRModule;
     #[name("parser.ParseExpr")]
     fn parse_expression(file_name: TVMString, source: TVMString) -> IRModule;
+    #[name("ir.IRModule")]
+    fn module_new(funcs: Map<GlobalVar, BaseFunc>, types: Map<GlobalTypeVar, TypeData>) -> IRModule;
     // Module methods
     #[name("ir.Module_Add")]
-    fn module_add(module: IRModule, type_name: GlobalVar, expr: relay::Expr, update: bool) -> ();
+    fn module_add(module: IRModule, type_name: GlobalVar, expr: BaseFunc, update: bool) -> IRModule;
     #[name("ir.Module_AddDef")]
     fn module_add_def(module: IRModule, type_name: GlobalTypeVar, type_data: TypeData, update: bool) -> ();
     #[name("ir.Module_GetGlobalVar")]
@@ -78,15 +82,15 @@ external! {
     #[name("ir.Module_Lookup_str")]
     fn module_lookup_str(module: IRModule, name: TVMString) -> BaseFunc;
     #[name("ir.Module_GetGlobalTypeVars")]
-    fn module_get_global_type_vars() -> Array<GlobalTypeVar>;
+    fn module_get_global_type_vars(module: IRModule) -> Array<GlobalTypeVar>;
     #[name("ir.Module_ContainGlobalVar")]
-    fn module_contains_global_var(name: TVMString) -> bool;
+    fn module_contains_global_var(module: IRModule, name: TVMString) -> bool;
     #[name("ir.Module_ContainGlobalTypeVar")]
-    fn module_contains_global_type_var(name: TVMString) -> bool;
+    fn module_contains_global_type_var(module: IRModule, name: TVMString) -> bool;
     #[name("ir.Module_LookupDef")]
-    fn module_lookup_def(module: IRModule, global: GlobalTypeVar) -> TypeDef;
+    fn module_lookup_def(module: IRModule, global: GlobalTypeVar) -> TypeData;
     #[name("ir.Module_LookupDef_str")]
-    fn module_lookup_def_str(module: IRModule, global: GlobalTypeVar) -> TypeDef;
+    fn module_lookup_def_str(module: IRModule, global: GlobalTypeVar) -> TypeData;
     #[name("ir.Module_LookupTag")]
     fn module_lookup_tag(module: IRModule, tag: i32) -> relay::Constructor;
     #[name("ir.Module_FromExpr")]
@@ -99,8 +103,12 @@ external! {
 
 // Note: we don't expose update here as update is going to be removed.
 
-
 impl IRModule {
+    pub fn new<F, T>(funcs: F, types: T) -> Result<IRModule>
+    where F: IntoIterator<Item=(GlobalVar, BaseFunc)>, T: IntoIterator<Item=(GlobalTypeVar, TypeData)> {
+        module_new(Map::from_iter(funcs), Map::from_iter(types))
+    }
+
     pub fn parse<N, S>(file_name: N, source: S) -> Result<IRModule>
     where
         N: Into<TVMString>,
@@ -119,6 +127,13 @@ impl IRModule {
         Ok(module)
     }
 
+    pub fn add(
+        &mut self,
+        var: GlobalVar,
+        func: BaseFunc) -> Result<IRModule> {
+            module_add(self.clone(), var, func, true)
+        }
+
     pub fn add_def(
         &mut self,
         type_name: GlobalTypeVar,
@@ -146,10 +161,127 @@ impl IRModule {
     {
         module_lookup_str(self.clone(), name.into())
     }
+
+    pub fn get_global_type_vars(&self) -> Result<Array<GlobalTypeVar>> {
+        module_get_global_type_vars(self.clone())
+    }
+
+    pub fn contains_global_var<S: Into<TVMString>>(&self, name: S) -> Result<bool> {
+        module_contains_global_var(self.clone(), name.into())
+    }
+
+    pub fn contains_global_type_var<S: Into<TVMString>>(&self, name: S) -> Result<bool> {
+        module_contains_global_type_var(self.clone(), name.into())
+    }
+
+    pub fn lookup_def(&self, global: GlobalTypeVar) -> Result<TypeData> {
+        module_lookup_def(self.clone(), global)
+    }
+
+    pub fn lookup_def_str(&self, global: GlobalTypeVar) -> Result<TypeData> {
+        module_lookup_def_str(self.clone(), global)
+    }
+
+    pub fn lookup_tag(&self, tag: i32) -> Result<relay::Constructor> {
+        module_lookup_tag(self.clone(), tag)
+    }
+
+    pub fn from_expr(expr: relay::Expr, funcs: Map<GlobalVar, BaseFunc>, types: Map<GlobalTypeVar, TypeData>) -> Result<IRModule> {
+        module_from_expr(expr, funcs, types)
+    }
+
+    pub fn import<S: Into<TVMString>>(&mut self, path: S) -> Result<()> {
+        module_import(self.clone(), path.into())
+    }
+
+    pub fn import_from_std<S: Into<TVMString>>(&mut self, path: S) -> Result<()> {
+        module_import_from_std(self.clone(), path.into())
+    }
 }
 
 #[cfg(test)]
 mod tests {
-    // #[test]
-    // fn
+    use std::collections::HashMap;
+    use super::relay::*;
+    use super::*;
+    use super::super::span::Span;
+    use tvm_rt::IsObjectRef;
+
+    #[test]
+    fn test_module_add() -> anyhow::Result<()> {
+        let funcs = HashMap::<GlobalVar, BaseFunc>::new();
+        let types =  HashMap::<GlobalTypeVar, TypeData>::new();
+        let mut module = IRModule::new(funcs, types)?;
+        let x = Var::static_tensor("x".into(), vec![1, 1], DataType::float32());
+        let params = Array::from_vec(vec![x.clone()])?;
+        let func = relay::Function::simple(params, x.upcast()).upcast();
+        let module = module.add(GlobalVar::new("foo".into(), Span::null()), func)?;
+        // let lfunc = module.lookup_str("foo")?;
+        // let lfunc = lfunc.downcast::<relay::Function>()?;
+        // assert_eq!(lfunc.params.len(), 1);
+        Ok(())
+    }
+
+    #[test]
+    fn test_module_add_def() {
+
+    }
+
+    #[test]
+    fn test_get_global_var() {
+
+    }
+
+    #[test]
+    fn test_get_global_vars() {
+
+    }
+
+    #[test]
+    fn test_lookup() {
+
+    }
+
+
+    // pub fn get_global_type_vars(&self) -> Result<Array<GlobalTypeVar>> {
+    //     module_get_global_type_vars(self.clone())
+    // }
+
+    // pub fn contains_global_var<S: Into<TVMString>>(&self, name: S) -> Result<bool> {
+    //     module_contains_global_var(self.clone(), name.into())
+    // }
+
+    // pub fn contains_global_type_var<S: Into<TVMString>>(&self, name: S) -> Result<bool> {
+    //     module_contains_global_type_var(self.clone(), name.into())
+    // }
+
+    #[test]
+    fn test_lookup_def() {
+
+    }
+    // pub fn lookup_def(&self, global: GlobalTypeVar) -> Result<TypeData> {
+    //     module_lookup_def(self.clone(), global)
+    // }
+
+    // pub fn lookup_def_str(&self, global: GlobalTypeVar) -> Result<TypeData> {
+    //     module_lookup_def_str(self.clone(), global)
+    // }
+
+    // pub fn lookup_tag(&self, tag: i32) -> Result<relay::Constructor> {
+    //     module_lookup_tag(self.clone(), tag)
+    // }
+
+    // pub fn from_expr(expr: relay::Expr, funcs: Map<GlobalVar, BaseFunc>, types: Map<GlobalTypeVar, TypeData>) -> Result<IRModule> {
+    //     module_from_expr(expr, funcs, types)
+    // }
+
+
+    // pub fn import<S: Into<TVMString>>(&mut self, path: S) -> Result<()> {
+    //     module_import(self.clone(), path.into())
+    // }
+
+
+    // pub fn import_from_std<S: Into<TVMString>>(&mut self, path: S) -> Result<()> {
+    //     module_import_from_std(self.clone(), path.into())
+    // }
 }
diff --git a/rust/tvm/src/ir/relay/mod.rs b/rust/tvm/src/ir/relay/mod.rs
index 530b120..90b7a6a 100644
--- a/rust/tvm/src/ir/relay/mod.rs
+++ b/rust/tvm/src/ir/relay/mod.rs
@@ -16,11 +16,6 @@
  * specific language governing permissions and limitations
  * under the License.
  */
-
-pub mod attrs;
-
-use std::hash::Hash;
-
 use crate::runtime::array::Array;
 use crate::runtime::{object::*, String as TString};
 
@@ -29,11 +24,15 @@ use super::expr::BaseExprNode;
 use super::function::BaseFuncNode;
 use super::span::Span;
 use super::ty::{Type, TypeNode};
+use super::span::Span;
 
 use tvm_macros::Object;
 use tvm_rt::NDArray;
 
 pub use super::expr::{GlobalVar, GlobalVarNode};
+pub use crate::runtime::DataType;
+
+pub mod attrs;
 
 #[repr(C)]
 #[derive(Object)]
@@ -58,20 +57,6 @@ impl ExprNode {
     }
 }
 
-impl Hash for Expr {
-    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
-        self.as_ptr().unwrap().ptr.hash(state)
-    }
-}
-
-impl PartialEq for Expr {
-    fn eq(&self, other: &Self) -> bool {
-        self.as_ptr().unwrap().ptr.eq(&other.as_ptr().unwrap().ptr)
-    }
-}
-
-impl Eq for Expr {}
-
 #[repr(C)]
 #[derive(Object)]
 #[ref_name = "Id"]
@@ -140,11 +125,11 @@ pub struct VarNode {
 }
 
 impl Var {
-    pub fn new(name_hint: String, type_annotation: Type, _span: ObjectRef) -> Var {
+    pub fn new(name_hint: String, type_annotation: Type, _span: Span) -> Var {
         let node = VarNode {
             base: ExprNode::base::<VarNode>(),
             vid: Id::new(name_hint.into()),
-            type_annotation,
+            type_annotation: type_annotation,
         };
         Var(Some(ObjectPtr::new(node)))
     }
@@ -153,8 +138,9 @@ impl Var {
         &self.vid.0.as_ref().unwrap().name_hint
     }
 
-    pub fn to_expr(self) -> Expr {
-        unsafe { Expr(std::mem::transmute(self.0)) }
+    pub fn static_tensor(name_hint: String, sh: Vec<i32>, dtype: DataType) -> Var {
+        let sh = Array::from_vec(sh.into_iter().map(Into::into).collect()).unwrap();
+        Self::new(name_hint, super::ty::TensorType::new(sh, dtype, Span::null()).upcast(), Span::null())
     }
 }
 
@@ -510,6 +496,10 @@ impl Function {
         };
         Function(Some(ObjectPtr::new(node)))
     }
+
+    pub fn simple(params: Array<Var>, body: Expr) -> Function {
+        Self::new(params, body, Type::null(), Array::from_vec(vec![]).unwrap())
+    }
 }
 
 #[cfg(test)]
diff --git a/rust/tvm/src/ir/tir.rs b/rust/tvm/src/ir/tir.rs
index 22d4e02..f07e854 100644
--- a/rust/tvm/src/ir/tir.rs
+++ b/rust/tvm/src/ir/tir.rs
@@ -47,6 +47,20 @@ macro_rules! define_node {
 // TODO(@jroesch): should move up to expr.rs to mirror TVM.
 define_node!(IntImm, "IntImm", "IntImm";
              IntImmNode { value: i64 });
+
+impl From<i32> for IntImm {
+    fn from(i: i32) -> IntImm {
+        IntImm::new(DataType::int(32, 1), i as i64)
+    }
+}
+
+impl From<i32> for PrimExpr {
+    fn from(i: i32) -> PrimExpr {
+        use crate::runtime::IsObjectRef;
+        IntImm::from(i).upcast()
+    }
+}
+
 define_node!(Var, "Var", "tir.Var";
              VarNode { name_hint: TVMString });
 


[incubator-tvm] 01/23: Add initial boilerplate for Rust diagnostic interface.

Posted by jr...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

jroesch pushed a commit to branch cargo-build
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git

commit 1097cbf8a23708b7408dfbab7c419e363af57728
Author: Jared Roesch <jr...@octoml.ai>
AuthorDate: Fri Oct 9 01:18:15 2020 -0700

    Add initial boilerplate for Rust diagnostic interface.
---
 python/tvm/ir/diagnostics/__init__.py |   2 +-
 rust/tvm/src/ir/diagnostics.rs        | 239 ++++++++++++++++++++++++++++++++++
 rust/tvm/src/ir/mod.rs                |   1 +
 3 files changed, 241 insertions(+), 1 deletion(-)

diff --git a/python/tvm/ir/diagnostics/__init__.py b/python/tvm/ir/diagnostics/__init__.py
index 6503743..0ad2a7a 100644
--- a/python/tvm/ir/diagnostics/__init__.py
+++ b/python/tvm/ir/diagnostics/__init__.py
@@ -37,7 +37,7 @@ def get_renderer():
     """
     return _ffi_api.GetRenderer()
 
-
+@tvm.register_func("diagnostics.override_renderer")
 def override_renderer(render_func):
     """
     Sets a custom renderer for diagnostics.
diff --git a/rust/tvm/src/ir/diagnostics.rs b/rust/tvm/src/ir/diagnostics.rs
new file mode 100644
index 0000000..799a10c
--- /dev/null
+++ b/rust/tvm/src/ir/diagnostics.rs
@@ -0,0 +1,239 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/// The diagnostic interface to TVM, used for reporting and rendering
+/// diagnostic information by the compiler. This module exposes
+/// three key abstractions: a Diagnostic, the DiagnosticContext,
+/// and the DiagnosticRenderer.
+
+use tvm_macros::{Object, external};
+use super::module::IRModule;
+use crate::runtime::{function::{Function, Typed}, array::Array, string::String as TString};
+use crate::runtime::object::{Object, ObjectRef};
+use crate::runtime::function::Result;
+use super::span::Span;
+
+type SourceName = ObjectRef;
+
+/// The diagnostic level, controls the printing of the message.
+#[repr(C)]
+pub enum DiagnosticLevel {
+    Bug = 10,
+    Error = 20,
+    Warning = 30,
+    Note = 40,
+    Help = 50,
+}
+
+/// A compiler diagnostic.
+#[repr(C)]
+#[derive(Object)]
+#[ref_name = "Diagnostic"]
+#[type_key = "Diagnostic"]
+pub struct DiagnosticNode {
+    pub base: Object,
+    /// The level.
+    pub level: DiagnosticLevel,
+    /// The span at which to report an error.
+    pub span: Span,
+    /// The diagnostic message.
+    pub message: TString,
+}
+
+impl Diagnostic {
+    pub fn new(level: DiagnosticLevel, span: Span, message: TString) {
+        todo!()
+    }
+
+    pub fn bug(span: Span) -> DiagnosticBuilder {
+        todo!()
+    }
+
+    pub fn error(span: Span) -> DiagnosticBuilder {
+        todo!()
+    }
+
+    pub fn warning(span: Span) -> DiagnosticBuilder {
+        todo!()
+    }
+
+    pub fn note(span: Span) -> DiagnosticBuilder {
+        todo!()
+    }
+
+    pub fn help(span: Span) -> DiagnosticBuilder {
+        todo!()
+    }
+}
+
+/// A wrapper around std::stringstream to build a diagnostic.
+pub struct DiagnosticBuilder {
+    /// The level.
+    pub level: DiagnosticLevel,
+
+    /// The source name.
+    pub source_name: SourceName,
+
+    /// The span of the diagnostic.
+    pub span: Span,
+}
+
+//   /*! \brief Display diagnostics in a given display format.
+//    *
+//    * A diagnostic renderer is responsible for converting the
+//    * raw diagnostics into consumable output.
+//    *
+//    * For example the terminal renderer will render a sequence
+//    * of compiler diagnostics to std::out and std::err in
+//    * a human readable form.
+//    */
+//   class DiagnosticRendererNode : public Object {
+//    public:
+//     TypedPackedFunc<void(DiagnosticContext ctx)> renderer;
+
+//     // override attr visitor
+//     void VisitAttrs(AttrVisitor* v) {}
+
+//     static constexpr const char* _type_key = "DiagnosticRenderer";
+//     TVM_DECLARE_FINAL_OBJECT_INFO(DiagnosticRendererNode, Object);
+//   };
+
+//   class DiagnosticRenderer : public ObjectRef {
+//    public:
+//     TVM_DLL DiagnosticRenderer(TypedPackedFunc<void(DiagnosticContext ctx)> render);
+//     TVM_DLL DiagnosticRenderer()
+//         : DiagnosticRenderer(TypedPackedFunc<void(DiagnosticContext ctx)>()) {}
+
+//     void Render(const DiagnosticContext& ctx);
+
+//     DiagnosticRendererNode* operator->() {
+//       CHECK(get() != nullptr);
+//       return static_cast<DiagnosticRendererNode*>(get_mutable());
+//     }
+
+//     TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(DiagnosticRenderer, ObjectRef, DiagnosticRendererNode);
+//   };
+
+// @tvm._ffi.register_object("DiagnosticRenderer")
+// class DiagnosticRenderer(Object):
+//     """
+//     A diagnostic renderer, which given a diagnostic context produces a "rendered"
+//     form of the diagnostics for either human or computer consumption.
+//     """
+
+//     def __init__(self, render_func):
+//         self.__init_handle_by_constructor__(_ffi_api.DiagnosticRenderer, render_func)
+
+//     def render(self, ctx):
+//         """
+//         Render the provided context.
+
+//         Params
+//         ------
+//         ctx: DiagnosticContext
+//             The diagnostic context to render.
+//         """
+//         return _ffi_api.DiagnosticRendererRender(self, ctx
+pub type DiagnosticRenderer = ObjectRef;
+
+#[repr(C)]
+#[derive(Object)]
+#[ref_name = "DiagnosticContext"]
+#[type_key = "DiagnosticContext"]
+/// A diagnostic context for recording errors against a source file.
+pub struct DiagnosticContextNode {
+    // The base type.
+    pub base: Object,
+
+    /// The Module to report against.
+    pub module: IRModule,
+
+    /// The set of diagnostics to report.
+    pub diagnostics: Array<Diagnostic>,
+
+    /// The renderer set for the context.
+    pub renderer: DiagnosticRenderer,
+}
+
+// Get the the diagnostic renderer.
+external! {
+    #[name("node.ArrayGetItem")]
+    fn get_renderer() -> DiagnosticRenderer;
+
+    #[name("diagnostics.DiagnosticRenderer")]
+    fn diagnostic_renderer(func: Function) -> DiagnosticRenderer;
+
+    #[name("diagnostics.Emit")]
+    fn emit(ctx: DiagnosticContext, diagnostic: Diagnostic) -> ();
+
+    #[name("diagnostics.DiagnosticContextRender")]
+    fn diagnostic_context_render(ctx: DiagnosticContext) -> ();
+}
+
+/// A diagnostic context which records active errors
+/// and contains a renderer.
+impl DiagnosticContext {
+    pub fn new(module: IRModule, renderer: DiagnosticRenderer) {
+        todo!()
+    }
+
+    pub fn default(module: IRModule) -> DiagnosticContext {
+        todo!()
+    }
+
+    /// Emit a diagnostic.
+    pub fn emit(&mut self, diagnostic: Diagnostic) -> Result<()> {
+        emit(self.clone(), diagnostic)
+    }
+
+    /// Render the errors and raise a DiagnosticError exception.
+    pub fn render(&mut self) -> Result<()> {
+        diagnostic_context_render(self.clone())
+    }
+
+    /// Emit a diagnostic and then immediately attempt to render all errors.
+    pub fn emit_fatal(&mut self, diagnostic: Diagnostic) -> Result<()> {
+        self.emit(diagnostic)?;
+        self.render()?;
+        Ok(())
+    }
+}
+
+// Sets a custom renderer for diagnostics.
+
+// Params
+// ------
+// render_func: Option[Callable[[DiagnosticContext], None]]
+//     If the render_func is None it will remove the current custom renderer
+//     and return to default behavior.
+fn override_renderer<F>(opt_func: Option<F>) -> Result<()>
+where F: Fn(DiagnosticContext) -> ()
+{
+    todo!()
+    // fn ()
+    // diagnostic_renderer(func)
+    // if render_func:
+
+    //     def _render_factory():
+    //         return DiagnosticRenderer(render_func)
+
+    //     register_func("diagnostics.OverrideRenderer", _render_factory, override=True)
+    // else:
+    //     _ffi_api.ClearRenderer()
+}
diff --git a/rust/tvm/src/ir/mod.rs b/rust/tvm/src/ir/mod.rs
index 126d0fa..8450bd7 100644
--- a/rust/tvm/src/ir/mod.rs
+++ b/rust/tvm/src/ir/mod.rs
@@ -20,6 +20,7 @@
 pub mod arith;
 pub mod attrs;
 pub mod expr;
+pub mod diagnostics;
 pub mod function;
 pub mod module;
 pub mod op;


[incubator-tvm] 19/23: WIP

Posted by jr...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

jroesch pushed a commit to branch cargo-build
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git

commit a9ee3cb34c020a4debe75fc9a194303f22d00892
Author: Jared Roesch <ro...@gmail.com>
AuthorDate: Thu Oct 22 11:48:34 2020 -0700

    WIP
---
 rust/tvm-macros/Cargo.toml      |  2 +-
 rust/tvm-macros/src/external.rs | 43 +++++++++++++++++++++++++++++++++--------
 rust/tvm-macros/src/lib.rs      |  1 +
 rust/tvm-rt/src/object/mod.rs   |  2 +-
 rust/tvm/src/ir/module.rs       | 16 +++++++++++----
 5 files changed, 50 insertions(+), 14 deletions(-)

diff --git a/rust/tvm-macros/Cargo.toml b/rust/tvm-macros/Cargo.toml
index 63b8472..8e97d3b 100644
--- a/rust/tvm-macros/Cargo.toml
+++ b/rust/tvm-macros/Cargo.toml
@@ -33,5 +33,5 @@ proc-macro = true
 goblin = "^0.2"
 proc-macro2 = "^1.0"
 quote = "^1.0"
-syn = { version = "1.0.17", features = ["full", "extra-traits"] }
+syn = { version = "^1.0", features = ["full", "parsing", "extra-traits"] }
 proc-macro-error = "^1.0"
diff --git a/rust/tvm-macros/src/external.rs b/rust/tvm-macros/src/external.rs
index de8ada3..44a242c 100644
--- a/rust/tvm-macros/src/external.rs
+++ b/rust/tvm-macros/src/external.rs
@@ -21,9 +21,28 @@ use proc_macro_error::abort;
 use quote::quote;
 use syn::parse::{Parse, ParseStream, Result};
 
-use syn::{FnArg, Generics, Ident, Lit, Meta, NestedMeta, Pat, ReturnType, TraitItemMethod, Type};
+use syn::{Token, FnArg, Signature, Attribute, token::Semi, Visibility, Generics, Ident, Lit, Meta, NestedMeta, Pat, ReturnType, TraitItemMethod, Type};
+
+struct ExternalItem {
+    attrs: Vec<Attribute>,
+    visibility: Visibility,
+    sig: Signature,
+}
+
+impl Parse for ExternalItem {
+    fn parse(input: ParseStream) -> Result<Self> {
+        let item = ExternalItem {
+            attrs: input.call(Attribute::parse_outer)?,
+            visibility: input.parse()?,
+            sig: input.parse()?,
+        };
+        let _semi: Semi = input.parse()?;
+        Ok(item)
+    }
+}
 
 struct External {
+    visibility: Visibility,
     tvm_name: String,
     ident: Ident,
     generics: Generics,
@@ -33,7 +52,8 @@ struct External {
 
 impl Parse for External {
     fn parse(input: ParseStream) -> Result<Self> {
-        let method: TraitItemMethod = input.parse()?;
+        let method: ExternalItem = input.parse()?;
+        let visibility = method.visibility;
         assert_eq!(method.attrs.len(), 1);
         let sig = method.sig;
         let tvm_name = method.attrs[0].parse_meta()?;
@@ -48,8 +68,7 @@ impl Parse for External {
             }
             _ => panic!(),
         };
-        assert_eq!(method.default, None);
-        assert!(method.semi_token != None);
+
         let ident = sig.ident;
         let generics = sig.generics;
         let inputs = sig
@@ -61,6 +80,7 @@ impl Parse for External {
         let ret_type = sig.output;
 
         Ok(External {
+            visibility,
             tvm_name,
             ident,
             generics,
@@ -99,6 +119,7 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
     let mut items = Vec::new();
 
     for external in &ext_input.externs {
+        let visibility = &external.visibility;
         let name = &external.ident;
         let global_name = format!("global_{}", external.ident);
         let global_name = Ident::new(&global_name, Span::call_site());
@@ -127,15 +148,21 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
                         let ty: Type = *pat_type.ty.clone();
                         (ident, ty)
                     }
-                    _ => panic!(),
+                    _ =>  abort! { pat_type,
+                        "Only supports type parameters."
+                    }
                 },
-                _ => panic!(),
+                pat => abort! {
+                        pat, "invalid pattern type for function";
+
+                        note = "{:?} is not allowed here", pat;
+                    }
             })
             .unzip();
 
         let ret_type = match &external.ret_type {
             ReturnType::Type(_, rtype) => *rtype.clone(),
-            _ => panic!(),
+            ReturnType::Default => syn::parse_str::<Type>("()").unwrap(),
         };
 
         let global = quote! {
@@ -150,7 +177,7 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
         items.push(global);
 
         let wrapper = quote! {
-            pub fn #name<#(#ty_params),*>(#(#args : #tys),*) -> #result_type<#ret_type> {
+            #visibility fn #name<#(#ty_params),*>(#(#args : #tys),*) -> #result_type<#ret_type> {
                 let func_ref: #tvm_rt_crate::Function = #global_name.clone();
                 let func_ref: Box<dyn Fn(#(#tys),*) -> #result_type<#ret_type>> = func_ref.into();
                 let res: #ret_type = func_ref(#(#args),*)?;
diff --git a/rust/tvm-macros/src/lib.rs b/rust/tvm-macros/src/lib.rs
index ab75c92..32f2839 100644
--- a/rust/tvm-macros/src/lib.rs
+++ b/rust/tvm-macros/src/lib.rs
@@ -18,6 +18,7 @@
  */
 
 use proc_macro::TokenStream;
+use proc_macro_error::proc_macro_error;
 
 mod external;
 mod import_module;
diff --git a/rust/tvm-rt/src/object/mod.rs b/rust/tvm-rt/src/object/mod.rs
index 46e0342..e48c017 100644
--- a/rust/tvm-rt/src/object/mod.rs
+++ b/rust/tvm-rt/src/object/mod.rs
@@ -88,7 +88,7 @@ pub trait IsObjectRef:
 
 external! {
     #[name("ir.DebugPrint")]
-    fn debug_print(object: ObjectRef) -> CString;
+    pub fn debug_print(object: ObjectRef) -> CString;
     #[name("node.StructuralHash")]
     fn structural_hash(object: ObjectRef, map_free_vars: bool) -> ObjectRef;
     #[name("node.StructuralEqual")]
diff --git a/rust/tvm/src/ir/module.rs b/rust/tvm/src/ir/module.rs
index 8918bdc..3b60b0c 100644
--- a/rust/tvm/src/ir/module.rs
+++ b/rust/tvm/src/ir/module.rs
@@ -31,8 +31,10 @@ use crate::runtime::{external, Object, ObjectRef};
 use super::expr::GlobalVar;
 use super::function::BaseFunc;
 use super::source_map::SourceMap;
+use super::{ty::GlobalTypeVar, relay};
 
 // TODO(@jroesch): define type
+
 type TypeData = ObjectRef;
 type GlobalTypeVar = ObjectRef;
 
@@ -64,7 +66,7 @@ external! {
     fn parse_expression(file_name: TVMString, source: TVMString) -> IRModule;
     // Module methods
     #[name("ir.Module_Add")]
-    fn module_add_def(module: IRModule, type_name: GlobalVar, expr: relay::Expr, update: bool) -> ();
+    fn module_add(module: IRModule, type_name: GlobalVar, expr: relay::Expr, update: bool) -> ();
     #[name("ir.Module_AddDef")]
     fn module_add_def(module: IRModule, type_name: GlobalTypeVar, type_data: TypeData, update: bool) -> ();
     #[name("ir.Module_GetGlobalVar")]
@@ -78,15 +80,15 @@ external! {
     #[name("ir.Module_GetGlobalTypeVars")]
     fn module_get_global_type_vars() -> Array<GlobalTypeVar>;
     #[name("ir.Module_ContainGlobalVar")]
-    fn module_get_global_var(name: TVMString) -> bool;
+    fn module_contains_global_var(name: TVMString) -> bool;
     #[name("ir.Module_ContainGlobalTypeVar")]
-    fn module_get_global_type_var(name: TVMString) -> bool;
+    fn module_contains_global_type_var(name: TVMString) -> bool;
     #[name("ir.Module_LookupDef")]
     fn module_lookup_def(module: IRModule, global: GlobalTypeVar) -> TypeDef;
     #[name("ir.Module_LookupDef_str")]
     fn module_lookup_def_str(module: IRModule, global: GlobalTypeVar) -> TypeDef;
     #[name("ir.Module_LookupTag")]
-    fn module_lookup_tag(module: IRModule, tag: i32) -> Constructor;
+    fn module_lookup_tag(module: IRModule, tag: i32) -> relay::Constructor;
     #[name("ir.Module_FromExpr")]
     fn module_from_expr(expr: relay::Expr, funcs: Map<GlobalVar, BaseFunc>, types: Map<GlobalTypeVar, TypeData>) -> IRModule;
     #[name("ir.Module_Import")]
@@ -145,3 +147,9 @@ impl IRModule {
         module_lookup_str(self.clone(), name.into())
     }
 }
+
+#[cfg(test)]
+mod tests {
+    // #[test]
+    // fn
+}


[incubator-tvm] 17/23: Fix some CR

Posted by jr...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

jroesch pushed a commit to branch cargo-build
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git

commit 8e295b79c2e693d7190751532824b463dc9373e5
Author: Jared Roesch <jr...@octoml.ai>
AuthorDate: Mon Oct 19 19:52:20 2020 -0700

    Fix some CR
---
 rust/tvm/src/ir/diagnostics/codespan.rs | 6 ++++--
 rust/tvm/src/lib.rs                     | 4 ++--
 2 files changed, 6 insertions(+), 4 deletions(-)

diff --git a/rust/tvm/src/ir/diagnostics/codespan.rs b/rust/tvm/src/ir/diagnostics/codespan.rs
index 9a31691..54fd336 100644
--- a/rust/tvm/src/ir/diagnostics/codespan.rs
+++ b/rust/tvm/src/ir/diagnostics/codespan.rs
@@ -17,8 +17,10 @@
  * under the License.
  */
 
-/// A TVM diagnostics renderer which uses the Rust `codespan`
-/// library to produce error messages.
+/// A TVM diagnostics renderer which uses the Rust `codespan` library
+/// to produce error messages.
+///
+///
 use std::collections::HashMap;
 use std::sync::{Arc, Mutex};
 
diff --git a/rust/tvm/src/lib.rs b/rust/tvm/src/lib.rs
index ec80ece..7e0682b 100644
--- a/rust/tvm/src/lib.rs
+++ b/rust/tvm/src/lib.rs
@@ -24,7 +24,7 @@
 //! One particular use case is that given optimized deep learning model artifacts,
 //! (compiled with TVM) which include a shared library
 //! `lib.so`, `graph.json` and a byte-array `param.params`, one can load them
-//! in Rust idomatically to create a TVM Graph Runtime and
+//! in Rust idiomatically to create a TVM Graph Runtime and
 //! run the model for some inputs and get the
 //! desired predictions *all in Rust*.
 //!
@@ -53,7 +53,7 @@ macro_rules! export {
     ($($fn_name:expr),*) => {
         pub fn tvm_export(ns: &str) -> Result<(), tvm::Error> {
             $(
-                let name = String::from(ns) + ::std::stringify!($fn_name);
+                let name = String::fromwe(ns) + ::std::stringify!($fn_name);
                 tvm::runtime::function::register_override($fn_name, name, true)?;
             )*
             Ok(())


[incubator-tvm] 09/23: Improve Rust bindings

Posted by jr...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

jroesch pushed a commit to branch cargo-build
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git

commit 4cd1bbc79a71873676927c22528f5911e3f4072d
Author: Jared Roesch <ro...@gmail.com>
AuthorDate: Thu Oct 15 20:24:43 2020 -0700

    Improve Rust bindings
---
 rust/tvm/src/ir/diagnostics/codespan.rs            | 131 +++++++++++++++++++++
 .../src/ir/{diagnostics.rs => diagnostics/mod.rs}  |  69 +----------
 rust/tvm/src/ir/source_map.rs                      |   3 +-
 rust/tvm/test.rly                                  |   3 +-
 src/ir/diagnostic.cc                               |   1 +
 5 files changed, 138 insertions(+), 69 deletions(-)

diff --git a/rust/tvm/src/ir/diagnostics/codespan.rs b/rust/tvm/src/ir/diagnostics/codespan.rs
new file mode 100644
index 0000000..80a8784
--- /dev/null
+++ b/rust/tvm/src/ir/diagnostics/codespan.rs
@@ -0,0 +1,131 @@
+use std::collections::HashMap;
+use std::sync::{Arc, Mutex};
+
+use codespan_reporting::diagnostic::{Diagnostic as CDiagnostic, Label, Severity};
+use codespan_reporting::files::SimpleFiles;
+use codespan_reporting::term::termcolor::{ColorChoice, StandardStream};
+
+use crate::ir::source_map::*;
+use super::*;
+
+enum StartOrEnd {
+    Start,
+    End,
+}
+
+enum FileSpanToByteRange {
+    AsciiSource,
+    Utf8 {
+        /// Map character regions which are larger then 1-byte to length.
+        lengths: HashMap<isize, isize>,
+        source: String,
+    }
+}
+
+impl FileSpanToByteRange {
+    fn new(source: String) -> FileSpanToByteRange {
+        let mut last_index = 0;
+        let mut is_ascii = true;
+        if source.is_ascii() {
+            FileSpanToByteRange::AsciiSource
+        } else {
+            panic!()
+        }
+
+        // for (index, _) in source.char_indices() {
+        //     if last_index - 1 != last_index {
+        //         is_ascii = false;
+        //     } else {
+        //         panic!();
+        //     }
+        //     last_index = index;
+        // }
+    }
+}
+
+struct SpanToByteRange {
+    map: HashMap<String, FileSpanToByteRange>
+}
+
+impl SpanToByteRange {
+    fn new() -> SpanToByteRange {
+        SpanToByteRange { map: HashMap::new() }
+    }
+
+    pub fn add_source(&mut self, source: Source) {
+        let source_name: String = source.source_name.name.as_str().expect("foo").into();
+
+        if self.map.contains_key(&source_name) {
+            panic!()
+        } else {
+            let source = source.source.as_str().expect("fpp").into();
+            self.map.insert(source_name, FileSpanToByteRange::new(source));
+        }
+    }
+}
+
+struct ByteRange<FileId> {
+    file_id: FileId,
+    start_pos: usize,
+    end_pos: usize,
+}
+
+
+pub fn to_diagnostic(diag: super::Diagnostic) -> CDiagnostic<String> {
+    let severity = match diag.level {
+        DiagnosticLevel::Error => Severity::Error,
+        DiagnosticLevel::Warning => Severity::Warning,
+        DiagnosticLevel::Note => Severity::Note,
+        DiagnosticLevel::Help => Severity::Help,
+        DiagnosticLevel::Bug => Severity::Bug,
+    };
+
+    let file_id = "foo".into(); // diag.span.source_name;
+
+    let message: String = diag.message.as_str().unwrap().into();
+    let inner_message: String = "expected `String`, found `Nat`".into();
+    let diagnostic = CDiagnostic::new(severity)
+        .with_message(message)
+        .with_code("EXXX")
+        .with_labels(vec![
+            Label::primary(file_id, 328..331).with_message(inner_message)
+        ]);
+
+    diagnostic
+}
+
+struct DiagnosticState {
+    files: SimpleFiles<String, String>,
+    span_map: SpanToByteRange,
+}
+
+impl DiagnosticState {
+    fn new() -> DiagnosticState {
+        DiagnosticState {
+            files: SimpleFiles::new(),
+            span_map: SpanToByteRange::new(),
+        }
+    }
+}
+
+fn renderer(state: &mut DiagnosticState, diag_ctx: DiagnosticContext) {
+    let source_map = diag_ctx.module.source_map.clone();
+        for diagnostic in diag_ctx.diagnostics.clone() {
+            match source_map.source_map.get(&diagnostic.span.source_name) {
+                Err(err) => panic!(),
+                Ok(source) => state.span_map.add_source(source),
+            }
+            println!("Diagnostic: {}", diagnostic.message);
+        }
+}
+
+pub fn init() -> Result<()> {
+    let diag_state = Arc::new(Mutex::new(DiagnosticState::new()));
+    let render_fn = move |diag_ctx: DiagnosticContext| {
+        // let mut guard = diag_state.lock().unwrap();
+        // renderer(&mut *guard, diag_ctx);
+    };
+
+    override_renderer(Some(render_fn))?;
+    Ok(())
+}
diff --git a/rust/tvm/src/ir/diagnostics.rs b/rust/tvm/src/ir/diagnostics/mod.rs
similarity index 76%
rename from rust/tvm/src/ir/diagnostics.rs
rename to rust/tvm/src/ir/diagnostics/mod.rs
index 4975a45..fce214a 100644
--- a/rust/tvm/src/ir/diagnostics.rs
+++ b/rust/tvm/src/ir/diagnostics/mod.rs
@@ -18,7 +18,7 @@
  */
 
 use super::module::IRModule;
-use super::span::Span;
+use super::span::*;
 use crate::runtime::function::Result;
 use crate::runtime::object::{Object, ObjectPtr, ObjectRef};
 use crate::runtime::{
@@ -32,7 +32,7 @@ use crate::runtime::{
 /// and the DiagnosticRenderer.
 use tvm_macros::{external, Object};
 
-type SourceName = ObjectRef;
+pub mod codespan;
 
 // Get the the diagnostic renderer.
 external! {
@@ -229,68 +229,3 @@ where
         }
     }
 }
-
-pub mod codespan {
-    use super::*;
-
-    use codespan_reporting::diagnostic::{Diagnostic as CDiagnostic, Label, Severity};
-    use codespan_reporting::files::SimpleFiles;
-    use codespan_reporting::term::termcolor::{ColorChoice, StandardStream};
-
-    enum StartOrEnd {
-        Start,
-        End,
-    }
-
-    // struct SpanToBytes {
-    //     inner: HashMap<std::String, HashMap<usize, (StartOrEnd,
-    // }
-
-    struct ByteRange<FileId> {
-        file_id: FileId,
-        start_pos: usize,
-        end_pos: usize,
-    }
-
-    // impl SpanToBytes {
-    //     fn to_byte_pos(&self, span: tvm::ir::Span) -> ByteRange<FileId> {
-
-    //     }
-    // }
-
-    pub fn to_diagnostic(diag: super::Diagnostic) -> CDiagnostic<String> {
-        let severity = match diag.level {
-            DiagnosticLevel::Error => Severity::Error,
-            DiagnosticLevel::Warning => Severity::Warning,
-            DiagnosticLevel::Note => Severity::Note,
-            DiagnosticLevel::Help => Severity::Help,
-            DiagnosticLevel::Bug => Severity::Bug,
-        };
-
-        let file_id = "foo".into(); // diag.span.source_name;
-
-        let message: String = diag.message.as_str().unwrap().into();
-        let inner_message: String = "expected `String`, found `Nat`".into();
-        let diagnostic = CDiagnostic::new(severity)
-            .with_message(message)
-            .with_code("EXXX")
-            .with_labels(vec![
-                Label::primary(file_id, 328..331).with_message(inner_message)
-            ]);
-
-        diagnostic
-    }
-
-    pub fn init() -> Result<()> {
-        let mut files: SimpleFiles<String, String> = SimpleFiles::new();
-        let render_fn = move |diag_ctx: DiagnosticContext| {
-            let source_map = diag_ctx.module.source_map.clone();
-            for diagnostic in diag_ctx.diagnostics.clone() {
-                println!("Diagnostic: {}", diagnostic.message);
-            }
-        };
-
-        override_renderer(Some(render_fn))?;
-        Ok(())
-    }
-}
diff --git a/rust/tvm/src/ir/source_map.rs b/rust/tvm/src/ir/source_map.rs
index 56c0830..ebe7e46 100644
--- a/rust/tvm/src/ir/source_map.rs
+++ b/rust/tvm/src/ir/source_map.rs
@@ -19,6 +19,7 @@
 
 use crate::runtime::map::Map;
 use crate::runtime::object::Object;
+use crate::runtime::string::{String as TString};
 
 use super::span::{SourceName, Span};
 
@@ -37,7 +38,7 @@ pub struct SourceNode {
     pub source_name: SourceName,
 
     /// The raw source. */
-    source: String,
+    pub source: TString,
 
    // A mapping of line breaks into the raw source.
    // std::vector<std::pair<int, int>> line_map;
diff --git a/rust/tvm/test.rly b/rust/tvm/test.rly
index d8b7c69..e9407b0 100644
--- a/rust/tvm/test.rly
+++ b/rust/tvm/test.rly
@@ -1,2 +1,3 @@
 #[version = "0.0.5"]
-fn @main(%x: int32) -> float32 { %x }
+
+def @main(%x: int32) -> float32 { %x }
diff --git a/src/ir/diagnostic.cc b/src/ir/diagnostic.cc
index f9299e3..a4fee1e 100644
--- a/src/ir/diagnostic.cc
+++ b/src/ir/diagnostic.cc
@@ -113,6 +113,7 @@ TVM_REGISTER_GLOBAL("diagnostics.DiagnosticRendererRender")
     });
 
 DiagnosticContext::DiagnosticContext(const IRModule& module, const DiagnosticRenderer& renderer) {
+  CHECK(renderer.defined()) << "can not initialize a diagnostic renderer with a null function";
   auto n = make_object<DiagnosticContextNode>();
   n->module = module;
   n->renderer = renderer;


[incubator-tvm] 16/23: More cleanup

Posted by jr...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

jroesch pushed a commit to branch cargo-build
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git

commit 187435099c4d80be2c200791e7853f8518f421e1
Author: Jared Roesch <jr...@octoml.ai>
AuthorDate: Fri Oct 16 16:45:12 2020 -0700

    More cleanup
---
 rust/compiler-ext/src/lib.rs | 22 ++++++++--------------
 rust/tvm/src/lib.rs          |  7 ++++---
 2 files changed, 12 insertions(+), 17 deletions(-)

diff --git a/rust/compiler-ext/src/lib.rs b/rust/compiler-ext/src/lib.rs
index 346f40f..5f83f7b 100644
--- a/rust/compiler-ext/src/lib.rs
+++ b/rust/compiler-ext/src/lib.rs
@@ -18,25 +18,19 @@
  */
 
 use env_logger;
-use tvm;
-use tvm::runtime::function::register_override;
+use tvm::export;
 
-fn test_fn() -> Result<(), tvm::Error> {
-    println!("Hello Greg from Rust!");
-    Ok(())
+fn diagnostics() -> Result<(), tvm::Error> {
+    tvm::ir::diagnostics::codespan::init()
 }
 
-fn test_fn2(message: tvm::runtime::string::String) -> Result<(), tvm::Error> {
-    println!("The message: {}", message);
-    Ok(())
-}
-
-tvm::export!(test_fn, test_fn2);
+export!(diagnostics);
 
 #[no_mangle]
-fn compiler_ext_initialize() -> i32 {
+extern fn compiler_ext_initialize() -> i32 {
     let _ = env_logger::try_init();
-    tvm_export("rust_ext").expect("failed to initialize Rust compiler_ext");
-    log::debug!("done!");
+    tvm_export("rust_ext")
+        .expect("failed to initialize the Rust compiler extensions.");
+    log::debug!("Loaded the Rust compiler extension.");
     return 0;
 }
diff --git a/rust/tvm/src/lib.rs b/rust/tvm/src/lib.rs
index d193f09..ec80ece 100644
--- a/rust/tvm/src/lib.rs
+++ b/rust/tvm/src/lib.rs
@@ -50,10 +50,11 @@ pub use runtime::version;
 
 #[macro_export]
 macro_rules! export {
-    ($($fn_names:expr),*) => {
+    ($($fn_name:expr),*) => {
         pub fn tvm_export(ns: &str) -> Result<(), tvm::Error> {
             $(
-                register_override($fn_name, concat!($ns, stringfy!($fn_name)), true)?;
+                let name = String::from(ns) + ::std::stringify!($fn_name);
+                tvm::runtime::function::register_override($fn_name, name, true)?;
             )*
             Ok(())
         }
@@ -65,7 +66,7 @@ macro_rules! export_mod {
     ($ns:expr, $($mod_name:expr),*) => {
         pub fn tvm_mod_export() -> Result<(), tvm::Error> {
             $(
-                $mod_names::tvm_export($ns)?;
+                $mod_name::tvm_export($ns)?;
             )*
             Ok(())
         }


[incubator-tvm] 06/23: Update CMake and delete old API

Posted by jr...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

jroesch pushed a commit to branch cargo-build
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git

commit db245535a4ea5b8bed3cdddd26a94825d2fcc9b5
Author: Jared Roesch <ro...@gmail.com>
AuthorDate: Thu Oct 15 01:03:03 2020 -0700

    Update CMake and delete old API
---
 CMakeLists.txt                  |   1 +
 cmake/modules/RustExt.cmake     |  25 ++-
 include/tvm/parser/source_map.h |   2 -
 rust/compiler-ext/Cargo.toml    |   5 +-
 rust/compiler-ext/src/lib.rs    | 334 ++--------------------------------------
 rust/tvm-rt/Cargo.toml          |  15 +-
 rust/tvm-sys/Cargo.toml         |   1 +
 rust/tvm-sys/build.rs           |   1 +
 rust/tvm/Cargo.toml             |  22 ++-
 rust/tvm/src/bin/tyck.rs        |   1 -
 rust/tvm/src/ir/diagnostics.rs  |  42 +++--
 rust/tvm/src/ir/mod.rs          |   2 +-
 rust/tvm/src/ir/relay/mod.rs    |   3 +-
 rust/tvm/src/ir/source_map.rs   |  61 ++++++++
 rust/tvm/src/ir/span.rs         |  95 +++++++++---
 src/ir/expr.cc                  |  11 ++
 src/parser/source_map.cc        |  11 --
 17 files changed, 237 insertions(+), 395 deletions(-)

diff --git a/CMakeLists.txt b/CMakeLists.txt
index 9f82754..58bb2f7 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -353,6 +353,7 @@ include(cmake/modules/contrib/ArmComputeLib.cmake)
 include(cmake/modules/contrib/TensorRT.cmake)
 include(cmake/modules/Git.cmake)
 include(cmake/modules/LibInfo.cmake)
+include(cmake/modules/RustExt.cmake)
 
 include(CheckCXXCompilerFlag)
 if(NOT MSVC)
diff --git a/cmake/modules/RustExt.cmake b/cmake/modules/RustExt.cmake
index 45e46bd..2ad726e9 100644
--- a/cmake/modules/RustExt.cmake
+++ b/cmake/modules/RustExt.cmake
@@ -1,7 +1,14 @@
-if(USE_RUST_EXT)
-    set(RUST_SRC_DIR "rust")
-    set(CARGO_OUT_DIR "rust/target"
-    set(COMPILER_EXT_PATH "${CARGO_OUT_DIR}/target/release/libcompiler_ext.dylib")
+if(USE_RUST_EXT AND NOT USE_RUST_EXT EQUAL OFF)
+    set(RUST_SRC_DIR "${CMAKE_SOURCE_DIR}/rust")
+    set(CARGO_OUT_DIR "${CMAKE_SOURCE_DIR}/rust/target")
+
+    if(USE_RUST_EXT STREQUAL "STATIC")
+        set(COMPILER_EXT_PATH "${CARGO_OUT_DIR}/release/libcompiler_ext.a")
+    elseif(USE_RUST_EXT STREQUAL "DYNAMIC")
+        set(COMPILER_EXT_PATH "${CARGO_OUT_DIR}/release/libcompiler_ext.so")
+    else()
+        message(FATAL_ERROR "invalid setting for RUST_EXT")
+    endif()
 
     add_custom_command(
         OUTPUT "${COMPILER_EXT_PATH}"
@@ -9,5 +16,11 @@ if(USE_RUST_EXT)
         MAIN_DEPENDENCY "${RUST_SRC_DIR}"
         WORKING_DIRECTORY "${RUST_SRC_DIR}/compiler-ext")
 
-    target_link_libraries(tvm "${COMPILER_EXT_PATH}" PRIVATE)
-endif(USE_RUST_EXT)
+    add_custom_target(rust_ext ALL DEPENDS "${COMPILER_EXT_PATH}")
+
+    # TODO(@jroesch, @tkonolige): move this to CMake target
+    # target_link_libraries(tvm "${COMPILER_EXT_PATH}" PRIVATE)
+    list(APPEND TVM_LINKER_LIBS ${COMPILER_EXT_PATH})
+
+    add_definitions(-DRUST_COMPILER_EXT=1)
+endif()
diff --git a/include/tvm/parser/source_map.h b/include/tvm/parser/source_map.h
index 424af5c..a160c22 100644
--- a/include/tvm/parser/source_map.h
+++ b/include/tvm/parser/source_map.h
@@ -103,8 +103,6 @@ class SourceMap : public ObjectRef {
 
   TVM_DLL SourceMap() : SourceMap(Map<SourceName, Source>()) {}
 
-  TVM_DLL static SourceMap Global();
-
   void Add(const Source& source);
 
   SourceMapNode* operator->() {
diff --git a/rust/compiler-ext/Cargo.toml b/rust/compiler-ext/Cargo.toml
index 76d10eb..3b13bc5 100644
--- a/rust/compiler-ext/Cargo.toml
+++ b/rust/compiler-ext/Cargo.toml
@@ -6,8 +6,11 @@ edition = "2018"
 # TODO(@jroesch): would be cool to figure out how to statically link instead.
 
 [lib]
-crate-type = ["cdylib"]
+crate-type = ["staticlib", "cdylib"]
 
 # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
 
 [dependencies]
+tvm = { path = "../tvm", default-features = false, features = ["static-linking"] }
+log = "*"
+env_logger = "*"
diff --git a/rust/compiler-ext/src/lib.rs b/rust/compiler-ext/src/lib.rs
index 58bdd0c..3e37d21 100644
--- a/rust/compiler-ext/src/lib.rs
+++ b/rust/compiler-ext/src/lib.rs
@@ -17,321 +17,19 @@
  * under the License.
  */
 
- use std::os::raw::c_int;
- use tvm::initialize;
- use tvm::ir::{tir, PrimExpr};
- use tvm::runtime::function::register_override;
- use tvm::runtime::map::Map;
- use tvm::runtime::object::{IsObject, IsObjectRef};
- 
- use ordered_float::NotNan;
- 
- mod interval;
- mod math;
- 
- use math::{BoundsMap, Expr, RecExpr};
- use tvm::ir::arith::ConstIntBound;
- use tvm_rt::{ObjectRef, array::Array};
- 
- macro_rules! downcast_match {
-     ($id:ident; { $($t:ty => $arm:expr $(,)? )+ , else => $default:expr }) => {
-         $( if let Ok($id) = $id.downcast_clone::<$t>() { $arm } else )+
-         { $default }
-     }
- }
- 
- #[derive(Default)]
- struct VarMap {
-     vars: Vec<(tvm::ir::tir::Var, egg::Symbol)>,
-     objs: Vec<ObjectRef>,
- }
- 
- impl VarMap {
-     // FIXME this should eventually do the right thing for TVM variables
-     // right now it depends on them having unique names
-     fn make_symbol(&mut self, var: tvm::ir::tir::Var) -> egg::Symbol {
-         let sym = egg::Symbol::from(var.name_hint.as_str().unwrap());
-         for (_, sym2) in &self.vars {
-             if sym == *sym2 {
-                 return sym;
-             }
-         }
- 
-         self.vars.push((var, sym));
-         sym
-     }
- 
-     fn get_symbol(&self, sym: egg::Symbol) -> tvm::ir::tir::Var {
-         for (v, sym2) in &self.vars {
-             if sym == *sym2 {
-                 return v.clone();
-             }
-         }
-         panic!("Should have found a var")
-     }
- 
-     fn push_obj(&mut self, obj: impl IsObjectRef) -> usize {
-         let i = self.objs.len();
-         self.objs.push(obj.upcast());
-         i
-     }
- 
-     fn get_obj<T: IsObjectRef>(&self, i: usize) -> T {
-         self.objs[i].clone().downcast().expect("bad downcast")
-     }
- }
- 
- fn to_egg(vars: &mut VarMap, prim: &PrimExpr) -> RecExpr {
-     fn build(vars: &mut VarMap, p: &PrimExpr, recexpr: &mut RecExpr) -> egg::Id {
-         macro_rules! r {
-             ($e:expr) => {
-                 build(vars, &$e, recexpr)
-             };
-         }
- 
-         let dt = recexpr.add(Expr::DataType(p.datatype));
-         let e = downcast_match!(p; {
-             tir::Add => Expr::Add([dt, r!(p.a), r!(p.b)]),
-             tir::Sub => Expr::Sub([dt, r!(p.a), r!(p.b)]),
-             tir::Mul => Expr::Mul([dt, r!(p.a), r!(p.b)]),
- 
-             tir::Div => Expr::Div([dt, r!(p.a), r!(p.b)]),
-             tir::Mod => Expr::Mod([dt, r!(p.a), r!(p.b)]),
-             tir::FloorDiv => Expr::FloorDiv([dt, r!(p.a), r!(p.b)]),
-             tir::FloorMod => Expr::FloorMod([dt, r!(p.a), r!(p.b)]),
- 
-             tir::Min => Expr::Min([dt, r!(p.a), r!(p.b)]),
-             tir::Max => Expr::Max([dt, r!(p.a), r!(p.b)]),
- 
-             tir::Ramp => Expr::Ramp([dt, r!(p.start), r!(p.stride), recexpr.add(Expr::Int(p.lanes.into()))]),
-             tir::Select => Expr::Select([dt, r!(p.condition), r!(p.true_value), r!(p.false_value)]),
- 
-             tir::Eq => Expr::Equal([dt, r!(p.a), r!(p.b)]),
-             tir::Ne => Expr::NotEqual([dt, r!(p.a), r!(p.b)]),
-             tir::Lt => Expr::Less([dt, r!(p.a), r!(p.b)]),
-             tir::Le => Expr::LessEqual([dt, r!(p.a), r!(p.b)]),
-             tir::Gt => Expr::Greater([dt, r!(p.a), r!(p.b)]),
-             tir::Ge => Expr::GreaterEqual([dt, r!(p.a), r!(p.b)]),
- 
-             tir::And => Expr::And([dt, r!(p.a), r!(p.b)]),
-             tir::Or => Expr::Or([dt, r!(p.a), r!(p.b)]),
-             tir::Not => Expr::Not([dt, r!(p.value)]),
- 
-             tir::Broadcast => Expr::Broadcast([dt, r!(p.value), recexpr.add(Expr::Int(p.lanes.into()))]),
- 
-             tir::Let => {
-                 let sym = recexpr.add(Expr::Symbol(vars.make_symbol(p.var.clone())));
-                 Expr::Let([dt, sym, r!(p.value), r!(p.body)])
-             }
-             tir::Var => {
-                 let sym = recexpr.add(Expr::Symbol(vars.make_symbol(p)));
-                 Expr::Var([dt, sym])
-             }
-             tir::IntImm => {
-                 let int = recexpr.add(Expr::Int(p.value));
-                 Expr::IntImm([dt, int])
-             }
-             tir::FloatImm => {
-                 let float = recexpr.add(Expr::Float(NotNan::new(p.value).unwrap()));
-                 Expr::FloatImm([dt, float])
-             }
-             tir::Cast => Expr::Cast([dt, r!(p.value)]),
- 
-             tir::Call => {
-                 let op = vars.push_obj(p.op.clone());
-                 let mut arg_ids = vec![dt];
-                 for i in 0..p.args.len() {
-                     let arg: PrimExpr = p.args.get(i as isize).expect("array get fail");
-                     arg_ids.push(r!(arg));
-                 }
-                 Expr::Call(op, arg_ids)
-             },
-             tir::Load => {
-                 let sym = recexpr.add(Expr::Symbol(vars.make_symbol(p.var.clone())));
-                 Expr::Load([dt, sym, r!(p.index), r!(p.predicate)])
-             },
-             else => {
-                 println!("Failed to downcast type '{}': {}", p.type_key(), tvm::runtime::debug_print(p.clone().upcast()).unwrap().to_str().unwrap());
-                 Expr::Object(vars.push_obj(p.clone()))
-             }
-         });
- 
-         recexpr.add(e)
-     }
- 
-     let mut recexpr = Default::default();
-     build(vars, prim, &mut recexpr);
-     recexpr
- }
- 
- fn from_egg(vars: &VarMap, recexpr: &RecExpr) -> PrimExpr {
-     fn build(vars: &VarMap, nodes: &[Expr]) -> PrimExpr {
-         let go = |i: &egg::Id| build(vars, &nodes[..usize::from(*i) + 1]);
-         let get_dt = |i: &egg::Id| nodes[usize::from(*i)].to_dtype().unwrap();
-         let prim: PrimExpr = match nodes.last().expect("cannot be empty") {
-             Expr::Var([_dt, s]) => match &nodes[usize::from(*s)] {
-                 Expr::Symbol(sym) => vars.get_symbol(*sym).upcast(),
-                 n => panic!("Expected a symbol, got {:?}", n),
-             },
-             Expr::IntImm([dt, v]) => {
-                 let value = nodes[usize::from(*v)].to_int().unwrap();
-                 tir::IntImm::new(get_dt(dt), value).upcast()
-             }
-             Expr::FloatImm([dt, v]) => {
-                 let value = nodes[usize::from(*v)].to_float().unwrap();
-                 tir::FloatImm::new(get_dt(dt), value).upcast()
-             }
-             Expr::Let([dt, s, value, body]) => {
-                 let var = match &nodes[usize::from(*s)] {
-                     Expr::Symbol(sym) => vars.get_symbol(*sym).upcast(),
-                     n => panic!("Expected a symbol, got {:?}", n),
-                 };
-                 tir::Let::new(get_dt(dt), var, go(value), go(body)).upcast()
-             }
-             Expr::Load([dt, s, value, body]) => {
-                 let var = match &nodes[usize::from(*s)] {
-                     Expr::Symbol(sym) => vars.get_symbol(*sym).upcast(),
-                     n => panic!("Expected a symbol, got {:?}", n),
-                 };
-                 tir::Load::new(get_dt(dt), var, go(value), go(body)).upcast()
-             }
- 
-             Expr::Add([dt, a, b]) => tir::Add::new(get_dt(dt), go(a), go(b)).upcast(),
-             Expr::Sub([dt, a, b]) => tir::Sub::new(get_dt(dt), go(a), go(b)).upcast(),
-             Expr::Mul([dt, a, b]) => tir::Mul::new(get_dt(dt), go(a), go(b)).upcast(),
- 
-             Expr::Div([dt, a, b]) => tir::Div::new(get_dt(dt), go(a), go(b)).upcast(),
-             Expr::Mod([dt, a, b]) => tir::Mod::new(get_dt(dt), go(a), go(b)).upcast(),
-             Expr::FloorDiv([dt, a, b]) => tir::FloorDiv::new(get_dt(dt), go(a), go(b)).upcast(),
-             Expr::FloorMod([dt, a, b]) => tir::FloorMod::new(get_dt(dt), go(a), go(b)).upcast(),
- 
-             Expr::Min([dt, a, b]) => tir::Min::new(get_dt(dt), go(a), go(b)).upcast(),
-             Expr::Max([dt, a, b]) => tir::Max::new(get_dt(dt), go(a), go(b)).upcast(),
- 
-             Expr::Equal([dt, a, b]) => tir::Eq::new(get_dt(dt), go(a), go(b)).upcast(),
-             Expr::NotEqual([dt, a, b]) => tir::Ne::new(get_dt(dt), go(a), go(b)).upcast(),
-             Expr::Less([dt, a, b]) => tir::Lt::new(get_dt(dt), go(a), go(b)).upcast(),
-             Expr::LessEqual([dt, a, b]) => tir::Le::new(get_dt(dt), go(a), go(b)).upcast(),
-             Expr::Greater([dt, a, b]) => tir::Gt::new(get_dt(dt), go(a), go(b)).upcast(),
-             Expr::GreaterEqual([dt, a, b]) => tir::Ge::new(get_dt(dt), go(a), go(b)).upcast(),
- 
-             Expr::And([dt, a, b]) => tir::And::new(get_dt(dt), go(a), go(b)).upcast(),
-             Expr::Or([dt, a, b]) => tir::Or::new(get_dt(dt), go(a), go(b)).upcast(),
-             Expr::Not([dt, a]) => tir::Not::new(get_dt(dt), go(a)).upcast(),
- 
-             Expr::Ramp([dt, a, b, c]) => {
-                 let len = &nodes[usize::from(*c)];
-                 let i = len
-                     .to_int()
-                     .unwrap_or_else(|| panic!("Ramp lanes must be an int, got {:?}", len));
-                 tir::Ramp::new(get_dt(dt), go(a), go(b), i as i32).upcast()
-             }
-             Expr::Broadcast([dt, val, lanes]) => {
-                 let lanes = &nodes[usize::from(*lanes)];
-                 let lanes = lanes
-                     .to_int()
-                     .unwrap_or_else(|| panic!("Ramp lanes must be an int, got {:?}", lanes));
-                 println!("dt: {}", get_dt(dt));
-                 tir::Broadcast::new(get_dt(dt), go(val), lanes as i32).upcast()
-             }
- 
-             Expr::Select([dt, a, b, c]) => tir::Select::new(get_dt(dt), go(a), go(b), go(c)).upcast(),
-             Expr::Cast([dt, a]) => tir::Cast::new(get_dt(dt), go(a)).upcast(),
-             Expr::Call(expr, args) => {
-                 let arg_exprs: Vec<PrimExpr> = args[1..].iter().map(go).collect();
-                 let arg_exprs = Array::from_vec(arg_exprs).expect("failed to convert args");
-                 tir::Call::new(get_dt(&args[0]), vars.get_obj(*expr), arg_exprs).upcast()
-             }
- 
-             Expr::Object(i) => vars.get_obj(*i),
-             node => panic!("I don't know how to extract {:?}", node),
-         };
-         assert_ne!(prim.datatype.bits(), 0);
-         assert_ne!(prim.datatype.lanes(), 0);
-         prim
-     }
-     build(vars, recexpr.as_ref())
- }
- 
- fn run(
-     input: PrimExpr,
-     expected: Option<PrimExpr>,
-     map: Map<PrimExpr, ConstIntBound>,
- ) -> Result<PrimExpr, String> {
-     use egg::{CostFunction, Extractor};
- 
-     let mut bounds = BoundsMap::default();
-     for (k, v) in map {
-         if let Ok(var) = k.downcast_clone::<tir::Var>() {
-             let sym: egg::Symbol = var.name_hint.as_str().unwrap().into();
-             bounds.insert(sym, (v.min_value, v.max_value));
-         } else {
-             println!("Non var in bounds map: {}", tvm::ir::as_text(k));
-         }
-     }
- 
-     let mut vars = VarMap::default();
-     let expr = to_egg(&mut vars, &input);
-     let mut runner = math::default_runner();
-     runner.egraph.analysis.bounds = bounds;
- 
-     let mut runner = runner.with_expr(&expr).run(&math::rules());
-     // runner.print_report();
-     let mut extractor = Extractor::new(&runner.egraph, math::CostFn);
-     let root = runner.egraph.find(runner.roots[0]);
-     let (cost, best) = extractor.find_best(root);
-     if let Some(expected) = expected {
-         let mut expected_vars = VarMap::default();
-         let expected_expr = to_egg(&mut expected_vars, &expected);
-         let expected_root = runner.egraph.add_expr(&expected_expr);
-         if expected_root != root {
-             return Err(format!(
-                 "\n\nFailed to prove them equal!\nExpected:\n{}\nFound:\n{}\n",
-                 expected_expr.pretty(40),
-                 best.pretty(40)
-             ));
-         }
-         let expected_cost = math::CostFn.cost_rec(&expected_expr);
-         if expected_cost != cost {
-             let msg = format!(
-                 "\n\nCosts not equal: Expected {}:\n{}\nFound {}:\n{}\n",
-                 expected_cost,
-                 expected_expr.pretty(40),
-                 cost,
-                 best.pretty(40)
-             );
-             if cost < expected_cost {
-                 println!("egg wins: {}", msg)
-             } else {
-                 return Err(msg);
-             }
-         }
-     }
-     log::info!("  returning... {}", best.pretty(60));
-     Ok(from_egg(&vars, &best))
- }
- 
- fn simplify(prim: PrimExpr, map: Map<PrimExpr, ConstIntBound>) -> Result<PrimExpr, tvm::Error> {
-     log::debug!("map: {:?}", map);
-     run(prim, None, map).map_err(tvm::Error::CallFailed)
- }
- 
- fn simplify_and_check(
-     prim: PrimExpr,
-     check: PrimExpr,
-     map: Map<PrimExpr, ConstIntBound>,
- ) -> Result<PrimExpr, tvm::Error> {
-     log::debug!("check map: {:?}", map);
-     run(prim, Some(check), map).map_err(tvm::Error::CallFailed)
- }
- 
- initialize!({
-     let _ = env_logger::try_init();
-     // NOTE this print prevents a segfault (on Linux) for now...
-     println!("Initializing simplifier... ");
-     register_override(simplify, "egg.simplify", true).expect("failed to initialize simplifier");
-     register_override(simplify_and_check, "egg.simplify_and_check", true)
-         .expect("failed to initialize simplifier");
-     log::debug!("done!");
- });
- 
\ No newline at end of file
+use env_logger;
+use tvm;
+use tvm::runtime::function::register_override;
+
+fn test_fn() -> Result<(), tvm::Error> {
+    println!("Hello from Rust!");
+    Ok(())
+}
+
+#[no_mangle]
+fn compiler_ext_initialize() -> i32 {
+    let _ = env_logger::try_init();
+    register_override(test_fn, "rust_ext.test_fn", true).expect("failed to initialize simplifier");
+    log::debug!("done!");
+    return 0;
+}
diff --git a/rust/tvm-rt/Cargo.toml b/rust/tvm-rt/Cargo.toml
index acece5a..9660943 100644
--- a/rust/tvm-rt/Cargo.toml
+++ b/rust/tvm-rt/Cargo.toml
@@ -28,19 +28,26 @@ categories = ["api-bindings", "science"]
 authors = ["TVM Contributors"]
 edition = "2018"
 
+[features]
+default = ["dynamic-linking"]
+dynamic-linking = ["tvm-sys/bindings"]
+static-linking = []
+blas = ["ndarray/blas"]
+
 [dependencies]
 thiserror = "^1.0"
 ndarray = "0.12"
 num-traits = "0.2"
-tvm-sys = { version = "0.1", path = "../tvm-sys/", features = ["bindings"] }
 tvm-macros = { version = "0.1", path = "../tvm-macros" }
 paste = "0.1"
 mashup = "0.1"
 once_cell = "^1.3.1"
 memoffset = "0.5.6"
 
+[dependencies.tvm-sys]
+version = "0.1"
+default-features = false
+path = "../tvm-sys/"
+
 [dev-dependencies]
 anyhow = "^1.0"
-
-[features]
-blas = ["ndarray/blas"]
diff --git a/rust/tvm-sys/Cargo.toml b/rust/tvm-sys/Cargo.toml
index 4e3fc98..c25a5bf 100644
--- a/rust/tvm-sys/Cargo.toml
+++ b/rust/tvm-sys/Cargo.toml
@@ -23,6 +23,7 @@ license = "Apache-2.0"
 edition = "2018"
 
 [features]
+default = ["bindings"]
 bindings = []
 
 [dependencies]
diff --git a/rust/tvm-sys/build.rs b/rust/tvm-sys/build.rs
index 05806c0..2d86c4b 100644
--- a/rust/tvm-sys/build.rs
+++ b/rust/tvm-sys/build.rs
@@ -60,6 +60,7 @@ fn main() -> Result<()> {
     if cfg!(feature = "bindings") {
         println!("cargo:rerun-if-env-changed=TVM_HOME");
         println!("cargo:rustc-link-lib=dylib=tvm");
+        println!("cargo:rustc-link-lib=dylib=llvm-10");
         println!("cargo:rustc-link-search={}/build", tvm_home);
     }
 
diff --git a/rust/tvm/Cargo.toml b/rust/tvm/Cargo.toml
index 71a4b93..153a195 100644
--- a/rust/tvm/Cargo.toml
+++ b/rust/tvm/Cargo.toml
@@ -28,14 +28,24 @@ categories = ["api-bindings", "science"]
 authors = ["TVM Contributors"]
 edition = "2018"
 
+[features]
+default = ["python", "dynamic-linking"]
+dynamic-linking = ["tvm-rt/dynamic-linking"]
+static-linking = ["tvm-rt/static-linking"]
+blas = ["ndarray/blas"]
+python = ["pyo3"]
+
+[dependencies.tvm-rt]
+version = "0.1"
+default-features = false
+path = "../tvm-rt/"
+
 [dependencies]
 thiserror = "^1.0"
 anyhow = "^1.0"
 lazy_static = "1.1"
 ndarray = "0.12"
 num-traits = "0.2"
-tvm-rt = { version = "0.1", path = "../tvm-rt/" }
-tvm-sys = { version = "0.1", path = "../tvm-sys/" }
 tvm-macros = { version = "*", path = "../tvm-macros/" }
 paste = "0.1"
 mashup = "0.1"
@@ -44,8 +54,6 @@ pyo3 = { version = "0.11.1", optional = true }
 codespan-reporting = "0.9.5"
 structopt = { version = "0.3" }
 
-[features]
-default = ["python"]
-
-blas = ["ndarray/blas"]
-python = ["pyo3"]
+[[bin]]
+name = "tyck"
+required-features = ["dynamic-linking"]
diff --git a/rust/tvm/src/bin/tyck.rs b/rust/tvm/src/bin/tyck.rs
index 9300412..b869012 100644
--- a/rust/tvm/src/bin/tyck.rs
+++ b/rust/tvm/src/bin/tyck.rs
@@ -6,7 +6,6 @@ use structopt::StructOpt;
 use tvm::ir::diagnostics::codespan;
 use tvm::ir::IRModule;
 
-
 #[derive(Debug, StructOpt)]
 #[structopt(name = "tyck", about = "Parse and type check a Relay program.")]
 struct Opt {
diff --git a/rust/tvm/src/ir/diagnostics.rs b/rust/tvm/src/ir/diagnostics.rs
index d306185..b76e43f 100644
--- a/rust/tvm/src/ir/diagnostics.rs
+++ b/rust/tvm/src/ir/diagnostics.rs
@@ -17,17 +17,20 @@
  * under the License.
  */
 
+use super::module::IRModule;
+use super::span::Span;
+use crate::runtime::function::Result;
+use crate::runtime::object::{Object, ObjectPtr, ObjectRef};
+use crate::runtime::{
+    array::Array,
+    function::{self, Function, ToFunction},
+    string::String as TString,
+};
 /// The diagnostic interface to TVM, used for reporting and rendering
 /// diagnostic information by the compiler. This module exposes
 /// three key abstractions: a Diagnostic, the DiagnosticContext,
 /// and the DiagnosticRenderer.
-
-use tvm_macros::{Object, external};
-use super::module::IRModule;
-use crate::runtime::{function::{self, Function, ToFunction}, array::Array, string::String as TString};
-use crate::runtime::object::{Object, ObjectPtr, ObjectRef};
-use crate::runtime::function::Result;
-use super::span::Span;
+use tvm_macros::{external, Object};
 
 type SourceName = ObjectRef;
 
@@ -134,7 +137,6 @@ pub struct DiagnosticRendererNode {
     // memory layout
 }
 
-
 //     def render(self, ctx):
 //         """
 //         Render the provided context.
@@ -169,7 +171,8 @@ pub struct DiagnosticContextNode {
 /// and contains a renderer.
 impl DiagnosticContext {
     pub fn new<F>(module: IRModule, render_func: F) -> DiagnosticContext
-    where F: Fn(DiagnosticContext) -> () + 'static
+    where
+        F: Fn(DiagnosticContext) -> () + 'static,
     {
         let renderer = diagnostic_renderer(render_func.to_function()).unwrap();
         let node = DiagnosticContextNode {
@@ -210,21 +213,16 @@ impl DiagnosticContext {
 //     If the render_func is None it will remove the current custom renderer
 //     and return to default behavior.
 fn override_renderer<F>(opt_func: Option<F>) -> Result<()>
-where F: Fn(DiagnosticContext) -> () + 'static
+where
+    F: Fn(DiagnosticContext) -> () + 'static,
 {
-
     match opt_func {
         None => clear_renderer(),
         Some(func) => {
             let func = func.to_function();
-            let render_factory = move || {
-                diagnostic_renderer(func.clone()).unwrap()
-            };
+            let render_factory = move || diagnostic_renderer(func.clone()).unwrap();
 
-            function::register_override(
-                render_factory,
-                "diagnostics.OverrideRenderer",
-                true)?;
+            function::register_override(render_factory, "diagnostics.OverrideRenderer", true)?;
 
             Ok(())
         }
@@ -243,9 +241,9 @@ pub mod codespan {
         End,
     }
 
-    struct SpanToBytes {
-        inner: HashMap<std::String, HashMap<usize, (StartOrEnd,
-    }
+    // struct SpanToBytes {
+    //     inner: HashMap<std::String, HashMap<usize, (StartOrEnd,
+    // }
 
     struct ByteRange<FileId> {
         file_id: FileId,
@@ -276,7 +274,7 @@ pub mod codespan {
             .with_message(message)
             .with_code("EXXX")
             .with_labels(vec![
-                Label::primary(file_id, 328..331).with_message(inner_message),
+                Label::primary(file_id, 328..331).with_message(inner_message)
             ]);
 
         diagnostic
diff --git a/rust/tvm/src/ir/mod.rs b/rust/tvm/src/ir/mod.rs
index 8450bd7..401b6c2 100644
--- a/rust/tvm/src/ir/mod.rs
+++ b/rust/tvm/src/ir/mod.rs
@@ -19,8 +19,8 @@
 
 pub mod arith;
 pub mod attrs;
-pub mod expr;
 pub mod diagnostics;
+pub mod expr;
 pub mod function;
 pub mod module;
 pub mod op;
diff --git a/rust/tvm/src/ir/relay/mod.rs b/rust/tvm/src/ir/relay/mod.rs
index e539221..4b09128 100644
--- a/rust/tvm/src/ir/relay/mod.rs
+++ b/rust/tvm/src/ir/relay/mod.rs
@@ -28,6 +28,7 @@ use super::attrs::Attrs;
 use super::expr::BaseExprNode;
 use super::function::BaseFuncNode;
 use super::ty::{Type, TypeNode};
+use super::span::Span;
 
 use tvm_macros::Object;
 use tvm_rt::NDArray;
@@ -51,7 +52,7 @@ impl ExprNode {
             span: ObjectRef::null(),
             checked_type: Type::from(TypeNode {
                 base: Object::base_object::<TypeNode>(),
-                span: ObjectRef::null(),
+                span: Span::empty(),
             }),
         }
     }
diff --git a/rust/tvm/src/ir/source_map.rs b/rust/tvm/src/ir/source_map.rs
index e69de29..e6c0371 100644
--- a/rust/tvm/src/ir/source_map.rs
+++ b/rust/tvm/src/ir/source_map.rs
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+use crate::runtime::map::Map;
+use crate::runtime::object::Object;
+
+/// A program source in any language.
+///
+/// Could represent the source from an ML framework or a source of an IRModule.
+#[repr(C)]
+#[derive(Object)]
+#[type_key = "Source"]
+#[ref_key = "Source"]
+struct SourceNode {
+    pub base: Object,
+    /*! \brief The source name. */
+   SourceName source_name;
+
+   /*! \brief The raw source. */
+   String source;
+
+   /*! \brief A mapping of line breaks into the raw source. */
+   std::vector<std::pair<int, int>> line_map;
+}
+
+
+//  class Source : public ObjectRef {
+//   public:
+//    TVM_DLL Source(SourceName src_name, std::string source);
+//    TVM_DLL tvm::String GetLine(int line);
+
+//    TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Source, ObjectRef, SourceNode);
+//  };
+
+
+/// A mapping from a unique source name to source fragments.
+#[repr(C)]
+#[derive(Object)]
+#[type_key = "SourceMap"]
+#[ref_key = "SourceMap"]
+struct SourceMapNode {
+    pub base: Object,
+   /// The source mapping.
+   pub source_map: Map<SourceName, Source>,
+}
diff --git a/rust/tvm/src/ir/span.rs b/rust/tvm/src/ir/span.rs
index d2e19a2..c54fd51 100644
--- a/rust/tvm/src/ir/span.rs
+++ b/rust/tvm/src/ir/span.rs
@@ -1,22 +1,75 @@
 /*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-use crate::runtime::ObjectRef;
-
-pub type Span = ObjectRef;
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements.  See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership.  The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License.  You may obtain a copy of the License at
+*
+*   http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an
+* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+* KIND, either express or implied.  See the License for the
+
+* specific language governing permissions and limitations
+* under the License.
+*/
+
+use crate::runtime::{ObjectRef, Object, String as TString};
+use tvm_macros::Object;
+
+/// A source file name, contained in a Span.
+
+#[repr(C)]
+#[derive(Object)]
+#[type_key = "SourceName"]
+#[ref_name = "SourceName"]
+pub struct SourceNameNode {
+    pub base: Object,
+    pub name: TString,
+}
+
+//  /*!
+//   * \brief The source name of a file span.
+//   * \sa SourceNameNode, Span
+//   */
+//  class SourceName : public ObjectRef {
+//   public:
+//    /*!
+//     * \brief Get an SourceName for a given operator name.
+//     *  Will raise an error if the source name has not been registered.
+//     * \param name Name of the operator.
+//     * \return SourceName valid throughout program lifetime.
+//     */
+//    TVM_DLL static SourceName Get(const String& name);
+
+//    TVM_DEFINE_OBJECT_REF_METHODS(SourceName, ObjectRef, SourceNameNode);
+//  };
+
+/// Span information for diagnostic purposes.
+#[repr(C)]
+#[derive(Object)]
+#[type_key = "Span"]
+#[ref_name = "Span"]
+pub struct SpanNode {
+    pub base: Object,
+    /// The source name.
+    pub source_name: SourceName,
+    /// The line number.
+    pub line: i32,
+    /// The column offset.
+    pub column: i32,
+    /// The end line number.
+    pub end_line: i32,
+    /// The end column number.
+    pub end_column: i32,
+}
+
+impl Span {
+    pub fn empty() -> Span {
+        todo!()
+    }
+}
diff --git a/src/ir/expr.cc b/src/ir/expr.cc
index 67e5cea..5110eef 100644
--- a/src/ir/expr.cc
+++ b/src/ir/expr.cc
@@ -192,4 +192,15 @@ TVM_REGISTER_GLOBAL("ir.DebugPrint").set_body_typed([](ObjectRef ref) {
   return ss.str();
 });
 
+
+
 }  // namespace tvm
+
+#ifdef RUST_COMPILER_EXT
+
+extern "C" {
+  int compiler_ext_initialize();
+  static int test = compiler_ext_initialize();
+}
+
+#endif
diff --git a/src/parser/source_map.cc b/src/parser/source_map.cc
index 7ac978c..7340f69 100644
--- a/src/parser/source_map.cc
+++ b/src/parser/source_map.cc
@@ -77,12 +77,6 @@ tvm::String Source::GetLine(int line) {
   return line_text;
 }
 
-// TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-//     .set_dispatch<SourceNameNode>([](const ObjectRef& ref, ReprPrinter* p) {
-//       auto* node = static_cast<const SourceNameNode*>(ref.get());
-//       p->stream << "SourceName(" << node->name << ", " << node << ")";
-//     });
-
 TVM_REGISTER_NODE_TYPE(SourceMapNode);
 
 SourceMap::SourceMap(Map<SourceName, Source> source_map) {
@@ -91,11 +85,6 @@ SourceMap::SourceMap(Map<SourceName, Source> source_map) {
   data_ = std::move(n);
 }
 
-// TODO(@jroesch): fix this
-static SourceMap global_source_map = SourceMap(Map<SourceName, Source>());
-
-SourceMap SourceMap::Global() { return global_source_map; }
-
 void SourceMap::Add(const Source& source) { (*this)->source_map.Set(source->source_name, source); }
 
 TVM_REGISTER_GLOBAL("SourceMapAdd").set_body_typed([](SourceMap map, String name, String content) {


[incubator-tvm] 03/23: WIP

Posted by jr...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

jroesch pushed a commit to branch cargo-build
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git

commit cb3785680cb343eb295ea8e0585c87ca42db4323
Author: Jared Roesch <jr...@octoml.ai>
AuthorDate: Tue Oct 13 11:04:37 2020 -0700

    WIP
---
 rust/tvm/src/ir/diagnostics.rs | 78 ++++++++++++++++++++----------------------
 1 file changed, 37 insertions(+), 41 deletions(-)

diff --git a/rust/tvm/src/ir/diagnostics.rs b/rust/tvm/src/ir/diagnostics.rs
index e434d3f..d306185 100644
--- a/rust/tvm/src/ir/diagnostics.rs
+++ b/rust/tvm/src/ir/diagnostics.rs
@@ -24,7 +24,7 @@
 
 use tvm_macros::{Object, external};
 use super::module::IRModule;
-use crate::runtime::{function::{self, Function, ToFunction, Typed}, array::Array, string::String as TString};
+use crate::runtime::{function::{self, Function, ToFunction}, array::Array, string::String as TString};
 use crate::runtime::object::{Object, ObjectPtr, ObjectRef};
 use crate::runtime::function::Result;
 use super::span::Span;
@@ -121,42 +121,19 @@ pub struct DiagnosticBuilder {
 //    * of compiler diagnostics to std::out and std::err in
 //    * a human readable form.
 //    */
-//   class DiagnosticRendererNode : public Object {
-//    public:
-//     TypedPackedFunc<void(DiagnosticContext ctx)> renderer;
-
-//     // override attr visitor
-//     void VisitAttrs(AttrVisitor* v) {}
-
-//     static constexpr const char* _type_key = "DiagnosticRenderer";
-//     TVM_DECLARE_FINAL_OBJECT_INFO(DiagnosticRendererNode, Object);
-//   };
-
-//   class DiagnosticRenderer : public ObjectRef {
-//    public:
-//     TVM_DLL DiagnosticRenderer(TypedPackedFunc<void(DiagnosticContext ctx)> render);
-//     TVM_DLL DiagnosticRenderer()
-//         : DiagnosticRenderer(TypedPackedFunc<void(DiagnosticContext ctx)>()) {}
-
-//     void Render(const DiagnosticContext& ctx);
-
-//     DiagnosticRendererNode* operator->() {
-//       CHECK(get() != nullptr);
-//       return static_cast<DiagnosticRendererNode*>(get_mutable());
-//     }
-
-//     TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(DiagnosticRenderer, ObjectRef, DiagnosticRendererNode);
-//   };
-
-// @tvm._ffi.register_object("DiagnosticRenderer")
-// class DiagnosticRenderer(Object):
-//     """
-//     A diagnostic renderer, which given a diagnostic context produces a "rendered"
-//     form of the diagnostics for either human or computer consumption.
-//     """
+#[repr(C)]
+#[derive(Object)]
+#[ref_name = "DiagnosticRenderer"]
+#[type_key = "DiagnosticRenderer"]
+/// A diagnostic renderer, which given a diagnostic context produces a "rendered"
+/// form of the diagnostics for either human or computer consumption.
+pub struct DiagnosticRendererNode {
+    /// The base type.
+    pub base: Object,
+    // TODO(@jroesch): we can't easily exposed packed functions due to
+    // memory layout
+}
 
-//     def __init__(self, render_func):
-//         self.__init_handle_by_constructor__(_ffi_api.DiagnosticRenderer, render_func)
 
 //     def render(self, ctx):
 //         """
@@ -168,7 +145,6 @@ pub struct DiagnosticBuilder {
 //             The diagnostic context to render.
 //         """
 //         return _ffi_api.DiagnosticRendererRender(self, ctx
-pub type DiagnosticRenderer = ObjectRef;
 
 #[repr(C)]
 #[derive(Object)]
@@ -227,8 +203,7 @@ impl DiagnosticContext {
     }
 }
 
-// Sets a custom renderer for diagnostics.
-
+// Override the global diagnostics renderer.
 // Params
 // ------
 // render_func: Option[Callable[[DiagnosticContext], None]]
@@ -263,6 +238,27 @@ pub mod codespan {
     use codespan_reporting::files::SimpleFiles;
     use codespan_reporting::term::termcolor::{ColorChoice, StandardStream};
 
+    enum StartOrEnd {
+        Start,
+        End,
+    }
+
+    struct SpanToBytes {
+        inner: HashMap<std::String, HashMap<usize, (StartOrEnd,
+    }
+
+    struct ByteRange<FileId> {
+        file_id: FileId,
+        start_pos: usize,
+        end_pos: usize,
+    }
+
+    // impl SpanToBytes {
+    //     fn to_byte_pos(&self, span: tvm::ir::Span) -> ByteRange<FileId> {
+
+    //     }
+    // }
+
     pub fn to_diagnostic(diag: super::Diagnostic) -> CDiagnostic<String> {
         let severity = match diag.level {
             DiagnosticLevel::Error => Severity::Error,
@@ -290,9 +286,9 @@ pub mod codespan {
         let mut files: SimpleFiles<String, String> = SimpleFiles::new();
         let render_fn = move |diag_ctx: DiagnosticContext| {
             // let source_map = diag_ctx.module.source_map;
-            for diagnostic in diag_ctx.diagnostics {
+            // for diagnostic in diag_ctx.diagnostics {
 
-            }
+            // }
             panic!("render_fn");
         };