You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by bo...@apache.org on 2023/01/18 13:42:36 UTC
[tvm] branch main updated: [TVMScript] Use TVMScript for all TIR Printing (#13795)
This is an automated email from the ASF dual-hosted git repository.
bohan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new da99e9d1b5 [TVMScript] Use TVMScript for all TIR Printing (#13795)
da99e9d1b5 is described below
commit da99e9d1b5208e9a23e0b8e5b45da6e633f05415
Author: Junru Shao <ju...@gmail.com>
AuthorDate: Wed Jan 18 05:42:24 2023 -0800
[TVMScript] Use TVMScript for all TIR Printing (#13795)
---
CMakeLists.txt | 2 +-
include/tvm/ir/module.h | 28 -----
include/tvm/ir/transform.h | 1 -
include/tvm/relay/base.h | 28 +++++
include/tvm/{ir => relay}/error.h | 13 +--
include/tvm/relay/expr.h | 1 -
include/tvm/relay/expr_functor.h | 2 +-
include/tvm/relay/pattern_functor.h | 2 +-
python/tvm/ir/__init__.py | 1 -
python/tvm/ir/affine_type.py | 2 +-
python/tvm/ir/base.py | 31 -----
python/tvm/ir/expr.py | 32 +++++-
python/tvm/ir/module.py | 28 +++++
python/tvm/ir/op.py | 31 ++++-
python/tvm/ir/tensor_type.py | 2 +-
python/tvm/micro/model_library_format.py | 7 +-
python/tvm/relay/__init__.py | 1 +
python/tvm/relay/base.py | 39 ++++++-
python/tvm/relay/dataflow_pattern/__init__.py | 29 ++++-
python/tvm/relay/expr.py | 34 ++++--
python/tvm/relay/function.py | 29 ++++-
python/tvm/script/__init__.py | 1 -
python/tvm/script/printer/__init__.py | 1 -
python/tvm/script/printer/printer.py | 54 ---------
rust/tvm/src/ir/expr.rs | 2 +-
src/ir/transform.cc | 6 +-
src/relay/analysis/annotated_region_set.cc | 2 +-
src/relay/analysis/annotated_region_set.h | 2 +-
src/relay/analysis/kind_check.cc | 2 +-
src/relay/analysis/match_exhaustion.cc | 2 +-
src/relay/analysis/type_solver.h | 2 +-
src/relay/backend/contrib/ethosu/codegen.cc | 2 +-
src/relay/backend/contrib/ethosu/compiler_attrs.cc | 2 +-
src/relay/backend/contrib/ethosu/preprocess.cc | 2 +-
src/relay/backend/contrib/uma/relay_to_tir.cc | 2 +-
src/relay/backend/vm/compiler.cc | 2 +-
src/relay/backend/vm/compiler.h | 2 +-
src/relay/collage/partition_rule.h | 2 +-
src/relay/ir/base.cc | 5 +
src/{ => relay}/ir/error.cc | 11 +-
src/relay/op/tensor/transform.cc | 2 +-
src/relay/op/tensor/transform.h | 2 +-
src/relay/op/type_relations.h | 2 +-
src/{ => relay}/printer/doc.cc | 4 +-
src/{ => relay}/printer/doc.h | 9 +-
src/{ => relay}/printer/meta_data.h | 13 +--
.../printer/model_library_format_printer.cc | 6 +-
src/{ => relay}/printer/relay_text_printer.cc | 13 +--
src/{ => relay}/printer/text_printer.cc | 9 +-
src/{ => relay}/printer/text_printer.h | 47 +++-----
src/{ => relay}/printer/tir_text_printer.cc | 28 ++---
src/{ => relay}/printer/tir_text_printer_debug.cc | 4 +-
src/{ => relay}/printer/tir_text_printer_debug.h | 10 +-
src/{ => relay}/printer/tvmscript_printer.cc | 85 +++++++-------
src/relay/transforms/merge_compiler_regions.cc | 2 +-
src/relay/transforms/partition_graph.cc | 2 +-
src/script/printer/printer.cc | 7 --
src/tir/schedule/error.cc | 6 +-
src/tir/transforms/install_debug_spans.cc | 4 +-
tests/python/relay/test_ir_parser.py | 10 +-
.../test_meta_schedule_schedule_rule_mlt.py | 3 +-
tests/python/unittest/test_tir_nodes.py | 126 ---------------------
.../test_tir_transform_lower_warp_memory.py | 9 +-
.../test_tvmscript_printer_syntax_sugar.py | 69 -----------
.../python/unittest/test_tvmscript_printer_tir.py | 42 +++++++
65 files changed, 447 insertions(+), 514 deletions(-)
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 4a8d8b733e..36f7d37962 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -288,7 +288,6 @@ tvm_file_glob(GLOB_RECURSE COMPILER_SRCS
src/topi/*.cc
src/driver/*.cc
src/parser/*.cc
- src/printer/*.cc
src/support/*.cc
src/script/*.cc
)
@@ -317,6 +316,7 @@ tvm_file_glob(GLOB RELAY_BACKEND_SRCS
)
tvm_file_glob(GLOB_RECURSE RELAY_IR_SRCS
src/relay/ir/*.cc
+ src/relay/printer/*.cc
)
tvm_file_glob(GLOB_RECURSE RELAY_QNN_SRCS
src/relay/qnn/*.cc
diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h
index f26e640f6c..4cd357d418 100644
--- a/include/tvm/ir/module.h
+++ b/include/tvm/ir/module.h
@@ -446,34 +446,6 @@ class IRModule : public ObjectRef {
TVM_DEFINE_OBJECT_REF_COW_METHOD(IRModuleNode);
};
-/*!
- * \brief Pretty print a node for debug purposes.
- *
- * \param node The node to be printed.
- * \return The text reperesentation.
- * \note This function does not show version or meta-data.
- * Use AsText if you want to store the text.
- * \sa AsText.
- */
-TVM_DLL String PrettyPrint(const ObjectRef& node);
-
-/*!
- * \brief Render the node as a string in the text format.
- *
- * \param node The node to be rendered.
- * \param show_meta_data Whether to print meta data section.
- * \param annotate An optional callback function for attaching
- * additional comment block to an expr.
- *
- * \note We support a limited set of IR nodes that are part of
- * relay IR and
- *
- * \sa PrettyPrint.
- * \return The text representation.
- */
-TVM_DLL String AsText(const ObjectRef& node, bool show_meta_data = true,
- runtime::TypedPackedFunc<String(ObjectRef)> annotate = nullptr);
-
namespace attr {
// Following are attributes for IRModule only.
diff --git a/include/tvm/ir/transform.h b/include/tvm/ir/transform.h
index febcca5c01..473e629168 100644
--- a/include/tvm/ir/transform.h
+++ b/include/tvm/ir/transform.h
@@ -57,7 +57,6 @@
#define TVM_IR_TRANSFORM_H_
#include <tvm/ir/diagnostic.h>
-#include <tvm/ir/error.h>
#include <tvm/ir/instrument.h>
#include <tvm/ir/module.h>
#include <tvm/runtime/container/array.h>
diff --git a/include/tvm/relay/base.h b/include/tvm/relay/base.h
index e94bd2756e..2825bcfc65 100644
--- a/include/tvm/relay/base.h
+++ b/include/tvm/relay/base.h
@@ -120,6 +120,34 @@ class Id : public ObjectRef {
TVM_DEFINE_OBJECT_REF_METHODS(Id, ObjectRef, IdNode);
};
+/*!
+ * \brief Pretty print a node for debug purposes.
+ *
+ * \param node The node to be printed.
+ * \return The text reperesentation.
+ * \note This function does not show version or meta-data.
+ * Use AsText if you want to store the text.
+ * \sa AsText.
+ */
+TVM_DLL String PrettyPrint(const ObjectRef& node);
+
+/*!
+ * \brief Render the node as a string in the text format.
+ *
+ * \param node The node to be rendered.
+ * \param show_meta_data Whether to print meta data section.
+ * \param annotate An optional callback function for attaching
+ * additional comment block to an expr.
+ *
+ * \note We support a limited set of IR nodes that are part of
+ * relay IR and
+ *
+ * \sa PrettyPrint.
+ * \return The text representation.
+ */
+TVM_DLL String AsText(const ObjectRef& node, bool show_meta_data = true,
+ runtime::TypedPackedFunc<String(ObjectRef)> annotate = nullptr);
+
} // namespace relay
} // namespace tvm
diff --git a/include/tvm/ir/error.h b/include/tvm/relay/error.h
similarity index 97%
rename from include/tvm/ir/error.h
rename to include/tvm/relay/error.h
index 6ff61781ac..be34e2b8ae 100644
--- a/include/tvm/ir/error.h
+++ b/include/tvm/relay/error.h
@@ -16,13 +16,8 @@
* specific language governing permissions and limitations
* under the License.
*/
-
-/*!
- * \file tvm/ir/error.h
- * \brief Utilities for error tracking and reporting.
- */
-#ifndef TVM_IR_ERROR_H_
-#define TVM_IR_ERROR_H_
+#ifndef TVM_RELAY_ERROR_H_
+#define TVM_RELAY_ERROR_H_
#include <tvm/ir/module.h>
#include <tvm/ir/span.h>
@@ -33,6 +28,7 @@
#include <vector>
namespace tvm {
+namespace relay {
/*!
* \brief A wrapper around std::stringstream to build error.
*
@@ -181,5 +177,6 @@ class ErrorReporter {
std::unordered_map<ObjectRef, GlobalVar, ObjectPtrHash, ObjectPtrEqual> node_to_gv_;
};
+} // namespace relay
} // namespace tvm
-#endif // TVM_IR_ERROR_H_
+#endif // TVM_RELAY_ERROR_H_
diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h
index 6847a53caa..854050464d 100644
--- a/include/tvm/relay/expr.h
+++ b/include/tvm/relay/expr.h
@@ -57,7 +57,6 @@ using BaseFunc = tvm::BaseFunc;
using BaseFuncNode = tvm::BaseFuncNode;
using GlobalVar = tvm::GlobalVar;
using GlobalVarNode = tvm::GlobalVarNode;
-using tvm::PrettyPrint;
/*!
* \brief Constant tensor, backed by an NDArray on the cpu(0) device.
diff --git a/include/tvm/relay/expr_functor.h b/include/tvm/relay/expr_functor.h
index 280a1f8a6c..2a295c9da7 100644
--- a/include/tvm/relay/expr_functor.h
+++ b/include/tvm/relay/expr_functor.h
@@ -25,9 +25,9 @@
#ifndef TVM_RELAY_EXPR_FUNCTOR_H_
#define TVM_RELAY_EXPR_FUNCTOR_H_
-#include <tvm/ir/error.h>
#include <tvm/node/functor.h>
#include <tvm/relay/adt.h>
+#include <tvm/relay/error.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/function.h>
#include <tvm/relay/op.h>
diff --git a/include/tvm/relay/pattern_functor.h b/include/tvm/relay/pattern_functor.h
index 711d8323f1..9d2b6689b2 100644
--- a/include/tvm/relay/pattern_functor.h
+++ b/include/tvm/relay/pattern_functor.h
@@ -25,8 +25,8 @@
#ifndef TVM_RELAY_PATTERN_FUNCTOR_H_
#define TVM_RELAY_PATTERN_FUNCTOR_H_
-#include <tvm/ir/error.h>
#include <tvm/node/functor.h>
+#include <tvm/relay/error.h>
#include <string>
#include <unordered_map>
diff --git a/python/tvm/ir/__init__.py b/python/tvm/ir/__init__.py
index 9e81dd5519..4f63cbecd9 100644
--- a/python/tvm/ir/__init__.py
+++ b/python/tvm/ir/__init__.py
@@ -27,7 +27,6 @@ from .base import (
Span,
assert_structural_equal,
load_json,
- pretty_print,
save_json,
structural_equal,
structural_hash,
diff --git a/python/tvm/ir/affine_type.py b/python/tvm/ir/affine_type.py
index 8d185ae59a..24126f94b9 100644
--- a/python/tvm/ir/affine_type.py
+++ b/python/tvm/ir/affine_type.py
@@ -32,7 +32,7 @@ class AffineType(Node):
return not self.__eq__(other)
def __str__(self):
- from tvm.ir import pretty_print # pylint: disable=import-outside-toplevel
+ from tvm.relay import pretty_print # pylint: disable=import-outside-toplevel
return pretty_print(self)
diff --git a/python/tvm/ir/base.py b/python/tvm/ir/base.py
index a1e1d20d88..b84a83d558 100644
--- a/python/tvm/ir/base.py
+++ b/python/tvm/ir/base.py
@@ -23,40 +23,9 @@ from tvm.runtime import Object
from . import _ffi_api, json_compact
-def pretty_print(obj: Object) -> None:
- """Pretty print the object."""
- return _ffi_api.PrettyPrint(obj) # type: ignore # pylint: disable=no-member
-
-
class Node(Object):
"""Base class of all IR Nodes, implements astext function."""
- def astext(self, show_meta_data=True, annotate=None):
- """Get the text format of the expression.
-
- Parameters
- ----------
- show_meta_data : bool
- Whether to include meta data section in the text
- if there is meta data.
-
- annotate: Optional[Object->str]
- Optionally annotate function to provide additional
- information in the comment block.
-
- Returns
- -------
- text : str
- The text format of the expression.
-
- Notes
- -----
- The meta data section is necessary to fully parse the text format.
- However, it can contain dumps that are big (e.g constant weights),
- so it can be helpful to skip printing the meta data section.
- """
- return _ffi_api.AsText(self, show_meta_data, annotate)
-
@tvm._ffi.register_object("SourceName")
class SourceName(Object):
diff --git a/python/tvm/ir/expr.py b/python/tvm/ir/expr.py
index e16cd5ea9e..52af8407b7 100644
--- a/python/tvm/ir/expr.py
+++ b/python/tvm/ir/expr.py
@@ -17,9 +17,9 @@
"""Common expressions data structures in the IR."""
import tvm._ffi
-from .base import Node
-from . import _ffi_api
from ..runtime import const, convert
+from . import _ffi_api
+from .base import Node
class BaseExpr(Node):
@@ -91,6 +91,34 @@ class GlobalVar(RelayExpr):
"Do not know how to handle GlobalVar.__call__ for types {}".format(arg_types)
)
+ def astext(self, show_meta_data=True, annotate=None):
+ """Get the text format of the expression.
+
+ Parameters
+ ----------
+ show_meta_data : bool
+ Whether to include meta data section in the text
+ if there is meta data.
+
+ annotate: Optional[Object->str]
+ Optionally annotate function to provide additional
+ information in the comment block.
+
+ Returns
+ -------
+ text : str
+ The text format of the expression.
+
+ Notes
+ -----
+ The meta data section is necessary to fully parse the text format.
+ However, it can contain dumps that are big (e.g constant weights),
+ so it can be helpful to skip printing the meta data section.
+ """
+ from tvm.relay import astext # pylint: disable=import-outside-toplevel
+
+ return astext(self, show_meta_data, annotate)
+
@tvm._ffi.register_object
class Range(Node):
diff --git a/python/tvm/ir/module.py b/python/tvm/ir/module.py
index b184c3b0c3..51410049ec 100644
--- a/python/tvm/ir/module.py
+++ b/python/tvm/ir/module.py
@@ -287,6 +287,34 @@ class IRModule(Node):
return _ffi_api.Module_WithAttr(self, attr_key, attr_value)
+ def astext(self, show_meta_data=True, annotate=None):
+ """Get the text format of the expression.
+
+ Parameters
+ ----------
+ show_meta_data : bool
+ Whether to include meta data section in the text
+ if there is meta data.
+
+ annotate: Optional[Object->str]
+ Optionally annotate function to provide additional
+ information in the comment block.
+
+ Returns
+ -------
+ text : str
+ The text format of the expression.
+
+ Notes
+ -----
+ The meta data section is necessary to fully parse the text format.
+ However, it can contain dumps that are big (e.g constant weights),
+ so it can be helpful to skip printing the meta data section.
+ """
+ from tvm.relay import astext # pylint: disable=import-outside-toplevel
+
+ return astext(self, show_meta_data, annotate)
+
def script(
self,
*,
diff --git a/python/tvm/ir/op.py b/python/tvm/ir/op.py
index 49ac72b887..70aba97951 100644
--- a/python/tvm/ir/op.py
+++ b/python/tvm/ir/op.py
@@ -17,8 +17,9 @@
# pylint: disable=invalid-name
"""Primitive operators in the TVM IR."""
import tvm._ffi
-from .expr import RelayExpr
+
from . import _ffi_api
+from .expr import RelayExpr
@tvm._ffi.register_object("Op")
@@ -28,6 +29,34 @@ class Op(RelayExpr):
def __init__(self):
raise RuntimeError("Cannot create op, use get instead")
+ def astext(self, show_meta_data=True, annotate=None):
+ """Get the text format of the expression.
+
+ Parameters
+ ----------
+ show_meta_data : bool
+ Whether to include meta data section in the text
+ if there is meta data.
+
+ annotate: Optional[Object->str]
+ Optionally annotate function to provide additional
+ information in the comment block.
+
+ Returns
+ -------
+ text : str
+ The text format of the expression.
+
+ Notes
+ -----
+ The meta data section is necessary to fully parse the text format.
+ However, it can contain dumps that are big (e.g constant weights),
+ so it can be helpful to skip printing the meta data section.
+ """
+ from tvm.relay import astext # pylint: disable=import-outside-toplevel
+
+ return astext(self, show_meta_data, annotate)
+
@staticmethod
def get(op_name):
"""Get the Op for a given name
diff --git a/python/tvm/ir/tensor_type.py b/python/tvm/ir/tensor_type.py
index 7313f3c2b4..495e0fe868 100644
--- a/python/tvm/ir/tensor_type.py
+++ b/python/tvm/ir/tensor_type.py
@@ -56,6 +56,6 @@ class TensorType(Type):
return tuple(int(x) for x in self.shape)
def __str__(self):
- from tvm.ir import pretty_print # pylint: disable=import-outside-toplevel
+ from tvm.relay import pretty_print # pylint: disable=import-outside-toplevel
return pretty_print(self)
diff --git a/python/tvm/micro/model_library_format.py b/python/tvm/micro/model_library_format.py
index 0f30c39ad4..fc32fe34d6 100644
--- a/python/tvm/micro/model_library_format.py
+++ b/python/tvm/micro/model_library_format.py
@@ -27,12 +27,13 @@ import typing
import tvm
from tvm.micro import get_standalone_crt_dir
+
from .._ffi import get_global_func
from ..contrib import utils
from ..driver import build_module
-from ..relay.backend import executor_factory
-from ..relay.backend.name_transforms import to_c_variable_style, prefix_generated_name
from ..relay import param_dict
+from ..relay.backend import executor_factory
+from ..relay.backend.name_transforms import prefix_generated_name, to_c_variable_style
from ..tir import expr
# This should be kept identical to runtime::symbol::tvm_module_main
@@ -528,7 +529,7 @@ def _write_tir_and_build_operator_memory_map(src_dir, targets, ir_module_by_targ
# TODO(mbs): The device type is not unique, better would be to use target.kind.name
target_device_type = target.get_target_device_type()
ir_mod = ir_module_by_target[target]
- printer = get_global_func("tir.ModelLibraryFormatPrinter")(False, None, False)
+ printer = get_global_func("relay.ir.ModelLibraryFormatPrinter")(False, None, False)
with open(src_dir / f"tir-{target_device_type}.txt", "w") as f:
f.write(printer["print"](ir_mod))
diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py
index 97842738e5..5e5d1d5f18 100644
--- a/python/tvm/relay/__init__.py
+++ b/python/tvm/relay/__init__.py
@@ -29,6 +29,7 @@ from . import adt
from . import prelude
from . import loops
from . import scope_builder
+from .base import pretty_print, astext
from . import transform
from . import analysis
diff --git a/python/tvm/relay/base.py b/python/tvm/relay/base.py
index 323a8f6e5a..8667bfb1df 100644
--- a/python/tvm/relay/base.py
+++ b/python/tvm/relay/base.py
@@ -17,15 +17,50 @@
# pylint: disable=no-else-return, unidiomatic-typecheck, unused-import
"""The base node types for the Relay language."""
import os
-import tvm._ffi
+import tvm._ffi
+from tvm.ir import Node as RelayNode
+from tvm.ir import SourceName, Span
from tvm.runtime import Object
-from tvm.ir import SourceName, Span, Node as RelayNode
+from . import _ffi_api
__STD_PATH__ = os.path.join(os.path.dirname(os.path.realpath(__file__)), "std")
+def pretty_print(obj: Object) -> None:
+ """Pretty print the object."""
+ return _ffi_api.PrettyPrint(obj) # type: ignore # pylint: disable=no-member
+
+
+def astext(obj: Object, show_meta_data=True, annotate=None):
+ """Get the text format of the expression.
+
+ Parameters
+ ----------
+ obj : Object
+ The object to be printed.
+ show_meta_data : bool
+ Whether to include meta data section in the text
+ if there is meta data.
+ annotate: Optional[Object->str]
+ Optionally annotate function to provide additional
+ information in the comment block.
+
+ Returns
+ -------
+ text : str
+ The text format of the expression.
+
+ Notes
+ -----
+ The meta data section is necessary to fully parse the text format.
+ However, it can contain dumps that are big (e.g constant weights),
+ so it can be helpful to skip printing the meta data section.
+ """
+ return _ffi_api.AsText(obj, show_meta_data, annotate) # type: ignore # pylint: disable=no-member
+
+
@tvm._ffi.register_func("tvm.relay.std_path")
def _std_path():
return __STD_PATH__
diff --git a/python/tvm/relay/dataflow_pattern/__init__.py b/python/tvm/relay/dataflow_pattern/__init__.py
index 6c29825bc0..6e19cafa74 100644
--- a/python/tvm/relay/dataflow_pattern/__init__.py
+++ b/python/tvm/relay/dataflow_pattern/__init__.py
@@ -26,6 +26,7 @@ from ... import ir as _ir
from ...ir import make_node
from ...ir.base import Node
from ...runtime import Object
+from ..base import astext, pretty_print
from ..op import get
from . import _ffi as ffi
@@ -47,10 +48,34 @@ class DFPattern(Node):
"""Base class of all Patterns."""
def __str__(self):
- from tvm.ir import pretty_print # pylint: disable=import-outside-toplevel
-
return pretty_print(self)
+ def astext(self, show_meta_data=True, annotate=None):
+ """Get the text format of the expression.
+
+ Parameters
+ ----------
+ show_meta_data : bool
+ Whether to include meta data section in the text
+ if there is meta data.
+
+ annotate: Optional[Object->str]
+ Optionally annotate function to provide additional
+ information in the comment block.
+
+ Returns
+ -------
+ text : str
+ The text format of the expression.
+
+ Notes
+ -----
+ The meta data section is necessary to fully parse the text format.
+ However, it can contain dumps that are big (e.g constant weights),
+ so it can be helpful to skip printing the meta data section.
+ """
+ return astext(self, show_meta_data, annotate)
+
def __call__(self, *args):
args = list(args)
if len(args) == 1 and args[0] is None:
diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py
index 7d60e89b59..cb14552ac1 100644
--- a/python/tvm/relay/expr.py
+++ b/python/tvm/relay/expr.py
@@ -30,7 +30,7 @@ from tvm.runtime import ndarray as _nd
from . import _ffi_api
from . import ty as _ty
-from .base import RelayNode
+from .base import RelayNode, astext, pretty_print
# alias relay expr as Expr.
Expr = RelayExpr
@@ -62,10 +62,34 @@ class ExprWithOp(RelayExpr):
return _ffi_api.cast(self, dtype)
def __str__(self):
- from tvm.ir import pretty_print # pylint: disable=import-outside-toplevel
-
return pretty_print(self)
+ def astext(self, show_meta_data=True, annotate=None):
+ """Get the text format of the expression.
+
+ Parameters
+ ----------
+ show_meta_data : bool
+ Whether to include meta data section in the text
+ if there is meta data.
+
+ annotate: Optional[Object->str]
+ Optionally annotate function to provide additional
+ information in the comment block.
+
+ Returns
+ -------
+ text : str
+ The text format of the expression.
+
+ Notes
+ -----
+ The meta data section is necessary to fully parse the text format.
+ However, it can contain dumps that are big (e.g constant weights),
+ so it can be helpful to skip printing the meta data section.
+ """
+ return astext(self, show_meta_data, annotate)
+
def __neg__(self):
return _op_make.negative(self)
@@ -719,8 +743,6 @@ class StorageInfo(Node):
self.__init_handle_by_constructor__(_ffi_api.StorageInfo, sids, dev_types, sizes)
def __str__(self):
- from tvm.ir import pretty_print # pylint: disable=import-outside-toplevel
-
return pretty_print(self)
@property
@@ -750,6 +772,4 @@ class StaticMemoryPlan(Node):
self.__init_handle_by_constructor__(_ffi_api.StaticMemoryPlan, expr_to_storage_info)
def __str__(self):
- from tvm.ir import pretty_print # pylint: disable=import-outside-toplevel
-
return pretty_print(self)
diff --git a/python/tvm/relay/function.py b/python/tvm/relay/function.py
index ef33564500..dc0636a9b3 100644
--- a/python/tvm/relay/function.py
+++ b/python/tvm/relay/function.py
@@ -23,6 +23,7 @@ from tvm.ir import BaseFunc
from tvm.runtime import convert
from . import _ffi_api
+from .base import astext, pretty_print
from .expr import Call
@@ -68,10 +69,34 @@ class Function(BaseFunc):
return Call(self, args, None, None)
def __str__(self):
- from tvm.ir import pretty_print # pylint: disable=import-outside-toplevel
-
return pretty_print(self)
+ def astext(self, show_meta_data=True, annotate=None):
+ """Get the text format of the expression.
+
+ Parameters
+ ----------
+ show_meta_data : bool
+ Whether to include meta data section in the text
+ if there is meta data.
+
+ annotate: Optional[Object->str]
+ Optionally annotate function to provide additional
+ information in the comment block.
+
+ Returns
+ -------
+ text : str
+ The text format of the expression.
+
+ Notes
+ -----
+ The meta data section is necessary to fully parse the text format.
+ However, it can contain dumps that are big (e.g constant weights),
+ so it can be helpful to skip printing the meta data section.
+ """
+ return astext(self, show_meta_data, annotate)
+
@tvm._ffi.register_func("relay.FunctionWithFields")
def FunctionWithFields(
diff --git a/python/tvm/script/__init__.py b/python/tvm/script/__init__.py
index 82bb698f27..9283727ad4 100644
--- a/python/tvm/script/__init__.py
+++ b/python/tvm/script/__init__.py
@@ -18,4 +18,3 @@
from .parser import ir, ir_module
from .parser import parse as from_source
from .parser import tir
-from .printer import script
diff --git a/python/tvm/script/printer/__init__.py b/python/tvm/script/printer/__init__.py
index dc37ea1ff6..01d89dacbf 100644
--- a/python/tvm/script/printer/__init__.py
+++ b/python/tvm/script/printer/__init__.py
@@ -20,4 +20,3 @@ This package provides a set of APIs to print supported TVM IR into TVMScript
in a roundtrippable way.
"""
from . import default
-from .printer import script
diff --git a/python/tvm/script/printer/printer.py b/python/tvm/script/printer/printer.py
deleted file mode 100644
index 2ce6329dca..0000000000
--- a/python/tvm/script/printer/printer.py
+++ /dev/null
@@ -1,54 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""The printer interface"""
-from typing import Optional
-
-from tvm.runtime.object_path import ObjectPath
-
-from . import _ffi_api
-
-
-def script(
- obj,
- indent_space: int = 4,
- print_line_number: bool = False,
- num_context_lines: int = -1,
- path_to_underline: Optional[ObjectPath] = None,
-):
- """Print a TVM IR as a TVMScript text format.
-
- Parameters
- ----------
- obj : object
- An TVM object representing TVM IR
- indent_space : int = 4
- The number of spaces to indent
- print_line_number : bool = False
- Whether to print line number
- num_context_lines : int = -1
- The number of context lines to print. -1 means all lines.
- path_to_underline : Optional[ObjectPath]
- The path to underline in the script.
-
- Returns
- -------
- script : str
- The TVMScript text format
- """
- return _ffi_api.Script( # type: ignore # pylint: disable=no-member
- obj, indent_space, print_line_number, num_context_lines, path_to_underline
- )
diff --git a/rust/tvm/src/ir/expr.rs b/rust/tvm/src/ir/expr.rs
index 03d8a49207..1a0e7aea39 100644
--- a/rust/tvm/src/ir/expr.rs
+++ b/rust/tvm/src/ir/expr.rs
@@ -90,7 +90,7 @@ impl GlobalVar {
// TODO: figure out how to type the last argument runtime::TypedPackedFunc<String(ObjectRef)> annotate)
external! {
- #[name("ir.AsText")]
+ #[name("relay.ir.AsText")]
fn _as_text(object: ObjectRef, show_meta_data: i32, annotate: runtime::Function) -> TString;
}
diff --git a/src/ir/transform.cc b/src/ir/transform.cc
index bfd0a59175..9a669493cc 100644
--- a/src/ir/transform.cc
+++ b/src/ir/transform.cc
@@ -587,7 +587,11 @@ TVM_REGISTER_GLOBAL("transform.OverrideInstruments")
Pass PrintIR(String header, bool show_meta_data) {
auto pass_func = [header, show_meta_data](IRModule mod, const PassContext& ctx) {
- LOG(INFO) << "PrintIR(" << header << "):\n" << AsText(mod, show_meta_data);
+ if (const auto* f = runtime::Registry::Get("relay.PrintIR")) {
+ (*f)(mod, header, show_meta_data);
+ } else {
+ LOG(INFO) << "PrintIR(" << header << "):\n" << mod;
+ }
return mod;
};
return CreateModulePass(pass_func, 0, "PrintIR", {});
diff --git a/src/relay/analysis/annotated_region_set.cc b/src/relay/analysis/annotated_region_set.cc
index 53c680b722..ef21604d8a 100644
--- a/src/relay/analysis/annotated_region_set.cc
+++ b/src/relay/analysis/annotated_region_set.cc
@@ -19,7 +19,7 @@
#include "annotated_region_set.h"
-#include <tvm/ir/error.h>
+#include <tvm/relay/error.h>
#include <tvm/relay/expr.h>
#include <unordered_map>
diff --git a/src/relay/analysis/annotated_region_set.h b/src/relay/analysis/annotated_region_set.h
index aca4239791..443bd5ec1d 100644
--- a/src/relay/analysis/annotated_region_set.h
+++ b/src/relay/analysis/annotated_region_set.h
@@ -27,9 +27,9 @@
#ifndef TVM_RELAY_ANALYSIS_ANNOTATED_REGION_SET_H_
#define TVM_RELAY_ANALYSIS_ANNOTATED_REGION_SET_H_
-#include <tvm/ir/error.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/attrs/annotation.h>
+#include <tvm/relay/error.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h>
diff --git a/src/relay/analysis/kind_check.cc b/src/relay/analysis/kind_check.cc
index 65b8516cb1..f7a5e7bf2d 100644
--- a/src/relay/analysis/kind_check.cc
+++ b/src/relay/analysis/kind_check.cc
@@ -31,9 +31,9 @@
* We check this by ensuring the `dtype` field of a Tensor always
* contains a data type such as `int`, `float`, `uint`.
*/
-#include <tvm/ir/error.h>
#include <tvm/ir/type_functor.h>
#include <tvm/relay/analysis.h>
+#include <tvm/relay/error.h>
namespace tvm {
namespace relay {
diff --git a/src/relay/analysis/match_exhaustion.cc b/src/relay/analysis/match_exhaustion.cc
index 2a90b911b6..05d5b36e36 100644
--- a/src/relay/analysis/match_exhaustion.cc
+++ b/src/relay/analysis/match_exhaustion.cc
@@ -27,8 +27,8 @@
* code correctness, since hitting an unmatched case results in a
* dynamic error unless exhaustiveness is checked in advance.
*/
-#include <tvm/ir/error.h>
#include <tvm/relay/adt.h>
+#include <tvm/relay/error.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/pattern_functor.h>
diff --git a/src/relay/analysis/type_solver.h b/src/relay/analysis/type_solver.h
index 3bde1a1e37..7940e347b3 100644
--- a/src/relay/analysis/type_solver.h
+++ b/src/relay/analysis/type_solver.h
@@ -24,8 +24,8 @@
#ifndef TVM_RELAY_ANALYSIS_TYPE_SOLVER_H_
#define TVM_RELAY_ANALYSIS_TYPE_SOLVER_H_
-#include <tvm/ir/error.h>
#include <tvm/relay/analysis.h>
+#include <tvm/relay/error.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/type.h>
diff --git a/src/relay/backend/contrib/ethosu/codegen.cc b/src/relay/backend/contrib/ethosu/codegen.cc
index afa17750d8..a622f96c81 100644
--- a/src/relay/backend/contrib/ethosu/codegen.cc
+++ b/src/relay/backend/contrib/ethosu/codegen.cc
@@ -24,9 +24,9 @@
* Codegen.
*/
-#include <tvm/ir/error.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/attrs/annotation.h>
+#include <tvm/relay/error.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h>
diff --git a/src/relay/backend/contrib/ethosu/compiler_attrs.cc b/src/relay/backend/contrib/ethosu/compiler_attrs.cc
index 42add45b01..6c825a1890 100644
--- a/src/relay/backend/contrib/ethosu/compiler_attrs.cc
+++ b/src/relay/backend/contrib/ethosu/compiler_attrs.cc
@@ -17,9 +17,9 @@
* under the License.
*/
-#include <tvm/ir/error.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/attrs/annotation.h>
+#include <tvm/relay/error.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h>
diff --git a/src/relay/backend/contrib/ethosu/preprocess.cc b/src/relay/backend/contrib/ethosu/preprocess.cc
index 571a56ad97..a0e0ac772f 100644
--- a/src/relay/backend/contrib/ethosu/preprocess.cc
+++ b/src/relay/backend/contrib/ethosu/preprocess.cc
@@ -16,9 +16,9 @@
* specific language governing permissions and limitations
* under the License.
*/
-#include <tvm/ir/error.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/attrs/annotation.h>
+#include <tvm/relay/error.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h>
diff --git a/src/relay/backend/contrib/uma/relay_to_tir.cc b/src/relay/backend/contrib/uma/relay_to_tir.cc
index 8aed694531..ca3ae0ebec 100644
--- a/src/relay/backend/contrib/uma/relay_to_tir.cc
+++ b/src/relay/backend/contrib/uma/relay_to_tir.cc
@@ -23,9 +23,9 @@
* \brief this file contains the target hooks for the Universal Modular Accelerator Interface (UMA).
*/
-#include <tvm/ir/error.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/attrs/annotation.h>
+#include <tvm/relay/error.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h>
diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc
index 9ba90b9f67..fb23c4cc08 100644
--- a/src/relay/backend/vm/compiler.cc
+++ b/src/relay/backend/vm/compiler.cc
@@ -25,11 +25,11 @@
#include "compiler.h"
#include <tvm/driver/driver_api.h>
-#include <tvm/ir/error.h>
#include <tvm/parser/parser.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/attrs/device_copy.h>
#include <tvm/relay/attrs/memory.h>
+#include <tvm/relay/error.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/interpreter.h>
#include <tvm/relay/qnn/transform.h>
diff --git a/src/relay/backend/vm/compiler.h b/src/relay/backend/vm/compiler.h
index 163ec39901..9160ce0e2e 100644
--- a/src/relay/backend/vm/compiler.h
+++ b/src/relay/backend/vm/compiler.h
@@ -25,7 +25,7 @@
#ifndef TVM_RELAY_BACKEND_VM_COMPILER_H_
#define TVM_RELAY_BACKEND_VM_COMPILER_H_
-#include <tvm/ir/error.h>
+#include <tvm/relay/error.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/interpreter.h>
#include <tvm/relay/transform.h>
diff --git a/src/relay/collage/partition_rule.h b/src/relay/collage/partition_rule.h
index 19e7f3cceb..ca68c9b086 100644
--- a/src/relay/collage/partition_rule.h
+++ b/src/relay/collage/partition_rule.h
@@ -31,7 +31,7 @@
#include <string>
#include <vector>
-#include "../../printer/doc.h"
+#include "../printer/doc.h"
#include "./candidate_partition.h"
#include "./combiner_rule.h"
#include "./sub_graph.h"
diff --git a/src/relay/ir/base.cc b/src/relay/ir/base.cc
index 5f7b8747a7..5f91302608 100644
--- a/src/relay/ir/base.cc
+++ b/src/relay/ir/base.cc
@@ -51,5 +51,10 @@ TVM_REGISTER_GLOBAL("ir.NodeSetSpan").set_body_typed([](ObjectRef node_ref, Span
}
});
+TVM_REGISTER_GLOBAL("relay.PrintIR")
+ .set_body_typed([](ObjectRef mod, String header, bool show_metadata) {
+ LOG(INFO) << "PrintIR(" << header << "):\n" << AsText(mod, show_metadata);
+ });
+
} // namespace relay
} // namespace tvm
diff --git a/src/ir/error.cc b/src/relay/ir/error.cc
similarity index 97%
rename from src/ir/error.cc
rename to src/relay/ir/error.cc
index 26448d0400..940efd91aa 100644
--- a/src/ir/error.cc
+++ b/src/relay/ir/error.cc
@@ -16,13 +16,9 @@
* specific language governing permissions and limitations
* under the License.
*/
-
-/*!
- * \file ir/error.cc
- * \brief Utilities for error tracking and reporting.
- */
-#include <tvm/ir/error.h>
#include <tvm/ir/module.h>
+#include <tvm/relay/base.h>
+#include <tvm/relay/error.h>
// clang-format off
#include <string>
@@ -31,6 +27,7 @@
// clang-format on
namespace tvm {
+namespace relay {
template <typename T, typename U>
using NodeMap = std::unordered_map<T, U, ObjectPtrHash, ObjectPtrEqual>;
@@ -137,5 +134,5 @@ void ErrorReporter::ReportAt(const GlobalVar& global, const ObjectRef& node,
}
this->node_to_gv_.insert({node, global});
}
-
+} // namespace relay
} // namespace tvm
diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc
index c41eb0f8ad..5c5cd6f4b7 100644
--- a/src/relay/op/tensor/transform.cc
+++ b/src/relay/op/tensor/transform.cc
@@ -23,8 +23,8 @@
*/
#include "transform.h"
-#include <tvm/ir/error.h>
#include <tvm/relay/attrs/transform.h>
+#include <tvm/relay/error.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/op.h>
#include <tvm/runtime/packed_func.h>
diff --git a/src/relay/op/tensor/transform.h b/src/relay/op/tensor/transform.h
index 3c638a59f4..6c88aec8b9 100644
--- a/src/relay/op/tensor/transform.h
+++ b/src/relay/op/tensor/transform.h
@@ -24,8 +24,8 @@
#ifndef TVM_RELAY_OP_TENSOR_TRANSFORM_H_
#define TVM_RELAY_OP_TENSOR_TRANSFORM_H_
-#include <tvm/ir/error.h>
#include <tvm/relay/attrs/transform.h>
+#include <tvm/relay/error.h>
#include <tvm/relay/op_attr_types.h>
#include <algorithm>
diff --git a/src/relay/op/type_relations.h b/src/relay/op/type_relations.h
index 6d6d5f70c0..740766172d 100644
--- a/src/relay/op/type_relations.h
+++ b/src/relay/op/type_relations.h
@@ -25,7 +25,7 @@
#ifndef TVM_RELAY_OP_TYPE_RELATIONS_H_
#define TVM_RELAY_OP_TYPE_RELATIONS_H_
-#include <tvm/ir/error.h>
+#include <tvm/relay/error.h>
#include <tvm/relay/type.h>
#include <string>
diff --git a/src/printer/doc.cc b/src/relay/printer/doc.cc
similarity index 98%
rename from src/printer/doc.cc
rename to src/relay/printer/doc.cc
index b06995fb12..79313c9a58 100644
--- a/src/printer/doc.cc
+++ b/src/relay/printer/doc.cc
@@ -30,9 +30,10 @@
#include <sstream>
#include <vector>
-#include "../support/str_escape.h"
+#include "../../support/str_escape.h"
namespace tvm {
+namespace relay {
/*!
* \brief Represent a piece of text in the doc.
@@ -157,4 +158,5 @@ Doc Doc::Concat(const std::vector<Doc>& vec, const Doc& sep) {
}
return seq;
}
+} // namespace relay
} // namespace tvm
diff --git a/src/printer/doc.h b/src/relay/printer/doc.h
similarity index 97%
rename from src/printer/doc.h
rename to src/relay/printer/doc.h
index dc6ba8952f..36f26d9bd2 100644
--- a/src/printer/doc.h
+++ b/src/relay/printer/doc.h
@@ -23,8 +23,8 @@
*
* Reference: Philip Wadler. A Prettier Printer. Journal of Functional Programming'98
*/
-#ifndef TVM_PRINTER_DOC_H_
-#define TVM_PRINTER_DOC_H_
+#ifndef TVM_RELAY_PRINTER_DOC_H_
+#define TVM_RELAY_PRINTER_DOC_H_
#include <tvm/node/node.h>
#include <tvm/runtime/data_type.h>
@@ -35,6 +35,7 @@
#include <vector>
namespace tvm {
+namespace relay {
/*!
* \brief Doc atom node for the ADT.
@@ -162,6 +163,6 @@ class Doc {
/*! \brief Internal doc stream. */
std::vector<DocAtom> stream_;
};
-
+} // namespace relay
} // namespace tvm
-#endif // TVM_PRINTER_DOC_H_
+#endif // TVM_RELAY_PRINTER_DOC_H_
diff --git a/src/printer/meta_data.h b/src/relay/printer/meta_data.h
similarity index 95%
rename from src/printer/meta_data.h
rename to src/relay/printer/meta_data.h
index ddf0d78087..2dfd594de7 100644
--- a/src/printer/meta_data.h
+++ b/src/relay/printer/meta_data.h
@@ -16,13 +16,8 @@
* specific language governing permissions and limitations
* under the License.
*/
-
-/*!
- * \file tvm/printer/meta_data.h
- * \brief Meta data context for printers.
- */
-#ifndef TVM_PRINTER_META_DATA_H_
-#define TVM_PRINTER_META_DATA_H_
+#ifndef TVM_RELAY_PRINTER_META_DATA_H_
+#define TVM_RELAY_PRINTER_META_DATA_H_
#include <tvm/node/serialization.h>
@@ -32,6 +27,7 @@
#include "doc.h"
namespace tvm {
+namespace relay {
/*!
* \brief Meta data context for Printers
*
@@ -140,5 +136,6 @@ class TextMetaDataContext {
/*! \brief map from meta data into its string representation */
std::unordered_map<ObjectRef, Doc, ObjectPtrHash, ObjectPtrEqual> meta_repr_;
};
+} // namespace relay
} // namespace tvm
-#endif // TVM_PRINTER_META_DATA_H_
+#endif // TVM_RELAY_PRINTER_META_DATA_H_
diff --git a/src/printer/model_library_format_printer.cc b/src/relay/printer/model_library_format_printer.cc
similarity index 96%
rename from src/printer/model_library_format_printer.cc
rename to src/relay/printer/model_library_format_printer.cc
index 4220aa00f5..76d0f1423d 100644
--- a/src/printer/model_library_format_printer.cc
+++ b/src/relay/printer/model_library_format_printer.cc
@@ -26,7 +26,7 @@
#include "text_printer.h"
namespace tvm {
-namespace printer {
+namespace relay {
class ModelLibraryFormatPrinter : public ::tvm::runtime::ModuleNode {
public:
@@ -69,7 +69,7 @@ class ModelLibraryFormatPrinter : public ::tvm::runtime::ModuleNode {
TextPrinter text_printer_;
};
-TVM_REGISTER_GLOBAL("tir.ModelLibraryFormatPrinter")
+TVM_REGISTER_GLOBAL("relay.ir.ModelLibraryFormatPrinter")
.set_body_typed([](bool show_meta_data,
const runtime::TypedPackedFunc<std::string(ObjectRef)>& annotate,
bool show_warning) {
@@ -77,5 +77,5 @@ TVM_REGISTER_GLOBAL("tir.ModelLibraryFormatPrinter")
make_object<ModelLibraryFormatPrinter>(show_meta_data, annotate, show_warning));
});
-} // namespace printer
+} // namespace relay
} // namespace tvm
diff --git a/src/printer/relay_text_printer.cc b/src/relay/printer/relay_text_printer.cc
similarity index 99%
rename from src/printer/relay_text_printer.cc
rename to src/relay/printer/relay_text_printer.cc
index 76cac28b07..cc86f9b564 100644
--- a/src/printer/relay_text_printer.cc
+++ b/src/relay/printer/relay_text_printer.cc
@@ -40,10 +40,10 @@
#include <tvm/target/virtual_device.h>
#include <tvm/tir/function.h>
-#include "../ir/attr_functor.h"
-#include "../parser/meta_ref.h"
-#include "../relay/analysis/dependency_graph.h"
-#include "../support/scalars.h"
+#include "../../ir/attr_functor.h"
+#include "../../parser/meta_ref.h"
+#include "../../support/scalars.h"
+#include "../analysis/dependency_graph.h"
#include "doc.h"
#include "meta_data.h"
#include "text_printer.h"
@@ -970,10 +970,5 @@ Doc RelayTextPrinter::PrintSpan(const Span& span) {
return doc;
}
-TVM_REGISTER_GLOBAL("ir.TextPrinter").set_body_typed([](ObjectRef node) {
- auto text = AsText(node, false, nullptr);
- return text;
-});
-
} // namespace relay
} // namespace tvm
diff --git a/src/printer/text_printer.cc b/src/relay/printer/text_printer.cc
similarity index 95%
rename from src/printer/text_printer.cc
rename to src/relay/printer/text_printer.cc
index 4d4113fef6..f51f7c3dfa 100644
--- a/src/printer/text_printer.cc
+++ b/src/relay/printer/text_printer.cc
@@ -23,7 +23,7 @@
* that can be parsed by a parser.
*/
-#include "text_printer.h"
+#include "./text_printer.h"
#include <tvm/tir/function.h>
@@ -31,6 +31,7 @@
#include <string>
namespace tvm {
+namespace relay {
static const char* kSemVer = "0.0.5";
@@ -124,8 +125,8 @@ String AsText(const ObjectRef& node, bool show_meta_data,
return doc.str();
}
-TVM_REGISTER_GLOBAL("ir.PrettyPrint").set_body_typed(PrettyPrint);
-
-TVM_REGISTER_GLOBAL("ir.AsText").set_body_typed(AsText);
+TVM_REGISTER_GLOBAL("relay.ir.PrettyPrint").set_body_typed(PrettyPrint);
+TVM_REGISTER_GLOBAL("relay.ir.AsText").set_body_typed(AsText);
+} // namespace relay
} // namespace tvm
diff --git a/src/printer/text_printer.h b/src/relay/printer/text_printer.h
similarity index 95%
rename from src/printer/text_printer.h
rename to src/relay/printer/text_printer.h
index 925c2ebf49..707bbec5ad 100644
--- a/src/printer/text_printer.h
+++ b/src/relay/printer/text_printer.h
@@ -23,8 +23,8 @@
* that can be parsed by a parser.
*/
-#ifndef TVM_PRINTER_TEXT_PRINTER_H_
-#define TVM_PRINTER_TEXT_PRINTER_H_
+#ifndef TVM_RELAY_PRINTER_TEXT_PRINTER_H_
+#define TVM_RELAY_PRINTER_TEXT_PRINTER_H_
#include <tvm/ir/module.h>
#include <tvm/ir/type_functor.h>
@@ -41,19 +41,16 @@
#include <unordered_set>
#include <vector>
-#include "../ir/attr_functor.h"
-#include "../relay/analysis/dependency_graph.h"
+#include "../../ir/attr_functor.h"
+#include "../analysis/dependency_graph.h"
#include "doc.h"
#include "meta_data.h"
-#include "text_printer.h"
-
-namespace tvm {
-class TextPrinter;
-} // namespace tvm
namespace tvm {
namespace relay {
+class TextPrinter;
+
class RelayTextPrinter : public ExprFunctor<Doc(const Expr&)>,
public PatternFunctor<Doc(const Pattern&)>,
public TypeFunctor<Doc(const Type&)>,
@@ -227,14 +224,10 @@ class RelayTextPrinter : public ExprFunctor<Doc(const Expr&)>,
DependencyGraph dg_;
class AttrPrinter;
friend class AttrPrinter;
- friend class tvm::TextPrinter;
+ friend class tvm::relay::TextPrinter;
};
-} // namespace relay
-} // namespace tvm
-
-namespace tvm {
-namespace tir {
+using namespace ::tvm::tir;
/*!
* \brief Meta node collector
@@ -274,7 +267,7 @@ class MetaCollector : public StmtExprVisitor {
};
class TIRTextPrinter : public StmtFunctor<Doc(const Stmt&)>,
- public ExprFunctor<Doc(const PrimExpr&)>,
+ public tir::ExprFunctor<Doc(const PrimExpr&)>,
public TypeFunctor<Doc(const Type&)> {
public:
explicit TIRTextPrinter(bool show_meta, TextMetaDataContext* meta)
@@ -298,7 +291,7 @@ class TIRTextPrinter : public StmtFunctor<Doc(const Stmt&)>,
Doc VisitExpr_(const FloatImmNode* op) override;
Doc VisitExpr_(const StringImmNode* op) override;
Doc VisitExpr_(const CastNode* op) override;
- Doc VisitExpr_(const VarNode* op) override;
+ Doc VisitExpr_(const tir::VarNode* op) override;
Doc VisitExpr_(const AddNode* op) override;
Doc VisitExpr_(const SubNode* op) override;
Doc VisitExpr_(const MulNode* op) override;
@@ -323,8 +316,8 @@ class TIRTextPrinter : public StmtFunctor<Doc(const Stmt&)>,
Doc VisitExpr_(const LoadNode* op) override;
Doc VisitExpr_(const RampNode* op) override;
Doc VisitExpr_(const BroadcastNode* op) override;
- Doc VisitExpr_(const LetNode* op) override;
- Doc VisitExpr_(const CallNode* op) override;
+ Doc VisitExpr_(const tir::LetNode* op) override;
+ Doc VisitExpr_(const tir::CallNode* op) override;
Doc VisitExpr_(const ShuffleNode* op) override;
Doc VisitExpr_(const ReduceNode* op) override;
Doc VisitExprDefault_(const Object* op) override;
@@ -357,7 +350,7 @@ class TIRTextPrinter : public StmtFunctor<Doc(const Stmt&)>,
/*! \brief meta collector */
MetaCollector meta_collector_;
/*! \brief Map from Var to Doc */
- std::unordered_map<Var, Doc, ObjectPtrHash, ObjectPtrEqual> memo_var_;
+ std::unordered_map<tir::Var, Doc, ObjectPtrHash, ObjectPtrEqual> memo_var_;
/*! \brief Map from Buffer to Doc */
std::unordered_map<Buffer, Doc, ObjectPtrHash, ObjectPtrEqual> memo_buf_;
/*! \brief Map from Buffer to Doc */
@@ -365,7 +358,7 @@ class TIRTextPrinter : public StmtFunctor<Doc(const Stmt&)>,
/*! \brief name allocation map */
std::unordered_map<std::string, int> name_alloc_map_;
- friend class tvm::TextPrinter;
+ friend class TextPrinter;
Doc VisitType_(const PrimTypeNode* node) override;
Doc VisitType_(const PointerTypeNode* node) override;
@@ -396,7 +389,7 @@ class TIRTextPrinter : public StmtFunctor<Doc(const Stmt&)>,
template <typename T>
static Doc PrintConstScalar(DataType dtype, const T& data);
Doc GetUniqueName(std::string prefix);
- Doc AllocVar(const Var& var);
+ Doc AllocVar(const tir::Var& var);
Doc AllocConst(const AllocateConst& var);
Doc AllocBuf(const Buffer& buffer);
Doc AllocProducer(const DataProducer& buffer);
@@ -412,11 +405,6 @@ class TIRTextPrinter : public StmtFunctor<Doc(const Stmt&)>,
String AsTVMScriptWithDiagnostic(const ObjectRef& mod, const String& tir_prefix, bool show_meta,
runtime::TypedPackedFunc<std::string(Stmt)> annotate);
-} // namespace tir
-} // namespace tvm
-
-namespace tvm {
-
class TextPrinter {
public:
explicit TextPrinter(bool show_meta_data,
@@ -441,7 +429,7 @@ class TextPrinter {
/*! \brief Relay Text Printer */
relay::RelayTextPrinter relay_text_printer_;
/*! \brief TIR Text Printer */
- tir::TIRTextPrinter tir_text_printer_;
+ TIRTextPrinter tir_text_printer_;
bool GetVarName(::tvm::tir::Var v, std::string* s) { return tir_text_printer_.GetVarName(v, s); }
@@ -472,6 +460,7 @@ class TextPrinter {
Doc PrintMod(const IRModule& mod);
};
+} // namespace relay
} // namespace tvm
-#endif // TVM_PRINTER_TEXT_PRINTER_H_
+#endif // TVM_RELAY_PRINTER_TEXT_PRINTER_H_
diff --git a/src/printer/tir_text_printer.cc b/src/relay/printer/tir_text_printer.cc
similarity index 97%
rename from src/printer/tir_text_printer.cc
rename to src/relay/printer/tir_text_printer.cc
index 4d74cc6d5a..eb089bd0d7 100644
--- a/src/printer/tir_text_printer.cc
+++ b/src/relay/printer/tir_text_printer.cc
@@ -36,13 +36,13 @@
#include <algorithm>
#include <string>
-#include "../tir/transforms/ir_utils.h"
+#include "../../tir/transforms/ir_utils.h"
#include "doc.h"
#include "meta_data.h"
#include "text_printer.h"
namespace tvm {
-namespace tir {
+namespace relay {
Doc TIRTextPrinter::Print(const ObjectRef& node) {
if (!node.defined()) return Doc::Text("(nullptr)");
@@ -93,9 +93,9 @@ Doc TIRTextPrinter::PrintPrimFunc(const PrimFunc& prim_func) {
memo_buf_.clear();
// ordered vars associated with buffers, for consistent printing
- std::vector<Var> buffer_vars_ordered;
+ std::vector<tir::Var> buffer_vars_ordered;
- for (Var v : op->params) {
+ for (tir::Var v : op->params) {
auto buffer_map_find = op->buffer_map.find(v);
if (buffer_map_find != op->buffer_map.end()) {
auto map_data = *buffer_map_find;
@@ -132,7 +132,7 @@ Doc TIRTextPrinter::PrintPrimFunc(const PrimFunc& prim_func) {
if (memo_buf_.size() != 0) {
Doc buffer_doc;
std::vector<Doc> buffer_docs;
- for (const Var& v : buffer_vars_ordered) {
+ for (const tir::Var& v : buffer_vars_ordered) {
const Buffer buf = op->buffer_map[v];
buffer_docs.push_back(BufferNode2Doc(buf.get(), Print(buf)));
}
@@ -144,7 +144,7 @@ Doc TIRTextPrinter::PrintPrimFunc(const PrimFunc& prim_func) {
if (op->buffer_map.size() != 0) {
// print buffer_map
std::vector<Doc> buffer_map_doc;
- for (const Var& v : buffer_vars_ordered) {
+ for (const tir::Var& v : buffer_vars_ordered) {
const Buffer buf = op->buffer_map[v];
buffer_map_doc.push_back(Print(v) << ": " << Print(buf));
}
@@ -302,9 +302,9 @@ Doc TIRTextPrinter::VisitExpr_(const CastNode* op) {
return doc;
}
-Doc TIRTextPrinter::VisitExpr_(const VarNode* op) {
- const Var& var = GetRef<Var>(op);
- return meta_->InMeta(var) ? meta_->GetMetaNode(var) : AllocVar(GetRef<Var>(op));
+Doc TIRTextPrinter::VisitExpr_(const tir::VarNode* op) {
+ const tir::Var& var = GetRef<tir::Var>(op);
+ return meta_->InMeta(var) ? meta_->GetMetaNode(var) : AllocVar(GetRef<tir::Var>(op));
}
#define TVM_DECLARE_TIR_TEXT_PRINTER_BINOP(OpName, OpString) \
@@ -401,13 +401,13 @@ Doc TIRTextPrinter::VisitExpr_(const BroadcastNode* op) {
return doc;
}
-Doc TIRTextPrinter::VisitExpr_(const LetNode* op) {
+Doc TIRTextPrinter::VisitExpr_(const tir::LetNode* op) {
Doc doc;
doc << "let " << Print(op->var) << " = " << Print(op->value) << " in " << Print(op->body);
return doc;
}
-Doc TIRTextPrinter::VisitExpr_(const CallNode* op) {
+Doc TIRTextPrinter::VisitExpr_(const tir::CallNode* op) {
Doc doc;
std::vector<Doc> func_args;
if (auto* ptr_op = op->op.as<OpNode>()) {
@@ -771,7 +771,7 @@ Doc TIRTextPrinter::GetUniqueName(std::string prefix) {
return Doc::Text(unique_prefix);
}
-Doc TIRTextPrinter::AllocVar(const Var& var) {
+Doc TIRTextPrinter::AllocVar(const tir::Var& var) {
const auto& it = memo_var_.find(var);
if (it != memo_var_.end()) {
return it->second;
@@ -831,7 +831,7 @@ Doc TIRTextPrinter::PrintBody(const Stmt& body, bool indent) {
return doc;
}
-bool TIRTextPrinter::GetVarName(Var v, std::string* s) {
+bool TIRTextPrinter::GetVarName(tir::Var v, std::string* s) {
auto it = memo_var_.find(v);
if (it == memo_var_.end()) {
return false;
@@ -841,5 +841,5 @@ bool TIRTextPrinter::GetVarName(Var v, std::string* s) {
return true;
}
-} // namespace tir
+} // namespace relay
} // namespace tvm
diff --git a/src/printer/tir_text_printer_debug.cc b/src/relay/printer/tir_text_printer_debug.cc
similarity index 98%
rename from src/printer/tir_text_printer_debug.cc
rename to src/relay/printer/tir_text_printer_debug.cc
index 6c29558f72..914d8877d2 100644
--- a/src/printer/tir_text_printer_debug.cc
+++ b/src/relay/printer/tir_text_printer_debug.cc
@@ -29,7 +29,7 @@
#include <string>
namespace tvm {
-namespace tir {
+namespace relay {
std::optional<std::string> span_text(const Span& span) {
if (!span.defined()) {
@@ -93,5 +93,5 @@ Doc TIRTextPrinterDebug::VisitExpr(const PrimExpr& e) {
return TIRTextPrinter::VisitExpr(e);
}
-} // namespace tir
+} // namespace relay
} // namespace tvm
diff --git a/src/printer/tir_text_printer_debug.h b/src/relay/printer/tir_text_printer_debug.h
similarity index 90%
rename from src/printer/tir_text_printer_debug.h
rename to src/relay/printer/tir_text_printer_debug.h
index d0046034cf..f7cb7a6554 100644
--- a/src/printer/tir_text_printer_debug.h
+++ b/src/relay/printer/tir_text_printer_debug.h
@@ -23,8 +23,8 @@
* that can be parsed by a parser.
*/
-#ifndef TVM_PRINTER_TIR_TEXT_PRINTER_DEBUG_H_
-#define TVM_PRINTER_TIR_TEXT_PRINTER_DEBUG_H_
+#ifndef TVM_RELAY_PRINTER_TIR_TEXT_PRINTER_DEBUG_H_
+#define TVM_RELAY_PRINTER_TIR_TEXT_PRINTER_DEBUG_H_
#include <tuple>
#include <vector>
@@ -32,7 +32,7 @@
#include "text_printer.h"
namespace tvm {
-namespace tir {
+namespace relay {
class TIRTextPrinterDebug : public TIRTextPrinter {
public:
@@ -64,7 +64,7 @@ class TIRTextPrinterDebug : public TIRTextPrinter {
std::vector<std::tuple<const PrimExprNode*, size_t>> exprs_by_line_;
};
-} // namespace tir
+} // namespace relay
} // namespace tvm
-#endif // TVM_PRINTER_TIR_TEXT_PRINTER_DEBUG_H_
+#endif // TVM_RELAY_PRINTER_TIR_TEXT_PRINTER_DEBUG_H_
diff --git a/src/printer/tvmscript_printer.cc b/src/relay/printer/tvmscript_printer.cc
similarity index 96%
rename from src/printer/tvmscript_printer.cc
rename to src/relay/printer/tvmscript_printer.cc
index c578bc53d3..0966110950 100644
--- a/src/printer/tvmscript_printer.cc
+++ b/src/relay/printer/tvmscript_printer.cc
@@ -39,13 +39,15 @@
#include <algorithm>
#include <utility>
-#include "../tir/transforms/ir_utils.h"
+#include "../../tir/transforms/ir_utils.h"
#include "doc.h"
#include "meta_data.h"
#include "text_printer.h"
namespace tvm {
-namespace tir {
+namespace relay {
+
+using namespace tvm::tir;
enum class ExprPrecedence : int {
/*! \brief Identity(e.g., IntImm, Var) and function call(e.g., floordiv, min) */
@@ -77,14 +79,14 @@ enum class ExprPrecedence : int {
*/
class BufferUsageFinder : public StmtExprVisitor {
public:
- static Map<Var, Array<Buffer>> FindUsage(Map<Var, Array<Buffer>> usage, Stmt body) {
+ static Map<tir::Var, Array<Buffer>> FindUsage(Map<tir::Var, Array<Buffer>> usage, Stmt body) {
BufferUsageFinder visitor(std::move(usage));
visitor.VisitStmt(body);
return std::move(visitor.usage_);
}
- void VisitExpr_(const VarNode* op) final {
- Var var = GetRef<Var>(op);
+ void VisitExpr_(const tir::VarNode* op) final {
+ tir::Var var = GetRef<tir::Var>(op);
if (!usage_.count(var)) {
usage_.Set(var, {});
}
@@ -107,7 +109,7 @@ class BufferUsageFinder : public StmtExprVisitor {
}
private:
- explicit BufferUsageFinder(Map<Var, Array<Buffer>> usage) : usage_(usage) {}
+ explicit BufferUsageFinder(Map<tir::Var, Array<Buffer>> usage) : usage_(usage) {}
void VisitBuffer(const Buffer& buffer) {
if (buffers_visited_.count(buffer.get())) {
@@ -124,7 +126,7 @@ class BufferUsageFinder : public StmtExprVisitor {
}
// The search result.
- Map<Var, Array<Buffer>> usage_;
+ Map<tir::Var, Array<Buffer>> usage_;
// The buffers that have been visited so far, to avoid duplicate
// entries in the search result.
std::unordered_set<const BufferNode*> buffers_visited_;
@@ -139,7 +141,7 @@ class BufferUsageFinder : public StmtExprVisitor {
* subexpression to decide whether or not parentheses is needed.
*/
class TVMScriptPrinter : public StmtFunctor<Doc(const Stmt&)>,
- public ExprFunctor<Doc(const PrimExpr&, ExprPrecedence*)>,
+ public tir::ExprFunctor<Doc(const PrimExpr&, ExprPrecedence*)>,
public TypeFunctor<Doc(const Type&)> {
public:
explicit TVMScriptPrinter(const String& tir_prefix, bool show_meta,
@@ -167,20 +169,20 @@ class TVMScriptPrinter : public StmtFunctor<Doc(const Stmt&)>,
/*! \brief meta data context */
TextMetaDataContext meta_;
/*! \brief meta collector */
- MetaCollector meta_collector_;
+ relay::MetaCollector meta_collector_;
/*! \brief map from Function to GlobalVar */
std::unordered_map<const BaseFuncNode*, GlobalVar> func2var_;
/*! \brief var collector (var defined by For/Loop/Block) */
- std::unordered_set<const VarNode*> var_not_in_headers_;
+ std::unordered_set<const tir::VarNode*> var_not_in_headers_;
/*!
* \brief buffer collector
* (buffer defined in BufferMap, BufferAllocation and MatchBufferRegion)
*/
std::unordered_set<const BufferNode*> buf_not_in_headers_;
/*! \brief Map from Var to thread env name */
- std::unordered_map<Var, String, ObjectPtrHash, ObjectPtrEqual> var_env_map_;
+ std::unordered_map<tir::Var, String, ObjectPtrHash, ObjectPtrEqual> var_env_map_;
/*! \brief Map from Var to Doc */
- std::unordered_map<Var, Doc, ObjectPtrHash, ObjectPtrEqual> memo_var_;
+ std::unordered_map<tir::Var, Doc, ObjectPtrHash, ObjectPtrEqual> memo_var_;
/*! \brief Map from Buffer to Doc */
std::unordered_map<Buffer, Doc, ObjectPtrHash, ObjectPtrEqual> memo_buf_;
/*! \brief Map from Buffer to Declaration Doc */
@@ -194,7 +196,7 @@ class TVMScriptPrinter : public StmtFunctor<Doc(const Stmt&)>,
/*! \brief loop stack without annotations */
std::vector<For> simple_loop_stack_;
/*! \brief the maps from loop_vars to the loops */
- std::unordered_map<const VarNode*, For> loop_var_map_;
+ std::unordered_map<const tir::VarNode*, For> loop_var_map_;
/*!
* \brief simple block vars remap from loop vars
* simple_remap requires:
@@ -210,12 +212,12 @@ class TVMScriptPrinter : public StmtFunctor<Doc(const Stmt&)>,
* LetStmt or Allocate that generates their data pointer, rather
* than in the header.
*/
- Map<Var, Array<Buffer>> buffer_var_usage_;
+ Map<tir::Var, Array<Buffer>> buffer_var_usage_;
/*! \brief Analyzer to simplify some expressions. */
arith::Analyzer ana_;
Doc VisitExpr_(const CastNode* op, ExprPrecedence* out_precedence) override;
- Doc VisitExpr_(const VarNode* op, ExprPrecedence* out_precedence) override;
+ Doc VisitExpr_(const tir::VarNode* op, ExprPrecedence* out_precedence) override;
Doc VisitExpr_(const AddNode* op, ExprPrecedence* out_precedence) override;
Doc VisitExpr_(const SubNode* op, ExprPrecedence* out_precedence) override;
Doc VisitExpr_(const MulNode* op, ExprPrecedence* out_precedence) override;
@@ -243,8 +245,8 @@ class TVMScriptPrinter : public StmtFunctor<Doc(const Stmt&)>,
Doc VisitExpr_(const LoadNode* op, ExprPrecedence* out_precedence) override;
Doc VisitExpr_(const RampNode* op, ExprPrecedence* out_precedence) override;
Doc VisitExpr_(const BroadcastNode* op, ExprPrecedence* out_precedence) override;
- Doc VisitExpr_(const LetNode* op, ExprPrecedence* out_precedence) override;
- Doc VisitExpr_(const CallNode* op, ExprPrecedence* out_precedence) override;
+ Doc VisitExpr_(const tir::LetNode* op, ExprPrecedence* out_precedence) override;
+ Doc VisitExpr_(const tir::CallNode* op, ExprPrecedence* out_precedence) override;
Doc VisitExpr_(const ShuffleNode* op, ExprPrecedence* out_precedence) override;
Doc VisitExpr_(const ReduceNode* op, ExprPrecedence* out_precedence) override;
Doc VisitExprDefault_(const Object* op, ExprPrecedence* out_precedence) override;
@@ -297,9 +299,9 @@ class TVMScriptPrinter : public StmtFunctor<Doc(const Stmt&)>,
static Doc PrintString(const StringObj* op) { return Doc::StrLiteral(op->data); }
Doc GetUniqueName(std::string prefix);
- Doc AllocVar(const Var& var);
+ Doc AllocVar(const tir::Var& var);
Doc AllocBuf(const Buffer& buffer);
- void TryDeallocVar(const Var& var);
+ void TryDeallocVar(const tir::Var& var);
bool ContainsOptionalInfo(const Stmt& stmt);
/*!
* \brief Check if a buffer declaration satisfies:
@@ -338,7 +340,9 @@ class TVMScriptPrinter : public StmtFunctor<Doc(const Stmt&)>,
* \return A boolean indicating whether the input loop depends on previous loops
*/
bool DependOnPrevLoops(const ForNode* for_op) {
- auto f_check = [&var_map = this->loop_var_map_](const VarNode* v) { return var_map.count(v); };
+ auto f_check = [&var_map = this->loop_var_map_](const tir::VarNode* v) {
+ return var_map.count(v);
+ };
return UsesVar(for_op->min, f_check) || UsesVar(for_op->extent, f_check);
}
@@ -494,7 +498,7 @@ Doc TVMScriptPrinter::GetUniqueName(std::string prefix) {
return Doc::Text(unique_prefix);
}
-Doc TVMScriptPrinter::AllocVar(const Var& var) {
+Doc TVMScriptPrinter::AllocVar(const tir::Var& var) {
const auto& it = memo_var_.find(var);
if (it != memo_var_.end()) {
return it->second;
@@ -522,8 +526,8 @@ Doc TVMScriptPrinter::AllocBufferDeclaration(const Buffer& buf) {
if (!buf->strides.empty()) {
doc << ", strides=" << Print(buf->strides);
}
- if (buf->elem_offset->IsInstance<VarNode>()) {
- Var elem_offset = Downcast<Var>(buf->elem_offset);
+ if (buf->elem_offset->IsInstance<tir::VarNode>()) {
+ tir::Var elem_offset = Downcast<tir::Var>(buf->elem_offset);
if (memo_var_.find(elem_offset) != memo_var_.end()) {
doc << ", elem_offset=" << Print(buf->elem_offset);
} else {
@@ -585,7 +589,7 @@ bool TVMScriptPrinter::ContainsOptionalInfo(const Stmt& stmt) {
* \brief Try to dealloc vars out of space and leave the index to coming vars.
* \note It is not a necessary step.
*/
-void TVMScriptPrinter::TryDeallocVar(const Var& var) {
+void TVMScriptPrinter::TryDeallocVar(const tir::Var& var) {
auto it = memo_var_.find(var);
ICHECK(it != memo_var_.end());
std::string print_name = it->second.str();
@@ -695,7 +699,7 @@ Doc TVMScriptPrinter::PrintCommReducer(const CommReducerNode* op) {
int n_var = static_cast<int>(op->rhs.size());
doc << tir_prefix_ << ".comm_reducer(lambda ";
- for (const Var& v_lhs : op->lhs) {
+ for (const tir::Var& v_lhs : op->lhs) {
doc << Print(v_lhs) << ", ";
}
for (int i = 0; i < n_var; ++i) {
@@ -789,10 +793,10 @@ Doc TVMScriptPrinter::VisitExpr_(const CastNode* op, ExprPrecedence* out_precede
return doc;
}
-Doc TVMScriptPrinter::VisitExpr_(const VarNode* op, ExprPrecedence* out_precedence) {
+Doc TVMScriptPrinter::VisitExpr_(const tir::VarNode* op, ExprPrecedence* out_precedence) {
*out_precedence = ExprPrecedence::kIdentity;
- const Var& var = GetRef<Var>(op);
- return meta_.InMeta(var) ? meta_.GetMetaNode(var) : AllocVar(GetRef<Var>(op));
+ const tir::Var& var = GetRef<tir::Var>(op);
+ return meta_.InMeta(var) ? meta_.GetMetaNode(var) : AllocVar(GetRef<tir::Var>(op));
}
bool WillPrintConstScalar(const PrimExpr& expr) {
@@ -938,7 +942,7 @@ Doc TVMScriptPrinter::VisitExpr_(const BroadcastNode* op, ExprPrecedence* out_pr
return doc;
}
-Doc TVMScriptPrinter::VisitExpr_(const LetNode* op, ExprPrecedence* out_precedence) {
+Doc TVMScriptPrinter::VisitExpr_(const tir::LetNode* op, ExprPrecedence* out_precedence) {
*out_precedence = ExprPrecedence::kIdentity;
Doc doc;
doc << tir_prefix_ << ".let(" << Print(op->var) << ", " << Print(op->value) << ", "
@@ -946,7 +950,7 @@ Doc TVMScriptPrinter::VisitExpr_(const LetNode* op, ExprPrecedence* out_preceden
return doc;
}
-Doc TVMScriptPrinter::VisitExpr_(const CallNode* op, ExprPrecedence* out_precedence) {
+Doc TVMScriptPrinter::VisitExpr_(const tir::CallNode* op, ExprPrecedence* out_precedence) {
*out_precedence = ExprPrecedence::kIdentity;
Doc doc;
if (auto* ptr_op = op->op.as<OpNode>()) {
@@ -1090,7 +1094,7 @@ Doc TVMScriptPrinter::VisitStmt_(const BufferRealizeNode* op) {
namespace {
bool IsAllocateDeclBufferPattern(const AllocateNode* allocate) {
- const Var& buffer_var = allocate->buffer_var;
+ const tir::Var& buffer_var = allocate->buffer_var;
const DeclBufferNode* decl_buffer = allocate->body.as<DeclBufferNode>();
if (!decl_buffer) {
return false;
@@ -1468,8 +1472,8 @@ Doc TVMScriptPrinter::PrintBlockVars(const BlockRealizeNode* op) {
auto is_simple_remap = [this, &expr_equal](const IterVar& iter_var,
const PrimExpr& value) -> bool {
if (iter_var->iter_type != kDataPar && iter_var->iter_type != kCommReduce) return false;
- if (!value->IsInstance<VarNode>()) return false;
- const Var& var = Downcast<Var>(value);
+ if (!value->IsInstance<tir::VarNode>()) return false;
+ const tir::Var& var = Downcast<tir::Var>(value);
auto it = loop_var_map_.find(var.get());
return it != loop_var_map_.end() && expr_equal(it->second->min, iter_var->dom->min) &&
expr_equal(it->second->extent, iter_var->dom->extent);
@@ -1763,7 +1767,7 @@ Doc TVMScriptPrinter::PrintPrimFunc(const PrimFunc& primFunc) {
}
// print var declaration
Doc header_var;
- std::vector<const VarNode*> vars;
+ std::vector<const tir::VarNode*> vars;
for (const auto& it : memo_var_) {
if (var_not_in_headers_.find(it.first.get()) == var_not_in_headers_.end()) {
vars.push_back(it.first.get());
@@ -1777,20 +1781,21 @@ Doc TVMScriptPrinter::PrintPrimFunc(const PrimFunc& primFunc) {
}
}
if (!vars.empty()) {
- std::sort(vars.begin(), vars.end(), [&](const VarNode* a, const VarNode* b) {
- return memo_var_[GetRef<Var>(a)].str() < memo_var_[GetRef<Var>(b)].str();
+ std::sort(vars.begin(), vars.end(), [&](const tir::VarNode* a, const tir::VarNode* b) {
+ return memo_var_[GetRef<tir::Var>(a)].str() < memo_var_[GetRef<tir::Var>(b)].str();
});
for (const auto& var : vars) {
- auto type = GetRef<Var>(var)->type_annotation;
+ auto type = GetRef<tir::Var>(var)->type_annotation;
if (auto* ptr_type = type.as<PointerTypeNode>()) {
auto* prim_type = ptr_type->element_type.as<PrimTypeNode>();
ICHECK(prim_type);
- header_var << Doc::NewLine() << Print(GetRef<Var>(var)) << " = " << tir_prefix_
+ header_var << Doc::NewLine() << Print(GetRef<tir::Var>(var)) << " = " << tir_prefix_
<< ".buffer_var(";
header_var << PrintDType(prim_type->dtype) << ", "
<< Doc::StrLiteral(ptr_type->storage_scope) << ")";
} else {
- header_var << Doc::NewLine() << Print(GetRef<Var>(var)) << " = " << tir_prefix_ << ".var(";
+ header_var << Doc::NewLine() << Print(GetRef<tir::Var>(var)) << " = " << tir_prefix_
+ << ".var(";
header_var << PrintDType(var->dtype) << ")";
}
}
@@ -2013,5 +2018,5 @@ String AsTVMScriptWithDiagnostic(const ObjectRef& mod, const String& tir_prefix,
TVM_REGISTER_GLOBAL("script.AsTVMScriptWithDiagnostic").set_body_typed(AsTVMScriptWithDiagnostic);
-} // namespace tir
+} // namespace relay
} // namespace tvm
diff --git a/src/relay/transforms/merge_compiler_regions.cc b/src/relay/transforms/merge_compiler_regions.cc
index d18c17e63c..d70c7480e9 100644
--- a/src/relay/transforms/merge_compiler_regions.cc
+++ b/src/relay/transforms/merge_compiler_regions.cc
@@ -30,9 +30,9 @@
* as external functions.
*/
-#include <tvm/ir/error.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/attrs/annotation.h>
+#include <tvm/relay/error.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h>
diff --git a/src/relay/transforms/partition_graph.cc b/src/relay/transforms/partition_graph.cc
index f6cdf6d1ca..32ca2878fd 100644
--- a/src/relay/transforms/partition_graph.cc
+++ b/src/relay/transforms/partition_graph.cc
@@ -29,10 +29,10 @@
* external functions, and they will use the provided compiler for codegen.
*/
-#include <tvm/ir/error.h>
#include <tvm/ir/module.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/attrs/annotation.h>
+#include <tvm/relay/error.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h>
diff --git a/src/script/printer/printer.cc b/src/script/printer/printer.cc
index 9ebdcb1e99..878b380a37 100644
--- a/src/script/printer/printer.cc
+++ b/src/script/printer/printer.cc
@@ -23,18 +23,11 @@ namespace tvm {
namespace script {
namespace printer {
-String Script(ObjectRef obj, int indent_spaces, bool print_line_numbers, int num_context_lines,
- Optional<ObjectPath> path_to_underline) {
- return DocToPythonScript(IRDocsifier()->AsDoc(obj, ObjectPath::Root()), indent_spaces,
- print_line_numbers, num_context_lines, path_to_underline);
-}
-
Default* Default::Instance() {
static Default inst;
return &inst;
}
-TVM_REGISTER_GLOBAL("script.printer.Script").set_body_typed(Script);
TVM_REGISTER_GLOBAL("script.printer.DefaultIRPrefix")
.set_body_typed([](std::string ir, std::string prefix) { Default::Prefix(ir) = prefix; });
TVM_REGISTER_GLOBAL("script.printer.DefaultBufferDType")
diff --git a/src/tir/schedule/error.cc b/src/tir/schedule/error.cc
index 55d751c331..1aae0202ac 100644
--- a/src/tir/schedule/error.cc
+++ b/src/tir/schedule/error.cc
@@ -16,7 +16,6 @@
* specific language governing permissions and limitations
* under the License.
*/
-#include "../../printer/text_printer.h"
#include "./utils.h"
namespace tvm {
@@ -52,10 +51,11 @@ String ScheduleError::RenderReport(const String& primitive) const {
}
return it->second;
});
-
+ const auto* f = runtime::Registry::Get("script.AsTVMScriptWithDiagnostic");
+ ICHECK(f != nullptr);
os << "ScheduleError: An error occurred in the schedule primitive '" << primitive
<< "'.\n\nThe IR with diagnostic is:\n"
- << AsTVMScriptWithDiagnostic(mod, "T", false, annotate);
+ << ((*f)(mod, "T", false, annotate).operator String());
// print error message
os << "Error message: " << msg;
diff --git a/src/tir/transforms/install_debug_spans.cc b/src/tir/transforms/install_debug_spans.cc
index bc9002ee84..c97070e1bf 100644
--- a/src/tir/transforms/install_debug_spans.cc
+++ b/src/tir/transforms/install_debug_spans.cc
@@ -30,7 +30,7 @@
#include <string>
#include <utility>
-#include "../../printer/tir_text_printer_debug.h"
+#include "../../relay/printer/tir_text_printer_debug.h"
namespace tvm {
namespace tir {
@@ -42,7 +42,7 @@ Stmt DebugInfoInstaller::InstallInfo(const std::string& name, const Stmt& stmt)
DebugInfoInstaller::DebugInfoInstaller(const Stmt& stmt, const std::string& filename) {
// Determine the line that each stmt/expr will be printed on
- tvm::tir::TIRTextPrinterDebug printer(false);
+ tvm::relay::TIRTextPrinterDebug printer(false);
// Fill in the stmts and exprs' line info
auto result = printer.Print(stmt).str();
diff --git a/tests/python/relay/test_ir_parser.py b/tests/python/relay/test_ir_parser.py
index 5ea6d7e5de..08fa01f0b3 100644
--- a/tests/python/relay/test_ir_parser.py
+++ b/tests/python/relay/test_ir_parser.py
@@ -14,15 +14,15 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from typing import Union
+
import numpy as np
import pytest
-
import tvm
-import tvm.testing
-from tvm import relay
import tvm.relay.testing
+import tvm.testing
from numpy import isclose
-from typing import Union
+from tvm import relay
SEMVER = '#[version = "0.0.5"]\n'
@@ -74,7 +74,7 @@ def graph_equal(lhs, rhs):
def roundtrip_expr(expr):
- text = tvm.relay.Expr.astext(expr, show_meta_data=False)
+ text = expr.astext()
x = tvm.parser.parse_expr(text)
assert_graph_equal(x, expr)
diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_mlt.py b/tests/python/unittest/test_meta_schedule_schedule_rule_mlt.py
index bb96022794..f40d942749 100644
--- a/tests/python/unittest/test_meta_schedule_schedule_rule_mlt.py
+++ b/tests/python/unittest/test_meta_schedule_schedule_rule_mlt.py
@@ -885,5 +885,4 @@ def test_max_pool_blocked():
if __name__ == "__main__":
- # tvm.testing.main()
- test_cache_read_specify_consumer()
+ tvm.testing.main()
diff --git a/tests/python/unittest/test_tir_nodes.py b/tests/python/unittest/test_tir_nodes.py
index d4ae84a556..2806c7b2fc 100644
--- a/tests/python/unittest/test_tir_nodes.py
+++ b/tests/python/unittest/test_tir_nodes.py
@@ -343,7 +343,6 @@ def test_prim_func():
func = tvm.tir.PrimFunc([x, y, b], stmt)
# make sure we can print
- func.astext()
assert func.buffer_map[func.params[2]].same_as(b)
assert len(func.buffer_map) == 1
@@ -399,130 +398,5 @@ def test_intimm_cond():
assert x == 1
-def test_block_blockrealize():
- x = tvm.tir.Var("x", "int32")
- y = tvm.tir.Var("y", "int32")
- vx = tvm.tir.IterVar((16, 16), "vx", 0)
- vx_var = vx.var
- vy = tvm.tir.IterVar((16, 16), "vy", 2)
- vy_var = vy.var
- A = tvm.tir.decl_buffer((16), "float32")
- B = tvm.tir.decl_buffer((16, 16), "float32")
- alloc_buffer = tvm.tir.decl_buffer((16, 16), "float32")
- match_buffer = tvm.tir.decl_buffer((16, 16), "float32")
- init_body = tvm.tir.BufferStore(A, 0.0, [vx_var])
- body = tvm.tir.BufferStore(
- A,
- tvm.tir.BufferLoad(A, [vx_var]) + tvm.tir.BufferLoad(B, [vx_var, vy_var]),
- [vx_var],
- )
- reads = [
- tvm.tir.BufferRegion(
- B, [tvm.ir.Range.from_min_extent(vx_var, 1), tvm.ir.Range.from_min_extent(vy_var, 1)]
- )
- ]
- writes = [tvm.tir.BufferRegion(A, [tvm.ir.Range.from_min_extent(vx_var, 1)])]
- block_match_buffer = tvm.tir.MatchBufferRegion(
- match_buffer, tvm.tir.BufferRegion(B, [tvm.ir.Range(0, 16), tvm.ir.Range(0, 16)])
- )
-
- block = tvm.tir.Block(
- [vx, vy],
- reads,
- writes,
- "block",
- body,
- init=init_body,
- alloc_buffers=[alloc_buffer],
- match_buffers=[block_match_buffer],
- annotations={"attr_key": "attr_value"},
- )
-
- # Checking Block
- assert isinstance(block, tvm.tir.Block)
- # Checking iter_vars
- assert block.iter_vars[0] == vx
- assert block.iter_vars[1] == vy
- # Checking reads/writes region
- assert isinstance(block.reads[0], tvm.tir.BufferRegion)
- assert block.reads[0].buffer == B
- assert block.reads[0].region[0].min == vx_var
- assert block.reads[0].region[1].min == vy_var
- assert isinstance(block.writes[0], tvm.tir.BufferRegion)
- assert block.writes[0].buffer == A
- assert block.writes[0].region[0].min == vx_var
- assert block.writes[0].region[0].extent == 1
- # Checking name_hint
- assert block.name_hint == "block"
- # Checking body
- assert block.body == body
- # Checking init
- assert block.init == init_body
- # Checking alloc_buffers
- assert block.alloc_buffers[0] == alloc_buffer
- # Checking match_buffers
- assert block.match_buffers[0].buffer == match_buffer
- assert isinstance(block.match_buffers[0].source, tvm.tir.BufferRegion)
- assert block.match_buffers[0].source.buffer == B
- assert block.match_buffers[0].source.region[0].min == 0
- assert block.match_buffers[0].source.region[0].extent == 16
-
- # Checking BlockRealize
- block_realize = tvm.tir.BlockRealize([x, y], tvm.tir.const(True, "bool"), block)
- assert isinstance(block_realize, tvm.tir.BlockRealize)
- assert block_realize.iter_values[0] == x
- assert block_realize.iter_values[1] == y
- assert block_realize.predicate == tvm.tir.const(True, "bool")
- assert block_realize.block == block
-
- # make sure we can print using ReprPrinter
- str(block)
- str(block_realize)
- # make sure we can print using TIRTextPrinter
- func = tvm.tir.PrimFunc([], block_realize)
- output = func.astext()
- assert output.find("meta[tir.BlockRealise]") == -1
- assert output.find("bind") != -1
- assert output.find("reads") != -1
- assert output.find("writes") != -1
- assert output.find("alloc_buffer") != -1
- assert output.find("match_buffer") != -1
- assert output.find("attr") != -1
- assert output.find("with init()") != -1
-
-
-def test_tir_allocate():
- dtype = "int8"
- storage_scope = "global"
- ptype = tvm.ir.PointerType(tvm.ir.PrimType(dtype), storage_scope)
- a = te.var("buffer", ptype)
- allocate = tvm.tir.Allocate(
- buffer_var=a,
- dtype=dtype,
- extents=[2, 2],
- condition=tvm.get_global_func("tir.const_true")(dtype, None),
- body=tvm.tir.Evaluate(2 + 1),
- annotations={
- "attr1": "foo",
- "attr2": "bar",
- },
- )
- assert allocate.buffer_var == a
- assert allocate.dtype == "int8"
- assert list(allocate.extents) == [2, 2]
- assert allocate.annotations["attr1"] == "foo"
- assert allocate.annotations["attr2"] == "bar"
-
- # make sure we can print using TIRTextPrinter
- func = tvm.tir.PrimFunc([], allocate)
- output = func.astext()
- assert (
- output.find(
- 'allocate(buffer: Pointer(global int8), int8, [2, 2]), storage_scope = global, annotations = {"attr2": "bar", "attr1": "foo"})'
- )
- != -1
- )
-
-
if __name__ == "__main__":
tvm.testing.main()
diff --git a/tests/python/unittest/test_tir_transform_lower_warp_memory.py b/tests/python/unittest/test_tir_transform_lower_warp_memory.py
index 48af3ebaf5..d4abc26bb2 100644
--- a/tests/python/unittest/test_tir_transform_lower_warp_memory.py
+++ b/tests/python/unittest/test_tir_transform_lower_warp_memory.py
@@ -14,14 +14,13 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+import numpy as np
+import pytest
import tvm
+import tvm.testing
from tvm import te
from tvm.contrib.nvcc import have_fp16
-import numpy as np
-import tvm.testing
-import pytest
-
@tvm.testing.requires_cuda
def test_lower_warp_memory_local_scope():
@@ -320,7 +319,7 @@ def test_lower_warp_memory_same_thread():
fdevice = tvm.tir.transform.SplitHostDevice()(mod)["f_kernel0"]
mod = tvm.IRModule.from_expr(fdevice)
fdevice = tvm.tir.transform.LowerWarpMemory()(mod)["f_kernel0"]
- assert "tvm_warp_shuffle" not in fdevice.astext()
+ assert "tvm_warp_shuffle" not in fdevice.script()
@tvm.testing.requires_cuda
diff --git a/tests/python/unittest/test_tvmscript_printer_syntax_sugar.py b/tests/python/unittest/test_tvmscript_printer_syntax_sugar.py
deleted file mode 100644
index 1bccb8188c..0000000000
--- a/tests/python/unittest/test_tvmscript_printer_syntax_sugar.py
+++ /dev/null
@@ -1,69 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-import pytest
-import tvm.testing
-from tvm.script.parser import tir as T
-from tvm.script import script
-
-
-def _test(obj, expected: str):
- assert script(obj).strip() == expected.strip()
-
-
-def test_remap():
- @T.prim_func
- def block_with_remap_implicitly():
- for i0, i1, i2, i3, i4, i5 in T.grid(128, 128, 128, 128, 128, 128):
- with T.block("update"):
- v0 = T.axis.spatial(128, i0 + 1)
- v1 = T.axis.spatial(128, i1)
- v2 = T.axis.reduce(128, i2)
- v3 = T.axis.spatial(128, i3 - 1)
- v4 = T.axis.reduce(128, i4)
- v5 = T.axis.spatial(128, i5)
- pass
-
- @T.prim_func
- def block_with_remap_explicitly():
- for i0, i1, i2, i3, i4, i5 in T.grid(128, 128, 128, 128, 128, 128):
- with T.block("update"):
- v0 = T.axis.spatial(128, i0 + 1)
- v1, v2 = T.axis.remap("SR", [i1, i2])
- v3 = T.axis.spatial(128, i3 - 1)
- v4, v5 = T.axis.remap("RS", [i4, i5])
- pass
-
- expected_output = """@T.prim_func
-def main():
- with T.block("root"):
- T.reads()
- T.writes()
- for i0, i1, i2, i3, i4, i5 in T.grid(128, 128, 128, 128, 128, 128):
- with T.block("update"):
- v0 = T.axis.spatial(128, i0 + 1)
- v1, v2 = T.axis.remap("SR", [i1, i2])
- v3 = T.axis.spatial(128, i3 - 1)
- v4, v5 = T.axis.remap("RS", [i4, i5])
- T.reads()
- T.writes()
- T.evaluate(0)"""
- _test(block_with_remap_implicitly, expected_output)
- _test(block_with_remap_explicitly, expected_output)
-
-
-if __name__ == "__main__":
- tvm.testing.main()
diff --git a/tests/python/unittest/test_tvmscript_printer_tir.py b/tests/python/unittest/test_tvmscript_printer_tir.py
index 9c15fbc889..d62a1cd12c 100644
--- a/tests/python/unittest/test_tvmscript_printer_tir.py
+++ b/tests/python/unittest/test_tvmscript_printer_tir.py
@@ -598,6 +598,47 @@ def test_tuple_type():
_assert_print(obj, "T.Tuple(T.float32, T.int32)")
+def test_remap():
+ from tvm.script import tir as T
+
+ @T.prim_func
+ def block_with_remap_implicitly():
+ for i0, i1, i2, i3, i4, i5 in T.grid(128, 128, 128, 128, 128, 128):
+ with T.block("update"):
+ v0 = T.axis.spatial(128, i0 + 1)
+ v1 = T.axis.spatial(128, i1)
+ v2 = T.axis.reduce(128, i2)
+ v3 = T.axis.spatial(128, i3 - 1)
+ v4 = T.axis.reduce(128, i4)
+ v5 = T.axis.spatial(128, i5)
+
+ @T.prim_func
+ def block_with_remap_explicitly():
+ for i0, i1, i2, i3, i4, i5 in T.grid(128, 128, 128, 128, 128, 128):
+ with T.block("update"):
+ v0 = T.axis.spatial(128, i0 + 1)
+ v1, v2 = T.axis.remap("SR", [i1, i2])
+ v3 = T.axis.spatial(128, i3 - 1)
+ v4, v5 = T.axis.remap("RS", [i4, i5])
+
+ expected_output = """@T.prim_func
+def main():
+ with T.block("root"):
+ T.reads()
+ T.writes()
+ for i0, i1, i2, i3, i4, i5 in T.grid(128, 128, 128, 128, 128, 128):
+ with T.block("update"):
+ v0 = T.axis.spatial(128, i0 + 1)
+ v1, v2 = T.axis.remap("SR", [i1, i2])
+ v3 = T.axis.spatial(128, i3 - 1)
+ v4, v5 = T.axis.remap("RS", [i4, i5])
+ T.reads()
+ T.writes()
+ T.evaluate(0)"""
+ _assert_print(block_with_remap_explicitly, expected_output)
+ _assert_print(block_with_remap_implicitly, expected_output)
+
+
if __name__ == "__main__":
test_prim_func()
test_block_realize()
@@ -639,3 +680,4 @@ if __name__ == "__main__":
test_prim_type()
test_pointer_type()
test_tuple_type()
+ test_remap()