You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by jr...@apache.org on 2020/11/05 21:40:56 UTC
[incubator-tvm] branch cargo-build created (now c933926)
This is an automated email from the ASF dual-hosted git repository.
jroesch pushed a change to branch cargo-build
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git.
at c933926 Post-rebase
This branch includes the following new commits:
new 1097cbf Add initial boilerplate for Rust diagnostic interface.
new 77ba309 Codespan example almost working
new cb37856 WIP
new 131e40a Hacking on Rust inside of TVM
new b2b59c2 Borrow code from Egg
new db24553 Update CMake and delete old API
new e0f9801 Fix Linux build
new 20c6a28 Clean up exporting to show off new diagnostics
new 4cd1bbc Improve Rust bindings
new eeb86c6 Fix calling
new 4261461 Fix
new 6e13467 Rust Diagnostics work
new 0cabfdc Remove type checker
new 04a9779 Format and cleanup
new 6828374 Fix the extension code
new 1874350 More cleanup
new 8e295b7 Fix some CR
new 49246bf WIP
new a9ee3cb WIP
new b8dcc35 WIP
new 478326c Debug segfault
new e8bb83d WIP
new c933926 Post-rebase
The 23 revisions listed above as "new" are entirely new to this
repository and will be described in separate emails. The revisions
listed as "add" were already present in the repository and have only
been added to this reference.
[incubator-tvm] 11/23: Fix
Posted by jr...@apache.org.
This is an automated email from the ASF dual-hosted git repository.
jroesch pushed a commit to branch cargo-build
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
commit 4261461974f2fb548d6b48e19fab07d34dd03a51
Author: Jared Roesch <ro...@gmail.com>
AuthorDate: Fri Oct 16 02:11:53 2020 -0700
Fix
---
rust/tvm/src/ir/module.rs | 25 +++++++++++++++++--------
1 file changed, 17 insertions(+), 8 deletions(-)
diff --git a/rust/tvm/src/ir/module.rs b/rust/tvm/src/ir/module.rs
index 5156e74..11d6c49 100644
--- a/rust/tvm/src/ir/module.rs
+++ b/rust/tvm/src/ir/module.rs
@@ -16,6 +16,11 @@
* specific language governing permissions and limitations
* under the License.
*/
+use std::io::Result as IOResult;
+use std::path::Path;
+
+use thiserror::Error;
+use tvm_macros::Object;
use crate::runtime::array::Array;
use crate::runtime::function::Result;
@@ -27,15 +32,19 @@ use super::expr::GlobalVar;
use super::function::BaseFunc;
use super::source_map::SourceMap;
-use std::io::Result as IOResult;
-use std::path::Path;
-
-use tvm_macros::Object;
// TODO(@jroesch): define type
type TypeData = ObjectRef;
type GlobalTypeVar = ObjectRef;
+#[derive(Error, Debug)]
+pub enum Error {
+ #[error("{0}")]
+ IO(#[from] std::io::Error),
+ #[error("{0}")]
+ TVM(#[from] crate::runtime::Error),
+}
+
#[repr(C)]
#[derive(Object)]
#[ref_name = "IRModule"]
@@ -116,19 +125,19 @@ external! {
// });
impl IRModule {
- pub fn parse<N, S>(file_name: N, source: S) -> IRModule
+ pub fn parse<N, S>(file_name: N, source: S) -> Result<IRModule>
where
N: Into<TVMString>,
S: Into<TVMString>,
{
- parse_module(file_name.into(), source.into()).expect("failed to call parser")
+ parse_module(file_name.into(), source.into())
}
- pub fn parse_file<P: 'static + AsRef<Path>>(file_path: P) -> IOResult<IRModule> {
+ pub fn parse_file<P: 'static + AsRef<Path>>(file_path: P) -> std::result::Result<IRModule, Error> {
let file_path = file_path.as_ref();
let file_path_as_str = file_path.to_str().unwrap().to_string();
let source = std::fs::read_to_string(file_path)?;
- let module = IRModule::parse(file_path_as_str, source);
+ let module = IRModule::parse(file_path_as_str, source)?;
Ok(module)
}
[incubator-tvm] 23/23: Post-rebase
Posted by jr...@apache.org.
This is an automated email from the ASF dual-hosted git repository.
jroesch pushed a commit to branch cargo-build
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
commit c9339267bf7ae641a7f9c3678c84cf1ff600959e
Author: Jared Roesch <jr...@octoml.ai>
AuthorDate: Sat Oct 31 15:39:05 2020 -0700
Post-rebase
---
rust/tvm/src/ir/module.rs | 3 ---
1 file changed, 3 deletions(-)
diff --git a/rust/tvm/src/ir/module.rs b/rust/tvm/src/ir/module.rs
index db32ce2..869c5e6 100644
--- a/rust/tvm/src/ir/module.rs
+++ b/rust/tvm/src/ir/module.rs
@@ -34,11 +34,8 @@ use super::function::BaseFunc;
use super::source_map::SourceMap;
use super::{ty::GlobalTypeVar, relay};
-use tvm_macros::Object;
-
// TODO(@jroesch): define type
type TypeData = ObjectRef;
-type GlobalTypeVar = ObjectRef;
#[derive(Error, Debug)]
pub enum Error {
[incubator-tvm] 15/23: Fix the extension code
Posted by jr...@apache.org.
This is an automated email from the ASF dual-hosted git repository.
jroesch pushed a commit to branch cargo-build
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
commit 68283745b9c283db611579205a3b925eb09e9faa
Author: Jared Roesch <jr...@octoml.ai>
AuthorDate: Fri Oct 16 16:32:13 2020 -0700
Fix the extension code
---
src/contrib/rust_extension.cc | 31 +++++++++++++++++++++++++++++++
1 file changed, 31 insertions(+)
diff --git a/src/contrib/rust_extension.cc b/src/contrib/rust_extension.cc
new file mode 100644
index 0000000..075cbc6
--- /dev/null
+++ b/src/contrib/rust_extension.cc
@@ -0,0 +1,31 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/contrib/rust_extension.cc
+ * \brief Expose Rust extensions initialization.
+ */
+#ifdef RUST_COMPILER_EXT
+
+extern "C" {
+ int compiler_ext_initialize();
+ static int test = compiler_ext_initialize();
+}
+
+#endif
[incubator-tvm] 13/23: Remove type checker
Posted by jr...@apache.org.
This is an automated email from the ASF dual-hosted git repository.
jroesch pushed a commit to branch cargo-build
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
commit 0cabfdcee309b12d8907fc3abe2ba8e8718ecac6
Author: Jared Roesch <jr...@octoml.ai>
AuthorDate: Fri Oct 16 15:56:10 2020 -0700
Remove type checker
---
tests/python/relay/test_type_infer2.py | 419 ---------------------------------
1 file changed, 419 deletions(-)
diff --git a/tests/python/relay/test_type_infer2.py b/tests/python/relay/test_type_infer2.py
deleted file mode 100644
index 6758d96..0000000
--- a/tests/python/relay/test_type_infer2.py
+++ /dev/null
@@ -1,419 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""Test that type checker correcly computes types
- for expressions.
-"""
-import pytest
-import tvm
-
-from tvm import IRModule, te, relay, parser
-from tvm.relay import op, transform, analysis
-from tvm.relay import Any
-
-
-def infer_mod(mod, annotate_spans=True):
- if annotate_spans:
- mod = relay.transform.AnnotateSpans()(mod)
-
- mod = transform.InferType()(mod)
- return mod
-
-
-def infer_expr(expr, annotate_spans=True):
- mod = IRModule.from_expr(expr)
- mod = infer_mod(mod, annotate_spans)
- mod = transform.InferType()(mod)
- entry = mod["main"]
- return entry if isinstance(expr, relay.Function) else entry.body
-
-
-def assert_has_type(expr, typ, mod=None):
- if not mod:
- mod = tvm.IRModule({})
-
- mod["main"] = expr
- mod = infer_mod(mod)
- checked_expr = mod["main"]
- checked_type = checked_expr.checked_type
- if checked_type != typ:
- raise RuntimeError("Type mismatch %s vs %s" % (checked_type, typ))
-
-
-def initialize_box_adt(mod):
- # initializes simple ADT for tests
- box = relay.GlobalTypeVar("box")
- tv = relay.TypeVar("tv")
- constructor = relay.Constructor("constructor", [tv], box)
- data = relay.TypeData(box, [tv], [constructor])
- mod[box] = data
- return box, constructor
-
-
-def test_monomorphic_let():
- "Program: let %x = 1; %x"
- # TODO(@jroesch): this seems whack.
- sb = relay.ScopeBuilder()
- x = relay.var("x", dtype="float64", shape=())
- x = sb.let("x", relay.const(1.0, "float64"))
- sb.ret(x)
- xchecked = infer_expr(sb.get())
- assert xchecked.checked_type == relay.scalar_type("float64")
-
-
-def test_single_op():
- "Program: fn (%x : float32) { let %t1 = f(%x); %t1 }"
- x = relay.var("x", shape=[])
- func = relay.Function([x], op.log(x))
- ttype = relay.TensorType([], dtype="float32")
- assert_has_type(func, relay.FuncType([ttype], ttype))
-
-
-def test_add_broadcast_op():
- """
- Program:
- fn (%x: Tensor[(10, 4), float32], %y: Tensor[(5, 10, 1), float32])
- -> Tensor[(5, 10, 4), float32] {
- %x + %y
- }
- """
- x = relay.var("x", shape=(10, 4))
- y = relay.var("y", shape=(5, 10, 1))
- z = x + y
- func = relay.Function([x, y], z)
- t1 = relay.TensorType((10, 4), "float32")
- t2 = relay.TensorType((5, 10, 1), "float32")
- t3 = relay.TensorType((5, 10, 4), "float32")
- expected_ty = relay.FuncType([t1, t2], t3)
- assert_has_type(func, expected_ty)
-
-
-def test_dual_op():
- """Program:
- fn (%x : Tensor[(10, 10), float32]) {
- let %t1 = log(x);
- let %t2 = add(%t1, %x);
- %t1
- }
- """
- tp = relay.TensorType((10, 10), "float32")
- x = relay.var("x", tp)
- sb = relay.ScopeBuilder()
- t1 = sb.let("t1", relay.log(x))
- t2 = sb.let("t2", relay.add(t1, x))
- sb.ret(t2)
- f = relay.Function([x], sb.get())
- fchecked = infer_expr(f)
- assert fchecked.checked_type == relay.FuncType([tp], tp)
-
-
-def test_decl():
- """Program:
- def @f(%x : Tensor[(10, 10), float32]) {
- log(%x)
- }
- """
- tp = relay.TensorType((10, 10))
- x = relay.var("x", tp)
- f = relay.Function([x], relay.log(x))
- fchecked = infer_expr(f)
- assert fchecked.checked_type == relay.FuncType([tp], tp)
-
-
-def test_recursion():
- """
- Program:
- def @f(%n: int32, %data: float32) -> float32 {
- if (%n == 0) {
- %data
- } else {
- @f(%n - 1, log(%data))
- }
- }
- """
- sb = relay.ScopeBuilder()
- f = relay.GlobalVar("f")
- ti32 = relay.scalar_type("int32")
- tf32 = relay.scalar_type("float32")
- n = relay.var("n", ti32)
- data = relay.var("data", tf32)
-
- with sb.if_scope(relay.equal(n, relay.const(0, ti32))):
- sb.ret(data)
- with sb.else_scope():
- sb.ret(f(relay.subtract(n, relay.const(1, ti32)), relay.log(data)))
- mod = tvm.IRModule()
- mod[f] = relay.Function([n, data], sb.get())
- mod = infer_mod(mod)
- assert "@f(%1, %2)" in mod.astext()
- assert mod["f"].checked_type == relay.FuncType([ti32, tf32], tf32)
-
-
-def test_incomplete_call():
- tt = relay.scalar_type("int32")
- x = relay.var("x", tt)
- f = relay.var("f")
- func = relay.Function([x, f], relay.Call(f, [x]), tt)
-
- ft = infer_expr(func)
- f_type = relay.FuncType([tt], tt)
- assert ft.checked_type == relay.FuncType([tt, f_type], tt)
-
-
-def test_higher_order_argument():
- a = relay.TypeVar("a")
- x = relay.Var("x", a)
- id_func = relay.Function([x], x, a, [a])
-
- b = relay.TypeVar("b")
- f = relay.Var("f", relay.FuncType([b], b))
- y = relay.Var("y", b)
- ho_func = relay.Function([f, y], f(y), b, [b])
-
- # id func should be an acceptable argument to the higher-order
- # function even though id_func takes a type parameter
- ho_call = ho_func(id_func, relay.const(0, "int32"))
-
- hc = infer_expr(ho_call)
- expected = relay.scalar_type("int32")
- assert hc.checked_type == expected
-
-
-def test_higher_order_return():
- a = relay.TypeVar("a")
- x = relay.Var("x", a)
- id_func = relay.Function([x], x, a, [a])
-
- b = relay.TypeVar("b")
- nested_id = relay.Function([], id_func, relay.FuncType([b], b), [b])
-
- ft = infer_expr(nested_id)
- assert ft.checked_type == relay.FuncType([], relay.FuncType([b], b), [b])
-
-
-def test_higher_order_nested():
- a = relay.TypeVar("a")
- x = relay.Var("x", a)
- id_func = relay.Function([x], x, a, [a])
-
- choice_t = relay.FuncType([], relay.scalar_type("bool"))
- f = relay.Var("f", choice_t)
-
- b = relay.TypeVar("b")
- z = relay.Var("z")
- top = relay.Function(
- [f], relay.If(f(), id_func, relay.Function([z], z)), relay.FuncType([b], b), [b]
- )
-
- expected = relay.FuncType([choice_t], relay.FuncType([b], b), [b])
- ft = infer_expr(top)
- assert ft.checked_type == expected
-
-
-def test_tuple():
- tp = relay.TensorType((10,))
- x = relay.var("x", tp)
- res = relay.Tuple([x, x])
- assert infer_expr(res).checked_type == relay.TupleType([tp, tp])
-
-
-def test_ref():
- x = relay.var("x", "float32")
- y = relay.var("y", "float32")
- r = relay.RefCreate(x)
- st = relay.scalar_type("float32")
- assert infer_expr(r).checked_type == relay.RefType(st)
- g = relay.RefRead(r)
- assert infer_expr(g).checked_type == st
- w = relay.RefWrite(r, y)
- assert infer_expr(w).checked_type == relay.TupleType([])
-
-
-def test_free_expr():
- x = relay.var("x", "float32")
- y = relay.add(x, x)
- yy = infer_expr(y, annotate_spans=False)
- assert tvm.ir.structural_equal(yy.args[0], x, map_free_vars=True)
- assert yy.checked_type == relay.scalar_type("float32")
- assert x.vid.same_as(yy.args[0].vid)
-
-
-def test_type_args():
- x = relay.var("x", shape=(10, 10))
- y = relay.var("y", shape=(1, 10))
- z = relay.add(x, y)
- ty_z = infer_expr(z)
- ty_args = ty_z.type_args
- assert len(ty_args) == 2
- assert ty_args[0].dtype == "float32"
- assert ty_args[1].dtype == "float32"
- sh1 = ty_args[0].shape
- sh2 = ty_args[1].shape
- assert sh1[0].value == 10
- assert sh1[1].value == 10
- assert sh2[0].value == 1
- assert sh2[1].value == 10
-
-
-def test_global_var_recursion():
- mod = tvm.IRModule({})
- gv = relay.GlobalVar("main")
- x = relay.var("x", shape=[])
- tt = relay.scalar_type("float32")
-
- func = relay.Function([x], relay.Call(gv, [x]), tt)
- mod[gv] = func
- mod = infer_mod(mod)
- func_ty = mod["main"].checked_type
-
- assert func_ty == relay.FuncType([tt], tt)
-
-
-def test_equal():
- i = relay.var("i", shape=[], dtype="int32")
- eq = op.equal(i, relay.const(0, dtype="int32"))
- func = relay.Function([i], eq)
- ft = infer_expr(func)
- expected = relay.FuncType([relay.scalar_type("int32")], relay.scalar_type("bool"))
- assert ft.checked_type == expected
-
- assert ft.checked_type == relay.FuncType(
- [relay.scalar_type("int32")], relay.scalar_type("bool")
- )
-
-
-def test_constructor_type():
- mod = tvm.IRModule()
- box, constructor = initialize_box_adt(mod)
-
- a = relay.TypeVar("a")
- x = relay.Var("x", a)
- func = relay.Function([x], constructor(x), box(a), [a])
- mod["main"] = func
- mod = infer_mod(mod)
- func_ty = mod["main"].checked_type
- box = mod.get_global_type_var("box")
- expected = relay.FuncType([a], box(a), [a])
- assert func_ty == expected
-
-
-def test_constructor_call():
- mod = tvm.IRModule()
- box, constructor = initialize_box_adt(mod)
-
- box_unit = constructor(relay.Tuple([]))
- box_constant = constructor(relay.const(0, "float32"))
-
- func = relay.Function([], relay.Tuple([box_unit, box_constant]))
- mod["main"] = func
- mod = infer_mod(mod)
- ret_type = mod["main"].checked_type.ret_type.fields
- # NB(@jroesch): when we annotate spans the ast fragments before
- # annotation the previous fragments will no longer be directly equal.
- box = mod.get_global_type_var("box")
- expected1 = box(relay.TupleType([]))
- expected2 = box(relay.TensorType((), "float32"))
- assert ret_type[0] == expected1
- assert ret_type[1] == expected2
-
-
-def test_adt_match():
- mod = tvm.IRModule()
- box, constructor = initialize_box_adt(mod)
-
- v = relay.Var("v", relay.TensorType((), "float32"))
- match = relay.Match(
- constructor(relay.const(0, "float32")),
- [
- relay.Clause(
- relay.PatternConstructor(constructor, [relay.PatternVar(v)]), relay.Tuple([])
- ),
- # redundant but shouldn't matter to typechecking
- relay.Clause(relay.PatternWildcard(), relay.Tuple([])),
- ],
- )
-
- func = relay.Function([], match)
- mod["main"] = func
- mod = infer_mod(mod)
- actual = mod["main"].checked_type.ret_type
- assert actual == relay.TupleType([])
-
-
-def test_adt_match_type_annotations():
- mod = tvm.IRModule()
- box, constructor = initialize_box_adt(mod)
-
- # the only type annotation is inside the match pattern var
- # but that should be enough info
- tt = relay.TensorType((2, 2), "float32")
- x = relay.Var("x")
- mv = relay.Var("mv", tt)
- match = relay.Match(
- constructor(x),
- [
- relay.Clause(
- relay.PatternConstructor(constructor, [relay.PatternVar(mv)]), relay.Tuple([])
- )
- ],
- )
-
- mod["main"] = relay.Function([x], match)
- mod = infer_mod(mod)
- ft = mod["main"].checked_type
- assert ft == relay.FuncType([tt], relay.TupleType([]))
-
-
-def test_let_polymorphism():
- id = relay.Var("id")
- xt = relay.TypeVar("xt")
- x = relay.Var("x", xt)
- body = relay.Tuple([id(relay.const(1)), id(relay.Tuple([]))])
- body = relay.Let(id, relay.Function([x], x, xt, [xt]), body)
- body = infer_expr(body)
- int32 = relay.TensorType((), "int32")
- tvm.ir.assert_structural_equal(body.checked_type, relay.TupleType([int32, relay.TupleType([])]))
-
-
-def test_if():
- choice_t = relay.FuncType([], relay.scalar_type("bool"))
- f = relay.Var("f", choice_t)
- true_branch = relay.Var("True", relay.TensorType([Any(), 1], dtype="float32"))
- false_branch = relay.Var("False", relay.TensorType([Any(), Any()], dtype="float32"))
- top = relay.Function([f, true_branch, false_branch], relay.If(f(), true_branch, false_branch))
- ft = infer_expr(top)
- tvm.ir.assert_structural_equal(ft.ret_type, relay.TensorType([Any(), 1], dtype="float32"))
-
-
-def test_type_arg_infer():
- code = """
-#[version = "0.0.5"]
-def @id[A](%x: A) -> A {
- %x
-}
-def @main(%f: float32) -> float32 {
- @id(%f)
-}
-"""
- mod = tvm.parser.fromtext(code)
- mod = transform.InferType()(mod)
- tvm.ir.assert_structural_equal(mod["main"].body.type_args, [relay.TensorType((), "float32")])
-
-
-if __name__ == "__main__":
- import sys
-
- pytest.main(sys.argv)
[incubator-tvm] 21/23: Debug segfault
Posted by jr...@apache.org.
This is an automated email from the ASF dual-hosted git repository.
jroesch pushed a commit to branch cargo-build
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
commit 478326c8d7c2ccce10b77aa28e4b22bafbfdf877
Author: Jared Roesch <jr...@octoml.ai>
AuthorDate: Sun Oct 25 17:26:47 2020 -0700
Debug segfault
---
python/tvm/__init__.py | 2 ++
python/tvm/relay/__init__.py | 3 +-
python/tvm/relay/analysis/__init__.py | 2 +-
python/tvm/relay/analysis/analysis.py | 6 ++--
python/tvm/relay/analysis/annotated_regions.py | 2 +-
python/tvm/relay/analysis/call_graph.py | 4 +--
python/tvm/relay/analysis/sparse_dense.py | 15 ++++-----
python/tvm/relay/build_module.py | 6 +++-
python/tvm/relay/op/op.py | 40 ++++++++++++------------
python/tvm/relay/transform/__init__.py | 2 +-
python/tvm/relay/transform/memory_alloc.py | 7 ++---
python/tvm/relay/transform/transform.py | 5 +--
python/tvm/topi/cuda/__init__.py | 2 --
python/tvm/topi/cuda/sparse.py | 3 +-
rust/tvm/Cargo.toml | 2 +-
rust/tvm/src/ir/relay/mod.rs | 1 -
rust/tvm/src/python.rs | 43 +++++++++++++++++++++++---
17 files changed, 91 insertions(+), 54 deletions(-)
diff --git a/python/tvm/__init__.py b/python/tvm/__init__.py
index 569e8f0..60f81f4 100644
--- a/python/tvm/__init__.py
+++ b/python/tvm/__init__.py
@@ -67,6 +67,8 @@ from . import support
# Contrib initializers
from .contrib import rocm as _rocm, nvcc as _nvcc, sdaccel as _sdaccel
+def cleanup():
+ _ffi.base._LIB = None
def tvm_wrap_excepthook(exception_hook):
"""Wrap given excepthook with TVM additional work."""
diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py
index cd96ecc..7e6ed4f 100644
--- a/python/tvm/relay/__init__.py
+++ b/python/tvm/relay/__init__.py
@@ -60,8 +60,7 @@ from . import qnn
from .scope_builder import ScopeBuilder
# Load Memory Passes
-from .transform import memory_alloc
-from .transform import memory_plan
+from .transform import memory_alloc, memory_plan
# Required to traverse large programs
setrecursionlimit(10000)
diff --git a/python/tvm/relay/analysis/__init__.py b/python/tvm/relay/analysis/__init__.py
index b4ea7f3..4ea4de7 100644
--- a/python/tvm/relay/analysis/__init__.py
+++ b/python/tvm/relay/analysis/__init__.py
@@ -26,7 +26,7 @@ from .annotated_regions import AnnotatedRegionSet
from . import call_graph
from .call_graph import CallGraph
-# Feature
+# # Feature
from . import feature
from . import sparse_dense
diff --git a/python/tvm/relay/analysis/analysis.py b/python/tvm/relay/analysis/analysis.py
index 7e49461..48e9ce0 100644
--- a/python/tvm/relay/analysis/analysis.py
+++ b/python/tvm/relay/analysis/analysis.py
@@ -20,9 +20,9 @@
This file contains the set of passes for Relay, which exposes an interface for
configuring the passes and scripting them in Python.
"""
-from tvm.ir import IRModule
-from tvm.relay import transform, build_module
-from tvm.runtime.ndarray import cpu
+from ...ir import IRModule
+from ...relay import transform, build_module
+from ...runtime.ndarray import cpu
from . import _ffi_api
from .feature import Feature
diff --git a/python/tvm/relay/analysis/annotated_regions.py b/python/tvm/relay/analysis/annotated_regions.py
index 437b97b..a18ccb9 100644
--- a/python/tvm/relay/analysis/annotated_regions.py
+++ b/python/tvm/relay/analysis/annotated_regions.py
@@ -17,7 +17,7 @@
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name, unused-import
"""Regions used in Relay."""
-from tvm.runtime import Object
+from ...runtime import Object
from . import _ffi_api
diff --git a/python/tvm/relay/analysis/call_graph.py b/python/tvm/relay/analysis/call_graph.py
index 966659a..fd9704d 100644
--- a/python/tvm/relay/analysis/call_graph.py
+++ b/python/tvm/relay/analysis/call_graph.py
@@ -17,8 +17,8 @@
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name, unused-import
"""Call graph used in Relay."""
-from tvm.ir import IRModule
-from tvm.runtime import Object
+from ...ir import IRModule
+from ...runtime import Object
from ..expr import GlobalVar
from . import _ffi_api
diff --git a/python/tvm/relay/analysis/sparse_dense.py b/python/tvm/relay/analysis/sparse_dense.py
index d521748..51fab34 100644
--- a/python/tvm/relay/analysis/sparse_dense.py
+++ b/python/tvm/relay/analysis/sparse_dense.py
@@ -22,8 +22,8 @@ to block sparse model
"""
from collections import namedtuple
import numpy as np
-import scipy.sparse as sp
-import tvm
+
+from ... import nd, runtime
from . import _ffi_api
@@ -73,6 +73,7 @@ def process_params(expr, params, block_size, sparsity_threshold):
ret : Namedtuple[weight_name: Array[String], weight_shape: Array[Array[IntImm]]]
return names of qualified dense weight and the shape in BSR format
"""
+ import scipy.sparse as sp
memo = SparseAnalysisResult(weight_name=[], weight_shape=[])
weight_names = _search_dense_op_weight(expr)
for name in weight_names:
@@ -89,11 +90,11 @@ def process_params(expr, params, block_size, sparsity_threshold):
+ list(sparse_weight.indices.shape)
+ list(sparse_weight.indptr.shape)
)
- params[name + ".data"] = tvm.nd.array(sparse_weight.data)
- params[name + ".indices"] = tvm.nd.array(sparse_weight.indices)
- params[name + ".indptr"] = tvm.nd.array(sparse_weight.indptr)
+ params[name + ".data"] = nd.array(sparse_weight.data)
+ params[name + ".indices"] = nd.array(sparse_weight.indices)
+ params[name + ".indptr"] = nd.array(sparse_weight.indptr)
ret = SparseAnalysisResult(
- weight_name=tvm.runtime.convert(memo.weight_name),
- weight_shape=tvm.runtime.convert(memo.weight_shape),
+ weight_name=runtime.convert(memo.weight_name),
+ weight_shape=runtime.convert(memo.weight_shape),
)
return ret
diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py
index 35bd8e6..e93d654 100644
--- a/python/tvm/relay/build_module.py
+++ b/python/tvm/relay/build_module.py
@@ -24,7 +24,7 @@ import numpy as np
from tvm.ir import IRModule
from tvm.tir import expr as tvm_expr
-from .. import nd as _nd, autotvm
+from .. import nd as _nd, autotvm, register_func
from ..target import Target
from ..contrib import graph_runtime as _graph_rt
from . import _build_module
@@ -186,6 +186,10 @@ class BuildModule(object):
ret[key] = value.data
return ret
+@register_func("tvm.relay.build")
+def build1(mod, target=None, target_host=None, params=None, mod_name="default"):
+ import pdb; pdb.set_trace()
+ return build(mod, target, target_host, params, mod_name)
def build(mod, target=None, target_host=None, params=None, mod_name="default"):
"""Helper function that builds a Relay function to run on TVM graph
diff --git a/python/tvm/relay/op/op.py b/python/tvm/relay/op/op.py
index 755659a..f780dc9 100644
--- a/python/tvm/relay/op/op.py
+++ b/python/tvm/relay/op/op.py
@@ -16,10 +16,8 @@
# under the License.
# pylint: disable=unused-argument,invalid-name
"""The base node types for the Relay language."""
-import tvm._ffi
-import tvm.ir
-from tvm.driver import lower, build
-
+from ... import _ffi, ir
+from ...driver import lower, build
from ...target import get_native_generic_func, GenericFunc
from ...runtime import Object
from . import _make
@@ -38,7 +36,7 @@ def get(op_name):
op : Op
The op of the corresponding name
"""
- return tvm.ir.Op.get(op_name)
+ return ir.Op.get(op_name)
class OpPattern(object):
@@ -65,7 +63,7 @@ class OpPattern(object):
OPAQUE = 8
-@tvm._ffi.register_object("relay.OpImplementation")
+@_ffi.register_object("relay.OpImplementation")
class OpImplementation(Object):
"""Operator implementation"""
@@ -112,12 +110,12 @@ class OpImplementation(Object):
return _OpImplementationSchedule(self, attrs, outs, target)
-@tvm._ffi.register_object("relay.OpSpecialization")
+@_ffi.register_object("relay.OpSpecialization")
class OpSpecialization(Object):
"""Operator specialization"""
-@tvm._ffi.register_object("relay.OpStrategy")
+@_ffi.register_object("relay.OpStrategy")
class OpStrategy(Object):
"""Operator strategy"""
@@ -184,7 +182,7 @@ def register_compute(op_name, compute=None, level=10):
level : int
The priority level
"""
- return tvm.ir.register_op_attr(op_name, "FTVMCompute", compute, level)
+ return ir.register_op_attr(op_name, "FTVMCompute", compute, level)
def register_strategy(op_name, fstrategy=None, level=10):
@@ -205,7 +203,7 @@ def register_strategy(op_name, fstrategy=None, level=10):
if not isinstance(fstrategy, GenericFunc):
assert hasattr(fstrategy, "generic_func_node")
fstrategy = fstrategy.generic_func_node
- return tvm.ir.register_op_attr(op_name, "FTVMStrategy", fstrategy, level)
+ return ir.register_op_attr(op_name, "FTVMStrategy", fstrategy, level)
def register_schedule(op_name, schedule, level=10):
@@ -286,7 +284,7 @@ def register_alter_op_layout(op_name, alter_layout=None, level=10):
level : int
The priority level
"""
- return tvm.ir.register_op_attr(op_name, "FTVMAlterOpLayout", alter_layout, level)
+ return ir.register_op_attr(op_name, "FTVMAlterOpLayout", alter_layout, level)
def register_convert_op_layout(op_name, convert_layout=None, level=10):
@@ -303,7 +301,7 @@ def register_convert_op_layout(op_name, convert_layout=None, level=10):
level : int
The priority level
"""
- return tvm.ir.register_op_attr(op_name, "FTVMConvertOpLayout", convert_layout, level)
+ return ir.register_op_attr(op_name, "FTVMConvertOpLayout", convert_layout, level)
def register_legalize(op_name, legal_op=None, level=10):
@@ -320,7 +318,7 @@ def register_legalize(op_name, legal_op=None, level=10):
level : int
The priority level
"""
- return tvm.ir.register_op_attr(op_name, "FTVMLegalize", legal_op, level)
+ return ir.register_op_attr(op_name, "FTVMLegalize", legal_op, level)
def register_pattern(op_name, pattern, level=10):
@@ -337,7 +335,7 @@ def register_pattern(op_name, pattern, level=10):
level : int
The priority level
"""
- return tvm.ir.register_op_attr(op_name, "TOpPattern", pattern, level)
+ return ir.register_op_attr(op_name, "TOpPattern", pattern, level)
def register_gradient(op_name, fgradient=None, level=10):
@@ -354,7 +352,7 @@ def register_gradient(op_name, fgradient=None, level=10):
level : int
The priority level
"""
- return tvm.ir.register_op_attr(op_name, "FPrimalGradient", fgradient, level)
+ return ir.register_op_attr(op_name, "FPrimalGradient", fgradient, level)
def register_shape_func(op_name, data_dependant, shape_func=None, level=10):
@@ -376,7 +374,7 @@ def register_shape_func(op_name, data_dependant, shape_func=None, level=10):
The priority level
"""
get(op_name).set_attr("TShapeDataDependant", data_dependant, level)
- return tvm.ir.register_op_attr(op_name, "FShapeFunc", shape_func, level)
+ return ir.register_op_attr(op_name, "FShapeFunc", shape_func, level)
def register_external_compiler(op_name, fexternal=None, level=10):
@@ -395,15 +393,15 @@ def register_external_compiler(op_name, fexternal=None, level=10):
level : int
The priority level
"""
- return tvm.ir.register_op_attr(op_name, "FTVMExternalCompiler", fexternal, level)
+ return ir.register_op_attr(op_name, "FTVMExternalCompiler", fexternal, level)
-@tvm._ffi.register_func("relay.op.compiler._lower")
+_ffi.register_func("relay.op.compiler._lower")
def _lower(name, schedule, inputs, outputs):
return lower(schedule, list(inputs) + list(outputs), name=name)
-@tvm._ffi.register_func("relay.op.compiler._build")
+_ffi.register_func("relay.op.compiler._build")
def _build(lowered_funcs):
return build(lowered_funcs, target="llvm")
@@ -420,7 +418,7 @@ def debug(expr, debug_func=None):
if debug_func:
name = "debugger_func{}".format(__DEBUG_COUNTER__)
- tvm._ffi.register_func(name, debug_func)
+ _ffi.register_func(name, debug_func)
__DEBUG_COUNTER__ += 1
else:
name = ""
@@ -428,4 +426,4 @@ def debug(expr, debug_func=None):
return _make.debug(expr, name)
-tvm._ffi._init_api("relay.op", __name__)
+_ffi._init_api("relay.op", __name__)
diff --git a/python/tvm/relay/transform/__init__.py b/python/tvm/relay/transform/__init__.py
index 1d0ea17..9684e42 100644
--- a/python/tvm/relay/transform/__init__.py
+++ b/python/tvm/relay/transform/__init__.py
@@ -19,4 +19,4 @@
# transformation passes
from .transform import *
from .recast import recast
-from . import memory_alloc
+# from . import memory_alloc
diff --git a/python/tvm/relay/transform/memory_alloc.py b/python/tvm/relay/transform/memory_alloc.py
index 66528c8..593a411 100644
--- a/python/tvm/relay/transform/memory_alloc.py
+++ b/python/tvm/relay/transform/memory_alloc.py
@@ -20,14 +20,13 @@ A pass for manifesting explicit memory allocations.
"""
import numpy as np
-from tvm.ir.transform import PassContext, module_pass
-from tvm.relay.transform import InferType
-from tvm import nd, container
+from ... import DataType, register_func, nd, container, cpu
+from ...ir.transform import PassContext, module_pass
+from . import InferType
from ..function import Function
from ..expr_functor import ExprVisitor, ExprMutator
from ..scope_builder import ScopeBuilder
from .. import op
-from ... import DataType, register_func
from .. import ty, expr
from ..backend import compile_engine
from ..op.memory import flatten_tuple_type, from_tuple_type, to_tuple_type
diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py
index f0f55f6..af1e718 100644
--- a/python/tvm/relay/transform/transform.py
+++ b/python/tvm/relay/transform/transform.py
@@ -23,11 +23,12 @@ import inspect
import functools
import warnings
+from ...ir import transform as tvm_transform
import tvm.ir
from tvm import te
from tvm.runtime import ndarray as _nd
-from tvm import relay
+# from tvm import relay
from . import _ffi_api
@@ -82,7 +83,7 @@ def build_config(opt_level=2, required_pass=None, disabled_pass=None, trace=None
@tvm._ffi.register_object("relay.FunctionPass")
-class FunctionPass(tvm.ir.transform.Pass):
+class FunctionPass():
"""A pass that works on each tvm.relay.Function in a module. A function
pass class should be created through `function_pass`.
"""
diff --git a/python/tvm/topi/cuda/__init__.py b/python/tvm/topi/cuda/__init__.py
index 3ff544f..47badb5 100644
--- a/python/tvm/topi/cuda/__init__.py
+++ b/python/tvm/topi/cuda/__init__.py
@@ -17,8 +17,6 @@
# pylint: disable=redefined-builtin, wildcard-import
"""CUDA specific declaration and schedules."""
-from __future__ import absolute_import as _abs
-
from .conv1d import *
from .conv1d_transpose_ncw import *
from .conv2d import *
diff --git a/python/tvm/topi/cuda/sparse.py b/python/tvm/topi/cuda/sparse.py
index d125423..c2b99af 100644
--- a/python/tvm/topi/cuda/sparse.py
+++ b/python/tvm/topi/cuda/sparse.py
@@ -17,7 +17,6 @@
"""Sparse operators"""
import numpy as np
-import scipy.sparse as sp
import tvm
from tvm import relay, te
@@ -326,6 +325,7 @@ def schedule_sparse_dense_padded(outs):
def pad_sparse_matrix(matrix, blocksize):
"""Pad rows of sparse matrix matrix so that they are a multiple of blocksize."""
+ import scipy.sparse as sp
assert isinstance(matrix, sp.bsr_matrix)
new_entries = np.zeros(matrix.shape[0], dtype=matrix.indptr.dtype)
bsr = matrix.blocksize[0]
@@ -362,6 +362,7 @@ def _alter_sparse_dense_layout(_attrs, inputs, _tinfos, _out_type):
sparse_dense implementation for one that operates on a padded matrix. We
also padd the matrix.
"""
+ import scipy.sparse as sp
if (
isinstance(inputs[1], relay.Constant)
and isinstance(inputs[2], relay.Constant)
diff --git a/rust/tvm/Cargo.toml b/rust/tvm/Cargo.toml
index 153a195..c1d8aa8 100644
--- a/rust/tvm/Cargo.toml
+++ b/rust/tvm/Cargo.toml
@@ -50,7 +50,7 @@ tvm-macros = { version = "*", path = "../tvm-macros/" }
paste = "0.1"
mashup = "0.1"
once_cell = "^1.3.1"
-pyo3 = { version = "0.11.1", optional = true }
+pyo3 = { version = "^0.12", optional = true }
codespan-reporting = "0.9.5"
structopt = { version = "0.3" }
diff --git a/rust/tvm/src/ir/relay/mod.rs b/rust/tvm/src/ir/relay/mod.rs
index 90b7a6a..a6ea684 100644
--- a/rust/tvm/src/ir/relay/mod.rs
+++ b/rust/tvm/src/ir/relay/mod.rs
@@ -24,7 +24,6 @@ use super::expr::BaseExprNode;
use super::function::BaseFuncNode;
use super::span::Span;
use super::ty::{Type, TypeNode};
-use super::span::Span;
use tvm_macros::Object;
use tvm_rt::NDArray;
diff --git a/rust/tvm/src/python.rs b/rust/tvm/src/python.rs
index 89558af..2b2d374 100644
--- a/rust/tvm/src/python.rs
+++ b/rust/tvm/src/python.rs
@@ -18,6 +18,21 @@
*/
use pyo3::prelude::*;
+use once_cell::sync::OnceCell;
+
+// static TVM_PYTHON: OnceCell<Py<PyModule>> = OnceCell::new();
+
+// fn initialize() -> Py<PyModule> {
+// TVM_PYTHON.get_or_init(|| {
+// let gil = Python::acquire_gil();
+// let py = gil.python();
+// PyModule::new(py, "__tvm__rust__module__").map_err(|e| {
+// // We can't display Python exceptions via std::fmt::Display,
+// // so print the error here manually.
+// e.print_and_set_sys_last_vars(py);
+// }).expect("failed to initialize the Python interface").into()
+// }).clone()
+// }
/// Load the Python interpreter into the address space.
///
@@ -29,6 +44,8 @@ use pyo3::prelude::*;
pub fn load() -> Result<String, ()> {
let gil = Python::acquire_gil();
let py = gil.python();
+ // let main_mod = initialize();
+ //let main_mod = main_mod.as_ref(py);
load_python_tvm_(py).map_err(|e| {
// We can't display Python exceptions via std::fmt::Display,
// so print the error here manually.
@@ -36,12 +53,30 @@ pub fn load() -> Result<String, ()> {
})
}
-// const TVMC_CODE: &'static str = include_str!("tvmc.py");
+fn import_python<'p, 'b: 'p>(py: Python<'p>, to_import: &'b str) -> PyResult<&'p PyModule> {
+ let imported_mod = py.import(to_import)?;
+ PyModule::from_code(py,
+ r#"
+import tvm
+from tvm import relay
+tvm.cleanup()
+"#, "blah", "my_mod")?;
+ // py_mod.add(to_import, imported_mod)?;
+ Ok(imported_mod)
+}
+
+pub fn import(mod_to_import: &str) -> PyResult<()> {
+ let gil = Python::acquire_gil();
+ let py = gil.python();
+ // let main_mod = initialize();
+ // let main_mod = main_mod.as_ref(py);
+ import_python(py, mod_to_import)?;
+ Ok(())
+}
fn load_python_tvm_(py: Python) -> PyResult<String> {
- let sys = py.import("tvm")?;
- let version: String = sys.get("__version__")?.extract()?;
- // py.run(TVMC_CODE, None, None)?;
+ let imported_mod = import_python(py, "tvm")?;
+ let version: String = imported_mod.get("__version__")?.extract()?;
Ok(version)
}
[incubator-tvm] 04/23: Hacking on Rust inside of TVM
Posted by jr...@apache.org.
This is an automated email from the ASF dual-hosted git repository.
jroesch pushed a commit to branch cargo-build
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
commit 131e40afeb4bd89f0b378e0e7d8ad440d017d33a
Author: Jared Roesch <jr...@octoml.ai>
AuthorDate: Tue Oct 13 15:25:56 2020 -0700
Hacking on Rust inside of TVM
---
CMakeLists.txt | 1 +
cmake/modules/RustExt.cmake | 13 +
rust/Cargo.toml | 1 +
rust/compiler-ext/Cargo.toml | 13 +
rust/compiler-ext/src/lib.rs | 7 +
rust/tvm/src/ir/source_map.rs | 0
rust/tvm/test.rly | 2 +
tests/python/relay/test_type_infer2.py | 419 +++++++++++++++++++++++++++++++++
8 files changed, 456 insertions(+)
diff --git a/CMakeLists.txt b/CMakeLists.txt
index abf2b56..9f82754 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -79,6 +79,7 @@ tvm_option(USE_ARM_COMPUTE_LIB "Build with Arm Compute Library" OFF)
tvm_option(USE_ARM_COMPUTE_LIB_GRAPH_RUNTIME "Build with Arm Compute Library graph runtime" OFF)
tvm_option(USE_TENSORRT_CODEGEN "Build with TensorRT Codegen support" OFF)
tvm_option(USE_TENSORRT_RUNTIME "Build with TensorRT runtime" OFF)
+tvm_option(USE_RUST_EXT "Build with Rust based compiler extensions" OFF)
# include directories
include_directories(${CMAKE_INCLUDE_PATH})
diff --git a/cmake/modules/RustExt.cmake b/cmake/modules/RustExt.cmake
new file mode 100644
index 0000000..45e46bd
--- /dev/null
+++ b/cmake/modules/RustExt.cmake
@@ -0,0 +1,13 @@
+if(USE_RUST_EXT)
+ set(RUST_SRC_DIR "rust")
+ set(CARGO_OUT_DIR "rust/target"
+ set(COMPILER_EXT_PATH "${CARGO_OUT_DIR}/target/release/libcompiler_ext.dylib")
+
+ add_custom_command(
+ OUTPUT "${COMPILER_EXT_PATH}"
+ COMMAND cargo build --release
+ MAIN_DEPENDENCY "${RUST_SRC_DIR}"
+ WORKING_DIRECTORY "${RUST_SRC_DIR}/compiler-ext")
+
+ target_link_libraries(tvm "${COMPILER_EXT_PATH}" PRIVATE)
+endif(USE_RUST_EXT)
diff --git a/rust/Cargo.toml b/rust/Cargo.toml
index 28312a5..7c092d8 100644
--- a/rust/Cargo.toml
+++ b/rust/Cargo.toml
@@ -29,4 +29,5 @@ members = [
"tvm-graph-rt/tests/test_tvm_dso",
"tvm-graph-rt/tests/test_wasm32",
"tvm-graph-rt/tests/test_nn",
+ "compiler-ext",
]
diff --git a/rust/compiler-ext/Cargo.toml b/rust/compiler-ext/Cargo.toml
new file mode 100644
index 0000000..76d10eb
--- /dev/null
+++ b/rust/compiler-ext/Cargo.toml
@@ -0,0 +1,13 @@
+[package]
+name = "compiler-ext"
+version = "0.1.0"
+authors = ["Jared Roesch <jr...@octoml.ai>"]
+edition = "2018"
+# TODO(@jroesch): would be cool to figure out how to statically link instead.
+
+[lib]
+crate-type = ["cdylib"]
+
+# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
+
+[dependencies]
diff --git a/rust/compiler-ext/src/lib.rs b/rust/compiler-ext/src/lib.rs
new file mode 100644
index 0000000..31e1bb2
--- /dev/null
+++ b/rust/compiler-ext/src/lib.rs
@@ -0,0 +1,7 @@
+#[cfg(test)]
+mod tests {
+ #[test]
+ fn it_works() {
+ assert_eq!(2 + 2, 4);
+ }
+}
diff --git a/rust/tvm/src/ir/source_map.rs b/rust/tvm/src/ir/source_map.rs
new file mode 100644
index 0000000..e69de29
diff --git a/rust/tvm/test.rly b/rust/tvm/test.rly
new file mode 100644
index 0000000..d8b7c69
--- /dev/null
+++ b/rust/tvm/test.rly
@@ -0,0 +1,2 @@
+#[version = "0.0.5"]
+fn @main(%x: int32) -> float32 { %x }
diff --git a/tests/python/relay/test_type_infer2.py b/tests/python/relay/test_type_infer2.py
new file mode 100644
index 0000000..6758d96
--- /dev/null
+++ b/tests/python/relay/test_type_infer2.py
@@ -0,0 +1,419 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Test that type checker correcly computes types
+ for expressions.
+"""
+import pytest
+import tvm
+
+from tvm import IRModule, te, relay, parser
+from tvm.relay import op, transform, analysis
+from tvm.relay import Any
+
+
+def infer_mod(mod, annotate_spans=True):
+ if annotate_spans:
+ mod = relay.transform.AnnotateSpans()(mod)
+
+ mod = transform.InferType()(mod)
+ return mod
+
+
+def infer_expr(expr, annotate_spans=True):
+ mod = IRModule.from_expr(expr)
+ mod = infer_mod(mod, annotate_spans)
+ mod = transform.InferType()(mod)
+ entry = mod["main"]
+ return entry if isinstance(expr, relay.Function) else entry.body
+
+
+def assert_has_type(expr, typ, mod=None):
+ if not mod:
+ mod = tvm.IRModule({})
+
+ mod["main"] = expr
+ mod = infer_mod(mod)
+ checked_expr = mod["main"]
+ checked_type = checked_expr.checked_type
+ if checked_type != typ:
+ raise RuntimeError("Type mismatch %s vs %s" % (checked_type, typ))
+
+
+def initialize_box_adt(mod):
+ # initializes simple ADT for tests
+ box = relay.GlobalTypeVar("box")
+ tv = relay.TypeVar("tv")
+ constructor = relay.Constructor("constructor", [tv], box)
+ data = relay.TypeData(box, [tv], [constructor])
+ mod[box] = data
+ return box, constructor
+
+
+def test_monomorphic_let():
+ "Program: let %x = 1; %x"
+ # TODO(@jroesch): this seems whack.
+ sb = relay.ScopeBuilder()
+ x = relay.var("x", dtype="float64", shape=())
+ x = sb.let("x", relay.const(1.0, "float64"))
+ sb.ret(x)
+ xchecked = infer_expr(sb.get())
+ assert xchecked.checked_type == relay.scalar_type("float64")
+
+
+def test_single_op():
+ "Program: fn (%x : float32) { let %t1 = f(%x); %t1 }"
+ x = relay.var("x", shape=[])
+ func = relay.Function([x], op.log(x))
+ ttype = relay.TensorType([], dtype="float32")
+ assert_has_type(func, relay.FuncType([ttype], ttype))
+
+
+def test_add_broadcast_op():
+ """
+ Program:
+ fn (%x: Tensor[(10, 4), float32], %y: Tensor[(5, 10, 1), float32])
+ -> Tensor[(5, 10, 4), float32] {
+ %x + %y
+ }
+ """
+ x = relay.var("x", shape=(10, 4))
+ y = relay.var("y", shape=(5, 10, 1))
+ z = x + y
+ func = relay.Function([x, y], z)
+ t1 = relay.TensorType((10, 4), "float32")
+ t2 = relay.TensorType((5, 10, 1), "float32")
+ t3 = relay.TensorType((5, 10, 4), "float32")
+ expected_ty = relay.FuncType([t1, t2], t3)
+ assert_has_type(func, expected_ty)
+
+
+def test_dual_op():
+ """Program:
+ fn (%x : Tensor[(10, 10), float32]) {
+ let %t1 = log(x);
+ let %t2 = add(%t1, %x);
+ %t1
+ }
+ """
+ tp = relay.TensorType((10, 10), "float32")
+ x = relay.var("x", tp)
+ sb = relay.ScopeBuilder()
+ t1 = sb.let("t1", relay.log(x))
+ t2 = sb.let("t2", relay.add(t1, x))
+ sb.ret(t2)
+ f = relay.Function([x], sb.get())
+ fchecked = infer_expr(f)
+ assert fchecked.checked_type == relay.FuncType([tp], tp)
+
+
+def test_decl():
+ """Program:
+ def @f(%x : Tensor[(10, 10), float32]) {
+ log(%x)
+ }
+ """
+ tp = relay.TensorType((10, 10))
+ x = relay.var("x", tp)
+ f = relay.Function([x], relay.log(x))
+ fchecked = infer_expr(f)
+ assert fchecked.checked_type == relay.FuncType([tp], tp)
+
+
+def test_recursion():
+ """
+ Program:
+ def @f(%n: int32, %data: float32) -> float32 {
+ if (%n == 0) {
+ %data
+ } else {
+ @f(%n - 1, log(%data))
+ }
+ }
+ """
+ sb = relay.ScopeBuilder()
+ f = relay.GlobalVar("f")
+ ti32 = relay.scalar_type("int32")
+ tf32 = relay.scalar_type("float32")
+ n = relay.var("n", ti32)
+ data = relay.var("data", tf32)
+
+ with sb.if_scope(relay.equal(n, relay.const(0, ti32))):
+ sb.ret(data)
+ with sb.else_scope():
+ sb.ret(f(relay.subtract(n, relay.const(1, ti32)), relay.log(data)))
+ mod = tvm.IRModule()
+ mod[f] = relay.Function([n, data], sb.get())
+ mod = infer_mod(mod)
+ assert "@f(%1, %2)" in mod.astext()
+ assert mod["f"].checked_type == relay.FuncType([ti32, tf32], tf32)
+
+
+def test_incomplete_call():
+ tt = relay.scalar_type("int32")
+ x = relay.var("x", tt)
+ f = relay.var("f")
+ func = relay.Function([x, f], relay.Call(f, [x]), tt)
+
+ ft = infer_expr(func)
+ f_type = relay.FuncType([tt], tt)
+ assert ft.checked_type == relay.FuncType([tt, f_type], tt)
+
+
+def test_higher_order_argument():
+ a = relay.TypeVar("a")
+ x = relay.Var("x", a)
+ id_func = relay.Function([x], x, a, [a])
+
+ b = relay.TypeVar("b")
+ f = relay.Var("f", relay.FuncType([b], b))
+ y = relay.Var("y", b)
+ ho_func = relay.Function([f, y], f(y), b, [b])
+
+ # id func should be an acceptable argument to the higher-order
+ # function even though id_func takes a type parameter
+ ho_call = ho_func(id_func, relay.const(0, "int32"))
+
+ hc = infer_expr(ho_call)
+ expected = relay.scalar_type("int32")
+ assert hc.checked_type == expected
+
+
+def test_higher_order_return():
+ a = relay.TypeVar("a")
+ x = relay.Var("x", a)
+ id_func = relay.Function([x], x, a, [a])
+
+ b = relay.TypeVar("b")
+ nested_id = relay.Function([], id_func, relay.FuncType([b], b), [b])
+
+ ft = infer_expr(nested_id)
+ assert ft.checked_type == relay.FuncType([], relay.FuncType([b], b), [b])
+
+
+def test_higher_order_nested():
+ a = relay.TypeVar("a")
+ x = relay.Var("x", a)
+ id_func = relay.Function([x], x, a, [a])
+
+ choice_t = relay.FuncType([], relay.scalar_type("bool"))
+ f = relay.Var("f", choice_t)
+
+ b = relay.TypeVar("b")
+ z = relay.Var("z")
+ top = relay.Function(
+ [f], relay.If(f(), id_func, relay.Function([z], z)), relay.FuncType([b], b), [b]
+ )
+
+ expected = relay.FuncType([choice_t], relay.FuncType([b], b), [b])
+ ft = infer_expr(top)
+ assert ft.checked_type == expected
+
+
+def test_tuple():
+ tp = relay.TensorType((10,))
+ x = relay.var("x", tp)
+ res = relay.Tuple([x, x])
+ assert infer_expr(res).checked_type == relay.TupleType([tp, tp])
+
+
+def test_ref():
+ x = relay.var("x", "float32")
+ y = relay.var("y", "float32")
+ r = relay.RefCreate(x)
+ st = relay.scalar_type("float32")
+ assert infer_expr(r).checked_type == relay.RefType(st)
+ g = relay.RefRead(r)
+ assert infer_expr(g).checked_type == st
+ w = relay.RefWrite(r, y)
+ assert infer_expr(w).checked_type == relay.TupleType([])
+
+
+def test_free_expr():
+ x = relay.var("x", "float32")
+ y = relay.add(x, x)
+ yy = infer_expr(y, annotate_spans=False)
+ assert tvm.ir.structural_equal(yy.args[0], x, map_free_vars=True)
+ assert yy.checked_type == relay.scalar_type("float32")
+ assert x.vid.same_as(yy.args[0].vid)
+
+
+def test_type_args():
+ x = relay.var("x", shape=(10, 10))
+ y = relay.var("y", shape=(1, 10))
+ z = relay.add(x, y)
+ ty_z = infer_expr(z)
+ ty_args = ty_z.type_args
+ assert len(ty_args) == 2
+ assert ty_args[0].dtype == "float32"
+ assert ty_args[1].dtype == "float32"
+ sh1 = ty_args[0].shape
+ sh2 = ty_args[1].shape
+ assert sh1[0].value == 10
+ assert sh1[1].value == 10
+ assert sh2[0].value == 1
+ assert sh2[1].value == 10
+
+
+def test_global_var_recursion():
+ mod = tvm.IRModule({})
+ gv = relay.GlobalVar("main")
+ x = relay.var("x", shape=[])
+ tt = relay.scalar_type("float32")
+
+ func = relay.Function([x], relay.Call(gv, [x]), tt)
+ mod[gv] = func
+ mod = infer_mod(mod)
+ func_ty = mod["main"].checked_type
+
+ assert func_ty == relay.FuncType([tt], tt)
+
+
+def test_equal():
+ i = relay.var("i", shape=[], dtype="int32")
+ eq = op.equal(i, relay.const(0, dtype="int32"))
+ func = relay.Function([i], eq)
+ ft = infer_expr(func)
+ expected = relay.FuncType([relay.scalar_type("int32")], relay.scalar_type("bool"))
+ assert ft.checked_type == expected
+
+ assert ft.checked_type == relay.FuncType(
+ [relay.scalar_type("int32")], relay.scalar_type("bool")
+ )
+
+
+def test_constructor_type():
+ mod = tvm.IRModule()
+ box, constructor = initialize_box_adt(mod)
+
+ a = relay.TypeVar("a")
+ x = relay.Var("x", a)
+ func = relay.Function([x], constructor(x), box(a), [a])
+ mod["main"] = func
+ mod = infer_mod(mod)
+ func_ty = mod["main"].checked_type
+ box = mod.get_global_type_var("box")
+ expected = relay.FuncType([a], box(a), [a])
+ assert func_ty == expected
+
+
+def test_constructor_call():
+ mod = tvm.IRModule()
+ box, constructor = initialize_box_adt(mod)
+
+ box_unit = constructor(relay.Tuple([]))
+ box_constant = constructor(relay.const(0, "float32"))
+
+ func = relay.Function([], relay.Tuple([box_unit, box_constant]))
+ mod["main"] = func
+ mod = infer_mod(mod)
+ ret_type = mod["main"].checked_type.ret_type.fields
+ # NB(@jroesch): when we annotate spans the ast fragments before
+ # annotation the previous fragments will no longer be directly equal.
+ box = mod.get_global_type_var("box")
+ expected1 = box(relay.TupleType([]))
+ expected2 = box(relay.TensorType((), "float32"))
+ assert ret_type[0] == expected1
+ assert ret_type[1] == expected2
+
+
+def test_adt_match():
+ mod = tvm.IRModule()
+ box, constructor = initialize_box_adt(mod)
+
+ v = relay.Var("v", relay.TensorType((), "float32"))
+ match = relay.Match(
+ constructor(relay.const(0, "float32")),
+ [
+ relay.Clause(
+ relay.PatternConstructor(constructor, [relay.PatternVar(v)]), relay.Tuple([])
+ ),
+ # redundant but shouldn't matter to typechecking
+ relay.Clause(relay.PatternWildcard(), relay.Tuple([])),
+ ],
+ )
+
+ func = relay.Function([], match)
+ mod["main"] = func
+ mod = infer_mod(mod)
+ actual = mod["main"].checked_type.ret_type
+ assert actual == relay.TupleType([])
+
+
+def test_adt_match_type_annotations():
+ mod = tvm.IRModule()
+ box, constructor = initialize_box_adt(mod)
+
+ # the only type annotation is inside the match pattern var
+ # but that should be enough info
+ tt = relay.TensorType((2, 2), "float32")
+ x = relay.Var("x")
+ mv = relay.Var("mv", tt)
+ match = relay.Match(
+ constructor(x),
+ [
+ relay.Clause(
+ relay.PatternConstructor(constructor, [relay.PatternVar(mv)]), relay.Tuple([])
+ )
+ ],
+ )
+
+ mod["main"] = relay.Function([x], match)
+ mod = infer_mod(mod)
+ ft = mod["main"].checked_type
+ assert ft == relay.FuncType([tt], relay.TupleType([]))
+
+
+def test_let_polymorphism():
+ id = relay.Var("id")
+ xt = relay.TypeVar("xt")
+ x = relay.Var("x", xt)
+ body = relay.Tuple([id(relay.const(1)), id(relay.Tuple([]))])
+ body = relay.Let(id, relay.Function([x], x, xt, [xt]), body)
+ body = infer_expr(body)
+ int32 = relay.TensorType((), "int32")
+ tvm.ir.assert_structural_equal(body.checked_type, relay.TupleType([int32, relay.TupleType([])]))
+
+
+def test_if():
+ choice_t = relay.FuncType([], relay.scalar_type("bool"))
+ f = relay.Var("f", choice_t)
+ true_branch = relay.Var("True", relay.TensorType([Any(), 1], dtype="float32"))
+ false_branch = relay.Var("False", relay.TensorType([Any(), Any()], dtype="float32"))
+ top = relay.Function([f, true_branch, false_branch], relay.If(f(), true_branch, false_branch))
+ ft = infer_expr(top)
+ tvm.ir.assert_structural_equal(ft.ret_type, relay.TensorType([Any(), 1], dtype="float32"))
+
+
+def test_type_arg_infer():
+ code = """
+#[version = "0.0.5"]
+def @id[A](%x: A) -> A {
+ %x
+}
+def @main(%f: float32) -> float32 {
+ @id(%f)
+}
+"""
+ mod = tvm.parser.fromtext(code)
+ mod = transform.InferType()(mod)
+ tvm.ir.assert_structural_equal(mod["main"].body.type_args, [relay.TensorType((), "float32")])
+
+
+if __name__ == "__main__":
+ import sys
+
+ pytest.main(sys.argv)
[incubator-tvm] 02/23: Codespan example almost working
Posted by jr...@apache.org.
This is an automated email from the ASF dual-hosted git repository.
jroesch pushed a commit to branch cargo-build
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
commit 77ba30993a7883c142b05e511e8a5a7a91116b2f
Author: Jared Roesch <jr...@octoml.ai>
AuthorDate: Fri Oct 9 23:59:51 2020 -0700
Codespan example almost working
---
rust/tvm-sys/src/packed_func.rs | 1 +
rust/tvm/Cargo.toml | 2 +
rust/tvm/src/bin/tyck.rs | 24 ++++++++
rust/tvm/src/ir/diagnostics.rs | 121 +++++++++++++++++++++++++++++----------
rust/tvm/src/ir/relay/visitor.rs | 24 ++++++++
5 files changed, 143 insertions(+), 29 deletions(-)
diff --git a/rust/tvm-sys/src/packed_func.rs b/rust/tvm-sys/src/packed_func.rs
index f7b289c..7b8d529 100644
--- a/rust/tvm-sys/src/packed_func.rs
+++ b/rust/tvm-sys/src/packed_func.rs
@@ -101,6 +101,7 @@ macro_rules! TVMPODValue {
TVMArgTypeCode_kTVMOpaqueHandle => Handle($value.v_handle),
TVMArgTypeCode_kTVMDLTensorHandle => ArrayHandle($value.v_handle as TVMArrayHandle),
TVMArgTypeCode_kTVMObjectHandle => ObjectHandle($value.v_handle),
+ TVMArgTypeCode_kTVMObjectRValueRefArg => ObjectHandle(*($value.v_handle as *mut *mut c_void)),
TVMArgTypeCode_kTVMModuleHandle => ModuleHandle($value.v_handle),
TVMArgTypeCode_kTVMPackedFuncHandle => FuncHandle($value.v_handle),
TVMArgTypeCode_kTVMNDArrayHandle => NDArrayHandle($value.v_handle),
diff --git a/rust/tvm/Cargo.toml b/rust/tvm/Cargo.toml
index 55fc179..71a4b93 100644
--- a/rust/tvm/Cargo.toml
+++ b/rust/tvm/Cargo.toml
@@ -41,6 +41,8 @@ paste = "0.1"
mashup = "0.1"
once_cell = "^1.3.1"
pyo3 = { version = "0.11.1", optional = true }
+codespan-reporting = "0.9.5"
+structopt = { version = "0.3" }
[features]
default = ["python"]
diff --git a/rust/tvm/src/bin/tyck.rs b/rust/tvm/src/bin/tyck.rs
new file mode 100644
index 0000000..9300412
--- /dev/null
+++ b/rust/tvm/src/bin/tyck.rs
@@ -0,0 +1,24 @@
+use std::path::PathBuf;
+
+use anyhow::Result;
+use structopt::StructOpt;
+
+use tvm::ir::diagnostics::codespan;
+use tvm::ir::IRModule;
+
+
+#[derive(Debug, StructOpt)]
+#[structopt(name = "tyck", about = "Parse and type check a Relay program.")]
+struct Opt {
+ /// Input file
+ #[structopt(parse(from_os_str))]
+ input: PathBuf,
+}
+
+fn main() -> Result<()> {
+ codespan::init().expect("Rust based diagnostics");
+ let opt = Opt::from_args();
+ println!("{:?}", &opt);
+ let file = IRModule::parse_file(opt.input)?;
+ Ok(())
+}
diff --git a/rust/tvm/src/ir/diagnostics.rs b/rust/tvm/src/ir/diagnostics.rs
index 799a10c..e434d3f 100644
--- a/rust/tvm/src/ir/diagnostics.rs
+++ b/rust/tvm/src/ir/diagnostics.rs
@@ -24,13 +24,31 @@
use tvm_macros::{Object, external};
use super::module::IRModule;
-use crate::runtime::{function::{Function, Typed}, array::Array, string::String as TString};
-use crate::runtime::object::{Object, ObjectRef};
+use crate::runtime::{function::{self, Function, ToFunction, Typed}, array::Array, string::String as TString};
+use crate::runtime::object::{Object, ObjectPtr, ObjectRef};
use crate::runtime::function::Result;
use super::span::Span;
type SourceName = ObjectRef;
+// Get the the diagnostic renderer.
+external! {
+ #[name("node.ArrayGetItem")]
+ fn get_renderer() -> DiagnosticRenderer;
+
+ #[name("diagnostics.DiagnosticRenderer")]
+ fn diagnostic_renderer(func: Function) -> DiagnosticRenderer;
+
+ #[name("diagnostics.Emit")]
+ fn emit(ctx: DiagnosticContext, diagnostic: Diagnostic) -> ();
+
+ #[name("diagnostics.DiagnosticContextRender")]
+ fn diagnostic_context_render(ctx: DiagnosticContext) -> ();
+
+ #[name("diagnostics.ClearRenderer")]
+ fn clear_renderer() -> ();
+}
+
/// The diagnostic level, controls the printing of the message.
#[repr(C)]
pub enum DiagnosticLevel {
@@ -171,26 +189,20 @@ pub struct DiagnosticContextNode {
pub renderer: DiagnosticRenderer,
}
-// Get the the diagnostic renderer.
-external! {
- #[name("node.ArrayGetItem")]
- fn get_renderer() -> DiagnosticRenderer;
-
- #[name("diagnostics.DiagnosticRenderer")]
- fn diagnostic_renderer(func: Function) -> DiagnosticRenderer;
-
- #[name("diagnostics.Emit")]
- fn emit(ctx: DiagnosticContext, diagnostic: Diagnostic) -> ();
-
- #[name("diagnostics.DiagnosticContextRender")]
- fn diagnostic_context_render(ctx: DiagnosticContext) -> ();
-}
-
/// A diagnostic context which records active errors
/// and contains a renderer.
impl DiagnosticContext {
- pub fn new(module: IRModule, renderer: DiagnosticRenderer) {
- todo!()
+ pub fn new<F>(module: IRModule, render_func: F) -> DiagnosticContext
+ where F: Fn(DiagnosticContext) -> () + 'static
+ {
+ let renderer = diagnostic_renderer(render_func.to_function()).unwrap();
+ let node = DiagnosticContextNode {
+ base: Object::base_object::<DiagnosticContextNode>(),
+ module,
+ diagnostics: Array::from_vec(vec![]).unwrap(),
+ renderer,
+ };
+ DiagnosticContext(Some(ObjectPtr::new(node)))
}
pub fn default(module: IRModule) -> DiagnosticContext {
@@ -223,17 +235,68 @@ impl DiagnosticContext {
// If the render_func is None it will remove the current custom renderer
// and return to default behavior.
fn override_renderer<F>(opt_func: Option<F>) -> Result<()>
-where F: Fn(DiagnosticContext) -> ()
+where F: Fn(DiagnosticContext) -> () + 'static
{
- todo!()
- // fn ()
- // diagnostic_renderer(func)
- // if render_func:
- // def _render_factory():
- // return DiagnosticRenderer(render_func)
+ match opt_func {
+ None => clear_renderer(),
+ Some(func) => {
+ let func = func.to_function();
+ let render_factory = move || {
+ diagnostic_renderer(func.clone()).unwrap()
+ };
+
+ function::register_override(
+ render_factory,
+ "diagnostics.OverrideRenderer",
+ true)?;
+
+ Ok(())
+ }
+ }
+}
+
+pub mod codespan {
+ use super::*;
+
+ use codespan_reporting::diagnostic::{Diagnostic as CDiagnostic, Label, Severity};
+ use codespan_reporting::files::SimpleFiles;
+ use codespan_reporting::term::termcolor::{ColorChoice, StandardStream};
+
+ pub fn to_diagnostic(diag: super::Diagnostic) -> CDiagnostic<String> {
+ let severity = match diag.level {
+ DiagnosticLevel::Error => Severity::Error,
+ DiagnosticLevel::Warning => Severity::Warning,
+ DiagnosticLevel::Note => Severity::Note,
+ DiagnosticLevel::Help => Severity::Help,
+ DiagnosticLevel::Bug => Severity::Bug,
+ };
+
+ let file_id = "foo".into(); // diag.span.source_name;
+
+ let message: String = diag.message.as_str().unwrap().into();
+ let inner_message: String = "expected `String`, found `Nat`".into();
+ let diagnostic = CDiagnostic::new(severity)
+ .with_message(message)
+ .with_code("EXXX")
+ .with_labels(vec![
+ Label::primary(file_id, 328..331).with_message(inner_message),
+ ]);
+
+ diagnostic
+ }
+
+ pub fn init() -> Result<()> {
+ let mut files: SimpleFiles<String, String> = SimpleFiles::new();
+ let render_fn = move |diag_ctx: DiagnosticContext| {
+ // let source_map = diag_ctx.module.source_map;
+ for diagnostic in diag_ctx.diagnostics {
+
+ }
+ panic!("render_fn");
+ };
- // register_func("diagnostics.OverrideRenderer", _render_factory, override=True)
- // else:
- // _ffi_api.ClearRenderer()
+ override_renderer(Some(render_fn))?;
+ Ok(())
+ }
}
diff --git a/rust/tvm/src/ir/relay/visitor.rs b/rust/tvm/src/ir/relay/visitor.rs
new file mode 100644
index 0000000..3166174
--- /dev/null
+++ b/rust/tvm/src/ir/relay/visitor.rs
@@ -0,0 +1,24 @@
+use super::Expr;
+
+macro_rules! downcast_match {
+ ($id:ident; { $($t:ty => $arm:expr $(,)? )+ , else => $default:expr }) => {
+ $( if let Ok($id) = $id.downcast_clone::<$t>() { $arm } else )+
+ { $default }
+ }
+}
+
+trait ExprVisitorMut {
+ fn visit(&mut self, expr: Expr) {
+ downcast_match!(expr; {
+ else => {
+ panic!()
+ }
+ });
+ }
+
+ fn visit(&mut self, expr: Expr);
+}
+
+// trait ExprTransformer {
+// fn
+// }
[incubator-tvm] 14/23: Format and cleanup
Posted by jr...@apache.org.
This is an automated email from the ASF dual-hosted git repository.
jroesch pushed a commit to branch cargo-build
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
commit 04a9779359380ae383405f5f72a66e79631ee2d5
Author: Jared Roesch <jr...@octoml.ai>
AuthorDate: Fri Oct 16 16:31:37 2020 -0700
Format and cleanup
---
python/tvm/ir/diagnostics/__init__.py | 1 +
rust/compiler-ext/src/lib.rs | 3 +-
rust/tvm-rt/src/array.rs | 11 +++--
rust/tvm-rt/src/errors.rs | 4 +-
rust/tvm-rt/src/function.rs | 2 +-
rust/tvm/src/bin/tyck.rs | 8 +--
rust/tvm/src/ir/diagnostics/codespan.rs | 87 +++++++++++++++++++--------------
rust/tvm/src/ir/mod.rs | 2 +-
rust/tvm/src/ir/module.rs | 5 +-
rust/tvm/src/ir/relay/mod.rs | 2 +-
rust/tvm/src/ir/relay/visitor.rs | 24 ---------
rust/tvm/src/ir/source_map.rs | 13 ++---
rust/tvm/src/ir/span.rs | 2 +-
src/ir/expr.cc | 11 -----
14 files changed, 79 insertions(+), 96 deletions(-)
diff --git a/python/tvm/ir/diagnostics/__init__.py b/python/tvm/ir/diagnostics/__init__.py
index 0ad2a7a..3a6402c 100644
--- a/python/tvm/ir/diagnostics/__init__.py
+++ b/python/tvm/ir/diagnostics/__init__.py
@@ -37,6 +37,7 @@ def get_renderer():
"""
return _ffi_api.GetRenderer()
+
@tvm.register_func("diagnostics.override_renderer")
def override_renderer(render_func):
"""
diff --git a/rust/compiler-ext/src/lib.rs b/rust/compiler-ext/src/lib.rs
index c136d06..346f40f 100644
--- a/rust/compiler-ext/src/lib.rs
+++ b/rust/compiler-ext/src/lib.rs
@@ -36,8 +36,7 @@ tvm::export!(test_fn, test_fn2);
#[no_mangle]
fn compiler_ext_initialize() -> i32 {
let _ = env_logger::try_init();
- tvm_export("rust_ext")
- .expect("failed to initialize Rust compiler_ext");
+ tvm_export("rust_ext").expect("failed to initialize Rust compiler_ext");
log::debug!("done!");
return 0;
}
diff --git a/rust/tvm-rt/src/array.rs b/rust/tvm-rt/src/array.rs
index 032ca79..66e32a7 100644
--- a/rust/tvm-rt/src/array.rs
+++ b/rust/tvm-rt/src/array.rs
@@ -18,8 +18,8 @@
*/
use std::convert::{TryFrom, TryInto};
-use std::marker::PhantomData;
use std::iter::{IntoIterator, Iterator};
+use std::marker::PhantomData;
use crate::errors::Error;
use crate::object::{IsObjectRef, Object, ObjectPtr, ObjectRef};
@@ -93,8 +93,7 @@ impl<T: IsObjectRef> Iterator for IntoIter<T> {
fn next(&mut self) -> Option<Self::Item> {
if self.pos < self.size {
- let item = self.array.get(self.pos)
- .expect("should not fail");
+ let item = self.array.get(self.pos).expect("should not fail");
self.pos += 1;
Some(item)
} else {
@@ -109,7 +108,11 @@ impl<T: IsObjectRef> IntoIterator for Array<T> {
fn into_iter(self) -> Self::IntoIter {
let size = self.len() as isize;
- IntoIter { array: self, pos: 0, size: size }
+ IntoIter {
+ array: self,
+ pos: 0,
+ size: size,
+ }
}
}
diff --git a/rust/tvm-rt/src/errors.rs b/rust/tvm-rt/src/errors.rs
index 3de9f3c..31ce385 100644
--- a/rust/tvm-rt/src/errors.rs
+++ b/rust/tvm-rt/src/errors.rs
@@ -68,7 +68,9 @@ pub enum Error {
Infallible(#[from] std::convert::Infallible),
#[error("a panic occurred while executing a Rust packed function")]
Panic,
- #[error("one or more error diagnostics were emitted, please check diagnostic render for output.")]
+ #[error(
+ "one or more error diagnostics were emitted, please check diagnostic render for output."
+ )]
DiagnosticError(String),
#[error("{0}")]
Raw(String),
diff --git a/rust/tvm-rt/src/function.rs b/rust/tvm-rt/src/function.rs
index 173b60a..4c6f56e 100644
--- a/rust/tvm-rt/src/function.rs
+++ b/rust/tvm-rt/src/function.rs
@@ -128,7 +128,7 @@ impl Function {
type_codes.as_mut_ptr() as *mut c_int,
num_args as c_int,
&mut ret_val as *mut _,
- &mut ret_type_code as *mut _
+ &mut ret_type_code as *mut _,
)
};
diff --git a/rust/tvm/src/bin/tyck.rs b/rust/tvm/src/bin/tyck.rs
index 13470e7..e9c2663 100644
--- a/rust/tvm/src/bin/tyck.rs
+++ b/rust/tvm/src/bin/tyck.rs
@@ -20,9 +20,11 @@ fn main() -> Result<()> {
let opt = Opt::from_args();
println!("{:?}", &opt);
let _module = match IRModule::parse_file(opt.input) {
- Err(ir::module::Error::TVM(Error::DiagnosticError(_))) => { return Ok(()) },
- Err(e) => { return Err(e.into()); },
- Ok(module) => module
+ Err(ir::module::Error::TVM(Error::DiagnosticError(_))) => return Ok(()),
+ Err(e) => {
+ return Err(e.into());
+ }
+ Ok(module) => module,
};
Ok(())
diff --git a/rust/tvm/src/ir/diagnostics/codespan.rs b/rust/tvm/src/ir/diagnostics/codespan.rs
index 9fc1ee0..9a31691 100644
--- a/rust/tvm/src/ir/diagnostics/codespan.rs
+++ b/rust/tvm/src/ir/diagnostics/codespan.rs
@@ -1,3 +1,24 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/// A TVM diagnostics renderer which uses the Rust `codespan`
+/// library to produce error messages.
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
@@ -6,13 +27,8 @@ use codespan_reporting::files::SimpleFiles;
use codespan_reporting::term::termcolor::{ColorChoice, StandardStream};
use codespan_reporting::term::{self, ColorArg};
-use crate::ir::source_map::*;
use super::*;
-
-enum StartOrEnd {
- Start,
- End,
-}
+use crate::ir::source_map::*;
struct ByteRange<FileId> {
file_id: FileId,
@@ -26,7 +42,7 @@ enum FileSpanToByteRange {
/// Map character regions which are larger then 1-byte to length.
lengths: HashMap<isize, isize>,
source: String,
- }
+ },
}
impl FileSpanToByteRange {
@@ -34,24 +50,11 @@ impl FileSpanToByteRange {
let mut last_index = 0;
let mut is_ascii = true;
if source.is_ascii() {
- let line_lengths =
- source
- .lines()
- .map(|line| line.len())
- .collect();
+ let line_lengths = source.lines().map(|line| line.len()).collect();
FileSpanToByteRange::AsciiSource(line_lengths)
} else {
panic!()
}
-
- // for (index, _) in source.char_indices() {
- // if last_index - 1 != last_index {
- // is_ascii = false;
- // } else {
- // panic!();
- // }
- // last_index = index;
- // }
}
fn lookup(&self, span: &Span) -> ByteRange<String> {
@@ -61,22 +64,34 @@ impl FileSpanToByteRange {
match self {
AsciiSource(ref line_lengths) => {
- let start_pos = (&line_lengths[0..(span.line - 1) as usize]).into_iter().sum::<usize>() + (span.column) as usize;
- let end_pos = (&line_lengths[0..(span.end_line - 1) as usize]).into_iter().sum::<usize>() + (span.end_column) as usize;
- ByteRange { file_id: source_name, start_pos, end_pos }
- },
- _ => panic!()
+ let start_pos = (&line_lengths[0..(span.line - 1) as usize])
+ .into_iter()
+ .sum::<usize>()
+ + (span.column) as usize;
+ let end_pos = (&line_lengths[0..(span.end_line - 1) as usize])
+ .into_iter()
+ .sum::<usize>()
+ + (span.end_column) as usize;
+ ByteRange {
+ file_id: source_name,
+ start_pos,
+ end_pos,
+ }
+ }
+ _ => panic!(),
}
}
}
struct SpanToByteRange {
- map: HashMap<String, FileSpanToByteRange>
+ map: HashMap<String, FileSpanToByteRange>,
}
impl SpanToByteRange {
fn new() -> SpanToByteRange {
- SpanToByteRange { map: HashMap::new() }
+ SpanToByteRange {
+ map: HashMap::new(),
+ }
}
pub fn add_source(&mut self, source: Source) {
@@ -86,7 +101,8 @@ impl SpanToByteRange {
panic!()
} else {
let source = source.source.as_str().expect("fpp").into();
- self.map.insert(source_name, FileSpanToByteRange::new(source));
+ self.map
+ .insert(source_name, FileSpanToByteRange::new(source));
}
}
@@ -143,9 +159,10 @@ impl DiagnosticState {
let diagnostic = CDiagnostic::new(severity)
.with_message(message)
.with_code("EXXX")
- .with_labels(vec![
- Label::primary(file_id, byte_range.start_pos..byte_range.end_pos)
- ]);
+ .with_labels(vec![Label::primary(
+ file_id,
+ byte_range.start_pos..byte_range.end_pos,
+ )]);
diagnostic
}
@@ -161,11 +178,7 @@ fn renderer(state: &mut DiagnosticState, diag_ctx: DiagnosticContext) {
Ok(source) => {
state.add_source(source);
let diagnostic = state.to_diagnostic(diagnostic);
- term::emit(
- &mut writer.lock(),
- &config,
- &state.files,
- &diagnostic).unwrap();
+ term::emit(&mut writer.lock(), &config, &state.files, &diagnostic).unwrap();
}
}
}
diff --git a/rust/tvm/src/ir/mod.rs b/rust/tvm/src/ir/mod.rs
index df9bc68..6d51580 100644
--- a/rust/tvm/src/ir/mod.rs
+++ b/rust/tvm/src/ir/mod.rs
@@ -25,8 +25,8 @@ pub mod function;
pub mod module;
pub mod op;
pub mod relay;
-pub mod span;
pub mod source_map;
+pub mod span;
pub mod tir;
pub mod ty;
diff --git a/rust/tvm/src/ir/module.rs b/rust/tvm/src/ir/module.rs
index 11d6c49..443915f 100644
--- a/rust/tvm/src/ir/module.rs
+++ b/rust/tvm/src/ir/module.rs
@@ -32,7 +32,6 @@ use super::expr::GlobalVar;
use super::function::BaseFunc;
use super::source_map::SourceMap;
-
// TODO(@jroesch): define type
type TypeData = ObjectRef;
type GlobalTypeVar = ObjectRef;
@@ -133,7 +132,9 @@ impl IRModule {
parse_module(file_name.into(), source.into())
}
- pub fn parse_file<P: 'static + AsRef<Path>>(file_path: P) -> std::result::Result<IRModule, Error> {
+ pub fn parse_file<P: 'static + AsRef<Path>>(
+ file_path: P,
+ ) -> std::result::Result<IRModule, Error> {
let file_path = file_path.as_ref();
let file_path_as_str = file_path.to_str().unwrap().to_string();
let source = std::fs::read_to_string(file_path)?;
diff --git a/rust/tvm/src/ir/relay/mod.rs b/rust/tvm/src/ir/relay/mod.rs
index 4b09128..530b120 100644
--- a/rust/tvm/src/ir/relay/mod.rs
+++ b/rust/tvm/src/ir/relay/mod.rs
@@ -27,8 +27,8 @@ use crate::runtime::{object::*, String as TString};
use super::attrs::Attrs;
use super::expr::BaseExprNode;
use super::function::BaseFuncNode;
-use super::ty::{Type, TypeNode};
use super::span::Span;
+use super::ty::{Type, TypeNode};
use tvm_macros::Object;
use tvm_rt::NDArray;
diff --git a/rust/tvm/src/ir/relay/visitor.rs b/rust/tvm/src/ir/relay/visitor.rs
deleted file mode 100644
index 3166174..0000000
--- a/rust/tvm/src/ir/relay/visitor.rs
+++ /dev/null
@@ -1,24 +0,0 @@
-use super::Expr;
-
-macro_rules! downcast_match {
- ($id:ident; { $($t:ty => $arm:expr $(,)? )+ , else => $default:expr }) => {
- $( if let Ok($id) = $id.downcast_clone::<$t>() { $arm } else )+
- { $default }
- }
-}
-
-trait ExprVisitorMut {
- fn visit(&mut self, expr: Expr) {
- downcast_match!(expr; {
- else => {
- panic!()
- }
- });
- }
-
- fn visit(&mut self, expr: Expr);
-}
-
-// trait ExprTransformer {
-// fn
-// }
diff --git a/rust/tvm/src/ir/source_map.rs b/rust/tvm/src/ir/source_map.rs
index ebe7e46..b0cc0d8 100644
--- a/rust/tvm/src/ir/source_map.rs
+++ b/rust/tvm/src/ir/source_map.rs
@@ -19,7 +19,7 @@
use crate::runtime::map::Map;
use crate::runtime::object::Object;
-use crate::runtime::string::{String as TString};
+use crate::runtime::string::String as TString;
use super::span::{SourceName, Span};
@@ -39,12 +39,10 @@ pub struct SourceNode {
/// The raw source. */
pub source: TString,
-
- // A mapping of line breaks into the raw source.
- // std::vector<std::pair<int, int>> line_map;
+ // A mapping of line breaks into the raw source.
+ // std::vector<std::pair<int, int>> line_map;
}
-
// class Source : public ObjectRef {
// public:
// TVM_DLL Source(SourceName src_name, std::string source);
@@ -53,7 +51,6 @@ pub struct SourceNode {
// TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Source, ObjectRef, SourceNode);
// };
-
/// A mapping from a unique source name to source fragments.
#[repr(C)]
#[derive(Object)]
@@ -61,6 +58,6 @@ pub struct SourceNode {
#[ref_name = "SourceMap"]
pub struct SourceMapNode {
pub base: Object,
- /// The source mapping.
- pub source_map: Map<SourceName, Source>,
+ /// The source mapping.
+ pub source_map: Map<SourceName, Source>,
}
diff --git a/rust/tvm/src/ir/span.rs b/rust/tvm/src/ir/span.rs
index c54fd51..afcbe9c 100644
--- a/rust/tvm/src/ir/span.rs
+++ b/rust/tvm/src/ir/span.rs
@@ -18,7 +18,7 @@
* under the License.
*/
-use crate::runtime::{ObjectRef, Object, String as TString};
+use crate::runtime::{Object, ObjectRef, String as TString};
use tvm_macros::Object;
/// A source file name, contained in a Span.
diff --git a/src/ir/expr.cc b/src/ir/expr.cc
index 5110eef..67e5cea 100644
--- a/src/ir/expr.cc
+++ b/src/ir/expr.cc
@@ -192,15 +192,4 @@ TVM_REGISTER_GLOBAL("ir.DebugPrint").set_body_typed([](ObjectRef ref) {
return ss.str();
});
-
-
} // namespace tvm
-
-#ifdef RUST_COMPILER_EXT
-
-extern "C" {
- int compiler_ext_initialize();
- static int test = compiler_ext_initialize();
-}
-
-#endif
[incubator-tvm] 10/23: Fix calling
Posted by jr...@apache.org.
This is an automated email from the ASF dual-hosted git repository.
jroesch pushed a commit to branch cargo-build
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
commit eeb86c63d693288f8d406aed6b1b0df6d28e4b07
Author: Jared Roesch <ro...@gmail.com>
AuthorDate: Thu Oct 15 21:37:34 2020 -0700
Fix calling
---
rust/tvm-rt/src/function.rs | 28 +++++++++++++++++-----------
rust/tvm/src/bin/tyck.rs | 2 +-
rust/tvm/src/ir/diagnostics/mod.rs | 4 +---
3 files changed, 19 insertions(+), 15 deletions(-)
diff --git a/rust/tvm-rt/src/function.rs b/rust/tvm-rt/src/function.rs
index bae06e9..c7aebdd 100644
--- a/rust/tvm-rt/src/function.rs
+++ b/rust/tvm-rt/src/function.rs
@@ -33,6 +33,7 @@ use std::{
};
use crate::errors::Error;
+use crate::object::ObjectPtr;
pub use super::to_function::{ToFunction, Typed};
pub use tvm_sys::{ffi, ArgValue, RetValue};
@@ -120,21 +121,26 @@ impl Function {
let mut ret_val = ffi::TVMValue { v_int64: 0 };
let mut ret_type_code = 0i32;
- check_call!(ffi::TVMFuncCall(
- self.handle,
- values.as_mut_ptr() as *mut ffi::TVMValue,
- type_codes.as_mut_ptr() as *mut c_int,
- num_args as c_int,
- &mut ret_val as *mut _,
- &mut ret_type_code as *mut _
- ));
+ let ret_code = unsafe {
+ ffi::TVMFuncCall(
+ self.handle,
+ values.as_mut_ptr() as *mut ffi::TVMValue,
+ type_codes.as_mut_ptr() as *mut c_int,
+ num_args as c_int,
+ &mut ret_val as *mut _,
+ &mut ret_type_code as *mut _
+ )
+ };
+
+ if ret_code != 0 {
+ return Err(Error::CallFailed(crate::get_last_error().into()));
+ }
let rv = RetValue::from_tvm_value(ret_val, ret_type_code as u32);
match rv {
RetValue::ObjectHandle(object) => {
- let optr = crate::object::ObjectPtr::from_raw(object as _).unwrap();
- // println!("after wrapped call: {}", optr.count());
- crate::object::ObjectPtr::leak(optr);
+ let optr = ObjectPtr::from_raw(object as _).unwrap();
+ ObjectPtr::leak(optr);
}
_ => {}
};
diff --git a/rust/tvm/src/bin/tyck.rs b/rust/tvm/src/bin/tyck.rs
index e0c7136..fbab027 100644
--- a/rust/tvm/src/bin/tyck.rs
+++ b/rust/tvm/src/bin/tyck.rs
@@ -18,7 +18,7 @@ fn main() -> Result<()> {
codespan::init().expect("Rust based diagnostics");
let opt = Opt::from_args();
println!("{:?}", &opt);
- let module = IRModule::parse_file(opt.input)?;
+ let module = IRModule::parse_file(opt.input);
// for (k, v) in module.functions {
// println!("Function name: {:?}", v);
diff --git a/rust/tvm/src/ir/diagnostics/mod.rs b/rust/tvm/src/ir/diagnostics/mod.rs
index fce214a..039d1ed 100644
--- a/rust/tvm/src/ir/diagnostics/mod.rs
+++ b/rust/tvm/src/ir/diagnostics/mod.rs
@@ -207,9 +207,7 @@ impl DiagnosticContext {
}
}
-// Override the global diagnostics renderer.
-// Params
-// ------
+/// Override the global diagnostics renderer.
// render_func: Option[Callable[[DiagnosticContext], None]]
// If the render_func is None it will remove the current custom renderer
// and return to default behavior.
[incubator-tvm] 18/23: WIP
Posted by jr...@apache.org.
This is an automated email from the ASF dual-hosted git repository.
jroesch pushed a commit to branch cargo-build
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
commit 49246bff342eb59757cecc34a9a9465a2e3c063d
Author: Jared Roesch <ro...@gmail.com>
AuthorDate: Wed Oct 21 14:09:37 2020 -0700
WIP
---
rust/tvm-macros/src/external.rs | 5 ++-
rust/tvm-macros/src/lib.rs | 1 +
rust/tvm/src/ir/module.rs | 67 +++++++++++++----------------------------
3 files changed, 26 insertions(+), 47 deletions(-)
diff --git a/rust/tvm-macros/src/external.rs b/rust/tvm-macros/src/external.rs
index 802d7ae..de8ada3 100644
--- a/rust/tvm-macros/src/external.rs
+++ b/rust/tvm-macros/src/external.rs
@@ -17,6 +17,7 @@
* under the License.
*/
use proc_macro2::Span;
+use proc_macro_error::abort;
use quote::quote;
use syn::parse::{Parse, ParseStream, Result};
@@ -109,7 +110,9 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
.iter()
.map(|ty_param| match ty_param {
syn::GenericParam::Type(param) => param.clone(),
- _ => panic!(),
+ _ => abort! { ty_param,
+ "Only supports type parameters."
+ }
})
.collect();
diff --git a/rust/tvm-macros/src/lib.rs b/rust/tvm-macros/src/lib.rs
index 603e1ce..ab75c92 100644
--- a/rust/tvm-macros/src/lib.rs
+++ b/rust/tvm-macros/src/lib.rs
@@ -35,6 +35,7 @@ pub fn macro_impl(input: TokenStream) -> TokenStream {
TokenStream::from(object::macro_impl(input))
}
+#[proc_macro_error]
#[proc_macro]
pub fn external(input: TokenStream) -> TokenStream {
external::macro_impl(input)
diff --git a/rust/tvm/src/ir/module.rs b/rust/tvm/src/ir/module.rs
index 443915f..8918bdc 100644
--- a/rust/tvm/src/ir/module.rs
+++ b/rust/tvm/src/ir/module.rs
@@ -63,6 +63,8 @@ external! {
#[name("parser.ParseExpr")]
fn parse_expression(file_name: TVMString, source: TVMString) -> IRModule;
// Module methods
+ #[name("ir.Module_Add")]
+ fn module_add_def(module: IRModule, type_name: GlobalVar, expr: relay::Expr, update: bool) -> ();
#[name("ir.Module_AddDef")]
fn module_add_def(module: IRModule, type_name: GlobalTypeVar, type_data: TypeData, update: bool) -> ();
#[name("ir.Module_GetGlobalVar")]
@@ -73,55 +75,28 @@ external! {
fn module_lookup(module: IRModule, var: GlobalVar) -> BaseFunc;
#[name("ir.Module_Lookup_str")]
fn module_lookup_str(module: IRModule, name: TVMString) -> BaseFunc;
+ #[name("ir.Module_GetGlobalTypeVars")]
+ fn module_get_global_type_vars() -> Array<GlobalTypeVar>;
+ #[name("ir.Module_ContainGlobalVar")]
+ fn module_get_global_var(name: TVMString) -> bool;
+ #[name("ir.Module_ContainGlobalTypeVar")]
+ fn module_get_global_type_var(name: TVMString) -> bool;
+ #[name("ir.Module_LookupDef")]
+ fn module_lookup_def(module: IRModule, global: GlobalTypeVar) -> TypeDef;
+ #[name("ir.Module_LookupDef_str")]
+ fn module_lookup_def_str(module: IRModule, global: GlobalTypeVar) -> TypeDef;
+ #[name("ir.Module_LookupTag")]
+ fn module_lookup_tag(module: IRModule, tag: i32) -> Constructor;
+ #[name("ir.Module_FromExpr")]
+ fn module_from_expr(expr: relay::Expr, funcs: Map<GlobalVar, BaseFunc>, types: Map<GlobalTypeVar, TypeData>) -> IRModule;
+ #[name("ir.Module_Import")]
+ fn module_import(module: IRModule, path: TVMString);
+ #[name("ir.Module_ImportFromStd")]
+ fn module_import_from_std(module: IRModule, path: TVMString);
}
-// TVM_REGISTER_GLOBAL("ir.Module_GetGlobalTypeVars")
-// .set_body_method<IRModule>(&IRModuleNode::GetGlobalTypeVars);
+// Note: we don't expose update here as update is going to be removed.
-// TVM_REGISTER_GLOBAL("ir.Module_ContainGlobalVar")
-// .set_body_method<IRModule>(&IRModuleNode::ContainGlobalVar);
-
-// TVM_REGISTER_GLOBAL("ir.Module_GetGlobalTypeVar")
-// .set_body_method<IRModule>(&IRModuleNode::GetGlobalTypeVar);
-
-// TVM_REGISTER_GLOBAL("ir.Module_LookupDef").set_body_typed([](IRModule mod, GlobalTypeVar var) {
-// return mod->LookupTypeDef(var);
-// });
-
-// TVM_REGISTER_GLOBAL("ir.Module_LookupDef_str").set_body_typed([](IRModule mod, String var) {
-// return mod->LookupTypeDef(var);
-// });
-
-// TVM_REGISTER_GLOBAL("ir.Module_LookupTag").set_body_typed([](IRModule mod, int32_t tag) {
-// return mod->LookupTag(tag);
-// });
-
-// TVM_REGISTER_GLOBAL("ir.Module_FromExpr")
-// .set_body_typed([](RelayExpr e, tvm::Map<GlobalVar, BaseFunc> funcs,
-// tvm::Map<GlobalTypeVar, TypeData> type_defs) {
-// return IRModule::FromExpr(e, funcs, type_defs);
-// });
-
-// TVM_REGISTER_GLOBAL("ir.Module_Update").set_body_typed([](IRModule mod, IRModule from) {
-// mod->Update(from);
-// });
-
-// TVM_REGISTER_GLOBAL("ir.Module_UpdateFunction")
-// .set_body_typed([](IRModule mod, GlobalVar gv, BaseFunc func) { mod->Update(gv, func); });
-
-// TVM_REGISTER_GLOBAL("ir.Module_Import").set_body_typed([](IRModule mod, String path) {
-// mod->Import(path);
-// });
-
-// TVM_REGISTER_GLOBAL("ir.Module_ImportFromStd").set_body_typed([](IRModule mod, String path) {
-// mod->ImportFromStd(path);
-// });
-
-// TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-// .set_dispatch<IRModuleNode>([](const ObjectRef& ref, ReprPrinter* p) {
-// auto* node = static_cast<const IRModuleNode*>(ref.get());
-// p->stream << "IRModuleNode( " << node->functions << ")";
-// });
impl IRModule {
pub fn parse<N, S>(file_name: N, source: S) -> Result<IRModule>
[incubator-tvm] 08/23: Clean up exporting to show off new
diagnostics
Posted by jr...@apache.org.
This is an automated email from the ASF dual-hosted git repository.
jroesch pushed a commit to branch cargo-build
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
commit 20c6a28606c053fbf9adf1c36c85fd608e63e024
Author: Jared Roesch <ro...@gmail.com>
AuthorDate: Thu Oct 15 17:03:00 2020 -0700
Clean up exporting to show off new diagnostics
---
rust/compiler-ext/src/lib.rs | 12 ++++++++++--
rust/tvm-rt/src/array.rs | 32 ++++++++++++++++++++++++++++++++
rust/tvm/src/bin/tyck.rs | 7 ++++++-
rust/tvm/src/ir/diagnostics.rs | 10 +++++-----
rust/tvm/src/ir/mod.rs | 1 +
rust/tvm/src/ir/module.rs | 3 +++
rust/tvm/src/ir/source_map.rs | 26 +++++++++++++++-----------
rust/tvm/src/lib.rs | 24 ++++++++++++++++++++++++
8 files changed, 96 insertions(+), 19 deletions(-)
diff --git a/rust/compiler-ext/src/lib.rs b/rust/compiler-ext/src/lib.rs
index 3e37d21..c136d06 100644
--- a/rust/compiler-ext/src/lib.rs
+++ b/rust/compiler-ext/src/lib.rs
@@ -22,14 +22,22 @@ use tvm;
use tvm::runtime::function::register_override;
fn test_fn() -> Result<(), tvm::Error> {
- println!("Hello from Rust!");
+ println!("Hello Greg from Rust!");
Ok(())
}
+fn test_fn2(message: tvm::runtime::string::String) -> Result<(), tvm::Error> {
+ println!("The message: {}", message);
+ Ok(())
+}
+
+tvm::export!(test_fn, test_fn2);
+
#[no_mangle]
fn compiler_ext_initialize() -> i32 {
let _ = env_logger::try_init();
- register_override(test_fn, "rust_ext.test_fn", true).expect("failed to initialize simplifier");
+ tvm_export("rust_ext")
+ .expect("failed to initialize Rust compiler_ext");
log::debug!("done!");
return 0;
}
diff --git a/rust/tvm-rt/src/array.rs b/rust/tvm-rt/src/array.rs
index 5e19cef..032ca79 100644
--- a/rust/tvm-rt/src/array.rs
+++ b/rust/tvm-rt/src/array.rs
@@ -19,6 +19,7 @@
use std::convert::{TryFrom, TryInto};
use std::marker::PhantomData;
+use std::iter::{IntoIterator, Iterator};
use crate::errors::Error;
use crate::object::{IsObjectRef, Object, ObjectPtr, ObjectRef};
@@ -81,6 +82,37 @@ impl<T: IsObjectRef> Array<T> {
}
}
+pub struct IntoIter<T: IsObjectRef> {
+ array: Array<T>,
+ pos: isize,
+ size: isize,
+}
+
+impl<T: IsObjectRef> Iterator for IntoIter<T> {
+ type Item = T;
+
+ fn next(&mut self) -> Option<Self::Item> {
+ if self.pos < self.size {
+ let item = self.array.get(self.pos)
+ .expect("should not fail");
+ self.pos += 1;
+ Some(item)
+ } else {
+ None
+ }
+ }
+}
+
+impl<T: IsObjectRef> IntoIterator for Array<T> {
+ type Item = T;
+ type IntoIter = IntoIter<T>;
+
+ fn into_iter(self) -> Self::IntoIter {
+ let size = self.len() as isize;
+ IntoIter { array: self, pos: 0, size: size }
+ }
+}
+
impl<T: IsObjectRef> From<Array<T>> for ArgValue<'static> {
fn from(array: Array<T>) -> ArgValue<'static> {
array.object.into()
diff --git a/rust/tvm/src/bin/tyck.rs b/rust/tvm/src/bin/tyck.rs
index b869012..e0c7136 100644
--- a/rust/tvm/src/bin/tyck.rs
+++ b/rust/tvm/src/bin/tyck.rs
@@ -18,6 +18,11 @@ fn main() -> Result<()> {
codespan::init().expect("Rust based diagnostics");
let opt = Opt::from_args();
println!("{:?}", &opt);
- let file = IRModule::parse_file(opt.input)?;
+ let module = IRModule::parse_file(opt.input)?;
+
+ // for (k, v) in module.functions {
+ // println!("Function name: {:?}", v);
+ // }
+
Ok(())
}
diff --git a/rust/tvm/src/ir/diagnostics.rs b/rust/tvm/src/ir/diagnostics.rs
index b76e43f..4975a45 100644
--- a/rust/tvm/src/ir/diagnostics.rs
+++ b/rust/tvm/src/ir/diagnostics.rs
@@ -135,6 +135,7 @@ pub struct DiagnosticRendererNode {
pub base: Object,
// TODO(@jroesch): we can't easily exposed packed functions due to
// memory layout
+ // missing field here
}
// def render(self, ctx):
@@ -283,11 +284,10 @@ pub mod codespan {
pub fn init() -> Result<()> {
let mut files: SimpleFiles<String, String> = SimpleFiles::new();
let render_fn = move |diag_ctx: DiagnosticContext| {
- // let source_map = diag_ctx.module.source_map;
- // for diagnostic in diag_ctx.diagnostics {
-
- // }
- panic!("render_fn");
+ let source_map = diag_ctx.module.source_map.clone();
+ for diagnostic in diag_ctx.diagnostics.clone() {
+ println!("Diagnostic: {}", diagnostic.message);
+ }
};
override_renderer(Some(render_fn))?;
diff --git a/rust/tvm/src/ir/mod.rs b/rust/tvm/src/ir/mod.rs
index 401b6c2..df9bc68 100644
--- a/rust/tvm/src/ir/mod.rs
+++ b/rust/tvm/src/ir/mod.rs
@@ -26,6 +26,7 @@ pub mod module;
pub mod op;
pub mod relay;
pub mod span;
+pub mod source_map;
pub mod tir;
pub mod ty;
diff --git a/rust/tvm/src/ir/module.rs b/rust/tvm/src/ir/module.rs
index e0444b3..5156e74 100644
--- a/rust/tvm/src/ir/module.rs
+++ b/rust/tvm/src/ir/module.rs
@@ -25,6 +25,7 @@ use crate::runtime::{external, Object, ObjectRef};
use super::expr::GlobalVar;
use super::function::BaseFunc;
+use super::source_map::SourceMap;
use std::io::Result as IOResult;
use std::path::Path;
@@ -43,6 +44,8 @@ pub struct IRModuleNode {
pub base: Object,
pub functions: Map<GlobalVar, BaseFunc>,
pub type_definitions: Map<GlobalTypeVar, TypeData>,
+ pub source_map: SourceMap,
+ // TODO(@jroesch): this is missing some fields
}
external! {
diff --git a/rust/tvm/src/ir/source_map.rs b/rust/tvm/src/ir/source_map.rs
index e6c0371..56c0830 100644
--- a/rust/tvm/src/ir/source_map.rs
+++ b/rust/tvm/src/ir/source_map.rs
@@ -12,7 +12,7 @@
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
+ * KIND, either exprss or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
@@ -20,23 +20,27 @@
use crate::runtime::map::Map;
use crate::runtime::object::Object;
+use super::span::{SourceName, Span};
+
+use tvm_macros::Object;
+
/// A program source in any language.
///
/// Could represent the source from an ML framework or a source of an IRModule.
#[repr(C)]
#[derive(Object)]
#[type_key = "Source"]
-#[ref_key = "Source"]
-struct SourceNode {
+#[ref_name = "Source"]
+pub struct SourceNode {
pub base: Object,
- /*! \brief The source name. */
- SourceName source_name;
+ /// The source name. */
+ pub source_name: SourceName,
- /*! \brief The raw source. */
- String source;
+ /// The raw source. */
+ source: String,
- /*! \brief A mapping of line breaks into the raw source. */
- std::vector<std::pair<int, int>> line_map;
+ // A mapping of line breaks into the raw source.
+ // std::vector<std::pair<int, int>> line_map;
}
@@ -53,8 +57,8 @@ struct SourceNode {
#[repr(C)]
#[derive(Object)]
#[type_key = "SourceMap"]
-#[ref_key = "SourceMap"]
-struct SourceMapNode {
+#[ref_name = "SourceMap"]
+pub struct SourceMapNode {
pub base: Object,
/// The source mapping.
pub source_map: Map<SourceName, Source>,
diff --git a/rust/tvm/src/lib.rs b/rust/tvm/src/lib.rs
index 36c7503..d193f09 100644
--- a/rust/tvm/src/lib.rs
+++ b/rust/tvm/src/lib.rs
@@ -47,3 +47,27 @@ pub mod runtime;
pub mod transform;
pub use runtime::version;
+
+#[macro_export]
+macro_rules! export {
+ ($($fn_names:expr),*) => {
+ pub fn tvm_export(ns: &str) -> Result<(), tvm::Error> {
+ $(
+ register_override($fn_name, concat!($ns, stringfy!($fn_name)), true)?;
+ )*
+ Ok(())
+ }
+ }
+}
+
+#[macro_export]
+macro_rules! export_mod {
+ ($ns:expr, $($mod_name:expr),*) => {
+ pub fn tvm_mod_export() -> Result<(), tvm::Error> {
+ $(
+ $mod_names::tvm_export($ns)?;
+ )*
+ Ok(())
+ }
+ }
+}
[incubator-tvm] 12/23: Rust Diagnostics work
Posted by jr...@apache.org.
This is an automated email from the ASF dual-hosted git repository.
jroesch pushed a commit to branch cargo-build
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
commit 6e1346748e08255f220c3e6cf72c59a8a3f6ef29
Author: Jared Roesch <jr...@octoml.ai>
AuthorDate: Fri Oct 16 14:17:01 2020 -0700
Rust Diagnostics work
---
rust/tvm-rt/src/errors.rs | 15 ++++
rust/tvm-rt/src/function.rs | 7 +-
rust/tvm/src/bin/tyck.rs | 13 ++--
rust/tvm/src/ir/diagnostics/codespan.rs | 126 ++++++++++++++++++++++----------
4 files changed, 117 insertions(+), 44 deletions(-)
diff --git a/rust/tvm-rt/src/errors.rs b/rust/tvm-rt/src/errors.rs
index c884c56..3de9f3c 100644
--- a/rust/tvm-rt/src/errors.rs
+++ b/rust/tvm-rt/src/errors.rs
@@ -68,6 +68,21 @@ pub enum Error {
Infallible(#[from] std::convert::Infallible),
#[error("a panic occurred while executing a Rust packed function")]
Panic,
+ #[error("one or more error diagnostics were emitted, please check diagnostic render for output.")]
+ DiagnosticError(String),
+ #[error("{0}")]
+ Raw(String),
+}
+
+impl Error {
+ pub fn from_raw_tvm(raw: &str) -> Error {
+ let err_header = raw.find(":").unwrap_or(0);
+ let (err_ty, err_content) = raw.split_at(err_header);
+ match err_ty {
+ "DiagnosticError" => Error::DiagnosticError((&err_content[1..]).into()),
+ _ => Error::Raw(raw.into()),
+ }
+ }
}
impl Error {
diff --git a/rust/tvm-rt/src/function.rs b/rust/tvm-rt/src/function.rs
index c7aebdd..173b60a 100644
--- a/rust/tvm-rt/src/function.rs
+++ b/rust/tvm-rt/src/function.rs
@@ -133,7 +133,12 @@ impl Function {
};
if ret_code != 0 {
- return Err(Error::CallFailed(crate::get_last_error().into()));
+ let raw_error = crate::get_last_error();
+ let error = match Error::from_raw_tvm(raw_error) {
+ Error::Raw(string) => Error::CallFailed(string),
+ e => e,
+ };
+ return Err(error);
}
let rv = RetValue::from_tvm_value(ret_val, ret_type_code as u32);
diff --git a/rust/tvm/src/bin/tyck.rs b/rust/tvm/src/bin/tyck.rs
index fbab027..13470e7 100644
--- a/rust/tvm/src/bin/tyck.rs
+++ b/rust/tvm/src/bin/tyck.rs
@@ -4,7 +4,8 @@ use anyhow::Result;
use structopt::StructOpt;
use tvm::ir::diagnostics::codespan;
-use tvm::ir::IRModule;
+use tvm::ir::{self, IRModule};
+use tvm::runtime::Error;
#[derive(Debug, StructOpt)]
#[structopt(name = "tyck", about = "Parse and type check a Relay program.")]
@@ -18,11 +19,11 @@ fn main() -> Result<()> {
codespan::init().expect("Rust based diagnostics");
let opt = Opt::from_args();
println!("{:?}", &opt);
- let module = IRModule::parse_file(opt.input);
-
- // for (k, v) in module.functions {
- // println!("Function name: {:?}", v);
- // }
+ let _module = match IRModule::parse_file(opt.input) {
+ Err(ir::module::Error::TVM(Error::DiagnosticError(_))) => { return Ok(()) },
+ Err(e) => { return Err(e.into()); },
+ Ok(module) => module
+ };
Ok(())
}
diff --git a/rust/tvm/src/ir/diagnostics/codespan.rs b/rust/tvm/src/ir/diagnostics/codespan.rs
index 80a8784..9fc1ee0 100644
--- a/rust/tvm/src/ir/diagnostics/codespan.rs
+++ b/rust/tvm/src/ir/diagnostics/codespan.rs
@@ -4,6 +4,7 @@ use std::sync::{Arc, Mutex};
use codespan_reporting::diagnostic::{Diagnostic as CDiagnostic, Label, Severity};
use codespan_reporting::files::SimpleFiles;
use codespan_reporting::term::termcolor::{ColorChoice, StandardStream};
+use codespan_reporting::term::{self, ColorArg};
use crate::ir::source_map::*;
use super::*;
@@ -13,8 +14,14 @@ enum StartOrEnd {
End,
}
+struct ByteRange<FileId> {
+ file_id: FileId,
+ start_pos: usize,
+ end_pos: usize,
+}
+
enum FileSpanToByteRange {
- AsciiSource,
+ AsciiSource(Vec<usize>),
Utf8 {
/// Map character regions which are larger then 1-byte to length.
lengths: HashMap<isize, isize>,
@@ -27,7 +34,12 @@ impl FileSpanToByteRange {
let mut last_index = 0;
let mut is_ascii = true;
if source.is_ascii() {
- FileSpanToByteRange::AsciiSource
+ let line_lengths =
+ source
+ .lines()
+ .map(|line| line.len())
+ .collect();
+ FileSpanToByteRange::AsciiSource(line_lengths)
} else {
panic!()
}
@@ -41,6 +53,21 @@ impl FileSpanToByteRange {
// last_index = index;
// }
}
+
+ fn lookup(&self, span: &Span) -> ByteRange<String> {
+ use FileSpanToByteRange::*;
+
+ let source_name: String = span.source_name.name.as_str().unwrap().into();
+
+ match self {
+ AsciiSource(ref line_lengths) => {
+ let start_pos = (&line_lengths[0..(span.line - 1) as usize]).into_iter().sum::<usize>() + (span.column) as usize;
+ let end_pos = (&line_lengths[0..(span.end_line - 1) as usize]).into_iter().sum::<usize>() + (span.end_column) as usize;
+ ByteRange { file_id: source_name, start_pos, end_pos }
+ },
+ _ => panic!()
+ }
+ }
}
struct SpanToByteRange {
@@ -62,41 +89,22 @@ impl SpanToByteRange {
self.map.insert(source_name, FileSpanToByteRange::new(source));
}
}
-}
-
-struct ByteRange<FileId> {
- file_id: FileId,
- start_pos: usize,
- end_pos: usize,
-}
-
-
-pub fn to_diagnostic(diag: super::Diagnostic) -> CDiagnostic<String> {
- let severity = match diag.level {
- DiagnosticLevel::Error => Severity::Error,
- DiagnosticLevel::Warning => Severity::Warning,
- DiagnosticLevel::Note => Severity::Note,
- DiagnosticLevel::Help => Severity::Help,
- DiagnosticLevel::Bug => Severity::Bug,
- };
-
- let file_id = "foo".into(); // diag.span.source_name;
- let message: String = diag.message.as_str().unwrap().into();
- let inner_message: String = "expected `String`, found `Nat`".into();
- let diagnostic = CDiagnostic::new(severity)
- .with_message(message)
- .with_code("EXXX")
- .with_labels(vec![
- Label::primary(file_id, 328..331).with_message(inner_message)
- ]);
+ pub fn lookup(&self, span: &Span) -> ByteRange<String> {
+ let source_name: String = span.source_name.name.as_str().expect("foo").into();
- diagnostic
+ match self.map.get(&source_name) {
+ Some(file_span_to_bytes) => file_span_to_bytes.lookup(span),
+ None => panic!(),
+ }
+ }
}
struct DiagnosticState {
files: SimpleFiles<String, String>,
span_map: SpanToByteRange,
+ // todo unify wih source name
+ source_to_id: HashMap<String, usize>,
}
impl DiagnosticState {
@@ -104,26 +112,70 @@ impl DiagnosticState {
DiagnosticState {
files: SimpleFiles::new(),
span_map: SpanToByteRange::new(),
+ source_to_id: HashMap::new(),
}
}
+
+ fn add_source(&mut self, source: Source) {
+ let source_str: String = source.source.as_str().unwrap().into();
+ let source_name: String = source.source_name.name.as_str().unwrap().into();
+ self.span_map.add_source(source);
+ let file_id = self.files.add(source_name.clone(), source_str);
+ self.source_to_id.insert(source_name, file_id);
+ }
+
+ fn to_diagnostic(&self, diag: super::Diagnostic) -> CDiagnostic<usize> {
+ let severity = match diag.level {
+ DiagnosticLevel::Error => Severity::Error,
+ DiagnosticLevel::Warning => Severity::Warning,
+ DiagnosticLevel::Note => Severity::Note,
+ DiagnosticLevel::Help => Severity::Help,
+ DiagnosticLevel::Bug => Severity::Bug,
+ };
+
+ let source_name: String = diag.span.source_name.name.as_str().unwrap().into();
+ let file_id = *self.source_to_id.get(&source_name).unwrap();
+
+ let message: String = diag.message.as_str().unwrap().into();
+
+ let byte_range = self.span_map.lookup(&diag.span);
+
+ let diagnostic = CDiagnostic::new(severity)
+ .with_message(message)
+ .with_code("EXXX")
+ .with_labels(vec![
+ Label::primary(file_id, byte_range.start_pos..byte_range.end_pos)
+ ]);
+
+ diagnostic
+ }
}
fn renderer(state: &mut DiagnosticState, diag_ctx: DiagnosticContext) {
let source_map = diag_ctx.module.source_map.clone();
- for diagnostic in diag_ctx.diagnostics.clone() {
- match source_map.source_map.get(&diagnostic.span.source_name) {
- Err(err) => panic!(),
- Ok(source) => state.span_map.add_source(source),
+ let writer = StandardStream::stderr(ColorChoice::Always);
+ let config = codespan_reporting::term::Config::default();
+ for diagnostic in diag_ctx.diagnostics.clone() {
+ match source_map.source_map.get(&diagnostic.span.source_name) {
+ Err(err) => panic!(err),
+ Ok(source) => {
+ state.add_source(source);
+ let diagnostic = state.to_diagnostic(diagnostic);
+ term::emit(
+ &mut writer.lock(),
+ &config,
+ &state.files,
+ &diagnostic).unwrap();
}
- println!("Diagnostic: {}", diagnostic.message);
}
+ }
}
pub fn init() -> Result<()> {
let diag_state = Arc::new(Mutex::new(DiagnosticState::new()));
let render_fn = move |diag_ctx: DiagnosticContext| {
- // let mut guard = diag_state.lock().unwrap();
- // renderer(&mut *guard, diag_ctx);
+ let mut guard = diag_state.lock().unwrap();
+ renderer(&mut *guard, diag_ctx);
};
override_renderer(Some(render_fn))?;
[incubator-tvm] 22/23: WIP
Posted by jr...@apache.org.
This is an automated email from the ASF dual-hosted git repository.
jroesch pushed a commit to branch cargo-build
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
commit e8bb83d33e7ad44ff16c72d9de61c2de722d12a8
Author: Jared Roesch <jr...@octoml.ai>
AuthorDate: Thu Oct 29 15:31:05 2020 -0700
WIP
---
rust/tvm/src/python.rs | 34 ++++++----------------------------
1 file changed, 6 insertions(+), 28 deletions(-)
diff --git a/rust/tvm/src/python.rs b/rust/tvm/src/python.rs
index 2b2d374..50ce7b0 100644
--- a/rust/tvm/src/python.rs
+++ b/rust/tvm/src/python.rs
@@ -20,20 +20,6 @@
use pyo3::prelude::*;
use once_cell::sync::OnceCell;
-// static TVM_PYTHON: OnceCell<Py<PyModule>> = OnceCell::new();
-
-// fn initialize() -> Py<PyModule> {
-// TVM_PYTHON.get_or_init(|| {
-// let gil = Python::acquire_gil();
-// let py = gil.python();
-// PyModule::new(py, "__tvm__rust__module__").map_err(|e| {
-// // We can't display Python exceptions via std::fmt::Display,
-// // so print the error here manually.
-// e.print_and_set_sys_last_vars(py);
-// }).expect("failed to initialize the Python interface").into()
-// }).clone()
-// }
-
/// Load the Python interpreter into the address space.
///
/// This enables the ability for Rust code to call TVM
@@ -53,27 +39,19 @@ pub fn load() -> Result<String, ()> {
})
}
-fn import_python<'p, 'b: 'p>(py: Python<'p>, to_import: &'b str) -> PyResult<&'p PyModule> {
- let imported_mod = py.import(to_import)?;
- PyModule::from_code(py,
- r#"
-import tvm
-from tvm import relay
-tvm.cleanup()
-"#, "blah", "my_mod")?;
- // py_mod.add(to_import, imported_mod)?;
- Ok(imported_mod)
-}
-
pub fn import(mod_to_import: &str) -> PyResult<()> {
let gil = Python::acquire_gil();
let py = gil.python();
- // let main_mod = initialize();
- // let main_mod = main_mod.as_ref(py);
import_python(py, mod_to_import)?;
Ok(())
}
+fn import_python<'p, 'b: 'p>(py: Python<'p>, to_import: &'b str) -> PyResult<&'p PyModule> {
+ let imported_mod = py.import(to_import)?;
+ Ok(imported_mod)
+}
+
+
fn load_python_tvm_(py: Python) -> PyResult<String> {
let imported_mod = import_python(py, "tvm")?;
let version: String = imported_mod.get("__version__")?.extract()?;
[incubator-tvm] 07/23: Fix Linux build
Posted by jr...@apache.org.
This is an automated email from the ASF dual-hosted git repository.
jroesch pushed a commit to branch cargo-build
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
commit e0f980142c3c0ab795de316943079a651743d8d7
Author: Jared Roesch <ro...@gmail.com>
AuthorDate: Thu Oct 15 01:42:14 2020 -0700
Fix Linux build
---
cmake/modules/LLVM.cmake | 7 ++++++-
rust/tvm-sys/Cargo.toml | 2 +-
rust/tvm-sys/build.rs | 3 +--
3 files changed, 8 insertions(+), 4 deletions(-)
diff --git a/cmake/modules/LLVM.cmake b/cmake/modules/LLVM.cmake
index 5f8ace1..ca4ecd6 100644
--- a/cmake/modules/LLVM.cmake
+++ b/cmake/modules/LLVM.cmake
@@ -16,7 +16,12 @@
# under the License.
# LLVM rules
-add_definitions(-DDMLC_USE_FOPEN64=0)
+# Due to LLVM debug symbols you can sometimes face linking issues on
+# certain compiler, platform combinations if you don't set NDEBUG.
+#
+# See https://github.com/imageworks/OpenShadingLanguage/issues/1069
+# for more discussion.
+add_definitions(-DDMLC_USE_FOPEN64=0 -DNDEBUG=1)
# Test if ${USE_LLVM} is not an explicit boolean false
# It may be a boolean or a string
diff --git a/rust/tvm-sys/Cargo.toml b/rust/tvm-sys/Cargo.toml
index c25a5bf..2952aa4 100644
--- a/rust/tvm-sys/Cargo.toml
+++ b/rust/tvm-sys/Cargo.toml
@@ -23,7 +23,7 @@ license = "Apache-2.0"
edition = "2018"
[features]
-default = ["bindings"]
+default = []
bindings = []
[dependencies]
diff --git a/rust/tvm-sys/build.rs b/rust/tvm-sys/build.rs
index 2d86c4b..1590234 100644
--- a/rust/tvm-sys/build.rs
+++ b/rust/tvm-sys/build.rs
@@ -60,8 +60,7 @@ fn main() -> Result<()> {
if cfg!(feature = "bindings") {
println!("cargo:rerun-if-env-changed=TVM_HOME");
println!("cargo:rustc-link-lib=dylib=tvm");
- println!("cargo:rustc-link-lib=dylib=llvm-10");
- println!("cargo:rustc-link-search={}/build", tvm_home);
+ println!("cargo:rustc-link-search=native={}/build", tvm_home);
}
// @see rust-bindgen#550 for `blacklist_type`
[incubator-tvm] 05/23: Borrow code from Egg
Posted by jr...@apache.org.
This is an automated email from the ASF dual-hosted git repository.
jroesch pushed a commit to branch cargo-build
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
commit b2b59c229e9b8c2002d8c8cd520748df6b38e074
Author: Jared Roesch <jr...@octoml.ai>
AuthorDate: Tue Oct 13 15:26:54 2020 -0700
Borrow code from Egg
---
rust/compiler-ext/src/lib.rs | 344 ++++++++++++++++++++++++++++++++++++++++++-
1 file changed, 337 insertions(+), 7 deletions(-)
diff --git a/rust/compiler-ext/src/lib.rs b/rust/compiler-ext/src/lib.rs
index 31e1bb2..58bdd0c 100644
--- a/rust/compiler-ext/src/lib.rs
+++ b/rust/compiler-ext/src/lib.rs
@@ -1,7 +1,337 @@
-#[cfg(test)]
-mod tests {
- #[test]
- fn it_works() {
- assert_eq!(2 + 2, 4);
- }
-}
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+ use std::os::raw::c_int;
+ use tvm::initialize;
+ use tvm::ir::{tir, PrimExpr};
+ use tvm::runtime::function::register_override;
+ use tvm::runtime::map::Map;
+ use tvm::runtime::object::{IsObject, IsObjectRef};
+
+ use ordered_float::NotNan;
+
+ mod interval;
+ mod math;
+
+ use math::{BoundsMap, Expr, RecExpr};
+ use tvm::ir::arith::ConstIntBound;
+ use tvm_rt::{ObjectRef, array::Array};
+
+ macro_rules! downcast_match {
+ ($id:ident; { $($t:ty => $arm:expr $(,)? )+ , else => $default:expr }) => {
+ $( if let Ok($id) = $id.downcast_clone::<$t>() { $arm } else )+
+ { $default }
+ }
+ }
+
+ #[derive(Default)]
+ struct VarMap {
+ vars: Vec<(tvm::ir::tir::Var, egg::Symbol)>,
+ objs: Vec<ObjectRef>,
+ }
+
+ impl VarMap {
+ // FIXME this should eventually do the right thing for TVM variables
+ // right now it depends on them having unique names
+ fn make_symbol(&mut self, var: tvm::ir::tir::Var) -> egg::Symbol {
+ let sym = egg::Symbol::from(var.name_hint.as_str().unwrap());
+ for (_, sym2) in &self.vars {
+ if sym == *sym2 {
+ return sym;
+ }
+ }
+
+ self.vars.push((var, sym));
+ sym
+ }
+
+ fn get_symbol(&self, sym: egg::Symbol) -> tvm::ir::tir::Var {
+ for (v, sym2) in &self.vars {
+ if sym == *sym2 {
+ return v.clone();
+ }
+ }
+ panic!("Should have found a var")
+ }
+
+ fn push_obj(&mut self, obj: impl IsObjectRef) -> usize {
+ let i = self.objs.len();
+ self.objs.push(obj.upcast());
+ i
+ }
+
+ fn get_obj<T: IsObjectRef>(&self, i: usize) -> T {
+ self.objs[i].clone().downcast().expect("bad downcast")
+ }
+ }
+
+ fn to_egg(vars: &mut VarMap, prim: &PrimExpr) -> RecExpr {
+ fn build(vars: &mut VarMap, p: &PrimExpr, recexpr: &mut RecExpr) -> egg::Id {
+ macro_rules! r {
+ ($e:expr) => {
+ build(vars, &$e, recexpr)
+ };
+ }
+
+ let dt = recexpr.add(Expr::DataType(p.datatype));
+ let e = downcast_match!(p; {
+ tir::Add => Expr::Add([dt, r!(p.a), r!(p.b)]),
+ tir::Sub => Expr::Sub([dt, r!(p.a), r!(p.b)]),
+ tir::Mul => Expr::Mul([dt, r!(p.a), r!(p.b)]),
+
+ tir::Div => Expr::Div([dt, r!(p.a), r!(p.b)]),
+ tir::Mod => Expr::Mod([dt, r!(p.a), r!(p.b)]),
+ tir::FloorDiv => Expr::FloorDiv([dt, r!(p.a), r!(p.b)]),
+ tir::FloorMod => Expr::FloorMod([dt, r!(p.a), r!(p.b)]),
+
+ tir::Min => Expr::Min([dt, r!(p.a), r!(p.b)]),
+ tir::Max => Expr::Max([dt, r!(p.a), r!(p.b)]),
+
+ tir::Ramp => Expr::Ramp([dt, r!(p.start), r!(p.stride), recexpr.add(Expr::Int(p.lanes.into()))]),
+ tir::Select => Expr::Select([dt, r!(p.condition), r!(p.true_value), r!(p.false_value)]),
+
+ tir::Eq => Expr::Equal([dt, r!(p.a), r!(p.b)]),
+ tir::Ne => Expr::NotEqual([dt, r!(p.a), r!(p.b)]),
+ tir::Lt => Expr::Less([dt, r!(p.a), r!(p.b)]),
+ tir::Le => Expr::LessEqual([dt, r!(p.a), r!(p.b)]),
+ tir::Gt => Expr::Greater([dt, r!(p.a), r!(p.b)]),
+ tir::Ge => Expr::GreaterEqual([dt, r!(p.a), r!(p.b)]),
+
+ tir::And => Expr::And([dt, r!(p.a), r!(p.b)]),
+ tir::Or => Expr::Or([dt, r!(p.a), r!(p.b)]),
+ tir::Not => Expr::Not([dt, r!(p.value)]),
+
+ tir::Broadcast => Expr::Broadcast([dt, r!(p.value), recexpr.add(Expr::Int(p.lanes.into()))]),
+
+ tir::Let => {
+ let sym = recexpr.add(Expr::Symbol(vars.make_symbol(p.var.clone())));
+ Expr::Let([dt, sym, r!(p.value), r!(p.body)])
+ }
+ tir::Var => {
+ let sym = recexpr.add(Expr::Symbol(vars.make_symbol(p)));
+ Expr::Var([dt, sym])
+ }
+ tir::IntImm => {
+ let int = recexpr.add(Expr::Int(p.value));
+ Expr::IntImm([dt, int])
+ }
+ tir::FloatImm => {
+ let float = recexpr.add(Expr::Float(NotNan::new(p.value).unwrap()));
+ Expr::FloatImm([dt, float])
+ }
+ tir::Cast => Expr::Cast([dt, r!(p.value)]),
+
+ tir::Call => {
+ let op = vars.push_obj(p.op.clone());
+ let mut arg_ids = vec![dt];
+ for i in 0..p.args.len() {
+ let arg: PrimExpr = p.args.get(i as isize).expect("array get fail");
+ arg_ids.push(r!(arg));
+ }
+ Expr::Call(op, arg_ids)
+ },
+ tir::Load => {
+ let sym = recexpr.add(Expr::Symbol(vars.make_symbol(p.var.clone())));
+ Expr::Load([dt, sym, r!(p.index), r!(p.predicate)])
+ },
+ else => {
+ println!("Failed to downcast type '{}': {}", p.type_key(), tvm::runtime::debug_print(p.clone().upcast()).unwrap().to_str().unwrap());
+ Expr::Object(vars.push_obj(p.clone()))
+ }
+ });
+
+ recexpr.add(e)
+ }
+
+ let mut recexpr = Default::default();
+ build(vars, prim, &mut recexpr);
+ recexpr
+ }
+
+ fn from_egg(vars: &VarMap, recexpr: &RecExpr) -> PrimExpr {
+ fn build(vars: &VarMap, nodes: &[Expr]) -> PrimExpr {
+ let go = |i: &egg::Id| build(vars, &nodes[..usize::from(*i) + 1]);
+ let get_dt = |i: &egg::Id| nodes[usize::from(*i)].to_dtype().unwrap();
+ let prim: PrimExpr = match nodes.last().expect("cannot be empty") {
+ Expr::Var([_dt, s]) => match &nodes[usize::from(*s)] {
+ Expr::Symbol(sym) => vars.get_symbol(*sym).upcast(),
+ n => panic!("Expected a symbol, got {:?}", n),
+ },
+ Expr::IntImm([dt, v]) => {
+ let value = nodes[usize::from(*v)].to_int().unwrap();
+ tir::IntImm::new(get_dt(dt), value).upcast()
+ }
+ Expr::FloatImm([dt, v]) => {
+ let value = nodes[usize::from(*v)].to_float().unwrap();
+ tir::FloatImm::new(get_dt(dt), value).upcast()
+ }
+ Expr::Let([dt, s, value, body]) => {
+ let var = match &nodes[usize::from(*s)] {
+ Expr::Symbol(sym) => vars.get_symbol(*sym).upcast(),
+ n => panic!("Expected a symbol, got {:?}", n),
+ };
+ tir::Let::new(get_dt(dt), var, go(value), go(body)).upcast()
+ }
+ Expr::Load([dt, s, value, body]) => {
+ let var = match &nodes[usize::from(*s)] {
+ Expr::Symbol(sym) => vars.get_symbol(*sym).upcast(),
+ n => panic!("Expected a symbol, got {:?}", n),
+ };
+ tir::Load::new(get_dt(dt), var, go(value), go(body)).upcast()
+ }
+
+ Expr::Add([dt, a, b]) => tir::Add::new(get_dt(dt), go(a), go(b)).upcast(),
+ Expr::Sub([dt, a, b]) => tir::Sub::new(get_dt(dt), go(a), go(b)).upcast(),
+ Expr::Mul([dt, a, b]) => tir::Mul::new(get_dt(dt), go(a), go(b)).upcast(),
+
+ Expr::Div([dt, a, b]) => tir::Div::new(get_dt(dt), go(a), go(b)).upcast(),
+ Expr::Mod([dt, a, b]) => tir::Mod::new(get_dt(dt), go(a), go(b)).upcast(),
+ Expr::FloorDiv([dt, a, b]) => tir::FloorDiv::new(get_dt(dt), go(a), go(b)).upcast(),
+ Expr::FloorMod([dt, a, b]) => tir::FloorMod::new(get_dt(dt), go(a), go(b)).upcast(),
+
+ Expr::Min([dt, a, b]) => tir::Min::new(get_dt(dt), go(a), go(b)).upcast(),
+ Expr::Max([dt, a, b]) => tir::Max::new(get_dt(dt), go(a), go(b)).upcast(),
+
+ Expr::Equal([dt, a, b]) => tir::Eq::new(get_dt(dt), go(a), go(b)).upcast(),
+ Expr::NotEqual([dt, a, b]) => tir::Ne::new(get_dt(dt), go(a), go(b)).upcast(),
+ Expr::Less([dt, a, b]) => tir::Lt::new(get_dt(dt), go(a), go(b)).upcast(),
+ Expr::LessEqual([dt, a, b]) => tir::Le::new(get_dt(dt), go(a), go(b)).upcast(),
+ Expr::Greater([dt, a, b]) => tir::Gt::new(get_dt(dt), go(a), go(b)).upcast(),
+ Expr::GreaterEqual([dt, a, b]) => tir::Ge::new(get_dt(dt), go(a), go(b)).upcast(),
+
+ Expr::And([dt, a, b]) => tir::And::new(get_dt(dt), go(a), go(b)).upcast(),
+ Expr::Or([dt, a, b]) => tir::Or::new(get_dt(dt), go(a), go(b)).upcast(),
+ Expr::Not([dt, a]) => tir::Not::new(get_dt(dt), go(a)).upcast(),
+
+ Expr::Ramp([dt, a, b, c]) => {
+ let len = &nodes[usize::from(*c)];
+ let i = len
+ .to_int()
+ .unwrap_or_else(|| panic!("Ramp lanes must be an int, got {:?}", len));
+ tir::Ramp::new(get_dt(dt), go(a), go(b), i as i32).upcast()
+ }
+ Expr::Broadcast([dt, val, lanes]) => {
+ let lanes = &nodes[usize::from(*lanes)];
+ let lanes = lanes
+ .to_int()
+ .unwrap_or_else(|| panic!("Ramp lanes must be an int, got {:?}", lanes));
+ println!("dt: {}", get_dt(dt));
+ tir::Broadcast::new(get_dt(dt), go(val), lanes as i32).upcast()
+ }
+
+ Expr::Select([dt, a, b, c]) => tir::Select::new(get_dt(dt), go(a), go(b), go(c)).upcast(),
+ Expr::Cast([dt, a]) => tir::Cast::new(get_dt(dt), go(a)).upcast(),
+ Expr::Call(expr, args) => {
+ let arg_exprs: Vec<PrimExpr> = args[1..].iter().map(go).collect();
+ let arg_exprs = Array::from_vec(arg_exprs).expect("failed to convert args");
+ tir::Call::new(get_dt(&args[0]), vars.get_obj(*expr), arg_exprs).upcast()
+ }
+
+ Expr::Object(i) => vars.get_obj(*i),
+ node => panic!("I don't know how to extract {:?}", node),
+ };
+ assert_ne!(prim.datatype.bits(), 0);
+ assert_ne!(prim.datatype.lanes(), 0);
+ prim
+ }
+ build(vars, recexpr.as_ref())
+ }
+
+ fn run(
+ input: PrimExpr,
+ expected: Option<PrimExpr>,
+ map: Map<PrimExpr, ConstIntBound>,
+ ) -> Result<PrimExpr, String> {
+ use egg::{CostFunction, Extractor};
+
+ let mut bounds = BoundsMap::default();
+ for (k, v) in map {
+ if let Ok(var) = k.downcast_clone::<tir::Var>() {
+ let sym: egg::Symbol = var.name_hint.as_str().unwrap().into();
+ bounds.insert(sym, (v.min_value, v.max_value));
+ } else {
+ println!("Non var in bounds map: {}", tvm::ir::as_text(k));
+ }
+ }
+
+ let mut vars = VarMap::default();
+ let expr = to_egg(&mut vars, &input);
+ let mut runner = math::default_runner();
+ runner.egraph.analysis.bounds = bounds;
+
+ let mut runner = runner.with_expr(&expr).run(&math::rules());
+ // runner.print_report();
+ let mut extractor = Extractor::new(&runner.egraph, math::CostFn);
+ let root = runner.egraph.find(runner.roots[0]);
+ let (cost, best) = extractor.find_best(root);
+ if let Some(expected) = expected {
+ let mut expected_vars = VarMap::default();
+ let expected_expr = to_egg(&mut expected_vars, &expected);
+ let expected_root = runner.egraph.add_expr(&expected_expr);
+ if expected_root != root {
+ return Err(format!(
+ "\n\nFailed to prove them equal!\nExpected:\n{}\nFound:\n{}\n",
+ expected_expr.pretty(40),
+ best.pretty(40)
+ ));
+ }
+ let expected_cost = math::CostFn.cost_rec(&expected_expr);
+ if expected_cost != cost {
+ let msg = format!(
+ "\n\nCosts not equal: Expected {}:\n{}\nFound {}:\n{}\n",
+ expected_cost,
+ expected_expr.pretty(40),
+ cost,
+ best.pretty(40)
+ );
+ if cost < expected_cost {
+ println!("egg wins: {}", msg)
+ } else {
+ return Err(msg);
+ }
+ }
+ }
+ log::info!(" returning... {}", best.pretty(60));
+ Ok(from_egg(&vars, &best))
+ }
+
+ fn simplify(prim: PrimExpr, map: Map<PrimExpr, ConstIntBound>) -> Result<PrimExpr, tvm::Error> {
+ log::debug!("map: {:?}", map);
+ run(prim, None, map).map_err(tvm::Error::CallFailed)
+ }
+
+ fn simplify_and_check(
+ prim: PrimExpr,
+ check: PrimExpr,
+ map: Map<PrimExpr, ConstIntBound>,
+ ) -> Result<PrimExpr, tvm::Error> {
+ log::debug!("check map: {:?}", map);
+ run(prim, Some(check), map).map_err(tvm::Error::CallFailed)
+ }
+
+ initialize!({
+ let _ = env_logger::try_init();
+ // NOTE this print prevents a segfault (on Linux) for now...
+ println!("Initializing simplifier... ");
+ register_override(simplify, "egg.simplify", true).expect("failed to initialize simplifier");
+ register_override(simplify_and_check, "egg.simplify_and_check", true)
+ .expect("failed to initialize simplifier");
+ log::debug!("done!");
+ });
+
\ No newline at end of file
[incubator-tvm] 20/23: WIP
Posted by jr...@apache.org.
This is an automated email from the ASF dual-hosted git repository.
jroesch pushed a commit to branch cargo-build
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
commit b8dcc35801e7f0f3c696e0d030b3bd0931127785
Author: Jared Roesch <ro...@gmail.com>
AuthorDate: Thu Oct 22 22:18:58 2020 -0700
WIP
---
rust/tvm-macros/src/external.rs | 2 +-
rust/tvm-macros/src/lib.rs | 3 +-
rust/tvm-macros/src/object.rs | 23 ++++++
rust/tvm-rt/src/object/mod.rs | 9 +--
rust/tvm-rt/src/object/object_ptr.rs | 16 ++++
rust/tvm-rt/src/string.rs | 1 +
rust/tvm-rt/src/value.rs | 1 -
rust/tvm-sys/src/datatype.rs | 4 +
rust/tvm/src/ir/module.rs | 152 ++++++++++++++++++++++++++++++++---
rust/tvm/src/ir/relay/mod.rs | 36 +++------
rust/tvm/src/ir/tir.rs | 14 ++++
11 files changed, 218 insertions(+), 43 deletions(-)
diff --git a/rust/tvm-macros/src/external.rs b/rust/tvm-macros/src/external.rs
index 44a242c..51a389b 100644
--- a/rust/tvm-macros/src/external.rs
+++ b/rust/tvm-macros/src/external.rs
@@ -21,7 +21,7 @@ use proc_macro_error::abort;
use quote::quote;
use syn::parse::{Parse, ParseStream, Result};
-use syn::{Token, FnArg, Signature, Attribute, token::Semi, Visibility, Generics, Ident, Lit, Meta, NestedMeta, Pat, ReturnType, TraitItemMethod, Type};
+use syn::{FnArg, Signature, Attribute, token::Semi, Visibility, Generics, Ident, Lit, Meta, NestedMeta, Pat, ReturnType, Type};
struct ExternalItem {
attrs: Vec<Attribute>,
diff --git a/rust/tvm-macros/src/lib.rs b/rust/tvm-macros/src/lib.rs
index 32f2839..e563a57 100644
--- a/rust/tvm-macros/src/lib.rs
+++ b/rust/tvm-macros/src/lib.rs
@@ -30,7 +30,8 @@ pub fn import_module(input: TokenStream) -> TokenStream {
import_module::macro_impl(input)
}
-#[proc_macro_derive(Object, attributes(base, ref_name, type_key))]
+#[proc_macro_error]
+#[proc_macro_derive(Object, attributes(base, ref_name, type_key, no_derive))]
pub fn macro_impl(input: TokenStream) -> TokenStream {
// let input = proc_macro2::TokenStream::from(input);
TokenStream::from(object::macro_impl(input))
diff --git a/rust/tvm-macros/src/object.rs b/rust/tvm-macros/src/object.rs
index ff72d6a..7e6a934 100644
--- a/rust/tvm-macros/src/object.rs
+++ b/rust/tvm-macros/src/object.rs
@@ -36,6 +36,8 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream {
.map(attr_to_str)
.expect("Failed to get type_key");
+ let derive = get_attr(&derive_input, "no_derive").map(|_| false).unwrap_or(true);
+
let ref_id = get_attr(&derive_input, "ref_name")
.map(|a| Ident::new(attr_to_str(a).value().as_str(), Span::call_site()))
.unwrap_or_else(|| {
@@ -185,5 +187,26 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream {
expanded.extend(base_tokens);
+ if derive {
+ let derives = quote! {
+ impl std::hash::Hash for #ref_id {
+ fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
+ self.0.hash(state)
+ }
+ }
+
+ impl std::cmp::PartialEq for #ref_id {
+ fn eq(&self, other: &Self) -> bool {
+ self.0 == other.0
+ }
+ }
+
+ impl std::cmp::Eq for #ref_id {}
+ };
+
+
+ expanded.extend(derives);
+ }
+
TokenStream::from(expanded)
}
diff --git a/rust/tvm-rt/src/object/mod.rs b/rust/tvm-rt/src/object/mod.rs
index e48c017..7e6107d 100644
--- a/rust/tvm-rt/src/object/mod.rs
+++ b/rust/tvm-rt/src/object/mod.rs
@@ -90,12 +90,7 @@ external! {
#[name("ir.DebugPrint")]
pub fn debug_print(object: ObjectRef) -> CString;
#[name("node.StructuralHash")]
- fn structural_hash(object: ObjectRef, map_free_vars: bool) -> ObjectRef;
+ fn structural_hash(object: ObjectRef, map_free_vars: bool) -> i64;
#[name("node.StructuralEqual")]
- fn structural_equal(lhs: ObjectRef, rhs: ObjectRef, assert_mode: bool, map_free_vars: bool) -> ObjectRef;
+ fn structural_equal(lhs: ObjectRef, rhs: ObjectRef, assert_mode: bool, map_free_vars: bool) -> bool;
}
-
-// external! {
-// #[name("ir.TextPrinter")]
-// fn as_text(object: ObjectRef) -> CString;
-// }
diff --git a/rust/tvm-rt/src/object/object_ptr.rs b/rust/tvm-rt/src/object/object_ptr.rs
index 77254d2..a923506 100644
--- a/rust/tvm-rt/src/object/object_ptr.rs
+++ b/rust/tvm-rt/src/object/object_ptr.rs
@@ -342,6 +342,22 @@ impl<'a, T: IsObject> TryFrom<ArgValue<'a>> for ObjectPtr<T> {
}
}
+impl<T: IsObject> std::hash::Hash for ObjectPtr<T> {
+ fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
+ state.write_i64(super::structural_hash(ObjectRef(Some(self.clone().upcast())), false).unwrap())
+ }
+}
+
+impl<T: IsObject> PartialEq for ObjectPtr<T> {
+ fn eq(&self, other: &Self) -> bool {
+ let lhs = ObjectRef(Some(self.clone().upcast()));
+ let rhs = ObjectRef(Some(other.clone().upcast()));
+ super::structural_equal(lhs, rhs, false, false).unwrap()
+ }
+}
+
+impl<T: IsObject> Eq for ObjectPtr<T> {}
+
#[cfg(test)]
mod tests {
use super::{Object, ObjectPtr};
diff --git a/rust/tvm-rt/src/string.rs b/rust/tvm-rt/src/string.rs
index 6ff24be..e9a76d2 100644
--- a/rust/tvm-rt/src/string.rs
+++ b/rust/tvm-rt/src/string.rs
@@ -28,6 +28,7 @@ use tvm_macros::Object;
#[derive(Object)]
#[ref_name = "String"]
#[type_key = "runtime.String"]
+#[no_derive]
pub struct StringObj {
base: Object,
data: *const u8,
diff --git a/rust/tvm-rt/src/value.rs b/rust/tvm-rt/src/value.rs
index c49944d..b8cd190 100644
--- a/rust/tvm-rt/src/value.rs
+++ b/rust/tvm-rt/src/value.rs
@@ -22,7 +22,6 @@
//! `RetValue` is the owned version of `TVMPODValue`.
use std::convert::TryFrom;
-// use std::ffi::c_void;
use crate::{ArgValue, Module, RetValue};
use tvm_sys::{errors::ValueDowncastError, ffi::TVMModuleHandle, try_downcast};
diff --git a/rust/tvm-sys/src/datatype.rs b/rust/tvm-sys/src/datatype.rs
index 8050d93..5f7e0c3 100644
--- a/rust/tvm-sys/src/datatype.rs
+++ b/rust/tvm-sys/src/datatype.rs
@@ -83,6 +83,10 @@ impl DataType {
DataType::new(DL_FLOAT_CODE, bits, lanes)
}
+ pub const fn float32() -> DataType {
+ Self::float(32, 1)
+ }
+
pub const fn uint(bits: u8, lanes: u16) -> DataType {
DataType::new(DL_UINT_CODE, bits, lanes)
}
diff --git a/rust/tvm/src/ir/module.rs b/rust/tvm/src/ir/module.rs
index 3b60b0c..db32ce2 100644
--- a/rust/tvm/src/ir/module.rs
+++ b/rust/tvm/src/ir/module.rs
@@ -17,6 +17,7 @@
* under the License.
*/
use std::io::Result as IOResult;
+use std::iter::FromIterator;
use std::path::Path;
use thiserror::Error;
@@ -33,8 +34,9 @@ use super::function::BaseFunc;
use super::source_map::SourceMap;
use super::{ty::GlobalTypeVar, relay};
-// TODO(@jroesch): define type
+use tvm_macros::Object;
+// TODO(@jroesch): define type
type TypeData = ObjectRef;
type GlobalTypeVar = ObjectRef;
@@ -64,9 +66,11 @@ external! {
fn parse_module(file_name: TVMString, source: TVMString) -> IRModule;
#[name("parser.ParseExpr")]
fn parse_expression(file_name: TVMString, source: TVMString) -> IRModule;
+ #[name("ir.IRModule")]
+ fn module_new(funcs: Map<GlobalVar, BaseFunc>, types: Map<GlobalTypeVar, TypeData>) -> IRModule;
// Module methods
#[name("ir.Module_Add")]
- fn module_add(module: IRModule, type_name: GlobalVar, expr: relay::Expr, update: bool) -> ();
+ fn module_add(module: IRModule, type_name: GlobalVar, expr: BaseFunc, update: bool) -> IRModule;
#[name("ir.Module_AddDef")]
fn module_add_def(module: IRModule, type_name: GlobalTypeVar, type_data: TypeData, update: bool) -> ();
#[name("ir.Module_GetGlobalVar")]
@@ -78,15 +82,15 @@ external! {
#[name("ir.Module_Lookup_str")]
fn module_lookup_str(module: IRModule, name: TVMString) -> BaseFunc;
#[name("ir.Module_GetGlobalTypeVars")]
- fn module_get_global_type_vars() -> Array<GlobalTypeVar>;
+ fn module_get_global_type_vars(module: IRModule) -> Array<GlobalTypeVar>;
#[name("ir.Module_ContainGlobalVar")]
- fn module_contains_global_var(name: TVMString) -> bool;
+ fn module_contains_global_var(module: IRModule, name: TVMString) -> bool;
#[name("ir.Module_ContainGlobalTypeVar")]
- fn module_contains_global_type_var(name: TVMString) -> bool;
+ fn module_contains_global_type_var(module: IRModule, name: TVMString) -> bool;
#[name("ir.Module_LookupDef")]
- fn module_lookup_def(module: IRModule, global: GlobalTypeVar) -> TypeDef;
+ fn module_lookup_def(module: IRModule, global: GlobalTypeVar) -> TypeData;
#[name("ir.Module_LookupDef_str")]
- fn module_lookup_def_str(module: IRModule, global: GlobalTypeVar) -> TypeDef;
+ fn module_lookup_def_str(module: IRModule, global: GlobalTypeVar) -> TypeData;
#[name("ir.Module_LookupTag")]
fn module_lookup_tag(module: IRModule, tag: i32) -> relay::Constructor;
#[name("ir.Module_FromExpr")]
@@ -99,8 +103,12 @@ external! {
// Note: we don't expose update here as update is going to be removed.
-
impl IRModule {
+ pub fn new<F, T>(funcs: F, types: T) -> Result<IRModule>
+ where F: IntoIterator<Item=(GlobalVar, BaseFunc)>, T: IntoIterator<Item=(GlobalTypeVar, TypeData)> {
+ module_new(Map::from_iter(funcs), Map::from_iter(types))
+ }
+
pub fn parse<N, S>(file_name: N, source: S) -> Result<IRModule>
where
N: Into<TVMString>,
@@ -119,6 +127,13 @@ impl IRModule {
Ok(module)
}
+ pub fn add(
+ &mut self,
+ var: GlobalVar,
+ func: BaseFunc) -> Result<IRModule> {
+ module_add(self.clone(), var, func, true)
+ }
+
pub fn add_def(
&mut self,
type_name: GlobalTypeVar,
@@ -146,10 +161,127 @@ impl IRModule {
{
module_lookup_str(self.clone(), name.into())
}
+
+ pub fn get_global_type_vars(&self) -> Result<Array<GlobalTypeVar>> {
+ module_get_global_type_vars(self.clone())
+ }
+
+ pub fn contains_global_var<S: Into<TVMString>>(&self, name: S) -> Result<bool> {
+ module_contains_global_var(self.clone(), name.into())
+ }
+
+ pub fn contains_global_type_var<S: Into<TVMString>>(&self, name: S) -> Result<bool> {
+ module_contains_global_type_var(self.clone(), name.into())
+ }
+
+ pub fn lookup_def(&self, global: GlobalTypeVar) -> Result<TypeData> {
+ module_lookup_def(self.clone(), global)
+ }
+
+ pub fn lookup_def_str(&self, global: GlobalTypeVar) -> Result<TypeData> {
+ module_lookup_def_str(self.clone(), global)
+ }
+
+ pub fn lookup_tag(&self, tag: i32) -> Result<relay::Constructor> {
+ module_lookup_tag(self.clone(), tag)
+ }
+
+ pub fn from_expr(expr: relay::Expr, funcs: Map<GlobalVar, BaseFunc>, types: Map<GlobalTypeVar, TypeData>) -> Result<IRModule> {
+ module_from_expr(expr, funcs, types)
+ }
+
+ pub fn import<S: Into<TVMString>>(&mut self, path: S) -> Result<()> {
+ module_import(self.clone(), path.into())
+ }
+
+ pub fn import_from_std<S: Into<TVMString>>(&mut self, path: S) -> Result<()> {
+ module_import_from_std(self.clone(), path.into())
+ }
}
#[cfg(test)]
mod tests {
- // #[test]
- // fn
+ use std::collections::HashMap;
+ use super::relay::*;
+ use super::*;
+ use super::super::span::Span;
+ use tvm_rt::IsObjectRef;
+
+ #[test]
+ fn test_module_add() -> anyhow::Result<()> {
+ let funcs = HashMap::<GlobalVar, BaseFunc>::new();
+ let types = HashMap::<GlobalTypeVar, TypeData>::new();
+ let mut module = IRModule::new(funcs, types)?;
+ let x = Var::static_tensor("x".into(), vec![1, 1], DataType::float32());
+ let params = Array::from_vec(vec![x.clone()])?;
+ let func = relay::Function::simple(params, x.upcast()).upcast();
+ let module = module.add(GlobalVar::new("foo".into(), Span::null()), func)?;
+ // let lfunc = module.lookup_str("foo")?;
+ // let lfunc = lfunc.downcast::<relay::Function>()?;
+ // assert_eq!(lfunc.params.len(), 1);
+ Ok(())
+ }
+
+ #[test]
+ fn test_module_add_def() {
+
+ }
+
+ #[test]
+ fn test_get_global_var() {
+
+ }
+
+ #[test]
+ fn test_get_global_vars() {
+
+ }
+
+ #[test]
+ fn test_lookup() {
+
+ }
+
+
+ // pub fn get_global_type_vars(&self) -> Result<Array<GlobalTypeVar>> {
+ // module_get_global_type_vars(self.clone())
+ // }
+
+ // pub fn contains_global_var<S: Into<TVMString>>(&self, name: S) -> Result<bool> {
+ // module_contains_global_var(self.clone(), name.into())
+ // }
+
+ // pub fn contains_global_type_var<S: Into<TVMString>>(&self, name: S) -> Result<bool> {
+ // module_contains_global_type_var(self.clone(), name.into())
+ // }
+
+ #[test]
+ fn test_lookup_def() {
+
+ }
+ // pub fn lookup_def(&self, global: GlobalTypeVar) -> Result<TypeData> {
+ // module_lookup_def(self.clone(), global)
+ // }
+
+ // pub fn lookup_def_str(&self, global: GlobalTypeVar) -> Result<TypeData> {
+ // module_lookup_def_str(self.clone(), global)
+ // }
+
+ // pub fn lookup_tag(&self, tag: i32) -> Result<relay::Constructor> {
+ // module_lookup_tag(self.clone(), tag)
+ // }
+
+ // pub fn from_expr(expr: relay::Expr, funcs: Map<GlobalVar, BaseFunc>, types: Map<GlobalTypeVar, TypeData>) -> Result<IRModule> {
+ // module_from_expr(expr, funcs, types)
+ // }
+
+
+ // pub fn import<S: Into<TVMString>>(&mut self, path: S) -> Result<()> {
+ // module_import(self.clone(), path.into())
+ // }
+
+
+ // pub fn import_from_std<S: Into<TVMString>>(&mut self, path: S) -> Result<()> {
+ // module_import_from_std(self.clone(), path.into())
+ // }
}
diff --git a/rust/tvm/src/ir/relay/mod.rs b/rust/tvm/src/ir/relay/mod.rs
index 530b120..90b7a6a 100644
--- a/rust/tvm/src/ir/relay/mod.rs
+++ b/rust/tvm/src/ir/relay/mod.rs
@@ -16,11 +16,6 @@
* specific language governing permissions and limitations
* under the License.
*/
-
-pub mod attrs;
-
-use std::hash::Hash;
-
use crate::runtime::array::Array;
use crate::runtime::{object::*, String as TString};
@@ -29,11 +24,15 @@ use super::expr::BaseExprNode;
use super::function::BaseFuncNode;
use super::span::Span;
use super::ty::{Type, TypeNode};
+use super::span::Span;
use tvm_macros::Object;
use tvm_rt::NDArray;
pub use super::expr::{GlobalVar, GlobalVarNode};
+pub use crate::runtime::DataType;
+
+pub mod attrs;
#[repr(C)]
#[derive(Object)]
@@ -58,20 +57,6 @@ impl ExprNode {
}
}
-impl Hash for Expr {
- fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
- self.as_ptr().unwrap().ptr.hash(state)
- }
-}
-
-impl PartialEq for Expr {
- fn eq(&self, other: &Self) -> bool {
- self.as_ptr().unwrap().ptr.eq(&other.as_ptr().unwrap().ptr)
- }
-}
-
-impl Eq for Expr {}
-
#[repr(C)]
#[derive(Object)]
#[ref_name = "Id"]
@@ -140,11 +125,11 @@ pub struct VarNode {
}
impl Var {
- pub fn new(name_hint: String, type_annotation: Type, _span: ObjectRef) -> Var {
+ pub fn new(name_hint: String, type_annotation: Type, _span: Span) -> Var {
let node = VarNode {
base: ExprNode::base::<VarNode>(),
vid: Id::new(name_hint.into()),
- type_annotation,
+ type_annotation: type_annotation,
};
Var(Some(ObjectPtr::new(node)))
}
@@ -153,8 +138,9 @@ impl Var {
&self.vid.0.as_ref().unwrap().name_hint
}
- pub fn to_expr(self) -> Expr {
- unsafe { Expr(std::mem::transmute(self.0)) }
+ pub fn static_tensor(name_hint: String, sh: Vec<i32>, dtype: DataType) -> Var {
+ let sh = Array::from_vec(sh.into_iter().map(Into::into).collect()).unwrap();
+ Self::new(name_hint, super::ty::TensorType::new(sh, dtype, Span::null()).upcast(), Span::null())
}
}
@@ -510,6 +496,10 @@ impl Function {
};
Function(Some(ObjectPtr::new(node)))
}
+
+ pub fn simple(params: Array<Var>, body: Expr) -> Function {
+ Self::new(params, body, Type::null(), Array::from_vec(vec![]).unwrap())
+ }
}
#[cfg(test)]
diff --git a/rust/tvm/src/ir/tir.rs b/rust/tvm/src/ir/tir.rs
index 22d4e02..f07e854 100644
--- a/rust/tvm/src/ir/tir.rs
+++ b/rust/tvm/src/ir/tir.rs
@@ -47,6 +47,20 @@ macro_rules! define_node {
// TODO(@jroesch): should move up to expr.rs to mirror TVM.
define_node!(IntImm, "IntImm", "IntImm";
IntImmNode { value: i64 });
+
+impl From<i32> for IntImm {
+ fn from(i: i32) -> IntImm {
+ IntImm::new(DataType::int(32, 1), i as i64)
+ }
+}
+
+impl From<i32> for PrimExpr {
+ fn from(i: i32) -> PrimExpr {
+ use crate::runtime::IsObjectRef;
+ IntImm::from(i).upcast()
+ }
+}
+
define_node!(Var, "Var", "tir.Var";
VarNode { name_hint: TVMString });
[incubator-tvm] 01/23: Add initial boilerplate for Rust diagnostic
interface.
Posted by jr...@apache.org.
This is an automated email from the ASF dual-hosted git repository.
jroesch pushed a commit to branch cargo-build
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
commit 1097cbf8a23708b7408dfbab7c419e363af57728
Author: Jared Roesch <jr...@octoml.ai>
AuthorDate: Fri Oct 9 01:18:15 2020 -0700
Add initial boilerplate for Rust diagnostic interface.
---
python/tvm/ir/diagnostics/__init__.py | 2 +-
rust/tvm/src/ir/diagnostics.rs | 239 ++++++++++++++++++++++++++++++++++
rust/tvm/src/ir/mod.rs | 1 +
3 files changed, 241 insertions(+), 1 deletion(-)
diff --git a/python/tvm/ir/diagnostics/__init__.py b/python/tvm/ir/diagnostics/__init__.py
index 6503743..0ad2a7a 100644
--- a/python/tvm/ir/diagnostics/__init__.py
+++ b/python/tvm/ir/diagnostics/__init__.py
@@ -37,7 +37,7 @@ def get_renderer():
"""
return _ffi_api.GetRenderer()
-
+@tvm.register_func("diagnostics.override_renderer")
def override_renderer(render_func):
"""
Sets a custom renderer for diagnostics.
diff --git a/rust/tvm/src/ir/diagnostics.rs b/rust/tvm/src/ir/diagnostics.rs
new file mode 100644
index 0000000..799a10c
--- /dev/null
+++ b/rust/tvm/src/ir/diagnostics.rs
@@ -0,0 +1,239 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/// The diagnostic interface to TVM, used for reporting and rendering
+/// diagnostic information by the compiler. This module exposes
+/// three key abstractions: a Diagnostic, the DiagnosticContext,
+/// and the DiagnosticRenderer.
+
+use tvm_macros::{Object, external};
+use super::module::IRModule;
+use crate::runtime::{function::{Function, Typed}, array::Array, string::String as TString};
+use crate::runtime::object::{Object, ObjectRef};
+use crate::runtime::function::Result;
+use super::span::Span;
+
+type SourceName = ObjectRef;
+
+/// The diagnostic level, controls the printing of the message.
+#[repr(C)]
+pub enum DiagnosticLevel {
+ Bug = 10,
+ Error = 20,
+ Warning = 30,
+ Note = 40,
+ Help = 50,
+}
+
+/// A compiler diagnostic.
+#[repr(C)]
+#[derive(Object)]
+#[ref_name = "Diagnostic"]
+#[type_key = "Diagnostic"]
+pub struct DiagnosticNode {
+ pub base: Object,
+ /// The level.
+ pub level: DiagnosticLevel,
+ /// The span at which to report an error.
+ pub span: Span,
+ /// The diagnostic message.
+ pub message: TString,
+}
+
+impl Diagnostic {
+ pub fn new(level: DiagnosticLevel, span: Span, message: TString) {
+ todo!()
+ }
+
+ pub fn bug(span: Span) -> DiagnosticBuilder {
+ todo!()
+ }
+
+ pub fn error(span: Span) -> DiagnosticBuilder {
+ todo!()
+ }
+
+ pub fn warning(span: Span) -> DiagnosticBuilder {
+ todo!()
+ }
+
+ pub fn note(span: Span) -> DiagnosticBuilder {
+ todo!()
+ }
+
+ pub fn help(span: Span) -> DiagnosticBuilder {
+ todo!()
+ }
+}
+
+/// A wrapper around std::stringstream to build a diagnostic.
+pub struct DiagnosticBuilder {
+ /// The level.
+ pub level: DiagnosticLevel,
+
+ /// The source name.
+ pub source_name: SourceName,
+
+ /// The span of the diagnostic.
+ pub span: Span,
+}
+
+// /*! \brief Display diagnostics in a given display format.
+// *
+// * A diagnostic renderer is responsible for converting the
+// * raw diagnostics into consumable output.
+// *
+// * For example the terminal renderer will render a sequence
+// * of compiler diagnostics to std::out and std::err in
+// * a human readable form.
+// */
+// class DiagnosticRendererNode : public Object {
+// public:
+// TypedPackedFunc<void(DiagnosticContext ctx)> renderer;
+
+// // override attr visitor
+// void VisitAttrs(AttrVisitor* v) {}
+
+// static constexpr const char* _type_key = "DiagnosticRenderer";
+// TVM_DECLARE_FINAL_OBJECT_INFO(DiagnosticRendererNode, Object);
+// };
+
+// class DiagnosticRenderer : public ObjectRef {
+// public:
+// TVM_DLL DiagnosticRenderer(TypedPackedFunc<void(DiagnosticContext ctx)> render);
+// TVM_DLL DiagnosticRenderer()
+// : DiagnosticRenderer(TypedPackedFunc<void(DiagnosticContext ctx)>()) {}
+
+// void Render(const DiagnosticContext& ctx);
+
+// DiagnosticRendererNode* operator->() {
+// CHECK(get() != nullptr);
+// return static_cast<DiagnosticRendererNode*>(get_mutable());
+// }
+
+// TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(DiagnosticRenderer, ObjectRef, DiagnosticRendererNode);
+// };
+
+// @tvm._ffi.register_object("DiagnosticRenderer")
+// class DiagnosticRenderer(Object):
+// """
+// A diagnostic renderer, which given a diagnostic context produces a "rendered"
+// form of the diagnostics for either human or computer consumption.
+// """
+
+// def __init__(self, render_func):
+// self.__init_handle_by_constructor__(_ffi_api.DiagnosticRenderer, render_func)
+
+// def render(self, ctx):
+// """
+// Render the provided context.
+
+// Params
+// ------
+// ctx: DiagnosticContext
+// The diagnostic context to render.
+// """
+// return _ffi_api.DiagnosticRendererRender(self, ctx
+pub type DiagnosticRenderer = ObjectRef;
+
+#[repr(C)]
+#[derive(Object)]
+#[ref_name = "DiagnosticContext"]
+#[type_key = "DiagnosticContext"]
+/// A diagnostic context for recording errors against a source file.
+pub struct DiagnosticContextNode {
+ // The base type.
+ pub base: Object,
+
+ /// The Module to report against.
+ pub module: IRModule,
+
+ /// The set of diagnostics to report.
+ pub diagnostics: Array<Diagnostic>,
+
+ /// The renderer set for the context.
+ pub renderer: DiagnosticRenderer,
+}
+
+// Get the the diagnostic renderer.
+external! {
+ #[name("node.ArrayGetItem")]
+ fn get_renderer() -> DiagnosticRenderer;
+
+ #[name("diagnostics.DiagnosticRenderer")]
+ fn diagnostic_renderer(func: Function) -> DiagnosticRenderer;
+
+ #[name("diagnostics.Emit")]
+ fn emit(ctx: DiagnosticContext, diagnostic: Diagnostic) -> ();
+
+ #[name("diagnostics.DiagnosticContextRender")]
+ fn diagnostic_context_render(ctx: DiagnosticContext) -> ();
+}
+
+/// A diagnostic context which records active errors
+/// and contains a renderer.
+impl DiagnosticContext {
+ pub fn new(module: IRModule, renderer: DiagnosticRenderer) {
+ todo!()
+ }
+
+ pub fn default(module: IRModule) -> DiagnosticContext {
+ todo!()
+ }
+
+ /// Emit a diagnostic.
+ pub fn emit(&mut self, diagnostic: Diagnostic) -> Result<()> {
+ emit(self.clone(), diagnostic)
+ }
+
+ /// Render the errors and raise a DiagnosticError exception.
+ pub fn render(&mut self) -> Result<()> {
+ diagnostic_context_render(self.clone())
+ }
+
+ /// Emit a diagnostic and then immediately attempt to render all errors.
+ pub fn emit_fatal(&mut self, diagnostic: Diagnostic) -> Result<()> {
+ self.emit(diagnostic)?;
+ self.render()?;
+ Ok(())
+ }
+}
+
+// Sets a custom renderer for diagnostics.
+
+// Params
+// ------
+// render_func: Option[Callable[[DiagnosticContext], None]]
+// If the render_func is None it will remove the current custom renderer
+// and return to default behavior.
+fn override_renderer<F>(opt_func: Option<F>) -> Result<()>
+where F: Fn(DiagnosticContext) -> ()
+{
+ todo!()
+ // fn ()
+ // diagnostic_renderer(func)
+ // if render_func:
+
+ // def _render_factory():
+ // return DiagnosticRenderer(render_func)
+
+ // register_func("diagnostics.OverrideRenderer", _render_factory, override=True)
+ // else:
+ // _ffi_api.ClearRenderer()
+}
diff --git a/rust/tvm/src/ir/mod.rs b/rust/tvm/src/ir/mod.rs
index 126d0fa..8450bd7 100644
--- a/rust/tvm/src/ir/mod.rs
+++ b/rust/tvm/src/ir/mod.rs
@@ -20,6 +20,7 @@
pub mod arith;
pub mod attrs;
pub mod expr;
+pub mod diagnostics;
pub mod function;
pub mod module;
pub mod op;
[incubator-tvm] 19/23: WIP
Posted by jr...@apache.org.
This is an automated email from the ASF dual-hosted git repository.
jroesch pushed a commit to branch cargo-build
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
commit a9ee3cb34c020a4debe75fc9a194303f22d00892
Author: Jared Roesch <ro...@gmail.com>
AuthorDate: Thu Oct 22 11:48:34 2020 -0700
WIP
---
rust/tvm-macros/Cargo.toml | 2 +-
rust/tvm-macros/src/external.rs | 43 +++++++++++++++++++++++++++++++++--------
rust/tvm-macros/src/lib.rs | 1 +
rust/tvm-rt/src/object/mod.rs | 2 +-
rust/tvm/src/ir/module.rs | 16 +++++++++++----
5 files changed, 50 insertions(+), 14 deletions(-)
diff --git a/rust/tvm-macros/Cargo.toml b/rust/tvm-macros/Cargo.toml
index 63b8472..8e97d3b 100644
--- a/rust/tvm-macros/Cargo.toml
+++ b/rust/tvm-macros/Cargo.toml
@@ -33,5 +33,5 @@ proc-macro = true
goblin = "^0.2"
proc-macro2 = "^1.0"
quote = "^1.0"
-syn = { version = "1.0.17", features = ["full", "extra-traits"] }
+syn = { version = "^1.0", features = ["full", "parsing", "extra-traits"] }
proc-macro-error = "^1.0"
diff --git a/rust/tvm-macros/src/external.rs b/rust/tvm-macros/src/external.rs
index de8ada3..44a242c 100644
--- a/rust/tvm-macros/src/external.rs
+++ b/rust/tvm-macros/src/external.rs
@@ -21,9 +21,28 @@ use proc_macro_error::abort;
use quote::quote;
use syn::parse::{Parse, ParseStream, Result};
-use syn::{FnArg, Generics, Ident, Lit, Meta, NestedMeta, Pat, ReturnType, TraitItemMethod, Type};
+use syn::{Token, FnArg, Signature, Attribute, token::Semi, Visibility, Generics, Ident, Lit, Meta, NestedMeta, Pat, ReturnType, TraitItemMethod, Type};
+
+struct ExternalItem {
+ attrs: Vec<Attribute>,
+ visibility: Visibility,
+ sig: Signature,
+}
+
+impl Parse for ExternalItem {
+ fn parse(input: ParseStream) -> Result<Self> {
+ let item = ExternalItem {
+ attrs: input.call(Attribute::parse_outer)?,
+ visibility: input.parse()?,
+ sig: input.parse()?,
+ };
+ let _semi: Semi = input.parse()?;
+ Ok(item)
+ }
+}
struct External {
+ visibility: Visibility,
tvm_name: String,
ident: Ident,
generics: Generics,
@@ -33,7 +52,8 @@ struct External {
impl Parse for External {
fn parse(input: ParseStream) -> Result<Self> {
- let method: TraitItemMethod = input.parse()?;
+ let method: ExternalItem = input.parse()?;
+ let visibility = method.visibility;
assert_eq!(method.attrs.len(), 1);
let sig = method.sig;
let tvm_name = method.attrs[0].parse_meta()?;
@@ -48,8 +68,7 @@ impl Parse for External {
}
_ => panic!(),
};
- assert_eq!(method.default, None);
- assert!(method.semi_token != None);
+
let ident = sig.ident;
let generics = sig.generics;
let inputs = sig
@@ -61,6 +80,7 @@ impl Parse for External {
let ret_type = sig.output;
Ok(External {
+ visibility,
tvm_name,
ident,
generics,
@@ -99,6 +119,7 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let mut items = Vec::new();
for external in &ext_input.externs {
+ let visibility = &external.visibility;
let name = &external.ident;
let global_name = format!("global_{}", external.ident);
let global_name = Ident::new(&global_name, Span::call_site());
@@ -127,15 +148,21 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let ty: Type = *pat_type.ty.clone();
(ident, ty)
}
- _ => panic!(),
+ _ => abort! { pat_type,
+ "Only supports type parameters."
+ }
},
- _ => panic!(),
+ pat => abort! {
+ pat, "invalid pattern type for function";
+
+ note = "{:?} is not allowed here", pat;
+ }
})
.unzip();
let ret_type = match &external.ret_type {
ReturnType::Type(_, rtype) => *rtype.clone(),
- _ => panic!(),
+ ReturnType::Default => syn::parse_str::<Type>("()").unwrap(),
};
let global = quote! {
@@ -150,7 +177,7 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
items.push(global);
let wrapper = quote! {
- pub fn #name<#(#ty_params),*>(#(#args : #tys),*) -> #result_type<#ret_type> {
+ #visibility fn #name<#(#ty_params),*>(#(#args : #tys),*) -> #result_type<#ret_type> {
let func_ref: #tvm_rt_crate::Function = #global_name.clone();
let func_ref: Box<dyn Fn(#(#tys),*) -> #result_type<#ret_type>> = func_ref.into();
let res: #ret_type = func_ref(#(#args),*)?;
diff --git a/rust/tvm-macros/src/lib.rs b/rust/tvm-macros/src/lib.rs
index ab75c92..32f2839 100644
--- a/rust/tvm-macros/src/lib.rs
+++ b/rust/tvm-macros/src/lib.rs
@@ -18,6 +18,7 @@
*/
use proc_macro::TokenStream;
+use proc_macro_error::proc_macro_error;
mod external;
mod import_module;
diff --git a/rust/tvm-rt/src/object/mod.rs b/rust/tvm-rt/src/object/mod.rs
index 46e0342..e48c017 100644
--- a/rust/tvm-rt/src/object/mod.rs
+++ b/rust/tvm-rt/src/object/mod.rs
@@ -88,7 +88,7 @@ pub trait IsObjectRef:
external! {
#[name("ir.DebugPrint")]
- fn debug_print(object: ObjectRef) -> CString;
+ pub fn debug_print(object: ObjectRef) -> CString;
#[name("node.StructuralHash")]
fn structural_hash(object: ObjectRef, map_free_vars: bool) -> ObjectRef;
#[name("node.StructuralEqual")]
diff --git a/rust/tvm/src/ir/module.rs b/rust/tvm/src/ir/module.rs
index 8918bdc..3b60b0c 100644
--- a/rust/tvm/src/ir/module.rs
+++ b/rust/tvm/src/ir/module.rs
@@ -31,8 +31,10 @@ use crate::runtime::{external, Object, ObjectRef};
use super::expr::GlobalVar;
use super::function::BaseFunc;
use super::source_map::SourceMap;
+use super::{ty::GlobalTypeVar, relay};
// TODO(@jroesch): define type
+
type TypeData = ObjectRef;
type GlobalTypeVar = ObjectRef;
@@ -64,7 +66,7 @@ external! {
fn parse_expression(file_name: TVMString, source: TVMString) -> IRModule;
// Module methods
#[name("ir.Module_Add")]
- fn module_add_def(module: IRModule, type_name: GlobalVar, expr: relay::Expr, update: bool) -> ();
+ fn module_add(module: IRModule, type_name: GlobalVar, expr: relay::Expr, update: bool) -> ();
#[name("ir.Module_AddDef")]
fn module_add_def(module: IRModule, type_name: GlobalTypeVar, type_data: TypeData, update: bool) -> ();
#[name("ir.Module_GetGlobalVar")]
@@ -78,15 +80,15 @@ external! {
#[name("ir.Module_GetGlobalTypeVars")]
fn module_get_global_type_vars() -> Array<GlobalTypeVar>;
#[name("ir.Module_ContainGlobalVar")]
- fn module_get_global_var(name: TVMString) -> bool;
+ fn module_contains_global_var(name: TVMString) -> bool;
#[name("ir.Module_ContainGlobalTypeVar")]
- fn module_get_global_type_var(name: TVMString) -> bool;
+ fn module_contains_global_type_var(name: TVMString) -> bool;
#[name("ir.Module_LookupDef")]
fn module_lookup_def(module: IRModule, global: GlobalTypeVar) -> TypeDef;
#[name("ir.Module_LookupDef_str")]
fn module_lookup_def_str(module: IRModule, global: GlobalTypeVar) -> TypeDef;
#[name("ir.Module_LookupTag")]
- fn module_lookup_tag(module: IRModule, tag: i32) -> Constructor;
+ fn module_lookup_tag(module: IRModule, tag: i32) -> relay::Constructor;
#[name("ir.Module_FromExpr")]
fn module_from_expr(expr: relay::Expr, funcs: Map<GlobalVar, BaseFunc>, types: Map<GlobalTypeVar, TypeData>) -> IRModule;
#[name("ir.Module_Import")]
@@ -145,3 +147,9 @@ impl IRModule {
module_lookup_str(self.clone(), name.into())
}
}
+
+#[cfg(test)]
+mod tests {
+ // #[test]
+ // fn
+}
[incubator-tvm] 17/23: Fix some CR
Posted by jr...@apache.org.
This is an automated email from the ASF dual-hosted git repository.
jroesch pushed a commit to branch cargo-build
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
commit 8e295b79c2e693d7190751532824b463dc9373e5
Author: Jared Roesch <jr...@octoml.ai>
AuthorDate: Mon Oct 19 19:52:20 2020 -0700
Fix some CR
---
rust/tvm/src/ir/diagnostics/codespan.rs | 6 ++++--
rust/tvm/src/lib.rs | 4 ++--
2 files changed, 6 insertions(+), 4 deletions(-)
diff --git a/rust/tvm/src/ir/diagnostics/codespan.rs b/rust/tvm/src/ir/diagnostics/codespan.rs
index 9a31691..54fd336 100644
--- a/rust/tvm/src/ir/diagnostics/codespan.rs
+++ b/rust/tvm/src/ir/diagnostics/codespan.rs
@@ -17,8 +17,10 @@
* under the License.
*/
-/// A TVM diagnostics renderer which uses the Rust `codespan`
-/// library to produce error messages.
+/// A TVM diagnostics renderer which uses the Rust `codespan` library
+/// to produce error messages.
+///
+///
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
diff --git a/rust/tvm/src/lib.rs b/rust/tvm/src/lib.rs
index ec80ece..7e0682b 100644
--- a/rust/tvm/src/lib.rs
+++ b/rust/tvm/src/lib.rs
@@ -24,7 +24,7 @@
//! One particular use case is that given optimized deep learning model artifacts,
//! (compiled with TVM) which include a shared library
//! `lib.so`, `graph.json` and a byte-array `param.params`, one can load them
-//! in Rust idomatically to create a TVM Graph Runtime and
+//! in Rust idiomatically to create a TVM Graph Runtime and
//! run the model for some inputs and get the
//! desired predictions *all in Rust*.
//!
@@ -53,7 +53,7 @@ macro_rules! export {
($($fn_name:expr),*) => {
pub fn tvm_export(ns: &str) -> Result<(), tvm::Error> {
$(
- let name = String::from(ns) + ::std::stringify!($fn_name);
+ let name = String::fromwe(ns) + ::std::stringify!($fn_name);
tvm::runtime::function::register_override($fn_name, name, true)?;
)*
Ok(())
[incubator-tvm] 09/23: Improve Rust bindings
Posted by jr...@apache.org.
This is an automated email from the ASF dual-hosted git repository.
jroesch pushed a commit to branch cargo-build
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
commit 4cd1bbc79a71873676927c22528f5911e3f4072d
Author: Jared Roesch <ro...@gmail.com>
AuthorDate: Thu Oct 15 20:24:43 2020 -0700
Improve Rust bindings
---
rust/tvm/src/ir/diagnostics/codespan.rs | 131 +++++++++++++++++++++
.../src/ir/{diagnostics.rs => diagnostics/mod.rs} | 69 +----------
rust/tvm/src/ir/source_map.rs | 3 +-
rust/tvm/test.rly | 3 +-
src/ir/diagnostic.cc | 1 +
5 files changed, 138 insertions(+), 69 deletions(-)
diff --git a/rust/tvm/src/ir/diagnostics/codespan.rs b/rust/tvm/src/ir/diagnostics/codespan.rs
new file mode 100644
index 0000000..80a8784
--- /dev/null
+++ b/rust/tvm/src/ir/diagnostics/codespan.rs
@@ -0,0 +1,131 @@
+use std::collections::HashMap;
+use std::sync::{Arc, Mutex};
+
+use codespan_reporting::diagnostic::{Diagnostic as CDiagnostic, Label, Severity};
+use codespan_reporting::files::SimpleFiles;
+use codespan_reporting::term::termcolor::{ColorChoice, StandardStream};
+
+use crate::ir::source_map::*;
+use super::*;
+
+enum StartOrEnd {
+ Start,
+ End,
+}
+
+enum FileSpanToByteRange {
+ AsciiSource,
+ Utf8 {
+ /// Map character regions which are larger then 1-byte to length.
+ lengths: HashMap<isize, isize>,
+ source: String,
+ }
+}
+
+impl FileSpanToByteRange {
+ fn new(source: String) -> FileSpanToByteRange {
+ let mut last_index = 0;
+ let mut is_ascii = true;
+ if source.is_ascii() {
+ FileSpanToByteRange::AsciiSource
+ } else {
+ panic!()
+ }
+
+ // for (index, _) in source.char_indices() {
+ // if last_index - 1 != last_index {
+ // is_ascii = false;
+ // } else {
+ // panic!();
+ // }
+ // last_index = index;
+ // }
+ }
+}
+
+struct SpanToByteRange {
+ map: HashMap<String, FileSpanToByteRange>
+}
+
+impl SpanToByteRange {
+ fn new() -> SpanToByteRange {
+ SpanToByteRange { map: HashMap::new() }
+ }
+
+ pub fn add_source(&mut self, source: Source) {
+ let source_name: String = source.source_name.name.as_str().expect("foo").into();
+
+ if self.map.contains_key(&source_name) {
+ panic!()
+ } else {
+ let source = source.source.as_str().expect("fpp").into();
+ self.map.insert(source_name, FileSpanToByteRange::new(source));
+ }
+ }
+}
+
+struct ByteRange<FileId> {
+ file_id: FileId,
+ start_pos: usize,
+ end_pos: usize,
+}
+
+
+pub fn to_diagnostic(diag: super::Diagnostic) -> CDiagnostic<String> {
+ let severity = match diag.level {
+ DiagnosticLevel::Error => Severity::Error,
+ DiagnosticLevel::Warning => Severity::Warning,
+ DiagnosticLevel::Note => Severity::Note,
+ DiagnosticLevel::Help => Severity::Help,
+ DiagnosticLevel::Bug => Severity::Bug,
+ };
+
+ let file_id = "foo".into(); // diag.span.source_name;
+
+ let message: String = diag.message.as_str().unwrap().into();
+ let inner_message: String = "expected `String`, found `Nat`".into();
+ let diagnostic = CDiagnostic::new(severity)
+ .with_message(message)
+ .with_code("EXXX")
+ .with_labels(vec![
+ Label::primary(file_id, 328..331).with_message(inner_message)
+ ]);
+
+ diagnostic
+}
+
+struct DiagnosticState {
+ files: SimpleFiles<String, String>,
+ span_map: SpanToByteRange,
+}
+
+impl DiagnosticState {
+ fn new() -> DiagnosticState {
+ DiagnosticState {
+ files: SimpleFiles::new(),
+ span_map: SpanToByteRange::new(),
+ }
+ }
+}
+
+fn renderer(state: &mut DiagnosticState, diag_ctx: DiagnosticContext) {
+ let source_map = diag_ctx.module.source_map.clone();
+ for diagnostic in diag_ctx.diagnostics.clone() {
+ match source_map.source_map.get(&diagnostic.span.source_name) {
+ Err(err) => panic!(),
+ Ok(source) => state.span_map.add_source(source),
+ }
+ println!("Diagnostic: {}", diagnostic.message);
+ }
+}
+
+pub fn init() -> Result<()> {
+ let diag_state = Arc::new(Mutex::new(DiagnosticState::new()));
+ let render_fn = move |diag_ctx: DiagnosticContext| {
+ // let mut guard = diag_state.lock().unwrap();
+ // renderer(&mut *guard, diag_ctx);
+ };
+
+ override_renderer(Some(render_fn))?;
+ Ok(())
+}
diff --git a/rust/tvm/src/ir/diagnostics.rs b/rust/tvm/src/ir/diagnostics/mod.rs
similarity index 76%
rename from rust/tvm/src/ir/diagnostics.rs
rename to rust/tvm/src/ir/diagnostics/mod.rs
index 4975a45..fce214a 100644
--- a/rust/tvm/src/ir/diagnostics.rs
+++ b/rust/tvm/src/ir/diagnostics/mod.rs
@@ -18,7 +18,7 @@
*/
use super::module::IRModule;
-use super::span::Span;
+use super::span::*;
use crate::runtime::function::Result;
use crate::runtime::object::{Object, ObjectPtr, ObjectRef};
use crate::runtime::{
@@ -32,7 +32,7 @@ use crate::runtime::{
/// and the DiagnosticRenderer.
use tvm_macros::{external, Object};
-type SourceName = ObjectRef;
+pub mod codespan;
// Get the the diagnostic renderer.
external! {
@@ -229,68 +229,3 @@ where
}
}
}
-
-pub mod codespan {
- use super::*;
-
- use codespan_reporting::diagnostic::{Diagnostic as CDiagnostic, Label, Severity};
- use codespan_reporting::files::SimpleFiles;
- use codespan_reporting::term::termcolor::{ColorChoice, StandardStream};
-
- enum StartOrEnd {
- Start,
- End,
- }
-
- // struct SpanToBytes {
- // inner: HashMap<std::String, HashMap<usize, (StartOrEnd,
- // }
-
- struct ByteRange<FileId> {
- file_id: FileId,
- start_pos: usize,
- end_pos: usize,
- }
-
- // impl SpanToBytes {
- // fn to_byte_pos(&self, span: tvm::ir::Span) -> ByteRange<FileId> {
-
- // }
- // }
-
- pub fn to_diagnostic(diag: super::Diagnostic) -> CDiagnostic<String> {
- let severity = match diag.level {
- DiagnosticLevel::Error => Severity::Error,
- DiagnosticLevel::Warning => Severity::Warning,
- DiagnosticLevel::Note => Severity::Note,
- DiagnosticLevel::Help => Severity::Help,
- DiagnosticLevel::Bug => Severity::Bug,
- };
-
- let file_id = "foo".into(); // diag.span.source_name;
-
- let message: String = diag.message.as_str().unwrap().into();
- let inner_message: String = "expected `String`, found `Nat`".into();
- let diagnostic = CDiagnostic::new(severity)
- .with_message(message)
- .with_code("EXXX")
- .with_labels(vec![
- Label::primary(file_id, 328..331).with_message(inner_message)
- ]);
-
- diagnostic
- }
-
- pub fn init() -> Result<()> {
- let mut files: SimpleFiles<String, String> = SimpleFiles::new();
- let render_fn = move |diag_ctx: DiagnosticContext| {
- let source_map = diag_ctx.module.source_map.clone();
- for diagnostic in diag_ctx.diagnostics.clone() {
- println!("Diagnostic: {}", diagnostic.message);
- }
- };
-
- override_renderer(Some(render_fn))?;
- Ok(())
- }
-}
diff --git a/rust/tvm/src/ir/source_map.rs b/rust/tvm/src/ir/source_map.rs
index 56c0830..ebe7e46 100644
--- a/rust/tvm/src/ir/source_map.rs
+++ b/rust/tvm/src/ir/source_map.rs
@@ -19,6 +19,7 @@
use crate::runtime::map::Map;
use crate::runtime::object::Object;
+use crate::runtime::string::{String as TString};
use super::span::{SourceName, Span};
@@ -37,7 +38,7 @@ pub struct SourceNode {
pub source_name: SourceName,
/// The raw source. */
- source: String,
+ pub source: TString,
// A mapping of line breaks into the raw source.
// std::vector<std::pair<int, int>> line_map;
diff --git a/rust/tvm/test.rly b/rust/tvm/test.rly
index d8b7c69..e9407b0 100644
--- a/rust/tvm/test.rly
+++ b/rust/tvm/test.rly
@@ -1,2 +1,3 @@
#[version = "0.0.5"]
-fn @main(%x: int32) -> float32 { %x }
+
+def @main(%x: int32) -> float32 { %x }
diff --git a/src/ir/diagnostic.cc b/src/ir/diagnostic.cc
index f9299e3..a4fee1e 100644
--- a/src/ir/diagnostic.cc
+++ b/src/ir/diagnostic.cc
@@ -113,6 +113,7 @@ TVM_REGISTER_GLOBAL("diagnostics.DiagnosticRendererRender")
});
DiagnosticContext::DiagnosticContext(const IRModule& module, const DiagnosticRenderer& renderer) {
+ CHECK(renderer.defined()) << "can not initialize a diagnostic renderer with a null function";
auto n = make_object<DiagnosticContextNode>();
n->module = module;
n->renderer = renderer;
[incubator-tvm] 16/23: More cleanup
Posted by jr...@apache.org.
This is an automated email from the ASF dual-hosted git repository.
jroesch pushed a commit to branch cargo-build
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
commit 187435099c4d80be2c200791e7853f8518f421e1
Author: Jared Roesch <jr...@octoml.ai>
AuthorDate: Fri Oct 16 16:45:12 2020 -0700
More cleanup
---
rust/compiler-ext/src/lib.rs | 22 ++++++++--------------
rust/tvm/src/lib.rs | 7 ++++---
2 files changed, 12 insertions(+), 17 deletions(-)
diff --git a/rust/compiler-ext/src/lib.rs b/rust/compiler-ext/src/lib.rs
index 346f40f..5f83f7b 100644
--- a/rust/compiler-ext/src/lib.rs
+++ b/rust/compiler-ext/src/lib.rs
@@ -18,25 +18,19 @@
*/
use env_logger;
-use tvm;
-use tvm::runtime::function::register_override;
+use tvm::export;
-fn test_fn() -> Result<(), tvm::Error> {
- println!("Hello Greg from Rust!");
- Ok(())
+fn diagnostics() -> Result<(), tvm::Error> {
+ tvm::ir::diagnostics::codespan::init()
}
-fn test_fn2(message: tvm::runtime::string::String) -> Result<(), tvm::Error> {
- println!("The message: {}", message);
- Ok(())
-}
-
-tvm::export!(test_fn, test_fn2);
+export!(diagnostics);
#[no_mangle]
-fn compiler_ext_initialize() -> i32 {
+extern fn compiler_ext_initialize() -> i32 {
let _ = env_logger::try_init();
- tvm_export("rust_ext").expect("failed to initialize Rust compiler_ext");
- log::debug!("done!");
+ tvm_export("rust_ext")
+ .expect("failed to initialize the Rust compiler extensions.");
+ log::debug!("Loaded the Rust compiler extension.");
return 0;
}
diff --git a/rust/tvm/src/lib.rs b/rust/tvm/src/lib.rs
index d193f09..ec80ece 100644
--- a/rust/tvm/src/lib.rs
+++ b/rust/tvm/src/lib.rs
@@ -50,10 +50,11 @@ pub use runtime::version;
#[macro_export]
macro_rules! export {
- ($($fn_names:expr),*) => {
+ ($($fn_name:expr),*) => {
pub fn tvm_export(ns: &str) -> Result<(), tvm::Error> {
$(
- register_override($fn_name, concat!($ns, stringfy!($fn_name)), true)?;
+ let name = String::from(ns) + ::std::stringify!($fn_name);
+ tvm::runtime::function::register_override($fn_name, name, true)?;
)*
Ok(())
}
@@ -65,7 +66,7 @@ macro_rules! export_mod {
($ns:expr, $($mod_name:expr),*) => {
pub fn tvm_mod_export() -> Result<(), tvm::Error> {
$(
- $mod_names::tvm_export($ns)?;
+ $mod_name::tvm_export($ns)?;
)*
Ok(())
}
[incubator-tvm] 06/23: Update CMake and delete old API
Posted by jr...@apache.org.
This is an automated email from the ASF dual-hosted git repository.
jroesch pushed a commit to branch cargo-build
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
commit db245535a4ea5b8bed3cdddd26a94825d2fcc9b5
Author: Jared Roesch <ro...@gmail.com>
AuthorDate: Thu Oct 15 01:03:03 2020 -0700
Update CMake and delete old API
---
CMakeLists.txt | 1 +
cmake/modules/RustExt.cmake | 25 ++-
include/tvm/parser/source_map.h | 2 -
rust/compiler-ext/Cargo.toml | 5 +-
rust/compiler-ext/src/lib.rs | 334 ++--------------------------------------
rust/tvm-rt/Cargo.toml | 15 +-
rust/tvm-sys/Cargo.toml | 1 +
rust/tvm-sys/build.rs | 1 +
rust/tvm/Cargo.toml | 22 ++-
rust/tvm/src/bin/tyck.rs | 1 -
rust/tvm/src/ir/diagnostics.rs | 42 +++--
rust/tvm/src/ir/mod.rs | 2 +-
rust/tvm/src/ir/relay/mod.rs | 3 +-
rust/tvm/src/ir/source_map.rs | 61 ++++++++
rust/tvm/src/ir/span.rs | 95 +++++++++---
src/ir/expr.cc | 11 ++
src/parser/source_map.cc | 11 --
17 files changed, 237 insertions(+), 395 deletions(-)
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 9f82754..58bb2f7 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -353,6 +353,7 @@ include(cmake/modules/contrib/ArmComputeLib.cmake)
include(cmake/modules/contrib/TensorRT.cmake)
include(cmake/modules/Git.cmake)
include(cmake/modules/LibInfo.cmake)
+include(cmake/modules/RustExt.cmake)
include(CheckCXXCompilerFlag)
if(NOT MSVC)
diff --git a/cmake/modules/RustExt.cmake b/cmake/modules/RustExt.cmake
index 45e46bd..2ad726e9 100644
--- a/cmake/modules/RustExt.cmake
+++ b/cmake/modules/RustExt.cmake
@@ -1,7 +1,14 @@
-if(USE_RUST_EXT)
- set(RUST_SRC_DIR "rust")
- set(CARGO_OUT_DIR "rust/target"
- set(COMPILER_EXT_PATH "${CARGO_OUT_DIR}/target/release/libcompiler_ext.dylib")
+if(USE_RUST_EXT AND NOT USE_RUST_EXT EQUAL OFF)
+ set(RUST_SRC_DIR "${CMAKE_SOURCE_DIR}/rust")
+ set(CARGO_OUT_DIR "${CMAKE_SOURCE_DIR}/rust/target")
+
+ if(USE_RUST_EXT STREQUAL "STATIC")
+ set(COMPILER_EXT_PATH "${CARGO_OUT_DIR}/release/libcompiler_ext.a")
+ elseif(USE_RUST_EXT STREQUAL "DYNAMIC")
+ set(COMPILER_EXT_PATH "${CARGO_OUT_DIR}/release/libcompiler_ext.so")
+ else()
+ message(FATAL_ERROR "invalid setting for RUST_EXT")
+ endif()
add_custom_command(
OUTPUT "${COMPILER_EXT_PATH}"
@@ -9,5 +16,11 @@ if(USE_RUST_EXT)
MAIN_DEPENDENCY "${RUST_SRC_DIR}"
WORKING_DIRECTORY "${RUST_SRC_DIR}/compiler-ext")
- target_link_libraries(tvm "${COMPILER_EXT_PATH}" PRIVATE)
-endif(USE_RUST_EXT)
+ add_custom_target(rust_ext ALL DEPENDS "${COMPILER_EXT_PATH}")
+
+ # TODO(@jroesch, @tkonolige): move this to CMake target
+ # target_link_libraries(tvm "${COMPILER_EXT_PATH}" PRIVATE)
+ list(APPEND TVM_LINKER_LIBS ${COMPILER_EXT_PATH})
+
+ add_definitions(-DRUST_COMPILER_EXT=1)
+endif()
diff --git a/include/tvm/parser/source_map.h b/include/tvm/parser/source_map.h
index 424af5c..a160c22 100644
--- a/include/tvm/parser/source_map.h
+++ b/include/tvm/parser/source_map.h
@@ -103,8 +103,6 @@ class SourceMap : public ObjectRef {
TVM_DLL SourceMap() : SourceMap(Map<SourceName, Source>()) {}
- TVM_DLL static SourceMap Global();
-
void Add(const Source& source);
SourceMapNode* operator->() {
diff --git a/rust/compiler-ext/Cargo.toml b/rust/compiler-ext/Cargo.toml
index 76d10eb..3b13bc5 100644
--- a/rust/compiler-ext/Cargo.toml
+++ b/rust/compiler-ext/Cargo.toml
@@ -6,8 +6,11 @@ edition = "2018"
# TODO(@jroesch): would be cool to figure out how to statically link instead.
[lib]
-crate-type = ["cdylib"]
+crate-type = ["staticlib", "cdylib"]
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
+tvm = { path = "../tvm", default-features = false, features = ["static-linking"] }
+log = "*"
+env_logger = "*"
diff --git a/rust/compiler-ext/src/lib.rs b/rust/compiler-ext/src/lib.rs
index 58bdd0c..3e37d21 100644
--- a/rust/compiler-ext/src/lib.rs
+++ b/rust/compiler-ext/src/lib.rs
@@ -17,321 +17,19 @@
* under the License.
*/
- use std::os::raw::c_int;
- use tvm::initialize;
- use tvm::ir::{tir, PrimExpr};
- use tvm::runtime::function::register_override;
- use tvm::runtime::map::Map;
- use tvm::runtime::object::{IsObject, IsObjectRef};
-
- use ordered_float::NotNan;
-
- mod interval;
- mod math;
-
- use math::{BoundsMap, Expr, RecExpr};
- use tvm::ir::arith::ConstIntBound;
- use tvm_rt::{ObjectRef, array::Array};
-
- macro_rules! downcast_match {
- ($id:ident; { $($t:ty => $arm:expr $(,)? )+ , else => $default:expr }) => {
- $( if let Ok($id) = $id.downcast_clone::<$t>() { $arm } else )+
- { $default }
- }
- }
-
- #[derive(Default)]
- struct VarMap {
- vars: Vec<(tvm::ir::tir::Var, egg::Symbol)>,
- objs: Vec<ObjectRef>,
- }
-
- impl VarMap {
- // FIXME this should eventually do the right thing for TVM variables
- // right now it depends on them having unique names
- fn make_symbol(&mut self, var: tvm::ir::tir::Var) -> egg::Symbol {
- let sym = egg::Symbol::from(var.name_hint.as_str().unwrap());
- for (_, sym2) in &self.vars {
- if sym == *sym2 {
- return sym;
- }
- }
-
- self.vars.push((var, sym));
- sym
- }
-
- fn get_symbol(&self, sym: egg::Symbol) -> tvm::ir::tir::Var {
- for (v, sym2) in &self.vars {
- if sym == *sym2 {
- return v.clone();
- }
- }
- panic!("Should have found a var")
- }
-
- fn push_obj(&mut self, obj: impl IsObjectRef) -> usize {
- let i = self.objs.len();
- self.objs.push(obj.upcast());
- i
- }
-
- fn get_obj<T: IsObjectRef>(&self, i: usize) -> T {
- self.objs[i].clone().downcast().expect("bad downcast")
- }
- }
-
- fn to_egg(vars: &mut VarMap, prim: &PrimExpr) -> RecExpr {
- fn build(vars: &mut VarMap, p: &PrimExpr, recexpr: &mut RecExpr) -> egg::Id {
- macro_rules! r {
- ($e:expr) => {
- build(vars, &$e, recexpr)
- };
- }
-
- let dt = recexpr.add(Expr::DataType(p.datatype));
- let e = downcast_match!(p; {
- tir::Add => Expr::Add([dt, r!(p.a), r!(p.b)]),
- tir::Sub => Expr::Sub([dt, r!(p.a), r!(p.b)]),
- tir::Mul => Expr::Mul([dt, r!(p.a), r!(p.b)]),
-
- tir::Div => Expr::Div([dt, r!(p.a), r!(p.b)]),
- tir::Mod => Expr::Mod([dt, r!(p.a), r!(p.b)]),
- tir::FloorDiv => Expr::FloorDiv([dt, r!(p.a), r!(p.b)]),
- tir::FloorMod => Expr::FloorMod([dt, r!(p.a), r!(p.b)]),
-
- tir::Min => Expr::Min([dt, r!(p.a), r!(p.b)]),
- tir::Max => Expr::Max([dt, r!(p.a), r!(p.b)]),
-
- tir::Ramp => Expr::Ramp([dt, r!(p.start), r!(p.stride), recexpr.add(Expr::Int(p.lanes.into()))]),
- tir::Select => Expr::Select([dt, r!(p.condition), r!(p.true_value), r!(p.false_value)]),
-
- tir::Eq => Expr::Equal([dt, r!(p.a), r!(p.b)]),
- tir::Ne => Expr::NotEqual([dt, r!(p.a), r!(p.b)]),
- tir::Lt => Expr::Less([dt, r!(p.a), r!(p.b)]),
- tir::Le => Expr::LessEqual([dt, r!(p.a), r!(p.b)]),
- tir::Gt => Expr::Greater([dt, r!(p.a), r!(p.b)]),
- tir::Ge => Expr::GreaterEqual([dt, r!(p.a), r!(p.b)]),
-
- tir::And => Expr::And([dt, r!(p.a), r!(p.b)]),
- tir::Or => Expr::Or([dt, r!(p.a), r!(p.b)]),
- tir::Not => Expr::Not([dt, r!(p.value)]),
-
- tir::Broadcast => Expr::Broadcast([dt, r!(p.value), recexpr.add(Expr::Int(p.lanes.into()))]),
-
- tir::Let => {
- let sym = recexpr.add(Expr::Symbol(vars.make_symbol(p.var.clone())));
- Expr::Let([dt, sym, r!(p.value), r!(p.body)])
- }
- tir::Var => {
- let sym = recexpr.add(Expr::Symbol(vars.make_symbol(p)));
- Expr::Var([dt, sym])
- }
- tir::IntImm => {
- let int = recexpr.add(Expr::Int(p.value));
- Expr::IntImm([dt, int])
- }
- tir::FloatImm => {
- let float = recexpr.add(Expr::Float(NotNan::new(p.value).unwrap()));
- Expr::FloatImm([dt, float])
- }
- tir::Cast => Expr::Cast([dt, r!(p.value)]),
-
- tir::Call => {
- let op = vars.push_obj(p.op.clone());
- let mut arg_ids = vec![dt];
- for i in 0..p.args.len() {
- let arg: PrimExpr = p.args.get(i as isize).expect("array get fail");
- arg_ids.push(r!(arg));
- }
- Expr::Call(op, arg_ids)
- },
- tir::Load => {
- let sym = recexpr.add(Expr::Symbol(vars.make_symbol(p.var.clone())));
- Expr::Load([dt, sym, r!(p.index), r!(p.predicate)])
- },
- else => {
- println!("Failed to downcast type '{}': {}", p.type_key(), tvm::runtime::debug_print(p.clone().upcast()).unwrap().to_str().unwrap());
- Expr::Object(vars.push_obj(p.clone()))
- }
- });
-
- recexpr.add(e)
- }
-
- let mut recexpr = Default::default();
- build(vars, prim, &mut recexpr);
- recexpr
- }
-
- fn from_egg(vars: &VarMap, recexpr: &RecExpr) -> PrimExpr {
- fn build(vars: &VarMap, nodes: &[Expr]) -> PrimExpr {
- let go = |i: &egg::Id| build(vars, &nodes[..usize::from(*i) + 1]);
- let get_dt = |i: &egg::Id| nodes[usize::from(*i)].to_dtype().unwrap();
- let prim: PrimExpr = match nodes.last().expect("cannot be empty") {
- Expr::Var([_dt, s]) => match &nodes[usize::from(*s)] {
- Expr::Symbol(sym) => vars.get_symbol(*sym).upcast(),
- n => panic!("Expected a symbol, got {:?}", n),
- },
- Expr::IntImm([dt, v]) => {
- let value = nodes[usize::from(*v)].to_int().unwrap();
- tir::IntImm::new(get_dt(dt), value).upcast()
- }
- Expr::FloatImm([dt, v]) => {
- let value = nodes[usize::from(*v)].to_float().unwrap();
- tir::FloatImm::new(get_dt(dt), value).upcast()
- }
- Expr::Let([dt, s, value, body]) => {
- let var = match &nodes[usize::from(*s)] {
- Expr::Symbol(sym) => vars.get_symbol(*sym).upcast(),
- n => panic!("Expected a symbol, got {:?}", n),
- };
- tir::Let::new(get_dt(dt), var, go(value), go(body)).upcast()
- }
- Expr::Load([dt, s, value, body]) => {
- let var = match &nodes[usize::from(*s)] {
- Expr::Symbol(sym) => vars.get_symbol(*sym).upcast(),
- n => panic!("Expected a symbol, got {:?}", n),
- };
- tir::Load::new(get_dt(dt), var, go(value), go(body)).upcast()
- }
-
- Expr::Add([dt, a, b]) => tir::Add::new(get_dt(dt), go(a), go(b)).upcast(),
- Expr::Sub([dt, a, b]) => tir::Sub::new(get_dt(dt), go(a), go(b)).upcast(),
- Expr::Mul([dt, a, b]) => tir::Mul::new(get_dt(dt), go(a), go(b)).upcast(),
-
- Expr::Div([dt, a, b]) => tir::Div::new(get_dt(dt), go(a), go(b)).upcast(),
- Expr::Mod([dt, a, b]) => tir::Mod::new(get_dt(dt), go(a), go(b)).upcast(),
- Expr::FloorDiv([dt, a, b]) => tir::FloorDiv::new(get_dt(dt), go(a), go(b)).upcast(),
- Expr::FloorMod([dt, a, b]) => tir::FloorMod::new(get_dt(dt), go(a), go(b)).upcast(),
-
- Expr::Min([dt, a, b]) => tir::Min::new(get_dt(dt), go(a), go(b)).upcast(),
- Expr::Max([dt, a, b]) => tir::Max::new(get_dt(dt), go(a), go(b)).upcast(),
-
- Expr::Equal([dt, a, b]) => tir::Eq::new(get_dt(dt), go(a), go(b)).upcast(),
- Expr::NotEqual([dt, a, b]) => tir::Ne::new(get_dt(dt), go(a), go(b)).upcast(),
- Expr::Less([dt, a, b]) => tir::Lt::new(get_dt(dt), go(a), go(b)).upcast(),
- Expr::LessEqual([dt, a, b]) => tir::Le::new(get_dt(dt), go(a), go(b)).upcast(),
- Expr::Greater([dt, a, b]) => tir::Gt::new(get_dt(dt), go(a), go(b)).upcast(),
- Expr::GreaterEqual([dt, a, b]) => tir::Ge::new(get_dt(dt), go(a), go(b)).upcast(),
-
- Expr::And([dt, a, b]) => tir::And::new(get_dt(dt), go(a), go(b)).upcast(),
- Expr::Or([dt, a, b]) => tir::Or::new(get_dt(dt), go(a), go(b)).upcast(),
- Expr::Not([dt, a]) => tir::Not::new(get_dt(dt), go(a)).upcast(),
-
- Expr::Ramp([dt, a, b, c]) => {
- let len = &nodes[usize::from(*c)];
- let i = len
- .to_int()
- .unwrap_or_else(|| panic!("Ramp lanes must be an int, got {:?}", len));
- tir::Ramp::new(get_dt(dt), go(a), go(b), i as i32).upcast()
- }
- Expr::Broadcast([dt, val, lanes]) => {
- let lanes = &nodes[usize::from(*lanes)];
- let lanes = lanes
- .to_int()
- .unwrap_or_else(|| panic!("Ramp lanes must be an int, got {:?}", lanes));
- println!("dt: {}", get_dt(dt));
- tir::Broadcast::new(get_dt(dt), go(val), lanes as i32).upcast()
- }
-
- Expr::Select([dt, a, b, c]) => tir::Select::new(get_dt(dt), go(a), go(b), go(c)).upcast(),
- Expr::Cast([dt, a]) => tir::Cast::new(get_dt(dt), go(a)).upcast(),
- Expr::Call(expr, args) => {
- let arg_exprs: Vec<PrimExpr> = args[1..].iter().map(go).collect();
- let arg_exprs = Array::from_vec(arg_exprs).expect("failed to convert args");
- tir::Call::new(get_dt(&args[0]), vars.get_obj(*expr), arg_exprs).upcast()
- }
-
- Expr::Object(i) => vars.get_obj(*i),
- node => panic!("I don't know how to extract {:?}", node),
- };
- assert_ne!(prim.datatype.bits(), 0);
- assert_ne!(prim.datatype.lanes(), 0);
- prim
- }
- build(vars, recexpr.as_ref())
- }
-
- fn run(
- input: PrimExpr,
- expected: Option<PrimExpr>,
- map: Map<PrimExpr, ConstIntBound>,
- ) -> Result<PrimExpr, String> {
- use egg::{CostFunction, Extractor};
-
- let mut bounds = BoundsMap::default();
- for (k, v) in map {
- if let Ok(var) = k.downcast_clone::<tir::Var>() {
- let sym: egg::Symbol = var.name_hint.as_str().unwrap().into();
- bounds.insert(sym, (v.min_value, v.max_value));
- } else {
- println!("Non var in bounds map: {}", tvm::ir::as_text(k));
- }
- }
-
- let mut vars = VarMap::default();
- let expr = to_egg(&mut vars, &input);
- let mut runner = math::default_runner();
- runner.egraph.analysis.bounds = bounds;
-
- let mut runner = runner.with_expr(&expr).run(&math::rules());
- // runner.print_report();
- let mut extractor = Extractor::new(&runner.egraph, math::CostFn);
- let root = runner.egraph.find(runner.roots[0]);
- let (cost, best) = extractor.find_best(root);
- if let Some(expected) = expected {
- let mut expected_vars = VarMap::default();
- let expected_expr = to_egg(&mut expected_vars, &expected);
- let expected_root = runner.egraph.add_expr(&expected_expr);
- if expected_root != root {
- return Err(format!(
- "\n\nFailed to prove them equal!\nExpected:\n{}\nFound:\n{}\n",
- expected_expr.pretty(40),
- best.pretty(40)
- ));
- }
- let expected_cost = math::CostFn.cost_rec(&expected_expr);
- if expected_cost != cost {
- let msg = format!(
- "\n\nCosts not equal: Expected {}:\n{}\nFound {}:\n{}\n",
- expected_cost,
- expected_expr.pretty(40),
- cost,
- best.pretty(40)
- );
- if cost < expected_cost {
- println!("egg wins: {}", msg)
- } else {
- return Err(msg);
- }
- }
- }
- log::info!(" returning... {}", best.pretty(60));
- Ok(from_egg(&vars, &best))
- }
-
- fn simplify(prim: PrimExpr, map: Map<PrimExpr, ConstIntBound>) -> Result<PrimExpr, tvm::Error> {
- log::debug!("map: {:?}", map);
- run(prim, None, map).map_err(tvm::Error::CallFailed)
- }
-
- fn simplify_and_check(
- prim: PrimExpr,
- check: PrimExpr,
- map: Map<PrimExpr, ConstIntBound>,
- ) -> Result<PrimExpr, tvm::Error> {
- log::debug!("check map: {:?}", map);
- run(prim, Some(check), map).map_err(tvm::Error::CallFailed)
- }
-
- initialize!({
- let _ = env_logger::try_init();
- // NOTE this print prevents a segfault (on Linux) for now...
- println!("Initializing simplifier... ");
- register_override(simplify, "egg.simplify", true).expect("failed to initialize simplifier");
- register_override(simplify_and_check, "egg.simplify_and_check", true)
- .expect("failed to initialize simplifier");
- log::debug!("done!");
- });
-
\ No newline at end of file
+use env_logger;
+use tvm;
+use tvm::runtime::function::register_override;
+
+fn test_fn() -> Result<(), tvm::Error> {
+ println!("Hello from Rust!");
+ Ok(())
+}
+
+#[no_mangle]
+fn compiler_ext_initialize() -> i32 {
+ let _ = env_logger::try_init();
+ register_override(test_fn, "rust_ext.test_fn", true).expect("failed to initialize simplifier");
+ log::debug!("done!");
+ return 0;
+}
diff --git a/rust/tvm-rt/Cargo.toml b/rust/tvm-rt/Cargo.toml
index acece5a..9660943 100644
--- a/rust/tvm-rt/Cargo.toml
+++ b/rust/tvm-rt/Cargo.toml
@@ -28,19 +28,26 @@ categories = ["api-bindings", "science"]
authors = ["TVM Contributors"]
edition = "2018"
+[features]
+default = ["dynamic-linking"]
+dynamic-linking = ["tvm-sys/bindings"]
+static-linking = []
+blas = ["ndarray/blas"]
+
[dependencies]
thiserror = "^1.0"
ndarray = "0.12"
num-traits = "0.2"
-tvm-sys = { version = "0.1", path = "../tvm-sys/", features = ["bindings"] }
tvm-macros = { version = "0.1", path = "../tvm-macros" }
paste = "0.1"
mashup = "0.1"
once_cell = "^1.3.1"
memoffset = "0.5.6"
+[dependencies.tvm-sys]
+version = "0.1"
+default-features = false
+path = "../tvm-sys/"
+
[dev-dependencies]
anyhow = "^1.0"
-
-[features]
-blas = ["ndarray/blas"]
diff --git a/rust/tvm-sys/Cargo.toml b/rust/tvm-sys/Cargo.toml
index 4e3fc98..c25a5bf 100644
--- a/rust/tvm-sys/Cargo.toml
+++ b/rust/tvm-sys/Cargo.toml
@@ -23,6 +23,7 @@ license = "Apache-2.0"
edition = "2018"
[features]
+default = ["bindings"]
bindings = []
[dependencies]
diff --git a/rust/tvm-sys/build.rs b/rust/tvm-sys/build.rs
index 05806c0..2d86c4b 100644
--- a/rust/tvm-sys/build.rs
+++ b/rust/tvm-sys/build.rs
@@ -60,6 +60,7 @@ fn main() -> Result<()> {
if cfg!(feature = "bindings") {
println!("cargo:rerun-if-env-changed=TVM_HOME");
println!("cargo:rustc-link-lib=dylib=tvm");
+ println!("cargo:rustc-link-lib=dylib=llvm-10");
println!("cargo:rustc-link-search={}/build", tvm_home);
}
diff --git a/rust/tvm/Cargo.toml b/rust/tvm/Cargo.toml
index 71a4b93..153a195 100644
--- a/rust/tvm/Cargo.toml
+++ b/rust/tvm/Cargo.toml
@@ -28,14 +28,24 @@ categories = ["api-bindings", "science"]
authors = ["TVM Contributors"]
edition = "2018"
+[features]
+default = ["python", "dynamic-linking"]
+dynamic-linking = ["tvm-rt/dynamic-linking"]
+static-linking = ["tvm-rt/static-linking"]
+blas = ["ndarray/blas"]
+python = ["pyo3"]
+
+[dependencies.tvm-rt]
+version = "0.1"
+default-features = false
+path = "../tvm-rt/"
+
[dependencies]
thiserror = "^1.0"
anyhow = "^1.0"
lazy_static = "1.1"
ndarray = "0.12"
num-traits = "0.2"
-tvm-rt = { version = "0.1", path = "../tvm-rt/" }
-tvm-sys = { version = "0.1", path = "../tvm-sys/" }
tvm-macros = { version = "*", path = "../tvm-macros/" }
paste = "0.1"
mashup = "0.1"
@@ -44,8 +54,6 @@ pyo3 = { version = "0.11.1", optional = true }
codespan-reporting = "0.9.5"
structopt = { version = "0.3" }
-[features]
-default = ["python"]
-
-blas = ["ndarray/blas"]
-python = ["pyo3"]
+[[bin]]
+name = "tyck"
+required-features = ["dynamic-linking"]
diff --git a/rust/tvm/src/bin/tyck.rs b/rust/tvm/src/bin/tyck.rs
index 9300412..b869012 100644
--- a/rust/tvm/src/bin/tyck.rs
+++ b/rust/tvm/src/bin/tyck.rs
@@ -6,7 +6,6 @@ use structopt::StructOpt;
use tvm::ir::diagnostics::codespan;
use tvm::ir::IRModule;
-
#[derive(Debug, StructOpt)]
#[structopt(name = "tyck", about = "Parse and type check a Relay program.")]
struct Opt {
diff --git a/rust/tvm/src/ir/diagnostics.rs b/rust/tvm/src/ir/diagnostics.rs
index d306185..b76e43f 100644
--- a/rust/tvm/src/ir/diagnostics.rs
+++ b/rust/tvm/src/ir/diagnostics.rs
@@ -17,17 +17,20 @@
* under the License.
*/
+use super::module::IRModule;
+use super::span::Span;
+use crate::runtime::function::Result;
+use crate::runtime::object::{Object, ObjectPtr, ObjectRef};
+use crate::runtime::{
+ array::Array,
+ function::{self, Function, ToFunction},
+ string::String as TString,
+};
/// The diagnostic interface to TVM, used for reporting and rendering
/// diagnostic information by the compiler. This module exposes
/// three key abstractions: a Diagnostic, the DiagnosticContext,
/// and the DiagnosticRenderer.
-
-use tvm_macros::{Object, external};
-use super::module::IRModule;
-use crate::runtime::{function::{self, Function, ToFunction}, array::Array, string::String as TString};
-use crate::runtime::object::{Object, ObjectPtr, ObjectRef};
-use crate::runtime::function::Result;
-use super::span::Span;
+use tvm_macros::{external, Object};
type SourceName = ObjectRef;
@@ -134,7 +137,6 @@ pub struct DiagnosticRendererNode {
// memory layout
}
-
// def render(self, ctx):
// """
// Render the provided context.
@@ -169,7 +171,8 @@ pub struct DiagnosticContextNode {
/// and contains a renderer.
impl DiagnosticContext {
pub fn new<F>(module: IRModule, render_func: F) -> DiagnosticContext
- where F: Fn(DiagnosticContext) -> () + 'static
+ where
+ F: Fn(DiagnosticContext) -> () + 'static,
{
let renderer = diagnostic_renderer(render_func.to_function()).unwrap();
let node = DiagnosticContextNode {
@@ -210,21 +213,16 @@ impl DiagnosticContext {
// If the render_func is None it will remove the current custom renderer
// and return to default behavior.
fn override_renderer<F>(opt_func: Option<F>) -> Result<()>
-where F: Fn(DiagnosticContext) -> () + 'static
+where
+ F: Fn(DiagnosticContext) -> () + 'static,
{
-
match opt_func {
None => clear_renderer(),
Some(func) => {
let func = func.to_function();
- let render_factory = move || {
- diagnostic_renderer(func.clone()).unwrap()
- };
+ let render_factory = move || diagnostic_renderer(func.clone()).unwrap();
- function::register_override(
- render_factory,
- "diagnostics.OverrideRenderer",
- true)?;
+ function::register_override(render_factory, "diagnostics.OverrideRenderer", true)?;
Ok(())
}
@@ -243,9 +241,9 @@ pub mod codespan {
End,
}
- struct SpanToBytes {
- inner: HashMap<std::String, HashMap<usize, (StartOrEnd,
- }
+ // struct SpanToBytes {
+ // inner: HashMap<std::String, HashMap<usize, (StartOrEnd,
+ // }
struct ByteRange<FileId> {
file_id: FileId,
@@ -276,7 +274,7 @@ pub mod codespan {
.with_message(message)
.with_code("EXXX")
.with_labels(vec![
- Label::primary(file_id, 328..331).with_message(inner_message),
+ Label::primary(file_id, 328..331).with_message(inner_message)
]);
diagnostic
diff --git a/rust/tvm/src/ir/mod.rs b/rust/tvm/src/ir/mod.rs
index 8450bd7..401b6c2 100644
--- a/rust/tvm/src/ir/mod.rs
+++ b/rust/tvm/src/ir/mod.rs
@@ -19,8 +19,8 @@
pub mod arith;
pub mod attrs;
-pub mod expr;
pub mod diagnostics;
+pub mod expr;
pub mod function;
pub mod module;
pub mod op;
diff --git a/rust/tvm/src/ir/relay/mod.rs b/rust/tvm/src/ir/relay/mod.rs
index e539221..4b09128 100644
--- a/rust/tvm/src/ir/relay/mod.rs
+++ b/rust/tvm/src/ir/relay/mod.rs
@@ -28,6 +28,7 @@ use super::attrs::Attrs;
use super::expr::BaseExprNode;
use super::function::BaseFuncNode;
use super::ty::{Type, TypeNode};
+use super::span::Span;
use tvm_macros::Object;
use tvm_rt::NDArray;
@@ -51,7 +52,7 @@ impl ExprNode {
span: ObjectRef::null(),
checked_type: Type::from(TypeNode {
base: Object::base_object::<TypeNode>(),
- span: ObjectRef::null(),
+ span: Span::empty(),
}),
}
}
diff --git a/rust/tvm/src/ir/source_map.rs b/rust/tvm/src/ir/source_map.rs
index e69de29..e6c0371 100644
--- a/rust/tvm/src/ir/source_map.rs
+++ b/rust/tvm/src/ir/source_map.rs
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+use crate::runtime::map::Map;
+use crate::runtime::object::Object;
+
+/// A program source in any language.
+///
+/// Could represent the source from an ML framework or a source of an IRModule.
+#[repr(C)]
+#[derive(Object)]
+#[type_key = "Source"]
+#[ref_key = "Source"]
+struct SourceNode {
+ pub base: Object,
+ /*! \brief The source name. */
+ SourceName source_name;
+
+ /*! \brief The raw source. */
+ String source;
+
+ /*! \brief A mapping of line breaks into the raw source. */
+ std::vector<std::pair<int, int>> line_map;
+}
+
+
+// class Source : public ObjectRef {
+// public:
+// TVM_DLL Source(SourceName src_name, std::string source);
+// TVM_DLL tvm::String GetLine(int line);
+
+// TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Source, ObjectRef, SourceNode);
+// };
+
+
+/// A mapping from a unique source name to source fragments.
+#[repr(C)]
+#[derive(Object)]
+#[type_key = "SourceMap"]
+#[ref_key = "SourceMap"]
+struct SourceMapNode {
+ pub base: Object,
+ /// The source mapping.
+ pub source_map: Map<SourceName, Source>,
+}
diff --git a/rust/tvm/src/ir/span.rs b/rust/tvm/src/ir/span.rs
index d2e19a2..c54fd51 100644
--- a/rust/tvm/src/ir/span.rs
+++ b/rust/tvm/src/ir/span.rs
@@ -1,22 +1,75 @@
/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-use crate::runtime::ObjectRef;
-
-pub type Span = ObjectRef;
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements. See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership. The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an
+* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+* KIND, either express or implied. See the License for the
+
+* specific language governing permissions and limitations
+* under the License.
+*/
+
+use crate::runtime::{ObjectRef, Object, String as TString};
+use tvm_macros::Object;
+
+/// A source file name, contained in a Span.
+
+#[repr(C)]
+#[derive(Object)]
+#[type_key = "SourceName"]
+#[ref_name = "SourceName"]
+pub struct SourceNameNode {
+ pub base: Object,
+ pub name: TString,
+}
+
+// /*!
+// * \brief The source name of a file span.
+// * \sa SourceNameNode, Span
+// */
+// class SourceName : public ObjectRef {
+// public:
+// /*!
+// * \brief Get an SourceName for a given operator name.
+// * Will raise an error if the source name has not been registered.
+// * \param name Name of the operator.
+// * \return SourceName valid throughout program lifetime.
+// */
+// TVM_DLL static SourceName Get(const String& name);
+
+// TVM_DEFINE_OBJECT_REF_METHODS(SourceName, ObjectRef, SourceNameNode);
+// };
+
+/// Span information for diagnostic purposes.
+#[repr(C)]
+#[derive(Object)]
+#[type_key = "Span"]
+#[ref_name = "Span"]
+pub struct SpanNode {
+ pub base: Object,
+ /// The source name.
+ pub source_name: SourceName,
+ /// The line number.
+ pub line: i32,
+ /// The column offset.
+ pub column: i32,
+ /// The end line number.
+ pub end_line: i32,
+ /// The end column number.
+ pub end_column: i32,
+}
+
+impl Span {
+ pub fn empty() -> Span {
+ todo!()
+ }
+}
diff --git a/src/ir/expr.cc b/src/ir/expr.cc
index 67e5cea..5110eef 100644
--- a/src/ir/expr.cc
+++ b/src/ir/expr.cc
@@ -192,4 +192,15 @@ TVM_REGISTER_GLOBAL("ir.DebugPrint").set_body_typed([](ObjectRef ref) {
return ss.str();
});
+
+
} // namespace tvm
+
+#ifdef RUST_COMPILER_EXT
+
+extern "C" {
+ int compiler_ext_initialize();
+ static int test = compiler_ext_initialize();
+}
+
+#endif
diff --git a/src/parser/source_map.cc b/src/parser/source_map.cc
index 7ac978c..7340f69 100644
--- a/src/parser/source_map.cc
+++ b/src/parser/source_map.cc
@@ -77,12 +77,6 @@ tvm::String Source::GetLine(int line) {
return line_text;
}
-// TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
-// .set_dispatch<SourceNameNode>([](const ObjectRef& ref, ReprPrinter* p) {
-// auto* node = static_cast<const SourceNameNode*>(ref.get());
-// p->stream << "SourceName(" << node->name << ", " << node << ")";
-// });
-
TVM_REGISTER_NODE_TYPE(SourceMapNode);
SourceMap::SourceMap(Map<SourceName, Source> source_map) {
@@ -91,11 +85,6 @@ SourceMap::SourceMap(Map<SourceName, Source> source_map) {
data_ = std::move(n);
}
-// TODO(@jroesch): fix this
-static SourceMap global_source_map = SourceMap(Map<SourceName, Source>());
-
-SourceMap SourceMap::Global() { return global_source_map; }
-
void SourceMap::Add(const Source& source) { (*this)->source_map.Set(source->source_name, source); }
TVM_REGISTER_GLOBAL("SourceMapAdd").set_body_typed([](SourceMap map, String name, String content) {
[incubator-tvm] 03/23: WIP
Posted by jr...@apache.org.
This is an automated email from the ASF dual-hosted git repository.
jroesch pushed a commit to branch cargo-build
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
commit cb3785680cb343eb295ea8e0585c87ca42db4323
Author: Jared Roesch <jr...@octoml.ai>
AuthorDate: Tue Oct 13 11:04:37 2020 -0700
WIP
---
rust/tvm/src/ir/diagnostics.rs | 78 ++++++++++++++++++++----------------------
1 file changed, 37 insertions(+), 41 deletions(-)
diff --git a/rust/tvm/src/ir/diagnostics.rs b/rust/tvm/src/ir/diagnostics.rs
index e434d3f..d306185 100644
--- a/rust/tvm/src/ir/diagnostics.rs
+++ b/rust/tvm/src/ir/diagnostics.rs
@@ -24,7 +24,7 @@
use tvm_macros::{Object, external};
use super::module::IRModule;
-use crate::runtime::{function::{self, Function, ToFunction, Typed}, array::Array, string::String as TString};
+use crate::runtime::{function::{self, Function, ToFunction}, array::Array, string::String as TString};
use crate::runtime::object::{Object, ObjectPtr, ObjectRef};
use crate::runtime::function::Result;
use super::span::Span;
@@ -121,42 +121,19 @@ pub struct DiagnosticBuilder {
// * of compiler diagnostics to std::out and std::err in
// * a human readable form.
// */
-// class DiagnosticRendererNode : public Object {
-// public:
-// TypedPackedFunc<void(DiagnosticContext ctx)> renderer;
-
-// // override attr visitor
-// void VisitAttrs(AttrVisitor* v) {}
-
-// static constexpr const char* _type_key = "DiagnosticRenderer";
-// TVM_DECLARE_FINAL_OBJECT_INFO(DiagnosticRendererNode, Object);
-// };
-
-// class DiagnosticRenderer : public ObjectRef {
-// public:
-// TVM_DLL DiagnosticRenderer(TypedPackedFunc<void(DiagnosticContext ctx)> render);
-// TVM_DLL DiagnosticRenderer()
-// : DiagnosticRenderer(TypedPackedFunc<void(DiagnosticContext ctx)>()) {}
-
-// void Render(const DiagnosticContext& ctx);
-
-// DiagnosticRendererNode* operator->() {
-// CHECK(get() != nullptr);
-// return static_cast<DiagnosticRendererNode*>(get_mutable());
-// }
-
-// TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(DiagnosticRenderer, ObjectRef, DiagnosticRendererNode);
-// };
-
-// @tvm._ffi.register_object("DiagnosticRenderer")
-// class DiagnosticRenderer(Object):
-// """
-// A diagnostic renderer, which given a diagnostic context produces a "rendered"
-// form of the diagnostics for either human or computer consumption.
-// """
+#[repr(C)]
+#[derive(Object)]
+#[ref_name = "DiagnosticRenderer"]
+#[type_key = "DiagnosticRenderer"]
+/// A diagnostic renderer, which given a diagnostic context produces a "rendered"
+/// form of the diagnostics for either human or computer consumption.
+pub struct DiagnosticRendererNode {
+ /// The base type.
+ pub base: Object,
+ // TODO(@jroesch): we can't easily exposed packed functions due to
+ // memory layout
+}
-// def __init__(self, render_func):
-// self.__init_handle_by_constructor__(_ffi_api.DiagnosticRenderer, render_func)
// def render(self, ctx):
// """
@@ -168,7 +145,6 @@ pub struct DiagnosticBuilder {
// The diagnostic context to render.
// """
// return _ffi_api.DiagnosticRendererRender(self, ctx
-pub type DiagnosticRenderer = ObjectRef;
#[repr(C)]
#[derive(Object)]
@@ -227,8 +203,7 @@ impl DiagnosticContext {
}
}
-// Sets a custom renderer for diagnostics.
-
+// Override the global diagnostics renderer.
// Params
// ------
// render_func: Option[Callable[[DiagnosticContext], None]]
@@ -263,6 +238,27 @@ pub mod codespan {
use codespan_reporting::files::SimpleFiles;
use codespan_reporting::term::termcolor::{ColorChoice, StandardStream};
+ enum StartOrEnd {
+ Start,
+ End,
+ }
+
+ struct SpanToBytes {
+ inner: HashMap<std::String, HashMap<usize, (StartOrEnd,
+ }
+
+ struct ByteRange<FileId> {
+ file_id: FileId,
+ start_pos: usize,
+ end_pos: usize,
+ }
+
+ // impl SpanToBytes {
+ // fn to_byte_pos(&self, span: tvm::ir::Span) -> ByteRange<FileId> {
+
+ // }
+ // }
+
pub fn to_diagnostic(diag: super::Diagnostic) -> CDiagnostic<String> {
let severity = match diag.level {
DiagnosticLevel::Error => Severity::Error,
@@ -290,9 +286,9 @@ pub mod codespan {
let mut files: SimpleFiles<String, String> = SimpleFiles::new();
let render_fn = move |diag_ctx: DiagnosticContext| {
// let source_map = diag_ctx.module.source_map;
- for diagnostic in diag_ctx.diagnostics {
+ // for diagnostic in diag_ctx.diagnostics {
- }
+ // }
panic!("render_fn");
};