You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by bo...@apache.org on 2023/01/18 13:42:36 UTC

[tvm] branch main updated: [TVMScript] Use TVMScript for all TIR Printing (#13795)

This is an automated email from the ASF dual-hosted git repository.

bohan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new da99e9d1b5 [TVMScript] Use TVMScript for all TIR Printing (#13795)
da99e9d1b5 is described below

commit da99e9d1b5208e9a23e0b8e5b45da6e633f05415
Author: Junru Shao <ju...@gmail.com>
AuthorDate: Wed Jan 18 05:42:24 2023 -0800

    [TVMScript] Use TVMScript for all TIR Printing (#13795)
---
 CMakeLists.txt                                     |   2 +-
 include/tvm/ir/module.h                            |  28 -----
 include/tvm/ir/transform.h                         |   1 -
 include/tvm/relay/base.h                           |  28 +++++
 include/tvm/{ir => relay}/error.h                  |  13 +--
 include/tvm/relay/expr.h                           |   1 -
 include/tvm/relay/expr_functor.h                   |   2 +-
 include/tvm/relay/pattern_functor.h                |   2 +-
 python/tvm/ir/__init__.py                          |   1 -
 python/tvm/ir/affine_type.py                       |   2 +-
 python/tvm/ir/base.py                              |  31 -----
 python/tvm/ir/expr.py                              |  32 +++++-
 python/tvm/ir/module.py                            |  28 +++++
 python/tvm/ir/op.py                                |  31 ++++-
 python/tvm/ir/tensor_type.py                       |   2 +-
 python/tvm/micro/model_library_format.py           |   7 +-
 python/tvm/relay/__init__.py                       |   1 +
 python/tvm/relay/base.py                           |  39 ++++++-
 python/tvm/relay/dataflow_pattern/__init__.py      |  29 ++++-
 python/tvm/relay/expr.py                           |  34 ++++--
 python/tvm/relay/function.py                       |  29 ++++-
 python/tvm/script/__init__.py                      |   1 -
 python/tvm/script/printer/__init__.py              |   1 -
 python/tvm/script/printer/printer.py               |  54 ---------
 rust/tvm/src/ir/expr.rs                            |   2 +-
 src/ir/transform.cc                                |   6 +-
 src/relay/analysis/annotated_region_set.cc         |   2 +-
 src/relay/analysis/annotated_region_set.h          |   2 +-
 src/relay/analysis/kind_check.cc                   |   2 +-
 src/relay/analysis/match_exhaustion.cc             |   2 +-
 src/relay/analysis/type_solver.h                   |   2 +-
 src/relay/backend/contrib/ethosu/codegen.cc        |   2 +-
 src/relay/backend/contrib/ethosu/compiler_attrs.cc |   2 +-
 src/relay/backend/contrib/ethosu/preprocess.cc     |   2 +-
 src/relay/backend/contrib/uma/relay_to_tir.cc      |   2 +-
 src/relay/backend/vm/compiler.cc                   |   2 +-
 src/relay/backend/vm/compiler.h                    |   2 +-
 src/relay/collage/partition_rule.h                 |   2 +-
 src/relay/ir/base.cc                               |   5 +
 src/{ => relay}/ir/error.cc                        |  11 +-
 src/relay/op/tensor/transform.cc                   |   2 +-
 src/relay/op/tensor/transform.h                    |   2 +-
 src/relay/op/type_relations.h                      |   2 +-
 src/{ => relay}/printer/doc.cc                     |   4 +-
 src/{ => relay}/printer/doc.h                      |   9 +-
 src/{ => relay}/printer/meta_data.h                |  13 +--
 .../printer/model_library_format_printer.cc        |   6 +-
 src/{ => relay}/printer/relay_text_printer.cc      |  13 +--
 src/{ => relay}/printer/text_printer.cc            |   9 +-
 src/{ => relay}/printer/text_printer.h             |  47 +++-----
 src/{ => relay}/printer/tir_text_printer.cc        |  28 ++---
 src/{ => relay}/printer/tir_text_printer_debug.cc  |   4 +-
 src/{ => relay}/printer/tir_text_printer_debug.h   |  10 +-
 src/{ => relay}/printer/tvmscript_printer.cc       |  85 +++++++-------
 src/relay/transforms/merge_compiler_regions.cc     |   2 +-
 src/relay/transforms/partition_graph.cc            |   2 +-
 src/script/printer/printer.cc                      |   7 --
 src/tir/schedule/error.cc                          |   6 +-
 src/tir/transforms/install_debug_spans.cc          |   4 +-
 tests/python/relay/test_ir_parser.py               |  10 +-
 .../test_meta_schedule_schedule_rule_mlt.py        |   3 +-
 tests/python/unittest/test_tir_nodes.py            | 126 ---------------------
 .../test_tir_transform_lower_warp_memory.py        |   9 +-
 .../test_tvmscript_printer_syntax_sugar.py         |  69 -----------
 .../python/unittest/test_tvmscript_printer_tir.py  |  42 +++++++
 65 files changed, 447 insertions(+), 514 deletions(-)

diff --git a/CMakeLists.txt b/CMakeLists.txt
index 4a8d8b733e..36f7d37962 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -288,7 +288,6 @@ tvm_file_glob(GLOB_RECURSE COMPILER_SRCS
     src/topi/*.cc
     src/driver/*.cc
     src/parser/*.cc
-    src/printer/*.cc
     src/support/*.cc
     src/script/*.cc
     )
@@ -317,6 +316,7 @@ tvm_file_glob(GLOB RELAY_BACKEND_SRCS
     )
 tvm_file_glob(GLOB_RECURSE RELAY_IR_SRCS
     src/relay/ir/*.cc
+    src/relay/printer/*.cc
     )
 tvm_file_glob(GLOB_RECURSE RELAY_QNN_SRCS
     src/relay/qnn/*.cc
diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h
index f26e640f6c..4cd357d418 100644
--- a/include/tvm/ir/module.h
+++ b/include/tvm/ir/module.h
@@ -446,34 +446,6 @@ class IRModule : public ObjectRef {
   TVM_DEFINE_OBJECT_REF_COW_METHOD(IRModuleNode);
 };
 
-/*!
- * \brief Pretty print a node for debug purposes.
- *
- * \param node The node to be printed.
- * \return The text reperesentation.
- * \note This function does not show version or meta-data.
- *       Use AsText if you want to store the text.
- * \sa AsText.
- */
-TVM_DLL String PrettyPrint(const ObjectRef& node);
-
-/*!
- * \brief Render the node as a string in the text format.
- *
- * \param node The node to be rendered.
- * \param show_meta_data Whether to print meta data section.
- * \param annotate An optional callback function for attaching
- *        additional comment block to an expr.
- *
- * \note We support a limited set of IR nodes that are part of
- *       relay IR and
- *
- * \sa PrettyPrint.
- * \return The text representation.
- */
-TVM_DLL String AsText(const ObjectRef& node, bool show_meta_data = true,
-                      runtime::TypedPackedFunc<String(ObjectRef)> annotate = nullptr);
-
 namespace attr {
 
 // Following are attributes for IRModule only.
diff --git a/include/tvm/ir/transform.h b/include/tvm/ir/transform.h
index febcca5c01..473e629168 100644
--- a/include/tvm/ir/transform.h
+++ b/include/tvm/ir/transform.h
@@ -57,7 +57,6 @@
 #define TVM_IR_TRANSFORM_H_
 
 #include <tvm/ir/diagnostic.h>
-#include <tvm/ir/error.h>
 #include <tvm/ir/instrument.h>
 #include <tvm/ir/module.h>
 #include <tvm/runtime/container/array.h>
diff --git a/include/tvm/relay/base.h b/include/tvm/relay/base.h
index e94bd2756e..2825bcfc65 100644
--- a/include/tvm/relay/base.h
+++ b/include/tvm/relay/base.h
@@ -120,6 +120,34 @@ class Id : public ObjectRef {
   TVM_DEFINE_OBJECT_REF_METHODS(Id, ObjectRef, IdNode);
 };
 
+/*!
+ * \brief Pretty print a node for debug purposes.
+ *
+ * \param node The node to be printed.
+ * \return The text reperesentation.
+ * \note This function does not show version or meta-data.
+ *       Use AsText if you want to store the text.
+ * \sa AsText.
+ */
+TVM_DLL String PrettyPrint(const ObjectRef& node);
+
+/*!
+ * \brief Render the node as a string in the text format.
+ *
+ * \param node The node to be rendered.
+ * \param show_meta_data Whether to print meta data section.
+ * \param annotate An optional callback function for attaching
+ *        additional comment block to an expr.
+ *
+ * \note We support a limited set of IR nodes that are part of
+ *       relay IR and
+ *
+ * \sa PrettyPrint.
+ * \return The text representation.
+ */
+TVM_DLL String AsText(const ObjectRef& node, bool show_meta_data = true,
+                      runtime::TypedPackedFunc<String(ObjectRef)> annotate = nullptr);
+
 }  // namespace relay
 }  // namespace tvm
 
diff --git a/include/tvm/ir/error.h b/include/tvm/relay/error.h
similarity index 97%
rename from include/tvm/ir/error.h
rename to include/tvm/relay/error.h
index 6ff61781ac..be34e2b8ae 100644
--- a/include/tvm/ir/error.h
+++ b/include/tvm/relay/error.h
@@ -16,13 +16,8 @@
  * specific language governing permissions and limitations
  * under the License.
  */
-
-/*!
- * \file tvm/ir/error.h
- * \brief Utilities for error tracking and reporting.
- */
-#ifndef TVM_IR_ERROR_H_
-#define TVM_IR_ERROR_H_
+#ifndef TVM_RELAY_ERROR_H_
+#define TVM_RELAY_ERROR_H_
 
 #include <tvm/ir/module.h>
 #include <tvm/ir/span.h>
@@ -33,6 +28,7 @@
 #include <vector>
 
 namespace tvm {
+namespace relay {
 /*!
  * \brief A wrapper around std::stringstream to build error.
  *
@@ -181,5 +177,6 @@ class ErrorReporter {
   std::unordered_map<ObjectRef, GlobalVar, ObjectPtrHash, ObjectPtrEqual> node_to_gv_;
 };
 
+}  // namespace relay
 }  // namespace tvm
-#endif  // TVM_IR_ERROR_H_
+#endif  // TVM_RELAY_ERROR_H_
diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h
index 6847a53caa..854050464d 100644
--- a/include/tvm/relay/expr.h
+++ b/include/tvm/relay/expr.h
@@ -57,7 +57,6 @@ using BaseFunc = tvm::BaseFunc;
 using BaseFuncNode = tvm::BaseFuncNode;
 using GlobalVar = tvm::GlobalVar;
 using GlobalVarNode = tvm::GlobalVarNode;
-using tvm::PrettyPrint;
 
 /*!
  * \brief Constant tensor, backed by an NDArray on the cpu(0) device.
diff --git a/include/tvm/relay/expr_functor.h b/include/tvm/relay/expr_functor.h
index 280a1f8a6c..2a295c9da7 100644
--- a/include/tvm/relay/expr_functor.h
+++ b/include/tvm/relay/expr_functor.h
@@ -25,9 +25,9 @@
 #ifndef TVM_RELAY_EXPR_FUNCTOR_H_
 #define TVM_RELAY_EXPR_FUNCTOR_H_
 
-#include <tvm/ir/error.h>
 #include <tvm/node/functor.h>
 #include <tvm/relay/adt.h>
+#include <tvm/relay/error.h>
 #include <tvm/relay/expr.h>
 #include <tvm/relay/function.h>
 #include <tvm/relay/op.h>
diff --git a/include/tvm/relay/pattern_functor.h b/include/tvm/relay/pattern_functor.h
index 711d8323f1..9d2b6689b2 100644
--- a/include/tvm/relay/pattern_functor.h
+++ b/include/tvm/relay/pattern_functor.h
@@ -25,8 +25,8 @@
 #ifndef TVM_RELAY_PATTERN_FUNCTOR_H_
 #define TVM_RELAY_PATTERN_FUNCTOR_H_
 
-#include <tvm/ir/error.h>
 #include <tvm/node/functor.h>
+#include <tvm/relay/error.h>
 
 #include <string>
 #include <unordered_map>
diff --git a/python/tvm/ir/__init__.py b/python/tvm/ir/__init__.py
index 9e81dd5519..4f63cbecd9 100644
--- a/python/tvm/ir/__init__.py
+++ b/python/tvm/ir/__init__.py
@@ -27,7 +27,6 @@ from .base import (
     Span,
     assert_structural_equal,
     load_json,
-    pretty_print,
     save_json,
     structural_equal,
     structural_hash,
diff --git a/python/tvm/ir/affine_type.py b/python/tvm/ir/affine_type.py
index 8d185ae59a..24126f94b9 100644
--- a/python/tvm/ir/affine_type.py
+++ b/python/tvm/ir/affine_type.py
@@ -32,7 +32,7 @@ class AffineType(Node):
         return not self.__eq__(other)
 
     def __str__(self):
-        from tvm.ir import pretty_print  # pylint: disable=import-outside-toplevel
+        from tvm.relay import pretty_print  # pylint: disable=import-outside-toplevel
 
         return pretty_print(self)
 
diff --git a/python/tvm/ir/base.py b/python/tvm/ir/base.py
index a1e1d20d88..b84a83d558 100644
--- a/python/tvm/ir/base.py
+++ b/python/tvm/ir/base.py
@@ -23,40 +23,9 @@ from tvm.runtime import Object
 from . import _ffi_api, json_compact
 
 
-def pretty_print(obj: Object) -> None:
-    """Pretty print the object."""
-    return _ffi_api.PrettyPrint(obj)  # type: ignore # pylint: disable=no-member
-
-
 class Node(Object):
     """Base class of all IR Nodes, implements astext function."""
 
-    def astext(self, show_meta_data=True, annotate=None):
-        """Get the text format of the expression.
-
-        Parameters
-        ----------
-        show_meta_data : bool
-            Whether to include meta data section in the text
-            if there is meta data.
-
-        annotate: Optional[Object->str]
-            Optionally annotate function to provide additional
-            information in the comment block.
-
-        Returns
-        -------
-        text : str
-            The text format of the expression.
-
-        Notes
-        -----
-        The meta data section is necessary to fully parse the text format.
-        However, it can contain dumps that are big (e.g constant weights),
-        so it can be helpful to skip printing the meta data section.
-        """
-        return _ffi_api.AsText(self, show_meta_data, annotate)
-
 
 @tvm._ffi.register_object("SourceName")
 class SourceName(Object):
diff --git a/python/tvm/ir/expr.py b/python/tvm/ir/expr.py
index e16cd5ea9e..52af8407b7 100644
--- a/python/tvm/ir/expr.py
+++ b/python/tvm/ir/expr.py
@@ -17,9 +17,9 @@
 """Common expressions data structures in the IR."""
 import tvm._ffi
 
-from .base import Node
-from . import _ffi_api
 from ..runtime import const, convert
+from . import _ffi_api
+from .base import Node
 
 
 class BaseExpr(Node):
@@ -91,6 +91,34 @@ class GlobalVar(RelayExpr):
             "Do not know how to handle GlobalVar.__call__ for types {}".format(arg_types)
         )
 
+    def astext(self, show_meta_data=True, annotate=None):
+        """Get the text format of the expression.
+
+        Parameters
+        ----------
+        show_meta_data : bool
+            Whether to include meta data section in the text
+            if there is meta data.
+
+        annotate: Optional[Object->str]
+            Optionally annotate function to provide additional
+            information in the comment block.
+
+        Returns
+        -------
+        text : str
+            The text format of the expression.
+
+        Notes
+        -----
+        The meta data section is necessary to fully parse the text format.
+        However, it can contain dumps that are big (e.g constant weights),
+        so it can be helpful to skip printing the meta data section.
+        """
+        from tvm.relay import astext  # pylint: disable=import-outside-toplevel
+
+        return astext(self, show_meta_data, annotate)
+
 
 @tvm._ffi.register_object
 class Range(Node):
diff --git a/python/tvm/ir/module.py b/python/tvm/ir/module.py
index b184c3b0c3..51410049ec 100644
--- a/python/tvm/ir/module.py
+++ b/python/tvm/ir/module.py
@@ -287,6 +287,34 @@ class IRModule(Node):
 
         return _ffi_api.Module_WithAttr(self, attr_key, attr_value)
 
+    def astext(self, show_meta_data=True, annotate=None):
+        """Get the text format of the expression.
+
+        Parameters
+        ----------
+        show_meta_data : bool
+            Whether to include meta data section in the text
+            if there is meta data.
+
+        annotate: Optional[Object->str]
+            Optionally annotate function to provide additional
+            information in the comment block.
+
+        Returns
+        -------
+        text : str
+            The text format of the expression.
+
+        Notes
+        -----
+        The meta data section is necessary to fully parse the text format.
+        However, it can contain dumps that are big (e.g constant weights),
+        so it can be helpful to skip printing the meta data section.
+        """
+        from tvm.relay import astext  # pylint: disable=import-outside-toplevel
+
+        return astext(self, show_meta_data, annotate)
+
     def script(
         self,
         *,
diff --git a/python/tvm/ir/op.py b/python/tvm/ir/op.py
index 49ac72b887..70aba97951 100644
--- a/python/tvm/ir/op.py
+++ b/python/tvm/ir/op.py
@@ -17,8 +17,9 @@
 # pylint: disable=invalid-name
 """Primitive operators in the TVM IR."""
 import tvm._ffi
-from .expr import RelayExpr
+
 from . import _ffi_api
+from .expr import RelayExpr
 
 
 @tvm._ffi.register_object("Op")
@@ -28,6 +29,34 @@ class Op(RelayExpr):
     def __init__(self):
         raise RuntimeError("Cannot create op, use get instead")
 
+    def astext(self, show_meta_data=True, annotate=None):
+        """Get the text format of the expression.
+
+        Parameters
+        ----------
+        show_meta_data : bool
+            Whether to include meta data section in the text
+            if there is meta data.
+
+        annotate: Optional[Object->str]
+            Optionally annotate function to provide additional
+            information in the comment block.
+
+        Returns
+        -------
+        text : str
+            The text format of the expression.
+
+        Notes
+        -----
+        The meta data section is necessary to fully parse the text format.
+        However, it can contain dumps that are big (e.g constant weights),
+        so it can be helpful to skip printing the meta data section.
+        """
+        from tvm.relay import astext  # pylint: disable=import-outside-toplevel
+
+        return astext(self, show_meta_data, annotate)
+
     @staticmethod
     def get(op_name):
         """Get the Op for a given name
diff --git a/python/tvm/ir/tensor_type.py b/python/tvm/ir/tensor_type.py
index 7313f3c2b4..495e0fe868 100644
--- a/python/tvm/ir/tensor_type.py
+++ b/python/tvm/ir/tensor_type.py
@@ -56,6 +56,6 @@ class TensorType(Type):
         return tuple(int(x) for x in self.shape)
 
     def __str__(self):
-        from tvm.ir import pretty_print  # pylint: disable=import-outside-toplevel
+        from tvm.relay import pretty_print  # pylint: disable=import-outside-toplevel
 
         return pretty_print(self)
diff --git a/python/tvm/micro/model_library_format.py b/python/tvm/micro/model_library_format.py
index 0f30c39ad4..fc32fe34d6 100644
--- a/python/tvm/micro/model_library_format.py
+++ b/python/tvm/micro/model_library_format.py
@@ -27,12 +27,13 @@ import typing
 
 import tvm
 from tvm.micro import get_standalone_crt_dir
+
 from .._ffi import get_global_func
 from ..contrib import utils
 from ..driver import build_module
-from ..relay.backend import executor_factory
-from ..relay.backend.name_transforms import to_c_variable_style, prefix_generated_name
 from ..relay import param_dict
+from ..relay.backend import executor_factory
+from ..relay.backend.name_transforms import prefix_generated_name, to_c_variable_style
 from ..tir import expr
 
 # This should be kept identical to runtime::symbol::tvm_module_main
@@ -528,7 +529,7 @@ def _write_tir_and_build_operator_memory_map(src_dir, targets, ir_module_by_targ
         # TODO(mbs): The device type is not unique, better would be to use target.kind.name
         target_device_type = target.get_target_device_type()
         ir_mod = ir_module_by_target[target]
-        printer = get_global_func("tir.ModelLibraryFormatPrinter")(False, None, False)
+        printer = get_global_func("relay.ir.ModelLibraryFormatPrinter")(False, None, False)
         with open(src_dir / f"tir-{target_device_type}.txt", "w") as f:
             f.write(printer["print"](ir_mod))
 
diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py
index 97842738e5..5e5d1d5f18 100644
--- a/python/tvm/relay/__init__.py
+++ b/python/tvm/relay/__init__.py
@@ -29,6 +29,7 @@ from . import adt
 from . import prelude
 from . import loops
 from . import scope_builder
+from .base import pretty_print, astext
 
 from . import transform
 from . import analysis
diff --git a/python/tvm/relay/base.py b/python/tvm/relay/base.py
index 323a8f6e5a..8667bfb1df 100644
--- a/python/tvm/relay/base.py
+++ b/python/tvm/relay/base.py
@@ -17,15 +17,50 @@
 # pylint: disable=no-else-return, unidiomatic-typecheck, unused-import
 """The base node types for the Relay language."""
 import os
-import tvm._ffi
 
+import tvm._ffi
+from tvm.ir import Node as RelayNode
+from tvm.ir import SourceName, Span
 from tvm.runtime import Object
-from tvm.ir import SourceName, Span, Node as RelayNode
 
+from . import _ffi_api
 
 __STD_PATH__ = os.path.join(os.path.dirname(os.path.realpath(__file__)), "std")
 
 
+def pretty_print(obj: Object) -> None:
+    """Pretty print the object."""
+    return _ffi_api.PrettyPrint(obj)  # type: ignore # pylint: disable=no-member
+
+
+def astext(obj: Object, show_meta_data=True, annotate=None):
+    """Get the text format of the expression.
+
+    Parameters
+    ----------
+    obj : Object
+        The object to be printed.
+    show_meta_data : bool
+        Whether to include meta data section in the text
+        if there is meta data.
+    annotate: Optional[Object->str]
+        Optionally annotate function to provide additional
+        information in the comment block.
+
+    Returns
+    -------
+    text : str
+        The text format of the expression.
+
+    Notes
+    -----
+    The meta data section is necessary to fully parse the text format.
+    However, it can contain dumps that are big (e.g constant weights),
+    so it can be helpful to skip printing the meta data section.
+    """
+    return _ffi_api.AsText(obj, show_meta_data, annotate)  # type: ignore # pylint: disable=no-member
+
+
 @tvm._ffi.register_func("tvm.relay.std_path")
 def _std_path():
     return __STD_PATH__
diff --git a/python/tvm/relay/dataflow_pattern/__init__.py b/python/tvm/relay/dataflow_pattern/__init__.py
index 6c29825bc0..6e19cafa74 100644
--- a/python/tvm/relay/dataflow_pattern/__init__.py
+++ b/python/tvm/relay/dataflow_pattern/__init__.py
@@ -26,6 +26,7 @@ from ... import ir as _ir
 from ...ir import make_node
 from ...ir.base import Node
 from ...runtime import Object
+from ..base import astext, pretty_print
 from ..op import get
 from . import _ffi as ffi
 
@@ -47,10 +48,34 @@ class DFPattern(Node):
     """Base class of all Patterns."""
 
     def __str__(self):
-        from tvm.ir import pretty_print  # pylint: disable=import-outside-toplevel
-
         return pretty_print(self)
 
+    def astext(self, show_meta_data=True, annotate=None):
+        """Get the text format of the expression.
+
+        Parameters
+        ----------
+        show_meta_data : bool
+            Whether to include meta data section in the text
+            if there is meta data.
+
+        annotate: Optional[Object->str]
+            Optionally annotate function to provide additional
+            information in the comment block.
+
+        Returns
+        -------
+        text : str
+            The text format of the expression.
+
+        Notes
+        -----
+        The meta data section is necessary to fully parse the text format.
+        However, it can contain dumps that are big (e.g constant weights),
+        so it can be helpful to skip printing the meta data section.
+        """
+        return astext(self, show_meta_data, annotate)
+
     def __call__(self, *args):
         args = list(args)
         if len(args) == 1 and args[0] is None:
diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py
index 7d60e89b59..cb14552ac1 100644
--- a/python/tvm/relay/expr.py
+++ b/python/tvm/relay/expr.py
@@ -30,7 +30,7 @@ from tvm.runtime import ndarray as _nd
 
 from . import _ffi_api
 from . import ty as _ty
-from .base import RelayNode
+from .base import RelayNode, astext, pretty_print
 
 # alias relay expr as Expr.
 Expr = RelayExpr
@@ -62,10 +62,34 @@ class ExprWithOp(RelayExpr):
         return _ffi_api.cast(self, dtype)
 
     def __str__(self):
-        from tvm.ir import pretty_print  # pylint: disable=import-outside-toplevel
-
         return pretty_print(self)
 
+    def astext(self, show_meta_data=True, annotate=None):
+        """Get the text format of the expression.
+
+        Parameters
+        ----------
+        show_meta_data : bool
+            Whether to include meta data section in the text
+            if there is meta data.
+
+        annotate: Optional[Object->str]
+            Optionally annotate function to provide additional
+            information in the comment block.
+
+        Returns
+        -------
+        text : str
+            The text format of the expression.
+
+        Notes
+        -----
+        The meta data section is necessary to fully parse the text format.
+        However, it can contain dumps that are big (e.g constant weights),
+        so it can be helpful to skip printing the meta data section.
+        """
+        return astext(self, show_meta_data, annotate)
+
     def __neg__(self):
         return _op_make.negative(self)
 
@@ -719,8 +743,6 @@ class StorageInfo(Node):
         self.__init_handle_by_constructor__(_ffi_api.StorageInfo, sids, dev_types, sizes)
 
     def __str__(self):
-        from tvm.ir import pretty_print  # pylint: disable=import-outside-toplevel
-
         return pretty_print(self)
 
     @property
@@ -750,6 +772,4 @@ class StaticMemoryPlan(Node):
         self.__init_handle_by_constructor__(_ffi_api.StaticMemoryPlan, expr_to_storage_info)
 
     def __str__(self):
-        from tvm.ir import pretty_print  # pylint: disable=import-outside-toplevel
-
         return pretty_print(self)
diff --git a/python/tvm/relay/function.py b/python/tvm/relay/function.py
index ef33564500..dc0636a9b3 100644
--- a/python/tvm/relay/function.py
+++ b/python/tvm/relay/function.py
@@ -23,6 +23,7 @@ from tvm.ir import BaseFunc
 from tvm.runtime import convert
 
 from . import _ffi_api
+from .base import astext, pretty_print
 from .expr import Call
 
 
@@ -68,10 +69,34 @@ class Function(BaseFunc):
         return Call(self, args, None, None)
 
     def __str__(self):
-        from tvm.ir import pretty_print  # pylint: disable=import-outside-toplevel
-
         return pretty_print(self)
 
+    def astext(self, show_meta_data=True, annotate=None):
+        """Get the text format of the expression.
+
+        Parameters
+        ----------
+        show_meta_data : bool
+            Whether to include meta data section in the text
+            if there is meta data.
+
+        annotate: Optional[Object->str]
+            Optionally annotate function to provide additional
+            information in the comment block.
+
+        Returns
+        -------
+        text : str
+            The text format of the expression.
+
+        Notes
+        -----
+        The meta data section is necessary to fully parse the text format.
+        However, it can contain dumps that are big (e.g constant weights),
+        so it can be helpful to skip printing the meta data section.
+        """
+        return astext(self, show_meta_data, annotate)
+
 
 @tvm._ffi.register_func("relay.FunctionWithFields")
 def FunctionWithFields(
diff --git a/python/tvm/script/__init__.py b/python/tvm/script/__init__.py
index 82bb698f27..9283727ad4 100644
--- a/python/tvm/script/__init__.py
+++ b/python/tvm/script/__init__.py
@@ -18,4 +18,3 @@
 from .parser import ir, ir_module
 from .parser import parse as from_source
 from .parser import tir
-from .printer import script
diff --git a/python/tvm/script/printer/__init__.py b/python/tvm/script/printer/__init__.py
index dc37ea1ff6..01d89dacbf 100644
--- a/python/tvm/script/printer/__init__.py
+++ b/python/tvm/script/printer/__init__.py
@@ -20,4 +20,3 @@ This package provides a set of APIs to print supported TVM IR into TVMScript
 in a roundtrippable way.
 """
 from . import default
-from .printer import script
diff --git a/python/tvm/script/printer/printer.py b/python/tvm/script/printer/printer.py
deleted file mode 100644
index 2ce6329dca..0000000000
--- a/python/tvm/script/printer/printer.py
+++ /dev/null
@@ -1,54 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""The printer interface"""
-from typing import Optional
-
-from tvm.runtime.object_path import ObjectPath
-
-from . import _ffi_api
-
-
-def script(
-    obj,
-    indent_space: int = 4,
-    print_line_number: bool = False,
-    num_context_lines: int = -1,
-    path_to_underline: Optional[ObjectPath] = None,
-):
-    """Print a TVM IR as a TVMScript text format.
-
-    Parameters
-    ----------
-    obj : object
-        An TVM object representing TVM IR
-    indent_space : int = 4
-        The number of spaces to indent
-    print_line_number : bool = False
-        Whether to print line number
-    num_context_lines : int = -1
-        The number of context lines to print. -1 means all lines.
-    path_to_underline : Optional[ObjectPath]
-        The path to underline in the script.
-
-    Returns
-    -------
-    script : str
-        The TVMScript text format
-    """
-    return _ffi_api.Script(  # type: ignore # pylint: disable=no-member
-        obj, indent_space, print_line_number, num_context_lines, path_to_underline
-    )
diff --git a/rust/tvm/src/ir/expr.rs b/rust/tvm/src/ir/expr.rs
index 03d8a49207..1a0e7aea39 100644
--- a/rust/tvm/src/ir/expr.rs
+++ b/rust/tvm/src/ir/expr.rs
@@ -90,7 +90,7 @@ impl GlobalVar {
 
 // TODO: figure out how to type the last argument runtime::TypedPackedFunc<String(ObjectRef)> annotate)
 external! {
-    #[name("ir.AsText")]
+    #[name("relay.ir.AsText")]
     fn _as_text(object: ObjectRef, show_meta_data: i32, annotate: runtime::Function) -> TString;
 }
 
diff --git a/src/ir/transform.cc b/src/ir/transform.cc
index bfd0a59175..9a669493cc 100644
--- a/src/ir/transform.cc
+++ b/src/ir/transform.cc
@@ -587,7 +587,11 @@ TVM_REGISTER_GLOBAL("transform.OverrideInstruments")
 
 Pass PrintIR(String header, bool show_meta_data) {
   auto pass_func = [header, show_meta_data](IRModule mod, const PassContext& ctx) {
-    LOG(INFO) << "PrintIR(" << header << "):\n" << AsText(mod, show_meta_data);
+    if (const auto* f = runtime::Registry::Get("relay.PrintIR")) {
+      (*f)(mod, header, show_meta_data);
+    } else {
+      LOG(INFO) << "PrintIR(" << header << "):\n" << mod;
+    }
     return mod;
   };
   return CreateModulePass(pass_func, 0, "PrintIR", {});
diff --git a/src/relay/analysis/annotated_region_set.cc b/src/relay/analysis/annotated_region_set.cc
index 53c680b722..ef21604d8a 100644
--- a/src/relay/analysis/annotated_region_set.cc
+++ b/src/relay/analysis/annotated_region_set.cc
@@ -19,7 +19,7 @@
 
 #include "annotated_region_set.h"
 
-#include <tvm/ir/error.h>
+#include <tvm/relay/error.h>
 #include <tvm/relay/expr.h>
 
 #include <unordered_map>
diff --git a/src/relay/analysis/annotated_region_set.h b/src/relay/analysis/annotated_region_set.h
index aca4239791..443bd5ec1d 100644
--- a/src/relay/analysis/annotated_region_set.h
+++ b/src/relay/analysis/annotated_region_set.h
@@ -27,9 +27,9 @@
 #ifndef TVM_RELAY_ANALYSIS_ANNOTATED_REGION_SET_H_
 #define TVM_RELAY_ANALYSIS_ANNOTATED_REGION_SET_H_
 
-#include <tvm/ir/error.h>
 #include <tvm/relay/analysis.h>
 #include <tvm/relay/attrs/annotation.h>
+#include <tvm/relay/error.h>
 #include <tvm/relay/expr.h>
 #include <tvm/relay/expr_functor.h>
 #include <tvm/relay/transform.h>
diff --git a/src/relay/analysis/kind_check.cc b/src/relay/analysis/kind_check.cc
index 65b8516cb1..f7a5e7bf2d 100644
--- a/src/relay/analysis/kind_check.cc
+++ b/src/relay/analysis/kind_check.cc
@@ -31,9 +31,9 @@
  * We check this by ensuring the `dtype` field of a Tensor always
  * contains a data type such as `int`, `float`, `uint`.
  */
-#include <tvm/ir/error.h>
 #include <tvm/ir/type_functor.h>
 #include <tvm/relay/analysis.h>
+#include <tvm/relay/error.h>
 
 namespace tvm {
 namespace relay {
diff --git a/src/relay/analysis/match_exhaustion.cc b/src/relay/analysis/match_exhaustion.cc
index 2a90b911b6..05d5b36e36 100644
--- a/src/relay/analysis/match_exhaustion.cc
+++ b/src/relay/analysis/match_exhaustion.cc
@@ -27,8 +27,8 @@
  * code correctness, since hitting an unmatched case results in a
  * dynamic error unless exhaustiveness is checked in advance.
  */
-#include <tvm/ir/error.h>
 #include <tvm/relay/adt.h>
+#include <tvm/relay/error.h>
 #include <tvm/relay/expr_functor.h>
 #include <tvm/relay/pattern_functor.h>
 
diff --git a/src/relay/analysis/type_solver.h b/src/relay/analysis/type_solver.h
index 3bde1a1e37..7940e347b3 100644
--- a/src/relay/analysis/type_solver.h
+++ b/src/relay/analysis/type_solver.h
@@ -24,8 +24,8 @@
 #ifndef TVM_RELAY_ANALYSIS_TYPE_SOLVER_H_
 #define TVM_RELAY_ANALYSIS_TYPE_SOLVER_H_
 
-#include <tvm/ir/error.h>
 #include <tvm/relay/analysis.h>
+#include <tvm/relay/error.h>
 #include <tvm/relay/expr.h>
 #include <tvm/relay/type.h>
 
diff --git a/src/relay/backend/contrib/ethosu/codegen.cc b/src/relay/backend/contrib/ethosu/codegen.cc
index afa17750d8..a622f96c81 100644
--- a/src/relay/backend/contrib/ethosu/codegen.cc
+++ b/src/relay/backend/contrib/ethosu/codegen.cc
@@ -24,9 +24,9 @@
  * Codegen.
  */
 
-#include <tvm/ir/error.h>
 #include <tvm/relay/analysis.h>
 #include <tvm/relay/attrs/annotation.h>
+#include <tvm/relay/error.h>
 #include <tvm/relay/expr.h>
 #include <tvm/relay/expr_functor.h>
 #include <tvm/relay/transform.h>
diff --git a/src/relay/backend/contrib/ethosu/compiler_attrs.cc b/src/relay/backend/contrib/ethosu/compiler_attrs.cc
index 42add45b01..6c825a1890 100644
--- a/src/relay/backend/contrib/ethosu/compiler_attrs.cc
+++ b/src/relay/backend/contrib/ethosu/compiler_attrs.cc
@@ -17,9 +17,9 @@
  * under the License.
  */
 
-#include <tvm/ir/error.h>
 #include <tvm/relay/analysis.h>
 #include <tvm/relay/attrs/annotation.h>
+#include <tvm/relay/error.h>
 #include <tvm/relay/expr.h>
 #include <tvm/relay/expr_functor.h>
 #include <tvm/relay/transform.h>
diff --git a/src/relay/backend/contrib/ethosu/preprocess.cc b/src/relay/backend/contrib/ethosu/preprocess.cc
index 571a56ad97..a0e0ac772f 100644
--- a/src/relay/backend/contrib/ethosu/preprocess.cc
+++ b/src/relay/backend/contrib/ethosu/preprocess.cc
@@ -16,9 +16,9 @@
  * specific language governing permissions and limitations
  * under the License.
  */
-#include <tvm/ir/error.h>
 #include <tvm/relay/analysis.h>
 #include <tvm/relay/attrs/annotation.h>
+#include <tvm/relay/error.h>
 #include <tvm/relay/expr.h>
 #include <tvm/relay/expr_functor.h>
 #include <tvm/relay/transform.h>
diff --git a/src/relay/backend/contrib/uma/relay_to_tir.cc b/src/relay/backend/contrib/uma/relay_to_tir.cc
index 8aed694531..ca3ae0ebec 100644
--- a/src/relay/backend/contrib/uma/relay_to_tir.cc
+++ b/src/relay/backend/contrib/uma/relay_to_tir.cc
@@ -23,9 +23,9 @@
  * \brief this file contains the target hooks for the Universal Modular Accelerator Interface (UMA).
  */
 
-#include <tvm/ir/error.h>
 #include <tvm/relay/analysis.h>
 #include <tvm/relay/attrs/annotation.h>
+#include <tvm/relay/error.h>
 #include <tvm/relay/expr.h>
 #include <tvm/relay/expr_functor.h>
 #include <tvm/relay/transform.h>
diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc
index 9ba90b9f67..fb23c4cc08 100644
--- a/src/relay/backend/vm/compiler.cc
+++ b/src/relay/backend/vm/compiler.cc
@@ -25,11 +25,11 @@
 #include "compiler.h"
 
 #include <tvm/driver/driver_api.h>
-#include <tvm/ir/error.h>
 #include <tvm/parser/parser.h>
 #include <tvm/relay/analysis.h>
 #include <tvm/relay/attrs/device_copy.h>
 #include <tvm/relay/attrs/memory.h>
+#include <tvm/relay/error.h>
 #include <tvm/relay/expr_functor.h>
 #include <tvm/relay/interpreter.h>
 #include <tvm/relay/qnn/transform.h>
diff --git a/src/relay/backend/vm/compiler.h b/src/relay/backend/vm/compiler.h
index 163ec39901..9160ce0e2e 100644
--- a/src/relay/backend/vm/compiler.h
+++ b/src/relay/backend/vm/compiler.h
@@ -25,7 +25,7 @@
 #ifndef TVM_RELAY_BACKEND_VM_COMPILER_H_
 #define TVM_RELAY_BACKEND_VM_COMPILER_H_
 
-#include <tvm/ir/error.h>
+#include <tvm/relay/error.h>
 #include <tvm/relay/expr_functor.h>
 #include <tvm/relay/interpreter.h>
 #include <tvm/relay/transform.h>
diff --git a/src/relay/collage/partition_rule.h b/src/relay/collage/partition_rule.h
index 19e7f3cceb..ca68c9b086 100644
--- a/src/relay/collage/partition_rule.h
+++ b/src/relay/collage/partition_rule.h
@@ -31,7 +31,7 @@
 #include <string>
 #include <vector>
 
-#include "../../printer/doc.h"
+#include "../printer/doc.h"
 #include "./candidate_partition.h"
 #include "./combiner_rule.h"
 #include "./sub_graph.h"
diff --git a/src/relay/ir/base.cc b/src/relay/ir/base.cc
index 5f7b8747a7..5f91302608 100644
--- a/src/relay/ir/base.cc
+++ b/src/relay/ir/base.cc
@@ -51,5 +51,10 @@ TVM_REGISTER_GLOBAL("ir.NodeSetSpan").set_body_typed([](ObjectRef node_ref, Span
   }
 });
 
+TVM_REGISTER_GLOBAL("relay.PrintIR")
+    .set_body_typed([](ObjectRef mod, String header, bool show_metadata) {
+      LOG(INFO) << "PrintIR(" << header << "):\n" << AsText(mod, show_metadata);
+    });
+
 }  // namespace relay
 }  // namespace tvm
diff --git a/src/ir/error.cc b/src/relay/ir/error.cc
similarity index 97%
rename from src/ir/error.cc
rename to src/relay/ir/error.cc
index 26448d0400..940efd91aa 100644
--- a/src/ir/error.cc
+++ b/src/relay/ir/error.cc
@@ -16,13 +16,9 @@
  * specific language governing permissions and limitations
  * under the License.
  */
-
-/*!
- * \file ir/error.cc
- * \brief Utilities for error tracking and reporting.
- */
-#include <tvm/ir/error.h>
 #include <tvm/ir/module.h>
+#include <tvm/relay/base.h>
+#include <tvm/relay/error.h>
 
 // clang-format off
 #include <string>
@@ -31,6 +27,7 @@
 // clang-format on
 
 namespace tvm {
+namespace relay {
 
 template <typename T, typename U>
 using NodeMap = std::unordered_map<T, U, ObjectPtrHash, ObjectPtrEqual>;
@@ -137,5 +134,5 @@ void ErrorReporter::ReportAt(const GlobalVar& global, const ObjectRef& node,
   }
   this->node_to_gv_.insert({node, global});
 }
-
+}  // namespace relay
 }  // namespace tvm
diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc
index c41eb0f8ad..5c5cd6f4b7 100644
--- a/src/relay/op/tensor/transform.cc
+++ b/src/relay/op/tensor/transform.cc
@@ -23,8 +23,8 @@
  */
 #include "transform.h"
 
-#include <tvm/ir/error.h>
 #include <tvm/relay/attrs/transform.h>
+#include <tvm/relay/error.h>
 #include <tvm/relay/expr.h>
 #include <tvm/relay/op.h>
 #include <tvm/runtime/packed_func.h>
diff --git a/src/relay/op/tensor/transform.h b/src/relay/op/tensor/transform.h
index 3c638a59f4..6c88aec8b9 100644
--- a/src/relay/op/tensor/transform.h
+++ b/src/relay/op/tensor/transform.h
@@ -24,8 +24,8 @@
 #ifndef TVM_RELAY_OP_TENSOR_TRANSFORM_H_
 #define TVM_RELAY_OP_TENSOR_TRANSFORM_H_
 
-#include <tvm/ir/error.h>
 #include <tvm/relay/attrs/transform.h>
+#include <tvm/relay/error.h>
 #include <tvm/relay/op_attr_types.h>
 
 #include <algorithm>
diff --git a/src/relay/op/type_relations.h b/src/relay/op/type_relations.h
index 6d6d5f70c0..740766172d 100644
--- a/src/relay/op/type_relations.h
+++ b/src/relay/op/type_relations.h
@@ -25,7 +25,7 @@
 #ifndef TVM_RELAY_OP_TYPE_RELATIONS_H_
 #define TVM_RELAY_OP_TYPE_RELATIONS_H_
 
-#include <tvm/ir/error.h>
+#include <tvm/relay/error.h>
 #include <tvm/relay/type.h>
 
 #include <string>
diff --git a/src/printer/doc.cc b/src/relay/printer/doc.cc
similarity index 98%
rename from src/printer/doc.cc
rename to src/relay/printer/doc.cc
index b06995fb12..79313c9a58 100644
--- a/src/printer/doc.cc
+++ b/src/relay/printer/doc.cc
@@ -30,9 +30,10 @@
 #include <sstream>
 #include <vector>
 
-#include "../support/str_escape.h"
+#include "../../support/str_escape.h"
 
 namespace tvm {
+namespace relay {
 
 /*!
  * \brief Represent a piece of text in the doc.
@@ -157,4 +158,5 @@ Doc Doc::Concat(const std::vector<Doc>& vec, const Doc& sep) {
   }
   return seq;
 }
+}  // namespace relay
 }  // namespace tvm
diff --git a/src/printer/doc.h b/src/relay/printer/doc.h
similarity index 97%
rename from src/printer/doc.h
rename to src/relay/printer/doc.h
index dc6ba8952f..36f26d9bd2 100644
--- a/src/printer/doc.h
+++ b/src/relay/printer/doc.h
@@ -23,8 +23,8 @@
  *
  *  Reference: Philip Wadler. A Prettier Printer. Journal of Functional Programming'98
  */
-#ifndef TVM_PRINTER_DOC_H_
-#define TVM_PRINTER_DOC_H_
+#ifndef TVM_RELAY_PRINTER_DOC_H_
+#define TVM_RELAY_PRINTER_DOC_H_
 
 #include <tvm/node/node.h>
 #include <tvm/runtime/data_type.h>
@@ -35,6 +35,7 @@
 #include <vector>
 
 namespace tvm {
+namespace relay {
 
 /*!
  * \brief Doc atom node for the ADT.
@@ -162,6 +163,6 @@ class Doc {
   /*! \brief Internal doc stream. */
   std::vector<DocAtom> stream_;
 };
-
+}  // namespace relay
 }  // namespace tvm
-#endif  // TVM_PRINTER_DOC_H_
+#endif  // TVM_RELAY_PRINTER_DOC_H_
diff --git a/src/printer/meta_data.h b/src/relay/printer/meta_data.h
similarity index 95%
rename from src/printer/meta_data.h
rename to src/relay/printer/meta_data.h
index ddf0d78087..2dfd594de7 100644
--- a/src/printer/meta_data.h
+++ b/src/relay/printer/meta_data.h
@@ -16,13 +16,8 @@
  * specific language governing permissions and limitations
  * under the License.
  */
-
-/*!
- * \file tvm/printer/meta_data.h
- * \brief Meta data context for printers.
- */
-#ifndef TVM_PRINTER_META_DATA_H_
-#define TVM_PRINTER_META_DATA_H_
+#ifndef TVM_RELAY_PRINTER_META_DATA_H_
+#define TVM_RELAY_PRINTER_META_DATA_H_
 
 #include <tvm/node/serialization.h>
 
@@ -32,6 +27,7 @@
 #include "doc.h"
 
 namespace tvm {
+namespace relay {
 /*!
  * \brief Meta data context for Printers
  *
@@ -140,5 +136,6 @@ class TextMetaDataContext {
   /*! \brief map from meta data into its string representation */
   std::unordered_map<ObjectRef, Doc, ObjectPtrHash, ObjectPtrEqual> meta_repr_;
 };
+}  // namespace relay
 }  // namespace tvm
-#endif  // TVM_PRINTER_META_DATA_H_
+#endif  // TVM_RELAY_PRINTER_META_DATA_H_
diff --git a/src/printer/model_library_format_printer.cc b/src/relay/printer/model_library_format_printer.cc
similarity index 96%
rename from src/printer/model_library_format_printer.cc
rename to src/relay/printer/model_library_format_printer.cc
index 4220aa00f5..76d0f1423d 100644
--- a/src/printer/model_library_format_printer.cc
+++ b/src/relay/printer/model_library_format_printer.cc
@@ -26,7 +26,7 @@
 #include "text_printer.h"
 
 namespace tvm {
-namespace printer {
+namespace relay {
 
 class ModelLibraryFormatPrinter : public ::tvm::runtime::ModuleNode {
  public:
@@ -69,7 +69,7 @@ class ModelLibraryFormatPrinter : public ::tvm::runtime::ModuleNode {
   TextPrinter text_printer_;
 };
 
-TVM_REGISTER_GLOBAL("tir.ModelLibraryFormatPrinter")
+TVM_REGISTER_GLOBAL("relay.ir.ModelLibraryFormatPrinter")
     .set_body_typed([](bool show_meta_data,
                        const runtime::TypedPackedFunc<std::string(ObjectRef)>& annotate,
                        bool show_warning) {
@@ -77,5 +77,5 @@ TVM_REGISTER_GLOBAL("tir.ModelLibraryFormatPrinter")
           make_object<ModelLibraryFormatPrinter>(show_meta_data, annotate, show_warning));
     });
 
-}  // namespace printer
+}  // namespace relay
 }  // namespace tvm
diff --git a/src/printer/relay_text_printer.cc b/src/relay/printer/relay_text_printer.cc
similarity index 99%
rename from src/printer/relay_text_printer.cc
rename to src/relay/printer/relay_text_printer.cc
index 76cac28b07..cc86f9b564 100644
--- a/src/printer/relay_text_printer.cc
+++ b/src/relay/printer/relay_text_printer.cc
@@ -40,10 +40,10 @@
 #include <tvm/target/virtual_device.h>
 #include <tvm/tir/function.h>
 
-#include "../ir/attr_functor.h"
-#include "../parser/meta_ref.h"
-#include "../relay/analysis/dependency_graph.h"
-#include "../support/scalars.h"
+#include "../../ir/attr_functor.h"
+#include "../../parser/meta_ref.h"
+#include "../../support/scalars.h"
+#include "../analysis/dependency_graph.h"
 #include "doc.h"
 #include "meta_data.h"
 #include "text_printer.h"
@@ -970,10 +970,5 @@ Doc RelayTextPrinter::PrintSpan(const Span& span) {
   return doc;
 }
 
-TVM_REGISTER_GLOBAL("ir.TextPrinter").set_body_typed([](ObjectRef node) {
-  auto text = AsText(node, false, nullptr);
-  return text;
-});
-
 }  // namespace relay
 }  // namespace tvm
diff --git a/src/printer/text_printer.cc b/src/relay/printer/text_printer.cc
similarity index 95%
rename from src/printer/text_printer.cc
rename to src/relay/printer/text_printer.cc
index 4d4113fef6..f51f7c3dfa 100644
--- a/src/printer/text_printer.cc
+++ b/src/relay/printer/text_printer.cc
@@ -23,7 +23,7 @@
  *        that can be parsed by a parser.
  */
 
-#include "text_printer.h"
+#include "./text_printer.h"
 
 #include <tvm/tir/function.h>
 
@@ -31,6 +31,7 @@
 #include <string>
 
 namespace tvm {
+namespace relay {
 
 static const char* kSemVer = "0.0.5";
 
@@ -124,8 +125,8 @@ String AsText(const ObjectRef& node, bool show_meta_data,
   return doc.str();
 }
 
-TVM_REGISTER_GLOBAL("ir.PrettyPrint").set_body_typed(PrettyPrint);
-
-TVM_REGISTER_GLOBAL("ir.AsText").set_body_typed(AsText);
+TVM_REGISTER_GLOBAL("relay.ir.PrettyPrint").set_body_typed(PrettyPrint);
+TVM_REGISTER_GLOBAL("relay.ir.AsText").set_body_typed(AsText);
 
+}  // namespace relay
 }  // namespace tvm
diff --git a/src/printer/text_printer.h b/src/relay/printer/text_printer.h
similarity index 95%
rename from src/printer/text_printer.h
rename to src/relay/printer/text_printer.h
index 925c2ebf49..707bbec5ad 100644
--- a/src/printer/text_printer.h
+++ b/src/relay/printer/text_printer.h
@@ -23,8 +23,8 @@
  *        that can be parsed by a parser.
  */
 
-#ifndef TVM_PRINTER_TEXT_PRINTER_H_
-#define TVM_PRINTER_TEXT_PRINTER_H_
+#ifndef TVM_RELAY_PRINTER_TEXT_PRINTER_H_
+#define TVM_RELAY_PRINTER_TEXT_PRINTER_H_
 
 #include <tvm/ir/module.h>
 #include <tvm/ir/type_functor.h>
@@ -41,19 +41,16 @@
 #include <unordered_set>
 #include <vector>
 
-#include "../ir/attr_functor.h"
-#include "../relay/analysis/dependency_graph.h"
+#include "../../ir/attr_functor.h"
+#include "../analysis/dependency_graph.h"
 #include "doc.h"
 #include "meta_data.h"
-#include "text_printer.h"
-
-namespace tvm {
-class TextPrinter;
-}  // namespace tvm
 
 namespace tvm {
 namespace relay {
 
+class TextPrinter;
+
 class RelayTextPrinter : public ExprFunctor<Doc(const Expr&)>,
                          public PatternFunctor<Doc(const Pattern&)>,
                          public TypeFunctor<Doc(const Type&)>,
@@ -227,14 +224,10 @@ class RelayTextPrinter : public ExprFunctor<Doc(const Expr&)>,
   DependencyGraph dg_;
   class AttrPrinter;
   friend class AttrPrinter;
-  friend class tvm::TextPrinter;
+  friend class tvm::relay::TextPrinter;
 };
 
-}  // namespace relay
-}  // namespace tvm
-
-namespace tvm {
-namespace tir {
+using namespace ::tvm::tir;
 
 /*!
  *  \brief Meta node collector
@@ -274,7 +267,7 @@ class MetaCollector : public StmtExprVisitor {
 };
 
 class TIRTextPrinter : public StmtFunctor<Doc(const Stmt&)>,
-                       public ExprFunctor<Doc(const PrimExpr&)>,
+                       public tir::ExprFunctor<Doc(const PrimExpr&)>,
                        public TypeFunctor<Doc(const Type&)> {
  public:
   explicit TIRTextPrinter(bool show_meta, TextMetaDataContext* meta)
@@ -298,7 +291,7 @@ class TIRTextPrinter : public StmtFunctor<Doc(const Stmt&)>,
   Doc VisitExpr_(const FloatImmNode* op) override;
   Doc VisitExpr_(const StringImmNode* op) override;
   Doc VisitExpr_(const CastNode* op) override;
-  Doc VisitExpr_(const VarNode* op) override;
+  Doc VisitExpr_(const tir::VarNode* op) override;
   Doc VisitExpr_(const AddNode* op) override;
   Doc VisitExpr_(const SubNode* op) override;
   Doc VisitExpr_(const MulNode* op) override;
@@ -323,8 +316,8 @@ class TIRTextPrinter : public StmtFunctor<Doc(const Stmt&)>,
   Doc VisitExpr_(const LoadNode* op) override;
   Doc VisitExpr_(const RampNode* op) override;
   Doc VisitExpr_(const BroadcastNode* op) override;
-  Doc VisitExpr_(const LetNode* op) override;
-  Doc VisitExpr_(const CallNode* op) override;
+  Doc VisitExpr_(const tir::LetNode* op) override;
+  Doc VisitExpr_(const tir::CallNode* op) override;
   Doc VisitExpr_(const ShuffleNode* op) override;
   Doc VisitExpr_(const ReduceNode* op) override;
   Doc VisitExprDefault_(const Object* op) override;
@@ -357,7 +350,7 @@ class TIRTextPrinter : public StmtFunctor<Doc(const Stmt&)>,
   /*! \brief meta collector */
   MetaCollector meta_collector_;
   /*! \brief Map from Var to Doc */
-  std::unordered_map<Var, Doc, ObjectPtrHash, ObjectPtrEqual> memo_var_;
+  std::unordered_map<tir::Var, Doc, ObjectPtrHash, ObjectPtrEqual> memo_var_;
   /*! \brief Map from Buffer to Doc */
   std::unordered_map<Buffer, Doc, ObjectPtrHash, ObjectPtrEqual> memo_buf_;
   /*! \brief Map from Buffer to Doc */
@@ -365,7 +358,7 @@ class TIRTextPrinter : public StmtFunctor<Doc(const Stmt&)>,
   /*! \brief name allocation map */
   std::unordered_map<std::string, int> name_alloc_map_;
 
-  friend class tvm::TextPrinter;
+  friend class TextPrinter;
 
   Doc VisitType_(const PrimTypeNode* node) override;
   Doc VisitType_(const PointerTypeNode* node) override;
@@ -396,7 +389,7 @@ class TIRTextPrinter : public StmtFunctor<Doc(const Stmt&)>,
   template <typename T>
   static Doc PrintConstScalar(DataType dtype, const T& data);
   Doc GetUniqueName(std::string prefix);
-  Doc AllocVar(const Var& var);
+  Doc AllocVar(const tir::Var& var);
   Doc AllocConst(const AllocateConst& var);
   Doc AllocBuf(const Buffer& buffer);
   Doc AllocProducer(const DataProducer& buffer);
@@ -412,11 +405,6 @@ class TIRTextPrinter : public StmtFunctor<Doc(const Stmt&)>,
 String AsTVMScriptWithDiagnostic(const ObjectRef& mod, const String& tir_prefix, bool show_meta,
                                  runtime::TypedPackedFunc<std::string(Stmt)> annotate);
 
-}  // namespace tir
-}  // namespace tvm
-
-namespace tvm {
-
 class TextPrinter {
  public:
   explicit TextPrinter(bool show_meta_data,
@@ -441,7 +429,7 @@ class TextPrinter {
   /*! \brief Relay Text Printer */
   relay::RelayTextPrinter relay_text_printer_;
   /*! \brief TIR Text Printer */
-  tir::TIRTextPrinter tir_text_printer_;
+  TIRTextPrinter tir_text_printer_;
 
   bool GetVarName(::tvm::tir::Var v, std::string* s) { return tir_text_printer_.GetVarName(v, s); }
 
@@ -472,6 +460,7 @@ class TextPrinter {
 
   Doc PrintMod(const IRModule& mod);
 };
+}  // namespace relay
 }  // namespace tvm
 
-#endif  // TVM_PRINTER_TEXT_PRINTER_H_
+#endif  // TVM_RELAY_PRINTER_TEXT_PRINTER_H_
diff --git a/src/printer/tir_text_printer.cc b/src/relay/printer/tir_text_printer.cc
similarity index 97%
rename from src/printer/tir_text_printer.cc
rename to src/relay/printer/tir_text_printer.cc
index 4d74cc6d5a..eb089bd0d7 100644
--- a/src/printer/tir_text_printer.cc
+++ b/src/relay/printer/tir_text_printer.cc
@@ -36,13 +36,13 @@
 #include <algorithm>
 #include <string>
 
-#include "../tir/transforms/ir_utils.h"
+#include "../../tir/transforms/ir_utils.h"
 #include "doc.h"
 #include "meta_data.h"
 #include "text_printer.h"
 
 namespace tvm {
-namespace tir {
+namespace relay {
 
 Doc TIRTextPrinter::Print(const ObjectRef& node) {
   if (!node.defined()) return Doc::Text("(nullptr)");
@@ -93,9 +93,9 @@ Doc TIRTextPrinter::PrintPrimFunc(const PrimFunc& prim_func) {
   memo_buf_.clear();
 
   // ordered vars associated with buffers, for consistent printing
-  std::vector<Var> buffer_vars_ordered;
+  std::vector<tir::Var> buffer_vars_ordered;
 
-  for (Var v : op->params) {
+  for (tir::Var v : op->params) {
     auto buffer_map_find = op->buffer_map.find(v);
     if (buffer_map_find != op->buffer_map.end()) {
       auto map_data = *buffer_map_find;
@@ -132,7 +132,7 @@ Doc TIRTextPrinter::PrintPrimFunc(const PrimFunc& prim_func) {
   if (memo_buf_.size() != 0) {
     Doc buffer_doc;
     std::vector<Doc> buffer_docs;
-    for (const Var& v : buffer_vars_ordered) {
+    for (const tir::Var& v : buffer_vars_ordered) {
       const Buffer buf = op->buffer_map[v];
       buffer_docs.push_back(BufferNode2Doc(buf.get(), Print(buf)));
     }
@@ -144,7 +144,7 @@ Doc TIRTextPrinter::PrintPrimFunc(const PrimFunc& prim_func) {
   if (op->buffer_map.size() != 0) {
     // print buffer_map
     std::vector<Doc> buffer_map_doc;
-    for (const Var& v : buffer_vars_ordered) {
+    for (const tir::Var& v : buffer_vars_ordered) {
       const Buffer buf = op->buffer_map[v];
       buffer_map_doc.push_back(Print(v) << ": " << Print(buf));
     }
@@ -302,9 +302,9 @@ Doc TIRTextPrinter::VisitExpr_(const CastNode* op) {
   return doc;
 }
 
-Doc TIRTextPrinter::VisitExpr_(const VarNode* op) {
-  const Var& var = GetRef<Var>(op);
-  return meta_->InMeta(var) ? meta_->GetMetaNode(var) : AllocVar(GetRef<Var>(op));
+Doc TIRTextPrinter::VisitExpr_(const tir::VarNode* op) {
+  const tir::Var& var = GetRef<tir::Var>(op);
+  return meta_->InMeta(var) ? meta_->GetMetaNode(var) : AllocVar(GetRef<tir::Var>(op));
 }
 
 #define TVM_DECLARE_TIR_TEXT_PRINTER_BINOP(OpName, OpString) \
@@ -401,13 +401,13 @@ Doc TIRTextPrinter::VisitExpr_(const BroadcastNode* op) {
   return doc;
 }
 
-Doc TIRTextPrinter::VisitExpr_(const LetNode* op) {
+Doc TIRTextPrinter::VisitExpr_(const tir::LetNode* op) {
   Doc doc;
   doc << "let " << Print(op->var) << " = " << Print(op->value) << " in " << Print(op->body);
   return doc;
 }
 
-Doc TIRTextPrinter::VisitExpr_(const CallNode* op) {
+Doc TIRTextPrinter::VisitExpr_(const tir::CallNode* op) {
   Doc doc;
   std::vector<Doc> func_args;
   if (auto* ptr_op = op->op.as<OpNode>()) {
@@ -771,7 +771,7 @@ Doc TIRTextPrinter::GetUniqueName(std::string prefix) {
   return Doc::Text(unique_prefix);
 }
 
-Doc TIRTextPrinter::AllocVar(const Var& var) {
+Doc TIRTextPrinter::AllocVar(const tir::Var& var) {
   const auto& it = memo_var_.find(var);
   if (it != memo_var_.end()) {
     return it->second;
@@ -831,7 +831,7 @@ Doc TIRTextPrinter::PrintBody(const Stmt& body, bool indent) {
   return doc;
 }
 
-bool TIRTextPrinter::GetVarName(Var v, std::string* s) {
+bool TIRTextPrinter::GetVarName(tir::Var v, std::string* s) {
   auto it = memo_var_.find(v);
   if (it == memo_var_.end()) {
     return false;
@@ -841,5 +841,5 @@ bool TIRTextPrinter::GetVarName(Var v, std::string* s) {
   return true;
 }
 
-}  // namespace tir
+}  // namespace relay
 }  // namespace tvm
diff --git a/src/printer/tir_text_printer_debug.cc b/src/relay/printer/tir_text_printer_debug.cc
similarity index 98%
rename from src/printer/tir_text_printer_debug.cc
rename to src/relay/printer/tir_text_printer_debug.cc
index 6c29558f72..914d8877d2 100644
--- a/src/printer/tir_text_printer_debug.cc
+++ b/src/relay/printer/tir_text_printer_debug.cc
@@ -29,7 +29,7 @@
 #include <string>
 
 namespace tvm {
-namespace tir {
+namespace relay {
 
 std::optional<std::string> span_text(const Span& span) {
   if (!span.defined()) {
@@ -93,5 +93,5 @@ Doc TIRTextPrinterDebug::VisitExpr(const PrimExpr& e) {
   return TIRTextPrinter::VisitExpr(e);
 }
 
-}  // namespace tir
+}  // namespace relay
 }  // namespace tvm
diff --git a/src/printer/tir_text_printer_debug.h b/src/relay/printer/tir_text_printer_debug.h
similarity index 90%
rename from src/printer/tir_text_printer_debug.h
rename to src/relay/printer/tir_text_printer_debug.h
index d0046034cf..f7cb7a6554 100644
--- a/src/printer/tir_text_printer_debug.h
+++ b/src/relay/printer/tir_text_printer_debug.h
@@ -23,8 +23,8 @@
  *        that can be parsed by a parser.
  */
 
-#ifndef TVM_PRINTER_TIR_TEXT_PRINTER_DEBUG_H_
-#define TVM_PRINTER_TIR_TEXT_PRINTER_DEBUG_H_
+#ifndef TVM_RELAY_PRINTER_TIR_TEXT_PRINTER_DEBUG_H_
+#define TVM_RELAY_PRINTER_TIR_TEXT_PRINTER_DEBUG_H_
 
 #include <tuple>
 #include <vector>
@@ -32,7 +32,7 @@
 #include "text_printer.h"
 
 namespace tvm {
-namespace tir {
+namespace relay {
 
 class TIRTextPrinterDebug : public TIRTextPrinter {
  public:
@@ -64,7 +64,7 @@ class TIRTextPrinterDebug : public TIRTextPrinter {
   std::vector<std::tuple<const PrimExprNode*, size_t>> exprs_by_line_;
 };
 
-}  // namespace tir
+}  // namespace relay
 }  // namespace tvm
 
-#endif  // TVM_PRINTER_TIR_TEXT_PRINTER_DEBUG_H_
+#endif  // TVM_RELAY_PRINTER_TIR_TEXT_PRINTER_DEBUG_H_
diff --git a/src/printer/tvmscript_printer.cc b/src/relay/printer/tvmscript_printer.cc
similarity index 96%
rename from src/printer/tvmscript_printer.cc
rename to src/relay/printer/tvmscript_printer.cc
index c578bc53d3..0966110950 100644
--- a/src/printer/tvmscript_printer.cc
+++ b/src/relay/printer/tvmscript_printer.cc
@@ -39,13 +39,15 @@
 #include <algorithm>
 #include <utility>
 
-#include "../tir/transforms/ir_utils.h"
+#include "../../tir/transforms/ir_utils.h"
 #include "doc.h"
 #include "meta_data.h"
 #include "text_printer.h"
 
 namespace tvm {
-namespace tir {
+namespace relay {
+
+using namespace tvm::tir;
 
 enum class ExprPrecedence : int {
   /*! \brief Identity(e.g., IntImm, Var) and function call(e.g., floordiv, min) */
@@ -77,14 +79,14 @@ enum class ExprPrecedence : int {
  */
 class BufferUsageFinder : public StmtExprVisitor {
  public:
-  static Map<Var, Array<Buffer>> FindUsage(Map<Var, Array<Buffer>> usage, Stmt body) {
+  static Map<tir::Var, Array<Buffer>> FindUsage(Map<tir::Var, Array<Buffer>> usage, Stmt body) {
     BufferUsageFinder visitor(std::move(usage));
     visitor.VisitStmt(body);
     return std::move(visitor.usage_);
   }
 
-  void VisitExpr_(const VarNode* op) final {
-    Var var = GetRef<Var>(op);
+  void VisitExpr_(const tir::VarNode* op) final {
+    tir::Var var = GetRef<tir::Var>(op);
     if (!usage_.count(var)) {
       usage_.Set(var, {});
     }
@@ -107,7 +109,7 @@ class BufferUsageFinder : public StmtExprVisitor {
   }
 
  private:
-  explicit BufferUsageFinder(Map<Var, Array<Buffer>> usage) : usage_(usage) {}
+  explicit BufferUsageFinder(Map<tir::Var, Array<Buffer>> usage) : usage_(usage) {}
 
   void VisitBuffer(const Buffer& buffer) {
     if (buffers_visited_.count(buffer.get())) {
@@ -124,7 +126,7 @@ class BufferUsageFinder : public StmtExprVisitor {
   }
 
   // The search result.
-  Map<Var, Array<Buffer>> usage_;
+  Map<tir::Var, Array<Buffer>> usage_;
   // The buffers that have been visited so far, to avoid duplicate
   // entries in the search result.
   std::unordered_set<const BufferNode*> buffers_visited_;
@@ -139,7 +141,7 @@ class BufferUsageFinder : public StmtExprVisitor {
  *          subexpression to decide whether or not parentheses is needed.
  */
 class TVMScriptPrinter : public StmtFunctor<Doc(const Stmt&)>,
-                         public ExprFunctor<Doc(const PrimExpr&, ExprPrecedence*)>,
+                         public tir::ExprFunctor<Doc(const PrimExpr&, ExprPrecedence*)>,
                          public TypeFunctor<Doc(const Type&)> {
  public:
   explicit TVMScriptPrinter(const String& tir_prefix, bool show_meta,
@@ -167,20 +169,20 @@ class TVMScriptPrinter : public StmtFunctor<Doc(const Stmt&)>,
   /*! \brief meta data context */
   TextMetaDataContext meta_;
   /*! \brief meta collector */
-  MetaCollector meta_collector_;
+  relay::MetaCollector meta_collector_;
   /*! \brief map from Function to GlobalVar */
   std::unordered_map<const BaseFuncNode*, GlobalVar> func2var_;
   /*! \brief var collector (var defined by For/Loop/Block) */
-  std::unordered_set<const VarNode*> var_not_in_headers_;
+  std::unordered_set<const tir::VarNode*> var_not_in_headers_;
   /*!
    * \brief buffer collector
    *        (buffer defined in BufferMap, BufferAllocation and MatchBufferRegion)
    */
   std::unordered_set<const BufferNode*> buf_not_in_headers_;
   /*! \brief Map from Var to thread env name */
-  std::unordered_map<Var, String, ObjectPtrHash, ObjectPtrEqual> var_env_map_;
+  std::unordered_map<tir::Var, String, ObjectPtrHash, ObjectPtrEqual> var_env_map_;
   /*! \brief Map from Var to Doc */
-  std::unordered_map<Var, Doc, ObjectPtrHash, ObjectPtrEqual> memo_var_;
+  std::unordered_map<tir::Var, Doc, ObjectPtrHash, ObjectPtrEqual> memo_var_;
   /*! \brief Map from Buffer to Doc */
   std::unordered_map<Buffer, Doc, ObjectPtrHash, ObjectPtrEqual> memo_buf_;
   /*! \brief Map from Buffer to Declaration Doc */
@@ -194,7 +196,7 @@ class TVMScriptPrinter : public StmtFunctor<Doc(const Stmt&)>,
   /*! \brief loop stack without annotations */
   std::vector<For> simple_loop_stack_;
   /*! \brief the maps from loop_vars to the loops */
-  std::unordered_map<const VarNode*, For> loop_var_map_;
+  std::unordered_map<const tir::VarNode*, For> loop_var_map_;
   /*!
    * \brief simple block vars remap from loop vars
    * simple_remap requires:
@@ -210,12 +212,12 @@ class TVMScriptPrinter : public StmtFunctor<Doc(const Stmt&)>,
    * LetStmt or Allocate that generates their data pointer, rather
    * than in the header.
    */
-  Map<Var, Array<Buffer>> buffer_var_usage_;
+  Map<tir::Var, Array<Buffer>> buffer_var_usage_;
   /*! \brief Analyzer to simplify some expressions. */
   arith::Analyzer ana_;
 
   Doc VisitExpr_(const CastNode* op, ExprPrecedence* out_precedence) override;
-  Doc VisitExpr_(const VarNode* op, ExprPrecedence* out_precedence) override;
+  Doc VisitExpr_(const tir::VarNode* op, ExprPrecedence* out_precedence) override;
   Doc VisitExpr_(const AddNode* op, ExprPrecedence* out_precedence) override;
   Doc VisitExpr_(const SubNode* op, ExprPrecedence* out_precedence) override;
   Doc VisitExpr_(const MulNode* op, ExprPrecedence* out_precedence) override;
@@ -243,8 +245,8 @@ class TVMScriptPrinter : public StmtFunctor<Doc(const Stmt&)>,
   Doc VisitExpr_(const LoadNode* op, ExprPrecedence* out_precedence) override;
   Doc VisitExpr_(const RampNode* op, ExprPrecedence* out_precedence) override;
   Doc VisitExpr_(const BroadcastNode* op, ExprPrecedence* out_precedence) override;
-  Doc VisitExpr_(const LetNode* op, ExprPrecedence* out_precedence) override;
-  Doc VisitExpr_(const CallNode* op, ExprPrecedence* out_precedence) override;
+  Doc VisitExpr_(const tir::LetNode* op, ExprPrecedence* out_precedence) override;
+  Doc VisitExpr_(const tir::CallNode* op, ExprPrecedence* out_precedence) override;
   Doc VisitExpr_(const ShuffleNode* op, ExprPrecedence* out_precedence) override;
   Doc VisitExpr_(const ReduceNode* op, ExprPrecedence* out_precedence) override;
   Doc VisitExprDefault_(const Object* op, ExprPrecedence* out_precedence) override;
@@ -297,9 +299,9 @@ class TVMScriptPrinter : public StmtFunctor<Doc(const Stmt&)>,
   static Doc PrintString(const StringObj* op) { return Doc::StrLiteral(op->data); }
 
   Doc GetUniqueName(std::string prefix);
-  Doc AllocVar(const Var& var);
+  Doc AllocVar(const tir::Var& var);
   Doc AllocBuf(const Buffer& buffer);
-  void TryDeallocVar(const Var& var);
+  void TryDeallocVar(const tir::Var& var);
   bool ContainsOptionalInfo(const Stmt& stmt);
   /*!
    * \brief Check if a buffer declaration satisfies:
@@ -338,7 +340,9 @@ class TVMScriptPrinter : public StmtFunctor<Doc(const Stmt&)>,
    * \return A boolean indicating whether the input loop depends on previous loops
    */
   bool DependOnPrevLoops(const ForNode* for_op) {
-    auto f_check = [&var_map = this->loop_var_map_](const VarNode* v) { return var_map.count(v); };
+    auto f_check = [&var_map = this->loop_var_map_](const tir::VarNode* v) {
+      return var_map.count(v);
+    };
     return UsesVar(for_op->min, f_check) || UsesVar(for_op->extent, f_check);
   }
 
@@ -494,7 +498,7 @@ Doc TVMScriptPrinter::GetUniqueName(std::string prefix) {
   return Doc::Text(unique_prefix);
 }
 
-Doc TVMScriptPrinter::AllocVar(const Var& var) {
+Doc TVMScriptPrinter::AllocVar(const tir::Var& var) {
   const auto& it = memo_var_.find(var);
   if (it != memo_var_.end()) {
     return it->second;
@@ -522,8 +526,8 @@ Doc TVMScriptPrinter::AllocBufferDeclaration(const Buffer& buf) {
   if (!buf->strides.empty()) {
     doc << ", strides=" << Print(buf->strides);
   }
-  if (buf->elem_offset->IsInstance<VarNode>()) {
-    Var elem_offset = Downcast<Var>(buf->elem_offset);
+  if (buf->elem_offset->IsInstance<tir::VarNode>()) {
+    tir::Var elem_offset = Downcast<tir::Var>(buf->elem_offset);
     if (memo_var_.find(elem_offset) != memo_var_.end()) {
       doc << ", elem_offset=" << Print(buf->elem_offset);
     } else {
@@ -585,7 +589,7 @@ bool TVMScriptPrinter::ContainsOptionalInfo(const Stmt& stmt) {
  * \brief Try to dealloc vars out of space and leave the index to coming vars.
  * \note It is not a necessary step.
  */
-void TVMScriptPrinter::TryDeallocVar(const Var& var) {
+void TVMScriptPrinter::TryDeallocVar(const tir::Var& var) {
   auto it = memo_var_.find(var);
   ICHECK(it != memo_var_.end());
   std::string print_name = it->second.str();
@@ -695,7 +699,7 @@ Doc TVMScriptPrinter::PrintCommReducer(const CommReducerNode* op) {
   int n_var = static_cast<int>(op->rhs.size());
 
   doc << tir_prefix_ << ".comm_reducer(lambda ";
-  for (const Var& v_lhs : op->lhs) {
+  for (const tir::Var& v_lhs : op->lhs) {
     doc << Print(v_lhs) << ", ";
   }
   for (int i = 0; i < n_var; ++i) {
@@ -789,10 +793,10 @@ Doc TVMScriptPrinter::VisitExpr_(const CastNode* op, ExprPrecedence* out_precede
   return doc;
 }
 
-Doc TVMScriptPrinter::VisitExpr_(const VarNode* op, ExprPrecedence* out_precedence) {
+Doc TVMScriptPrinter::VisitExpr_(const tir::VarNode* op, ExprPrecedence* out_precedence) {
   *out_precedence = ExprPrecedence::kIdentity;
-  const Var& var = GetRef<Var>(op);
-  return meta_.InMeta(var) ? meta_.GetMetaNode(var) : AllocVar(GetRef<Var>(op));
+  const tir::Var& var = GetRef<tir::Var>(op);
+  return meta_.InMeta(var) ? meta_.GetMetaNode(var) : AllocVar(GetRef<tir::Var>(op));
 }
 
 bool WillPrintConstScalar(const PrimExpr& expr) {
@@ -938,7 +942,7 @@ Doc TVMScriptPrinter::VisitExpr_(const BroadcastNode* op, ExprPrecedence* out_pr
   return doc;
 }
 
-Doc TVMScriptPrinter::VisitExpr_(const LetNode* op, ExprPrecedence* out_precedence) {
+Doc TVMScriptPrinter::VisitExpr_(const tir::LetNode* op, ExprPrecedence* out_precedence) {
   *out_precedence = ExprPrecedence::kIdentity;
   Doc doc;
   doc << tir_prefix_ << ".let(" << Print(op->var) << ", " << Print(op->value) << ", "
@@ -946,7 +950,7 @@ Doc TVMScriptPrinter::VisitExpr_(const LetNode* op, ExprPrecedence* out_preceden
   return doc;
 }
 
-Doc TVMScriptPrinter::VisitExpr_(const CallNode* op, ExprPrecedence* out_precedence) {
+Doc TVMScriptPrinter::VisitExpr_(const tir::CallNode* op, ExprPrecedence* out_precedence) {
   *out_precedence = ExprPrecedence::kIdentity;
   Doc doc;
   if (auto* ptr_op = op->op.as<OpNode>()) {
@@ -1090,7 +1094,7 @@ Doc TVMScriptPrinter::VisitStmt_(const BufferRealizeNode* op) {
 namespace {
 
 bool IsAllocateDeclBufferPattern(const AllocateNode* allocate) {
-  const Var& buffer_var = allocate->buffer_var;
+  const tir::Var& buffer_var = allocate->buffer_var;
   const DeclBufferNode* decl_buffer = allocate->body.as<DeclBufferNode>();
   if (!decl_buffer) {
     return false;
@@ -1468,8 +1472,8 @@ Doc TVMScriptPrinter::PrintBlockVars(const BlockRealizeNode* op) {
   auto is_simple_remap = [this, &expr_equal](const IterVar& iter_var,
                                              const PrimExpr& value) -> bool {
     if (iter_var->iter_type != kDataPar && iter_var->iter_type != kCommReduce) return false;
-    if (!value->IsInstance<VarNode>()) return false;
-    const Var& var = Downcast<Var>(value);
+    if (!value->IsInstance<tir::VarNode>()) return false;
+    const tir::Var& var = Downcast<tir::Var>(value);
     auto it = loop_var_map_.find(var.get());
     return it != loop_var_map_.end() && expr_equal(it->second->min, iter_var->dom->min) &&
            expr_equal(it->second->extent, iter_var->dom->extent);
@@ -1763,7 +1767,7 @@ Doc TVMScriptPrinter::PrintPrimFunc(const PrimFunc& primFunc) {
   }
   // print var declaration
   Doc header_var;
-  std::vector<const VarNode*> vars;
+  std::vector<const tir::VarNode*> vars;
   for (const auto& it : memo_var_) {
     if (var_not_in_headers_.find(it.first.get()) == var_not_in_headers_.end()) {
       vars.push_back(it.first.get());
@@ -1777,20 +1781,21 @@ Doc TVMScriptPrinter::PrintPrimFunc(const PrimFunc& primFunc) {
     }
   }
   if (!vars.empty()) {
-    std::sort(vars.begin(), vars.end(), [&](const VarNode* a, const VarNode* b) {
-      return memo_var_[GetRef<Var>(a)].str() < memo_var_[GetRef<Var>(b)].str();
+    std::sort(vars.begin(), vars.end(), [&](const tir::VarNode* a, const tir::VarNode* b) {
+      return memo_var_[GetRef<tir::Var>(a)].str() < memo_var_[GetRef<tir::Var>(b)].str();
     });
     for (const auto& var : vars) {
-      auto type = GetRef<Var>(var)->type_annotation;
+      auto type = GetRef<tir::Var>(var)->type_annotation;
       if (auto* ptr_type = type.as<PointerTypeNode>()) {
         auto* prim_type = ptr_type->element_type.as<PrimTypeNode>();
         ICHECK(prim_type);
-        header_var << Doc::NewLine() << Print(GetRef<Var>(var)) << " = " << tir_prefix_
+        header_var << Doc::NewLine() << Print(GetRef<tir::Var>(var)) << " = " << tir_prefix_
                    << ".buffer_var(";
         header_var << PrintDType(prim_type->dtype) << ", "
                    << Doc::StrLiteral(ptr_type->storage_scope) << ")";
       } else {
-        header_var << Doc::NewLine() << Print(GetRef<Var>(var)) << " = " << tir_prefix_ << ".var(";
+        header_var << Doc::NewLine() << Print(GetRef<tir::Var>(var)) << " = " << tir_prefix_
+                   << ".var(";
         header_var << PrintDType(var->dtype) << ")";
       }
     }
@@ -2013,5 +2018,5 @@ String AsTVMScriptWithDiagnostic(const ObjectRef& mod, const String& tir_prefix,
 
 TVM_REGISTER_GLOBAL("script.AsTVMScriptWithDiagnostic").set_body_typed(AsTVMScriptWithDiagnostic);
 
-}  // namespace tir
+}  // namespace relay
 }  // namespace tvm
diff --git a/src/relay/transforms/merge_compiler_regions.cc b/src/relay/transforms/merge_compiler_regions.cc
index d18c17e63c..d70c7480e9 100644
--- a/src/relay/transforms/merge_compiler_regions.cc
+++ b/src/relay/transforms/merge_compiler_regions.cc
@@ -30,9 +30,9 @@
  * as external functions.
  */
 
-#include <tvm/ir/error.h>
 #include <tvm/relay/analysis.h>
 #include <tvm/relay/attrs/annotation.h>
+#include <tvm/relay/error.h>
 #include <tvm/relay/expr.h>
 #include <tvm/relay/expr_functor.h>
 #include <tvm/relay/transform.h>
diff --git a/src/relay/transforms/partition_graph.cc b/src/relay/transforms/partition_graph.cc
index f6cdf6d1ca..32ca2878fd 100644
--- a/src/relay/transforms/partition_graph.cc
+++ b/src/relay/transforms/partition_graph.cc
@@ -29,10 +29,10 @@
  * external functions, and they will use the provided compiler for codegen.
  */
 
-#include <tvm/ir/error.h>
 #include <tvm/ir/module.h>
 #include <tvm/relay/analysis.h>
 #include <tvm/relay/attrs/annotation.h>
+#include <tvm/relay/error.h>
 #include <tvm/relay/expr.h>
 #include <tvm/relay/expr_functor.h>
 #include <tvm/relay/transform.h>
diff --git a/src/script/printer/printer.cc b/src/script/printer/printer.cc
index 9ebdcb1e99..878b380a37 100644
--- a/src/script/printer/printer.cc
+++ b/src/script/printer/printer.cc
@@ -23,18 +23,11 @@ namespace tvm {
 namespace script {
 namespace printer {
 
-String Script(ObjectRef obj, int indent_spaces, bool print_line_numbers, int num_context_lines,
-              Optional<ObjectPath> path_to_underline) {
-  return DocToPythonScript(IRDocsifier()->AsDoc(obj, ObjectPath::Root()), indent_spaces,
-                           print_line_numbers, num_context_lines, path_to_underline);
-}
-
 Default* Default::Instance() {
   static Default inst;
   return &inst;
 }
 
-TVM_REGISTER_GLOBAL("script.printer.Script").set_body_typed(Script);
 TVM_REGISTER_GLOBAL("script.printer.DefaultIRPrefix")
     .set_body_typed([](std::string ir, std::string prefix) { Default::Prefix(ir) = prefix; });
 TVM_REGISTER_GLOBAL("script.printer.DefaultBufferDType")
diff --git a/src/tir/schedule/error.cc b/src/tir/schedule/error.cc
index 55d751c331..1aae0202ac 100644
--- a/src/tir/schedule/error.cc
+++ b/src/tir/schedule/error.cc
@@ -16,7 +16,6 @@
  * specific language governing permissions and limitations
  * under the License.
  */
-#include "../../printer/text_printer.h"
 #include "./utils.h"
 
 namespace tvm {
@@ -52,10 +51,11 @@ String ScheduleError::RenderReport(const String& primitive) const {
             }
             return it->second;
           });
-
+  const auto* f = runtime::Registry::Get("script.AsTVMScriptWithDiagnostic");
+  ICHECK(f != nullptr);
   os << "ScheduleError: An error occurred in the schedule primitive '" << primitive
      << "'.\n\nThe IR with diagnostic is:\n"
-     << AsTVMScriptWithDiagnostic(mod, "T", false, annotate);
+     << ((*f)(mod, "T", false, annotate).operator String());
 
   // print error message
   os << "Error message: " << msg;
diff --git a/src/tir/transforms/install_debug_spans.cc b/src/tir/transforms/install_debug_spans.cc
index bc9002ee84..c97070e1bf 100644
--- a/src/tir/transforms/install_debug_spans.cc
+++ b/src/tir/transforms/install_debug_spans.cc
@@ -30,7 +30,7 @@
 #include <string>
 #include <utility>
 
-#include "../../printer/tir_text_printer_debug.h"
+#include "../../relay/printer/tir_text_printer_debug.h"
 
 namespace tvm {
 namespace tir {
@@ -42,7 +42,7 @@ Stmt DebugInfoInstaller::InstallInfo(const std::string& name, const Stmt& stmt)
 
 DebugInfoInstaller::DebugInfoInstaller(const Stmt& stmt, const std::string& filename) {
   // Determine the line that each stmt/expr will be printed on
-  tvm::tir::TIRTextPrinterDebug printer(false);
+  tvm::relay::TIRTextPrinterDebug printer(false);
 
   // Fill in the stmts and exprs' line info
   auto result = printer.Print(stmt).str();
diff --git a/tests/python/relay/test_ir_parser.py b/tests/python/relay/test_ir_parser.py
index 5ea6d7e5de..08fa01f0b3 100644
--- a/tests/python/relay/test_ir_parser.py
+++ b/tests/python/relay/test_ir_parser.py
@@ -14,15 +14,15 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+from typing import Union
+
 import numpy as np
 import pytest
-
 import tvm
-import tvm.testing
-from tvm import relay
 import tvm.relay.testing
+import tvm.testing
 from numpy import isclose
-from typing import Union
+from tvm import relay
 
 SEMVER = '#[version = "0.0.5"]\n'
 
@@ -74,7 +74,7 @@ def graph_equal(lhs, rhs):
 
 
 def roundtrip_expr(expr):
-    text = tvm.relay.Expr.astext(expr, show_meta_data=False)
+    text = expr.astext()
     x = tvm.parser.parse_expr(text)
     assert_graph_equal(x, expr)
 
diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_mlt.py b/tests/python/unittest/test_meta_schedule_schedule_rule_mlt.py
index bb96022794..f40d942749 100644
--- a/tests/python/unittest/test_meta_schedule_schedule_rule_mlt.py
+++ b/tests/python/unittest/test_meta_schedule_schedule_rule_mlt.py
@@ -885,5 +885,4 @@ def test_max_pool_blocked():
 
 
 if __name__ == "__main__":
-    # tvm.testing.main()
-    test_cache_read_specify_consumer()
+    tvm.testing.main()
diff --git a/tests/python/unittest/test_tir_nodes.py b/tests/python/unittest/test_tir_nodes.py
index d4ae84a556..2806c7b2fc 100644
--- a/tests/python/unittest/test_tir_nodes.py
+++ b/tests/python/unittest/test_tir_nodes.py
@@ -343,7 +343,6 @@ def test_prim_func():
 
     func = tvm.tir.PrimFunc([x, y, b], stmt)
     # make sure we can print
-    func.astext()
     assert func.buffer_map[func.params[2]].same_as(b)
 
     assert len(func.buffer_map) == 1
@@ -399,130 +398,5 @@ def test_intimm_cond():
     assert x == 1
 
 
-def test_block_blockrealize():
-    x = tvm.tir.Var("x", "int32")
-    y = tvm.tir.Var("y", "int32")
-    vx = tvm.tir.IterVar((16, 16), "vx", 0)
-    vx_var = vx.var
-    vy = tvm.tir.IterVar((16, 16), "vy", 2)
-    vy_var = vy.var
-    A = tvm.tir.decl_buffer((16), "float32")
-    B = tvm.tir.decl_buffer((16, 16), "float32")
-    alloc_buffer = tvm.tir.decl_buffer((16, 16), "float32")
-    match_buffer = tvm.tir.decl_buffer((16, 16), "float32")
-    init_body = tvm.tir.BufferStore(A, 0.0, [vx_var])
-    body = tvm.tir.BufferStore(
-        A,
-        tvm.tir.BufferLoad(A, [vx_var]) + tvm.tir.BufferLoad(B, [vx_var, vy_var]),
-        [vx_var],
-    )
-    reads = [
-        tvm.tir.BufferRegion(
-            B, [tvm.ir.Range.from_min_extent(vx_var, 1), tvm.ir.Range.from_min_extent(vy_var, 1)]
-        )
-    ]
-    writes = [tvm.tir.BufferRegion(A, [tvm.ir.Range.from_min_extent(vx_var, 1)])]
-    block_match_buffer = tvm.tir.MatchBufferRegion(
-        match_buffer, tvm.tir.BufferRegion(B, [tvm.ir.Range(0, 16), tvm.ir.Range(0, 16)])
-    )
-
-    block = tvm.tir.Block(
-        [vx, vy],
-        reads,
-        writes,
-        "block",
-        body,
-        init=init_body,
-        alloc_buffers=[alloc_buffer],
-        match_buffers=[block_match_buffer],
-        annotations={"attr_key": "attr_value"},
-    )
-
-    # Checking Block
-    assert isinstance(block, tvm.tir.Block)
-    # Checking iter_vars
-    assert block.iter_vars[0] == vx
-    assert block.iter_vars[1] == vy
-    # Checking reads/writes region
-    assert isinstance(block.reads[0], tvm.tir.BufferRegion)
-    assert block.reads[0].buffer == B
-    assert block.reads[0].region[0].min == vx_var
-    assert block.reads[0].region[1].min == vy_var
-    assert isinstance(block.writes[0], tvm.tir.BufferRegion)
-    assert block.writes[0].buffer == A
-    assert block.writes[0].region[0].min == vx_var
-    assert block.writes[0].region[0].extent == 1
-    # Checking name_hint
-    assert block.name_hint == "block"
-    # Checking body
-    assert block.body == body
-    # Checking init
-    assert block.init == init_body
-    # Checking alloc_buffers
-    assert block.alloc_buffers[0] == alloc_buffer
-    # Checking match_buffers
-    assert block.match_buffers[0].buffer == match_buffer
-    assert isinstance(block.match_buffers[0].source, tvm.tir.BufferRegion)
-    assert block.match_buffers[0].source.buffer == B
-    assert block.match_buffers[0].source.region[0].min == 0
-    assert block.match_buffers[0].source.region[0].extent == 16
-
-    # Checking BlockRealize
-    block_realize = tvm.tir.BlockRealize([x, y], tvm.tir.const(True, "bool"), block)
-    assert isinstance(block_realize, tvm.tir.BlockRealize)
-    assert block_realize.iter_values[0] == x
-    assert block_realize.iter_values[1] == y
-    assert block_realize.predicate == tvm.tir.const(True, "bool")
-    assert block_realize.block == block
-
-    # make sure we can print using ReprPrinter
-    str(block)
-    str(block_realize)
-    # make sure we can print using TIRTextPrinter
-    func = tvm.tir.PrimFunc([], block_realize)
-    output = func.astext()
-    assert output.find("meta[tir.BlockRealise]") == -1
-    assert output.find("bind") != -1
-    assert output.find("reads") != -1
-    assert output.find("writes") != -1
-    assert output.find("alloc_buffer") != -1
-    assert output.find("match_buffer") != -1
-    assert output.find("attr") != -1
-    assert output.find("with init()") != -1
-
-
-def test_tir_allocate():
-    dtype = "int8"
-    storage_scope = "global"
-    ptype = tvm.ir.PointerType(tvm.ir.PrimType(dtype), storage_scope)
-    a = te.var("buffer", ptype)
-    allocate = tvm.tir.Allocate(
-        buffer_var=a,
-        dtype=dtype,
-        extents=[2, 2],
-        condition=tvm.get_global_func("tir.const_true")(dtype, None),
-        body=tvm.tir.Evaluate(2 + 1),
-        annotations={
-            "attr1": "foo",
-            "attr2": "bar",
-        },
-    )
-    assert allocate.buffer_var == a
-    assert allocate.dtype == "int8"
-    assert list(allocate.extents) == [2, 2]
-    assert allocate.annotations["attr1"] == "foo"
-    assert allocate.annotations["attr2"] == "bar"
-
-    # make sure we can print using TIRTextPrinter
-    func = tvm.tir.PrimFunc([], allocate)
-    output = func.astext()
-    assert (
-        output.find(
-            'allocate(buffer: Pointer(global int8), int8, [2, 2]), storage_scope = global, annotations = {"attr2": "bar", "attr1": "foo"})'
-        )
-        != -1
-    )
-
-
 if __name__ == "__main__":
     tvm.testing.main()
diff --git a/tests/python/unittest/test_tir_transform_lower_warp_memory.py b/tests/python/unittest/test_tir_transform_lower_warp_memory.py
index 48af3ebaf5..d4abc26bb2 100644
--- a/tests/python/unittest/test_tir_transform_lower_warp_memory.py
+++ b/tests/python/unittest/test_tir_transform_lower_warp_memory.py
@@ -14,14 +14,13 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+import numpy as np
+import pytest
 import tvm
+import tvm.testing
 from tvm import te
 from tvm.contrib.nvcc import have_fp16
 
-import numpy as np
-import tvm.testing
-import pytest
-
 
 @tvm.testing.requires_cuda
 def test_lower_warp_memory_local_scope():
@@ -320,7 +319,7 @@ def test_lower_warp_memory_same_thread():
     fdevice = tvm.tir.transform.SplitHostDevice()(mod)["f_kernel0"]
     mod = tvm.IRModule.from_expr(fdevice)
     fdevice = tvm.tir.transform.LowerWarpMemory()(mod)["f_kernel0"]
-    assert "tvm_warp_shuffle" not in fdevice.astext()
+    assert "tvm_warp_shuffle" not in fdevice.script()
 
 
 @tvm.testing.requires_cuda
diff --git a/tests/python/unittest/test_tvmscript_printer_syntax_sugar.py b/tests/python/unittest/test_tvmscript_printer_syntax_sugar.py
deleted file mode 100644
index 1bccb8188c..0000000000
--- a/tests/python/unittest/test_tvmscript_printer_syntax_sugar.py
+++ /dev/null
@@ -1,69 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-import pytest
-import tvm.testing
-from tvm.script.parser import tir as T
-from tvm.script import script
-
-
-def _test(obj, expected: str):
-    assert script(obj).strip() == expected.strip()
-
-
-def test_remap():
-    @T.prim_func
-    def block_with_remap_implicitly():
-        for i0, i1, i2, i3, i4, i5 in T.grid(128, 128, 128, 128, 128, 128):
-            with T.block("update"):
-                v0 = T.axis.spatial(128, i0 + 1)
-                v1 = T.axis.spatial(128, i1)
-                v2 = T.axis.reduce(128, i2)
-                v3 = T.axis.spatial(128, i3 - 1)
-                v4 = T.axis.reduce(128, i4)
-                v5 = T.axis.spatial(128, i5)
-                pass
-
-    @T.prim_func
-    def block_with_remap_explicitly():
-        for i0, i1, i2, i3, i4, i5 in T.grid(128, 128, 128, 128, 128, 128):
-            with T.block("update"):
-                v0 = T.axis.spatial(128, i0 + 1)
-                v1, v2 = T.axis.remap("SR", [i1, i2])
-                v3 = T.axis.spatial(128, i3 - 1)
-                v4, v5 = T.axis.remap("RS", [i4, i5])
-                pass
-
-    expected_output = """@T.prim_func
-def main():
-    with T.block("root"):
-        T.reads()
-        T.writes()
-        for i0, i1, i2, i3, i4, i5 in T.grid(128, 128, 128, 128, 128, 128):
-            with T.block("update"):
-                v0 = T.axis.spatial(128, i0 + 1)
-                v1, v2 = T.axis.remap("SR", [i1, i2])
-                v3 = T.axis.spatial(128, i3 - 1)
-                v4, v5 = T.axis.remap("RS", [i4, i5])
-                T.reads()
-                T.writes()
-                T.evaluate(0)"""
-    _test(block_with_remap_implicitly, expected_output)
-    _test(block_with_remap_explicitly, expected_output)
-
-
-if __name__ == "__main__":
-    tvm.testing.main()
diff --git a/tests/python/unittest/test_tvmscript_printer_tir.py b/tests/python/unittest/test_tvmscript_printer_tir.py
index 9c15fbc889..d62a1cd12c 100644
--- a/tests/python/unittest/test_tvmscript_printer_tir.py
+++ b/tests/python/unittest/test_tvmscript_printer_tir.py
@@ -598,6 +598,47 @@ def test_tuple_type():
     _assert_print(obj, "T.Tuple(T.float32, T.int32)")
 
 
+def test_remap():
+    from tvm.script import tir as T
+
+    @T.prim_func
+    def block_with_remap_implicitly():
+        for i0, i1, i2, i3, i4, i5 in T.grid(128, 128, 128, 128, 128, 128):
+            with T.block("update"):
+                v0 = T.axis.spatial(128, i0 + 1)
+                v1 = T.axis.spatial(128, i1)
+                v2 = T.axis.reduce(128, i2)
+                v3 = T.axis.spatial(128, i3 - 1)
+                v4 = T.axis.reduce(128, i4)
+                v5 = T.axis.spatial(128, i5)
+
+    @T.prim_func
+    def block_with_remap_explicitly():
+        for i0, i1, i2, i3, i4, i5 in T.grid(128, 128, 128, 128, 128, 128):
+            with T.block("update"):
+                v0 = T.axis.spatial(128, i0 + 1)
+                v1, v2 = T.axis.remap("SR", [i1, i2])
+                v3 = T.axis.spatial(128, i3 - 1)
+                v4, v5 = T.axis.remap("RS", [i4, i5])
+
+    expected_output = """@T.prim_func
+def main():
+    with T.block("root"):
+        T.reads()
+        T.writes()
+        for i0, i1, i2, i3, i4, i5 in T.grid(128, 128, 128, 128, 128, 128):
+            with T.block("update"):
+                v0 = T.axis.spatial(128, i0 + 1)
+                v1, v2 = T.axis.remap("SR", [i1, i2])
+                v3 = T.axis.spatial(128, i3 - 1)
+                v4, v5 = T.axis.remap("RS", [i4, i5])
+                T.reads()
+                T.writes()
+                T.evaluate(0)"""
+    _assert_print(block_with_remap_explicitly, expected_output)
+    _assert_print(block_with_remap_implicitly, expected_output)
+
+
 if __name__ == "__main__":
     test_prim_func()
     test_block_realize()
@@ -639,3 +680,4 @@ if __name__ == "__main__":
     test_prim_type()
     test_pointer_type()
     test_tuple_type()
+    test_remap()