You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by lu...@apache.org on 2023/08/27 15:21:23 UTC
[tvm] branch unity updated: [Unity] Add instruments to relay translator (#15601)
This is an automated email from the ASF dual-hosted git repository.
lunderberg pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new d2972f3b42 [Unity] Add instruments to relay translator (#15601)
d2972f3b42 is described below
commit d2972f3b42e8399496c957f6fcc7209daa7ca247
Author: Anirudh Sundar Subramaniam <qu...@quicinc.com>
AuthorDate: Sun Aug 27 20:51:14 2023 +0530
[Unity] Add instruments to relay translator (#15601)
* [Unity] Add instruments to relay translator
Sometimes its useful to instrument relay passes that are run during
relay to relax translation, and this patch adds a new argument to relay
translator to accept an instruments list argument that gets passed onto
the PassContext used while running relay prefix passes
* Fix test case
---
python/tvm/relax/testing/relay_translator.py | 9 ++++++++-
tests/python/relax/test_relay_translator.py | 26 ++++++++++++++++++++++++++
2 files changed, 34 insertions(+), 1 deletion(-)
diff --git a/python/tvm/relax/testing/relay_translator.py b/python/tvm/relax/testing/relay_translator.py
index 5316c33ca6..7b09c9ad57 100644
--- a/python/tvm/relax/testing/relay_translator.py
+++ b/python/tvm/relax/testing/relay_translator.py
@@ -18,11 +18,12 @@
# pylint: disable=too-many-nested-blocks, unused-variable
"""Relay to Relax translator."""
-from typing import Any, Dict, List, Optional
+from typing import Any, Dict, List, Optional, Sequence
import tvm
from tvm import relax, relay
from tvm.ir.module import IRModule
+from tvm.ir.instrument import PassInstrument
from tvm.relax.testing import nn
from tvm.relay.backend.te_compiler import select_implementation
from tvm.runtime import NDArray
@@ -37,6 +38,7 @@ def from_relay(
*,
opt_level: int = 3,
pass_config: Optional[Dict[str, Any]] = None,
+ instruments: Optional[Sequence[PassInstrument]] = None,
disabled_pass: Optional[List[str]] = None,
translate_op_with_tir: Optional[Dict[str, tvm.tir.PrimFunc]] = None,
append_op_attrs: bool = False,
@@ -60,6 +62,10 @@ def from_relay(
pass_config: Optional[Dict[str, Any]]
Pass configuration.
+ instruments : Optional[Sequence[PassInstrument]]
+ The list of pass instrument implementations to be passed onto relay
+ while calling relay passes
+
disabled_pass: Optional[List[str]]
Passes to disable.
@@ -255,6 +261,7 @@ def from_relay(
opt_level=opt_level,
config=pass_config,
disabled_pass=disabled_pass,
+ instruments=instruments,
):
mod = tvm.IRModule.from_expr(func)
mod = seq(mod)
diff --git a/tests/python/relax/test_relay_translator.py b/tests/python/relax/test_relay_translator.py
index 54cd1b243d..6790ae851b 100644
--- a/tests/python/relax/test_relay_translator.py
+++ b/tests/python/relax/test_relay_translator.py
@@ -312,5 +312,31 @@ def test_append_op_attrs():
assert "op_attrs" not in relax_mod_wo_attrs["concatenate"].attrs
+def test_instruments_support():
+ x = relay.var("x", shape=(10, 16))
+ y = relay.var("y", shape=(10, 16))
+ out = relay.add(x, y)
+ mod = tvm.IRModule.from_expr(out)
+
+ @tvm.instrument.pass_instrument
+ class SampleRunBeforeAfterInstrument:
+ def __init__(self):
+ self.events = []
+
+ def run_before_pass(self, mod, info):
+ self.events.append("run before " + info.name)
+
+ def run_after_pass(self, mod, info):
+ self.events.append("run after " + info.name)
+
+ my_test = SampleRunBeforeAfterInstrument()
+ relax_mod_with_attrs = relay_translator.from_relay(
+ mod["main"], target="llvm", instruments=[my_test]
+ )
+
+ assert "run after " in "".join(my_test.events)
+ assert "run before " in "".join(my_test.events)
+
+
if __name__ == "__main__":
pytest.main([__file__])