You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2021/11/06 13:20:18 UTC
[tvm] branch main updated: [TensorIR] Print TVMScript with prefix T
instead of tir (#9422)
This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 374e15b [TensorIR] Print TVMScript with prefix T instead of tir (#9422)
374e15b is described below
commit 374e15b49a1cd79e5387a24e74a52cda24e6f484
Author: Anirudh Sundar <qu...@quicinc.com>
AuthorDate: Sat Nov 6 18:49:37 2021 +0530
[TensorIR] Print TVMScript with prefix T instead of tir (#9422)
---
python/tvm/ir/module.py | 2 +-
python/tvm/tir/function.py | 2 +-
src/printer/tvmscript_printer.cc | 21 +++++++++++++++++++--
src/tir/schedule/error.cc | 2 +-
.../python/unittest/test_tvmscript_error_report.py | 16 ++++++++--------
5 files changed, 30 insertions(+), 13 deletions(-)
diff --git a/python/tvm/ir/module.py b/python/tvm/ir/module.py
index 11ef823..1a705b9 100644
--- a/python/tvm/ir/module.py
+++ b/python/tvm/ir/module.py
@@ -256,7 +256,7 @@ class IRModule(Node):
def __repr__(self):
return self.astext()
- def script(self, tir_prefix: str = "tir", show_meta: bool = False) -> str:
+ def script(self, tir_prefix: str = "T", show_meta: bool = False) -> str:
"""Print IRModule into TVMScript
Parameters
diff --git a/python/tvm/tir/function.py b/python/tvm/tir/function.py
index b002ace..ecbcd83 100644
--- a/python/tvm/tir/function.py
+++ b/python/tvm/tir/function.py
@@ -143,7 +143,7 @@ class PrimFunc(BaseFunc):
"""
return _ffi_api.Specialize(self, param_map) # type: ignore
- def script(self, tir_prefix: str = "tir", show_meta: bool = False) -> str:
+ def script(self, tir_prefix: str = "T", show_meta: bool = False) -> str:
"""Print IRModule into TVMScript
Parameters
diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc
index f43c827..a47712e 100644
--- a/src/printer/tvmscript_printer.cc
+++ b/src/printer/tvmscript_printer.cc
@@ -310,6 +310,17 @@ class TVMScriptPrinter : public StmtFunctor<Doc(const Stmt&)>,
}
return doc;
}
+
+ public:
+ static Doc PrintHeader(const std::string& tir_prefix) {
+ Doc header;
+ if (tir_prefix != "tir") {
+ header << "# from tvm.script import tir as " << tir_prefix << Doc::NewLine();
+ } else {
+ header << "# from tvm.script import tir" << Doc::NewLine();
+ }
+ return header;
+ }
};
Doc TVMScriptPrinter::GetUniqueName(std::string prefix) {
@@ -1431,7 +1442,10 @@ Doc TVMScriptPrinterWithDiagnostic::PrintLoop(const For& loop) {
String AsTVMScript(const ObjectRef& mod, const String& tir_prefix, bool show_meta) {
ICHECK(mod->IsInstance<PrimFuncNode>() || mod->IsInstance<IRModuleNode>());
- return TVMScriptPrinter(tir_prefix, show_meta).Print(mod).str() + "\n";
+ Doc doc;
+ doc << TVMScriptPrinter::PrintHeader(tir_prefix)
+ << TVMScriptPrinter(tir_prefix, show_meta).Print(mod);
+ return doc.str() + "\n";
}
TVM_REGISTER_GLOBAL("script.AsTVMScript").set_body_typed(AsTVMScript);
@@ -1439,7 +1453,10 @@ TVM_REGISTER_GLOBAL("script.AsTVMScript").set_body_typed(AsTVMScript);
String AsTVMScriptWithDiagnostic(const ObjectRef& mod, const String& tir_prefix, bool show_meta,
runtime::TypedPackedFunc<std::string(Stmt)> annotate) {
ICHECK(mod->IsInstance<PrimFuncNode>() || mod->IsInstance<IRModuleNode>());
- return TVMScriptPrinterWithDiagnostic(tir_prefix, show_meta, annotate).Print(mod).str() + "\n";
+ Doc doc;
+ doc << TVMScriptPrinter::PrintHeader(tir_prefix)
+ << TVMScriptPrinterWithDiagnostic(tir_prefix, show_meta, annotate).Print(mod);
+ return doc.str() + "\n";
}
TVM_REGISTER_GLOBAL("script.AsTVMScriptWithDiagnostic").set_body_typed(AsTVMScriptWithDiagnostic);
diff --git a/src/tir/schedule/error.cc b/src/tir/schedule/error.cc
index eb72773..4ce5a97 100644
--- a/src/tir/schedule/error.cc
+++ b/src/tir/schedule/error.cc
@@ -52,7 +52,7 @@ String ScheduleError::RenderReport(const String& primitive) const {
os << "ScheduleError: An error occurred in the schedule primitive '" << primitive
<< "'.\n\nThe IR with diagnostic is:\n"
- << AsTVMScriptWithDiagnostic(mod, "tir", false, annotate);
+ << AsTVMScriptWithDiagnostic(mod, "T", false, annotate);
// print error message
os << "Error message: " << msg;
diff --git a/tests/python/unittest/test_tvmscript_error_report.py b/tests/python/unittest/test_tvmscript_error_report.py
index 3098c86..4c7ffd6 100644
--- a/tests/python/unittest/test_tvmscript_error_report.py
+++ b/tests/python/unittest/test_tvmscript_error_report.py
@@ -548,8 +548,8 @@ def test_reorder_fail_block():
sch.reorder(l, i)
expected_sub_error_message = (
" # tir.Block#0\n"
- ' with tir.block("B"):\n'
- " ^^^^^^^^^^^^^^^^^^^^\n"
+ ' with T.block("B"):\n'
+ " ^^^^^^^^^^^^^^^^^^\n"
)
assert expected_sub_error_message in str(execinfo.value)
@@ -561,10 +561,10 @@ def test_reorder_fail_nested_loop_inner():
with pytest.raises(tvm.tir.ScheduleError) as execinfo:
sch.reorder(k, i)
expected_sub_error_message = (
- " for i in tir.serial(0, 128):\n"
+ " for i in T.serial(0, 128):\n"
" # tir.For#0\n"
- " for j in tir.serial(0, 128):\n"
- " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n"
+ " for j in T.serial(0, 128):\n"
+ " ^^^^^^^^^^^^^^^^^^^^^^^^^^\n"
)
assert expected_sub_error_message in str(execinfo.value)
@@ -577,9 +577,9 @@ def test_fuse_fail_nested_loop_outer():
sch.fuse(k, i)
expected_sub_error_message = (
" # tir.For#1\n"
- " for i in tir.serial(0, 128):\n"
- " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n"
- " for j in tir.serial(0, 128):\n"
+ " for i in T.serial(0, 128):\n"
+ " ^^^^^^^^^^^^^^^^^^^^^^^^^^\n"
+ " for j in T.serial(0, 128):\n"
)
assert expected_sub_error_message in str(execinfo.value)