You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2023/01/14 03:36:58 UTC
[tvm] branch main updated: [TVMScript] IR Fragment Printing (#13742)
This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new c452e6966c [TVMScript] IR Fragment Printing (#13742)
c452e6966c is described below
commit c452e6966c33047512155f42f63aef4e0586d129
Author: Junru Shao <ju...@gmail.com>
AuthorDate: Fri Jan 13 19:36:47 2023 -0800
[TVMScript] IR Fragment Printing (#13742)
This PR introduces support for TIR fragment printing.
Fragment printing makes it possible to print TIR fragments in the text
format consistency with TVMScript PrimFunc/IRModule printing.
This PR still preserves the legacy ReprPrinter format by introducing an
API `LegacyTIRPrint` for TIR PrimExpr. This method is used in
AutoScheduler and TIR CSE for full backward compatibility.
---
include/tvm/runtime/data_type.h | 2 +-
include/tvm/script/printer/ir_docsifier.h | 12 +-
include/tvm/script/printer/printer.h | 17 +-
include/tvm/tir/expr.h | 3 +
python/tvm/script/ir_builder/tir/ir.py | 8 +-
python/tvm/script/printer/__init__.py | 1 +
python/tvm/script/printer/default.py | 83 +++
python/tvm/script/printer/printer.py | 14 +-
src/auto_scheduler/compute_dag.cc | 27 +-
src/ir/expr.cc | 35 --
src/ir/type.cc | 28 -
.../printer/doc_printer/python_doc_printer.cc | 18 +-
src/script/printer/ir_docsifier.cc | 11 +-
src/script/printer/printer.cc | 22 +-
src/script/printer/tir/block.cc | 42 +-
src/script/printer/tir/buffer.cc | 144 +++--
src/script/printer/tir/expr.cc | 94 +--
src/script/printer/tir/for_loop.cc | 16 +-
src/script/printer/tir/function.cc | 17 +-
src/script/printer/tir/ir.cc | 42 +-
src/script/printer/tir/stmt.cc | 53 +-
src/script/printer/tir/utils.h | 55 +-
src/tir/ir/buffer.cc | 6 -
src/tir/ir/expr.cc | 351 ------------
src/tir/ir/function.cc | 15 -
src/tir/ir/index_map.cc | 3 +-
src/tir/ir/legacy_printer.cc | 270 +++++++++
src/tir/ir/stmt.cc | 403 -------------
src/tir/transforms/common_subexpr_elim.cc | 4 +-
src/tir/transforms/common_subexpr_elim_tools.cc | 4 +-
tests/cpp/expr_test.cc | 2 +-
tests/python/driver/tvmc/test_shape_parser.py | 5 +-
tests/python/relay/aot/test_c_device_api.py | 33 +-
tests/python/relay/aot/test_crt_aot.py | 35 +-
.../test_tvmscript_printer_python_doc_printer.py | 3 +-
.../python/unittest/test_tvmscript_printer_tir.py | 638 +++++++++++++++++++++
.../unittest/test_tvmscript_printer_underlining.py | 12 +-
vta/python/vta/transform.py | 2 +-
38 files changed, 1425 insertions(+), 1105 deletions(-)
diff --git a/include/tvm/runtime/data_type.h b/include/tvm/runtime/data_type.h
index 089147798a..f52e95c756 100644
--- a/include/tvm/runtime/data_type.h
+++ b/include/tvm/runtime/data_type.h
@@ -348,7 +348,7 @@ inline std::string DLDataType2String(DLDataType t) {
inline DLDataType String2DLDataType(std::string s) {
DLDataType t;
// handle void type
- if (s.length() == 0) {
+ if (s.length() == 0 || s == "void") {
t = DataType::Void();
return t;
}
diff --git a/include/tvm/script/printer/ir_docsifier.h b/include/tvm/script/printer/ir_docsifier.h
index e97ddc0234..e426946b56 100644
--- a/include/tvm/script/printer/ir_docsifier.h
+++ b/include/tvm/script/printer/ir_docsifier.h
@@ -126,10 +126,6 @@ class IRDocsifierNode : public Object {
/*! \brief The name of the variable */
Optional<String> name;
};
- /*!
- * \brief This map connects IR dispatch token to the name of identifier.
- */
- Map<String, String> ir_prefix;
/*!
* \brief The stack of frames.
* \sa FrameNode
@@ -152,7 +148,6 @@ class IRDocsifierNode : public Object {
std::unordered_map<const Object*, std::vector<const Object*>> common_prefix;
void VisitAttrs(tvm::AttrVisitor* v) {
- v->Visit("ir_prefix", &ir_prefix);
v->Visit("frames", &frames);
v->Visit("dispatch_tokens", &dispatch_tokens);
v->Visit("mod", &mod);
@@ -236,11 +231,8 @@ class IRDocsifierNode : public Object {
class IRDocsifier : public ObjectRef {
public:
using FType = IRDocsifierFunctor<printer::Doc, ObjectPath, IRDocsifier>;
- /*!
- * \brief Create a IRDocsifier.
- * \param ir_prefix The ir_prefix to use for this IRDocsifier.
- */
- explicit IRDocsifier(Map<String, String> ir_prefix);
+ /*! \brief Create a IRDocsifier. */
+ IRDocsifier();
/*! \brief The registration table for IRDocsifier. */
TVM_DLL static FType& vtable();
diff --git a/include/tvm/script/printer/printer.h b/include/tvm/script/printer/printer.h
index 31abd7d9ec..289e838b52 100644
--- a/include/tvm/script/printer/printer.h
+++ b/include/tvm/script/printer/printer.h
@@ -22,6 +22,7 @@
#include <tvm/node/node.h>
#include <tvm/script/printer/ir_docsifier.h>
+#include <string>
#include <unordered_map>
#include <vector>
@@ -31,6 +32,8 @@ namespace printer {
/*! \brief Default values in the TVMScript printer */
struct Default {
+ /*! \brief The prefix of IR nodes */
+ std::unordered_map<std::string, std::string> ir_prefix = {{"ir", "I"}, {"tir", "T"}};
/*! \brief Default data type of TIR buffer */
DataType buffer_dtype = DataType::Float(32);
/*! \brief Default data type of integer literals */
@@ -41,28 +44,30 @@ struct Default {
* T.float32/T.float64 wrapper.
*/
DataType float_dtype = DataType::Void();
+ /*! \brief Whether or not to verbose print expressions. */
+ bool verbose_expr = false;
/*! \brief Returns a singleton of the configuration */
static Default* Instance();
+ static std::string& Prefix(const std::string& ir) { return Instance()->ir_prefix.at(ir); }
static DataType& BufferDType() { return Instance()->buffer_dtype; }
static DataType& IntDType() { return Instance()->int_dtype; }
static DataType& FloatDType() { return Instance()->float_dtype; }
+ static bool& VerboseExpr() { return Instance()->verbose_expr; }
};
/*!
* \brief The entry method for TVMScript printing
* \param obj The object to be printed
- * \param ir_prefix The prefix of IR nodes
* \param indent_spaces Number of spaces used for indentation
* \param print_line_numbers Whether to print line numbers
* \param num_context_lines Number of context lines to print around the underlined text
* \param path_to_underline Object path to be underlined
* \return The TVMScript text format
*/
-String Script(ObjectRef obj, //
- Map<String, String> ir_prefix = {{"ir", "I"}, {"tir", "T"}}, //
- int indent_spaces = 4, //
- bool print_line_numbers = false, //
- int num_context_lines = -1, //
+String Script(ObjectRef obj, //
+ int indent_spaces = 4, //
+ bool print_line_numbers = false, //
+ int num_context_lines = -1, //
Optional<ObjectPath> path_to_underline = NullOpt);
/*!
diff --git a/include/tvm/tir/expr.h b/include/tvm/tir/expr.h
index 689b1c0a17..1d5e8f317a 100644
--- a/include/tvm/tir/expr.h
+++ b/include/tvm/tir/expr.h
@@ -1191,6 +1191,9 @@ class Any : public PrimExpr {
TVM_DEFINE_OBJECT_REF_COW_METHOD(AnyNode);
};
+/*! \brief Legacy ReprPrint format for TIR */
+std::string LegacyTIRPrint(const ObjectRef& obj);
+
/*
* \brief Template function to convert Map to unordered_map
* Sometimes useful for API gluing when internal uses unordered_map
diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py
index 06a85fa340..d4b280a37f 100644
--- a/python/tvm/script/ir_builder/tir/ir.py
+++ b/python/tvm/script/ir_builder/tir/ir.py
@@ -1210,17 +1210,17 @@ def buffer_store(buffer: Buffer, value: PrimExpr, indices: List[Union[PrimExpr,
)
-def prefetch(buffer: Buffer, indices: List[PrimExpr]) -> None:
+def prefetch(buffer: Buffer, bounds: List[Range]) -> None:
"""The prefetch hint for a buffer.
Parameters
----------
buffer : Buffer
The buffer to be prefetched.
- indices : List[PrimExpr]
- The indices of the buffer to extract.
+ bounds : List[Range]
+ The range to be prefetched.
"""
- return _ffi_api.Prefetch(buffer, indices) # type: ignore[attr-defined] # pylint: disable=no-member
+ return _ffi_api.Prefetch(buffer, bounds) # type: ignore[attr-defined] # pylint: disable=no-member
def evaluate(value: PrimExpr) -> None:
diff --git a/python/tvm/script/printer/__init__.py b/python/tvm/script/printer/__init__.py
index 25ea619a41..dc37ea1ff6 100644
--- a/python/tvm/script/printer/__init__.py
+++ b/python/tvm/script/printer/__init__.py
@@ -19,4 +19,5 @@ TVMScript Unified Printer
This package provides a set of APIs to print supported TVM IR into TVMScript
in a roundtrippable way.
"""
+from . import default
from .printer import script
diff --git a/python/tvm/script/printer/default.py b/python/tvm/script/printer/default.py
new file mode 100644
index 0000000000..33ca693ebf
--- /dev/null
+++ b/python/tvm/script/printer/default.py
@@ -0,0 +1,83 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""The printer configuration"""
+from typing_extensions import Literal
+
+from . import _ffi_api
+
+
+def ir_prefix( # pylint: disable=invalid-name
+ ir: Literal["ir", "tir"],
+ prefix: str,
+) -> None:
+ """Set the prefix for the IR. If not set, the prefix for "tvm.ir" is "I", and for "tir" is "T.
+
+ Parameters
+ ----------
+ ir : str
+ The IR type, either "ir" or "tir".
+
+ prefix : str
+ The prefix to use.
+ """
+ _ffi_api.DefaultIRPrefix(ir, prefix) # type: ignore # pylint: disable=no-member
+
+
+def buffer_dtype(dtype: str) -> None:
+ """Set the default dtype for buffer. If not set, it is "float32".
+
+ Parameters
+ ----------
+ dtype : str
+ The default dtype for buffer.
+ """
+ _ffi_api.DefaultBufferDtype(dtype) # type: ignore # pylint: disable=no-member
+
+
+def int_dtype(dtype: str) -> None:
+ """Set the default dtype for integers. If not set, it is "int32".
+
+ Parameters
+ ----------
+ dtype : str
+ The default dtype for buffer.
+ """
+ _ffi_api.DefaultBufferDtype(dtype) # type: ignore # pylint: disable=no-member
+
+
+def float_dtype(dtype: str) -> None:
+ """Set the default dtype for buffer. If not set, there is no default,
+ which means every floating point numbers will be wrapped with its precise dtype.
+
+ Parameters
+ ----------
+ dtype : str
+ The default dtype for buffer.
+ """
+ _ffi_api.DefaultFloatDtype(dtype) # type: ignore # pylint: disable=no-member
+
+
+def verbose_expr(verbose: bool) -> None:
+ """Whether or not to verbose print expressions. If not, the definition of every variable in an
+ expression will be printed as separate statements. Otherwise, the result will be a one-liner.
+
+ Parameters
+ ----------
+ dtype : str
+ The default dtype for buffer.
+ """
+ _ffi_api.VerboseExpr(verbose) # type: ignore # pylint: disable=no-member
diff --git a/python/tvm/script/printer/printer.py b/python/tvm/script/printer/printer.py
index 120ef03f57..2ce6329dca 100644
--- a/python/tvm/script/printer/printer.py
+++ b/python/tvm/script/printer/printer.py
@@ -15,8 +15,7 @@
# specific language governing permissions and limitations
# under the License.
"""The printer interface"""
-
-from typing import Mapping, Optional
+from typing import Optional
from tvm.runtime.object_path import ObjectPath
@@ -25,7 +24,6 @@ from . import _ffi_api
def script(
obj,
- ir_prefix: Optional[Mapping[str, str]] = None,
indent_space: int = 4,
print_line_number: bool = False,
num_context_lines: int = -1,
@@ -37,9 +35,6 @@ def script(
----------
obj : object
An TVM object representing TVM IR
- ir_prefix : Optional[Mapping[str, str]]
- A mapping from IR type to the prefix of the script.
- Default to {"ir": "I", "tir": T}
indent_space : int = 4
The number of spaces to indent
print_line_number : bool = False
@@ -54,11 +49,6 @@ def script(
script : str
The TVMScript text format
"""
- if ir_prefix is None:
- ir_prefix = {
- "ir": "I",
- "tir": "T",
- }
return _ffi_api.Script( # type: ignore # pylint: disable=no-member
- obj, ir_prefix, indent_space, print_line_number, num_context_lines, path_to_underline
+ obj, indent_space, print_line_number, num_context_lines, path_to_underline
)
diff --git a/src/auto_scheduler/compute_dag.cc b/src/auto_scheduler/compute_dag.cc
index 5500707fb9..3a92242276 100644
--- a/src/auto_scheduler/compute_dag.cc
+++ b/src/auto_scheduler/compute_dag.cc
@@ -1270,29 +1270,32 @@ String ComputeDAG::PrintDAG(bool simple_mode) const {
if (pop->body.size() > 1) {
ss << ".v" << k;
}
- if (auto preduce = pop->body[k].as<ReduceNode>()) {
- ICHECK_LT(k, preduce->combiner->result.size());
- PrimExpr combiner = preduce->combiner->result[k];
+ if (auto p_reduce = pop->body[k].as<ReduceNode>()) {
+ ICHECK_LT(k, p_reduce->combiner->result.size());
+ PrimExpr combiner = p_reduce->combiner->result[k];
if (combiner->IsInstance<AddNode>()) {
- ss << " += " << preduce->source[0] << "\n";
+ ss << " += " << LegacyTIRPrint(p_reduce->source[0]) << "\n";
} else if (combiner->IsInstance<MaxNode>()) {
- ss << " max= " << preduce->source[0] << "\n";
+ ss << " max= " << LegacyTIRPrint(p_reduce->source[0]) << "\n";
} else if (combiner->IsInstance<MinNode>()) {
- ss << " min= " << preduce->source[0] << "\n";
+ ss << " min= " << LegacyTIRPrint(p_reduce->source[0]) << "\n";
} else if (combiner->IsInstance<SelectNode>()) {
const auto& select = combiner.as<SelectNode>();
- ss << " select(" << select->condition << ", " << select->true_value << ", "
- << select->false_value << ")= " << '(' << preduce->source[0] << ','
- << preduce->source[1] << ")\n";
+ ss << " select(" << LegacyTIRPrint(select->condition) //
+ << ", " << LegacyTIRPrint(select->true_value) //
+ << ", " << LegacyTIRPrint(select->false_value) //
+ << ")= (" << LegacyTIRPrint(p_reduce->source[0]) //
+ << ',' << LegacyTIRPrint(p_reduce->source[1]) //
+ << ")\n";
} else {
- ss << "reduce" << combiner << "\n";
+ ss << "reduce" << LegacyTIRPrint(combiner) << "\n";
}
} else {
auto call = pop->body[k].as<CallNode>();
if (simple_mode && call) {
- ss << " = " << call->op << "\n";
+ ss << " = " << LegacyTIRPrint(call->op) << "\n";
} else {
- ss << " = " << pop->body[k] << "\n";
+ ss << " = " << LegacyTIRPrint(pop->body[k]) << "\n";
}
}
}
diff --git a/src/ir/expr.cc b/src/ir/expr.cc
index f097f8f363..7ba99e34d5 100644
--- a/src/ir/expr.cc
+++ b/src/ir/expr.cc
@@ -106,16 +106,6 @@ TVM_REGISTER_GLOBAL("ir.IntImm").set_body_typed([](DataType dtype, int64_t value
TVM_REGISTER_NODE_TYPE(IntImmNode);
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
- .set_dispatch<IntImmNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const IntImmNode*>(node.get());
- if (op->dtype == DataType::Int(32)) {
- p->stream << op->value;
- } else {
- p->stream << "(" << op->dtype << ")" << op->value;
- }
- });
-
FloatImm::FloatImm(DataType dtype, double value, Span span) {
ICHECK_EQ(dtype.lanes(), 1) << "ValueError: FloatImm can only take scalar.";
@@ -149,25 +139,6 @@ TVM_REGISTER_GLOBAL("ir.FloatImm").set_body_typed([](DataType dtype, double valu
TVM_REGISTER_NODE_TYPE(FloatImmNode);
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
- .set_dispatch<FloatImmNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const FloatImmNode*>(node.get());
- auto& stream = p->stream;
- switch (op->dtype.bits()) {
- case 64:
- stream << op->value;
- break;
- case 32:
- stream << op->value << 'f';
- break;
- case 16:
- stream << op->value << 'h';
- break;
- default:
- LOG(FATAL) << "Unknown float type bits=" << op->dtype.bits();
- }
- });
-
Range::Range(PrimExpr begin, PrimExpr end, Span span)
: Range(make_object<RangeNode>(begin, tir::is_zero(begin) ? end : (end - begin), span)) {}
@@ -183,12 +154,6 @@ TVM_REGISTER_GLOBAL("ir.Range").set_body([](TVMArgs args, TVMRetValue* ret) {
TVM_REGISTER_NODE_TYPE(RangeNode);
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
- .set_dispatch<RangeNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const RangeNode*>(node.get());
- p->stream << "range(min=" << op->min << ", ext=" << op->extent << ')';
- });
-
GlobalVar::GlobalVar(String name_hint, Type type, Span span) {
ObjectPtr<GlobalVarNode> n = make_object<GlobalVarNode>();
n->name_hint = std::move(name_hint);
diff --git a/src/ir/type.cc b/src/ir/type.cc
index fe8e00329b..ee05fd0359 100644
--- a/src/ir/type.cc
+++ b/src/ir/type.cc
@@ -37,12 +37,6 @@ TVM_REGISTER_GLOBAL("ir.PrimType").set_body_typed([](runtime::DataType dtype) {
return PrimType(dtype);
});
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
- .set_dispatch<PrimTypeNode>([](const ObjectRef& ref, ReprPrinter* p) {
- auto* node = static_cast<const PrimTypeNode*>(ref.get());
- p->stream << node->dtype;
- });
-
PointerType::PointerType(Type element_type, String storage_scope) {
ObjectPtr<PointerTypeNode> n = make_object<PointerTypeNode>();
n->element_type = std::move(element_type);
@@ -57,16 +51,6 @@ TVM_REGISTER_GLOBAL("ir.PointerType")
return PointerType(element_type, storage_scope);
});
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
- .set_dispatch<PointerTypeNode>([](const ObjectRef& ref, ReprPrinter* p) {
- auto* node = static_cast<const PointerTypeNode*>(ref.get());
- if (!node->storage_scope.empty()) {
- p->stream << node->storage_scope << " ";
- }
- p->Print(node->element_type);
- p->stream << '*';
- });
-
TypeVar::TypeVar(String name, TypeKind kind, Span span) {
ObjectPtr<TypeVarNode> n = make_object<TypeVarNode>();
n->name_hint = std::move(name);
@@ -148,12 +132,6 @@ TVM_REGISTER_GLOBAL("ir.TupleType").set_body_typed([](Array<Type> fields) {
return TupleType(fields);
});
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
- .set_dispatch<TupleTypeNode>([](const ObjectRef& ref, ReprPrinter* p) {
- auto* node = static_cast<const TupleTypeNode*>(ref.get());
- p->stream << "TupleTypeNode(" << node->fields << ")";
- });
-
IncompleteType::IncompleteType(TypeKind kind, Span span) {
auto n = make_object<IncompleteTypeNode>();
n->kind = std::move(kind);
@@ -167,12 +145,6 @@ TVM_REGISTER_GLOBAL("ir.IncompleteType").set_body_typed([](int kind) {
return IncompleteType(static_cast<TypeKind>(kind));
});
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
- .set_dispatch<IncompleteTypeNode>([](const ObjectRef& ref, ReprPrinter* p) {
- auto* node = static_cast<const IncompleteTypeNode*>(ref.get());
- p->stream << "IncompleteTypeNode(" << node->kind << ", " << node << ")";
- });
-
RelayRefType::RelayRefType(Type value, Span span) {
ObjectPtr<RelayRefTypeNode> n = make_object<RelayRefTypeNode>();
n->value = std::move(value);
diff --git a/src/script/printer/doc_printer/python_doc_printer.cc b/src/script/printer/doc_printer/python_doc_printer.cc
index 6851baf638..8634236df5 100644
--- a/src/script/printer/doc_printer/python_doc_printer.cc
+++ b/src/script/printer/doc_printer/python_doc_printer.cc
@@ -549,7 +549,16 @@ void PythonDocPrinter::PrintTypedDoc(const WhileDoc& doc) {
void PythonDocPrinter::PrintTypedDoc(const ForDoc& doc) {
MaybePrintCommentWithNewLine(doc);
output_ << "for ";
- PrintDoc(doc->lhs);
+ if (const auto* tuple = doc->lhs.as<TupleDocNode>()) {
+ if (tuple->elements.size() == 1) {
+ PrintDoc(tuple->elements[0]);
+ output_ << ",";
+ } else {
+ PrintJoinedDocs(tuple->elements, ", ");
+ }
+ } else {
+ PrintDoc(doc->lhs);
+ }
output_ << " in ";
PrintDoc(doc->rhs);
output_ << ":";
@@ -644,7 +653,12 @@ String DocToPythonScript(Doc doc, int indent_spaces, bool print_line_numbers, in
PythonDocPrinter printer(options);
printer.Append(doc, path_to_underline);
- return printer.GetString();
+ std::string result = printer.GetString();
+ int last_space = result.size();
+ while (last_space > 0 && std::isspace(result[last_space - 1])) {
+ last_space--;
+ }
+ return result.substr(0, last_space);
}
TVM_REGISTER_GLOBAL("script.printer.DocToPythonScript").set_body_typed(DocToPythonScript);
diff --git a/src/script/printer/ir_docsifier.cc b/src/script/printer/ir_docsifier.cc
index 8584f36031..4c52ce890c 100644
--- a/src/script/printer/ir_docsifier.cc
+++ b/src/script/printer/ir_docsifier.cc
@@ -27,7 +27,7 @@ namespace printer {
String GenerateUniqueName(std::string name_hint, std::unordered_set<String>* defined_names) {
for (char& c : name_hint) {
- if (c != 'c' && !std::isalnum(c)) {
+ if (c != '_' && !std::isalnum(c)) {
c = '_';
}
}
@@ -39,10 +39,10 @@ String GenerateUniqueName(std::string name_hint, std::unordered_set<String>* def
}
IdDoc IRDocsifierNode::Define(const ObjectRef& obj, const Frame& frame, const String& name_hint) {
+ ICHECK(obj2info.find(obj) == obj2info.end()) << "Duplicated object: " << obj;
String name = GenerateUniqueName(name_hint, &this->defined_names);
DocCreator doc_factory = [name]() { return IdDoc(name); };
- auto result = obj2info.insert({obj, VariableInfo{std::move(doc_factory), name}});
- ICHECK(result.second) << "Duplicated object: " << obj;
+ obj2info.insert({obj, VariableInfo{std::move(doc_factory), name}});
IdDoc def_doc(name);
frame->AddExitCallback([this, obj]() { this->RemoveVar(obj); });
return def_doc;
@@ -50,8 +50,6 @@ IdDoc IRDocsifierNode::Define(const ObjectRef& obj, const Frame& frame, const St
void IRDocsifierNode::Define(const ObjectRef& obj, const Frame& frame, DocCreator doc_factory) {
ICHECK(obj2info.find(obj) == obj2info.end()) << "Duplicated object: " << obj;
- ICHECK(!doc_factory()->IsInstance<IdDocNode>())
- << "IRDocsifierNode::Define cannot be used for variable that's mapped to IdDoc.";
obj2info.insert({obj, VariableInfo{std::move(doc_factory), NullOpt}});
frame->AddExitCallback([this, obj]() { this->RemoveVar(obj); });
}
@@ -146,9 +144,8 @@ void IRDocsifierNode::SetCommonPrefix(const ObjectRef& root,
this->common_prefix = std::move(visitor.common_prefix);
}
-IRDocsifier::IRDocsifier(Map<String, String> ir_prefix) {
+IRDocsifier::IRDocsifier() {
auto n = make_object<IRDocsifierNode>();
- n->ir_prefix = std::move(ir_prefix);
n->dispatch_tokens.push_back("");
data_ = std::move(n);
}
diff --git a/src/script/printer/printer.cc b/src/script/printer/printer.cc
index 47fd0b89b0..9ebdcb1e99 100644
--- a/src/script/printer/printer.cc
+++ b/src/script/printer/printer.cc
@@ -23,13 +23,10 @@ namespace tvm {
namespace script {
namespace printer {
-String Script(ObjectRef obj, Map<String, String> ir_prefix, int indent_spaces,
- bool print_line_numbers, int num_context_lines,
+String Script(ObjectRef obj, int indent_spaces, bool print_line_numbers, int num_context_lines,
Optional<ObjectPath> path_to_underline) {
- IRDocsifier d(ir_prefix);
- Doc doc = d->AsDoc(obj, ObjectPath::Root());
- return DocToPythonScript(doc, indent_spaces, print_line_numbers, num_context_lines,
- path_to_underline);
+ return DocToPythonScript(IRDocsifier()->AsDoc(obj, ObjectPath::Root()), indent_spaces,
+ print_line_numbers, num_context_lines, path_to_underline);
}
Default* Default::Instance() {
@@ -38,6 +35,19 @@ Default* Default::Instance() {
}
TVM_REGISTER_GLOBAL("script.printer.Script").set_body_typed(Script);
+TVM_REGISTER_GLOBAL("script.printer.DefaultIRPrefix")
+ .set_body_typed([](std::string ir, std::string prefix) { Default::Prefix(ir) = prefix; });
+TVM_REGISTER_GLOBAL("script.printer.DefaultBufferDType")
+ .set_body_typed([](runtime::DataType dtype) { Default::BufferDType() = dtype; });
+TVM_REGISTER_GLOBAL("script.printer.DefaultIntDType").set_body_typed([](runtime::DataType dtype) {
+ Default::IntDType() = dtype;
+});
+TVM_REGISTER_GLOBAL("script.printer.DefaultFloatDType").set_body_typed([](runtime::DataType dtype) {
+ Default::FloatDType() = dtype;
+});
+TVM_REGISTER_GLOBAL("script.printer.VerboseExpr").set_body_typed([](bool verbose_expr) {
+ Default::VerboseExpr() = verbose_expr;
+});
} // namespace printer
} // namespace script
diff --git a/src/script/printer/tir/block.cc b/src/script/printer/tir/block.cc
index f6dbf616a5..8f008375ff 100644
--- a/src/script/printer/tir/block.cc
+++ b/src/script/printer/tir/block.cc
@@ -26,14 +26,15 @@ Doc PrintBlock(IRDocsifier d, tir::Block block, ObjectPath block_p, //
Optional<tir::BlockRealize> opt_realize, Optional<ObjectPath> opt_realize_p) {
With<TIRFrame> frame(d, block);
ICHECK_EQ(opt_realize.defined(), opt_realize_p.defined());
- const tir::BlockRealizeNode* realize = opt_realize.value().get();
- const ObjectPathNode* realize_p = opt_realize_p.get();
+ const tir::BlockRealizeNode* realize =
+ opt_realize.defined() ? opt_realize.value().get() : nullptr;
+ const ObjectPathNode* realize_p = opt_realize_p.defined() ? opt_realize_p.get() : nullptr;
// Step 1. Handle block var and block bindings
int n_vars = block->iter_vars.size();
for (int i = 0; i < n_vars; ++i) {
tir::IterVar iter_var = block->iter_vars[i];
ObjectPath iter_var_p = block_p->Attr("iter_var")->ArrayIndex(i);
- ExprDoc rhs = TIR(d)->Attr("axis");
+ ExprDoc rhs = TIR("axis");
if (iter_var->iter_type == tir::IterVarType::kDataPar) {
rhs = rhs->Attr("spatial");
} else if (iter_var->iter_type == tir::IterVarType::kCommReduce) {
@@ -70,7 +71,7 @@ Doc PrintBlock(IRDocsifier d, tir::Block block, ObjectPath block_p, //
if (realize) {
ICHECK(realize->predicate.defined() && realize->predicate->dtype.is_bool());
if (!tir::is_one(realize->predicate)) {
- (*frame)->stmts.push_back(ExprStmtDoc(TIR(d)->Attr("where")->Call(
+ (*frame)->stmts.push_back(ExprStmtDoc(TIR("where")->Call(
{d->AsDoc<ExprDoc>(realize->predicate, realize_p->Attr("predicate"))})));
}
}
@@ -80,18 +81,17 @@ Doc PrintBlock(IRDocsifier d, tir::Block block, ObjectPath block_p, //
for (int i = 0, n = block->reads.size(); i < n; ++i) {
reads.push_back(d->AsDoc<ExprDoc>(block->reads[i], block_p->Attr("reads")->ArrayIndex(i)));
}
- (*frame)->stmts.push_back(ExprStmtDoc(TIR(d)->Attr("reads")->Call(reads)));
+ (*frame)->stmts.push_back(ExprStmtDoc(TIR("reads")->Call(reads)));
Array<ExprDoc> writes;
for (int i = 0, n = block->writes.size(); i < n; ++i) {
writes.push_back(d->AsDoc<ExprDoc>(block->writes[i], block_p->Attr("writes")->ArrayIndex(i)));
}
- (*frame)->stmts.push_back(ExprStmtDoc(TIR(d)->Attr("writes")->Call(writes)));
+ (*frame)->stmts.push_back(ExprStmtDoc(TIR("writes")->Call(writes)));
}
// Step 4. Handle block attributes
if (!block->annotations.empty()) {
(*frame)->stmts.push_back(ExprStmtDoc(
- TIR(d)
- ->Attr("block_attr")
+ TIR("block_attr")
->Call({d->AsDoc<ExprDoc>(block->annotations, block_p->Attr("annotations"))})));
}
// Step 5. Handle `alloc_buffer`
@@ -114,13 +114,19 @@ Doc PrintBlock(IRDocsifier d, tir::Block block, ObjectPath block_p, //
tir::Stmt init = block->init.value();
With<TIRFrame> init_frame(d, init);
AsDocBody(init, block_p->Attr("init"), init_frame->get(), d);
- (*frame)->stmts.push_back(
- ScopeDoc(NullOpt, TIR(d)->Attr("init")->Call({}), (*init_frame)->stmts));
+ (*frame)->stmts.push_back(ScopeDoc(NullOpt, TIR("init")->Call({}), (*init_frame)->stmts));
}
// Step 8. Handle block body
AsDocBody(block->body, block_p->Attr("body"), frame->get(), d);
- return ScopeDoc(NullOpt, TIR(d)->Attr("block")->Call({LiteralDoc::Str(block->name_hint)}),
- (*frame)->stmts);
+ Array<String> kwargs_keys;
+ Array<ExprDoc> kwargs_values;
+ if (!realize) {
+ kwargs_keys.push_back("no_realize");
+ kwargs_values.push_back(LiteralDoc::Boolean(true));
+ }
+ return ScopeDoc(
+ NullOpt, TIR("block")->Call({LiteralDoc::Str(block->name_hint)}, kwargs_keys, kwargs_values),
+ (*frame)->stmts);
}
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
@@ -134,16 +140,8 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
return PrintBlock(d, block, p, NullOpt, NullOpt);
});
-TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
- .set_dispatch<tir::MatchBufferRegion>(
- "", [](tir::MatchBufferRegion stmt, ObjectPath p, IRDocsifier d) -> Doc {
- Frame frame = d->frames.back();
- ExprDoc lhs = DefineBuffer(stmt->buffer, frame, d);
- ExprDoc src_buffer = d->AsDoc<ExprDoc>(stmt->source, p->Attr("source"));
- ExprDoc rhs = BufferDecl(stmt->buffer, "match_buffer", {src_buffer}, p->Attr("buffer"),
- d->frames.back(), d);
- return AssignDoc(lhs, rhs, NullOpt);
- });
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch<tir::BlockNode>(ReprPrint);
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch<tir::BlockRealizeNode>(ReprPrint);
} // namespace printer
} // namespace script
diff --git a/src/script/printer/tir/buffer.cc b/src/script/printer/tir/buffer.cc
index 3e1d71af4a..b9eef12abc 100644
--- a/src/script/printer/tir/buffer.cc
+++ b/src/script/printer/tir/buffer.cc
@@ -16,7 +16,7 @@
* specific language governing permissions and limitations
* under the License.
*/
-#include <tvm/runtime/device_api.h>
+#include <tvm/runtime/device_api.h> // For `kAllocAlignment`
#include "./utils.h"
@@ -121,73 +121,141 @@ ExprDoc BufferCall(const ExprDoc& prefix, const Map<String, ExprDoc>& attrs, Arr
ExprDoc BufferDecl(const tir::Buffer& buffer, const String& method, const Array<ExprDoc>& args,
const ObjectPath& p, const Frame& frame, const IRDocsifier& d) {
- return BufferCall(/*prefix=*/TIR(d)->Attr(method),
+ return BufferCall(/*prefix=*/TIR(method),
/*attrs=*/BufferAttrs(buffer, p, frame, d),
/*args=*/args);
}
-Doc BufferIndex(const PrimExpr& index, const ObjectPath& p, const IRDocsifier& d) {
- if (const auto* ramp = index.as<tir::RampNode>()) {
- if (const auto* stride = ramp->stride.as<IntImmNode>()) {
- ExprDoc start = d->AsDoc<ExprDoc>(ramp->base, p->Attr("base"));
- ExprDoc stop = d->AsDoc<ExprDoc>(ramp->base + ramp->lanes * ramp->stride, p->Attr("lanes"));
- Optional<ExprDoc> step = NullOpt;
- if (stride->value != 1) {
- step = d->AsDoc<ExprDoc>(ramp->stride, p->Attr("stride"));
+Array<Doc> BufferIndices(const Array<PrimExpr>& indices, const ObjectPath& p,
+ const IRDocsifier& d) {
+ int n = indices.size();
+ Array<Doc> indices_doc;
+ indices_doc.reserve(n);
+ for (int i = 0; i < n; ++i) {
+ if (const auto* ramp = indices[i].as<tir::RampNode>()) {
+ if (const auto* stride = ramp->stride.as<IntImmNode>()) {
+ ObjectPath ramp_p = p->Attr("indices")->ArrayIndex(i);
+ ObjectPath stride_p = ramp_p->Attr("stride");
+ ExprDoc start = d->AsDoc<ExprDoc>(ramp->base, //
+ ramp_p->Attr("base"));
+ ExprDoc stop = d->AsDoc<ExprDoc>(ramp->base + ramp->lanes * ramp->stride, //
+ ramp_p->Attr("lanes"));
+ Optional<ExprDoc> step = NullOpt;
+ if (stride->value != 1) {
+ step = d->AsDoc<ExprDoc>(ramp->stride, ramp_p->Attr("stride"));
+ }
+ indices_doc.push_back(SliceDoc(start, stop, step));
+ continue;
}
- return SliceDoc(start, stop, step);
}
+ indices_doc.push_back(d->AsDoc<ExprDoc>(indices[i], p->Attr("indices")->ArrayIndex(i)));
}
- return d->AsDoc<ExprDoc>(index, p);
+ return indices_doc;
}
-ExprDoc BufferIndices(const tir::Buffer& buffer, const Array<PrimExpr>& indices,
- const ObjectPath& p, const IRDocsifier& d) {
- int n = indices.size();
- Array<Doc> indices_doc;
- indices_doc.reserve(n);
+Array<Doc> BufferSlices(const Array<Range>& region, const ObjectPath& p, const IRDocsifier& d) {
+ int n = region.size();
+ Array<Doc> indices;
+ indices.reserve(n);
for (int i = 0; i < n; ++i) {
- indices_doc.push_back(BufferIndex(indices[i], p->Attr("indices")->ArrayIndex(i), d));
+ Range range = region[i];
+ ObjectPath range_p = p->ArrayIndex(i);
+ ExprDoc min = d->AsDoc<ExprDoc>(range->min, range_p->Attr("min"));
+ if (tir::is_one(range->extent)) {
+ indices.push_back(min);
+ } else {
+ ExprDoc max = d->AsDoc<ExprDoc>(range->min + range->extent, range_p->Attr("extent"));
+ indices.push_back(SliceDoc(min, max, NullOpt));
+ }
}
- return d->AsDoc<ExprDoc>(buffer, p->Attr("buffer"))[indices_doc];
+ return indices;
}
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<tir::BufferRegion>(
"", [](tir::BufferRegion buffer_region, ObjectPath p, IRDocsifier d) -> Doc {
ExprDoc prefix = d->AsDoc<ExprDoc>(buffer_region->buffer, p->Attr("buffer"));
- p = p->Attr("region");
- Array<Range> region = buffer_region->region;
- int n = region.size();
- Array<Doc> indices;
- indices.reserve(n);
- for (int i = 0; i < n; ++i) {
- Range range = region[i];
- ExprDoc min = d->AsDoc<ExprDoc>(range->min, p->ArrayIndex(i)->Attr("min"));
- if (tir::is_one(range->extent)) {
- indices.push_back(min);
- } else {
- ExprDoc max =
- d->AsDoc<ExprDoc>(range->min + range->extent, p->ArrayIndex(i)->Attr("extent"));
- indices.push_back(SliceDoc(min, max, NullOpt));
- }
- }
- return prefix[indices];
+ return prefix[BufferSlices(buffer_region->region, p->Attr("region"), d)];
});
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<tir::BufferStore>( //
"", [](tir::BufferStore store, ObjectPath p, IRDocsifier d) -> Doc {
- return AssignDoc(/*lhs=*/BufferIndices(store->buffer, store->indices, p, d),
+ ExprDoc buffer = d->AsDoc<ExprDoc>(store->buffer, p->Attr("buffer"));
+ return AssignDoc(/*lhs=*/buffer[BufferIndices(store->indices, p->Attr("indices"), d)],
/*rhs=*/d->AsDoc<ExprDoc>(store->value, p->Attr("value")), NullOpt);
});
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<tir::BufferLoad>( //
"", [](tir::BufferLoad load, ObjectPath p, IRDocsifier d) -> Doc {
- return BufferIndices(load->buffer, load->indices, p, d);
+ ExprDoc buffer = d->AsDoc<ExprDoc>(load->buffer, p->Attr("buffer"));
+ return buffer[BufferIndices(load->indices, p->Attr("indices"), d)];
+ });
+
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) //
+ .set_dispatch<tir::Buffer>("", [](tir::Buffer buffer, ObjectPath p, IRDocsifier d) -> Doc {
+ if (!d->IsVarDefined(buffer)) {
+ if (Optional<Frame> opt_f = FindLowestVarDef(buffer, d)) {
+ ExprDoc lhs = DefineBuffer(buffer, opt_f.value(), d);
+ ExprDoc rhs = BufferDecl(buffer, "buffer_decl", // TODO(@junrushao): name confusing
+ {}, p, opt_f.value(), d);
+ opt_f.value()->stmts.push_back(AssignDoc(lhs, rhs, NullOpt));
+ }
+ }
+ if (Optional<ExprDoc> doc = d->GetVarDoc(buffer)) {
+ return doc.value();
+ }
+ LOG(FATAL) << "IndexError: Buffer is not defined in the environment: " << buffer;
+ });
+
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
+ .set_dispatch<tir::MatchBufferRegion>(
+ "", [](tir::MatchBufferRegion stmt, ObjectPath p, IRDocsifier d) -> Doc {
+ Frame frame = d->frames.back();
+ ExprDoc lhs = DefineBuffer(stmt->buffer, frame, d);
+ ExprDoc src_buffer = d->AsDoc<ExprDoc>(stmt->source, p->Attr("source"));
+ ExprDoc rhs = BufferDecl(stmt->buffer, "match_buffer", {src_buffer}, p->Attr("buffer"),
+ d->frames.back(), d);
+ return AssignDoc(lhs, rhs, NullOpt);
+ });
+
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
+ .set_dispatch<tir::ProducerLoad>( //
+ "", [](tir::ProducerLoad load, ObjectPath p, IRDocsifier d) -> Doc {
+ ExprDoc prefix = IdDoc(load->producer->GetNameHint());
+ return prefix[BufferIndices(load->indices, p->Attr("indices"), d)];
+ });
+
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
+ .set_dispatch<tir::ProducerStore>( //
+ "", [](tir::ProducerStore store, ObjectPath p, IRDocsifier d) -> Doc {
+ ExprDoc prefix = IdDoc(store->producer->GetNameHint());
+ prefix = prefix[BufferIndices(store->indices, p->Attr("indices"), d)];
+ return AssignDoc(prefix, d->AsDoc<ExprDoc>(store->value, p->Attr("value")), NullOpt);
+ });
+
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
+ .set_dispatch<tir::ProducerRealize>( //
+ "", [](tir::ProducerRealize stmt, ObjectPath p, IRDocsifier d) -> Doc {
+ ExprDoc prefix = IdDoc(stmt->producer->GetNameHint());
+ prefix = prefix[BufferSlices(stmt->bounds, p->Attr("bounds"), d)];
+ prefix = TIR("ProducerRealize")
+ ->Call({prefix, d->AsDoc<ExprDoc>(stmt->condition, p->Attr("condition"))});
+ With<TIRFrame> f(d, stmt);
+ AsDocBody(stmt->body, p->Attr("body"), f->get(), d);
+ return ScopeDoc(NullOpt, prefix, (*f)->stmts);
});
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch<tir::BufferRegionNode>(ReprPrint);
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch<tir::BufferLoadNode>(ReprPrint);
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch<tir::BufferStoreNode>(ReprPrint);
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch<tir::BufferNode>(ReprPrint);
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch<tir::MatchBufferRegionNode>(ReprPrint);
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch<tir::ProducerLoadNode>(ReprPrint);
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch<tir::ProducerStoreNode>(ReprPrint);
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch<tir::ProducerRealizeNode>(ReprPrint);
+
} // namespace printer
} // namespace script
} // namespace tvm
diff --git a/src/script/printer/tir/expr.cc b/src/script/printer/tir/expr.cc
index f9b4eb6214..317201fa3d 100644
--- a/src/script/printer/tir/expr.cc
+++ b/src/script/printer/tir/expr.cc
@@ -34,7 +34,7 @@ Doc PrintVar(const tir::Var& var, const ObjectPath& p, const IRDocsifier& d) {
ExprDoc rhs = d->AsDoc<ExprDoc>(type, p->Attr("type_annotation"));
opt_f.value()->stmts.push_back(AssignDoc(lhs, rhs, NullOpt));
} else {
- ExprDoc rhs = TIR(d)->Attr("var")->Call({LiteralDoc::DataType(var->dtype)});
+ ExprDoc rhs = TIR("var")->Call({LiteralDoc::DataType(var->dtype)});
opt_f.value()->stmts.push_back(AssignDoc(lhs, rhs, NullOpt));
}
}
@@ -57,8 +57,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) //
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<tir::IterVar>("", [](tir::IterVar var, ObjectPath p, IRDocsifier d) -> Doc {
- return TIR(d)
- ->Attr("iter_var")
+ return TIR("iter_var")
->Call({
d->AsDoc<ExprDoc>(var->var, p->Attr("var")),
d->AsDoc<ExprDoc>(var->dom, p->Attr("dom")),
@@ -67,27 +66,11 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
});
});
-TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) //
- .set_dispatch<tir::Buffer>("", [](tir::Buffer buffer, ObjectPath p, IRDocsifier d) -> Doc {
- if (!d->IsVarDefined(buffer)) {
- if (Optional<Frame> opt_f = FindLowestVarDef(buffer, d)) {
- ExprDoc lhs = DefineBuffer(buffer, opt_f.value(), d);
- ExprDoc rhs = BufferDecl(buffer, "buffer_decl", // TODO(@junrushao): name confusing
- {}, p, opt_f.value(), d);
- opt_f.value()->stmts.push_back(AssignDoc(lhs, rhs, NullOpt));
- }
- }
- if (Optional<ExprDoc> doc = d->GetVarDoc(buffer)) {
- return doc.value();
- }
- LOG(FATAL) << "IndexError: Buffer is not defined in the environment: " << buffer;
- });
-
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<tir::Not>("", [](tir::Not node, ObjectPath p, IRDocsifier d) -> Doc {
ExprDoc a = d->AsDoc<ExprDoc>(node->a, p->Attr("a"));
if (a->IsInstance<LiteralDocNode>()) {
- return TIR(d)->Attr("Not")->Call({a});
+ return TIR("Not")->Call({a});
}
return OperationDoc(OperationDocNode::Kind::kNot, {a});
});
@@ -101,12 +84,12 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<tir::Cast>("", [](tir::Cast cast, ObjectPath p, IRDocsifier d) -> Doc {
ExprDoc dtype = LiteralDoc::DataType(cast->dtype);
ExprDoc value = d->AsDoc<ExprDoc>(cast->value, p->Attr("value"));
- return TIR(d)->Attr("Cast")->Call({dtype, value});
+ return TIR("Cast")->Call({dtype, value});
});
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<tir::Select>("", [](tir::Select select, ObjectPath p, IRDocsifier d) -> Doc {
- return TIR(d)->Attr("Select")->Call({
+ return TIR("Select")->Call({
d->AsDoc<ExprDoc>(select->condition, p->Attr("condition")),
d->AsDoc<ExprDoc>(select->true_value, p->Attr("true_value")),
d->AsDoc<ExprDoc>(select->false_value, p->Attr("false_value")),
@@ -115,7 +98,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<tir::Ramp>("", [](tir::Ramp ramp, ObjectPath p, IRDocsifier d) -> Doc {
- return TIR(d)->Attr("Ramp")->Call({
+ return TIR("Ramp")->Call({
d->AsDoc<ExprDoc>(ramp->base, p->Attr("base")),
d->AsDoc<ExprDoc>(ramp->stride, p->Attr("stride")),
LiteralDoc::Int(ramp->lanes),
@@ -124,8 +107,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<tir::Broadcast>("", [](tir::Broadcast bc, ObjectPath p, IRDocsifier d) -> Doc {
- return TIR(d)
- ->Attr("Broadcast")
+ return TIR("Broadcast")
->Call({
d->AsDoc<ExprDoc>(bc->value, p->Attr("value")),
LiteralDoc::Int(bc->lanes),
@@ -135,7 +117,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<tir::Shuffle>( //
"", [](tir::Shuffle shuffle, ObjectPath p, IRDocsifier d) -> Doc {
- return TIR(d)->Attr("Shuffle")->Call({
+ return TIR("Shuffle")->Call({
d->AsDoc<ExprDoc>(shuffle->vectors, p->Attr("vectors")),
d->AsDoc<ExprDoc>(shuffle->indices, p->Attr("indices")),
});
@@ -170,12 +152,12 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
}
}
ExprDoc id = d->AsDoc<ExprDoc>(r->identity_element, p->Attr("identity_element"));
- return TIR(d)->Attr("comm_reducer")->Call({lambda, id});
+ return TIR("comm_reducer")->Call({lambda, id});
});
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<tir::Let>("", [](tir::Let let, ObjectPath p, IRDocsifier d) -> Doc {
- return TIR(d)->Attr("let")->Call({
+ return TIR("let")->Call({
d->AsDoc<ExprDoc>(let->var, p->Attr("var")),
d->AsDoc<ExprDoc>(let->value, p->Attr("value")),
d->AsDoc<ExprDoc>(let->body, p->Attr("body")),
@@ -209,7 +191,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
ExprDoc prefix{nullptr};
if (const auto* op = call->op.as<OpNode>()) {
String name = op_names[GetRef<Op>(op)];
- prefix = TIR(d)->Attr(name);
+ prefix = TIR(name);
} else if (const auto* gv = call->op.as<GlobalVarNode>()) {
prefix = LiteralDoc::Str(gv->name_hint);
} else {
@@ -232,20 +214,22 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<tir::Any>("", [](tir::Any any, ObjectPath p, IRDocsifier d) -> Doc {
- return TIR(d)->Attr("Any")->Call({});
+ return TIR("Any")->Call({});
});
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<tir::Reduce>("", [](tir::Reduce r, ObjectPath p, IRDocsifier d) -> Doc {
+ ExprDoc combiner = d->AsDoc<ExprDoc>(r->combiner, p->Attr("combiner"));
+ ExprDoc source = d->AsDoc<ExprDoc>(r->source, p->Attr("source"));
+ ExprDoc init = d->AsDoc<ExprDoc>(r->init, p->Attr("init"));
+ ExprDoc axis = d->AsDoc<ExprDoc>(r->axis, p->Attr("axis"));
+ ExprDoc condition = d->AsDoc<ExprDoc>(r->condition, p->Attr("condition"));
+ ExprDoc value_index = LiteralDoc::Int(r->value_index);
+ return TIR("reduce")->Call({combiner}, {"source", "init", "axis", "condition", "value_index"},
+ {source, init, axis, condition, value_index});
LOG(FATAL) << "ValueError: Reduce should never exist in TIR: " << r;
});
-TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
- .set_dispatch<tir::ProducerLoad>(
- "", [](tir::ProducerLoad load, ObjectPath p, IRDocsifier d) -> Doc {
- LOG(FATAL) << "ValueError: ProducerLoad should never exist in TIR: " << load;
- });
-
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<tir::Load>("", [](tir::Load load, ObjectPath p, IRDocsifier d) -> Doc {
LOG(FATAL) << "ValueError: Load has been deprecated for BufferLoad: " << load;
@@ -257,7 +241,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
[](tir::NodeType node, ObjectPath p, IRDocsifier d) -> Doc { \
ExprDoc a = d->AsDoc<ExprDoc>(node->a, p->Attr("a")); \
ExprDoc b = d->AsDoc<ExprDoc>(node->b, p->Attr("b")); \
- return TIR(d)->Attr(OpString)->Call({a, b}); \
+ return TIR(OpString)->Call({a, b}); \
});
#define TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(NodeType, OpString, OpKind) \
@@ -267,7 +251,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
ExprDoc a = d->AsDoc<ExprDoc>(node->a, p->Attr("a")); \
ExprDoc b = d->AsDoc<ExprDoc>(node->b, p->Attr("b")); \
if (a->IsInstance<LiteralDocNode>() && b->IsInstance<LiteralDocNode>()) { \
- return TIR(d)->Attr(OpString)->Call({a, b}); \
+ return TIR(OpString)->Call({a, b}); \
} \
return OperationDoc(OperationDocNode::Kind::OpKind, {a, b}); \
});
@@ -294,6 +278,40 @@ TVM_SCRIPT_PRINTER_DEF_BINARY(Max, "max");
#undef TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR
#undef TVM_SCRIPT_PRINTER_DEF_BINARY
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch<tir::VarNode>(ReprPrint);
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch<tir::SizeVarNode>(ReprPrint);
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch<tir::IterVarNode>(ReprPrint);
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch<tir::StringImmNode>(ReprPrint);
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch<tir::CastNode>(ReprPrint);
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch<tir::AddNode>(ReprPrint);
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch<tir::SubNode>(ReprPrint);
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch<tir::MulNode>(ReprPrint);
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch<tir::DivNode>(ReprPrint);
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch<tir::ModNode>(ReprPrint);
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch<tir::FloorDivNode>(ReprPrint);
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch<tir::FloorModNode>(ReprPrint);
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch<tir::MinNode>(ReprPrint);
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch<tir::MaxNode>(ReprPrint);
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch<tir::LTNode>(ReprPrint);
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch<tir::LENode>(ReprPrint);
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch<tir::EQNode>(ReprPrint);
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch<tir::NENode>(ReprPrint);
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch<tir::GTNode>(ReprPrint);
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch<tir::GENode>(ReprPrint);
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch<tir::AndNode>(ReprPrint);
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch<tir::OrNode>(ReprPrint);
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch<tir::NotNode>(ReprPrint);
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch<tir::SelectNode>(ReprPrint);
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch<tir::RampNode>(ReprPrint);
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch<tir::BroadcastNode>(ReprPrint);
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch<tir::LetNode>(ReprPrint);
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch<tir::CallNode>(ReprPrint);
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch<tir::ShuffleNode>(ReprPrint);
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch<tir::CommReducerNode>(ReprPrint);
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch<tir::AnyNode>(ReprPrint);
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch<tir::ReduceNode>(ReprPrint);
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch<tir::LoadNode>(ReprPrint);
+
} // namespace printer
} // namespace script
} // namespace tvm
diff --git a/src/script/printer/tir/for_loop.cc b/src/script/printer/tir/for_loop.cc
index 6a375935bd..239b8e565f 100644
--- a/src/script/printer/tir/for_loop.cc
+++ b/src/script/printer/tir/for_loop.cc
@@ -59,7 +59,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
p = p->Attr("body");
}
AsDocBody(grid.back()->body, p, (*f).get(), d);
- return ForDoc(TupleDoc(lhs), TIR(d)->Attr("grid")->Call(rhs), (*f)->stmts);
+ return ForDoc(TupleDoc(lhs), TIR("grid")->Call(rhs), (*f)->stmts);
}
// Step 3. If not `T.grid`, print loop kind accordingly
IdDoc lhs = DefineVar(loop->loop_var, *f, d);
@@ -76,21 +76,21 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
if (!loop->annotations.empty()) {
annotations = d->AsDoc<ExprDoc>(loop->annotations, p->Attr("annotations"));
}
- ExprDoc prefix = TIR(d);
+ ExprDoc prefix{nullptr};
if (loop->kind == tir::ForKind::kSerial) {
if (loop->annotations.empty()) {
prefix = IdDoc("range");
} else {
- prefix = prefix->Attr("serial");
+ prefix = TIR("serial");
}
} else if (loop->kind == tir::ForKind::kParallel) {
- prefix = prefix->Attr("parallel");
+ prefix = TIR("parallel");
} else if (loop->kind == tir::ForKind::kUnrolled) {
- prefix = prefix->Attr("unroll");
+ prefix = TIR("unroll");
} else if (loop->kind == tir::ForKind::kVectorized) {
- prefix = prefix->Attr("vectorized");
+ prefix = TIR("vectorized");
} else if (loop->kind == tir::ForKind::kThreadBinding) {
- prefix = prefix->Attr("thread_binding");
+ prefix = TIR("thread_binding");
thread = LiteralDoc::Str(loop->thread_binding.value()->thread_tag);
} else {
LOG(FATAL) << "ValueError: Unknown ForKind: " << tir::ForKind2String(loop->kind);
@@ -117,6 +117,8 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
return ForDoc(lhs, rhs, (*f)->stmts);
});
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch<tir::ForNode>(ReprPrint);
+
} // namespace printer
} // namespace script
} // namespace tvm
diff --git a/src/script/printer/tir/function.cc b/src/script/printer/tir/function.cc
index d47a60209e..55e8c075de 100644
--- a/src/script/printer/tir/function.cc
+++ b/src/script/printer/tir/function.cc
@@ -36,11 +36,7 @@ String FindFunctionName(const IRDocsifier& d, const tir::PrimFunc& f) {
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<tir::PrimFunc>("", [](tir::PrimFunc func, ObjectPath p, IRDocsifier d) -> Doc {
- d->SetCommonPrefix(func, [](const ObjectRef& obj) {
- return obj->IsInstance<tir::VarNode>() || obj->IsInstance<tir::BufferNode>();
- });
- With<TIRFrame> frame(d, func);
- (*frame)->AddDispatchToken(d, "tir");
+ With<TIRFrame> frame(MakeDispatchFrame(d, func, func));
int n_args = func->params.size();
// Step 1. Handle `func->params`
Array<AssignDoc> args;
@@ -54,8 +50,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
// Step 2. Handle `func->attrs`
if (func->attrs.defined() && !func->attrs->dict.empty()) {
(*frame)->stmts.push_back(
- ExprStmtDoc(TIR(d)
- ->Attr("func_attr") //
+ ExprStmtDoc(TIR("func_attr") //
->Call({d->AsDoc<ExprDoc>(func->attrs, p->Attr("attrs"))})));
}
// Step 3. Handle `func->buffer_map`
@@ -76,11 +71,17 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
return FunctionDoc(
/*name=*/IdDoc(FindFunctionName(d, func)),
/*args=*/args,
- /*decorators=*/{TIR(d)->Attr("prim_func")},
+ /*decorators=*/{TIR("prim_func")},
/*return_type=*/d->AsDoc<ExprDoc>(func->ret_type, p->Attr("ret_type")),
/*body=*/(*frame)->stmts);
});
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+ .set_dispatch<tir::PrimFuncNode>([](const ObjectRef& obj, ReprPrinter* p) {
+ std::string res = DocToPythonScript(IRDocsifier()->AsDoc(obj, ObjectPath::Root()));
+ p->stream << res;
+ });
+
} // namespace printer
} // namespace script
} // namespace tvm
diff --git a/src/script/printer/tir/ir.cc b/src/script/printer/tir/ir.cc
index f4e3762fc0..5fea278a44 100644
--- a/src/script/printer/tir/ir.cc
+++ b/src/script/printer/tir/ir.cc
@@ -34,8 +34,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
} else if (dtype == DataType::Bool()) {
return LiteralDoc::Boolean(imm->value);
} else {
- return TIR(d) //
- ->Attr(runtime::DLDataType2String(dtype))
+ return TIR(runtime::DLDataType2String(dtype)) //
->Call({LiteralDoc::Int(imm->value)});
}
});
@@ -46,15 +45,14 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
if (dtype == Default::FloatDType()) {
return LiteralDoc::Float(imm->value);
} else {
- return TIR(d)
- ->Attr(runtime::DLDataType2String(dtype))
+ return TIR(runtime::DLDataType2String(dtype)) //
->Call({LiteralDoc::Float(imm->value)});
}
});
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<Range>("", [](Range range, ObjectPath p, IRDocsifier d) -> Doc {
- return TIR(d)->Attr("Range")->Call({
+ return TIR("Range")->Call({
d->AsDoc<ExprDoc>(range->min, p->Attr("min")),
d->AsDoc<ExprDoc>(range->extent, p->Attr("extent")),
});
@@ -63,16 +61,23 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<PrimType>("", [](PrimType ty, ObjectPath p, IRDocsifier d) -> Doc {
std::string dtype = ty->dtype.is_void() ? "void" : runtime::DLDataType2String(ty->dtype);
- return TIR(d)->Attr(dtype);
+ return TIR(dtype);
});
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<PointerType>("", [](PointerType ty, ObjectPath p, IRDocsifier d) -> Doc {
- ExprDoc element_type = d->AsDoc<ExprDoc>(ty->element_type, p->Attr("element_type"));
+ ExprDoc element_type{nullptr};
+ if (const auto* prim_type = ty->element_type.as<PrimTypeNode>()) {
+ std::string dtype =
+ prim_type->dtype.is_void() ? "void" : runtime::DLDataType2String(prim_type->dtype);
+ element_type = LiteralDoc::Str(dtype);
+ } else {
+ element_type = d->AsDoc<ExprDoc>(ty->element_type, p->Attr("element_type"));
+ }
if (ty->storage_scope == "") {
- return TIR(d)->Attr("Ptr")->Call({element_type});
+ return TIR("Ptr")->Call({element_type});
} else {
- return TIR(d)->Attr("Ptr")->Call({element_type, LiteralDoc::Str(ty->storage_scope)});
+ return TIR("Ptr")->Call({element_type, LiteralDoc::Str(ty->storage_scope)});
}
});
@@ -81,17 +86,28 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
if (ty->fields.empty()) {
return LiteralDoc::None();
}
- return TIR(d) //
- ->Attr("Tuple")
- ->Call(d->AsDoc<ListDoc>(ty->fields, p->Attr("fields"))->elements);
+ return TIR("Tuple")->Call(d->AsDoc<ListDoc>(ty->fields, p->Attr("fields"))->elements);
+ });
+
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
+ .set_dispatch<IncompleteType>("", [](IncompleteType ty, ObjectPath p, IRDocsifier d) -> Doc {
+ return TIR("IncompleteType")->Call({});
});
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<Target>("", [](Target target, ObjectPath p, IRDocsifier d) -> Doc {
Map<String, ObjectRef> config = target->Export();
- return TIR(d)->Attr("target")->Call({d->AsDoc<ExprDoc>(config, p)});
+ return TIR("target")->Call({d->AsDoc<ExprDoc>(config, p)});
});
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch<IntImmNode>(ReprPrint);
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch<FloatImmNode>(ReprPrint);
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch<RangeNode>(ReprPrint);
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch<PrimTypeNode>(ReprPrint);
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch<PointerTypeNode>(ReprPrint);
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch<TupleTypeNode>(ReprPrint);
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch<IncompleteTypeNode>(ReprPrint);
+
} // namespace printer
} // namespace script
} // namespace tvm
diff --git a/src/script/printer/tir/stmt.cc b/src/script/printer/tir/stmt.cc
index 03e5657d24..436f2b202d 100644
--- a/src/script/printer/tir/stmt.cc
+++ b/src/script/printer/tir/stmt.cc
@@ -16,7 +16,7 @@
* specific language governing permissions and limitations
* under the License.
*/
-#include "../../../tir/transforms/ir_utils.h"
+#include "../../../tir/transforms/ir_utils.h" // For `GetPtrStorageScope`
#include "./utils.h"
namespace tvm {
@@ -51,7 +51,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
if (eval->value->IsInstance<tir::CallNode>()) {
return ExprStmtDoc(value);
}
- return ExprStmtDoc(TIR(d)->Attr("evaluate")->Call({value}));
+ return ExprStmtDoc(TIR("evaluate")->Call({value}));
});
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
@@ -75,7 +75,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
stmts->insert(stmts->begin(), AssignDoc(lhs, rhs, type_doc));
return StmtBlockDoc(*stmts);
} else {
- rhs = TIR(d)->Attr("let")->Call({lhs, rhs});
+ rhs = TIR("let")->Call({lhs, rhs});
return ScopeDoc(NullOpt, rhs, *stmts);
}
});
@@ -93,7 +93,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
stmts->insert(stmts->begin(), AssertDoc(cond, msg));
return StmtBlockDoc(*stmts);
}
- return ScopeDoc(NullOpt, TIR(d)->Attr("Assert")->Call({cond, msg}), (*f)->stmts);
+ return ScopeDoc(NullOpt, TIR("Assert")->Call({cond, msg}), (*f)->stmts);
});
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
@@ -137,7 +137,6 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<tir::SeqStmt>("", [](tir::SeqStmt stmt, ObjectPath p, IRDocsifier d) -> Doc {
- // TODO(@junrushao): revisit for fragment printing
With<TIRFrame> f(d, stmt);
AsDocBody(stmt, p, f->get(), d);
return StmtBlockDoc((*f)->stmts);
@@ -146,8 +145,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<tir::Prefetch>( //
"", [](tir::Prefetch stmt, ObjectPath p, IRDocsifier d) -> Doc {
- return ExprStmtDoc(TIR(d)
- ->Attr("prefetch")
+ return ExprStmtDoc(TIR("prefetch")
->Call({
d->AsDoc<ExprDoc>(stmt->buffer, p->Attr("buffer")),
d->AsDoc<ExprDoc>(stmt->bounds, p->Attr("bounds")),
@@ -174,7 +172,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
}
ExprDoc lhs = DefineVar(stmt->buffer_var, d->frames.back(), d);
With<TIRFrame> f(d, stmt);
- ExprDoc rhs = TIR(d)->Attr("allocate")->Call(args, kwargs_keys, kwargs_values);
+ ExprDoc rhs = TIR("allocate")->Call(args, kwargs_keys, kwargs_values);
AsDocBody(stmt->body, p->Attr("body"), f->get(), d);
return DoConciseScoping(lhs, rhs, &(*f)->stmts, concise);
});
@@ -253,7 +251,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
args.push_back(data_doc);
args.push_back(LiteralDoc::DataType(stmt->dtype));
args.push_back(d->AsDoc<ExprDoc>(stmt->extents, p->Attr("extents")));
- ExprDoc rhs = TIR(d)->Attr("allocate_const")->Call(args, kwargs_keys, kwargs_values);
+ ExprDoc rhs = TIR("allocate_const")->Call(args, kwargs_keys, kwargs_values);
With<TIRFrame> f(d, stmt);
ExprDoc lhs = DefineVar(stmt->buffer_var, *f, d);
AsDocBody(stmt->body, p->Attr("body"), f->get(), d);
@@ -286,7 +284,7 @@ ExprDoc DocsifyBufferRealize(const tir::BufferRealizeNode* stmt, Optional<ExprDo
kwargs_keys.push_back("condition");
kwargs_values.push_back(d->AsDoc<ExprDoc>(stmt->condition, p->Attr("condition")));
}
- return TIR(d)->Attr("realize")->Call(args, kwargs_keys, kwargs_values);
+ return TIR("realize")->Call(args, kwargs_keys, kwargs_values);
}
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
@@ -326,13 +324,10 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
DefineVar(iter_var->var, f, d);
f->stmts.push_back(
AssignDoc(d->AsDoc<ExprDoc>(iter_var->var, p->Attr("node")->Attr("var")),
- TIR(d) //
- ->Attr("env_thread")
- ->Call({LiteralDoc::Str(iter_var->thread_tag)}), //
+ TIR("env_thread")->Call({LiteralDoc::Str(iter_var->thread_tag)}), //
NullOpt));
}
- rhs = TIR(d)
- ->Attr("launch_thread")
+ rhs = TIR("launch_thread")
->Call({
d->AsDoc<ExprDoc>(iter_var->var, p->Attr("node")),
d->AsDoc<ExprDoc>(stmt->value, p->Attr("value")),
@@ -340,7 +335,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
}
}
if (!rhs.defined()) {
- rhs = TIR(d)->Attr("attr")->Call({
+ rhs = TIR("attr")->Call({
d->AsDoc<ExprDoc>(stmt->node, p->Attr("node")),
LiteralDoc::Str(stmt->attr_key),
d->AsDoc<ExprDoc>(stmt->value, p->Attr("value")),
@@ -351,24 +346,26 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
return DoConciseScoping(NullOpt, rhs.value(), &(*f)->stmts, concise);
});
-TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
- .set_dispatch<tir::ProducerRealize>( //
- "", [](tir::ProducerRealize stmt, ObjectPath p, IRDocsifier d) -> Doc {
- LOG(FATAL) << "ValueError: ProducerRealize should never exist in TIR: " << stmt;
- });
-
-TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
- .set_dispatch<tir::ProducerStore>( //
- "", [](tir::ProducerStore stmt, ObjectPath p, IRDocsifier d) -> Doc {
- LOG(FATAL) << "ValueError: ProducerStore should never exist in TIR: " << stmt;
- });
-
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<tir::Store>( //
"", [](tir::Store stmt, ObjectPath p, IRDocsifier d) -> Doc {
LOG(FATAL) << "ValueError: Store has been deprecated for BufferStore: " << stmt;
});
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch<tir::LetStmtNode>(ReprPrint);
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch<tir::AttrStmtNode>(ReprPrint);
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch<tir::AssertStmtNode>(ReprPrint);
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch<tir::WhileNode>(ReprPrint);
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch<tir::AllocateNode>(ReprPrint);
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch<tir::AllocateConstNode>(ReprPrint);
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch<tir::DeclBufferNode>(ReprPrint);
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch<tir::PrefetchNode>(ReprPrint);
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch<tir::SeqStmtNode>(ReprPrint);
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch<tir::IfThenElseNode>(ReprPrint);
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch<tir::EvaluateNode>(ReprPrint);
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch<tir::BufferRealizeNode>(ReprPrint);
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch<tir::StoreNode>(ReprPrint);
+
} // namespace printer
} // namespace script
} // namespace tvm
diff --git a/src/script/printer/tir/utils.h b/src/script/printer/tir/utils.h
index 6cae378d0e..7f67c3a11c 100644
--- a/src/script/printer/tir/utils.h
+++ b/src/script/printer/tir/utils.h
@@ -28,6 +28,7 @@
#include <tvm/tir/op.h>
#include <tvm/tir/stmt.h>
+#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
@@ -70,9 +71,7 @@ class TIRFrame : public Frame {
};
/*! \brief Creates the TIR common prefix, which is by default `T` */
-inline IdDoc TIR(const IRDocsifier& d) { //
- return IdDoc(d->ir_prefix.Get("tir").value_or("T"));
-}
+inline ExprDoc TIR(const String& attr) { return IdDoc(Default::Prefix("tir"))->Attr(attr); }
/*!
* \brief Defines a variable in the IRDocsifier at the given frame,
@@ -141,10 +140,15 @@ inline Optional<Frame> FindLowestVarDef(const ObjectRef& var, const IRDocsifier&
}
int n_frames = d->frames.size();
std::unordered_map<const Object*, const FrameNode*> tir_to_frame;
+ const FrameNode* fallback_frame = nullptr;
tir_to_frame.reserve(n_frames);
for (int i = n_frames - 1; i >= 0; --i) {
if (const auto* f = d->frames[i].as<TIRFrameNode>()) {
- tir_to_frame[f->tir.get()] = f;
+ if (f->tir.defined()) {
+ tir_to_frame[f->tir.get()] = f;
+ } else if (fallback_frame == nullptr) {
+ fallback_frame = f;
+ }
}
}
const std::vector<const Object*>& path = d->common_prefix.at(var.get());
@@ -153,9 +157,52 @@ inline Optional<Frame> FindLowestVarDef(const ObjectRef& var, const IRDocsifier&
return GetRef<Frame>(tir_to_frame.at(*it));
}
}
+ if (fallback_frame != nullptr) {
+ return GetRef<Frame>(fallback_frame);
+ }
return NullOpt;
}
+/*!
+ * \brief Create a frame and add dispatch token. Calculate LCA information for the frame.
+ * \param d The IRDocsifier
+ * \param root The root of the TIR AST
+ * \param tir The TIR to be saved in the new TIR frame
+ * \return The frame created
+ */
+inline TIRFrame MakeDispatchFrame(const IRDocsifier& d, const ObjectRef& root,
+ const ObjectRef& tir) {
+ d->SetCommonPrefix(root, [](const ObjectRef& obj) {
+ return obj->IsInstance<tir::VarNode>() || obj->IsInstance<tir::BufferNode>();
+ });
+ TIRFrame frame(d, tir);
+ frame->AddDispatchToken(d, "tir");
+ return frame;
+}
+
+/*! \brief Redirected method for the ReprPrinter */
+inline void ReprPrint(const ObjectRef& stmt, ReprPrinter* p) {
+ IRDocsifier d;
+ With<TIRFrame> f(MakeDispatchFrame(d, stmt, ObjectRef(nullptr)));
+ Doc doc = d->AsDoc(stmt, ObjectPath::Root());
+ if (const auto* expr_doc = doc.as<ExprDocNode>()) {
+ if (!Default::VerboseExpr()) {
+ (*f)->stmts.clear();
+ }
+ (*f)->stmts.push_back(ExprStmtDoc(GetRef<ExprDoc>(expr_doc)));
+ } else if (const auto* stmt_doc = doc.as<StmtDocNode>()) {
+ (*f)->stmts.push_back(GetRef<StmtDoc>(stmt_doc));
+ } else if (const auto* stmt_block = doc.as<StmtBlockDocNode>()) {
+ for (const StmtDoc& d : stmt_block->stmts) {
+ (*f)->stmts.push_back(d);
+ }
+ } else {
+ LOG(FATAL) << "TypeError: Unexpected doc type: " << doc->GetTypeKey();
+ }
+ std::string res = DocToPythonScript(StmtBlockDoc((*f)->stmts));
+ p->stream << res;
+}
+
/*!
* \brief Declare and define a buffer
* \param buffer The buffer to be defined
diff --git a/src/tir/ir/buffer.cc b/src/tir/ir/buffer.cc
index 0dfda954b8..c2e6fad42d 100644
--- a/src/tir/ir/buffer.cc
+++ b/src/tir/ir/buffer.cc
@@ -612,12 +612,6 @@ tir::Buffer BufferWithOffsetAlignment(Array<PrimExpr> shape, DataType dtype, std
offset_factor, buffer_type);
}
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
- .set_dispatch<BufferNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const BufferNode*>(node.get());
- p->stream << "buffer(" << op->name << ", " << op << ")";
- });
-
TVM_REGISTER_NODE_TYPE(BufferNode);
TVM_REGISTER_GLOBAL("tir.Buffer").set_body([](TVMArgs args, TVMRetValue* ret) {
diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc
index daae7eaf68..40606761f8 100644
--- a/src/tir/ir/expr.cc
+++ b/src/tir/ir/expr.cc
@@ -116,14 +116,6 @@ TVM_REGISTER_GLOBAL("tir.Var").set_body_typed([](String name_hint, runtime::TVMA
TVM_REGISTER_NODE_TYPE(VarNode);
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
- .set_dispatch<VarNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const VarNode*>(node.get());
- // omit the type
- // stream << op->name << "." << op->type;
- p->stream << op->name_hint;
- });
-
// SizeVar
SizeVar::SizeVar(String name_hint, DataType dtype, Span span) {
auto n = make_object<SizeVarNode>();
@@ -140,12 +132,6 @@ TVM_REGISTER_GLOBAL("tir.SizeVar").set_body_typed([](String s, DataType t, Span
TVM_REGISTER_NODE_TYPE(SizeVarNode);
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
- .set_dispatch<SizeVarNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const SizeVarNode*>(node.get());
- p->stream << "{" << op->name_hint << "|" << op->name_hint << ">=0}";
- });
-
// IterVar
IterVar::IterVar(Range dom, Var var, IterVarType t, String thread_tag, Span span) {
ObjectPtr<IterVarNode> n = make_object<IterVarNode>();
@@ -171,22 +157,6 @@ TVM_REGISTER_GLOBAL("tir.IterVar")
return IterVar(dom, var, static_cast<IterVarType>(iter_type), thread_tag, span);
});
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
- .set_dispatch<IterVarNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const IterVarNode*>(node.get());
- p->stream << "iter_var(";
- if (op->var->name_hint.length() != 0) {
- p->stream << op->var->name_hint << ", ";
- }
- if (op->dom.defined()) {
- p->stream << op->dom;
- }
- if (op->thread_tag.length() != 0) {
- p->stream << ", " << op->thread_tag;
- }
- p->stream << ")";
- });
-
TVM_REGISTER_NODE_TYPE(IterVarNode);
// StringImm
@@ -204,12 +174,6 @@ TVM_REGISTER_GLOBAL("tir.StringImm").set_body_typed([](String value, Span span)
TVM_REGISTER_NODE_TYPE(StringImmNode);
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
- .set_dispatch<StringImmNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const StringImmNode*>(node.get());
- p->stream << '\"' << support::StrEscape(op->value) << '\"';
- });
-
// Cast
Cast::Cast(DataType t, PrimExpr value, Span span) {
ICHECK(value.defined());
@@ -227,14 +191,6 @@ TVM_REGISTER_GLOBAL("tir.Cast").set_body_typed([](DataType dtype, PrimExpr value
TVM_REGISTER_NODE_TYPE(CastNode);
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
- .set_dispatch<CastNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const CastNode*>(node.get());
- p->stream << op->dtype << '(';
- p->Print(op->value);
- p->stream << ')';
- });
-
// Add
TVM_DEFINE_BINOP_CONSTRUCTOR(Add);
@@ -244,16 +200,6 @@ TVM_REGISTER_GLOBAL("tir.Add").set_body_typed([](PrimExpr a, PrimExpr b, Span sp
TVM_REGISTER_NODE_TYPE(AddNode);
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
- .set_dispatch<AddNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const AddNode*>(node.get());
- p->stream << '(';
- p->Print(op->a);
- p->stream << " + ";
- p->Print(op->b);
- p->stream << ')';
- });
-
// Sub
TVM_DEFINE_BINOP_CONSTRUCTOR(Sub);
@@ -263,16 +209,6 @@ TVM_REGISTER_GLOBAL("tir.Sub").set_body_typed([](PrimExpr a, PrimExpr b, Span sp
TVM_REGISTER_NODE_TYPE(SubNode);
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
- .set_dispatch<SubNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const SubNode*>(node.get());
- p->stream << '(';
- p->Print(op->a);
- p->stream << " - ";
- p->Print(op->b);
- p->stream << ')';
- });
-
// Mul
TVM_DEFINE_BINOP_CONSTRUCTOR(Mul);
@@ -282,16 +218,6 @@ TVM_REGISTER_GLOBAL("tir.Mul").set_body_typed([](PrimExpr a, PrimExpr b, Span sp
TVM_REGISTER_NODE_TYPE(MulNode);
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
- .set_dispatch<MulNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const MulNode*>(node.get());
- p->stream << '(';
- p->Print(op->a);
- p->stream << "*";
- p->Print(op->b);
- p->stream << ')';
- });
-
// Div
TVM_DEFINE_BINOP_CONSTRUCTOR(Div);
@@ -301,16 +227,6 @@ TVM_REGISTER_GLOBAL("tir.Div").set_body_typed([](PrimExpr a, PrimExpr b, Span sp
TVM_REGISTER_NODE_TYPE(DivNode);
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
- .set_dispatch<DivNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const DivNode*>(node.get());
- p->stream << '(';
- p->Print(op->a);
- p->stream << "/";
- p->Print(op->b);
- p->stream << ')';
- });
-
// Mod
TVM_DEFINE_BINOP_CONSTRUCTOR(Mod);
@@ -320,16 +236,6 @@ TVM_REGISTER_GLOBAL("tir.Mod").set_body_typed([](PrimExpr a, PrimExpr b, Span sp
TVM_REGISTER_NODE_TYPE(ModNode);
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
- .set_dispatch<ModNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const ModNode*>(node.get());
- p->stream << '(';
- p->Print(op->a);
- p->stream << " % ";
- p->Print(op->b);
- p->stream << ')';
- });
-
// FloorDiv
TVM_DEFINE_BINOP_CONSTRUCTOR(FloorDiv);
@@ -339,12 +245,6 @@ TVM_REGISTER_GLOBAL("tir.FloorDiv").set_body_typed([](PrimExpr a, PrimExpr b, Sp
TVM_REGISTER_NODE_TYPE(FloorDivNode);
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
- .set_dispatch<FloorDivNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const FloorDivNode*>(node.get());
- p->stream << "floordiv(" << op->a << ", " << op->b << ")";
- });
-
// FloorMod
TVM_DEFINE_BINOP_CONSTRUCTOR(FloorMod);
@@ -354,12 +254,6 @@ TVM_REGISTER_GLOBAL("tir.FloorMod").set_body_typed([](PrimExpr a, PrimExpr b, Sp
TVM_REGISTER_NODE_TYPE(FloorModNode);
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
- .set_dispatch<FloorModNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const FloorModNode*>(node.get());
- p->stream << "floormod(" << op->a << ", " << op->b << ")";
- });
-
// Min
TVM_DEFINE_BINOP_CONSTRUCTOR(Min);
@@ -369,16 +263,6 @@ TVM_REGISTER_GLOBAL("tir.Min").set_body_typed([](PrimExpr a, PrimExpr b, Span sp
TVM_REGISTER_NODE_TYPE(MinNode);
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
- .set_dispatch<MinNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const MinNode*>(node.get());
- p->stream << "min(";
- p->Print(op->a);
- p->stream << ", ";
- p->Print(op->b);
- p->stream << ")";
- });
-
// Max
TVM_DEFINE_BINOP_CONSTRUCTOR(Max);
@@ -388,16 +272,6 @@ TVM_REGISTER_GLOBAL("tir.Max").set_body_typed([](PrimExpr a, PrimExpr b, Span sp
TVM_REGISTER_NODE_TYPE(MaxNode);
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
- .set_dispatch<MaxNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const MaxNode*>(node.get());
- p->stream << "max(";
- p->Print(op->a);
- p->stream << ", ";
- p->Print(op->b);
- p->stream << ")";
- });
-
// EQ
TVM_DEFINE_CMPOP_CONSTRUCTOR(EQ);
@@ -407,16 +281,6 @@ TVM_REGISTER_GLOBAL("tir.EQ").set_body_typed([](PrimExpr a, PrimExpr b, Span spa
TVM_REGISTER_NODE_TYPE(EQNode);
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
- .set_dispatch<EQNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const EQNode*>(node.get());
- p->stream << '(';
- p->Print(op->a);
- p->stream << " == ";
- p->Print(op->b);
- p->stream << ')';
- });
-
// NE
TVM_DEFINE_CMPOP_CONSTRUCTOR(NE);
@@ -426,16 +290,6 @@ TVM_REGISTER_GLOBAL("tir.NE").set_body_typed([](PrimExpr a, PrimExpr b, Span spa
TVM_REGISTER_NODE_TYPE(NENode);
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
- .set_dispatch<NENode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const NENode*>(node.get());
- p->stream << '(';
- p->Print(op->a);
- p->stream << " != ";
- p->Print(op->b);
- p->stream << ')';
- });
-
// LT
TVM_DEFINE_CMPOP_CONSTRUCTOR(LT);
@@ -445,16 +299,6 @@ TVM_REGISTER_GLOBAL("tir.LT").set_body_typed([](PrimExpr a, PrimExpr b, Span spa
TVM_REGISTER_NODE_TYPE(LTNode);
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
- .set_dispatch<LTNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const LTNode*>(node.get());
- p->stream << '(';
- p->Print(op->a);
- p->stream << " < ";
- p->Print(op->b);
- p->stream << ')';
- });
-
// LE
TVM_DEFINE_CMPOP_CONSTRUCTOR(LE);
@@ -464,16 +308,6 @@ TVM_REGISTER_GLOBAL("tir.LE").set_body_typed([](PrimExpr a, PrimExpr b, Span spa
TVM_REGISTER_NODE_TYPE(LENode);
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
- .set_dispatch<LENode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const LENode*>(node.get());
- p->stream << '(';
- p->Print(op->a);
- p->stream << " <= ";
- p->Print(op->b);
- p->stream << ')';
- });
-
// GT
TVM_DEFINE_CMPOP_CONSTRUCTOR(GT);
@@ -483,16 +317,6 @@ TVM_REGISTER_GLOBAL("tir.GT").set_body_typed([](PrimExpr a, PrimExpr b, Span spa
TVM_REGISTER_NODE_TYPE(GTNode);
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
- .set_dispatch<GTNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const GTNode*>(node.get());
- p->stream << '(';
- p->Print(op->a);
- p->stream << " > ";
- p->Print(op->b);
- p->stream << ')';
- });
-
// GE
TVM_DEFINE_CMPOP_CONSTRUCTOR(GE);
@@ -502,16 +326,6 @@ TVM_REGISTER_GLOBAL("tir.GE").set_body_typed([](PrimExpr a, PrimExpr b, Span spa
TVM_REGISTER_NODE_TYPE(GENode);
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
- .set_dispatch<GENode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const GENode*>(node.get());
- p->stream << '(';
- p->Print(op->a);
- p->stream << " >= ";
- p->Print(op->b);
- p->stream << ')';
- });
-
// And
And::And(PrimExpr a, PrimExpr b, Span span) {
ICHECK(a.defined()) << "ValueError: a is undefined";
@@ -534,16 +348,6 @@ TVM_REGISTER_GLOBAL("tir.And").set_body_typed([](PrimExpr a, PrimExpr b, Span sp
TVM_REGISTER_NODE_TYPE(AndNode);
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
- .set_dispatch<AndNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const AndNode*>(node.get());
- p->stream << '(';
- p->Print(op->a);
- p->stream << " && ";
- p->Print(op->b);
- p->stream << ')';
- });
-
// Or
Or::Or(PrimExpr a, PrimExpr b, Span span) {
ICHECK(a.defined()) << "ValueError: a is undefined";
@@ -566,16 +370,6 @@ TVM_REGISTER_GLOBAL("tir.Or").set_body_typed([](PrimExpr a, PrimExpr b, Span spa
TVM_REGISTER_NODE_TYPE(OrNode);
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
- .set_dispatch<OrNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const OrNode*>(node.get());
- p->stream << '(';
- p->Print(op->a);
- p->stream << " || ";
- p->Print(op->b);
- p->stream << ')';
- });
-
// Not
Not::Not(PrimExpr a, Span span) {
ICHECK(a.defined()) << "ValueError: a is undefined";
@@ -592,13 +386,6 @@ TVM_REGISTER_GLOBAL("tir.Not").set_body_typed([](PrimExpr a, Span span) { return
TVM_REGISTER_NODE_TYPE(NotNode);
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
- .set_dispatch<NotNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const NotNode*>(node.get());
- p->stream << '!';
- p->Print(op->a);
- });
-
// Select
Select::Select(PrimExpr condition, PrimExpr true_value, PrimExpr false_value, Span span) {
ICHECK(condition.defined()) << "ValueError: condition is undefined";
@@ -624,18 +411,6 @@ TVM_REGISTER_GLOBAL("tir.Select")
TVM_REGISTER_NODE_TYPE(SelectNode);
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
- .set_dispatch<SelectNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const SelectNode*>(node.get());
- p->stream << "select(";
- p->Print(op->condition);
- p->stream << ", ";
- p->Print(op->true_value);
- p->stream << ", ";
- p->Print(op->false_value);
- p->stream << ")";
- });
-
// Load
Load::Load(DataType dtype, Var buffer_var, PrimExpr index, PrimExpr predicate, Span span) {
LOG(FATAL) << "Unexpected use of deprecated Store node for buffer " << buffer_var->name_hint
@@ -703,18 +478,6 @@ TVM_REGISTER_GLOBAL("tir.Load").set_body([](TVMArgs args, TVMRetValue* ret) {
TVM_REGISTER_NODE_TYPE(LoadNode);
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
- .set_dispatch<LoadNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const LoadNode*>(node.get());
- p->stream << op->buffer_var << "[";
- p->Print(op->index);
- p->stream << "]";
- if (!is_one(op->predicate)) {
- p->stream << " if ";
- p->Print(op->predicate);
- }
- });
-
// Ramp
Ramp::Ramp(PrimExpr base, PrimExpr stride, int lanes, Span span) {
ICHECK(base.defined());
@@ -740,16 +503,6 @@ TVM_REGISTER_GLOBAL("tir.Ramp")
TVM_REGISTER_NODE_TYPE(RampNode);
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
- .set_dispatch<RampNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const RampNode*>(node.get());
- p->stream << "ramp(";
- p->Print(op->base);
- p->stream << ", ";
- p->Print(op->stride);
- p->stream << ", " << op->lanes << ")";
- });
-
// Broadcast
Broadcast::Broadcast(PrimExpr value, int lanes, Span span) {
ICHECK(value.defined());
@@ -770,14 +523,6 @@ TVM_REGISTER_GLOBAL("tir.Broadcast").set_body_typed([](PrimExpr value, int lanes
TVM_REGISTER_NODE_TYPE(BroadcastNode);
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
- .set_dispatch<BroadcastNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const BroadcastNode*>(node.get());
- p->stream << "x" << op->lanes << "(";
- p->Print(op->value);
- p->stream << ")";
- });
-
// Let
Let::Let(Var var, PrimExpr value, PrimExpr body, Span span) {
ICHECK(value.defined());
@@ -800,16 +545,6 @@ TVM_REGISTER_GLOBAL("tir.Let").set_body_typed([](Var var, PrimExpr value, PrimEx
TVM_REGISTER_NODE_TYPE(LetNode);
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
- .set_dispatch<LetNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const LetNode*>(node.get());
- p->stream << "(let " << op->var << " = ";
- p->Print(op->value);
- p->stream << " in ";
- p->Print(op->body);
- p->stream << ")";
- });
-
// Call
Call::Call(DataType dtype, RelayExpr op, Array<PrimExpr> args, Span span) {
for (size_t i = 0; i < args.size(); ++i) {
@@ -857,25 +592,6 @@ TVM_REGISTER_GLOBAL("tir.Call")
TVM_REGISTER_NODE_TYPE(CallNode);
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
- .set_dispatch<CallNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const CallNode*>(node.get());
- if (auto* ptr_op = op->op.as<OpNode>()) {
- p->stream << ptr_op->name << "(";
- } else {
- auto* ptr_gvar = op->op.as<GlobalVarNode>();
- ICHECK(ptr_gvar != nullptr);
- p->stream << "@" << ptr_gvar->name_hint << "(";
- }
- for (size_t i = 0; i < op->args.size(); ++i) {
- p->Print(op->args[i]);
- if (i < op->args.size() - 1) {
- p->stream << ", ";
- }
- }
- p->stream << ")";
- });
-
// Shuffle
Shuffle::Shuffle(Array<PrimExpr> vectors, Array<PrimExpr> indices, Span span) {
ICHECK_NE(vectors.size(), 0U);
@@ -924,26 +640,6 @@ TVM_REGISTER_GLOBAL("tir.Shuffle")
TVM_REGISTER_NODE_TYPE(ShuffleNode);
-template <typename T>
-void PrintList(const Array<T>& exprs, ReprPrinter* p) {
- for (size_t i = 0; i < exprs.size(); ++i) {
- p->Print(exprs[i]);
- if (i < exprs.size() - 1) {
- p->stream << ", ";
- }
- }
-}
-
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
- .set_dispatch<ShuffleNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const ShuffleNode*>(node.get());
- p->stream << "shuffle(";
- PrintList(op->vectors, p);
- p->stream << ", ";
- PrintList(op->indices, p);
- p->stream << ")";
- });
-
// CommReducer
CommReducer::CommReducer(Array<Var> lhs, Array<Var> rhs, Array<PrimExpr> result,
Array<PrimExpr> identity_element, Span span) {
@@ -1009,13 +705,6 @@ TVM_REGISTER_GLOBAL("tir.CommReducerCombine")
TVM_REGISTER_NODE_TYPE(CommReducerNode);
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
- .set_dispatch<CommReducerNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const CommReducerNode*>(node.get());
- p->stream << "comm_reducer(result=" << op->result << ", lhs=" << op->lhs
- << ", rhs=" << op->rhs << ", identity_element=" << op->identity_element << ")";
- });
-
// Reduce
Reduce::Reduce(CommReducer combiner, Array<PrimExpr> source, Array<IterVar> axis,
PrimExpr condition, int value_index, Array<PrimExpr> init, Span span) {
@@ -1057,18 +746,6 @@ TVM_REGISTER_GLOBAL("tir.Reduce")
TVM_REGISTER_NODE_TYPE(ReduceNode);
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
- .set_dispatch<ReduceNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const ReduceNode*>(node.get());
- p->stream << "reduce(combiner=" << op->combiner;
- p->stream << ", source=" << op->source;
- p->stream << ", init=" << op->init;
- p->stream << ", axis=" << op->axis;
- p->stream << ", where=" << op->condition;
- p->stream << ", value_index=" << op->value_index;
- p->stream << ")";
- });
-
// Any
Any::Any(Span span) {
auto n = make_object<AnyNode>();
@@ -1081,9 +758,6 @@ TVM_REGISTER_GLOBAL("tir.Any").set_body_typed([](Span span) { return Any(span);
TVM_REGISTER_NODE_TYPE(AnyNode);
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
- .set_dispatch<AnyNode>([](const ObjectRef& node, ReprPrinter* p) { p->stream << "?"; });
-
// BufferLoad
void BufferLoadNode::LegalizeDType() {
for (int i = 0; i < static_cast<int>(indices.size()) - 1; i++) {
@@ -1118,19 +792,6 @@ TVM_REGISTER_GLOBAL("tir.BufferLoad")
TVM_REGISTER_NODE_TYPE(BufferLoadNode);
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
- .set_dispatch<BufferLoadNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const BufferLoadNode*>(node.get());
- p->stream << op->buffer->name << "[";
- for (size_t i = 0; i < op->indices.size(); ++i) {
- p->Print(op->indices[i]);
- if (i < op->indices.size() - 1) {
- p->stream << ", ";
- }
- }
- p->stream << "]";
- });
-
// ProducerLoad
ProducerLoad::ProducerLoad(DataProducer producer, Array<PrimExpr> indices, Span span) {
ObjectPtr<ProducerLoadNode> node = make_object<ProducerLoadNode>();
@@ -1148,17 +809,5 @@ TVM_REGISTER_GLOBAL("tir.ProducerLoad")
TVM_REGISTER_NODE_TYPE(ProducerLoadNode);
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
- .set_dispatch<ProducerLoadNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const ProducerLoadNode*>(node.get());
- p->stream << op->producer->GetNameHint() << "[";
- for (size_t i = 0; i < op->indices.size(); ++i) {
- p->Print(op->indices[i]);
- if (i < op->indices.size() - 1) {
- p->stream << ", ";
- }
- }
- p->stream << "]";
- });
} // namespace tir
} // namespace tvm
diff --git a/src/tir/ir/function.cc b/src/tir/ir/function.cc
index d4802e2876..5067d90838 100644
--- a/src/tir/ir/function.cc
+++ b/src/tir/ir/function.cc
@@ -109,21 +109,6 @@ Optional<TensorIntrin> TensorIntrin::Get(String name, bool allow_missing) {
TVM_REGISTER_NODE_TYPE(TensorIntrinNode);
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
- .set_dispatch<PrimFuncNode>([](const ObjectRef& ref, ReprPrinter* p) {
- // TODO(tvm-team) redirect to Text printer once we have a good text format.
- auto* node = static_cast<const PrimFuncNode*>(ref.get());
- p->stream << "PrimFunc(" << node->params << ") ";
- if (node->attrs.defined()) {
- p->stream << "attrs=" << node->attrs;
- }
- p->stream << " {\n";
- p->indent += 2;
- p->Print(node->body);
- p->indent -= 2;
- p->stream << "}\n";
- });
-
TVM_REGISTER_GLOBAL("tir.PrimFunc")
.set_body_typed([](Array<tir::Var> params, Stmt body, Type ret_type,
Map<tir::Var, Buffer> buffer_map, DictAttrs attrs, Span span) {
diff --git a/src/tir/ir/index_map.cc b/src/tir/ir/index_map.cc
index 03a2f29bd1..ee7e493b61 100644
--- a/src/tir/ir/index_map.cc
+++ b/src/tir/ir/index_map.cc
@@ -21,11 +21,10 @@
* \file index_map.cc
*/
-#include "tvm/tir/index_map.h"
-
#include <tvm/arith/analyzer.h>
#include <tvm/arith/int_set.h>
#include <tvm/arith/iter_affine_map.h>
+#include <tvm/tir/index_map.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
diff --git a/src/tir/ir/legacy_printer.cc b/src/tir/ir/legacy_printer.cc
new file mode 100644
index 0000000000..4c2fd5037b
--- /dev/null
+++ b/src/tir/ir/legacy_printer.cc
@@ -0,0 +1,270 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <tvm/tir/expr.h>
+#include <tvm/tir/op.h>
+#include <tvm/tir/stmt.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include <sstream>
+
+#include "../../support/str_escape.h"
+
+namespace tvm {
+namespace tir {
+
+std::string LegacyTIRPrint(const ObjectRef& obj) {
+ using namespace tvm::tir;
+ class LegacyTIRPrinter : private tir::ExprVisitor {
+ public:
+ explicit LegacyTIRPrinter(std::ostream& os) : stream(os) {}
+
+ void Print(const ObjectRef& obj) {
+ if (const auto* op = obj.as<CommReducerNode>()) {
+ Print_(op);
+ } else if (const auto* op = obj.as<IterVarNode>()) {
+ Print_(op);
+ } else if (const auto* op = obj.as<RangeNode>()) {
+ Print_(op);
+ } else if (const auto* op = obj.as<OpNode>()) {
+ Print_(op);
+ } else {
+ VisitExpr(Downcast<PrimExpr>(obj));
+ }
+ }
+
+ private:
+ void VisitExpr_(const VarNode* op) final { stream << op->name_hint; }
+
+ void VisitExpr_(const SizeVarNode* op) final {
+ stream << "{" << op->name_hint << "|" << op->name_hint << ">=0}";
+ }
+
+ void VisitExpr_(const IntImmNode* op) final {
+ if (op->dtype == DataType::Int(32)) {
+ stream << op->value;
+ } else {
+ stream << "(" << op->dtype << ")" << op->value;
+ }
+ }
+
+ void VisitExpr_(const FloatImmNode* op) final {
+ switch (op->dtype.bits()) {
+ case 64:
+ stream << op->value;
+ break;
+ case 32:
+ stream << op->value << 'f';
+ break;
+ case 16:
+ stream << op->value << 'h';
+ break;
+ default:
+ LOG(FATAL) << "Unknown float type bits=" << op->dtype.bits();
+ }
+ }
+ void VisitExpr_(const StringImmNode* op) final {
+ stream << '\"' << support::StrEscape(op->value) << '\"';
+ }
+ void VisitExpr_(const CastNode* op) final {
+ stream << op->dtype << '(';
+ VisitExpr(op->value);
+ stream << ')';
+ }
+ void VisitExpr_(const AddNode* op) final { PrintBinary(op->a, op->b, " + "); }
+ void VisitExpr_(const SubNode* op) final { PrintBinary(op->a, op->b, " - "); }
+ void VisitExpr_(const MulNode* op) final { PrintBinary(op->a, op->b, "*"); }
+ void VisitExpr_(const DivNode* op) final { PrintBinary(op->a, op->b, "/"); }
+ void VisitExpr_(const ModNode* op) final { PrintBinary(op->a, op->b, " % "); }
+ void VisitExpr_(const FloorDivNode* op) final { PrintCall("floordiv", op->a, op->b); }
+ void VisitExpr_(const FloorModNode* op) final { PrintCall("floormod", op->a, op->b); }
+ void VisitExpr_(const MinNode* op) final { PrintCall("min", op->a, op->b); }
+ void VisitExpr_(const MaxNode* op) final { PrintCall("max", op->a, op->b); }
+ void VisitExpr_(const EQNode* op) final { PrintBinary(op->a, op->b, " == "); }
+ void VisitExpr_(const NENode* op) final { PrintBinary(op->a, op->b, " != "); }
+ void VisitExpr_(const LTNode* op) final { PrintBinary(op->a, op->b, " < "); }
+ void VisitExpr_(const LENode* op) final { PrintBinary(op->a, op->b, " <= "); }
+ void VisitExpr_(const GTNode* op) final { PrintBinary(op->a, op->b, " > "); }
+ void VisitExpr_(const GENode* op) final { PrintBinary(op->a, op->b, " >= "); }
+ void VisitExpr_(const AndNode* op) final { PrintBinary(op->a, op->b, " && "); }
+ void VisitExpr_(const OrNode* op) final { PrintBinary(op->a, op->b, " || "); }
+
+ void VisitExpr_(const NotNode* op) final {
+ stream << "!";
+ VisitExpr(op->a);
+ }
+
+ void VisitExpr_(const SelectNode* op) final {
+ stream << "select(";
+ VisitExpr(op->condition);
+ stream << ", ";
+ VisitExpr(op->true_value);
+ stream << ", ";
+ VisitExpr(op->false_value);
+ stream << ')';
+ }
+
+ void VisitExpr_(const RampNode* op) final {
+ stream << "ramp(";
+ VisitExpr(op->base);
+ stream << ", ";
+ VisitExpr(op->stride);
+ stream << ", " << op->lanes << ')';
+ }
+
+ void VisitExpr_(const BroadcastNode* op) final {
+ stream << "x" << op->lanes << "(";
+ VisitExpr(op->value);
+ stream << ")";
+ }
+
+ void VisitExpr_(const LetNode* op) final {
+ stream << "(let " << op->var << " = ";
+ VisitExpr(op->value);
+ stream << " in ";
+ VisitExpr(op->body);
+ stream << ")";
+ }
+
+ void VisitExpr_(const CallNode* op) final {
+ if (auto* ptr_op = op->op.as<OpNode>()) {
+ stream << ptr_op->name << "(";
+ } else {
+ auto* p = op->op.as<GlobalVarNode>();
+ ICHECK(p != nullptr);
+ stream << "@" << p->name_hint << "(";
+ }
+ for (size_t i = 0; i < op->args.size(); ++i) {
+ VisitExpr(op->args[i]);
+ if (i < op->args.size() - 1) {
+ stream << ", ";
+ }
+ }
+ stream << ")";
+ }
+
+ void VisitExpr_(const ShuffleNode* op) final {
+ stream << "shuffle(";
+ PrintList(op->vectors.GetArrayNode());
+ stream << ", ";
+ PrintList(op->indices.GetArrayNode());
+ stream << ")";
+ }
+
+ void VisitExpr_(const ReduceNode* op) final {
+ stream << "reduce(combiner=";
+ Print_(op->combiner.get());
+ stream << ", source=";
+ PrintList(op->source.GetArrayNode());
+ stream << ", init=";
+ PrintList(op->init.GetArrayNode());
+ stream << ", axis=";
+ PrintList(op->axis.GetArrayNode());
+ stream << ", where=";
+ VisitExpr(op->condition);
+ stream << ", value_index=" << op->value_index;
+ stream << ")";
+ }
+
+ void VisitExpr_(const AnyNode* op) final { stream << "?"; }
+
+ void VisitExpr_(const BufferLoadNode* op) final {
+ stream << op->buffer->name << "[";
+ for (size_t i = 0; i < op->indices.size(); ++i) {
+ VisitExpr(op->indices[i]);
+ if (i < op->indices.size() - 1) {
+ stream << ", ";
+ }
+ }
+ stream << "]";
+ }
+
+ void VisitExpr_(const ProducerLoadNode* op) final {
+ stream << op->producer->GetNameHint() << "[";
+ for (size_t i = 0; i < op->indices.size(); ++i) {
+ VisitExpr(op->indices[i]);
+ if (i < op->indices.size() - 1) {
+ stream << ", ";
+ }
+ }
+ stream << "]";
+ }
+
+ private:
+ void Print_(const CommReducerNode* op) {
+ stream << "comm_reducer(result=";
+ PrintList(op->result.GetArrayNode());
+ stream << ", lhs=";
+ PrintList(op->lhs.GetArrayNode());
+ stream << ", rhs=";
+ PrintList(op->rhs.GetArrayNode());
+ stream << ", identity_element=";
+ PrintList(op->identity_element.GetArrayNode());
+ stream << ")";
+ }
+
+ void Print_(const IterVarNode* op) {
+ stream << "{" << op->var->name_hint << "|" << op->var->name_hint << " in [";
+ VisitExpr(op->dom->min);
+ stream << ", ";
+ VisitExpr(op->dom->extent);
+ stream << ")}";
+ }
+
+ void Print_(const RangeNode* op) {
+ stream << "range(min=" << op->min << ", ext=" << op->extent << ')';
+ }
+
+ void Print_(const OpNode* op) { stream << "Op(" << op->name << ")"; }
+
+ private:
+ void PrintBinary(const PrimExpr& a, const PrimExpr& b, const std::string& sign) {
+ stream << '(';
+ VisitExpr(a);
+ stream << sign;
+ VisitExpr(b);
+ stream << ')';
+ }
+
+ void PrintCall(const std::string& call, const PrimExpr& a, const PrimExpr& b) {
+ stream << call << '(';
+ VisitExpr(a);
+ stream << ", ";
+ VisitExpr(b);
+ stream << ')';
+ }
+
+ void PrintList(const ArrayNode* exprs) {
+ int n = static_cast<int>(exprs->size());
+ for (int i = 0; i < n; ++i) {
+ VisitExpr(Downcast<PrimExpr>(exprs->at(i)));
+ if (i < n - 1) {
+ stream << ", ";
+ }
+ }
+ }
+
+ std::ostream& stream;
+ };
+ std::ostringstream os;
+ LegacyTIRPrinter(os).Print(obj);
+ return os.str();
+}
+
+} // namespace tir
+} // namespace tvm
diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc
index af6997a72a..355a3b16b8 100644
--- a/src/tir/ir/stmt.cc
+++ b/src/tir/ir/stmt.cc
@@ -59,16 +59,6 @@ TVM_REGISTER_GLOBAL("tir.LetStmt")
TVM_REGISTER_NODE_TYPE(LetStmtNode);
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
- .set_dispatch<LetStmtNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const LetStmtNode*>(node.get());
- p->PrintIndent();
- p->stream << "let " << op->var << " = ";
- p->Print(op->value);
- p->stream << '\n';
- p->Print(op->body);
- });
-
// AttrStmt
AttrStmt::AttrStmt(ObjectRef node, String attr_key, PrimExpr value, Stmt body, Span span) {
auto n = make_object<AttrStmtNode>();
@@ -87,18 +77,6 @@ TVM_REGISTER_GLOBAL("tir.AttrStmt")
TVM_REGISTER_NODE_TYPE(AttrStmtNode);
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
- .set_dispatch<AttrStmtNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const AttrStmtNode*>(node.get());
- p->PrintIndent();
- p->stream << "// attr [";
- p->Print(op->node);
- p->stream << "] " << op->attr_key << " = ";
- p->Print(op->value);
- p->stream << '\n';
- p->Print(op->body);
- });
-
// AssertStmt
AssertStmt::AssertStmt(PrimExpr condition, PrimExpr message, Stmt body, Span span) {
ICHECK(condition.defined());
@@ -125,18 +103,6 @@ TVM_REGISTER_GLOBAL("tir.AssertStmt")
}
});
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
- .set_dispatch<AssertStmtNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const AssertStmtNode*>(node.get());
- p->PrintIndent();
- p->stream << "assert(";
- p->Print(op->condition);
- p->stream << ", ";
- p->Print(op->message);
- p->stream << ")\n";
- p->Print(op->body);
- });
-
// For
For::For(Var loop_var, PrimExpr min, PrimExpr extent, ForKind kind, Stmt body,
Optional<IterVar> thread_binding, Map<String, ObjectRef> annotations, Span span) {
@@ -209,24 +175,6 @@ std::ostream& operator<<(std::ostream& out, ForKind type) { // NOLINT(*)
return out;
}
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
- .set_dispatch<ForNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const ForNode*>(node.get());
- p->PrintIndent();
- p->stream << op->kind << " (" << op->loop_var << ", ";
- p->Print(op->min);
- p->stream << ", ";
- p->Print(op->extent);
- p->stream << ") {\n";
-
- p->indent += 2;
- p->Print(op->body);
- p->indent -= 2;
-
- p->PrintIndent();
- p->stream << "}\n";
- });
-
// While
While::While(PrimExpr condition, Stmt body, Span span) {
ICHECK(condition.defined());
@@ -247,18 +195,6 @@ TVM_REGISTER_GLOBAL("tir.While").set_body_typed([](PrimExpr condition, Stmt body
TVM_REGISTER_NODE_TYPE(WhileNode);
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
- .set_dispatch<WhileNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const WhileNode*>(node.get());
- p->PrintIndent();
- p->stream << "while(" << op->condition << ") {\n";
- p->indent += 2;
- p->Print(op->body);
- p->indent -= 2;
- p->PrintIndent();
- p->stream << "}\n";
- });
-
// Store
Store::Store(Var buffer_var, PrimExpr value, PrimExpr index, PrimExpr predicate, Span span) {
LOG(FATAL) << "Unexpected use of deprecated Store node for buffer " << buffer_var->name_hint
@@ -312,21 +248,6 @@ TVM_REGISTER_GLOBAL("tir.Store").set_body([](TVMArgs args, TVMRetValue* ret) {
TVM_REGISTER_NODE_TYPE(StoreNode);
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
- .set_dispatch<StoreNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const StoreNode*>(node.get());
- p->PrintIndent();
- p->stream << op->buffer_var << "[";
- p->Print(op->index);
- p->stream << "] = ";
- p->Print(op->value);
- if (!is_one(op->predicate)) {
- p->stream << " if ";
- p->Print(op->predicate);
- }
- p->stream << '\n';
- });
-
// ProducerStore
ProducerStore::ProducerStore(DataProducer producer, PrimExpr value, Array<PrimExpr> indices,
Span span) {
@@ -345,21 +266,6 @@ TVM_REGISTER_GLOBAL("tir.ProducerStore")
TVM_REGISTER_NODE_TYPE(ProducerStoreNode);
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
- .set_dispatch<ProducerStoreNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const ProducerStoreNode*>(node.get());
- p->PrintIndent();
- p->stream << op->producer->GetNameHint() << "[";
- for (size_t i = 0; i < op->indices.size(); ++i) {
- p->Print(op->indices[i]);
- if (i < op->indices.size() - 1) p->stream << ", ";
- }
- p->stream << "]";
- p->stream << " =";
- p->Print(op->value);
- p->stream << '\n';
- });
-
// Allocate
Allocate::Allocate(Var buffer_var, DataType dtype, Array<PrimExpr> extents, PrimExpr condition,
Stmt body, Map<String, ObjectRef> annotations, Span span) {
@@ -412,26 +318,6 @@ TVM_REGISTER_GLOBAL("tir.Allocate")
TVM_REGISTER_NODE_TYPE(AllocateNode);
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
- .set_dispatch<AllocateNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const AllocateNode*>(node.get());
- const auto* ptr_type = op->buffer_var->type_annotation.as<PointerTypeNode>();
- ICHECK(ptr_type) << "The provided variable is not of pointer type";
- p->PrintIndent();
- p->stream << "allocate " << op->buffer_var << "[" << op->dtype;
- for (size_t i = 0; i < op->extents.size(); ++i) {
- p->stream << " * ";
- p->Print(op->extents[i]);
- }
- p->stream << "], storage_scope = " << ptr_type->storage_scope;
- if (!is_one(op->condition)) {
- p->stream << " if ";
- p->Print(op->condition);
- }
- p->stream << "\n";
- p->Print(op->body);
- });
-
// Const
// The constructor to create a IRNode with constant data
// depending on the type of ObjectRef, it will either
@@ -494,20 +380,6 @@ TVM_REGISTER_GLOBAL("tir.AllocateConst")
TVM_REGISTER_NODE_TYPE(AllocateConstNode);
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
- .set_dispatch<AllocateConstNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const AllocateConstNode*>(node.get());
- p->PrintIndent();
- p->stream << "constant " << op->buffer_var << "[" << op->dtype;
- for (size_t i = 0; i < op->extents.size(); ++i) {
- p->stream << " * ";
- p->Print(op->extents[i]);
- }
- p->stream << "]";
- p->stream << "\n";
- p->Print(op->body);
- });
-
// DeclBuffer
DeclBuffer::DeclBuffer(Buffer buffer, Stmt body, Span span) {
ObjectPtr<DeclBufferNode> node = make_object<DeclBufferNode>();
@@ -523,14 +395,6 @@ TVM_REGISTER_GLOBAL("tir.DeclBuffer").set_body_typed([](Buffer buffer, Stmt body
TVM_REGISTER_NODE_TYPE(DeclBufferNode);
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
- .set_dispatch<DeclBufferNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const DeclBufferNode*>(node.get());
- p->PrintIndent();
- p->stream << "decl_buffer " << op->buffer << "\n";
- p->stream << op->body;
- });
-
// ProducerRealize
ProducerRealize::ProducerRealize(DataProducer producer, Region bounds, PrimExpr condition,
Stmt body, String storage_scope, Span span) {
@@ -562,34 +426,6 @@ TVM_REGISTER_GLOBAL("tir.ProducerRealize")
TVM_REGISTER_NODE_TYPE(ProducerRealizeNode);
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
- .set_dispatch<ProducerRealizeNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const ProducerRealizeNode*>(node.get());
- p->PrintIndent();
- p->stream << "producer_realize " << op->producer->GetNameHint() << "(";
- for (size_t i = 0; i < op->bounds.size(); ++i) {
- p->stream << "[";
- p->Print(op->bounds[i]->min);
- p->stream << ", ";
- p->Print(op->bounds[i]->extent);
- p->stream << "]";
- if (i < op->bounds.size() - 1) p->stream << ", ";
- }
- p->stream << ")";
- if (!is_one(op->condition)) {
- p->stream << " if ";
- p->Print(op->condition);
- }
- p->stream << " {\n";
-
- p->indent += 2;
- p->Print(op->body);
- p->indent -= 2;
-
- p->PrintIndent();
- p->stream << "}\n";
- });
-
// Prefetch
Prefetch::Prefetch(Buffer buffer, Array<Range> bounds, Span span) {
data_ = make_object<PrefetchNode>(buffer, bounds, span);
@@ -602,22 +438,6 @@ TVM_REGISTER_GLOBAL("tir.Prefetch")
TVM_REGISTER_NODE_TYPE(PrefetchNode);
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
- .set_dispatch<PrefetchNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const PrefetchNode*>(node.get());
- p->PrintIndent();
- p->stream << "prefetch " << op->buffer << "(";
- for (size_t i = 0; i < op->bounds.size(); ++i) {
- p->stream << "[";
- p->Print(op->bounds[i]->min);
- p->stream << ", ";
- p->Print(op->bounds[i]->extent);
- p->stream << "]";
- if (i < op->bounds.size() - 1) p->stream << ", ";
- }
- p->stream << ")";
- });
-
// SeqStmt
SeqStmt::SeqStmt(Array<Stmt> seq, Span span) {
auto node = make_object<SeqStmtNode>();
@@ -632,14 +452,6 @@ TVM_REGISTER_GLOBAL("tir.SeqStmt").set_body_typed([](Array<Stmt> seq, Span span)
TVM_REGISTER_NODE_TYPE(SeqStmtNode);
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
- .set_dispatch<SeqStmtNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const SeqStmtNode*>(node.get());
- for (Stmt stmt : op->seq) {
- p->Print(stmt);
- }
- });
-
// IfThenElse
IfThenElse::IfThenElse(PrimExpr condition, Stmt then_case, Optional<Stmt> else_case, Span span) {
ICHECK(condition.defined());
@@ -660,37 +472,6 @@ TVM_REGISTER_GLOBAL("tir.IfThenElse")
return IfThenElse(condition, then_case, else_case, span);
});
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
- .set_dispatch<IfThenElseNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const IfThenElseNode*>(node.get());
- p->PrintIndent();
- while (true) {
- p->stream << "if (" << op->condition << ") {\n";
- p->indent += 2;
- p->Print(op->then_case);
- p->indent -= 2;
-
- if (!op->else_case) {
- break;
- }
-
- if (const IfThenElseNode* nested_if = op->else_case.as<IfThenElseNode>()) {
- p->PrintIndent();
- p->stream << "} else ";
- op = nested_if;
- } else {
- p->PrintIndent();
- p->stream << "} else {\n";
- p->indent += 2;
- p->Print(op->else_case);
- p->indent -= 2;
- break;
- }
- }
- p->PrintIndent();
- p->stream << "}\n";
- });
-
// Evaluate
Evaluate::Evaluate(PrimExpr value, Span span) {
ICHECK(value.defined());
@@ -707,14 +488,6 @@ TVM_REGISTER_GLOBAL("tir.Evaluate").set_body_typed([](PrimExpr value, Span span)
TVM_REGISTER_NODE_TYPE(EvaluateNode);
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
- .set_dispatch<EvaluateNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const EvaluateNode*>(node.get());
- p->PrintIndent();
- p->Print(op->value);
- p->stream << "\n";
- });
-
// BufferStore
BufferStore::BufferStore(Buffer buffer, PrimExpr value, Array<PrimExpr> indices, Span span) {
ICHECK_EQ(buffer->shape.size(), indices.size())
@@ -750,21 +523,6 @@ TVM_REGISTER_GLOBAL("tir.BufferStore")
TVM_REGISTER_NODE_TYPE(BufferStoreNode);
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
- .set_dispatch<BufferStoreNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const BufferStoreNode*>(node.get());
- p->PrintIndent();
- p->stream << op->buffer->name << "[";
- for (size_t i = 0; i < op->indices.size(); ++i) {
- p->Print(op->indices[i]);
- if (i < op->indices.size() - 1) p->stream << ", ";
- }
- p->stream << "]";
- p->stream << " = ";
- p->Print(op->value);
- p->stream << '\n';
- });
-
// BufferRealize
BufferRealize::BufferRealize(Buffer buffer, Array<Range> bounds, PrimExpr condition, Stmt body,
Span span) {
@@ -777,34 +535,6 @@ TVM_REGISTER_GLOBAL("tir.BufferRealize")
TVM_REGISTER_NODE_TYPE(BufferRealizeNode);
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
- .set_dispatch<BufferRealizeNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const BufferRealizeNode*>(node.get());
- p->PrintIndent();
- p->stream << "buffer_realize " << op->buffer->name << "(";
- for (size_t i = 0; i < op->bounds.size(); ++i) {
- p->stream << "[";
- p->Print(op->bounds[i]->min);
- p->stream << ", ";
- p->Print(op->bounds[i]->extent);
- p->stream << "]";
- if (i < op->bounds.size() - 1) p->stream << ", ";
- }
- p->stream << ")";
- if (!is_one(op->condition)) {
- p->stream << " if ";
- p->Print(op->condition);
- }
- p->stream << " {\n";
-
- p->indent += 2;
- p->Print(op->body);
- p->indent -= 2;
-
- p->PrintIndent();
- p->stream << "}\n";
- });
-
// BufferRegion
BufferRegion::BufferRegion(Buffer buffer, Array<Range> region) {
CHECK_EQ(buffer->shape.size(), region.size())
@@ -843,23 +573,6 @@ TVM_REGISTER_GLOBAL("tir.BufferRegion").set_body_typed([](Buffer buffer, Array<R
TVM_REGISTER_NODE_TYPE(BufferRegionNode);
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
- .set_dispatch<BufferRegionNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const BufferRegionNode*>(node.get());
- p->stream << op->buffer->name;
- p->stream << "[";
- for (size_t i = 0; i < op->region.size(); ++i) {
- const auto& range = op->region[i];
- p->Print(range->min);
- if (!is_one(range->extent)) {
- p->stream << ":";
- p->Print(range->min + range->extent);
- }
- if (i != op->region.size() - 1) p->stream << ", ";
- }
- p->stream << "]";
- });
-
// MatchBufferRegion
MatchBufferRegion::MatchBufferRegion(Buffer buffer, BufferRegion source) {
const Buffer& source_buffer = source->buffer;
@@ -917,15 +630,6 @@ TVM_REGISTER_GLOBAL("tir.MatchBufferRegion").set_body_typed([](Buffer buffer, Bu
TVM_REGISTER_NODE_TYPE(MatchBufferRegionNode);
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
- .set_dispatch<MatchBufferRegionNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const MatchBufferRegionNode*>(node.get());
- p->PrintIndent();
- p->stream << op->buffer->name << " = match_buffer(";
- p->Print(op->source);
- p->stream << ")\n";
- });
-
// Block
Block::Block(Array<IterVar> iter_vars, Array<BufferRegion> reads, Array<BufferRegion> writes,
String name_hint, Stmt body, Optional<Stmt> init, Array<Buffer> alloc_buffers,
@@ -956,78 +660,6 @@ TVM_REGISTER_GLOBAL("tir.Block")
TVM_REGISTER_NODE_TYPE(BlockNode);
-void PrintBlockTitle(const BlockNode* op, ReprPrinter* p) {
- p->stream << "block " << op->name_hint << "(";
- for (size_t i = 0; i < op->iter_vars.size(); i++) {
- p->Print(op->iter_vars[i]);
- if (i < op->iter_vars.size() - 1) p->stream << ", ";
- }
- p->stream << ")";
-}
-
-void PrintBlockSignature(const BlockNode* op, ReprPrinter* p) {
- // print read/write regions
- p->PrintIndent();
- p->stream << "reads(";
- p->Print(op->reads);
- p->stream << ")\n";
- p->PrintIndent();
- p->stream << "writes(";
- p->Print(op->writes);
- p->stream << ")\n";
- // Print alloc_buffers
- for (const auto& alloc_buf : op->alloc_buffers) {
- p->PrintIndent();
- p->stream << alloc_buf->name << " = alloc_buffer(" << alloc_buf->dtype << "[";
- for (size_t i = 0; i < alloc_buf->shape.size(); ++i) {
- if (i > 0) p->stream << ", ";
- p->Print(alloc_buf->shape[i]);
- }
- p->stream << "])\n";
- }
- // Print match_buffer_regions
- for (const auto& match_buf : op->match_buffers) {
- p->Print(match_buf);
- }
- if (!op->annotations.empty()) {
- p->PrintIndent();
- p->stream << "annotations(" << op->annotations << ")\n";
- }
-}
-
-void PrintBlockBody(const BlockNode* op, ReprPrinter* p) {
- // Print init
- if (op->init.defined()) {
- p->PrintIndent();
- p->stream << "with init() {\n";
- p->indent += 2;
- p->Print(op->init.value());
- p->indent -= 2;
- p->PrintIndent();
- p->stream << "}\n";
- }
- // Print body
- p->Print(op->body);
-}
-
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
- .set_dispatch<BlockNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const BlockNode*>(node.get());
- p->PrintIndent();
- PrintBlockTitle(op, p);
- p->stream << " {\n";
- p->indent += 2;
-
- // Print block elements (e.g. reads/writes, etc)
- PrintBlockSignature(op, p);
- // Print block init and body
- PrintBlockBody(op, p);
-
- p->indent -= 2;
- p->PrintIndent();
- p->stream << "}\n";
- });
-
// BlockRealize
BlockRealize::BlockRealize(Array<PrimExpr> values, PrimExpr predicate, Block block, Span span) {
CHECK_EQ(block->iter_vars.size(), values.size())
@@ -1048,41 +680,6 @@ TVM_REGISTER_GLOBAL("tir.BlockRealize")
TVM_REGISTER_NODE_TYPE(BlockRealizeNode);
-TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
- .set_dispatch<BlockRealizeNode>([](const ObjectRef& node, ReprPrinter* p) {
- auto* op = static_cast<const BlockRealizeNode*>(node.get());
- auto* block_op = op->block.get();
- p->PrintIndent();
- PrintBlockTitle(block_op, p);
- p->stream << " {\n";
- p->indent += 2;
-
- // Print binding iter_values
- for (size_t i = 0; i < block_op->iter_vars.size(); ++i) {
- p->PrintIndent();
- p->stream << "bind(";
- p->Print(block_op->iter_vars[i]->var);
- p->stream << ", ";
- p->Print(op->iter_values[i]);
- p->stream << ")\n";
- }
- // Print predicate
- if (!is_one(op->predicate)) {
- p->PrintIndent();
- p->stream << "where(";
- p->Print(op->predicate);
- p->stream << ")\n";
- }
- // Print block elements (e.g. reads/writes, etc)
- PrintBlockSignature(block_op, p);
- // Print block init and body
- PrintBlockBody(block_op, p);
-
- p->indent -= 2;
- p->PrintIndent();
- p->stream << "}\n";
- });
-
PrimExpr TypeAnnotation(DataType dtype, Span span) {
static auto op = Op::Get("tir.type_annotation");
return tir::Call(dtype, op, {}, span);
diff --git a/src/tir/transforms/common_subexpr_elim.cc b/src/tir/transforms/common_subexpr_elim.cc
index 447d85370c..5cf6f231dd 100644
--- a/src/tir/transforms/common_subexpr_elim.cc
+++ b/src/tir/transforms/common_subexpr_elim.cc
@@ -151,8 +151,8 @@ bool CommonSubexpressionEliminator::OrderOnExprAndFrequency(std::pair<PrimExpr,
// as we need a deterministic order
std::stringstream a_stream;
std::stringstream b_stream;
- a_stream << a.first;
- b_stream << b.first;
+ a_stream << LegacyTIRPrint(a.first);
+ b_stream << LegacyTIRPrint(b.first);
return (a_stream.str().compare(b_stream.str()) < 0);
}
diff --git a/src/tir/transforms/common_subexpr_elim_tools.cc b/src/tir/transforms/common_subexpr_elim_tools.cc
index c118d1db7d..c6b0b457c0 100644
--- a/src/tir/transforms/common_subexpr_elim_tools.cc
+++ b/src/tir/transforms/common_subexpr_elim_tools.cc
@@ -817,8 +817,8 @@ std::vector<std::pair<PrimExpr, size_t>> SyntacticToSemanticComputations(
[](std::pair<PrimExpr, size_t> a, std::pair<PrimExpr, size_t> b) {
std::stringstream a_stream;
std::stringstream b_stream;
- a_stream << a.first;
- b_stream << b.first;
+ a_stream << LegacyTIRPrint(a.first);
+ b_stream << LegacyTIRPrint(b.first);
return a_stream.str().compare(b_stream.str()) < 0;
});
diff --git a/tests/cpp/expr_test.cc b/tests/cpp/expr_test.cc
index f10d99eb1f..82de46616c 100644
--- a/tests/cpp/expr_test.cc
+++ b/tests/cpp/expr_test.cc
@@ -32,7 +32,7 @@ TEST(Expr, Basic) {
std::ostringstream os;
os << z;
ICHECK(zz.same_as(z));
- ICHECK(os.str() == "max(((x + 1) + 2), 100)");
+ ICHECK(os.str() == "T.max(x + 1 + 2, 100)");
}
TEST(Expr, VarTypeAnnotation) {
diff --git a/tests/python/driver/tvmc/test_shape_parser.py b/tests/python/driver/tvmc/test_shape_parser.py
index 1e3cde1292..b7b96ae4ef 100644
--- a/tests/python/driver/tvmc/test_shape_parser.py
+++ b/tests/python/driver/tvmc/test_shape_parser.py
@@ -18,7 +18,6 @@
import argparse
import pytest
-
from tvm.driver.tvmc.shape_parser import parse_shape_string
@@ -53,14 +52,14 @@ def test_negative_dimensions():
shape_string = "input:[-1,3,224,224]"
shape_dict = parse_shape_string(shape_string)
# Convert to strings to allow comparison with Any.
- assert str(shape_dict) == "{'input': [?, 3, 224, 224]}"
+ assert str(shape_dict) == "{'input': [T.Any(), 3, 224, 224]}"
def test_multiple_valid_gpu_inputs():
# Check that multiple valid gpu inputs are parsed correctly.
shape_string = "gpu_0/data_0:[1, -1,224,224] gpu_1/data_1:[7, 7]"
shape_dict = parse_shape_string(shape_string)
- expected = "{'gpu_0/data_0': [1, ?, 224, 224], 'gpu_1/data_1': [7, 7]}"
+ expected = "{'gpu_0/data_0': [1, T.Any(), 224, 224], 'gpu_1/data_1': [7, 7]}"
assert str(shape_dict) == expected
diff --git a/tests/python/relay/aot/test_c_device_api.py b/tests/python/relay/aot/test_c_device_api.py
index ea5ea4920c..247b22eac4 100644
--- a/tests/python/relay/aot/test_c_device_api.py
+++ b/tests/python/relay/aot/test_c_device_api.py
@@ -21,12 +21,11 @@ from collections import OrderedDict
import numpy as np
import pytest
-
import tvm.testing
from tvm import relay
from tvm.ir.module import IRModule
-from tvm.testing.aot import AOTTestModel, generate_ref_data, compile_models
from tvm.micro.testing.aot_test_utils import AOT_DEFAULT_RUNNER
+from tvm.testing.aot import AOTTestModel, compile_models, generate_ref_data
@pytest.fixture(name="device_api_main_func")
@@ -40,10 +39,13 @@ def fixture_device_api_main_func():
# pylint: disable=import-outside-toplevel
import tensorflow as tf
import tflite.Model
-
- from tests.python.contrib.test_ethosu.infra import create_test_runner, generate_ref_data_tflite
from tvm.relay.op.contrib.ethosu import partition_for_ethosu
+ from tests.python.contrib.test_ethosu.infra import (
+ create_test_runner,
+ generate_ref_data_tflite,
+ )
+
# pylint: enable=import-outside-toplevel
tf.config.run_functions_eagerly(True)
@@ -236,11 +238,12 @@ def test_without_device_api_unpacked_api(non_device_api_main_func):
"""Test a graph without the Device API with the unpacked internal calls"""
main_func = non_device_api_main_func(interface_api="c", use_unpacked_api=True)
+ body = main_func.body.seq[1].seq[0].seq[0].value
assert (
- str(main_func.body)
- == "tir.tvm_check_return(0, -1, tir.call_extern("
+ repr(body)
+ == 'T.tvm_check_return(0, -1, T.call_extern("int32", '
+ '"tvmgen_default_fused_multiply",'
- + " x_buffer_var, y_buffer_var, output_buffer_var))\n"
+ + " x_buffer_var, y_buffer_var, output_buffer_var))"
)
@@ -249,12 +252,16 @@ def test_without_device_api_packed_api(non_device_api_main_func):
main_func = non_device_api_main_func(interface_api="packed", use_unpacked_api=False)
- assert str(main_func.body) == (
- 'tir.tvm_call_cpacked("tvmgen_default_fused_multiply", '
- "tir.tvm_stack_make_array(x_buffer_var, tir.tvm_stack_make_shape(10, 10), tir.reinterpret((uint64)0), (uint32)2, float32(0), 0), " # pylint: disable=line-too-long
- "tir.tvm_stack_make_array(y_buffer_var, tir.tvm_stack_make_shape(1, 10), tir.reinterpret((uint64)0), (uint32)2, float32(0), 0), " # pylint: disable=line-too-long
- "tir.tvm_stack_make_array(output_buffer_var, tir.tvm_stack_make_shape(10, 10), tir.reinterpret((uint64)0), (uint32)2, float32(0), 0), " # pylint: disable=line-too-long
- "tir.reinterpret((uint64)0))\n"
+ body = main_func.body.seq[1].seq[0].seq[0].value
+ assert repr(body) == (
+ 'T.call_cpacked("tvmgen_default_fused_multiply", '
+ "T.tvm_stack_make_array(x_buffer_var, T.tvm_stack_make_shape(10, 10), "
+ 'T.reinterpret("handle", T.uint64(0)), T.uint32(2), T.Cast("float32", 0), 0), '
+ "T.tvm_stack_make_array(y_buffer_var, T.tvm_stack_make_shape(1, 10), "
+ 'T.reinterpret("handle", T.uint64(0)), T.uint32(2), T.Cast("float32", 0), 0), '
+ "T.tvm_stack_make_array(output_buffer_var, T.tvm_stack_make_shape(10, 10), "
+ 'T.reinterpret("handle", T.uint64(0)), T.uint32(2), T.Cast("float32", 0), 0), '
+ 'T.reinterpret("handle", T.uint64(0)))'
)
diff --git a/tests/python/relay/aot/test_crt_aot.py b/tests/python/relay/aot/test_crt_aot.py
index b3db410156..2e7e23ead6 100644
--- a/tests/python/relay/aot/test_crt_aot.py
+++ b/tests/python/relay/aot/test_crt_aot.py
@@ -16,35 +16,34 @@
# under the License.
"""AOT with C Runtime Tests"""
-from collections import OrderedDict
-import re
import os
-import tarfile
import pathlib
+import re
+import tarfile
+from collections import OrderedDict
import numpy as np
import pytest
-
import tvm
-from tvm import relay, TVMError
+from tvm import TVMError, relay
from tvm.contrib import utils
+from tvm.ir.instrument import pass_instrument
from tvm.ir.module import IRModule
+from tvm.micro import export_model_library_format
+from tvm.micro import model_library_format as mlf
+from tvm.micro.testing.aot_test_utils import AOT_DEFAULT_RUNNER, parametrize_aot_options
+from tvm.micro.testing.utils import get_conv2d_relay_module
from tvm.relay import testing, transform
-from tvm.relay.testing import byoc
-from tvm.relay.op.annotation import compiler_begin, compiler_end
from tvm.relay.backend import Executor, Runtime
-from tvm.micro import model_library_format as mlf
-from tvm.micro import export_model_library_format
-from tvm.ir.instrument import pass_instrument
+from tvm.relay.op.annotation import compiler_begin, compiler_end
+from tvm.relay.testing import byoc
from tvm.testing.aot import (
AOTTestModel,
- generate_ref_data,
compile_and_run,
compile_models,
create_relay_module_and_inputs_from_tflite_file,
+ generate_ref_data,
)
-from tvm.micro.testing.aot_test_utils import AOT_DEFAULT_RUNNER, parametrize_aot_options
-from tvm.micro.testing.utils import get_conv2d_relay_module
def test_error_c_interface_with_packed_api():
@@ -985,8 +984,8 @@ def test_workspace_calculation_cmsis_nn():
pytest.importorskip("tflite")
# pylint: disable=import-outside-toplevel
- from tvm.relay.op.contrib import cmsisnn
from tvm.contrib.download import download_testdata
+ from tvm.relay.op.contrib import cmsisnn
# pylint: enable=import-outside-toplevel
@@ -1040,11 +1039,11 @@ def test_aot_codegen_checks_returns():
main_func = main_ir_module["__tvm_main__"]
# Check operator call is wrapped properly
+ body = main_func.body[1].seq[0].seq[0].value
assert (
- str(main_func.body[1])
- == "tir.tvm_check_return(0, -1, tir.call_extern("
- + '"tvmgen_default_fused_add",'
- + " x_buffer_var, y_buffer_var, output_buffer_var))\n"
+ repr(body)
+ == 'T.tvm_check_return(0, -1, T.call_extern("int32", "tvmgen_default_fused_add",'
+ + " x_buffer_var, y_buffer_var, output_buffer_var))"
)
# TODO(Mousius) - Create a better place for C codegen tests
assert (
diff --git a/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py b/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py
index a426befd4b..d87f9ec69e 100644
--- a/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py
+++ b/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py
@@ -17,7 +17,6 @@
import itertools
import pytest
-
import tvm
from tvm.script.printer.doc import (
AssertDoc,
@@ -62,7 +61,7 @@ def format_script(s: str) -> str:
cleaned_lines = "\n".join(line[spaces_to_remove:] for line in s.splitlines())
if not cleaned_lines.endswith("\n"):
cleaned_lines += "\n"
- return cleaned_lines
+ return cleaned_lines.strip()
@pytest.mark.parametrize(
diff --git a/tests/python/unittest/test_tvmscript_printer_tir.py b/tests/python/unittest/test_tvmscript_printer_tir.py
new file mode 100644
index 0000000000..fd3bb3788c
--- /dev/null
+++ b/tests/python/unittest/test_tvmscript_printer_tir.py
@@ -0,0 +1,638 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+from contextlib import contextmanager
+
+from tvm import ir, tir
+from tvm.ir import Range
+from tvm.script.ir_builder import IRBuilder
+from tvm.script.ir_builder import tir as T
+from tvm.script.printer import default
+
+
+@contextmanager
+def verbose_expr():
+ try:
+ default.verbose_expr(True)
+ yield
+ finally:
+ default.verbose_expr(False)
+
+
+def _assert_print(obj, expected):
+ with verbose_expr():
+ assert repr(obj).strip() == expected.strip()
+
+
+def test_prim_func():
+ a = tir.Var("a", "handle")
+ b = tir.Var("b", "handle")
+ func = tir.PrimFunc(
+ params=[a, b],
+ ret_type=None,
+ buffer_map={
+ a: tir.decl_buffer(shape=[128, 128], dtype="float32", name="A"),
+ b: tir.decl_buffer(shape=[256, 256], dtype="float32", name="B"),
+ },
+ body=tir.Evaluate(0),
+ )
+ _assert_print(
+ func,
+ expected="""
+@T.prim_func
+def main(a: T.handle, b: T.handle) -> None:
+ A = T.match_buffer(a, (128, 128))
+ B = T.match_buffer(b, (256, 256))
+ T.evaluate(0)""",
+ )
+
+
+def test_block_realize():
+ i = tir.Var("i", "int32")
+ j = tir.Var("j", "int32")
+ k = tir.Var("k", "int32")
+ with IRBuilder() as ib:
+ with T.block(name="block", no_realize=False):
+ vi = ib.name("vi", T.axis.spatial(128, i))
+ vj = ib.name("vj", T.axis.spatial(64, j))
+ vk = ib.name("vk", T.axis.reduce(32, k))
+ T.reads()
+ T.writes()
+ T.evaluate(0)
+ obj = ib.get()
+ _assert_print(
+ obj,
+ """
+i = T.var("int32")
+j = T.var("int32")
+k = T.var("int32")
+with T.block("block"):
+ vi = T.axis.spatial(128, i)
+ vj = T.axis.spatial(64, j)
+ vk = T.axis.reduce(32, k)
+ T.reads()
+ T.writes()
+ T.evaluate(0)""",
+ )
+
+
+def test_block():
+ i = tir.Var("i", "int32")
+ j = tir.Var("j", "int32")
+ k = tir.Var("k", "int32")
+ with IRBuilder() as ib:
+ with T.block(name="block", no_realize=False):
+ vi = ib.name("vi", T.axis.spatial(128, i))
+ vj = ib.name("vj", T.axis.spatial(64, j))
+ vk = ib.name("vk", T.axis.reduce(32, k))
+ T.reads()
+ T.writes()
+ T.evaluate(0)
+ obj = ib.get().block
+ _assert_print(
+ obj,
+ """
+with T.block("block", no_realize=True):
+ vi = T.axis.spatial(128)
+ vj = T.axis.spatial(64)
+ vk = T.axis.reduce(32)
+ T.reads()
+ T.writes()
+ T.evaluate(0)""",
+ )
+
+
+def test_match_buffer_region():
+ src = tir.decl_buffer((128, 128), "float32", name="src")
+ tgt = tir.decl_buffer((64, 64), "float32", name="tgt")
+ obj = tir.MatchBufferRegion(
+ tgt,
+ tir.BufferRegion(
+ src,
+ [
+ Range(64, 128),
+ Range(64, 128),
+ ],
+ ),
+ )
+ _assert_print(
+ obj,
+ """
+src = T.buffer_decl((128, 128))
+tgt = T.match_buffer(src[64:128, 64:128], (64, 64))
+""",
+ )
+
+
+def test_buffer():
+ a = tir.decl_buffer((128, 128), "float16", name="A")
+ _assert_print(
+ a,
+ """A = T.buffer_decl((128, 128), "float16")
+A""",
+ )
+
+
+def test_buffer_region():
+ src = tir.decl_buffer((128, 128), "float32", name="src")
+ obj = tir.BufferRegion(
+ src,
+ [
+ Range(64, 128),
+ Range(64, 128),
+ ],
+ )
+ _assert_print(
+ obj,
+ """
+src = T.buffer_decl((128, 128))
+src[64:128, 64:128]
+""",
+ )
+
+
+def test_buffer_load():
+ a = tir.decl_buffer((128, 128), "float16", name="A")
+ obj = tir.BufferLoad(a, [128, 128])
+ _assert_print(
+ obj,
+ """
+A = T.buffer_decl((128, 128), "float16")
+A[128, 128]
+""",
+ )
+
+
+def test_buffer_store():
+ a = tir.decl_buffer((128, 128), "float16", name="A")
+ with IRBuilder() as ib:
+ T.buffer_store(a, a[128, 128] + 1, [128, 128])
+ obj = ib.get()
+ _assert_print(
+ obj,
+ """
+A = T.buffer_decl((128, 128), "float16")
+A[128, 128] = A[128, 128] + T.float16(1)
+""",
+ )
+
+
+def test_for():
+ with IRBuilder() as ib:
+ with T.grid(128, 128, 128) as (i, j, k):
+ ib.name_many(["i", "j", "k"], [i, j, k])
+ T.evaluate(0)
+ obj = ib.get()
+ _assert_print(
+ obj,
+ """
+for i, j, k in T.grid(128, 128, 128):
+ T.evaluate(0)
+""",
+ )
+
+
+def test_let_stmt():
+ with IRBuilder() as ib:
+ with T.let(T.var("float32"), T.float32(10)):
+ T.evaluate(0)
+ obj = ib.get()
+ _assert_print(
+ obj,
+ """
+with T.let(v, T.float32(10)):
+ T.evaluate(0)
+""",
+ )
+
+
+def test_attr_stmt():
+ with IRBuilder() as ib:
+ with T.attr("pragma", "unroll", 1):
+ T.evaluate(0)
+ obj = ib.get()
+ _assert_print(
+ obj,
+ """
+with T.attr("pragma", "unroll", 1):
+ T.evaluate(0)
+""",
+ )
+
+
+def test_assert_stmt():
+ with IRBuilder() as ib:
+ with T.Assert(1, "assertion"):
+ T.evaluate(0)
+ obj = ib.get()
+ _assert_print(
+ obj,
+ """
+with T.Assert(1, "assertion"):
+ T.evaluate(0)
+""",
+ )
+
+
+def test_while():
+ with IRBuilder() as ib:
+ x = T.var("int32")
+ with T.While(x < 10):
+ T.evaluate(0)
+ obj = ib.get()
+ _assert_print(
+ obj,
+ """
+v = T.var("int32")
+while v < 10:
+ T.evaluate(0)
+""",
+ )
+
+
+def test_allocate():
+ with IRBuilder() as ib:
+ with T.allocate([128, 128], "float32"):
+ T.evaluate(0)
+ obj = ib.get()
+ _assert_print(
+ obj,
+ """
+with T.allocate([128, 128], "float32", "global") as v:
+ T.evaluate(0)
+""",
+ )
+
+
+def test_decl_buffer():
+ with IRBuilder() as ib:
+ with T.decl_buffer((10, 10), data=T.ptr("float32")):
+ T.evaluate(0)
+ obj = ib.get()
+ _assert_print(
+ obj,
+ """
+with T.decl_buffer((10, 10)) as buffer:
+ T.evaluate(0)
+""",
+ )
+
+
+def test_prefetch():
+ a = tir.decl_buffer((128, 128), "float16", name="A")
+ with IRBuilder() as ib:
+ T.prefetch(a, [Range(0, 64), Range(0, 64)])
+ obj = ib.get()
+ _assert_print(
+ obj,
+ """
+A = T.buffer_decl((128, 128), "float16")
+T.prefetch(A, [T.Range(0, 64), T.Range(0, 64)])
+""",
+ )
+
+
+def test_seq_stmt():
+ with IRBuilder() as ib:
+ with T.serial(10):
+ T.evaluate(0)
+ T.evaluate(1)
+ obj = ib.get().body
+ _assert_print(
+ obj,
+ """
+T.evaluate(0)
+T.evaluate(1)
+""",
+ )
+
+
+def test_if_then_else():
+ with IRBuilder() as ib:
+ with T.If(T.var("int32") == 1):
+ with T.Then():
+ T.evaluate(0)
+
+ obj = ib.get()
+ _assert_print(
+ obj,
+ """
+v = T.var("int32")
+if v == 1:
+ T.evaluate(0)
+""",
+ )
+
+
+def test_evaluate():
+ with IRBuilder() as ib:
+ T.evaluate(0)
+ obj = ib.get()
+ _assert_print(
+ obj,
+ """
+T.evaluate(0)
+""",
+ )
+
+
+def test_buffer_realize():
+ with IRBuilder() as ib:
+ a = tir.decl_buffer((128, 128), "float32", name="A")
+ with T.realize(a[0:128, 0:128], "test_storage_scope", True):
+ T.evaluate(0)
+ obj = ib.get()
+ _assert_print(
+ obj,
+ """
+A = T.buffer_decl((128, 128))
+with T.realize(A[0:128, 0:128], "test_storage_scope"):
+ T.evaluate(0)
+""",
+ )
+
+
+def test_var():
+ a = tir.Var("a", "float32")
+ _assert_print(
+ a,
+ """
+a = T.var("float32")
+a""",
+ )
+
+
+def test_size_var():
+ a = tir.SizeVar("a", "float32")
+ _assert_print(
+ a,
+ """
+a = T.var("float32")
+a""",
+ )
+
+
+def test_iter_var():
+ a = tir.IterVar((0, 8), "a", iter_type=tir.IterVar.DataPar)
+ _assert_print(
+ a,
+ """
+a = T.var("int32")
+T.iter_var(a, T.Range(0, 8), "DataPar", "")
+""",
+ )
+
+
+def test_string_imm():
+ s = tir.StringImm("str")
+ _assert_print(s, '"str"')
+
+
+def test_cast():
+ obj = tir.Cast("float64", tir.Var("a", "float32"))
+ _assert_print(
+ obj,
+ """
+a = T.var("float32")
+T.Cast("float64", a)
+""",
+ )
+
+
+def test_binary_arith():
+ a = tir.Var("a", "float32")
+ b = tir.Var("b", "float32")
+ for op, sign in [
+ (tir.Add, "+"),
+ (tir.Sub, "-"),
+ (tir.Mul, "*"),
+ (tir.Div, "/"),
+ (tir.Mod, "truncmod"),
+ (tir.FloorDiv, "//"),
+ (tir.FloorMod, "%"),
+ (tir.LT, "<"),
+ (tir.LE, "<="),
+ (tir.EQ, "=="),
+ (tir.NE, "!="),
+ (tir.GT, ">"),
+ (tir.GE, ">="),
+ ]:
+ obj = op(a, b)
+ if sign.isalpha():
+ expected = """
+a = T.var("float32")
+b = T.var("float32")
+T.{}(a, b)""".format(
+ sign
+ )
+ else:
+ expected = """
+a = T.var("float32")
+b = T.var("float32")
+a {} b""".format(
+ sign
+ )
+ _assert_print(obj, expected)
+
+
+def test_logical():
+ a = T.var("bool", "a")
+ b = T.var("bool", "b")
+ _assert_print(
+ tir.And(a, b),
+ """
+a = T.var("bool")
+b = T.var("bool")
+a and b
+""",
+ )
+ _assert_print(
+ tir.Or(a, b),
+ """
+a = T.var("bool")
+b = T.var("bool")
+a or b
+""",
+ )
+ _assert_print(
+ tir.Not(a),
+ """
+a = T.var("bool")
+not a
+""",
+ )
+
+
+def test_select():
+ obj = tir.Select(True, 0, 2)
+ _assert_print(
+ obj,
+ """T.Select(True, 0, 2)
+""",
+ )
+
+
+def test_ramp():
+ a = tir.Var("a", "int32")
+ obj = tir.Ramp(a, 1, 32)
+ _assert_print(
+ obj,
+ """
+a = T.var("int32")
+T.Ramp(a, 1, 32)
+""",
+ )
+
+
+def test_broadcast():
+ obj = tir.Broadcast(0, 4)
+ _assert_print(
+ obj,
+ """
+T.Broadcast(0, 4)
+""",
+ )
+
+
+def test_let_expr():
+ x = tir.Var("x", "int32")
+ obj = tir.Let(x, 1, x + 1)
+ _assert_print(
+ obj,
+ """
+x = T.var("int32")
+T.let(x, 1, x + 1)
+""",
+ )
+
+
+def test_call():
+ obj = tir.atan(T.float32(1.0))
+ _assert_print(
+ obj,
+ """
+T.atan(T.float32(1))
+""",
+ )
+
+
+def test_comm_reducer():
+ obj = T.comm_reducer(lambda x, y: x + y, identity=[T.float32(0)])
+ _assert_print(
+ obj,
+ """
+T.comm_reducer(lambda x, y: x + y, [T.float32(0)])
+""",
+ )
+
+
+def test_any():
+ obj = tir.Any()
+ _assert_print(
+ obj,
+ """
+T.Any()
+""",
+ )
+
+
+def test_int_imm():
+ obj = T.int16(1)
+ _assert_print(
+ obj,
+ """
+T.int16(1)
+""",
+ )
+
+
+def test_float_imm():
+ obj = T.float16(1)
+ _assert_print(
+ obj,
+ """
+T.float16(1)
+""",
+ )
+
+
+def test_range():
+ obj = Range(0, 10)
+ _assert_print(
+ obj,
+ """
+T.Range(0, 10)
+""",
+ )
+
+
+def test_prim_type():
+ obj = ir.PrimType("float32")
+ _assert_print(obj, "T.float32")
+
+
+def test_pointer_type():
+ obj = ir.PointerType(ir.PrimType("int32"), "global")
+ _assert_print(obj, 'T.Ptr("int32", "global")')
+
+
+def test_tuple_type():
+ obj = ir.TupleType([ir.PrimType("float32"), ir.PrimType("int32")])
+ _assert_print(obj, "T.Tuple(T.float32, T.int32)")
+
+
+if __name__ == "__main__":
+ test_prim_func()
+ test_block_realize()
+ test_block()
+ test_buffer()
+ test_buffer_region()
+ test_buffer_load()
+ test_buffer_store()
+ test_match_buffer_region()
+ test_for()
+ test_let_stmt()
+ test_attr_stmt()
+ test_assert_stmt()
+ test_while()
+ test_allocate()
+ test_decl_buffer()
+ test_prefetch()
+ test_seq_stmt()
+ test_if_then_else()
+ test_evaluate()
+ test_buffer_realize()
+ test_var()
+ test_size_var()
+ test_iter_var()
+ test_string_imm()
+ test_cast()
+ test_binary_arith()
+ test_logical()
+ test_select()
+ test_ramp()
+ test_broadcast()
+ test_let_expr()
+ test_call()
+ test_comm_reducer()
+ test_any()
+ test_int_imm()
+ test_float_imm()
+ test_range()
+ test_prim_type()
+ test_pointer_type()
+ test_tuple_type()
diff --git a/tests/python/unittest/test_tvmscript_printer_underlining.py b/tests/python/unittest/test_tvmscript_printer_underlining.py
index a7e7dffb8b..467aad2df5 100644
--- a/tests/python/unittest/test_tvmscript_printer_underlining.py
+++ b/tests/python/unittest/test_tvmscript_printer_underlining.py
@@ -18,14 +18,13 @@
from typing import Optional
import pytest
-
from tvm.runtime import ObjectPath
from tvm.script.printer.doc import (
- StmtBlockDoc,
ExprStmtDoc,
IdDoc,
OperationDoc,
OperationKind,
+ StmtBlockDoc,
)
from tvm.script.printer.doc_printer import to_python_script
@@ -59,7 +58,7 @@ def format_script(s: str) -> str:
cleaned_lines = "\n".join(line[spaces_to_remove:] for line in s.splitlines())
if not cleaned_lines.endswith("\n"):
cleaned_lines += "\n"
- return cleaned_lines
+ return cleaned_lines.strip()
def test_underline_basic():
@@ -290,8 +289,10 @@ def test_print_two_context_lines(to_underline, expected_text):
def test_underline_and_print_line_numbers():
doc = StmtBlockDoc([ExprStmtDoc(make_id_doc(f"line{i + 1}")) for i in range(12)])
result = to_python_script(doc, print_line_numbers=True, path_to_underline=make_path("line6"))
- assert result == format_script(
- """
+ assert (
+ result.strip()
+ == format_script(
+ """
1 line1
2 line2
3 line3
@@ -306,6 +307,7 @@ def test_underline_and_print_line_numbers():
11 line11
12 line12
"""
+ ).strip()
)
diff --git a/vta/python/vta/transform.py b/vta/python/vta/transform.py
index 38d58179c4..b1135c0eb0 100644
--- a/vta/python/vta/transform.py
+++ b/vta/python/vta/transform.py
@@ -729,7 +729,7 @@ def InjectConv2DTransposeSkip():
def _do_fold(op):
if _match_pragma(op, "conv2d_transpose_gemm"):
- is_init = ".init" in str(op)
+ is_init = "_init" in str(op)
tvm.tir.stmt_functor.post_order_visit(op, _find_basics)
if is_init: