You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by lu...@apache.org on 2023/08/27 15:21:23 UTC

[tvm] branch unity updated: [Unity] Add instruments to relay translator (#15601)

This is an automated email from the ASF dual-hosted git repository.

lunderberg pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new d2972f3b42 [Unity] Add instruments to relay translator (#15601)
d2972f3b42 is described below

commit d2972f3b42e8399496c957f6fcc7209daa7ca247
Author: Anirudh Sundar Subramaniam <qu...@quicinc.com>
AuthorDate: Sun Aug 27 20:51:14 2023 +0530

    [Unity] Add instruments to relay translator (#15601)
    
    * [Unity] Add instruments to relay translator
    
    Sometimes its useful to instrument relay passes that are run during
    relay to relax translation, and this patch adds a new argument to relay
    translator to accept an instruments list argument that gets passed onto
    the PassContext used while running relay prefix passes
    
    * Fix test case
---
 python/tvm/relax/testing/relay_translator.py |  9 ++++++++-
 tests/python/relax/test_relay_translator.py  | 26 ++++++++++++++++++++++++++
 2 files changed, 34 insertions(+), 1 deletion(-)

diff --git a/python/tvm/relax/testing/relay_translator.py b/python/tvm/relax/testing/relay_translator.py
index 5316c33ca6..7b09c9ad57 100644
--- a/python/tvm/relax/testing/relay_translator.py
+++ b/python/tvm/relax/testing/relay_translator.py
@@ -18,11 +18,12 @@
 # pylint: disable=too-many-nested-blocks, unused-variable
 """Relay to Relax translator."""
 
-from typing import Any, Dict, List, Optional
+from typing import Any, Dict, List, Optional, Sequence
 
 import tvm
 from tvm import relax, relay
 from tvm.ir.module import IRModule
+from tvm.ir.instrument import PassInstrument
 from tvm.relax.testing import nn
 from tvm.relay.backend.te_compiler import select_implementation
 from tvm.runtime import NDArray
@@ -37,6 +38,7 @@ def from_relay(
     *,
     opt_level: int = 3,
     pass_config: Optional[Dict[str, Any]] = None,
+    instruments: Optional[Sequence[PassInstrument]] = None,
     disabled_pass: Optional[List[str]] = None,
     translate_op_with_tir: Optional[Dict[str, tvm.tir.PrimFunc]] = None,
     append_op_attrs: bool = False,
@@ -60,6 +62,10 @@ def from_relay(
     pass_config: Optional[Dict[str, Any]]
         Pass configuration.
 
+    instruments : Optional[Sequence[PassInstrument]]
+        The list of pass instrument implementations to be passed onto relay
+        while calling relay passes
+
     disabled_pass: Optional[List[str]]
         Passes to disable.
 
@@ -255,6 +261,7 @@ def from_relay(
         opt_level=opt_level,
         config=pass_config,
         disabled_pass=disabled_pass,
+        instruments=instruments,
     ):
         mod = tvm.IRModule.from_expr(func)
         mod = seq(mod)
diff --git a/tests/python/relax/test_relay_translator.py b/tests/python/relax/test_relay_translator.py
index 54cd1b243d..6790ae851b 100644
--- a/tests/python/relax/test_relay_translator.py
+++ b/tests/python/relax/test_relay_translator.py
@@ -312,5 +312,31 @@ def test_append_op_attrs():
     assert "op_attrs" not in relax_mod_wo_attrs["concatenate"].attrs
 
 
+def test_instruments_support():
+    x = relay.var("x", shape=(10, 16))
+    y = relay.var("y", shape=(10, 16))
+    out = relay.add(x, y)
+    mod = tvm.IRModule.from_expr(out)
+
+    @tvm.instrument.pass_instrument
+    class SampleRunBeforeAfterInstrument:
+        def __init__(self):
+            self.events = []
+
+        def run_before_pass(self, mod, info):
+            self.events.append("run before " + info.name)
+
+        def run_after_pass(self, mod, info):
+            self.events.append("run after " + info.name)
+
+    my_test = SampleRunBeforeAfterInstrument()
+    relax_mod_with_attrs = relay_translator.from_relay(
+        mod["main"], target="llvm", instruments=[my_test]
+    )
+
+    assert "run after " in "".join(my_test.events)
+    assert "run before " in "".join(my_test.events)
+
+
 if __name__ == "__main__":
     pytest.main([__file__])