You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2021/11/06 13:20:18 UTC

[tvm] branch main updated: [TensorIR] Print TVMScript with prefix T instead of tir (#9422)

This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 374e15b  [TensorIR] Print TVMScript with prefix T instead of tir (#9422)
374e15b is described below

commit 374e15b49a1cd79e5387a24e74a52cda24e6f484
Author: Anirudh Sundar <qu...@quicinc.com>
AuthorDate: Sat Nov 6 18:49:37 2021 +0530

    [TensorIR] Print TVMScript with prefix T instead of tir (#9422)
---
 python/tvm/ir/module.py                             |  2 +-
 python/tvm/tir/function.py                          |  2 +-
 src/printer/tvmscript_printer.cc                    | 21 +++++++++++++++++++--
 src/tir/schedule/error.cc                           |  2 +-
 .../python/unittest/test_tvmscript_error_report.py  | 16 ++++++++--------
 5 files changed, 30 insertions(+), 13 deletions(-)

diff --git a/python/tvm/ir/module.py b/python/tvm/ir/module.py
index 11ef823..1a705b9 100644
--- a/python/tvm/ir/module.py
+++ b/python/tvm/ir/module.py
@@ -256,7 +256,7 @@ class IRModule(Node):
     def __repr__(self):
         return self.astext()
 
-    def script(self, tir_prefix: str = "tir", show_meta: bool = False) -> str:
+    def script(self, tir_prefix: str = "T", show_meta: bool = False) -> str:
         """Print IRModule into TVMScript
 
         Parameters
diff --git a/python/tvm/tir/function.py b/python/tvm/tir/function.py
index b002ace..ecbcd83 100644
--- a/python/tvm/tir/function.py
+++ b/python/tvm/tir/function.py
@@ -143,7 +143,7 @@ class PrimFunc(BaseFunc):
         """
         return _ffi_api.Specialize(self, param_map)  # type: ignore
 
-    def script(self, tir_prefix: str = "tir", show_meta: bool = False) -> str:
+    def script(self, tir_prefix: str = "T", show_meta: bool = False) -> str:
         """Print IRModule into TVMScript
 
         Parameters
diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc
index f43c827..a47712e 100644
--- a/src/printer/tvmscript_printer.cc
+++ b/src/printer/tvmscript_printer.cc
@@ -310,6 +310,17 @@ class TVMScriptPrinter : public StmtFunctor<Doc(const Stmt&)>,
     }
     return doc;
   }
+
+ public:
+  static Doc PrintHeader(const std::string& tir_prefix) {
+    Doc header;
+    if (tir_prefix != "tir") {
+      header << "# from tvm.script import tir as " << tir_prefix << Doc::NewLine();
+    } else {
+      header << "# from tvm.script import tir" << Doc::NewLine();
+    }
+    return header;
+  }
 };
 
 Doc TVMScriptPrinter::GetUniqueName(std::string prefix) {
@@ -1431,7 +1442,10 @@ Doc TVMScriptPrinterWithDiagnostic::PrintLoop(const For& loop) {
 
 String AsTVMScript(const ObjectRef& mod, const String& tir_prefix, bool show_meta) {
   ICHECK(mod->IsInstance<PrimFuncNode>() || mod->IsInstance<IRModuleNode>());
-  return TVMScriptPrinter(tir_prefix, show_meta).Print(mod).str() + "\n";
+  Doc doc;
+  doc << TVMScriptPrinter::PrintHeader(tir_prefix)
+      << TVMScriptPrinter(tir_prefix, show_meta).Print(mod);
+  return doc.str() + "\n";
 }
 
 TVM_REGISTER_GLOBAL("script.AsTVMScript").set_body_typed(AsTVMScript);
@@ -1439,7 +1453,10 @@ TVM_REGISTER_GLOBAL("script.AsTVMScript").set_body_typed(AsTVMScript);
 String AsTVMScriptWithDiagnostic(const ObjectRef& mod, const String& tir_prefix, bool show_meta,
                                  runtime::TypedPackedFunc<std::string(Stmt)> annotate) {
   ICHECK(mod->IsInstance<PrimFuncNode>() || mod->IsInstance<IRModuleNode>());
-  return TVMScriptPrinterWithDiagnostic(tir_prefix, show_meta, annotate).Print(mod).str() + "\n";
+  Doc doc;
+  doc << TVMScriptPrinter::PrintHeader(tir_prefix)
+      << TVMScriptPrinterWithDiagnostic(tir_prefix, show_meta, annotate).Print(mod);
+  return doc.str() + "\n";
 }
 
 TVM_REGISTER_GLOBAL("script.AsTVMScriptWithDiagnostic").set_body_typed(AsTVMScriptWithDiagnostic);
diff --git a/src/tir/schedule/error.cc b/src/tir/schedule/error.cc
index eb72773..4ce5a97 100644
--- a/src/tir/schedule/error.cc
+++ b/src/tir/schedule/error.cc
@@ -52,7 +52,7 @@ String ScheduleError::RenderReport(const String& primitive) const {
 
   os << "ScheduleError: An error occurred in the schedule primitive '" << primitive
      << "'.\n\nThe IR with diagnostic is:\n"
-     << AsTVMScriptWithDiagnostic(mod, "tir", false, annotate);
+     << AsTVMScriptWithDiagnostic(mod, "T", false, annotate);
 
   // print error message
   os << "Error message: " << msg;
diff --git a/tests/python/unittest/test_tvmscript_error_report.py b/tests/python/unittest/test_tvmscript_error_report.py
index 3098c86..4c7ffd6 100644
--- a/tests/python/unittest/test_tvmscript_error_report.py
+++ b/tests/python/unittest/test_tvmscript_error_report.py
@@ -548,8 +548,8 @@ def test_reorder_fail_block():
         sch.reorder(l, i)
     expected_sub_error_message = (
         "            # tir.Block#0\n"
-        '            with tir.block("B"):\n'
-        "            ^^^^^^^^^^^^^^^^^^^^\n"
+        '            with T.block("B"):\n'
+        "            ^^^^^^^^^^^^^^^^^^\n"
     )
     assert expected_sub_error_message in str(execinfo.value)
 
@@ -561,10 +561,10 @@ def test_reorder_fail_nested_loop_inner():
     with pytest.raises(tvm.tir.ScheduleError) as execinfo:
         sch.reorder(k, i)
     expected_sub_error_message = (
-        "        for i in tir.serial(0, 128):\n"
+        "        for i in T.serial(0, 128):\n"
         "            # tir.For#0\n"
-        "            for j in tir.serial(0, 128):\n"
-        "            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n"
+        "            for j in T.serial(0, 128):\n"
+        "            ^^^^^^^^^^^^^^^^^^^^^^^^^^\n"
     )
     assert expected_sub_error_message in str(execinfo.value)
 
@@ -577,9 +577,9 @@ def test_fuse_fail_nested_loop_outer():
         sch.fuse(k, i)
     expected_sub_error_message = (
         "        # tir.For#1\n"
-        "        for i in tir.serial(0, 128):\n"
-        "        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n"
-        "            for j in tir.serial(0, 128):\n"
+        "        for i in T.serial(0, 128):\n"
+        "        ^^^^^^^^^^^^^^^^^^^^^^^^^^\n"
+        "            for j in T.serial(0, 128):\n"
     )
     assert expected_sub_error_message in str(execinfo.value)