You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2023/01/08 18:57:40 UTC
[tvm] branch main updated: [TVMScript] Refactor IRDocsifier (#13593)
This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new a99f0c1545 [TVMScript] Refactor IRDocsifier (#13593)
a99f0c1545 is described below
commit a99f0c15458653896c0bbe00ebf91d144c37aff2
Author: Junru Shao <ju...@gmail.com>
AuthorDate: Sun Jan 8 10:57:33 2023 -0800
[TVMScript] Refactor IRDocsifier (#13593)
This PR refactors the TVMScript printer and includes the following
changes:
- Consolidate the logics of VarTable into IRDocsifier
- Decouple TracedObject into Object and ObjectPath for less syntactic
noise
- Restructure the folder to ensure logics and consistency
Some tests removed because the APIs do not exist any more due to the
consolidation.
---
include/tvm/script/printer.h | 56 ---
include/tvm/script/printer/doc.h | 70 +--
include/tvm/script/printer/doc_printer.h | 48 --
include/tvm/script/printer/frame.h | 140 ------
include/tvm/script/printer/ir_docsifier.h | 308 +++++++------
include/tvm/script/printer/ir_docsifier_functor.h | 163 +++++++
include/tvm/script/printer/printer.h | 86 ++++
include/tvm/script/printer/traced_object.h | 484 ---------------------
include/tvm/script/printer/traced_object_functor.h | 175 --------
include/tvm/script/printer/var_table.h | 155 -------
include/tvm/support/with.h | 29 --
include/tvm/tir/op.h | 3 +
include/tvm/tir/op_attr_types.h | 5 +
python/tvm/script/__init__.py | 5 +-
python/tvm/script/ir_builder/tir/ir.py | 121 +++---
python/tvm/script/printer/__init__.py | 7 +-
python/tvm/script/printer/entry.py | 71 ---
python/tvm/script/printer/frame.py | 81 ----
python/tvm/script/printer/ir_docsifier.py | 245 -----------
python/tvm/script/printer/printer.py | 64 +++
python/tvm/script/printer/var_table.py | 118 -----
src/script/printer/doc.cc | 10 +-
.../printer/{ => doc_printer}/base_doc_printer.cc | 0
.../printer/{ => doc_printer}/base_doc_printer.h | 7 +-
.../{ => doc_printer}/python_doc_printer.cc | 11 +-
src/script/printer/frame.cc | 50 ---
src/script/printer/ir/ir.cc | 74 ++++
src/script/printer/ir/misc.cc | 77 ++++
src/script/{printer.cc => printer/ir/utils.h} | 49 ++-
src/script/printer/ir_docsifier.cc | 184 +++++---
src/script/{ => printer}/printer.cc | 34 +-
src/script/printer/tir/block.cc | 150 +++++++
src/script/printer/tir/buffer.cc | 193 ++++++++
src/script/printer/tir/expr.cc | 299 +++++++++++++
src/script/printer/tir/for_loop.cc | 122 ++++++
src/script/printer/tir/function.cc | 86 ++++
src/script/printer/tir/ir.cc | 97 +++++
src/script/printer/tir/stmt.cc | 374 ++++++++++++++++
src/script/printer/tir/utils.h | 176 ++++++++
src/script/printer/traced_object_functor.cc | 85 ----
src/script/printer/utils.h | 93 ----
src/script/printer/var_table.cc | 109 -----
src/tir/ir/stmt.cc | 2 +-
src/tir/op/builtin.cc | 16 +-
src/tir/op/op.cc | 92 ++--
src/tir/op/runtime.cc | 41 --
tests/cpp/traced_object_test.cc | 268 ------------
tests/cpp/tvmscript_printer_irdocsifier_test.cc | 117 -----
...tvmscript_printer_traced_object_functor_test.cc | 188 --------
tests/cpp/tvmscript_printer_var_table_test.cc | 158 -------
.../unittest/test_tvmscript_printer_entry_point.py | 30 --
.../unittest/test_tvmscript_printer_frame.py | 60 ---
.../unittest/test_tvmscript_printer_irdocsifier.py | 123 ------
.../unittest/test_tvmscript_printer_var_table.py | 89 ----
54 files changed, 2466 insertions(+), 3432 deletions(-)
diff --git a/include/tvm/script/printer.h b/include/tvm/script/printer.h
deleted file mode 100644
index b0fc54108c..0000000000
--- a/include/tvm/script/printer.h
+++ /dev/null
@@ -1,56 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-#ifndef TVM_SCRIPT_PRINTER_H_
-#define TVM_SCRIPT_PRINTER_H_
-
-#include <tvm/node/node.h>
-#include <tvm/node/object_path.h>
-
-namespace tvm {
-namespace script {
-namespace printer {
-
-/*!
- * \brief Print IR graph as TVMScript code
- *
- * \param root_node The root node to print.
- * \param ir_name The dispatch token of the target IR, e.g., "tir", "relax".
- * \param ir_prefix The symbol name for TVMScript IR namespaces. For example, {"tir": "T"}.
- * \param indent_spaces Number of spaces used for indentation
- * \param print_line_numbers Whether to print line numbers
- * \param num_context_lines Number of context lines to print around the underlined text
- * \param path_to_underline Object path to be underlined
- *
- * \return the TVMScript code as string.
- */
-String Script( //
- const ObjectRef& root_node, //
- String ir_name, //
- Map<String, String> ir_prefix, //
- int indent_spaces = 4, //
- bool print_line_numbers = false, //
- int num_context_lines = -1, //
- Optional<ObjectPath> path_to_underline = NullOpt //
-);
-
-} // namespace printer
-} // namespace script
-} // namespace tvm
-
-#endif // TVM_SCRIPT_PRINTER_H_
diff --git a/include/tvm/script/printer/doc.h b/include/tvm/script/printer/doc.h
index 1ee7fd6a7f..094d3fdf51 100644
--- a/include/tvm/script/printer/doc.h
+++ b/include/tvm/script/printer/doc.h
@@ -22,12 +22,13 @@
#include <tvm/ir/expr.h>
#include <tvm/node/node.h>
#include <tvm/runtime/data_type.h>
-#include <tvm/script/printer/traced_object.h>
namespace tvm {
namespace script {
namespace printer {
+class Doc;
+
/*!
* \brief The base class of all Doc.
*
@@ -88,15 +89,6 @@ class ExprDocNode : public DocNode {
*/
ExprDoc Attr(String attr) const;
- /*!
- * \brief Create a doc representing attribute access on the current ExprDoc
- * \param attr The attribute to access.
- *
- * The ObjectPath of attr will be pushed to the source_path of the returned
- * doc.
- */
- ExprDoc Attr(TracedObject<String> attr) const;
-
/*!
* \brief Create a doc representing index access on the current ExprDoc
* \param indices The indices to access.
@@ -259,83 +251,33 @@ class LiteralDoc : public ExprDoc {
* \brief Create a LiteralDoc to represent None/null/empty value.
*/
static LiteralDoc None() { return LiteralDoc(ObjectRef(nullptr)); }
-
- /*!
- * \brief Create a LiteralDoc to represent None/null/empty value.
- * \param object_path The source path of the returned Doc.
- */
- static LiteralDoc None(ObjectPath object_path) {
- return LiteralDoc(ObjectRef(nullptr), object_path);
- }
-
/*!
* \brief Create a LiteralDoc to represent integer.
* \param v The integer value.
*/
- static LiteralDoc Int(int v) { return LiteralDoc(IntImm(DataType::Int(64), v)); }
-
- /*!
- * \brief Create a LiteralDoc to represent integer.
- * \param v The integer value.
- *
- * The ObjectPath of v will be pushed to the source_path of the returned doc.
- */
- static LiteralDoc Int(const TracedObject<IntImm>& v) { return LiteralDoc(v.Get(), v.GetPath()); }
-
- /*!
- * \brief Create a LiteralDoc to represent integer.
- * \param v The integer value.
- *
- * The ObjectPath of v will be pushed to the source_path of the returned doc.
- */
- static LiteralDoc Int(const TracedBasicValue<int>& v) {
- return LiteralDoc(IntImm(DataType::Int(64), v.Get()), v.GetPath());
- }
+ static LiteralDoc Int(int64_t v) { return LiteralDoc(IntImm(DataType::Int(64), v)); }
/*!
* \brief Create a LiteralDoc to represent boolean.
* \param v The boolean value.
*/
static LiteralDoc Boolean(bool v) { return LiteralDoc(IntImm(DataType::Bool(), v)); }
-
- /*!
- * \brief Create a LiteralDoc to represent boolean.
- * \param v The boolean value.
- *
- * The ObjectPath of v will be pushed to the source_path of the returned doc.
- */
- static LiteralDoc Boolean(const TracedBasicValue<bool>& v) {
- return LiteralDoc(IntImm(DataType::Bool(), v.Get()), v.GetPath());
- }
-
/*!
* \brief Create a LiteralDoc to represent float.
* \param v The float value.
*/
static LiteralDoc Float(double v) { return LiteralDoc(FloatImm(DataType::Float(64), v)); }
-
- /*!
- * \brief Create a LiteralDoc to represent float.
- * \param v The float value.
- *
- * The ObjectPath of v will be pushed to the source_path of the returned doc.
- */
- static LiteralDoc Float(const TracedObject<FloatImm>& v) {
- return LiteralDoc(v.Get(), v.GetPath());
- }
-
/*!
* \brief Create a LiteralDoc to represent string.
* \param v The string value.
*/
static LiteralDoc Str(const String& v) { return LiteralDoc(v); }
-
/*!
* \brief Create a LiteralDoc to represent string.
* \param v The string value.
- *
- * The ObjectPath of v will be pushed to the source_path of the returned doc.
*/
- static LiteralDoc Str(const TracedObject<String>& v) { return LiteralDoc(v.Get(), v.GetPath()); }
+ static LiteralDoc DataType(const DLDataType& v) {
+ return LiteralDoc::Str(runtime::DLDataType2String(v));
+ }
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(LiteralDoc, ExprDoc, LiteralDocNode);
};
diff --git a/include/tvm/script/printer/doc_printer.h b/include/tvm/script/printer/doc_printer.h
deleted file mode 100644
index 04a67a9b82..0000000000
--- a/include/tvm/script/printer/doc_printer.h
+++ /dev/null
@@ -1,48 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-#ifndef TVM_SCRIPT_PRINTER_DOC_PRINTER_H_
-#define TVM_SCRIPT_PRINTER_DOC_PRINTER_H_
-
-#include <tvm/script/printer/doc.h>
-
-namespace tvm {
-namespace script {
-namespace printer {
-
-/*!
- * \brief Convert Doc into Python script.
- *
- * This function unpacks the DocPrinterOptions into function arguments
- * to be FFI friendly.
- *
- * \param doc Doc to be converted
- * \param indent_spaces Number of spaces used for indentation
- * \param print_line_numbers Whether to print line numbers
- * \param num_context_lines Number of context lines to print around the underlined text
- * \param path_to_underline Object path to be underlined
- */
-String DocToPythonScript(Doc doc, int indent_spaces = 4, bool print_line_numbers = false,
- int num_context_lines = -1,
- Optional<ObjectPath> path_to_underline = NullOpt);
-
-} // namespace printer
-} // namespace script
-} // namespace tvm
-
-#endif // TVM_SCRIPT_PRINTER_DOC_PRINTER_H_
diff --git a/include/tvm/script/printer/frame.h b/include/tvm/script/printer/frame.h
deleted file mode 100644
index 407ad16007..0000000000
--- a/include/tvm/script/printer/frame.h
+++ /dev/null
@@ -1,140 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-#ifndef TVM_SCRIPT_PRINTER_FRAME_H_
-#define TVM_SCRIPT_PRINTER_FRAME_H_
-
-#include <tvm/node/node.h>
-#include <tvm/script/printer/doc.h>
-
-#include <utility>
-#include <vector>
-
-namespace tvm {
-namespace script {
-namespace printer {
-
-/*!
- * Frame is the core data structure for semantic information
- * when printing IR graph into TVMScript code.
- */
-class FrameNode : public Object {
- public:
- void VisitAttrs(tvm::AttrVisitor* v) {}
-
- virtual ~FrameNode() = default;
-
- /*!
- * \brief Add a callback function to be called when this frame exits.
- * \param cb The callback function. It should have signature void().
- */
- template <typename TCallback>
- void AddExitCallback(TCallback&& cb) {
- callbacks_.emplace_back(std::forward<TCallback>(cb));
- }
-
- /*!
- * \brief Method that's called when Frame enters the scope.
- */
- virtual void EnterWithScope() {}
-
- /*!
- * \brief Method that's called when Frame exits the scope.
- */
- virtual void ExitWithScope() {
- for (const std::function<void()>& callback : callbacks_) {
- callback();
- }
- callbacks_.clear();
- }
-
- static constexpr const char* _type_key = "script.printer.Frame";
- TVM_DECLARE_BASE_OBJECT_INFO(FrameNode, Object);
-
- private:
- std::vector<std::function<void()>> callbacks_;
-};
-
-/*!
- * \brief Reference type of FrameNode
- */
-class Frame : public ObjectRef {
- protected:
- Frame() = default;
-
- public:
- virtual ~Frame() = default;
- TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Frame, ObjectRef, FrameNode);
-};
-
-/*!
- * \brief MetadataFrame contains information like contant parameter array.
- */
-class MetadataFrameNode : public FrameNode {
- public:
- Array<ObjectRef> metadata;
-
- void VisitAttrs(tvm::AttrVisitor* v) {
- FrameNode::VisitAttrs(v);
- v->Visit("metadata", &metadata);
- }
-
- static constexpr const char* _type_key = "script.printer.MetadataFrame";
- TVM_DECLARE_FINAL_OBJECT_INFO(MetadataFrameNode, FrameNode);
-};
-
-/*!
- * \brief Reference type of MetadataFrameNode
- */
-class MetadataFrame : public Frame {
- public:
- MetadataFrame();
- TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(MetadataFrame, Frame, MetadataFrameNode);
-};
-
-/*!
- * \brief VarDefFrame contains information about the free variables that needs to be defined
- * at the beginning of the printed snippet.
- */
-class VarDefFrameNode : public FrameNode {
- public:
- Array<StmtDoc> stmts;
-
- void VisitAttrs(tvm::AttrVisitor* v) {
- FrameNode::VisitAttrs(v);
- v->Visit("stmts", &stmts);
- }
-
- static constexpr const char* _type_key = "script.printer.VarDefFrame";
- TVM_DECLARE_FINAL_OBJECT_INFO(VarDefFrameNode, FrameNode);
-};
-
-/*!
- * \brief Reference type of VarDefFrameNode
- */
-class VarDefFrame : public Frame {
- public:
- VarDefFrame();
- TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(VarDefFrame, Frame, VarDefFrameNode);
-};
-
-} // namespace printer
-} // namespace script
-} // namespace tvm
-
-#endif // TVM_SCRIPT_PRINTER_FRAME_H_
diff --git a/include/tvm/script/printer/ir_docsifier.h b/include/tvm/script/printer/ir_docsifier.h
index 8945bd6e7a..e97ddc0234 100644
--- a/include/tvm/script/printer/ir_docsifier.h
+++ b/include/tvm/script/printer/ir_docsifier.h
@@ -19,45 +19,117 @@
#ifndef TVM_SCRIPT_PRINTER_IR_DOCSIFIER_H_
#define TVM_SCRIPT_PRINTER_IR_DOCSIFIER_H_
+#include <tvm/ir/module.h>
#include <tvm/node/node.h>
-#include <tvm/runtime/logging.h>
#include <tvm/script/printer/doc.h>
-#include <tvm/script/printer/frame.h>
-#include <tvm/script/printer/traced_object.h>
-#include <tvm/script/printer/traced_object_functor.h>
-#include <tvm/script/printer/var_table.h>
-#include <tvm/support/with.h>
+#include <tvm/script/printer/ir_docsifier_functor.h>
+
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+#include <vector>
namespace tvm {
namespace script {
namespace printer {
-using WithCtx = With<ContextManager>;
+//////////////////////// Frame ////////////////////////
+
+class IRDocsifier;
+class IRDocsifierNode;
+
+/*!
+ * Frame is the core data structure for semantic information
+ * when printing IR graph into TVMScript code.
+ */
+class FrameNode : public Object {
+ public:
+ /*! The docs generated in the frame */
+ Array<StmtDoc> stmts;
+ /*! The corresponding IRDocsifier */
+ IRDocsifierNode* d;
+ /*! The callbacks that are going to be invoked when the frame exits */
+ std::vector<std::function<void()>> callbacks;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {
+ v->Visit("stmts", &stmts);
+ // `d` is not visited
+ // `callbacks` is not visited
+ }
+
+ static constexpr const char* _type_key = "script.printer.Frame";
+ TVM_DECLARE_BASE_OBJECT_INFO(FrameNode, Object);
+
+ public:
+ virtual ~FrameNode() = default;
+
+ /*!
+ * \brief Add a callback function to be called when this frame exits.
+ * \param cb The callback function. It should have signature void().
+ */
+ template <typename TCallback>
+ void AddExitCallback(TCallback&& cb) {
+ callbacks.emplace_back(std::forward<TCallback>(cb));
+ }
+ /*!
+ * \brief Add a dispatch token to the docsifier, and a callback that pops the token when this
+ * frame exits.
+ * \param d The docsifier.
+ * \param token The token to be added.
+ */
+ void AddDispatchToken(const IRDocsifier& d, const String& token);
+ /*!
+ * \brief Method that's called when Frame enters the scope.
+ */
+ virtual void EnterWithScope();
+ /*!
+ * \brief Method that's called when Frame exits the scope.
+ */
+ virtual void ExitWithScope();
+};
+
+/*!
+ * \brief Reference type of FrameNode
+ */
+class Frame : public ObjectRef {
+ protected:
+ Frame() = default;
+
+ public:
+ virtual ~Frame() = default;
+
+ /*! \brief Method that's called when Frame enters the scope. */
+ void EnterWithScope() { get()->EnterWithScope(); }
+
+ /*! \brief Method that's called when Frame exits the scope. */
+ void ExitWithScope() { get()->ExitWithScope(); }
+
+ TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(Frame, ObjectRef, FrameNode);
+};
+
+//////////////////////// IRDocsifier ////////////////////////
/*!
* \brief IRDocsifier is the top-level interface in the IR->Doc process.
*
* It provides methods to convert IR node object to Doc, operate on Frame
* objects and change dispatch tokens.
- *
- * Example usage:
- * \code
- * TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
- * .set_dispatch([](TracedObject<tir::Var> obj, IRDocsifier p) { return IdDoc("x"); });
- *
- * TracedObject<tir::Var> var = ...;
- * IRDocsifier p;
- * p->AsDoc(var); // returns an IdDoc("x")
- * \endcode
- *
*/
class IRDocsifierNode : public Object {
public:
+ /*! \brief A function that creates the doc for a variable */
+ using DocCreator = std::function<ExprDoc()>;
+ /*! \brief Information about a variable, including its optional name and its doc creator */
+ struct VariableInfo {
+ /*! \brief The creator */
+ DocCreator creator;
+ /*! \brief The name of the variable */
+ Optional<String> name;
+ };
/*!
- * \brief The var table to use during the printing process.
- * \sa VarTableNode
+ * \brief This map connects IR dispatch token to the name of identifier.
*/
- VarTable vars;
+ Map<String, String> ir_prefix;
/*!
* \brief The stack of frames.
* \sa FrameNode
@@ -70,16 +142,23 @@ class IRDocsifierNode : public Object {
* when converting IR node object to Doc.
*/
Array<String> dispatch_tokens;
- /*!
- * \brief This map connects IR dipatch token to the name of identifier.
- */
- Map<String, String> ir_prefix;
+ /*! \brief The IRModule to be docsifier is handling */
+ Optional<IRModule> mod;
+ /*! \brief Mapping from a var to its info */
+ std::unordered_map<ObjectRef, VariableInfo, ObjectPtrHash, ObjectPtrEqual> obj2info;
+ /*! \brief The variable names used already */
+ std::unordered_set<String> defined_names;
+ /*! \brief Common prefixes of variable usages */
+ std::unordered_map<const Object*, std::vector<const Object*>> common_prefix;
void VisitAttrs(tvm::AttrVisitor* v) {
- v->Visit("vars", &vars);
+ v->Visit("ir_prefix", &ir_prefix);
v->Visit("frames", &frames);
v->Visit("dispatch_tokens", &dispatch_tokens);
- v->Visit("ir_prefix", &ir_prefix);
+ v->Visit("mod", &mod);
+ // `obj2info` is not visited
+ // `defined_names` is not visited
+ // `common_prefix` is not visited
}
static constexpr const char* _type_key = "script.printer.IRDocsifier";
@@ -87,79 +166,68 @@ class IRDocsifierNode : public Object {
public:
/*!
- * \brief Transform the input object into TDoc.
- * \param obj The object to be transformed.
+ * \brief Define variable by name.
+ * \param obj The variable object.
+ * \param frame The frame that this variable is defined in.
+ * \param name_hint The hint for variable name.
*
- * \return The Doc object.
+ * \return The id doc for this variable.
+ *
+ * This function will rename the variable to avoid name conflict with other variables
+ * in the table.
*/
- template <class TDoc>
- TDoc AsDoc(const TracedObject<ObjectRef>& obj) const {
- auto result = Downcast<TDoc>(AsDocImpl(obj));
- result->source_paths.push_back(obj.GetPath());
- return result;
- }
+ IdDoc Define(const ObjectRef& obj, const Frame& frame, const String& name_hint);
/*!
- * \brief Helper method to transform object into ExprDoc.
- * \param obj The object to be transformed.
+ * \brief Define variable by doc factory.
+ * \param obj The variable object.
+ * \param frame The frame that this variable is defined in.
+ * \param doc_factory The function to return an ExprDoc object for this variable.
*
- * \return The ExprDoc object.
+ * This function is a special form of `Define`. Variable is mapped to ExprDoc rather
+ * than IdDoc. It's useful when a variable is implicitly defined without a name, like
+ * the buf->data in TIR, which should be mapped to `AttrDoc(IdDoc("<buffer_name>"), "data")`.
+ *
+ * This function takes a DocFactory instead of Doc. It's because GetVarDoc needs to
+ * return a new Doc object every time it's called, as the returned doc will have
+ * different `source_path`. Currently there isn't a good way to deep copy a TVMObject
+ * so VarTable needs to call a factory function to get a freshly-constructed Doc object
+ * every time GetVarDoc is called.
*/
- ExprDoc AsExprDoc(const TracedObject<ObjectRef>& obj) { return AsDoc<ExprDoc>(obj); }
+ void Define(const ObjectRef& obj, const Frame& frame, DocCreator doc_factory);
/*!
- * \brief Push a new dispatch token into the stack
- * \details The top dispatch token decides which dispatch table to use
- * when printing Object. This method returns a RAII guard which
- * pops the token when going out of the scope.
- *
- * \param token The dispatch token to push.
+ * \brief Get the doc for variable.
+ * \param obj The variable object.
*
- * \return A RAII guard to pop dispatch token when going out of scope.
+ * \return The doc for variable, if it exists in the table. Otherwise it returns NullOpt.
*/
- WithCtx WithDispatchToken(const String& token) {
- this->dispatch_tokens.push_back(token);
- return WithCtx(nullptr, [this]() { this->dispatch_tokens.pop_back(); });
- }
+ Optional<ExprDoc> GetVarDoc(const ObjectRef& obj) const;
/*!
- * \brief Push a new frame the stack
- * \details Frame contains the contextual information that's needed during printing,
- * for example, variables in the scope. This method returns a RAII guard which
- * pops the frame and call the cleanup method of frame when going out of the scope.
- *
- * \param frame The frame to push.
+ * \brief Check if a variable exists in the table.
+ * \param obj The variable object.
*
- * \return A RAII guard to pop frame and call the exit method of frame
- * when going out of scope
+ * \return a boolean for whether variable exists.
*/
- WithCtx WithFrame(const Frame& frame) {
- frame->EnterWithScope();
- this->frames.push_back(frame);
- return WithCtx(nullptr, [this, pushed_frame = frame]() {
- Frame last_frame = this->frames.back();
- ICHECK_EQ(last_frame, pushed_frame);
- this->frames.pop_back();
- last_frame->ExitWithScope();
- });
- }
-
+ bool IsVarDefined(const ObjectRef& obj) const;
+ /*! \brief Remove the variable defined */
+ void RemoveVar(const ObjectRef& obj);
/*!
- * \brief Get the top frame with type FrameType
- * \tparam FrameType The type of frame to get.
+ * \brief Set the common prefix information of variable usage.
+ * \param root The root of the AST.
+ * \param is_var A function that returns true if the given object is considered a variable.
*/
- template <typename FrameType>
- Optional<FrameType> GetFrame() const {
- for (auto it = frames.rbegin(); it != frames.rend(); ++it) {
- if (const auto* f = (*it).as<typename FrameType::ContainerType>()) {
- return GetRef<FrameType>(f);
- }
- }
- return NullOpt;
- }
-
- private:
- Doc AsDocImpl(const TracedObject<ObjectRef>& obj) const;
+ void SetCommonPrefix(const ObjectRef& root, runtime::TypedPackedFunc<bool(ObjectRef)> is_var);
+ /*!
+ * \brief Transform the input object into TDoc.
+ * \param obj The object to be transformed.
+ * \param path The path to this object.
+ *
+ * \return The Doc object.
+ */
+ template <class TDoc = Doc>
+ inline TDoc AsDoc(const ObjectRef& obj, const ObjectPath& path) const;
};
/*!
@@ -167,61 +235,49 @@ class IRDocsifierNode : public Object {
*/
class IRDocsifier : public ObjectRef {
public:
+ using FType = IRDocsifierFunctor<printer::Doc, ObjectPath, IRDocsifier>;
/*!
* \brief Create a IRDocsifier.
* \param ir_prefix The ir_prefix to use for this IRDocsifier.
*/
explicit IRDocsifier(Map<String, String> ir_prefix);
-
- using FType = TracedObjectFunctor<printer::Doc, IRDocsifier>;
- /*!
- * \brief The registration table for IRDocsifier.
- */
+ /*! \brief The registration table for IRDocsifier. */
TVM_DLL static FType& vtable();
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(IRDocsifier, ObjectRef, IRDocsifierNode);
};
-/*!
- * \brief A wrapper object to provide injection point for printer of each IR.
- *
- * For any IR node to be transformed by IRDocsifier, it will be wrapped by RootNodeContainer
- * and be dispatched to the corresponding function first. This provides an injection point for
- * each IR's printer implemention to add specialized logic, for example, pushing a special
- * Frame to the IRDocsifier before doing any IR->Doc transformation.
- *
- * \code
- * TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
- * .set_dispatch("relax", [](TracedObject<RootNodeContainer> obj, IRDocsifier p) {
- * const ObjectRef& root_node = obj.Get()->root_node;
- * // For example, relax printer can create a Frame specialized to Relax here
- * RelaxGeneralFrame frame;
- * auto ctx = p->WithFrame(frame);
- * // More specialized logic for your IR.
- * return p->AsDoc<Doc>(MakeTraced(root_node));
- * });
- * \endcode
- */
-class RootNodeContainerNode : public Object {
- public:
- /*! \brief The root node to print. */
- ObjectRef root_node;
+//////////////////////// Implementation ////////////////////////
- void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("root_node", &root_node); }
+inline void FrameNode::EnterWithScope() {
+ if (d != nullptr) {
+ d->frames.push_back(GetRef<Frame>(this));
+ }
+}
- static constexpr const char* _type_key = "script.printer.RootNodeContainer";
- TVM_DECLARE_FINAL_OBJECT_INFO(RootNodeContainerNode, Object);
-};
+inline void FrameNode::ExitWithScope() {
+ for (const std::function<void()>& callback : callbacks) {
+ callback();
+ }
+ callbacks.clear();
+ if (d != nullptr) {
+ d->frames.pop_back();
+ }
+}
-class RootNodeContainer : public ObjectRef {
- public:
- /*!
- * \brief Constructor of RootNodeContainer.
- * \param root_node The root node to print.
- * */
- explicit RootNodeContainer(ObjectRef root_node);
- TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(RootNodeContainer, ObjectRef, RootNodeContainerNode);
-};
+template <class TDoc>
+inline TDoc IRDocsifierNode::AsDoc(const ObjectRef& obj, const ObjectPath& path) const {
+ if (!obj.defined()) {
+ return Downcast<TDoc>(LiteralDoc::None());
+ }
+ return Downcast<TDoc>(
+ IRDocsifier::vtable()(dispatch_tokens.back(), obj, path, GetRef<IRDocsifier>(this)));
+}
+
+inline void FrameNode::AddDispatchToken(const IRDocsifier& d, const String& token) {
+ d->dispatch_tokens.push_back(token);
+ this->AddExitCallback([doc = d.get()]() { doc->dispatch_tokens.pop_back(); });
+}
} // namespace printer
} // namespace script
diff --git a/include/tvm/script/printer/ir_docsifier_functor.h b/include/tvm/script/printer/ir_docsifier_functor.h
new file mode 100644
index 0000000000..d04d8c4d02
--- /dev/null
+++ b/include/tvm/script/printer/ir_docsifier_functor.h
@@ -0,0 +1,163 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef TVM_SCRIPT_PRINTER_IR_DOCSIFIER_FUNCTOR_H_
+#define TVM_SCRIPT_PRINTER_IR_DOCSIFIER_FUNCTOR_H_
+
+#include <tvm/node/node.h>
+#include <tvm/runtime/logging.h>
+#include <tvm/runtime/packed_func.h>
+
+#include <string>
+#include <type_traits>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+namespace tvm {
+namespace script {
+namespace printer {
+
+/*!
+ * \brief Dynamic dispatch functor based on ObjectPath.
+ *
+ * This functor dispatches based on the type of object and the input dispatch token.
+ */
+template <typename R, typename... Args>
+class IRDocsifierFunctor {
+ private:
+ using TSelf = IRDocsifierFunctor<R, Args...>;
+
+ template <class TObjectRef, class TCallable>
+ using IsDispatchFunction =
+ typename std::is_convertible<TCallable, std::function<R(TObjectRef, Args...)>>;
+
+ public:
+ /*!
+ * \brief Call the dispatch function.
+ * \param token The dispatch token.
+ * \param obj The object.
+ * \param args Other args.
+ *
+ * \return The return value of the dispatch function
+ *
+ * If the TObjectRef isn't registered with the token, it will try to find
+ * dispatch function for TObjectRef with the default dispatch token (empty string).
+ */
+ template <class TObjectRef>
+ R operator()(const String& token, TObjectRef obj, Args... args) const {
+ uint32_t type_index = obj.defined() ? obj->type_index() : 0;
+ const runtime::PackedFunc* pf = nullptr;
+ if ((pf = LookupDispatchTable(token, type_index)) != nullptr) {
+ return (*pf)(obj, args...);
+ }
+ if ((pf = LookupDispatchTable("", type_index)) != nullptr) {
+ return (*pf)(obj, args...);
+ }
+ ICHECK(false) << "ObjectFunctor calls un-registered function on type: "
+ << runtime::Object::TypeIndex2Key(type_index) << " (token: " << token << ")"
+ << ". ObjectType: " << obj->GetTypeKey() << ". Object: " << obj;
+ }
+
+ /*!
+ * \brief Set the dispatch function
+ * \param token The dispatch token.
+ * \param type_index The TVM object type index for this dispatch function.
+ * \param f The dispatch function.
+ *
+ * This takes a type-erased packed function as input. It should be used
+ * through FFI boundary, for example, registering dispatch function from Python.
+ */
+ TSelf& set_dispatch(String token, uint32_t type_index, runtime::PackedFunc f) {
+ std::vector<runtime::PackedFunc>* table = &dispatch_table_[token];
+ if (table->size() <= type_index) {
+ table->resize(type_index + 1, nullptr);
+ }
+ runtime::PackedFunc& slot = (*table)[type_index];
+ if (slot != nullptr) {
+ ICHECK(false) << "Dispatch for type is already registered: "
+ << runtime::Object::TypeIndex2Key(type_index);
+ }
+ slot = f;
+ return *this;
+ }
+
+ /*!
+ * \brief Set the dispatch function
+ * \param token The dispatch token.
+ * \param f The dispatch function.
+ */
+ template <typename TObjectRef, typename TCallable,
+ typename = std::enable_if_t<IsDispatchFunction<TObjectRef, TCallable>::value>>
+ TSelf& set_dispatch(String token, TCallable f) {
+ return set_dispatch(token, TObjectRef::ContainerType::RuntimeTypeIndex(),
+ runtime::TypedPackedFunc<R(TObjectRef, Args...)>(f));
+ }
+
+ /*!
+ * \brief Remove dispatch function
+ * \param token The dispatch token.
+ * \param type_index The TVM object type index for the dispatch function to be removed.
+ *
+ * This is useful when dispatch function comes from other language's runtime, and
+ * those function should be removed before that language runtime shuts down.
+ */
+ void remove_dispatch(String token, uint32_t type_index) {
+ std::vector<runtime::PackedFunc>* table = &dispatch_table_[token];
+ if (table->size() <= type_index) {
+ return;
+ }
+ (*table)[type_index] = nullptr;
+ }
+
+ private:
+ /*!
+ * \brief Look up the dispatch table for the given token and type_index.
+ * \param token The dispatch token.
+ * \param type_index The TVM object type index.
+ * \return Returns the functor if the lookup succeeds, nullptr otherwise.
+ */
+ const runtime::PackedFunc* LookupDispatchTable(const String& token, uint32_t type_index) const {
+ auto it = dispatch_table_.find(token);
+ if (it == dispatch_table_.end()) {
+ return nullptr;
+ }
+ const std::vector<runtime::PackedFunc>& tab = it->second;
+ if (type_index >= tab.size()) {
+ return nullptr;
+ }
+ const PackedFunc* f = &tab[type_index];
+ if (f->defined()) {
+ return f;
+ } else {
+ return nullptr;
+ }
+ }
+ /*
+ * This type alias and the following free functions are created to reduce the binary bloat
+ * from template and also hide implementation details from this header
+ */
+ using DispatchTable = std::unordered_map<std::string, std::vector<runtime::PackedFunc>>;
+ /*! \brief The dispatch table. */
+ DispatchTable dispatch_table_;
+};
+
+} // namespace printer
+} // namespace script
+} // namespace tvm
+#endif // TVM_SCRIPT_PRINTER_IR_DOCSIFIER_FUNCTOR_H_
diff --git a/include/tvm/script/printer/printer.h b/include/tvm/script/printer/printer.h
new file mode 100644
index 0000000000..31abd7d9ec
--- /dev/null
+++ b/include/tvm/script/printer/printer.h
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef TVM_SCRIPT_PRINTER_PRINTER_H_
+#define TVM_SCRIPT_PRINTER_PRINTER_H_
+
+#include <tvm/node/node.h>
+#include <tvm/script/printer/ir_docsifier.h>
+
+#include <unordered_map>
+#include <vector>
+
+namespace tvm {
+namespace script {
+namespace printer {
+
+/*! \brief Default values in the TVMScript printer */
+struct Default {
+ /*! \brief Default data type of TIR buffer */
+ DataType buffer_dtype = DataType::Float(32);
+ /*! \brief Default data type of integer literals */
+ DataType int_dtype = DataType::Int(32);
+ /*!
+ * \brief Default data type of float literals. Right now we always print out the explicit type
+ * of floating point values, so setting it to Void means we do not print without the
+ * T.float32/T.float64 wrapper.
+ */
+ DataType float_dtype = DataType::Void();
+ /*! \brief Returns a singleton of the configuration */
+ static Default* Instance();
+ static DataType& BufferDType() { return Instance()->buffer_dtype; }
+ static DataType& IntDType() { return Instance()->int_dtype; }
+ static DataType& FloatDType() { return Instance()->float_dtype; }
+};
+
+/*!
+ * \brief The entry method for TVMScript printing
+ * \param obj The object to be printed
+ * \param ir_prefix The prefix of IR nodes
+ * \param indent_spaces Number of spaces used for indentation
+ * \param print_line_numbers Whether to print line numbers
+ * \param num_context_lines Number of context lines to print around the underlined text
+ * \param path_to_underline Object path to be underlined
+ * \return The TVMScript text format
+ */
+String Script(ObjectRef obj, //
+ Map<String, String> ir_prefix = {{"ir", "I"}, {"tir", "T"}}, //
+ int indent_spaces = 4, //
+ bool print_line_numbers = false, //
+ int num_context_lines = -1, //
+ Optional<ObjectPath> path_to_underline = NullOpt);
+
+/*!
+ * \brief Convert Doc into Python script.
+ * \param doc Doc to be converted
+ * \param indent_spaces Number of spaces used for indentation
+ * \param print_line_numbers Whether to print line numbers
+ * \param num_context_lines Number of context lines to print around the underlined text
+ * \param path_to_underline Object path to be underlined
+ */
+String DocToPythonScript(Doc doc, //
+ int indent_spaces = 4, //
+ bool print_line_numbers = false, //
+ int num_context_lines = -1, //
+ Optional<ObjectPath> path_to_underline = NullOpt);
+
+} // namespace printer
+} // namespace script
+} // namespace tvm
+
+#endif // TVM_SCRIPT_PRINTER_PRINTER_H_
diff --git a/include/tvm/script/printer/traced_object.h b/include/tvm/script/printer/traced_object.h
deleted file mode 100644
index cb63c31cd4..0000000000
--- a/include/tvm/script/printer/traced_object.h
+++ /dev/null
@@ -1,484 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file tvm/script/printer/traced_object.h
- * Wrappers around TVM objects that also store an ObjectPath from some "root" object
- * to the wrapper object.
- */
-
-#ifndef TVM_SCRIPT_PRINTER_TRACED_OBJECT_H_
-#define TVM_SCRIPT_PRINTER_TRACED_OBJECT_H_
-
-#include <tvm/node/object_path.h>
-#include <tvm/node/reflection.h>
-#include <tvm/runtime/object.h>
-
-#include <string>
-#include <utility>
-
-namespace tvm {
-
-template <typename RefT>
-class TracedObject;
-template <typename K, typename V>
-class TracedMap;
-template <typename T>
-class TracedArray;
-template <typename T>
-class TracedOptional;
-template <typename T>
-class TracedBasicValue;
-
-namespace detail {
-
-template <typename T, bool IsObject = std::is_base_of<ObjectRef, T>::value>
-struct TracedObjectWrapperSelector;
-
-template <typename T>
-struct TracedObjectWrapperSelector<T, false> {
- using Type = TracedBasicValue<T>;
-};
-
-template <typename T>
-struct TracedObjectWrapperSelector<T, true> {
- using Type = TracedObject<T>;
-};
-
-template <typename K, typename V>
-struct TracedObjectWrapperSelector<Map<K, V>, true> {
- using Type = TracedMap<K, V>;
-};
-
-template <typename T>
-struct TracedObjectWrapperSelector<Array<T>, true> {
- using Type = TracedArray<T>;
-};
-
-template <typename T>
-struct TracedObjectWrapperSelector<Optional<T>, true> {
- using Type = TracedOptional<T>;
-};
-
-} // namespace detail
-
-/*!
- * \brief Traced wrapper for regular (non-container) TVM objects.
- */
-template <typename RefT>
-class TracedObject {
- using ObjectType = typename RefT::ContainerType;
-
- public:
- using ObjectRefType = RefT;
-
- // Don't use this direcly. For convenience, call MakeTraced() instead.
- explicit TracedObject(const RefT& object_ref, ObjectPath path)
- : ref_(object_ref), path_(std::move(path)) {}
-
- // Implicit conversion from a derived reference class
- template <typename DerivedRef>
- TracedObject(const TracedObject<DerivedRef>& derived)
- : ref_(derived.Get()), path_(derived.GetPath()) {}
-
- /*!
- * \brief Get a traced wrapper for an attribute of the wrapped object.
- */
- template <typename T, typename BaseType>
- typename detail::TracedObjectWrapperSelector<T>::Type GetAttr(T BaseType::*member_ptr) const {
- using WrapperType = typename detail::TracedObjectWrapperSelector<T>::Type;
- const ObjectType* node = static_cast<const ObjectType*>(ref_.get());
- const T& attr = node->*member_ptr;
- Optional<String> attr_key = ICHECK_NOTNULL(GetAttrKeyByAddress(node, &attr));
- return WrapperType(attr, path_->Attr(attr_key));
- }
-
- /*!
- * \brief Access the wrapped object.
- */
- const RefT& Get() const { return ref_; }
-
- /*!
- * \brief Check if the reference to the wrapped object can be converted to `RefU`.
- */
- template <typename RefU>
- bool IsInstance() const {
- return ref_->template IsInstance<typename RefU::ContainerType>();
- }
-
- /*!
- * \brief Same as Get().defined().
- */
- bool defined() const { return ref_.defined(); }
-
- /*!
- * \brief Convert the wrapped reference type to a subtype.
- *
- * Throws an exception if IsInstance<RefU>() is false.
- */
- template <typename RefU>
- TracedObject<RefU> Downcast() const {
- return TracedObject<RefU>(tvm::runtime::Downcast<RefU>(ref_), path_);
- }
-
- /*!
- * \brief Convert the wrapped reference type to a subtype.
- *
- * Returns an empty optional if IsInstance<RefU>() is false.
- */
- template <typename RefU>
- TracedOptional<RefU> TryDowncast() const {
- if (ref_->template IsInstance<typename RefU::ContainerType>()) {
- return Downcast<RefU>();
- } else {
- return TracedOptional<RefU>(NullOpt, path_);
- }
- }
-
- /*!
- * \brief Get the path of the wrapped object.
- */
- const ObjectPath& GetPath() const { return path_; }
-
- private:
- RefT ref_;
- ObjectPath path_;
-};
-
-/*!
- * \brief Iterator class for TracedMap<K, V>
- */
-template <typename K, typename V>
-class TracedMapIterator {
- public:
- using WrappedV = typename detail::TracedObjectWrapperSelector<V>::Type;
- using MapIter = typename Map<K, V>::iterator;
-
- using iterator_category = std::bidirectional_iterator_tag;
- using difference_type = ptrdiff_t;
- using value_type = const std::pair<K, WrappedV>;
- using pointer = value_type*;
- using reference = value_type;
-
- explicit TracedMapIterator(MapIter iter, ObjectPath map_path)
- : iter_(iter), map_path_(std::move(map_path)) {}
-
- bool operator==(const TracedMapIterator& other) const { return iter_ == other.iter_; }
-
- bool operator!=(const TracedMapIterator& other) const { return iter_ != other.iter_; }
-
- pointer operator->() const = delete;
-
- reference operator*() const {
- auto kv = *iter_;
- return std::make_pair(kv.first, WrappedV(kv.second, map_path_->MapValue(kv.first)));
- }
-
- TracedMapIterator& operator++() {
- ++iter_;
- return *this;
- }
-
- TracedMapIterator operator++(int) {
- TracedMapIterator copy = *this;
- ++(*this);
- return copy;
- }
-
- private:
- MapIter iter_;
- ObjectPath map_path_;
-};
-
-/*!
- * \brief Traced wrapper for Map objects.
- */
-template <typename K, typename V>
-class TracedMap {
- public:
- using WrappedV = typename detail::TracedObjectWrapperSelector<V>::Type;
-
- using iterator = TracedMapIterator<K, V>;
-
- // Don't use this direcly. For convenience, call MakeTraced() instead.
- explicit TracedMap(Map<K, V> map, ObjectPath path)
- : map_(std::move(map)), path_(std::move(path)) {}
-
- /*!
- * \brief Get a value by its key, wrapped in a traced wrapper.
- */
- WrappedV at(const K& key) const {
- auto it = map_.find(key);
- ICHECK(it != map_.end()) << "No such key in Map";
- auto kv = *it;
- return WrappedV(kv.second, path_->MapValue(kv.first));
- }
-
- /*!
- * \brief Access the wrapped map object.
- */
- const Map<K, V>& Get() const { return map_; }
-
- /*!
- * \brief Get the path of the wrapped object.
- */
- const ObjectPath& GetPath() const { return path_; }
-
- /*!
- * \brief Get an iterator to the first item of the map.
- */
- iterator begin() const { return iterator(map_.begin(), path_); }
-
- /*!
- * \brief Get an iterator to the end of the map.
- */
- iterator end() const { return iterator(map_.end(), path_); }
-
- /*!
- * \brief Returns true iff the wrapped map is empty.
- */
- bool empty() const { return map_.empty(); }
-
- private:
- Map<K, V> map_;
- ObjectPath path_;
-};
-
-/*!
- * \brief Iterator class for TracedArray<T>
- */
-template <typename T>
-class TracedArrayIterator {
- public:
- using WrappedT = typename detail::TracedObjectWrapperSelector<T>::Type;
-
- using difference_type = ptrdiff_t;
- using value_type = WrappedT;
- using pointer = WrappedT*;
- using reference = WrappedT&;
- using iterator_category = std::random_access_iterator_tag;
-
- explicit TracedArrayIterator(Array<T> array, size_t index, ObjectPath array_path)
- : array_(array), index_(index), array_path_(array_path) {}
-
- TracedArrayIterator& operator++() {
- ++index_;
- return *this;
- }
- TracedArrayIterator& operator--() {
- --index_;
- return *this;
- }
- TracedArrayIterator operator++(int) {
- TracedArrayIterator copy = *this;
- ++index_;
- return copy;
- }
- TracedArrayIterator operator--(int) {
- TracedArrayIterator copy = *this;
- --index_;
- return copy;
- }
-
- TracedArrayIterator operator+(difference_type offset) const {
- return TracedArrayIterator(array_, index_ + offset, array_path_);
- }
-
- TracedArrayIterator operator-(difference_type offset) const {
- return TracedArrayIterator(array_, index_ - offset, array_path_);
- }
-
- difference_type operator-(const TracedArrayIterator& rhs) const { return index_ - rhs.index_; }
-
- bool operator==(TracedArrayIterator other) const {
- return array_.get() == other.array_.get() && index_ == other.index_;
- }
- bool operator!=(TracedArrayIterator other) const { return !(*this == other); }
- value_type operator*() const { return WrappedT(array_[index_], array_path_->ArrayIndex(index_)); }
-
- private:
- Array<T> array_;
- size_t index_;
- ObjectPath array_path_;
-};
-
-/*!
- * \brief Traced wrapper for Array objects.
- */
-template <typename T>
-class TracedArray {
- public:
- using WrappedT = typename detail::TracedObjectWrapperSelector<T>::Type;
-
- using iterator = TracedArrayIterator<T>;
-
- // Don't use this direcly. For convenience, call MakeTraced() instead.
- explicit TracedArray(Array<T> array, ObjectPath path)
- : array_(std::move(array)), path_(std::move(path)) {}
-
- /*!
- * \brief Access the wrapped array object.
- */
- const Array<T>& Get() const { return array_; }
-
- /*!
- * \brief Get the path of the wrapped array object.
- */
- const ObjectPath& GetPath() const { return path_; }
-
- /*!
- * \brief Get an element by index, wrapped in a traced wrapper.
- */
- WrappedT operator[](size_t index) const {
- return WrappedT(array_[index], path_->ArrayIndex(index));
- }
-
- /*!
- * \brief Get an iterator to the first array element.
- *
- * The iterator's dereference operator will automatically wrap each element in a traced wrapper.
- */
- iterator begin() const { return iterator(array_, 0, path_); }
-
- /*!
- * \brief Get an iterator to the end of the array.
- *
- * The iterator's dereference operator will automatically wrap each element in a traced wrapper.
- */
- iterator end() const { return iterator(array_, array_.size(), path_); }
-
- /*!
- * \brief Returns true iff the wrapped array is empty.
- */
- bool empty() const { return array_.empty(); }
-
- /*!
- * \brief Get the size of the wrapped array.
- */
- size_t size() const { return array_.size(); }
-
- private:
- Array<T> array_;
- ObjectPath path_;
-};
-
-/*!
- * \brief Traced wrapper for Optional objects.
- */
-template <typename T>
-class TracedOptional {
- public:
- using WrappedT = typename detail::TracedObjectWrapperSelector<T>::Type;
-
- /*!
- * \brief Implicit conversion from the corresponding non-optional traced wrapper.
- */
- TracedOptional(const WrappedT& value) // NOLINT(runtime/explicit)
- : optional_(value.Get().defined() ? value.Get() : Optional<T>(NullOpt)),
- path_(value.GetPath()) {}
-
- // Don't use this direcly. For convenience, call MakeTraced() instead.
- explicit TracedOptional(Optional<T> optional, ObjectPath path)
- : optional_(std::move(optional)), path_(std::move(path)) {}
-
- /*!
- * \brief Access the wrapped optional object.
- */
- const Optional<T>& Get() const { return optional_; }
-
- /*!
- * \brief Get the path of the wrapped optional object.
- */
- const ObjectPath& GetPath() const { return path_; }
-
- /*!
- * \brief Returns true iff the object is present.
- */
- bool defined() const { return optional_.defined(); }
-
- /*!
- * \brief Returns a non-optional traced wrapper, throws if defined() is false.
- */
- WrappedT value() const { return WrappedT(optional_.value(), path_); }
-
- /*!
- * \brief Same as defined().
- */
- explicit operator bool() const { return optional_.defined(); }
-
- private:
- Optional<T> optional_;
- ObjectPath path_;
-};
-
-/*!
- * \brief Traced wrapper for basic values (i.e. non-TVM objects)
- */
-template <typename T>
-class TracedBasicValue {
- public:
- explicit TracedBasicValue(const T& value, ObjectPath path)
- : value_(value), path_(std::move(path)) {}
-
- /*!
- * \brief Access the wrapped value.
- */
- const T& Get() const { return value_; }
-
- /*!
- * \brief Get the path of the wrapped value.
- */
- const ObjectPath& GetPath() const { return path_; }
-
- /*!
- * \brief Transform the wrapped value without changing its path.
- */
- template <typename F>
- typename detail::TracedObjectWrapperSelector<typename std::invoke_result<F, const T&>::type>::Type
- ApplyFunc(F&& f) const {
- return MakeTraced(f(value_), path_);
- }
-
- private:
- T value_;
- ObjectPath path_;
-};
-
-/*!
- * \brief Wrap the given root object in an appropriate traced wrapper class.
- */
-template <typename RefT>
-typename detail::TracedObjectWrapperSelector<RefT>::Type MakeTraced(const RefT& object) {
- using WrappedT = typename detail::TracedObjectWrapperSelector<RefT>::Type;
- return WrappedT(object, ObjectPath::Root());
-}
-
-/*!
- * \brief Wrap the given object with the given path in an appropriate traced wrapper class.
- */
-template <typename RefT>
-typename detail::TracedObjectWrapperSelector<RefT>::Type MakeTraced(const RefT& object,
- ObjectPath path) {
- using WrappedT = typename detail::TracedObjectWrapperSelector<RefT>::Type;
- return WrappedT(object, std::move(path));
-}
-
-} // namespace tvm
-
-#endif // TVM_SCRIPT_PRINTER_TRACED_OBJECT_H_
diff --git a/include/tvm/script/printer/traced_object_functor.h b/include/tvm/script/printer/traced_object_functor.h
deleted file mode 100644
index 8f72d139a5..0000000000
--- a/include/tvm/script/printer/traced_object_functor.h
+++ /dev/null
@@ -1,175 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-#ifndef TVM_SCRIPT_PRINTER_TRACED_OBJECT_FUNCTOR_H_
-#define TVM_SCRIPT_PRINTER_TRACED_OBJECT_FUNCTOR_H_
-
-#include <tvm/node/node.h>
-#include <tvm/runtime/logging.h>
-#include <tvm/runtime/packed_func.h>
-#include <tvm/script/printer/traced_object.h>
-
-#include <string>
-#include <type_traits>
-#include <unordered_map>
-#include <utility>
-#include <vector>
-
-namespace tvm {
-namespace script {
-namespace printer {
-
-/*
- * This type alias and the following free functions are created to reduce the binary bloat
- * from template and also hide implementation details from this header
- */
-using DispatchTable = std::unordered_map<std::string, std::vector<runtime::PackedFunc>>;
-
-/*!
- * \brief Get function from dispatch table.
- * \param dispatch_table The dispatch table.
- * \param token The dispatch token.
- * \param type_index The type index of the Object type to be dispatched.
- *
- * \return The dispatch function.
- */
-const runtime::PackedFunc& GetDispatchFunction(const DispatchTable& dispatch_table,
- const String& token, uint32_t type_index);
-
-/*!
- * \brief Set function in dispatch table.
- * \param dispatch_table The dispatch table.
- * \param token The dispatch token.
- * \param type_index The type index of the Object type to be dispatched.
- * \param f The dispatch function.
- */
-void SetDispatchFunction(DispatchTable* dispatch_table, const String& token, uint32_t type_index,
- runtime::PackedFunc f);
-
-/*!
- * \brief Remove function from dispatch table.
- * \param dispatch_table The dispatch table.
- * \param token The dispatch token.
- * \param type_index The TVM object type index for the dispatch function to be removed.
- */
-void RemoveDispatchFunction(DispatchTable* dispatch_table, const String& token,
- uint32_t type_index);
-
-constexpr const char* kDefaultDispatchToken = "";
-
-/*!
- * \brief Dynamic dispatch functor based on TracedObject.
- *
- * This functor dispatches based on the type of object ref inside the input TracedObject,
- * and the input dispatch token.
- */
-template <typename R, typename... Args>
-class TracedObjectFunctor {
- private:
- using TSelf = TracedObjectFunctor<R, Args...>;
-
- template <class TObjectRef, class TCallable>
- using IsDispatchFunction =
- typename std::is_convertible<TCallable, std::function<R(TracedObject<TObjectRef>, Args...)>>;
-
- public:
- /*!
- * \brief Call the dispatch function.
- * \param token The dispatch token.
- * \param traced_object The traced object.
- * \param args Other args.
- *
- * \return The return value of the dispatch function
- *
- * If the TObjectRef isn't registered with the token, it will try to find
- * dispatch function for TObjectRef with kDefaultDispatchToken.
- */
- template <class TObjectRef>
- R operator()(const String& token, TracedObject<TObjectRef> traced_object, Args... args) const {
- const runtime::PackedFunc& dispatch_function =
- GetDispatchFunction(dispatch_table_, token, traced_object.Get()->type_index());
- return dispatch_function(traced_object.Get(), traced_object.GetPath(), args...);
- }
-
- /*!
- * \brief Set the dispatch function
- * \param token The dispatch token.
- * \param type_index The TVM object type index for this dispatch function.
- * \param f The dispatch function.
- *
- * This takes a type-erased packed function as input. It should be used
- * through FFI boundary, for example, registering dispatch function from Python.
- */
- TSelf& set_dispatch(String token, uint32_t type_index, runtime::PackedFunc f) {
- SetDispatchFunction(&dispatch_table_, token, type_index, std::move(f));
- return *this;
- }
-
- /*!
- * \brief Set the dispatch function
- * \param token The dispatch token.
- * \param f The dispatch function.
- *
- * The diaptch function should have signature `R(TracedObject<TObjectRef>, Args...)`.
- */
- template <typename TObjectRef, typename TCallable,
- typename = std::enable_if_t<IsDispatchFunction<TObjectRef, TCallable>::value>>
- TSelf& set_dispatch(String token, TCallable f) {
- return set_dispatch(
- token, //
- TObjectRef::ContainerType::RuntimeTypeIndex(), //
- runtime::TypedPackedFunc<R(TObjectRef, ObjectPath, Args...)>(
- [f = std::move(f)](TObjectRef object, ObjectPath path, Args... args) -> R {
- return f(MakeTraced(object, path), args...);
- }));
- }
- /*!
- * \brief Set the default dispatch function
- * \param f The dispatch function.
- *
- * Default dispatch function will be used if there is no function registered
- * with the requested dispatch token.
- *
- * Default dispatch function has an empty string as dispatch token.
- */
- template <typename TObjectRef, typename TCallable,
- typename = std::enable_if_t<IsDispatchFunction<TObjectRef, TCallable>::value>>
- TSelf& set_dispatch(TCallable&& f) {
- return set_dispatch<TObjectRef>(kDefaultDispatchToken, std::forward<TCallable>(f));
- }
-
- /*!
- * \brief Remove dispatch function
- * \param token The dispatch token.
- * \param type_index The TVM object type index for the dispatch function to be removed.
- *
- * This is useful when dispatch function comes from other language's runtime, and
- * those function should be removed before that language runtime shuts down.
- */
- void remove_dispatch(String token, uint32_t type_index) {
- RemoveDispatchFunction(&dispatch_table_, token, type_index);
- }
-
- private:
- DispatchTable dispatch_table_;
-};
-
-} // namespace printer
-} // namespace script
-} // namespace tvm
-#endif // TVM_SCRIPT_PRINTER_TRACED_OBJECT_FUNCTOR_H_
diff --git a/include/tvm/script/printer/var_table.h b/include/tvm/script/printer/var_table.h
deleted file mode 100644
index 2cd9335213..0000000000
--- a/include/tvm/script/printer/var_table.h
+++ /dev/null
@@ -1,155 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-#ifndef TVM_SCRIPT_PRINTER_VAR_TABLE_H_
-#define TVM_SCRIPT_PRINTER_VAR_TABLE_H_
-
-#include <tvm/node/node.h>
-#include <tvm/node/object_path.h>
-#include <tvm/script/printer/doc.h>
-#include <tvm/script/printer/frame.h>
-#include <tvm/script/printer/traced_object.h>
-
-#include <unordered_map>
-#include <unordered_set>
-
-namespace tvm {
-namespace script {
-namespace printer {
-
-/*!
- * \brief Variable Table manages mapping from variable object to ExprDoc during
- * the process of printing TVMScript.
- *
- * The value type of this map is ExprDoc rather than IdDoc or String. It's
- * because variables can be implicitly defined. For example in TIR buffer (tir::Buffer),
- * `buf->data` is a variable, while its representation in TVMScript should be an
- * expression `x.data`, where `x` is the variable for the buffer itself.
- */
-class VarTableNode : public Object {
- public:
- void VisitAttrs(AttrVisitor*) {}
-
- /*!
- * \brief Define variable by name.
- * \param obj The variable object.
- * \param name_hint The hint for variable name.
- * \param object_path The object_path for the returned ExprDoc.
- * \param frame The frame that this variable is defined in.
- *
- * \return The id doc for this variable.
- *
- * This function will rename the variable to avoid name conflict with other variables
- * in the table.
- */
- IdDoc Define(const ObjectRef& obj, const String& name_hint, const ObjectPath& object_path,
- const Frame& frame);
-
- /*!
- * \brief Define variable by name.
- * \param obj The variable object.
- * \param name_hint The hint for variable name.
- * \param frame The frame that this variable is defined in.
- *
- * \return The id doc for this variable.
- *
- * This is a shortcut version of `Define` which accepts a traced string.
- */
- IdDoc Define(const ObjectRef& obj, const TracedObject<String>& name_hint, const Frame& frame) {
- return Define(obj, name_hint.Get(), name_hint.GetPath(), frame);
- }
-
- using DocFactory = std::function<ExprDoc()>;
-
- /*!
- * \brief Define variable by doc factory.
- * \param obj The variable object.
- * \param doc_factory The function to return an ExprDoc object for this variable.
- * \param frame The frame that this variable is defined in.
- *
- * This function is a special form of `Define`. Variable is mapped to ExprDoc rather
- * than IdDoc. It's useful when a variable is implicitly defined without a name, like
- * the buf->data in TIR, which should be mapped to `AttrDoc(IdDoc("<buffer_name>"), "data")`.
- *
- * This function takes a DocFactory instead of Doc. It's because GetVarDoc needs to
- * return a new Doc object every time it's called, as the returned doc will have
- * different `soruce_path`. Currently there isn't a good way to deep copy a TVMObject
- * so VarTable needs to call a factory function to get a freshly-constructed Doc object
- * every time GetVarDoc is called.
- */
- void DefineByDoc(const ObjectRef& obj, DocFactory doc_factory, const Frame& frame);
-
- /*!
- * \brief Get the doc for variable.
- * \param obj The variable object.
- * \param object_path The object path for the variable.
- *
- * \return The doc for variable, if it exists in the table. Otherwise it returns NullOpt.
- */
- Optional<ExprDoc> GetVarDoc(const ObjectRef& obj, const ObjectPath& object_path) const;
-
- /*!
- * \brief Get the doc for variable.
- * \param obj The traced variable object.
- *
- * \return The doc for variable, if it exists in the table. Otherwise it returns NullOpt.
- */
- template <typename TObjectRef>
- Optional<ExprDoc> GetVarDoc(const TracedObject<TObjectRef> obj) const {
- return GetVarDoc(obj.Get(), obj.GetPath());
- }
-
- /*!
- * \brief Check if a variable exists in the table.
- * \param obj The variable object.
- *
- * \return a boolean for whether variable exists.
- */
- bool IsVarDefined(const ObjectRef& obj) const;
-
- static constexpr const char* _type_key = "script.printer.VarTable";
- TVM_DECLARE_FINAL_OBJECT_INFO(VarTableNode, Object);
-
- private:
- void RemoveVar(const ObjectRef& obj);
-
- struct VariableInfo {
- DocFactory doc_factory;
- Optional<String> name;
- };
- std::unordered_map<ObjectRef, VariableInfo, ObjectPtrHash, ObjectPtrEqual> obj2info;
- std::unordered_set<String> defined_names;
-};
-
-/*!
- * \brief Reference type of VarTableNode.
- */
-class VarTable : public ObjectRef {
- public:
- /*!
- * \brief Create an empty VarTable.
- */
- VarTable();
- TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(VarTable, ObjectRef, VarTableNode);
-};
-
-} // namespace printer
-} // namespace script
-} // namespace tvm
-
-#endif // TVM_SCRIPT_PRINTER_VAR_TABLE_H_
diff --git a/include/tvm/support/with.h b/include/tvm/support/with.h
index 5959affafd..8333adc9e6 100644
--- a/include/tvm/support/with.h
+++ b/include/tvm/support/with.h
@@ -92,34 +92,5 @@ class With {
ContextType ctx_;
};
-/*!
- * \brief A context type that delegates EnterWithScope and ExitWithScope
- * to user-provided functions.
- */
-class ContextManager {
- public:
- /*!
- * \brief Constructor of ContextManager.
- * \param f_enter The function to call when entering scope. If it's nullptr, do nothing when
- * entering.
- * \param f_exit The function to call when exiting scope. If it's nullptr, do nothing
- * when exiting.
- */
- template <class FEnter, class FExit>
- explicit ContextManager(FEnter f_enter, FExit f_exit) : f_enter_(f_enter), f_exit_(f_exit) {}
-
- private:
- void EnterWithScope() {
- if (f_enter_) f_enter_();
- }
- void ExitWithScope() {
- if (f_exit_) f_exit_();
- }
- std::function<void()> f_enter_;
- std::function<void()> f_exit_;
- template <typename>
- friend class With;
-};
-
} // namespace tvm
#endif // TVM_SUPPORT_WITH_H_
diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h
index 9b48b0cceb..21bc7e7a50 100644
--- a/include/tvm/tir/op.h
+++ b/include/tvm/tir/op.h
@@ -40,6 +40,9 @@
namespace tvm {
+#define TVM_TIR_REGISTER_OP(OpName) \
+ TVM_REGISTER_OP("tir." OpName).set_attr<TScriptPrinterName>("TScriptPrinterName", OpName)
+
// Most common operators can be overloaded by argument type(PrimExpr).
// So we put them under the root namespace.
//
diff --git a/include/tvm/tir/op_attr_types.h b/include/tvm/tir/op_attr_types.h
index 2dc174f7d2..858d89c2d5 100644
--- a/include/tvm/tir/op_attr_types.h
+++ b/include/tvm/tir/op_attr_types.h
@@ -56,6 +56,11 @@ using FLowerIntrinsic = runtime::TypedPackedFunc<PrimExpr(PrimExpr)>;
*/
using FLegalize = runtime::TypedPackedFunc<PrimExpr(PrimExpr)>;
+/*!
+ * \brief The operator's name in TVMScript printer
+ */
+using TScriptPrinterName = String;
+
/*!
* \brief The effect type of the call.
*/
diff --git a/python/tvm/script/__init__.py b/python/tvm/script/__init__.py
index 21bdfa6f16..82bb698f27 100644
--- a/python/tvm/script/__init__.py
+++ b/python/tvm/script/__init__.py
@@ -15,4 +15,7 @@
# specific language governing permissions and limitations
# under the License.
"""TVM Script APIs of TVM Python Package"""
-from .parser import ir, ir_module, parse as from_source, tir
+from .parser import ir, ir_module
+from .parser import parse as from_source
+from .parser import tir
+from .printer import script
diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py
index 48b2834479..06a85fa340 100644
--- a/python/tvm/script/ir_builder/tir/ir.py
+++ b/python/tvm/script/ir_builder/tir/ir.py
@@ -27,6 +27,7 @@ from typing_extensions import Literal
# isort: on
import numpy as np # type: ignore
+
from tvm.ir import Range, Type
from tvm.runtime import convert, ndarray
from tvm.target import Target
@@ -508,7 +509,9 @@ class axis: # pylint: disable=invalid-name
@staticmethod
def spatial(
- dom: Union[Range, List[PrimExpr], Tuple[PrimExpr]], binding: PrimExpr, dtype: str = "int32"
+ dom: Union[Range, List[PrimExpr], Tuple[PrimExpr]],
+ binding: PrimExpr,
+ dtype: str = "int32",
) -> Var:
"""The spatial block axis defining function.
@@ -534,7 +537,9 @@ class axis: # pylint: disable=invalid-name
@staticmethod
def reduce(
- dom: Union[Range, List[PrimExpr], Tuple[PrimExpr]], binding: PrimExpr, dtype: str = "int32"
+ dom: Union[Range, List[PrimExpr], Tuple[PrimExpr]],
+ binding: PrimExpr,
+ dtype: str = "int32",
) -> Var:
"""The reduced block axis defining function.
@@ -560,7 +565,9 @@ class axis: # pylint: disable=invalid-name
@staticmethod
def scan(
- dom: Union[Range, List[PrimExpr], Tuple[PrimExpr]], binding: PrimExpr, dtype: str = "int32"
+ dom: Union[Range, List[PrimExpr], Tuple[PrimExpr]],
+ binding: PrimExpr,
+ dtype: str = "int32",
) -> Var:
"""The scanning block axis defining function.
@@ -586,7 +593,9 @@ class axis: # pylint: disable=invalid-name
@staticmethod
def opaque(
- dom: Union[Range, List[PrimExpr], Tuple[PrimExpr]], binding: PrimExpr, dtype: str = "int32"
+ dom: Union[Range, List[PrimExpr], Tuple[PrimExpr]],
+ binding: PrimExpr,
+ dtype: str = "int32",
) -> Var:
"""The opaque block axis defining function.
@@ -1534,34 +1543,41 @@ def target(target_config: Union[Dict, str]) -> Target:
return Target(target_config)
-def _op_wrapper(func):
- @functools.wraps(func)
- def wrapped(*args, **kwargs):
- if "dtype" in kwargs:
- kwargs.pop("dtype")
- return func(*args, **kwargs)
+class meta_var: # pylint: disable=invalid-name
+ """A meta variable used in TVMScript metaprogramming. It means that the value of the variable
+ does not appear in the final TIR, but only stays in the parser.
- return wrapped
+ Parameters
+ ----------
+ value: Any
+ The meta variable.
+ """
+ def __init__(self, value: Any) -> None:
+ self.value = value
-def _dtype_forward(func):
+ def __iter__(self):
+ def f():
+ for i in self.value:
+ yield meta_var(i)
+
+ return f()
+
+
+# pylint: disable=invalid-name
+
+
+def _op_wrapper(func):
@functools.wraps(func)
def wrapped(*args, **kwargs):
if "dtype" in kwargs:
- args = (kwargs.pop("dtype"),) + args
+ kwargs.pop("dtype")
return func(*args, **kwargs)
return wrapped
-# pylint: disable=invalid-name
-
-broadcast = Broadcast
-ramp = Ramp
-
-buffer_var = ptr
abs = _op_wrapper(_tir_op.abs) # pylint: disable=redefined-builtin
-fabs = abs
acos = _op_wrapper(_tir_op.acos)
acosh = _op_wrapper(_tir_op.acosh)
address_of = _op_wrapper(_tir_op.address_of)
@@ -1607,7 +1623,6 @@ pow = _op_wrapper(_tir_op.pow) # pylint: disable=redefined-builtin
q_multiply_shift = _op_wrapper(_tir_op.q_multiply_shift)
q_multiply_shift_per_axis = _op_wrapper(_tir_op.q_multiply_shift_per_axis)
ret = _op_wrapper(_tir_op.ret)
-reinterpret = _dtype_forward(_tir_op.reinterpret)
round = _op_wrapper(_tir_op.round) # pylint: disable=redefined-builtin
rsqrt = _op_wrapper(_tir_op.rsqrt)
shift_left = _op_wrapper(_tir_op.shift_left)
@@ -1631,11 +1646,6 @@ call_packed = _op_wrapper(_tir_op.call_packed)
call_cpacked = _op_wrapper(_tir_op.call_cpacked)
call_packed_lowered = _op_wrapper(_tir_op.call_packed_lowered)
call_cpacked_lowered = _op_wrapper(_tir_op.call_cpacked_lowered)
-call_extern = _dtype_forward(_tir_op.call_extern)
-call_intrin = _dtype_forward(_tir_op.call_intrin)
-call_llvm_intrin = _dtype_forward(_tir_op.call_llvm_intrin)
-call_llvm_pure_intrin = _dtype_forward(_tir_op.call_llvm_pure_intrin)
-call_pure_extern = _dtype_forward(_tir_op.call_pure_extern)
tvm_tuple = _op_wrapper(_tir_op.tvm_tuple)
tvm_struct_set = _op_wrapper(_tir_op.tvm_struct_set)
tvm_struct_get = _tir_op.tvm_struct_get
@@ -1645,48 +1655,51 @@ tvm_mma_sync = _op_wrapper(_tir_op.tvm_mma_sync)
tvm_bmma_sync = _op_wrapper(_tir_op.tvm_bmma_sync)
tvm_fill_fragment = _op_wrapper(_tir_op.tvm_fill_fragment)
tvm_store_matrix_sync = _op_wrapper(_tir_op.tvm_store_matrix_sync)
-ptx_mma = _dtype_forward(_tir_op.ptx_mma)
-ptx_mma_sp = _dtype_forward(_tir_op.ptx_mma_sp)
-ptx_ldmatrix = _dtype_forward(_tir_op.ptx_ldmatrix)
-ptx_cp_async = _dtype_forward(_tir_op.ptx_cp_async)
ptx_wait_group = _op_wrapper(_tir_op.ptx_wait_group)
ptx_commit_group = _op_wrapper(_tir_op.ptx_commit_group)
-mma_store = _dtype_forward(_tir_op.mma_store)
-mma_fill = _dtype_forward(_tir_op.mma_fill)
-vectorlow = _dtype_forward(_tir_op.vectorlow)
-vectorhigh = _dtype_forward(_tir_op.vectorhigh)
-vectorcombine = _dtype_forward(_tir_op.vectorcombine)
assume = _op_wrapper(_tir_op.assume)
undef = _op_wrapper(_tir_op.undef)
-tvm_call_packed = call_packed
-tvm_call_cpacked = call_cpacked
-tvm_call_packed_lowered = call_packed_lowered
-tvm_call_cpacked_lowered = call_cpacked_lowered
TVMBackendAllocWorkspace = _op_wrapper(_tir_op.TVMBackendAllocWorkspace)
TVMBackendFreeWorkspace = _op_wrapper(_tir_op.TVMBackendFreeWorkspace)
start_profile_intrinsic = _op_wrapper(_tir_op.start_profile_intrinsic)
end_profile_intrinsic = _op_wrapper(_tir_op.end_profile_intrinsic)
-class meta_var:
- """A meta variable used in TVMScript metaprogramming. It means that the value of the variable
- does not appear in the final TIR, but only stays in the parser.
+def _dtype_forward(func):
+ @functools.wraps(func)
+ def wrapped(*args, **kwargs):
+ if "dtype" in kwargs:
+ args = (kwargs.pop("dtype"),) + args
+ return func(*args, **kwargs)
- Parameters
- ----------
- value: Any
- The meta variable.
- """
+ return wrapped
- def __init__(self, value: Any) -> None:
- self.value = value
- def __iter__(self):
- def f():
- for i in self.value:
- yield meta_var(i)
+reinterpret = _dtype_forward(_tir_op.reinterpret)
+call_extern = _dtype_forward(_tir_op.call_extern)
+call_intrin = _dtype_forward(_tir_op.call_intrin)
+call_llvm_intrin = _dtype_forward(_tir_op.call_llvm_intrin)
+call_llvm_pure_intrin = _dtype_forward(_tir_op.call_llvm_pure_intrin)
+call_pure_extern = _dtype_forward(_tir_op.call_pure_extern)
+ptx_mma = _dtype_forward(_tir_op.ptx_mma)
+ptx_mma_sp = _dtype_forward(_tir_op.ptx_mma_sp)
+ptx_ldmatrix = _dtype_forward(_tir_op.ptx_ldmatrix)
+ptx_cp_async = _dtype_forward(_tir_op.ptx_cp_async)
+mma_store = _dtype_forward(_tir_op.mma_store)
+mma_fill = _dtype_forward(_tir_op.mma_fill)
+vectorlow = _dtype_forward(_tir_op.vectorlow)
+vectorhigh = _dtype_forward(_tir_op.vectorhigh)
+vectorcombine = _dtype_forward(_tir_op.vectorcombine)
- return f()
+
+broadcast = Broadcast
+ramp = Ramp
+buffer_var = ptr
+fabs = abs
+tvm_call_packed = call_packed
+tvm_call_cpacked = call_cpacked
+tvm_call_packed_lowered = call_packed_lowered
+tvm_call_cpacked_lowered = call_cpacked_lowered
# pylint: enable=invalid-name
diff --git a/python/tvm/script/printer/__init__.py b/python/tvm/script/printer/__init__.py
index d49614db0f..25ea619a41 100644
--- a/python/tvm/script/printer/__init__.py
+++ b/python/tvm/script/printer/__init__.py
@@ -16,12 +16,7 @@
# under the License.
"""
TVMScript Unified Printer
-
This package provides a set of APIs to print supported TVM IR into TVMScript
in a roundtrippable way.
-
-https://github.com/apache/tvm-rfcs/blob/main/rfcs/0074-tvmscript-unified-printer.md
"""
-
-from . import _ffi_api
-from .entry import script
+from .printer import script
diff --git a/python/tvm/script/printer/entry.py b/python/tvm/script/printer/entry.py
deleted file mode 100644
index c015702af0..0000000000
--- a/python/tvm/script/printer/entry.py
+++ /dev/null
@@ -1,71 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""
-This file contains the entry point of TVMScript Unified Printer.
-"""
-
-from typing import Dict, Optional
-
-from tvm.runtime import Object, ObjectPath
-
-from . import _ffi_api
-
-
-def script( # pylint: disable=too-many-arguments
- root_node: Object,
- ir_name: str,
- ir_prefix: Dict[str, str],
- indent_spaces: int = 4,
- print_line_numbers: bool = False,
- num_context_lines: int = -1,
- path_to_underline: Optional[ObjectPath] = None,
-) -> str:
- """
- Print IR graph as TVMScript code
-
- Parameters
- ----------
- root_node : Object
- The root node to print.
- ir_name : str
- The dispatch token of the target IR, e.g., "tir", "relax".
- ir_prefix : Dict[str, str]
- The symbol name for TVMScript IR namespaces. For example,
- {"tir": "T"}.
- indent_spaces : int
- The number of indent spaces to use in the output
- print_line_numbers: bool
- Whether to print line numbers
- num_context_lines : Optional[int]
- Number of context lines to print around the underlined text
- path_to_underline : Optional[ObjectPath]
- Object path to be underlined
-
- Returns
- -------
- script : str
- The TVMScript code of the root_node
- """
- return _ffi_api.Script( # type: ignore # pylint: disable=no-member
- root_node,
- ir_name,
- ir_prefix,
- indent_spaces,
- print_line_numbers,
- num_context_lines,
- path_to_underline,
- )
diff --git a/python/tvm/script/printer/frame.py b/python/tvm/script/printer/frame.py
deleted file mode 100644
index c967382b8b..0000000000
--- a/python/tvm/script/printer/frame.py
+++ /dev/null
@@ -1,81 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""
-Frame is the core data structure for semantic information when printing
-IR graph into TVMScript code.
-"""
-
-from typing import Callable, Sequence
-
-from tvm._ffi import register_object
-from tvm.runtime import Object
-from tvm.script.printer.doc import StmtDoc
-
-from . import _ffi_api
-
-
-class Frame(Object):
- """
- Frame is the core data structure for semantic information
- when printing IR graph into TVMScript code.
-
- Frame base class manages a list of callbacks to be executed
- when frame goes out of scope.
- """
-
- def add_exit_callback(self, callback: Callable[[], None]) -> None:
- """
- Adds a callback function to be executed when frame goes out of scope.
-
- Parameters
- ----------
- callback : Callable[[], None]
- The callback function.
- """
- _ffi_api.FrameAddExitCallback(self, callback) # type: ignore # pylint: disable=no-member
-
- def __enter__(self):
- _ffi_api.FrameEnterWithScope(self) # type: ignore # pylint: disable=no-member
- return self
-
- def __exit__(self, *exception_info):
- _ffi_api.FrameExitWithScope(self) # type: ignore # pylint: disable=no-member
-
-
-@register_object("script.printer.MetadataFrame")
-class MetadataFrame(Frame):
- """
- MetadataFrame contains information like contant parameter array.
- """
-
- metadata: Sequence[Object]
-
- def __init__(self):
- self.__init_handle_by_constructor__(_ffi_api.MetadataFrame) # type: ignore # pylint: disable=no-member
-
-
-@register_object("script.printer.VarDefFrame")
-class VarDefFrame(Frame):
- """
- VarDefFrame contains information about the free variables that needs to
- be defined at the beginning of the printed snippet.
- """
-
- stmts: Sequence[StmtDoc]
-
- def __init__(self):
- self.__init_handle_by_constructor__(_ffi_api.VarDefFrame) # type: ignore # pylint: disable=no-member
diff --git a/python/tvm/script/printer/ir_docsifier.py b/python/tvm/script/printer/ir_docsifier.py
deleted file mode 100644
index c5ba8a498b..0000000000
--- a/python/tvm/script/printer/ir_docsifier.py
+++ /dev/null
@@ -1,245 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""
-IRDocsifier is the top-level interface in the process of transforming
-IR graph into Doc tree, during printing IR graph as TVMScript code.
-"""
-
-import atexit
-from contextlib import ExitStack, contextmanager
-from typing import Callable, Dict, Generator, Mapping, Optional, Sequence, Set, Tuple, Type, TypeVar
-
-from tvm._ffi import get_object_type_index, register_object
-from tvm.runtime import Object, ObjectPath
-
-from . import _ffi_api
-from .doc import Doc
-from .frame import Frame
-from .var_table import VarTable
-
-_REGISTERED_TYPES: Set[Tuple[str, int]] = set() # {(dispatch_token, type_index)}
-
-
-def _cleanup_dispatch_function():
- for dispatch_token, type_index in _REGISTERED_TYPES:
- _ffi_api.IRDocsifierRemoveDispatch(dispatch_token, type_index) # type: ignore # pylint: disable=no-member
-
-
-_CLEANUP_REGISTERED = False
-
-
-def _ensure_cleanup_function_registered():
- """
- Add a cleanup function to be called on interpreter termination,
- to remove all dispatch functions registered on the Python side.
-
- Without cleaning up those dispatch functions, program will segfault
- on termination. It's because dispatch functions are referenced from the
- static memory of libtvm, thus they will be cleaned up at the very end,
- making calls to Py_DecRef after Python interpreter terminates.
- """
- global _CLEANUP_REGISTERED # pylint: disable=global-statement
-
- if not _CLEANUP_REGISTERED:
- atexit.register(_cleanup_dispatch_function)
- _CLEANUP_REGISTERED = True
-
-
-@register_object("script.printer.RootNodeContainer")
-class RootNodeContainer(Object):
- """
- A wrapper object to provide injection point for printer of each IR.
-
- This class shouldn't be used directly. `IRDocsifier.set_root_dispatch`
- should be used instead.
- """
-
- root_node: Object
-
- def __init__(self, root_node: Object):
- self.__init_handle_by_constructor__(_ffi_api.RootNodeContainer, root_node) # type: ignore # pylint: disable=no-member
-
-
-@register_object("script.printer.IRDocsifier")
-class IRDocsifier(Object):
- """
- IRDocsifier is the top-level interface in the IR->Doc process.
-
- It provides methods to convert IR node object to Doc, operate on Frame
- objects and change dispatch tokens.
- """
-
- ir_prefix: Mapping[str, str]
- vars: VarTable
- frames: Sequence[Frame]
- dispatch_tokens: Sequence[str]
-
- def __init__(self, ir_prefix: Dict[str, str]):
- """
- Create a new IRDocsifier.
-
- Parameters
- ----------
- ir_prefix : Dict[str, str]
- The ir prefix to use. Key is the IR dispatch token and
- value is the name of identifier for this IR's namespace in TVMScript.
- """
- self.__init_handle_by_constructor__(_ffi_api.IRDocsifier, ir_prefix) # type: ignore # pylint: disable=no-member
-
- _TObject = TypeVar("_TObject", bound=Object)
-
- @classmethod
- def set_dispatch(
- cls,
- node_type: Type[_TObject],
- dispatch_function: Callable[[_TObject, ObjectPath, "IRDocsifier"], Doc],
- dispatch_token: str = "",
- ) -> None:
- """
- Set the dispatch function to transform a particular IR node type to Doc
-
- Parameters
- ----------
- node_type : Type[_TObject]
- The type of object to dispatch on.
- dispatch_function : Callable[[_TObject, ObjectPath, "IRDocsifier"], Doc]
- The dispatch function. It's called to transform IR node object to Doc.
- dispatch_token : str
- Function will only be called when this dispatch_token is the same as the one
- on the top of IRDocsifier's dispatch_tokens stack. An empty dispatch token
- means registering as default dispatch function, which will be called when
- there is no dispatch function registered with the current dispatch token.
- """
- type_index = get_object_type_index(node_type)
- if type_index is None:
- raise TypeError(f"{type(node_type)} is not a registered TVM object type.")
-
- _ensure_cleanup_function_registered()
- _ffi_api.IRDocsifierSetDispatch( # type: ignore # pylint: disable=no-member
- dispatch_token, type_index, dispatch_function
- )
- _REGISTERED_TYPES.add((dispatch_token, type_index))
-
- @classmethod
- def set_root_dispatch(
- cls, dispatch_token: str, root_dispatch_function: Callable[[Object, "IRDocsifier"], Doc]
- ) -> None:
- """
- Set the root dispatch function for an IR.
-
- The root dispatch function will be called with the root node of an IR graph
- that's being transformed to Doc. This provides an injection point for
- each IR's printer implemention to add specialized logic, for example,
- pushing a special Frame to the IRDocsifier before doing actual IR->Doc
- transformation.
-
- The simplest root dispatch function is
- ```
- def f(obj, ir_docsifier)
- return ir_docsifier.as_doc(obj, ObjectPath.root())
- ```
-
- Parameters
- ----------
- root_dispatch_function : Callable[[_TObject, "IRDocsifier"], Doc]
- The root dispatch function. It's called with the root node to be printed.
- dispatch_token : str
- The dispatch token of the IR that root_dispatch_funnction applies to.
- """
-
- def dispatch_function(obj: RootNodeContainer, _, ir_docsifier):
- return root_dispatch_function(obj.root_node, ir_docsifier)
-
- cls.set_dispatch(RootNodeContainer, dispatch_function, dispatch_token)
-
- def as_doc(self, obj: Object, object_path: ObjectPath) -> Doc:
- """
- Transform the input object into Doc.
-
- Parameters
- ----------
- obj : Object
- The IR node object.
- object_path : ObjectPath
- The object path of this object. It's used for locating diagnostic message.
-
- Returns
- -------
- doc : Doc
- The doc for this object.
- """
- return _ffi_api.IRDocsifierAsDoc(self, obj, object_path) # type: ignore # pylint: disable=no-member
-
- def get_frame(self, frame_type: Type[Frame]) -> Optional[Frame]:
- """
- Get the top frame with type `frame_type`.
-
- Parameters
- ----------
- frame_type : Type[Frame]
- The target frame type.
-
- Returns
- -------
- frame : Optional[Frame]
- The frame if found, otherwise None.
- """
- for i in range(len(self.frames) - 1, -1, -1):
- if isinstance(self.frames[i], frame_type):
- return self.frames[i]
- return None
-
- @contextmanager
- def dispatch_token(self, token: str):
- """
- Push a new dispatch token to the stack.
-
- Parameters
- ----------
- token : str
- The token to push.
-
- Returns
- -------
- A context manager that pops this dispatch token when exits.
- """
- with ExitStack() as stack:
- _ffi_api.IRDocsifierPushDispatchToken(self, token) # type: ignore # pylint: disable=no-member
- stack.callback(_ffi_api.IRDocsifierPopDispatchToken, self) # type: ignore # pylint: disable=no-member
- yield
-
- _TFrame = TypeVar("_TFrame", bound=Frame)
-
- @contextmanager
- def frame(self, frame: _TFrame) -> Generator[_TFrame, None, None]:
- """
- Push a new frame to the stack.
-
- Parameters
- ----------
- frame : Frame
- The frame to push.
-
- Returns
- -------
- A context manager that pops this frame when exits.
- """
- with ExitStack() as stack:
- stack.enter_context(frame)
- _ffi_api.IRDocsifierPushFrame(self, frame) # type: ignore # pylint: disable=no-member
- stack.callback(_ffi_api.IRDocsifierPopFrame, self) # type: ignore # pylint: disable=no-member
- yield frame
diff --git a/python/tvm/script/printer/printer.py b/python/tvm/script/printer/printer.py
new file mode 100644
index 0000000000..120ef03f57
--- /dev/null
+++ b/python/tvm/script/printer/printer.py
@@ -0,0 +1,64 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""The printer interface"""
+
+from typing import Mapping, Optional
+
+from tvm.runtime.object_path import ObjectPath
+
+from . import _ffi_api
+
+
+def script(
+ obj,
+ ir_prefix: Optional[Mapping[str, str]] = None,
+ indent_space: int = 4,
+ print_line_number: bool = False,
+ num_context_lines: int = -1,
+ path_to_underline: Optional[ObjectPath] = None,
+):
+ """Print a TVM IR as a TVMScript text format.
+
+ Parameters
+ ----------
+ obj : object
+ An TVM object representing TVM IR
+ ir_prefix : Optional[Mapping[str, str]]
+ A mapping from IR type to the prefix of the script.
+ Default to {"ir": "I", "tir": T}
+ indent_space : int = 4
+ The number of spaces to indent
+ print_line_number : bool = False
+ Whether to print line number
+ num_context_lines : int = -1
+ The number of context lines to print. -1 means all lines.
+ path_to_underline : Optional[ObjectPath]
+ The path to underline in the script.
+
+ Returns
+ -------
+ script : str
+ The TVMScript text format
+ """
+ if ir_prefix is None:
+ ir_prefix = {
+ "ir": "I",
+ "tir": "T",
+ }
+ return _ffi_api.Script( # type: ignore # pylint: disable=no-member
+ obj, ir_prefix, indent_space, print_line_number, num_context_lines, path_to_underline
+ )
diff --git a/python/tvm/script/printer/var_table.py b/python/tvm/script/printer/var_table.py
deleted file mode 100644
index ea1fa41b32..0000000000
--- a/python/tvm/script/printer/var_table.py
+++ /dev/null
@@ -1,118 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""Functions to print doc into text format"""
-
-from typing import Callable, Optional
-
-from tvm._ffi import register_object
-from tvm.runtime import Object, ObjectPath
-
-from . import _ffi_api
-from .doc import ExprDoc, IdDoc
-from .frame import Frame
-
-
-@register_object("script.printer.VarTable")
-class VarTable(Object):
- """
- Variable Table manages mapping from variable object to ExprDoc during
- the process of printing TVMScript.
- """
-
- def __init__(self):
- """
- Create an empty VarTable.
- """
- self.__init_handle_by_constructor__(_ffi_api.VarTable) # type: ignore # pylint: disable=no-member
-
- def define(self, obj: Object, name_hint: str, object_path: ObjectPath, frame: Frame) -> IdDoc:
- """
- Define a variable by name.
-
- Parameters
- ----------
- obj : Object
- The variable object.
- name_hint : str
- The hint for variable name.
- object_path : ObjectPath
- The object path to be associated with the returned ExprDoc.
- frame : Frame
- Then frame that this variable is defined in.
-
- Returns
- -------
- doc : IdDoc
- The doc for this variable.
- """
- return _ffi_api.VarTableDefine(self, obj, name_hint, object_path, frame) # type: ignore # pylint: disable=no-member
-
- def define_by_doc(self, obj: Object, doc_factory: Callable[[], ExprDoc], frame: Frame) -> None:
- """
- Define a variable by ExprDoc.
-
- Parameters
- ----------
- obj : Object
- The variable object.
- doc_factory : Callable[[], ExprDoc]
- The hint for variable name.
- frame : Frame
- Then frame that this variable is defined in.
-
- Returns
- -------
- None
- """
- _ffi_api.VarTableDefineByDoc(self, obj, doc_factory, frame) # type: ignore # pylint: disable=no-member
-
- def get_var_doc(self, obj: Object, object_path: ObjectPath) -> Optional[ExprDoc]:
- """
- Get the doc for a variable.
-
- Parameters
- ----------
- obj : Object
- The variable object.
- object_path : ObjectPath
- The object path to be associated with the returned ExprDoc.
-
- Returns
- -------
- doc : ExprDoc
- The doc for this variable.
- """
- return _ffi_api.VarTableGetVarDoc(self, obj, object_path) # type: ignore # pylint: disable=no-member
-
- def is_var_defined(self, obj: Object) -> bool:
- """
- Check whether a variable is defined.
-
- Parameters
- ----------
- obj : Object
- The variable object.
-
- Returns
- -------
- is_defined : bool
- Whether the variable is defined.
- """
- return _ffi_api.VarTableIsVarDefined(self, obj) # type: ignore # pylint: disable=no-member
-
- def __contains__(self, obj: Object) -> bool:
- return self.is_var_defined(obj)
diff --git a/src/script/printer/doc.cc b/src/script/printer/doc.cc
index 1ca7ced8e8..f41b40c92c 100644
--- a/src/script/printer/doc.cc
+++ b/src/script/printer/doc.cc
@@ -27,18 +27,12 @@ namespace printer {
ExprDoc ExprDocNode::Attr(String attr) const { return AttrAccessDoc(GetRef<ExprDoc>(this), attr); }
-ExprDoc ExprDocNode::Attr(TracedObject<String> attr) const {
- auto doc = AttrAccessDoc(GetRef<ExprDoc>(this), attr.Get());
- doc->source_paths.push_back(attr.GetPath());
- return std::move(doc);
-}
-
ExprDoc ExprDocNode::operator[](Array<Doc> indices) const {
return IndexDoc(GetRef<ExprDoc>(this), indices);
}
ExprDoc ExprDocNode::Call(Array<ExprDoc, void> args) const {
- return CallDoc(GetRef<ExprDoc>(this), args, {}, {});
+ return CallDoc(GetRef<ExprDoc>(this), args, Array<String>(), Array<ExprDoc>());
}
ExprDoc ExprDocNode::Call(Array<ExprDoc, void> args, Array<String, void> kwargs_keys,
@@ -258,7 +252,7 @@ TVM_REGISTER_GLOBAL("script.printer.StmtBlockDoc").set_body_typed([](Array<StmtD
TVM_REGISTER_NODE_TYPE(LiteralDocNode);
TVM_REGISTER_GLOBAL("script.printer.LiteralDocNone").set_body_typed<LiteralDoc()>(LiteralDoc::None);
TVM_REGISTER_GLOBAL("script.printer.LiteralDocInt")
- .set_body_typed<LiteralDoc(int)>(LiteralDoc::Int);
+ .set_body_typed<LiteralDoc(int64_t)>(LiteralDoc::Int);
TVM_REGISTER_GLOBAL("script.printer.LiteralDocBoolean")
.set_body_typed<LiteralDoc(bool)>(LiteralDoc::Boolean);
TVM_REGISTER_GLOBAL("script.printer.LiteralDocFloat")
diff --git a/src/script/printer/base_doc_printer.cc b/src/script/printer/doc_printer/base_doc_printer.cc
similarity index 100%
rename from src/script/printer/base_doc_printer.cc
rename to src/script/printer/doc_printer/base_doc_printer.cc
diff --git a/src/script/printer/base_doc_printer.h b/src/script/printer/doc_printer/base_doc_printer.h
similarity index 97%
rename from src/script/printer/base_doc_printer.h
rename to src/script/printer/doc_printer/base_doc_printer.h
index f3fb24d946..db1d733d96 100644
--- a/src/script/printer/base_doc_printer.h
+++ b/src/script/printer/doc_printer/base_doc_printer.h
@@ -16,11 +16,10 @@
* specific language governing permissions and limitations
* under the License.
*/
-#ifndef TVM_SCRIPT_PRINTER_BASE_DOC_PRINTER_H_
-#define TVM_SCRIPT_PRINTER_BASE_DOC_PRINTER_H_
+#ifndef TVM_SCRIPT_PRINTER_DOC_PRINTER_BASE_DOC_PRINTER_H_
+#define TVM_SCRIPT_PRINTER_DOC_PRINTER_BASE_DOC_PRINTER_H_
#include <tvm/script/printer/doc.h>
-#include <tvm/script/printer/doc_printer.h>
#include <limits>
#include <memory>
@@ -287,4 +286,4 @@ class DocPrinter {
} // namespace script
} // namespace tvm
-#endif // TVM_SCRIPT_PRINTER_BASE_DOC_PRINTER_H_
+#endif // TVM_SCRIPT_PRINTER_DOC_PRINTER_BASE_DOC_PRINTER_H_
diff --git a/src/script/printer/python_doc_printer.cc b/src/script/printer/doc_printer/python_doc_printer.cc
similarity index 98%
rename from src/script/printer/python_doc_printer.cc
rename to src/script/printer/doc_printer/python_doc_printer.cc
index 753f907c42..6851baf638 100644
--- a/src/script/printer/python_doc_printer.cc
+++ b/src/script/printer/doc_printer/python_doc_printer.cc
@@ -21,10 +21,11 @@
#include <tvm/script/printer/doc.h>
#include <algorithm>
+#include <cmath>
#include <string>
-#include "../../support/str_escape.h"
-#include "../../support/utils.h"
+#include "../../../support/str_escape.h"
+#include "../../../support/utils.h"
#include "./base_doc_printer.h"
namespace tvm {
@@ -294,7 +295,11 @@ void PythonDocPrinter::PrintTypedDoc(const LiteralDoc& doc) {
} else if (const auto* float_imm = value.as<FloatImmNode>()) {
// TODO(yelite): Make float number printing roundtrippable
output_.precision(17);
- output_ << float_imm->value;
+ if (std::isinf(float_imm->value) || std::isnan(float_imm->value)) {
+ output_ << '"' << float_imm->value << '"';
+ } else {
+ output_ << float_imm->value;
+ }
} else if (const auto* string_obj = value.as<StringObj>()) {
output_ << "\"" << support::StrEscape(string_obj->data, string_obj->size) << "\"";
} else {
diff --git a/src/script/printer/frame.cc b/src/script/printer/frame.cc
deleted file mode 100644
index b342c7c886..0000000000
--- a/src/script/printer/frame.cc
+++ /dev/null
@@ -1,50 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-#include <tvm/runtime/registry.h>
-#include <tvm/script/printer/frame.h>
-
-namespace tvm {
-namespace script {
-namespace printer {
-
-MetadataFrame::MetadataFrame() : MetadataFrame(make_object<MetadataFrameNode>()) {}
-
-VarDefFrame::VarDefFrame() : VarDefFrame(make_object<VarDefFrameNode>()) {}
-
-TVM_REGISTER_NODE_TYPE(FrameNode);
-TVM_REGISTER_GLOBAL("script.printer.FrameAddExitCallback")
- .set_body_typed([](Frame frame, runtime::TypedPackedFunc<void()> callback) {
- frame->AddExitCallback(callback);
- });
-TVM_REGISTER_GLOBAL("script.printer.FrameEnterWithScope")
- .set_body_method<Frame>(&FrameNode::EnterWithScope);
-TVM_REGISTER_GLOBAL("script.printer.FrameExitWithScope")
- .set_body_method<Frame>(&FrameNode::ExitWithScope);
-
-TVM_REGISTER_NODE_TYPE(MetadataFrameNode);
-TVM_REGISTER_GLOBAL("script.printer.MetadataFrame").set_body_typed([]() {
- return MetadataFrame();
-});
-
-TVM_REGISTER_NODE_TYPE(VarDefFrameNode);
-TVM_REGISTER_GLOBAL("script.printer.VarDefFrame").set_body_typed([]() { return VarDefFrame(); });
-
-} // namespace printer
-} // namespace script
-} // namespace tvm
diff --git a/src/script/printer/ir/ir.cc b/src/script/printer/ir/ir.cc
new file mode 100644
index 0000000000..c4ecf92e91
--- /dev/null
+++ b/src/script/printer/ir/ir.cc
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include "./utils.h"
+
+namespace tvm {
+namespace script {
+namespace printer {
+
+TVM_REGISTER_NODE_TYPE(IRFrameNode);
+
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
+ .set_dispatch<IRModule>("", [](IRModule mod, ObjectPath p, IRDocsifier d) -> Doc {
+ std::vector<std::pair<GlobalVar, BaseFunc>> functions{mod->functions.begin(),
+ mod->functions.end()};
+ // print "main" first
+ std::sort(functions.begin(), functions.end(), [](const auto& lhs, const auto& rhs) {
+ String lhs_name = lhs.first->name_hint;
+ String rhs_name = rhs.first->name_hint;
+ if (lhs_name == "main") {
+ lhs_name = "";
+ }
+ if (rhs_name == "main") {
+ rhs_name = "";
+ }
+ return lhs_name < rhs_name;
+ });
+ ICHECK(!d->mod.defined());
+ d->mod = mod;
+ {
+ With<IRFrame> f(d);
+ (*f)->AddDispatchToken(d, "ir");
+ for (const auto& kv : functions) {
+ GlobalVar gv = kv.first;
+ BaseFunc func = kv.second;
+ (*f)->stmts.push_back(d->AsDoc<FunctionDoc>(func, p->Attr("functions")->MapValue(gv)));
+ }
+ return ClassDoc(IdDoc("Module"), {IR(d)}, (*f)->stmts);
+ }
+ });
+
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
+ .set_dispatch<DictAttrs>("", [](DictAttrs attrs, ObjectPath p, IRDocsifier d) -> Doc {
+ return d->AsDoc(attrs->dict, p->Attr("dict"));
+ });
+
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
+ .set_dispatch<GlobalVar>("", [](GlobalVar gv, ObjectPath p, IRDocsifier d) -> Doc {
+ return IdDoc("GlobalVar")->Call({LiteralDoc::Str(gv->name_hint)});
+ });
+
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
+ .set_dispatch<Op>("", [](Op op, ObjectPath p, IRDocsifier d) -> Doc {
+ return IdDoc("Op")->Call({LiteralDoc::Str(op->name)});
+ });
+
+} // namespace printer
+} // namespace script
+} // namespace tvm
diff --git a/src/script/printer/ir/misc.cc b/src/script/printer/ir/misc.cc
new file mode 100644
index 0000000000..bd27921671
--- /dev/null
+++ b/src/script/printer/ir/misc.cc
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include "./utils.h"
+
+namespace tvm {
+namespace script {
+namespace printer {
+
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
+ .set_dispatch<String>("", [](String s, ObjectPath p, IRDocsifier d) -> Doc {
+ return LiteralDoc::Str(s);
+ });
+
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
+ .set_dispatch<Array<ObjectRef>>( //
+ "", [](Array<ObjectRef> array, ObjectPath p, IRDocsifier d) -> Doc {
+ int n = array.size();
+ Array<ExprDoc> results;
+ results.reserve(n);
+ for (int i = 0; i < n; ++i) {
+ results.push_back(d->AsDoc<ExprDoc>(array[i], p->ArrayIndex(i)));
+ }
+ return ListDoc(results);
+ });
+
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
+ .set_dispatch<Map<ObjectRef, ObjectRef>>( //
+ "", [](Map<ObjectRef, ObjectRef> dict, ObjectPath p, IRDocsifier d) -> Doc {
+ using POO = std::pair<ObjectRef, ObjectRef>;
+ std::vector<POO> items{dict.begin(), dict.end()};
+ bool is_str_map = true;
+ for (const auto& kv : items) {
+ if (!kv.first.as<runtime::StringObj>()) {
+ is_str_map = false;
+ break;
+ }
+ }
+ if (is_str_map) {
+ std::sort(items.begin(), items.end(), [](const POO& lhs, const POO& rhs) {
+ return Downcast<String>(lhs.first) < Downcast<String>(rhs.first);
+ });
+ } else {
+ std::sort(items.begin(), items.end(), [](const POO& lhs, const POO& rhs) {
+ return lhs.first.get() < rhs.first.get();
+ });
+ }
+ int n = dict.size();
+ Array<ExprDoc> ks;
+ Array<ExprDoc> vs;
+ ks.reserve(n);
+ vs.reserve(n);
+ for (int i = 0; i < n; ++i) {
+ ks.push_back(d->AsDoc<ExprDoc>(items[i].first, p->MissingMapEntry()));
+ vs.push_back(d->AsDoc<ExprDoc>(items[i].second, p->MapValue(items[i].first)));
+ }
+ return DictDoc(ks, vs);
+ });
+
+} // namespace printer
+} // namespace script
+} // namespace tvm
diff --git a/src/script/printer.cc b/src/script/printer/ir/utils.h
similarity index 50%
copy from src/script/printer.cc
copy to src/script/printer/ir/utils.h
index 051b774ba6..4065b895c1 100644
--- a/src/script/printer.cc
+++ b/src/script/printer/ir/utils.h
@@ -16,39 +16,46 @@
* specific language governing permissions and limitations
* under the License.
*/
+#ifndef TVM_SCRIPT_PRINTER_IR_UTILS_H_
+#define TVM_SCRIPT_PRINTER_IR_UTILS_H_
-#include <tvm/runtime/registry.h>
-#include <tvm/script/printer.h>
-#include <tvm/script/printer/doc.h>
-#include <tvm/script/printer/doc_printer.h>
-#include <tvm/script/printer/frame.h>
+#include <tvm/ir/expr.h>
+#include <tvm/ir/function.h>
+#include <tvm/ir/op.h>
#include <tvm/script/printer/ir_docsifier.h>
+#include <tvm/script/printer/printer.h>
+#include <tvm/support/with.h>
+
+#include <utility>
namespace tvm {
namespace script {
namespace printer {
-String Script( //
- const ObjectRef& root_node, //
- String ir_name, //
- Map<String, String> ir_prefix, //
- int indent_spaces, //
- bool print_line_numbers, //
- int num_context_lines, //
- Optional<ObjectPath> path_to_underline //
-) {
- IRDocsifier ir_docsifier(ir_prefix);
+inline ExprDoc IR(const IRDocsifier& d) { return IdDoc("tvm")->Attr("script"); }
- auto dispatch_ctx = ir_docsifier->WithDispatchToken(ir_name);
+class IRFrameNode : public FrameNode {
+ public:
+ void VisitAttrs(AttrVisitor* v) { FrameNode::VisitAttrs(v); }
- Doc doc = ir_docsifier->AsDoc<Doc>(MakeTraced(RootNodeContainer(root_node)));
+ static constexpr const char* _type_key = "script.printer.IRFrame";
+ TVM_DECLARE_FINAL_OBJECT_INFO(IRFrameNode, FrameNode);
+};
- return DocToPythonScript(doc, indent_spaces, print_line_numbers, num_context_lines,
- path_to_underline);
-}
+class IRFrame : public Frame {
+ public:
+ explicit IRFrame(const IRDocsifier& d) {
+ ObjectPtr<IRFrameNode> n = make_object<IRFrameNode>();
+ n->stmts.clear();
+ n->d = d.get();
+ data_ = std::move(n);
+ }
-TVM_REGISTER_GLOBAL("script.printer.Script").set_body_typed(&Script);
+ TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(IRFrame, Frame, IRFrameNode);
+};
} // namespace printer
} // namespace script
} // namespace tvm
+
+#endif // TVM_SCRIPT_PRINTER_IR_UTILS_H_
diff --git a/src/script/printer/ir_docsifier.cc b/src/script/printer/ir_docsifier.cc
index 7f032ec502..8584f36031 100644
--- a/src/script/printer/ir_docsifier.cc
+++ b/src/script/printer/ir_docsifier.cc
@@ -20,21 +20,136 @@
#include <tvm/runtime/logging.h>
#include <tvm/runtime/registry.h>
#include <tvm/script/printer/ir_docsifier.h>
-#include <tvm/script/printer/traced_object.h>
-#include <tvm/script/printer/traced_object_functor.h>
namespace tvm {
namespace script {
namespace printer {
-Doc IRDocsifierNode::AsDocImpl(const TracedObject<ObjectRef>& obj) const {
- return IRDocsifier::vtable()(dispatch_tokens.back(), obj, GetRef<IRDocsifier>(this));
+String GenerateUniqueName(std::string name_hint, std::unordered_set<String>* defined_names) {
+ for (char& c : name_hint) {
+ if (c != 'c' && !std::isalnum(c)) {
+ c = '_';
+ }
+ }
+ std::string name = name_hint;
+ for (int i = 1; !defined_names->insert(name).second; ++i) {
+ name = name_hint + "_" + std::to_string(i);
+ }
+ return name;
+}
+
+IdDoc IRDocsifierNode::Define(const ObjectRef& obj, const Frame& frame, const String& name_hint) {
+ String name = GenerateUniqueName(name_hint, &this->defined_names);
+ DocCreator doc_factory = [name]() { return IdDoc(name); };
+ auto result = obj2info.insert({obj, VariableInfo{std::move(doc_factory), name}});
+ ICHECK(result.second) << "Duplicated object: " << obj;
+ IdDoc def_doc(name);
+ frame->AddExitCallback([this, obj]() { this->RemoveVar(obj); });
+ return def_doc;
+}
+
+void IRDocsifierNode::Define(const ObjectRef& obj, const Frame& frame, DocCreator doc_factory) {
+ ICHECK(obj2info.find(obj) == obj2info.end()) << "Duplicated object: " << obj;
+ ICHECK(!doc_factory()->IsInstance<IdDocNode>())
+ << "IRDocsifierNode::Define cannot be used for variable that's mapped to IdDoc.";
+ obj2info.insert({obj, VariableInfo{std::move(doc_factory), NullOpt}});
+ frame->AddExitCallback([this, obj]() { this->RemoveVar(obj); });
+}
+
+Optional<ExprDoc> IRDocsifierNode::GetVarDoc(const ObjectRef& obj) const {
+ auto it = obj2info.find(obj);
+ if (it == obj2info.end()) {
+ return NullOpt;
+ }
+ return it->second.creator();
+}
+
+bool IRDocsifierNode::IsVarDefined(const ObjectRef& obj) const { return obj2info.count(obj); }
+
+void IRDocsifierNode::RemoveVar(const ObjectRef& obj) {
+ auto it = obj2info.find(obj);
+ ICHECK(it != obj2info.end()) << "No such object: " << obj;
+ if (it->second.name.defined()) {
+ defined_names.erase(it->second.name.value());
+ }
+ obj2info.erase(it);
+}
+
+void IRDocsifierNode::SetCommonPrefix(const ObjectRef& root,
+ runtime::TypedPackedFunc<bool(ObjectRef)> is_var) {
+ class Visitor : public AttrVisitor {
+ public:
+ inline void operator()(ObjectRef obj) { Visit("", &obj); }
+
+ private:
+ void Visit(const char* key, double* value) final {}
+ void Visit(const char* key, int64_t* value) final {}
+ void Visit(const char* key, uint64_t* value) final {}
+ void Visit(const char* key, int* value) final {}
+ void Visit(const char* key, bool* value) final {}
+ void Visit(const char* key, std::string* value) final {}
+ void Visit(const char* key, void** value) final {}
+ void Visit(const char* key, DataType* value) final {}
+ void Visit(const char* key, runtime::NDArray* value) final {}
+ void Visit(const char* key, ObjectRef* value) final {
+ const Object* obj = value->get();
+ if (obj == nullptr) {
+ return;
+ }
+ stack_.push_back(obj);
+ if (obj->IsInstance<ArrayNode>()) {
+ const ArrayNode* array = static_cast<const ArrayNode*>(obj);
+ for (ObjectRef element : *array) {
+ this->Visit("", &element);
+ }
+ } else if (obj->IsInstance<MapNode>()) {
+ const MapNode* map = static_cast<const MapNode*>(obj);
+ for (std::pair<ObjectRef, ObjectRef> kv : *map) {
+ this->Visit("", &kv.first);
+ this->Visit("", &kv.second);
+ }
+ } else {
+ vtable_->VisitAttrs(const_cast<Object*>(obj), this);
+ }
+ if (is_var(GetRef<ObjectRef>(obj))) {
+ HandleVar(obj);
+ }
+ stack_.pop_back();
+ }
+
+ void HandleVar(const Object* var) {
+ if (common_prefix.count(var) == 0) {
+ common_prefix[var] = stack_;
+ return;
+ }
+ std::vector<const Object*>& a = common_prefix[var];
+ std::vector<const Object*>& b = stack_;
+ int n = std::min(a.size(), b.size());
+ for (int i = 0; i < n; ++i) {
+ if (a[i] != b[i]) {
+ a.resize(i);
+ break;
+ }
+ }
+ }
+
+ ReflectionVTable* vtable_ = ReflectionVTable::Global();
+ std::vector<const Object*> stack_;
+
+ public:
+ runtime::TypedPackedFunc<bool(ObjectRef)> is_var;
+ std::unordered_map<const Object*, std::vector<const Object*>> common_prefix;
+ };
+ Visitor visitor;
+ visitor.is_var = is_var;
+ visitor(root);
+ this->common_prefix = std::move(visitor.common_prefix);
}
IRDocsifier::IRDocsifier(Map<String, String> ir_prefix) {
auto n = make_object<IRDocsifierNode>();
n->ir_prefix = std::move(ir_prefix);
- n->dispatch_tokens.push_back(kDefaultDispatchToken);
+ n->dispatch_tokens.push_back("");
data_ = std::move(n);
}
@@ -43,65 +158,8 @@ IRDocsifier::FType& IRDocsifier::vtable() {
return inst;
}
-RootNodeContainer::RootNodeContainer(ObjectRef root_node) {
- auto n = make_object<RootNodeContainerNode>();
- n->root_node = std::move(root_node);
- data_ = std::move(n);
-}
-
-// Add a default dispatch for the RootNodeContainer to throw error.
-// To add implementation for a new IR, RootNodeContainer needs to be
-// registered under the dispatch token of that IR, like:
-// \code
-// TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
-// .set_dispatch("relax", [](TracedObject<RootNodeContainer> obj, IRDocsifier p) {
-// const ObjectRef& root_node = obj.Get()->root_node;
-// \\ More specialized logic for your IR.
-// return p->AsDoc<Doc>(MakeTraced(root_node));
-// });
-// \endcode
-TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
- .set_dispatch<RootNodeContainer>([](TracedObject<RootNodeContainer> obj, IRDocsifier p) -> Doc {
- String top_dispatch_token = p->dispatch_tokens.back();
- ICHECK_NE(top_dispatch_token, "");
- ICHECK(false) << "Printing IR " << top_dispatch_token << " is not implemented.";
- throw;
- });
-
+TVM_REGISTER_NODE_TYPE(FrameNode);
TVM_REGISTER_NODE_TYPE(IRDocsifierNode);
-TVM_REGISTER_GLOBAL("script.printer.IRDocsifier").set_body_typed([](Map<String, String> ir_prefix) {
- return IRDocsifier(ir_prefix);
-});
-TVM_REGISTER_GLOBAL("script.printer.IRDocsifierAsDoc")
- .set_body_typed([](IRDocsifier p, ObjectRef obj, ObjectPath obj_path) {
- return p->AsDoc<Doc>(MakeTraced(obj, obj_path));
- });
-
-TVM_REGISTER_GLOBAL("script.printer.IRDocsifierPushDispatchToken")
- .set_body_typed([](IRDocsifier p, String token) { p->dispatch_tokens.push_back(token); });
-TVM_REGISTER_GLOBAL("script.printer.IRDocsifierPopDispatchToken").set_body_typed([](IRDocsifier p) {
- p->dispatch_tokens.pop_back();
-});
-
-TVM_REGISTER_GLOBAL("script.printer.IRDocsifierPushFrame")
- .set_body_typed([](IRDocsifier p, Frame frame) { p->frames.push_back(frame); });
-TVM_REGISTER_GLOBAL("script.printer.IRDocsifierPopFrame").set_body_typed([](IRDocsifier p) {
- p->frames.pop_back();
-});
-
-TVM_REGISTER_GLOBAL("script.printer.IRDocsifierSetDispatch")
- .set_body_typed([](String token, uint64_t type_index, runtime::PackedFunc f) {
- IRDocsifier::vtable().set_dispatch(token, type_index, std::move(f));
- });
-TVM_REGISTER_GLOBAL("script.printer.IRDocsifierRemoveDispatch")
- .set_body_typed([](String token, uint64_t type_index) {
- IRDocsifier::vtable().remove_dispatch(token, type_index);
- });
-
-TVM_REGISTER_NODE_TYPE(RootNodeContainerNode);
-TVM_REGISTER_GLOBAL("script.printer.RootNodeContainer").set_body_typed([](ObjectRef root_node) {
- return RootNodeContainer(root_node);
-});
} // namespace printer
} // namespace script
diff --git a/src/script/printer.cc b/src/script/printer/printer.cc
similarity index 57%
rename from src/script/printer.cc
rename to src/script/printer/printer.cc
index 051b774ba6..47fd0b89b0 100644
--- a/src/script/printer.cc
+++ b/src/script/printer/printer.cc
@@ -16,38 +16,28 @@
* specific language governing permissions and limitations
* under the License.
*/
-
#include <tvm/runtime/registry.h>
-#include <tvm/script/printer.h>
-#include <tvm/script/printer/doc.h>
-#include <tvm/script/printer/doc_printer.h>
-#include <tvm/script/printer/frame.h>
-#include <tvm/script/printer/ir_docsifier.h>
+#include <tvm/script/printer/printer.h>
namespace tvm {
namespace script {
namespace printer {
-String Script( //
- const ObjectRef& root_node, //
- String ir_name, //
- Map<String, String> ir_prefix, //
- int indent_spaces, //
- bool print_line_numbers, //
- int num_context_lines, //
- Optional<ObjectPath> path_to_underline //
-) {
- IRDocsifier ir_docsifier(ir_prefix);
-
- auto dispatch_ctx = ir_docsifier->WithDispatchToken(ir_name);
-
- Doc doc = ir_docsifier->AsDoc<Doc>(MakeTraced(RootNodeContainer(root_node)));
-
+String Script(ObjectRef obj, Map<String, String> ir_prefix, int indent_spaces,
+ bool print_line_numbers, int num_context_lines,
+ Optional<ObjectPath> path_to_underline) {
+ IRDocsifier d(ir_prefix);
+ Doc doc = d->AsDoc(obj, ObjectPath::Root());
return DocToPythonScript(doc, indent_spaces, print_line_numbers, num_context_lines,
path_to_underline);
}
-TVM_REGISTER_GLOBAL("script.printer.Script").set_body_typed(&Script);
+Default* Default::Instance() {
+ static Default inst;
+ return &inst;
+}
+
+TVM_REGISTER_GLOBAL("script.printer.Script").set_body_typed(Script);
} // namespace printer
} // namespace script
diff --git a/src/script/printer/tir/block.cc b/src/script/printer/tir/block.cc
new file mode 100644
index 0000000000..f6dbf616a5
--- /dev/null
+++ b/src/script/printer/tir/block.cc
@@ -0,0 +1,150 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include "./utils.h"
+
+namespace tvm {
+namespace script {
+namespace printer {
+
+Doc PrintBlock(IRDocsifier d, tir::Block block, ObjectPath block_p, //
+ Optional<tir::BlockRealize> opt_realize, Optional<ObjectPath> opt_realize_p) {
+ With<TIRFrame> frame(d, block);
+ ICHECK_EQ(opt_realize.defined(), opt_realize_p.defined());
+ const tir::BlockRealizeNode* realize = opt_realize.value().get();
+ const ObjectPathNode* realize_p = opt_realize_p.get();
+ // Step 1. Handle block var and block bindings
+ int n_vars = block->iter_vars.size();
+ for (int i = 0; i < n_vars; ++i) {
+ tir::IterVar iter_var = block->iter_vars[i];
+ ObjectPath iter_var_p = block_p->Attr("iter_var")->ArrayIndex(i);
+ ExprDoc rhs = TIR(d)->Attr("axis");
+ if (iter_var->iter_type == tir::IterVarType::kDataPar) {
+ rhs = rhs->Attr("spatial");
+ } else if (iter_var->iter_type == tir::IterVarType::kCommReduce) {
+ rhs = rhs->Attr("reduce");
+ } else if (iter_var->iter_type == tir::IterVarType::kOrdered) {
+ rhs = rhs->Attr("scan");
+ } else if (iter_var->iter_type == tir::IterVarType::kOpaque) {
+ rhs = rhs->Attr("opaque");
+ } else {
+ LOG(FATAL) << "ValueError: Unknown IterVarType in block signature: "
+ << tir::IterVarType2String(iter_var->iter_type);
+ }
+ ExprDoc dom{nullptr};
+ if (tir::is_zero(iter_var->dom->min)) {
+ ExprDoc extent = d->AsDoc<ExprDoc>(iter_var->dom->extent, //
+ iter_var_p->Attr("dom")->Attr("extent"));
+ dom = extent;
+ } else {
+ ExprDoc min = d->AsDoc<ExprDoc>(iter_var->dom->min, iter_var_p->Attr("dom")->Attr("min"));
+ ExprDoc max = d->AsDoc<ExprDoc>(iter_var->dom->min + iter_var->dom->extent,
+ iter_var_p->Attr("dom")->Attr("extent"));
+ dom = TupleDoc({min, max});
+ }
+ if (realize) {
+ ExprDoc binding = d->AsDoc<ExprDoc>(realize->iter_values[i], //
+ realize_p->Attr("iter_values")->ArrayIndex(i));
+ rhs = rhs->Call({dom, binding});
+ } else {
+ rhs = rhs->Call({dom});
+ }
+ (*frame)->stmts.push_back(AssignDoc(DefineVar(iter_var->var, *frame, d), rhs, NullOpt));
+ }
+ // Step 2. Handle block predicate
+ if (realize) {
+ ICHECK(realize->predicate.defined() && realize->predicate->dtype.is_bool());
+ if (!tir::is_one(realize->predicate)) {
+ (*frame)->stmts.push_back(ExprStmtDoc(TIR(d)->Attr("where")->Call(
+ {d->AsDoc<ExprDoc>(realize->predicate, realize_p->Attr("predicate"))})));
+ }
+ }
+ // Step 3. Handle block read/write regions
+ {
+ Array<ExprDoc> reads;
+ for (int i = 0, n = block->reads.size(); i < n; ++i) {
+ reads.push_back(d->AsDoc<ExprDoc>(block->reads[i], block_p->Attr("reads")->ArrayIndex(i)));
+ }
+ (*frame)->stmts.push_back(ExprStmtDoc(TIR(d)->Attr("reads")->Call(reads)));
+ Array<ExprDoc> writes;
+ for (int i = 0, n = block->writes.size(); i < n; ++i) {
+ writes.push_back(d->AsDoc<ExprDoc>(block->writes[i], block_p->Attr("writes")->ArrayIndex(i)));
+ }
+ (*frame)->stmts.push_back(ExprStmtDoc(TIR(d)->Attr("writes")->Call(writes)));
+ }
+ // Step 4. Handle block attributes
+ if (!block->annotations.empty()) {
+ (*frame)->stmts.push_back(ExprStmtDoc(
+ TIR(d)
+ ->Attr("block_attr")
+ ->Call({d->AsDoc<ExprDoc>(block->annotations, block_p->Attr("annotations"))})));
+ }
+ // Step 5. Handle `alloc_buffer`
+ for (int i = 0, n = block->alloc_buffers.size(); i < n; ++i) {
+ tir::Buffer buffer = block->alloc_buffers[i];
+ ObjectPath buffer_p = block_p->Attr("alloc_buffers")->ArrayIndex(i);
+ IdDoc lhs = DefineBuffer(buffer, *frame, d);
+ ExprDoc rhs = BufferDecl(buffer, "alloc_buffer", {}, buffer_p, *frame, d);
+ (*frame)->stmts.push_back(AssignDoc(lhs, rhs, NullOpt));
+ }
+ // Step 6. Handle `match_buffer`
+ for (int i = 0, n = block->match_buffers.size(); i < n; ++i) {
+ tir::MatchBufferRegion buffer_region = block->match_buffers[i];
+ ObjectPath buffer_region_p = block_p->Attr("match_buffers")->ArrayIndex(i);
+ StmtDoc doc = d->AsDoc<StmtDoc>(buffer_region, buffer_region_p);
+ (*frame)->stmts.push_back(doc);
+ }
+ // Step 7. Handle init block
+ if (block->init.defined()) {
+ tir::Stmt init = block->init.value();
+ With<TIRFrame> init_frame(d, init);
+ AsDocBody(init, block_p->Attr("init"), init_frame->get(), d);
+ (*frame)->stmts.push_back(
+ ScopeDoc(NullOpt, TIR(d)->Attr("init")->Call({}), (*init_frame)->stmts));
+ }
+ // Step 8. Handle block body
+ AsDocBody(block->body, block_p->Attr("body"), frame->get(), d);
+ return ScopeDoc(NullOpt, TIR(d)->Attr("block")->Call({LiteralDoc::Str(block->name_hint)}),
+ (*frame)->stmts);
+}
+
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
+ .set_dispatch<tir::BlockRealize>(
+ "", [](tir::BlockRealize realize, ObjectPath p, IRDocsifier d) -> Doc {
+ return PrintBlock(d, realize->block, p->Attr("block"), realize, p);
+ });
+
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
+ .set_dispatch<tir::Block>("", [](tir::Block block, ObjectPath p, IRDocsifier d) -> Doc {
+ return PrintBlock(d, block, p, NullOpt, NullOpt);
+ });
+
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
+ .set_dispatch<tir::MatchBufferRegion>(
+ "", [](tir::MatchBufferRegion stmt, ObjectPath p, IRDocsifier d) -> Doc {
+ Frame frame = d->frames.back();
+ ExprDoc lhs = DefineBuffer(stmt->buffer, frame, d);
+ ExprDoc src_buffer = d->AsDoc<ExprDoc>(stmt->source, p->Attr("source"));
+ ExprDoc rhs = BufferDecl(stmt->buffer, "match_buffer", {src_buffer}, p->Attr("buffer"),
+ d->frames.back(), d);
+ return AssignDoc(lhs, rhs, NullOpt);
+ });
+
+} // namespace printer
+} // namespace script
+} // namespace tvm
diff --git a/src/script/printer/tir/buffer.cc b/src/script/printer/tir/buffer.cc
new file mode 100644
index 0000000000..3e1d71af4a
--- /dev/null
+++ b/src/script/printer/tir/buffer.cc
@@ -0,0 +1,193 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <tvm/runtime/device_api.h>
+
+#include "./utils.h"
+
+namespace tvm {
+namespace script {
+namespace printer {
+
+Map<String, ExprDoc> BufferAttrs(const tir::Buffer& buffer, const ObjectPath& p, const Frame& frame,
+ const IRDocsifier& d) {
+ Map<String, ExprDoc> kwargs;
+ auto implicit_var_def = [&](const PrimExpr& e, const ObjectPath& p, const String& key) {
+ if (Optional<ExprDoc> doc = d->GetVarDoc(e)) {
+ kwargs.Set(key, doc.value());
+ return false;
+ }
+ if (e->IsInstance<tir::VarNode>()) {
+ d->Define(e, frame, [=]() { return d->AsDoc<IdDoc>(buffer, p)->Attr(key); });
+ return true;
+ }
+ kwargs.Set(key, d->AsDoc<ExprDoc>(e, p));
+ return false;
+ };
+ auto array_out_line_var_def = [&](const Array<PrimExpr>& array, const ObjectPath& p,
+ const String& key) {
+ int n = array.size();
+ Array<ExprDoc> results;
+ results.reserve(n);
+ for (int i = 0; i < n; ++i) {
+ PrimExpr s = array[i];
+ ObjectPath s_path = p->ArrayIndex(i);
+ // Add out-of-line definition for a new Var in shape
+ results.push_back(d->AsDoc<ExprDoc>(s, s_path));
+ }
+ kwargs.Set(key, TupleDoc(results));
+ };
+ // Step 1. Handle `buffer.shape`
+ array_out_line_var_def(buffer->shape, p->Attr("shape"), "shape");
+ // Step 2. Handle `buffer.dtype`
+ if (buffer->dtype != Default::BufferDType()) {
+ kwargs.Set("dtype", LiteralDoc::DataType(buffer->dtype));
+ }
+ // Step 3. Handle `buffer.data`
+ implicit_var_def(buffer->data, p->Attr("data"), "data");
+ // Step 4. Handle `buffer.strides`
+ if (!buffer->strides.empty()) {
+ array_out_line_var_def(buffer->strides, p->Attr("strides"), "strides");
+ }
+ // Step 5. Handle `buffer.elem_offset`
+ bool needs_print_factor = false;
+ if (const auto* int_imm = buffer->elem_offset.as<IntImmNode>()) {
+ if (int_imm->value != 0) {
+ kwargs.Set("elem_offset", d->AsDoc<ExprDoc>(buffer->elem_offset, p->Attr("elem_offset")));
+ }
+ } else {
+ needs_print_factor =
+ implicit_var_def(buffer->elem_offset, p->Attr("elem_offset"), "elem_offset");
+ }
+ // Step 6. Handle `buffer.scope`
+ {
+ String scope = buffer.scope();
+ if (scope != "global") {
+ kwargs.Set("scope", LiteralDoc::Str(scope));
+ }
+ }
+ // Step 7. Handle `buffer.data_alignment`
+ if (buffer->data_alignment != runtime::kAllocAlignment) {
+ kwargs.Set("align", LiteralDoc::Int(buffer->data_alignment));
+ }
+ // Step 8. Handle `buffer.offset_factor`
+ if (needs_print_factor || buffer->offset_factor != 1) {
+ kwargs.Set("offset_factor", LiteralDoc::Int(buffer->offset_factor));
+ }
+ // Step 9. Handle `buffer.buffer_type`
+ if (buffer->buffer_type != tir::BufferType::kDefault) {
+ kwargs.Set("type", LiteralDoc::Str("auto"));
+ }
+ // Step 10. Handle `buffer.axis_separator`
+ if (!buffer->axis_separators.empty()) {
+ kwargs.Set("axis_separators",
+ d->AsDoc<ExprDoc>(buffer->axis_separators, p->Attr("axis_separators")));
+ }
+ return kwargs;
+}
+
+ExprDoc BufferCall(const ExprDoc& prefix, const Map<String, ExprDoc>& attrs, Array<ExprDoc> args) {
+ Array<String> kwargs_keys;
+ Array<ExprDoc> kwargs_values;
+ for (String s : {"shape", "dtype"}) {
+ if (Optional<ExprDoc> doc = attrs.Get(s)) {
+ args.push_back(doc.value());
+ }
+ }
+ for (String s : {"data", "strides", "elem_offset", "scope", "align", "offset_factor", "type",
+ "axis_separators"}) {
+ if (Optional<ExprDoc> doc = attrs.Get(s)) {
+ kwargs_keys.push_back(s);
+ kwargs_values.push_back(doc.value());
+ }
+ }
+ return prefix->Call(args, kwargs_keys, kwargs_values);
+}
+
+ExprDoc BufferDecl(const tir::Buffer& buffer, const String& method, const Array<ExprDoc>& args,
+ const ObjectPath& p, const Frame& frame, const IRDocsifier& d) {
+ return BufferCall(/*prefix=*/TIR(d)->Attr(method),
+ /*attrs=*/BufferAttrs(buffer, p, frame, d),
+ /*args=*/args);
+}
+
+Doc BufferIndex(const PrimExpr& index, const ObjectPath& p, const IRDocsifier& d) {
+ if (const auto* ramp = index.as<tir::RampNode>()) {
+ if (const auto* stride = ramp->stride.as<IntImmNode>()) {
+ ExprDoc start = d->AsDoc<ExprDoc>(ramp->base, p->Attr("base"));
+ ExprDoc stop = d->AsDoc<ExprDoc>(ramp->base + ramp->lanes * ramp->stride, p->Attr("lanes"));
+ Optional<ExprDoc> step = NullOpt;
+ if (stride->value != 1) {
+ step = d->AsDoc<ExprDoc>(ramp->stride, p->Attr("stride"));
+ }
+ return SliceDoc(start, stop, step);
+ }
+ }
+ return d->AsDoc<ExprDoc>(index, p);
+}
+
+ExprDoc BufferIndices(const tir::Buffer& buffer, const Array<PrimExpr>& indices,
+ const ObjectPath& p, const IRDocsifier& d) {
+ int n = indices.size();
+ Array<Doc> indices_doc;
+ indices_doc.reserve(n);
+ for (int i = 0; i < n; ++i) {
+ indices_doc.push_back(BufferIndex(indices[i], p->Attr("indices")->ArrayIndex(i), d));
+ }
+ return d->AsDoc<ExprDoc>(buffer, p->Attr("buffer"))[indices_doc];
+}
+
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
+ .set_dispatch<tir::BufferRegion>(
+ "", [](tir::BufferRegion buffer_region, ObjectPath p, IRDocsifier d) -> Doc {
+ ExprDoc prefix = d->AsDoc<ExprDoc>(buffer_region->buffer, p->Attr("buffer"));
+ p = p->Attr("region");
+ Array<Range> region = buffer_region->region;
+ int n = region.size();
+ Array<Doc> indices;
+ indices.reserve(n);
+ for (int i = 0; i < n; ++i) {
+ Range range = region[i];
+ ExprDoc min = d->AsDoc<ExprDoc>(range->min, p->ArrayIndex(i)->Attr("min"));
+ if (tir::is_one(range->extent)) {
+ indices.push_back(min);
+ } else {
+ ExprDoc max =
+ d->AsDoc<ExprDoc>(range->min + range->extent, p->ArrayIndex(i)->Attr("extent"));
+ indices.push_back(SliceDoc(min, max, NullOpt));
+ }
+ }
+ return prefix[indices];
+ });
+
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
+ .set_dispatch<tir::BufferStore>( //
+ "", [](tir::BufferStore store, ObjectPath p, IRDocsifier d) -> Doc {
+ return AssignDoc(/*lhs=*/BufferIndices(store->buffer, store->indices, p, d),
+ /*rhs=*/d->AsDoc<ExprDoc>(store->value, p->Attr("value")), NullOpt);
+ });
+
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
+ .set_dispatch<tir::BufferLoad>( //
+ "", [](tir::BufferLoad load, ObjectPath p, IRDocsifier d) -> Doc {
+ return BufferIndices(load->buffer, load->indices, p, d);
+ });
+
+} // namespace printer
+} // namespace script
+} // namespace tvm
diff --git a/src/script/printer/tir/expr.cc b/src/script/printer/tir/expr.cc
new file mode 100644
index 0000000000..f9b4eb6214
--- /dev/null
+++ b/src/script/printer/tir/expr.cc
@@ -0,0 +1,299 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <tvm/tir/builtin.h>
+
+#include "./utils.h"
+
+namespace tvm {
+namespace script {
+namespace printer {
+
+Doc PrintVar(const tir::Var& var, const ObjectPath& p, const IRDocsifier& d) {
+ if (!d->IsVarDefined(var)) {
+ if (Optional<Frame> opt_f = FindLowestVarDef(var, d)) {
+ ExprDoc lhs = DefineVar(var, opt_f.value(), d);
+ Type type = var->type_annotation;
+ if (const auto* ptr_type = type.as<PointerTypeNode>()) {
+ ICHECK(ptr_type->element_type->IsInstance<PrimTypeNode>());
+ ExprDoc rhs = d->AsDoc<ExprDoc>(type, p->Attr("type_annotation"));
+ opt_f.value()->stmts.push_back(AssignDoc(lhs, rhs, NullOpt));
+ } else {
+ ExprDoc rhs = TIR(d)->Attr("var")->Call({LiteralDoc::DataType(var->dtype)});
+ opt_f.value()->stmts.push_back(AssignDoc(lhs, rhs, NullOpt));
+ }
+ }
+ }
+ if (Optional<ExprDoc> doc = d->GetVarDoc(var)) {
+ return doc.value();
+ }
+ LOG(FATAL) << "IndexError: Variable is not defined in the environment: " << var;
+}
+
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) //
+ .set_dispatch<tir::Var>("", [](tir::Var var, ObjectPath p, IRDocsifier d) -> Doc {
+ return PrintVar(var, p, d);
+ });
+
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) //
+ .set_dispatch<tir::SizeVar>("", [](tir::SizeVar var, ObjectPath p, IRDocsifier d) -> Doc {
+ return PrintVar(var, p, d);
+ });
+
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
+ .set_dispatch<tir::IterVar>("", [](tir::IterVar var, ObjectPath p, IRDocsifier d) -> Doc {
+ return TIR(d)
+ ->Attr("iter_var")
+ ->Call({
+ d->AsDoc<ExprDoc>(var->var, p->Attr("var")),
+ d->AsDoc<ExprDoc>(var->dom, p->Attr("dom")),
+ LiteralDoc::Str(IterVarType2String(var->iter_type)),
+ LiteralDoc::Str(var->thread_tag),
+ });
+ });
+
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) //
+ .set_dispatch<tir::Buffer>("", [](tir::Buffer buffer, ObjectPath p, IRDocsifier d) -> Doc {
+ if (!d->IsVarDefined(buffer)) {
+ if (Optional<Frame> opt_f = FindLowestVarDef(buffer, d)) {
+ ExprDoc lhs = DefineBuffer(buffer, opt_f.value(), d);
+ ExprDoc rhs = BufferDecl(buffer, "buffer_decl", // TODO(@junrushao): name confusing
+ {}, p, opt_f.value(), d);
+ opt_f.value()->stmts.push_back(AssignDoc(lhs, rhs, NullOpt));
+ }
+ }
+ if (Optional<ExprDoc> doc = d->GetVarDoc(buffer)) {
+ return doc.value();
+ }
+ LOG(FATAL) << "IndexError: Buffer is not defined in the environment: " << buffer;
+ });
+
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
+ .set_dispatch<tir::Not>("", [](tir::Not node, ObjectPath p, IRDocsifier d) -> Doc {
+ ExprDoc a = d->AsDoc<ExprDoc>(node->a, p->Attr("a"));
+ if (a->IsInstance<LiteralDocNode>()) {
+ return TIR(d)->Attr("Not")->Call({a});
+ }
+ return OperationDoc(OperationDocNode::Kind::kNot, {a});
+ });
+
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
+ .set_dispatch<tir::StringImm>("", [](tir::StringImm s, ObjectPath p, IRDocsifier d) -> Doc {
+ return d->AsDoc<ExprDoc>(s->value, p->Attr("value"));
+ });
+
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
+ .set_dispatch<tir::Cast>("", [](tir::Cast cast, ObjectPath p, IRDocsifier d) -> Doc {
+ ExprDoc dtype = LiteralDoc::DataType(cast->dtype);
+ ExprDoc value = d->AsDoc<ExprDoc>(cast->value, p->Attr("value"));
+ return TIR(d)->Attr("Cast")->Call({dtype, value});
+ });
+
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
+ .set_dispatch<tir::Select>("", [](tir::Select select, ObjectPath p, IRDocsifier d) -> Doc {
+ return TIR(d)->Attr("Select")->Call({
+ d->AsDoc<ExprDoc>(select->condition, p->Attr("condition")),
+ d->AsDoc<ExprDoc>(select->true_value, p->Attr("true_value")),
+ d->AsDoc<ExprDoc>(select->false_value, p->Attr("false_value")),
+ });
+ });
+
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
+ .set_dispatch<tir::Ramp>("", [](tir::Ramp ramp, ObjectPath p, IRDocsifier d) -> Doc {
+ return TIR(d)->Attr("Ramp")->Call({
+ d->AsDoc<ExprDoc>(ramp->base, p->Attr("base")),
+ d->AsDoc<ExprDoc>(ramp->stride, p->Attr("stride")),
+ LiteralDoc::Int(ramp->lanes),
+ });
+ });
+
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
+ .set_dispatch<tir::Broadcast>("", [](tir::Broadcast bc, ObjectPath p, IRDocsifier d) -> Doc {
+ return TIR(d)
+ ->Attr("Broadcast")
+ ->Call({
+ d->AsDoc<ExprDoc>(bc->value, p->Attr("value")),
+ LiteralDoc::Int(bc->lanes),
+ });
+ });
+
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
+ .set_dispatch<tir::Shuffle>( //
+ "", [](tir::Shuffle shuffle, ObjectPath p, IRDocsifier d) -> Doc {
+ return TIR(d)->Attr("Shuffle")->Call({
+ d->AsDoc<ExprDoc>(shuffle->vectors, p->Attr("vectors")),
+ d->AsDoc<ExprDoc>(shuffle->indices, p->Attr("indices")),
+ });
+ });
+
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
+ .set_dispatch<tir::CommReducer>( //
+ "", [](tir::CommReducer r, ObjectPath p, IRDocsifier d) -> Doc {
+ ICHECK_EQ(r->lhs.size(), r->rhs.size());
+ LambdaDoc lambda{nullptr};
+ {
+ With<TIRFrame> f(d, r);
+ int n_vars = r->lhs.size();
+ Array<IdDoc> vars;
+ vars.reserve(n_vars + n_vars);
+ for (int i = 0; i < n_vars; ++i) {
+ vars.push_back(DefineVar(r->lhs[i], *f, d));
+ }
+ for (int i = 0; i < n_vars; ++i) {
+ vars.push_back(DefineVar(r->rhs[i], *f, d));
+ }
+ int n_results = r->result.size();
+ Array<ExprDoc> results;
+ results.reserve(n_results);
+ for (int i = 0; i < n_results; ++i) {
+ results.push_back(d->AsDoc<ExprDoc>(r->result[i], p->Attr("result")->ArrayIndex(i)));
+ }
+ if (results.size() == 1) {
+ lambda = LambdaDoc(vars, results[0]);
+ } else {
+ lambda = LambdaDoc(vars, TupleDoc(results));
+ }
+ }
+ ExprDoc id = d->AsDoc<ExprDoc>(r->identity_element, p->Attr("identity_element"));
+ return TIR(d)->Attr("comm_reducer")->Call({lambda, id});
+ });
+
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
+ .set_dispatch<tir::Let>("", [](tir::Let let, ObjectPath p, IRDocsifier d) -> Doc {
+ return TIR(d)->Attr("let")->Call({
+ d->AsDoc<ExprDoc>(let->var, p->Attr("var")),
+ d->AsDoc<ExprDoc>(let->value, p->Attr("value")),
+ d->AsDoc<ExprDoc>(let->body, p->Attr("body")),
+ });
+ });
+
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
+ .set_dispatch<tir::Call>("", [](tir::Call call, ObjectPath p, IRDocsifier d) -> Doc {
+ static const OpAttrMap<tir::TScriptPrinterName>& op_names =
+ Op::GetAttrMap<tir::TScriptPrinterName>("TScriptPrinterName");
+ static const std::unordered_set<const Object*> dtype_first_arg = {
+ tir::builtin::reinterpret().get(),
+ tir::builtin::call_extern().get(),
+ tir::builtin::call_llvm_intrin().get(), //
+ tir::builtin::call_llvm_pure_intrin().get(), //
+ tir::builtin::call_pure_extern().get(), //
+ tir::builtin::ptx_mma().get(),
+ tir::builtin::ptx_mma_sp().get(),
+ tir::builtin::ptx_ldmatrix().get(),
+ tir::builtin::ptx_cp_async().get(),
+ tir::builtin::mma_store().get(),
+ tir::builtin::mma_fill().get(),
+ tir::builtin::vectorlow().get(),
+ tir::builtin::vectorhigh().get(),
+ tir::builtin::vectorcombine().get(),
+ Op::Get("tir.type_annotation").get(),
+ };
+ static const std::unordered_set<const Object*> dtype_last_arg = {
+ tir::builtin::tvm_struct_get().get(),
+ };
+ ExprDoc prefix{nullptr};
+ if (const auto* op = call->op.as<OpNode>()) {
+ String name = op_names[GetRef<Op>(op)];
+ prefix = TIR(d)->Attr(name);
+ } else if (const auto* gv = call->op.as<GlobalVarNode>()) {
+ prefix = LiteralDoc::Str(gv->name_hint);
+ } else {
+ LOG(FATAL) << "call: " << call;
+ }
+ Array<ExprDoc> args;
+ int n_args = call->args.size();
+ args.reserve(n_args + 1);
+ if (dtype_first_arg.count(call->op.get())) {
+ args.push_back(LiteralDoc::DataType(call->dtype));
+ }
+ for (int i = 0; i < n_args; ++i) {
+ args.push_back(d->AsDoc<ExprDoc>(call->args[i], p->Attr("args")->ArrayIndex(i)));
+ }
+ if (dtype_last_arg.count(call->op.get())) {
+ args.push_back(LiteralDoc::DataType(call->dtype));
+ }
+ return prefix->Call(args);
+ });
+
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
+ .set_dispatch<tir::Any>("", [](tir::Any any, ObjectPath p, IRDocsifier d) -> Doc {
+ return TIR(d)->Attr("Any")->Call({});
+ });
+
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
+ .set_dispatch<tir::Reduce>("", [](tir::Reduce r, ObjectPath p, IRDocsifier d) -> Doc {
+ LOG(FATAL) << "ValueError: Reduce should never exist in TIR: " << r;
+ });
+
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
+ .set_dispatch<tir::ProducerLoad>(
+ "", [](tir::ProducerLoad load, ObjectPath p, IRDocsifier d) -> Doc {
+ LOG(FATAL) << "ValueError: ProducerLoad should never exist in TIR: " << load;
+ });
+
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
+ .set_dispatch<tir::Load>("", [](tir::Load load, ObjectPath p, IRDocsifier d) -> Doc {
+ LOG(FATAL) << "ValueError: Load has been deprecated for BufferLoad: " << load;
+ });
+
+#define TVM_SCRIPT_PRINTER_DEF_BINARY(NodeType, OpString) \
+ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) \
+ .set_dispatch<tir::NodeType>("", \
+ [](tir::NodeType node, ObjectPath p, IRDocsifier d) -> Doc { \
+ ExprDoc a = d->AsDoc<ExprDoc>(node->a, p->Attr("a")); \
+ ExprDoc b = d->AsDoc<ExprDoc>(node->b, p->Attr("b")); \
+ return TIR(d)->Attr(OpString)->Call({a, b}); \
+ });
+
+#define TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(NodeType, OpString, OpKind) \
+ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) \
+ .set_dispatch<tir::NodeType>( \
+ "", [](tir::NodeType node, ObjectPath p, IRDocsifier d) -> Doc { \
+ ExprDoc a = d->AsDoc<ExprDoc>(node->a, p->Attr("a")); \
+ ExprDoc b = d->AsDoc<ExprDoc>(node->b, p->Attr("b")); \
+ if (a->IsInstance<LiteralDocNode>() && b->IsInstance<LiteralDocNode>()) { \
+ return TIR(d)->Attr(OpString)->Call({a, b}); \
+ } \
+ return OperationDoc(OperationDocNode::Kind::OpKind, {a, b}); \
+ });
+
+TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(Add, "Add", kAdd);
+TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(Sub, "Sub", kSub);
+TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(Mul, "Mul", kMult);
+TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(Div, "Div", kDiv);
+TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(FloorDiv, "FloorDiv", kFloorDiv);
+TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(FloorMod, "FloorMod", kMod);
+TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(LT, "LT", kLt);
+TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(LE, "LE", kLtE);
+TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(EQ, "EQ", kEq);
+TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(NE, "NE", kNotEq);
+TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(GT, "GT", kGt);
+TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(GE, "GE", kGtE);
+TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(And, "And", kAnd);
+TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(Or, "Or", kOr);
+
+TVM_SCRIPT_PRINTER_DEF_BINARY(Mod, "truncmod");
+TVM_SCRIPT_PRINTER_DEF_BINARY(Min, "min");
+TVM_SCRIPT_PRINTER_DEF_BINARY(Max, "max");
+
+#undef TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR
+#undef TVM_SCRIPT_PRINTER_DEF_BINARY
+
+} // namespace printer
+} // namespace script
+} // namespace tvm
diff --git a/src/script/printer/tir/for_loop.cc b/src/script/printer/tir/for_loop.cc
new file mode 100644
index 0000000000..6a375935bd
--- /dev/null
+++ b/src/script/printer/tir/for_loop.cc
@@ -0,0 +1,122 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include "./utils.h"
+
+namespace tvm {
+namespace script {
+namespace printer {
+
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
+ .set_dispatch<tir::For>("", [](tir::For loop, ObjectPath p, IRDocsifier d) -> Doc {
+ // Step 1. Check syntactic sugar: `T.grid`
+ std::vector<const tir::ForNode*> grid;
+ std::unordered_set<const tir::VarNode*> grid_loop_vars;
+ auto f_var_dep = [&grid_loop_vars](const PrimExpr& e) -> bool {
+ return tir::UsesVar(e, [&grid_loop_vars](const tir::VarNode* v) -> bool { //
+ return grid_loop_vars.count(v);
+ });
+ };
+ for (const tir::ForNode* l = loop.get(); l != nullptr; l = l->body.as<tir::ForNode>()) {
+ ICHECK(l->loop_var->dtype == l->min->dtype);
+ ICHECK(l->loop_var->dtype == l->extent->dtype);
+ if (l->kind != tir::ForKind::kSerial || //
+ !tir::is_zero(l->min) || //
+ !l->annotations.empty() || //
+ f_var_dep(l->extent)) {
+ break;
+ }
+ grid.push_back(l);
+ grid_loop_vars.insert(l->loop_var.get());
+ }
+ With<TIRFrame> f(d, loop);
+ // Step 2. Construct `T.grid`
+ if (grid.size() > 1) {
+ int n = grid.size();
+ Array<ExprDoc> lhs;
+ Array<ExprDoc> rhs;
+ lhs.reserve(n);
+ rhs.reserve(n);
+ for (int i = 0; i < n; ++i) {
+ const tir::ForNode* loop = grid[i];
+ lhs.push_back(DefineVar(loop->loop_var, *f, d));
+ rhs.push_back(d->AsDoc<ExprDoc>(loop->extent, p->Attr("extent")));
+ p = p->Attr("body");
+ }
+ AsDocBody(grid.back()->body, p, (*f).get(), d);
+ return ForDoc(TupleDoc(lhs), TIR(d)->Attr("grid")->Call(rhs), (*f)->stmts);
+ }
+ // Step 3. If not `T.grid`, print loop kind accordingly
+ IdDoc lhs = DefineVar(loop->loop_var, *f, d);
+ Optional<ExprDoc> min = NullOpt;
+ Optional<ExprDoc> max = NullOpt;
+ Optional<ExprDoc> annotations = NullOpt;
+ Optional<ExprDoc> thread = NullOpt;
+ if (tir::is_zero(loop->min)) {
+ max = d->AsDoc<ExprDoc>(loop->extent, p->Attr("extent"));
+ } else {
+ min = d->AsDoc<ExprDoc>(loop->min, p->Attr("min"));
+ max = d->AsDoc<ExprDoc>(loop->min + loop->extent, p->Attr("extent"));
+ }
+ if (!loop->annotations.empty()) {
+ annotations = d->AsDoc<ExprDoc>(loop->annotations, p->Attr("annotations"));
+ }
+ ExprDoc prefix = TIR(d);
+ if (loop->kind == tir::ForKind::kSerial) {
+ if (loop->annotations.empty()) {
+ prefix = IdDoc("range");
+ } else {
+ prefix = prefix->Attr("serial");
+ }
+ } else if (loop->kind == tir::ForKind::kParallel) {
+ prefix = prefix->Attr("parallel");
+ } else if (loop->kind == tir::ForKind::kUnrolled) {
+ prefix = prefix->Attr("unroll");
+ } else if (loop->kind == tir::ForKind::kVectorized) {
+ prefix = prefix->Attr("vectorized");
+ } else if (loop->kind == tir::ForKind::kThreadBinding) {
+ prefix = prefix->Attr("thread_binding");
+ thread = LiteralDoc::Str(loop->thread_binding.value()->thread_tag);
+ } else {
+ LOG(FATAL) << "ValueError: Unknown ForKind: " << tir::ForKind2String(loop->kind);
+ }
+ Array<ExprDoc> args;
+ Array<String> kwargs_keys;
+ Array<ExprDoc> kwargs_values;
+ if (min.defined()) {
+ args.push_back(min.value());
+ }
+ if (max.defined()) {
+ args.push_back(max.value());
+ }
+ if (thread.defined()) {
+ kwargs_keys.push_back("thread");
+ kwargs_values.push_back(thread.value());
+ }
+ if (annotations.defined()) {
+ kwargs_keys.push_back("annotations");
+ kwargs_values.push_back(annotations.value());
+ }
+ ExprDoc rhs = prefix->Call(args, kwargs_keys, kwargs_values);
+ AsDocBody(loop->body, p, (*f).get(), d);
+ return ForDoc(lhs, rhs, (*f)->stmts);
+ });
+
+} // namespace printer
+} // namespace script
+} // namespace tvm
diff --git a/src/script/printer/tir/function.cc b/src/script/printer/tir/function.cc
new file mode 100644
index 0000000000..d47a60209e
--- /dev/null
+++ b/src/script/printer/tir/function.cc
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include "./utils.h"
+
+namespace tvm {
+namespace script {
+namespace printer {
+
+String FindFunctionName(const IRDocsifier& d, const tir::PrimFunc& f) {
+ if (!d->mod.defined()) {
+ return "main";
+ }
+ for (const auto& kv : d->mod.value()->functions) {
+ if (kv.second.same_as(f)) {
+ return kv.first->name_hint;
+ }
+ }
+ return "main";
+}
+
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
+ .set_dispatch<tir::PrimFunc>("", [](tir::PrimFunc func, ObjectPath p, IRDocsifier d) -> Doc {
+ d->SetCommonPrefix(func, [](const ObjectRef& obj) {
+ return obj->IsInstance<tir::VarNode>() || obj->IsInstance<tir::BufferNode>();
+ });
+ With<TIRFrame> frame(d, func);
+ (*frame)->AddDispatchToken(d, "tir");
+ int n_args = func->params.size();
+ // Step 1. Handle `func->params`
+ Array<AssignDoc> args;
+ args.reserve(n_args);
+ for (int i = 0; i < n_args; ++i) {
+ tir::Var var = func->params[i];
+ ObjectPath var_p = p->Attr("params")->ArrayIndex(i);
+ ExprDoc a = d->AsDoc<ExprDoc>(var->type_annotation, var_p->Attr("type_annotation"));
+ args.push_back(AssignDoc(DefineVar(var, *frame, d), NullOpt, a));
+ }
+ // Step 2. Handle `func->attrs`
+ if (func->attrs.defined() && !func->attrs->dict.empty()) {
+ (*frame)->stmts.push_back(
+ ExprStmtDoc(TIR(d)
+ ->Attr("func_attr") //
+ ->Call({d->AsDoc<ExprDoc>(func->attrs, p->Attr("attrs"))})));
+ }
+ // Step 3. Handle `func->buffer_map`
+ for (int i = 0; i < n_args; ++i) {
+ tir::Var param = func->params[i];
+ if (func->buffer_map.count(param)) {
+ tir::Buffer buffer = func->buffer_map[param];
+ ExprDoc param = args[i]->lhs;
+ ObjectPath buffer_p = p->Attr("buffer_map")->MapValue(param);
+ ExprDoc lhs =
+ DefineBuffer(buffer, *frame, d); // TODO(@junrushao): switch `lhs` and `rhs`
+ ExprDoc rhs = BufferDecl(buffer, "match_buffer", {param}, buffer_p, *frame, d);
+ (*frame)->stmts.push_back(AssignDoc(lhs, rhs, NullOpt));
+ }
+ }
+ // Step 4. Handle `func->body`
+ AsDocBody(func->body, p->Attr("body"), frame->get(), d);
+ return FunctionDoc(
+ /*name=*/IdDoc(FindFunctionName(d, func)),
+ /*args=*/args,
+ /*decorators=*/{TIR(d)->Attr("prim_func")},
+ /*return_type=*/d->AsDoc<ExprDoc>(func->ret_type, p->Attr("ret_type")),
+ /*body=*/(*frame)->stmts);
+ });
+
+} // namespace printer
+} // namespace script
+} // namespace tvm
diff --git a/src/script/printer/tir/ir.cc b/src/script/printer/tir/ir.cc
new file mode 100644
index 0000000000..f4e3762fc0
--- /dev/null
+++ b/src/script/printer/tir/ir.cc
@@ -0,0 +1,97 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <tvm/target/target.h>
+
+#include "./utils.h"
+
+namespace tvm {
+namespace script {
+namespace printer {
+
+TVM_REGISTER_NODE_TYPE(TIRFrameNode);
+
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
+ .set_dispatch<IntImm>("", [](IntImm imm, ObjectPath p, IRDocsifier d) -> Doc {
+ DataType dtype = imm->dtype;
+ if (dtype == Default::IntDType()) {
+ return LiteralDoc::Int(imm->value);
+ } else if (dtype == DataType::Bool()) {
+ return LiteralDoc::Boolean(imm->value);
+ } else {
+ return TIR(d) //
+ ->Attr(runtime::DLDataType2String(dtype))
+ ->Call({LiteralDoc::Int(imm->value)});
+ }
+ });
+
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
+ .set_dispatch<FloatImm>("", [](FloatImm imm, ObjectPath p, IRDocsifier d) -> Doc {
+ DataType dtype = imm->dtype;
+ if (dtype == Default::FloatDType()) {
+ return LiteralDoc::Float(imm->value);
+ } else {
+ return TIR(d)
+ ->Attr(runtime::DLDataType2String(dtype))
+ ->Call({LiteralDoc::Float(imm->value)});
+ }
+ });
+
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
+ .set_dispatch<Range>("", [](Range range, ObjectPath p, IRDocsifier d) -> Doc {
+ return TIR(d)->Attr("Range")->Call({
+ d->AsDoc<ExprDoc>(range->min, p->Attr("min")),
+ d->AsDoc<ExprDoc>(range->extent, p->Attr("extent")),
+ });
+ });
+
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
+ .set_dispatch<PrimType>("", [](PrimType ty, ObjectPath p, IRDocsifier d) -> Doc {
+ std::string dtype = ty->dtype.is_void() ? "void" : runtime::DLDataType2String(ty->dtype);
+ return TIR(d)->Attr(dtype);
+ });
+
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
+ .set_dispatch<PointerType>("", [](PointerType ty, ObjectPath p, IRDocsifier d) -> Doc {
+ ExprDoc element_type = d->AsDoc<ExprDoc>(ty->element_type, p->Attr("element_type"));
+ if (ty->storage_scope == "") {
+ return TIR(d)->Attr("Ptr")->Call({element_type});
+ } else {
+ return TIR(d)->Attr("Ptr")->Call({element_type, LiteralDoc::Str(ty->storage_scope)});
+ }
+ });
+
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
+ .set_dispatch<TupleType>("", [](TupleType ty, ObjectPath p, IRDocsifier d) -> Doc {
+ if (ty->fields.empty()) {
+ return LiteralDoc::None();
+ }
+ return TIR(d) //
+ ->Attr("Tuple")
+ ->Call(d->AsDoc<ListDoc>(ty->fields, p->Attr("fields"))->elements);
+ });
+
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
+ .set_dispatch<Target>("", [](Target target, ObjectPath p, IRDocsifier d) -> Doc {
+ Map<String, ObjectRef> config = target->Export();
+ return TIR(d)->Attr("target")->Call({d->AsDoc<ExprDoc>(config, p)});
+ });
+
+} // namespace printer
+} // namespace script
+} // namespace tvm
diff --git a/src/script/printer/tir/stmt.cc b/src/script/printer/tir/stmt.cc
new file mode 100644
index 0000000000..03e5657d24
--- /dev/null
+++ b/src/script/printer/tir/stmt.cc
@@ -0,0 +1,374 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include "../../../tir/transforms/ir_utils.h"
+#include "./utils.h"
+
+namespace tvm {
+namespace script {
+namespace printer {
+
+Doc DoConciseScoping(const Optional<ExprDoc>& lhs, const ExprDoc& rhs, Array<StmtDoc>* stmts,
+ bool concise_scoping) {
+ if (concise_scoping) {
+ if (lhs.defined()) {
+ stmts->insert(stmts->begin(), AssignDoc(lhs.value(), rhs, NullOpt));
+ } else {
+ stmts->insert(stmts->begin(), ExprStmtDoc(rhs));
+ }
+ return StmtBlockDoc(*stmts);
+ } else {
+ return ScopeDoc(lhs, rhs, *stmts);
+ }
+}
+
+bool AllowConciseScoping(const IRDocsifier& d) {
+ ICHECK(!d->frames.empty());
+ if (const auto* f = d->frames.back().as<TIRFrameNode>()) {
+ return f->allow_concise_scoping;
+ }
+ LOG(FATAL) << "NotImplementedError: fragment printing";
+}
+
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
+ .set_dispatch<tir::Evaluate>("", [](tir::Evaluate eval, ObjectPath p, IRDocsifier d) -> Doc {
+ ExprDoc value = d->AsDoc<ExprDoc>(eval->value, p->Attr("value"));
+ if (eval->value->IsInstance<tir::CallNode>()) {
+ return ExprStmtDoc(value);
+ }
+ return ExprStmtDoc(TIR(d)->Attr("evaluate")->Call({value}));
+ });
+
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
+ .set_dispatch<tir::LetStmt>("", [](tir::LetStmt stmt, ObjectPath p, IRDocsifier d) -> Doc {
+ bool concise = AllowConciseScoping(d);
+ ExprDoc rhs = d->AsDoc<ExprDoc>(stmt->value, p->Attr("value"));
+ With<TIRFrame> f(d, stmt);
+ ExprDoc lhs = d->IsVarDefined(stmt->var) ? d->GetVarDoc(stmt->var).value()
+ : DefineVar(stmt->var, *f, d);
+ AsDocBody(stmt->body, p->Attr("body"), f->get(), d);
+ Array<StmtDoc>* stmts = &(*f)->stmts;
+ if (concise) {
+ Type type = stmt->var->type_annotation;
+ Optional<ExprDoc> type_doc =
+ d->AsDoc<ExprDoc>(type, p->Attr("var")->Attr("type_annotation"));
+ if (const auto* tuple_type = type.as<TupleTypeNode>()) {
+ if (tuple_type->fields.empty()) {
+ type_doc = NullOpt;
+ }
+ }
+ stmts->insert(stmts->begin(), AssignDoc(lhs, rhs, type_doc));
+ return StmtBlockDoc(*stmts);
+ } else {
+ rhs = TIR(d)->Attr("let")->Call({lhs, rhs});
+ return ScopeDoc(NullOpt, rhs, *stmts);
+ }
+ });
+
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
+ .set_dispatch<tir::AssertStmt>(
+ "", [](tir::AssertStmt stmt, ObjectPath p, IRDocsifier d) -> Doc {
+ bool concise = AllowConciseScoping(d);
+ ExprDoc cond = d->AsDoc<ExprDoc>(stmt->condition, p->Attr("condition"));
+ ExprDoc msg = d->AsDoc<ExprDoc>(stmt->message, p->Attr("message"));
+ With<TIRFrame> f(d, stmt);
+ AsDocBody(stmt->body, p->Attr("body"), f->get(), d);
+ if (concise) {
+ Array<StmtDoc>* stmts = &(*f)->stmts;
+ stmts->insert(stmts->begin(), AssertDoc(cond, msg));
+ return StmtBlockDoc(*stmts);
+ }
+ return ScopeDoc(NullOpt, TIR(d)->Attr("Assert")->Call({cond, msg}), (*f)->stmts);
+ });
+
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
+ .set_dispatch<tir::While>("", [](tir::While stmt, ObjectPath p, IRDocsifier d) -> Doc {
+ ExprDoc cond = d->AsDoc<ExprDoc>(stmt->condition, p->Attr("condition"));
+ With<TIRFrame> f(d, stmt);
+ AsDocBody(stmt->body, p->Attr("body"), f->get(), d);
+ return WhileDoc(cond, (*f)->stmts);
+ });
+
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
+ .set_dispatch<tir::DeclBuffer>( //
+ "", [](tir::DeclBuffer stmt, ObjectPath p, IRDocsifier d) -> Doc {
+ bool concise = AllowConciseScoping(d);
+ ExprDoc rhs =
+ BufferDecl(stmt->buffer, "decl_buffer", {}, p->Attr("buffer"), d->frames.back(), d);
+ With<TIRFrame> f(d, stmt);
+ ExprDoc lhs = DefineBuffer(stmt->buffer, *f, d);
+ AsDocBody(stmt->body, p->Attr("body"), f->get(), d);
+ return DoConciseScoping(lhs, rhs, &(*f)->stmts, concise);
+ });
+
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
+ .set_dispatch<tir::IfThenElse>( //
+ "", [](tir::IfThenElse stmt, ObjectPath p, IRDocsifier d) -> Doc {
+ ExprDoc cond = d->AsDoc<ExprDoc>(stmt->condition, p->Attr("condition"));
+ Array<StmtDoc> then_branch;
+ Array<StmtDoc> else_branch;
+ if (stmt->then_case.defined()) {
+ With<TIRFrame> f(d, stmt->then_case);
+ AsDocBody(stmt->then_case, p->Attr("then_case"), f->get(), d);
+ then_branch = (*f)->stmts;
+ }
+ if (stmt->else_case.defined()) {
+ With<TIRFrame> f(d, stmt->else_case);
+ AsDocBody(stmt->else_case.value(), p->Attr("else_case"), f->get(), d);
+ else_branch = (*f)->stmts;
+ }
+ return IfDoc(cond, then_branch, else_branch);
+ });
+
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
+ .set_dispatch<tir::SeqStmt>("", [](tir::SeqStmt stmt, ObjectPath p, IRDocsifier d) -> Doc {
+ // TODO(@junrushao): revisit for fragment printing
+ With<TIRFrame> f(d, stmt);
+ AsDocBody(stmt, p, f->get(), d);
+ return StmtBlockDoc((*f)->stmts);
+ });
+
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
+ .set_dispatch<tir::Prefetch>( //
+ "", [](tir::Prefetch stmt, ObjectPath p, IRDocsifier d) -> Doc {
+ return ExprStmtDoc(TIR(d)
+ ->Attr("prefetch")
+ ->Call({
+ d->AsDoc<ExprDoc>(stmt->buffer, p->Attr("buffer")),
+ d->AsDoc<ExprDoc>(stmt->bounds, p->Attr("bounds")),
+ }));
+ });
+
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
+ .set_dispatch<tir::Allocate>( //
+ "", [](tir::Allocate stmt, ObjectPath p, IRDocsifier d) -> Doc {
+ bool concise = AllowConciseScoping(d);
+ String storage_scope = tir::GetPtrStorageScope(stmt->buffer_var);
+ Array<ExprDoc> args;
+ Array<String> kwargs_keys;
+ Array<ExprDoc> kwargs_values;
+ args.push_back(d->AsDoc<ExprDoc>(stmt->extents, p->Attr("extents")));
+ args.push_back(LiteralDoc::DataType(stmt->dtype));
+ args.push_back(LiteralDoc::Str(storage_scope));
+ if (!tir::is_one(stmt->condition)) {
+ args.push_back(d->AsDoc<ExprDoc>(stmt->condition, p->Attr("condition")));
+ }
+ if (!stmt->annotations.empty()) {
+ kwargs_keys.push_back("annotations");
+ kwargs_values.push_back(d->AsDoc<ExprDoc>(stmt->annotations, p->Attr("annotations")));
+ }
+ ExprDoc lhs = DefineVar(stmt->buffer_var, d->frames.back(), d);
+ With<TIRFrame> f(d, stmt);
+ ExprDoc rhs = TIR(d)->Attr("allocate")->Call(args, kwargs_keys, kwargs_values);
+ AsDocBody(stmt->body, p->Attr("body"), f->get(), d);
+ return DoConciseScoping(lhs, rhs, &(*f)->stmts, concise);
+ });
+
+template <typename T>
+ExprDoc PrintNDArray(::tvm::runtime::NDArray arr) {
+ // FIXME(@junrushao): this is a hack and can be wrong in most of the cases
+ constexpr int NUM_PRINT = 200;
+ int ndim = arr->ndim;
+ int tot_dim = 1;
+ for (int i = 0; i < ndim; i++) {
+ tot_dim *= arr->shape[i];
+ }
+ Array<ExprDoc> result;
+ T* data_ptr = reinterpret_cast<T*>(arr->data);
+ runtime::DataType dtype = arr.DataType();
+ for (int i = 0; i < tot_dim; i++) {
+ if (dtype.is_float()) {
+ result.push_back(LiteralDoc::Float(data_ptr[i]));
+ } else {
+ result.push_back(LiteralDoc::Int(data_ptr[i]));
+ }
+ if (i == NUM_PRINT) {
+ break;
+ }
+ }
+ return ListDoc(result);
+}
+
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
+ .set_dispatch<tir::AllocateConst>(
+ "", [](tir::AllocateConst stmt, ObjectPath p, IRDocsifier d) -> Doc {
+ bool concise = AllowConciseScoping(d);
+ String storage_scope = tir::GetPtrStorageScope(stmt->buffer_var);
+ Array<ExprDoc> args;
+ Array<String> kwargs_keys;
+ Array<ExprDoc> kwargs_values;
+ ExprDoc data_doc{nullptr};
+ if (stmt->dtype.is_int()) {
+ if (stmt->dtype.bits() == 8) {
+ data_doc = PrintNDArray<int8_t>(stmt->data.value());
+ } else if (stmt->dtype.bits() == 16) {
+ data_doc = PrintNDArray<int16_t>(stmt->data.value());
+ } else if (stmt->dtype.bits() == 32) {
+ data_doc = PrintNDArray<int32_t>(stmt->data.value());
+ } else if (stmt->dtype.bits() == 64) {
+ data_doc = PrintNDArray<int64_t>(stmt->data.value());
+ } else {
+ LOG(FATAL) << "DataType not supported";
+ }
+ } else if (stmt->dtype.is_uint()) {
+ if (stmt->dtype.bits() == 8) {
+ data_doc = PrintNDArray<uint8_t>(stmt->data.value());
+ } else if (stmt->dtype.bits() == 16) {
+ data_doc = PrintNDArray<uint16_t>(stmt->data.value());
+ } else if (stmt->dtype.bits() == 32) {
+ data_doc = PrintNDArray<uint32_t>(stmt->data.value());
+ } else if (stmt->dtype.bits() == 64) {
+ data_doc = PrintNDArray<uint64_t>(stmt->data.value());
+ } else {
+ LOG(FATAL) << "DataType not supported";
+ }
+ } else if (stmt->dtype.is_float()) {
+ if (stmt->dtype.bits() == 16) {
+ data_doc = PrintNDArray<int16_t>(stmt->data.value());
+ } else if (stmt->dtype.bits() == 32) {
+ data_doc = PrintNDArray<float>(stmt->data.value());
+ } else if (stmt->dtype.bits() == 64) {
+ data_doc = PrintNDArray<double>(stmt->data.value());
+ } else {
+ LOG(FATAL) << "DataType not supported";
+ }
+ } else {
+ LOG(FATAL) << "DataType not supported";
+ }
+ args.push_back(data_doc);
+ args.push_back(LiteralDoc::DataType(stmt->dtype));
+ args.push_back(d->AsDoc<ExprDoc>(stmt->extents, p->Attr("extents")));
+ ExprDoc rhs = TIR(d)->Attr("allocate_const")->Call(args, kwargs_keys, kwargs_values);
+ With<TIRFrame> f(d, stmt);
+ ExprDoc lhs = DefineVar(stmt->buffer_var, *f, d);
+ AsDocBody(stmt->body, p->Attr("body"), f->get(), d);
+ return DoConciseScoping(lhs, rhs, &(*f)->stmts, concise);
+ });
+
+ExprDoc DocsifyBufferRealize(const tir::BufferRealizeNode* stmt, Optional<ExprDoc> value, //
+ ObjectPath p, IRDocsifier d) {
+ ExprDoc buffer = d->AsDoc<ExprDoc>(stmt->buffer, p->Attr("buffer"));
+ {
+ Array<Doc> bounds;
+ bounds.reserve(stmt->bounds.size());
+ for (int i = 0, n = stmt->bounds.size(); i < n; ++i) {
+ Range range = stmt->bounds[i];
+ ObjectPath range_p = p->Attr("bounds")->ArrayIndex(i);
+ bounds.push_back(
+ SliceDoc(d->AsDoc<ExprDoc>(range->min, range_p->Attr("min")),
+ d->AsDoc<ExprDoc>(range->min + range->extent, range_p->Attr("extent")), //
+ NullOpt));
+ }
+ buffer = buffer[bounds];
+ }
+ Array<ExprDoc> args{buffer};
+ Array<String> kwargs_keys;
+ Array<ExprDoc> kwargs_values;
+ if (value.defined()) {
+ args.push_back(value.value());
+ }
+ if (!tir::is_one(stmt->condition)) {
+ kwargs_keys.push_back("condition");
+ kwargs_values.push_back(d->AsDoc<ExprDoc>(stmt->condition, p->Attr("condition")));
+ }
+ return TIR(d)->Attr("realize")->Call(args, kwargs_keys, kwargs_values);
+}
+
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
+ .set_dispatch<tir::BufferRealize>( //
+ "", [](tir::BufferRealize stmt, ObjectPath p, IRDocsifier d) -> Doc {
+ bool concise = AllowConciseScoping(d);
+ ExprDoc rhs = DocsifyBufferRealize(stmt.get(), NullOpt, p, d);
+ With<TIRFrame> f(d, stmt);
+ AsDocBody(stmt->body, p->Attr("body"), f->get(), d);
+ return DoConciseScoping(NullOpt, rhs, &(*f)->stmts, concise);
+ });
+
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
+ .set_dispatch<tir::AttrStmt>( //
+ "", [](tir::AttrStmt stmt, ObjectPath p, IRDocsifier d) -> Doc {
+ bool concise = AllowConciseScoping(d);
+ Optional<ExprDoc> rhs = NullOpt;
+ tir::Stmt body = stmt->body;
+ ObjectPath body_p = p->Attr("body");
+ if (stmt->attr_key == "realize_scope") {
+ if (const auto* realize = stmt->body.as<tir::BufferRealizeNode>()) {
+ if (realize->buffer.same_as(stmt->node)) {
+ rhs =
+ DocsifyBufferRealize(realize,
+ /*value=*/d->AsDoc<ExprDoc>(stmt->value, p->Attr("value")),
+ /*p=*/p->Attr("body"), d);
+ body = realize->body;
+ body_p = body_p->Attr("body");
+ }
+ }
+ }
+ if (stmt->attr_key == "thread_extent" || stmt->attr_key == "virtual_thread") {
+ if (const auto* iter_var = stmt->node.as<tir::IterVarNode>()) {
+ if (!d->IsVarDefined(iter_var->var)) {
+ // `DefineVar` is not used here because a more specific name is desirable
+ Frame f = FindLowestVarDef(iter_var->var, d).value();
+ DefineVar(iter_var->var, f, d);
+ f->stmts.push_back(
+ AssignDoc(d->AsDoc<ExprDoc>(iter_var->var, p->Attr("node")->Attr("var")),
+ TIR(d) //
+ ->Attr("env_thread")
+ ->Call({LiteralDoc::Str(iter_var->thread_tag)}), //
+ NullOpt));
+ }
+ rhs = TIR(d)
+ ->Attr("launch_thread")
+ ->Call({
+ d->AsDoc<ExprDoc>(iter_var->var, p->Attr("node")),
+ d->AsDoc<ExprDoc>(stmt->value, p->Attr("value")),
+ });
+ }
+ }
+ if (!rhs.defined()) {
+ rhs = TIR(d)->Attr("attr")->Call({
+ d->AsDoc<ExprDoc>(stmt->node, p->Attr("node")),
+ LiteralDoc::Str(stmt->attr_key),
+ d->AsDoc<ExprDoc>(stmt->value, p->Attr("value")),
+ });
+ }
+ With<TIRFrame> f(d, stmt);
+ AsDocBody(body, body_p, f->get(), d);
+ return DoConciseScoping(NullOpt, rhs.value(), &(*f)->stmts, concise);
+ });
+
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
+ .set_dispatch<tir::ProducerRealize>( //
+ "", [](tir::ProducerRealize stmt, ObjectPath p, IRDocsifier d) -> Doc {
+ LOG(FATAL) << "ValueError: ProducerRealize should never exist in TIR: " << stmt;
+ });
+
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
+ .set_dispatch<tir::ProducerStore>( //
+ "", [](tir::ProducerStore stmt, ObjectPath p, IRDocsifier d) -> Doc {
+ LOG(FATAL) << "ValueError: ProducerStore should never exist in TIR: " << stmt;
+ });
+
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
+ .set_dispatch<tir::Store>( //
+ "", [](tir::Store stmt, ObjectPath p, IRDocsifier d) -> Doc {
+ LOG(FATAL) << "ValueError: Store has been deprecated for BufferStore: " << stmt;
+ });
+
+} // namespace printer
+} // namespace script
+} // namespace tvm
diff --git a/src/script/printer/tir/utils.h b/src/script/printer/tir/utils.h
new file mode 100644
index 0000000000..6cae378d0e
--- /dev/null
+++ b/src/script/printer/tir/utils.h
@@ -0,0 +1,176 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef TVM_SCRIPT_PRINTER_TIR_UTILS_H_
+#define TVM_SCRIPT_PRINTER_TIR_UTILS_H_
+
+#include <tvm/script/printer/ir_docsifier.h>
+#include <tvm/script/printer/printer.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/buffer.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/op.h>
+#include <tvm/tir/stmt.h>
+
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+namespace tvm {
+namespace script {
+namespace printer {
+
+/*! \brief A printer frame for TIR fragment */
+class TIRFrameNode : public FrameNode {
+ public:
+ /*! \brief The TIR fragment the frame corresponds to */
+ ObjectRef tir;
+ /*! \brief Whether or not the frame allows concise scoping */
+ bool allow_concise_scoping{false};
+
+ void VisitAttrs(AttrVisitor* v) {
+ FrameNode::VisitAttrs(v);
+ v->Visit("tir", &tir);
+ v->Visit("allow_concise_scoping", &allow_concise_scoping);
+ }
+
+ static constexpr const char* _type_key = "script.printer.TIRFrame";
+ TVM_DECLARE_FINAL_OBJECT_INFO(TIRFrameNode, FrameNode);
+};
+
+/*! \brief Managed reference to TIRFrameNode */
+class TIRFrame : public Frame {
+ public:
+ /*! \brief Constructor */
+ explicit TIRFrame(const IRDocsifier& d, const ObjectRef& tir) {
+ ObjectPtr<TIRFrameNode> n = make_object<TIRFrameNode>();
+ n->stmts.clear();
+ n->d = d.get();
+ n->tir = tir;
+ data_ = std::move(n);
+ }
+
+ TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TIRFrame, Frame, TIRFrameNode);
+};
+
+/*! \brief Creates the TIR common prefix, which is by default `T` */
+inline IdDoc TIR(const IRDocsifier& d) { //
+ return IdDoc(d->ir_prefix.Get("tir").value_or("T"));
+}
+
+/*!
+ * \brief Defines a variable in the IRDocsifier at the given frame,
+ * and returns the corresponding IdDoc
+ * \param var The variable to define
+ * \param d The IRDocsifier
+ * \param frame The frame to define the variable in
+ * \return The IdDoc corresponding to the variable
+ */
+inline IdDoc DefineVar(const tir::Var& var, const Frame& frame, const IRDocsifier& d) {
+ return d->Define(var, frame, var->name_hint.empty() ? "v" : var->name_hint);
+}
+
+/*!
+ * \brief Defines a buffer in the IRDocsifier at the given frame,
+ * and returns the corresponding IdDoc
+ * \param buffer The buffer to define
+ * \param frame The frame to define the buffer in
+ * \param d The IRDocsifier
+ * \return The IdDoc corresponding to the buffer
+ */
+inline IdDoc DefineBuffer(const tir::Buffer& buffer, const Frame& frame, const IRDocsifier& d) {
+ return d->Define(buffer, frame, buffer->name.empty() ? "buffer" : buffer->name);
+}
+
+/*!
+ * \brief Recursively process the body statements of a TIR fragment represented by a frame
+ * \param stmt The body statement to process
+ * \param p The object path
+ * \param f The frame
+ * \param d The IRDocsifier
+ */
+inline void AsDocBody(const tir::Stmt& stmt, ObjectPath p, TIRFrameNode* f, const IRDocsifier& d) {
+ if (const auto* seq_stmt = stmt.as<tir::SeqStmtNode>()) {
+ Array<tir::Stmt> body = seq_stmt->seq;
+ p = p->Attr("seq");
+ for (int i = 0, n = body.size(); i < n; ++i) {
+ f->allow_concise_scoping = (i == n - 1);
+ Doc doc = d->AsDoc(body[i], p->ArrayIndex(i));
+ if (const auto* block = doc.as<StmtBlockDocNode>()) {
+ f->stmts.insert(f->stmts.end(), block->stmts.begin(), block->stmts.end());
+ } else {
+ f->stmts.push_back(Downcast<StmtDoc>(doc));
+ }
+ }
+ } else {
+ f->allow_concise_scoping = true;
+ Doc doc = d->AsDoc(stmt, p);
+ if (const auto* block = doc.as<StmtBlockDocNode>()) {
+ f->stmts.insert(f->stmts.end(), block->stmts.begin(), block->stmts.end());
+ } else {
+ f->stmts.push_back(Downcast<StmtDoc>(doc));
+ }
+ }
+}
+
+/*!
+ * \brief Find the top frame in the stack that could place a var definition
+ * \param var The var to be defined
+ * \param d The IRDocsifier
+ * \return The frame that could place the var definition
+ */
+inline Optional<Frame> FindLowestVarDef(const ObjectRef& var, const IRDocsifier& d) {
+ if (!d->common_prefix.count(var.get())) {
+ return NullOpt;
+ }
+ int n_frames = d->frames.size();
+ std::unordered_map<const Object*, const FrameNode*> tir_to_frame;
+ tir_to_frame.reserve(n_frames);
+ for (int i = n_frames - 1; i >= 0; --i) {
+ if (const auto* f = d->frames[i].as<TIRFrameNode>()) {
+ tir_to_frame[f->tir.get()] = f;
+ }
+ }
+ const std::vector<const Object*>& path = d->common_prefix.at(var.get());
+ for (auto it = path.rbegin(); it != path.rend(); ++it) {
+ if (tir_to_frame.count(*it)) {
+ return GetRef<Frame>(tir_to_frame.at(*it));
+ }
+ }
+ return NullOpt;
+}
+
+/*!
+ * \brief Declare and define a buffer
+ * \param buffer The buffer to be defined
+ * \param method The method used to declare the buffer
+ * \param args The extra arguments used to declare the buffer
+ * \param p The object path
+ * \param f The frame
+ * \param d The IRDocsifier
+ * \return The ExprDoc corresponding to the buffer declaration
+ */
+ExprDoc BufferDecl(const tir::Buffer& buffer, const String& method, const Array<ExprDoc>& args,
+ const ObjectPath& p, const Frame& frame, const IRDocsifier& d);
+
+} // namespace printer
+} // namespace script
+} // namespace tvm
+
+#endif // TVM_SCRIPT_PRINTER_TIR_UTILS_H_
diff --git a/src/script/printer/traced_object_functor.cc b/src/script/printer/traced_object_functor.cc
deleted file mode 100644
index 43160c7f4b..0000000000
--- a/src/script/printer/traced_object_functor.cc
+++ /dev/null
@@ -1,85 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-#include <tvm/script/printer/traced_object_functor.h>
-
-namespace tvm {
-namespace script {
-namespace printer {
-
-const runtime::PackedFunc* GetDispatchFunctionForToken(const DispatchTable& table,
- const String& token, uint32_t type_index) {
- auto it = table.find(token);
- if (it == table.end()) {
- return nullptr;
- }
- const std::vector<runtime::PackedFunc>& tab = it->second;
- if (type_index >= tab.size()) {
- return nullptr;
- }
- const PackedFunc* f = &tab[type_index];
- if (f->defined()) {
- return f;
- } else {
- return nullptr;
- }
-}
-
-const runtime::PackedFunc& GetDispatchFunction(const DispatchTable& dispatch_table,
- const String& token, uint32_t type_index) {
- if (const runtime::PackedFunc* pf =
- GetDispatchFunctionForToken(dispatch_table, token, type_index)) {
- return *pf;
- } else if (const runtime::PackedFunc* pf =
- GetDispatchFunctionForToken(dispatch_table, kDefaultDispatchToken, type_index)) {
- // Fallback to function with the default dispatch token
- return *pf;
- } else {
- ICHECK(false) << "ObjectFunctor calls un-registered function on type: "
- << runtime::Object::TypeIndex2Key(type_index) << " (token: " << token << ")";
- throw;
- }
-}
-
-void SetDispatchFunction(DispatchTable* dispatch_table, const String& token, uint32_t type_index,
- runtime::PackedFunc f) {
- std::vector<runtime::PackedFunc>* table = &(*dispatch_table)[token];
- if (table->size() <= type_index) {
- table->resize(type_index + 1, nullptr);
- }
- runtime::PackedFunc& slot = (*table)[type_index];
- if (slot != nullptr) {
- ICHECK(false) << "Dispatch for type is already registered: "
- << runtime::Object::TypeIndex2Key(type_index);
- }
- slot = f;
-}
-
-void RemoveDispatchFunction(DispatchTable* dispatch_table, const String& token,
- uint32_t type_index) {
- std::vector<runtime::PackedFunc>* table = &(*dispatch_table)[token];
- if (table->size() <= type_index) {
- return;
- }
- (*table)[type_index] = nullptr;
-}
-
-} // namespace printer
-} // namespace script
-} // namespace tvm
diff --git a/src/script/printer/utils.h b/src/script/printer/utils.h
deleted file mode 100644
index abe7ce5e9a..0000000000
--- a/src/script/printer/utils.h
+++ /dev/null
@@ -1,93 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-#ifndef TVM_SCRIPT_PRINTER_UTILS_H_
-#define TVM_SCRIPT_PRINTER_UTILS_H_
-
-#include <tvm/script/printer/doc.h>
-#include <tvm/script/printer/ir_docsifier.h>
-
-#include <utility>
-
-namespace tvm {
-namespace script {
-namespace printer {
-
-template <typename DocType, typename NodeType>
-Array<DocType> AsDocArray(const TracedArray<NodeType>& refs, const IRDocsifier& ir_docsifier) {
- Array<DocType> result;
- for (auto ref : refs) {
- result.push_back(ir_docsifier->AsExprDoc(ref));
- }
- return result;
-}
-
-template <typename DocType, typename NodeType>
-Array<DocType> AsDocArray(std::initializer_list<NodeType>&& refs, const IRDocsifier& ir_docsifier) {
- Array<DocType> result;
- for (auto& ref : refs) {
- result.push_back(ir_docsifier->AsExprDoc(ref));
- }
- return result;
-}
-
-template <typename RefType>
-Array<ExprDoc> AsExprDocArray(const TracedArray<RefType>& refs, const IRDocsifier& ir_docsifier) {
- return AsDocArray<ExprDoc>(refs, ir_docsifier);
-}
-
-template <typename RefType>
-Array<ExprDoc> AsExprDocArray(std::initializer_list<RefType>&& refs,
- const IRDocsifier& ir_docsifier) {
- return AsDocArray<ExprDoc>(std::move(refs), ir_docsifier);
-}
-
-inline DictDoc AsDictDoc(const TracedMap<String, ObjectRef>& dict,
- const IRDocsifier& ir_docsifier) {
- Array<ExprDoc> keys;
- Array<ExprDoc> values;
-
- for (auto p : dict) {
- keys.push_back(LiteralDoc::Str(p.first));
- values.push_back(ir_docsifier->AsExprDoc(p.second));
- }
-
- auto doc = DictDoc(keys, values);
- doc->source_paths.push_back(dict.GetPath());
- return doc;
-}
-
-template <typename T>
-inline ListDoc AsListDoc(const TracedArray<T>& arr, const IRDocsifier& ir_docsifier) {
- auto ret = ListDoc(AsExprDocArray(arr, ir_docsifier));
- ret->source_paths.push_back(arr.GetPath());
- return ret;
-}
-
-template <typename T>
-inline TupleDoc AsTupleDoc(const TracedArray<T>& arr, const IRDocsifier& ir_docsifier) {
- auto ret = TupleDoc(AsExprDocArray(arr, ir_docsifier));
- ret->source_paths.push_back(arr.GetPath());
- return ret;
-}
-
-} // namespace printer
-} // namespace script
-} // namespace tvm
-
-#endif // TVM_SCRIPT_PRINTER_UTILS_H_
diff --git a/src/script/printer/var_table.cc b/src/script/printer/var_table.cc
deleted file mode 100644
index 62d8b2f66c..0000000000
--- a/src/script/printer/var_table.cc
+++ /dev/null
@@ -1,109 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-#include <tvm/node/object_path.h>
-#include <tvm/runtime/container/optional.h>
-#include <tvm/runtime/logging.h>
-#include <tvm/runtime/registry.h>
-#include <tvm/script/printer/var_table.h>
-
-namespace tvm {
-namespace script {
-namespace printer {
-
-String GenerateUniqueName(const String& name_hint, std::unordered_set<String>* defined_names) {
- String name = name_hint;
- for (int i = 1; !defined_names->insert(name).second; ++i) {
- name = name_hint + "_" + std::to_string(i);
- }
- return name;
-}
-
-IdDoc VarTableNode::Define(const ObjectRef& obj, const String& name_hint,
- const ObjectPath& object_path, const Frame& frame) {
- String name = GenerateUniqueName(name_hint, &this->defined_names);
- DocFactory doc_factory = [name]() { return IdDoc(name); };
-
- auto result = obj2info.insert({obj, VariableInfo{std::move(doc_factory), name}});
- ICHECK(result.second) << "Duplicated object: " << obj;
-
- IdDoc def_doc(name);
- def_doc->source_paths.push_back(object_path);
-
- frame->AddExitCallback([this, obj]() { this->RemoveVar(obj); });
-
- return def_doc;
-}
-
-void VarTableNode::DefineByDoc(const ObjectRef& obj, DocFactory doc_factory, const Frame& frame) {
- ICHECK(obj2info.find(obj) == obj2info.end()) << "Duplicated object: " << obj;
-
- ICHECK(!doc_factory()->IsInstance<IdDocNode>())
- << "VarTableNode::Define cannot be used for variable that's mapped to IdDoc.";
-
- obj2info.insert({obj, VariableInfo{std::move(doc_factory), NullOpt}});
-
- frame->AddExitCallback([this, obj]() { this->RemoveVar(obj); });
-}
-
-Optional<ExprDoc> VarTableNode::GetVarDoc(const ObjectRef& obj,
- const ObjectPath& object_path) const {
- auto it = obj2info.find(obj);
- if (it == obj2info.end()) {
- return NullOpt;
- }
- ExprDoc doc = it->second.doc_factory();
- doc->source_paths.push_back(object_path);
- return doc;
-}
-
-bool VarTableNode::IsVarDefined(const ObjectRef& obj) const { return obj2info.count(obj); }
-
-void VarTableNode::RemoveVar(const ObjectRef& obj) {
- auto it = obj2info.find(obj);
- ICHECK(it != obj2info.end()) << "No such object: " << obj;
-
- if (it->second.name.defined()) {
- defined_names.erase(it->second.name.value());
- }
- obj2info.erase(it);
-}
-
-VarTable::VarTable() { data_ = make_object<VarTableNode>(); }
-
-TVM_REGISTER_NODE_TYPE(VarTableNode);
-TVM_REGISTER_GLOBAL("script.printer.VarTable").set_body_typed([]() { return VarTable(); });
-TVM_REGISTER_GLOBAL("script.printer.VarTableDefine")
- .set_body_method<VarTable, VarTableNode, IdDoc, const ObjectRef&, const String&,
- const ObjectPath&, const Frame&>(&VarTableNode::Define);
-TVM_REGISTER_GLOBAL("script.printer.VarTableDefineByDoc")
- .set_body_typed([](VarTable var_table, const ObjectRef& obj, runtime::PackedFunc factory,
- Frame frame) {
- var_table->DefineByDoc(
- obj, [f = std::move(factory)]() { return f(); }, frame);
- });
-TVM_REGISTER_GLOBAL("script.printer.VarTableGetVarDoc")
- .set_body_method<VarTable, VarTableNode, Optional<ExprDoc>, const ObjectRef&,
- const ObjectPath&>(&VarTableNode::GetVarDoc);
-TVM_REGISTER_GLOBAL("script.printer.VarTableIsVarDefined")
- .set_body_method<VarTable>(&VarTableNode::IsVarDefined);
-
-} // namespace printer
-} // namespace script
-} // namespace tvm
diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc
index a8d8936c90..af6997a72a 100644
--- a/src/tir/ir/stmt.cc
+++ b/src/tir/ir/stmt.cc
@@ -1088,7 +1088,7 @@ PrimExpr TypeAnnotation(DataType dtype, Span span) {
return tir::Call(dtype, op, {}, span);
}
-TVM_REGISTER_OP("tir.type_annotation")
+TVM_TIR_REGISTER_OP("type_annotation")
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure));
} // namespace tir
diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc
index 56ecba9e9e..dc3208f484 100644
--- a/src/tir/op/builtin.cc
+++ b/src/tir/op/builtin.cc
@@ -36,7 +36,7 @@ namespace builtin {
static const Op& op = Op::Get("tir." #OpName); \
return op; \
} \
- TVM_REGISTER_OP("tir." #OpName)
+ TVM_TIR_REGISTER_OP(#OpName)
TIR_DEFINE_BUILTIN_FUNC(reinterpret)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure))
@@ -181,10 +181,12 @@ TIR_DEFINE_BUILTIN_FUNC(tvm_stack_make_array)
// When num_inputs are not set, the function is assumed to be variable length.
TIR_DEFINE_BUILTIN_FUNC(tvm_call_packed)
- .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
+ .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque))
+ .set_attr<TScriptPrinterName>("TScriptPrinterName", String("call_packed"), /*plevel=*/20);
TIR_DEFINE_BUILTIN_FUNC(tvm_call_cpacked)
- .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
+ .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque))
+ .set_attr<TScriptPrinterName>("TScriptPrinterName", String("call_cpacked"), /*plevel=*/20);
TIR_DEFINE_BUILTIN_FUNC(tvm_call_trace_packed)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
@@ -198,10 +200,14 @@ TIR_DEFINE_BUILTIN_FUNC(tvm_thread_context)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
TIR_DEFINE_BUILTIN_FUNC(tvm_call_packed_lowered)
- .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
+ .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque))
+ .set_attr<TScriptPrinterName>("TScriptPrinterName", String("call_packed_lowered"),
+ /*plevel=*/20);
TIR_DEFINE_BUILTIN_FUNC(tvm_call_cpacked_lowered)
- .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
+ .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque))
+ .set_attr<TScriptPrinterName>("TScriptPrinterName", String("call_cpacked_lowered"),
+ /*plevel=*/20);
TIR_DEFINE_BUILTIN_FUNC(tvm_call_trace_packed_lowered)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc
index 044d8fd08d..078e32ca57 100644
--- a/src/tir/op/op.cc
+++ b/src/tir/op/op.cc
@@ -39,13 +39,13 @@ namespace tvm {
using namespace tir;
// macro to register an unary op
-#define TIR_REGISTER_PURE_UNARY_OP(OpName) \
- TVM_REGISTER_OP(OpName).set_num_inputs(1).set_attr<TCallEffectKind>( \
+#define TVM_TIR_REGISTER_PURE_UNARY_OP(OpName) \
+ TVM_TIR_REGISTER_OP(OpName).set_num_inputs(1).set_attr<TCallEffectKind>( \
"TCallEffectKind", Integer(CallEffectKind::kPure))
// macro to register an binary op
-#define TIR_REGISTER_PURE_BINARY_OP(OpName) \
- TVM_REGISTER_OP(OpName).set_num_inputs(2).set_attr<TCallEffectKind>( \
+#define TVM_TIR_REGISTER_PURE_BINARY_OP(OpName) \
+ TVM_TIR_REGISTER_OP(OpName).set_num_inputs(2).set_attr<TCallEffectKind>( \
"TCallEffectKind", Integer(CallEffectKind::kPure))
runtime::DataType GetRuntimeDataType(const Type& type) {
@@ -657,7 +657,7 @@ PrimExpr pow(PrimExpr x, PrimExpr y, Span span) {
return tir::Call(x.dtype(), op, {x, y}, span);
}
-TIR_REGISTER_PURE_BINARY_OP("tir.pow").set_attr<TVectorizable>("TVectorizable", true);
+TVM_TIR_REGISTER_PURE_BINARY_OP("pow").set_attr<TVectorizable>("TVectorizable", true);
// abs
PrimExpr abs(PrimExpr x, Span span) {
@@ -685,7 +685,7 @@ PrimExpr abs(PrimExpr x, Span span) {
}
}
-TIR_REGISTER_PURE_UNARY_OP("tir.fabs").set_attr<TVectorizable>("TVectorizable", true);
+TVM_TIR_REGISTER_PURE_UNARY_OP("fabs").set_attr<TVectorizable>("TVectorizable", true);
// isnan
PrimExpr isnan(PrimExpr x, Span span) {
@@ -783,7 +783,7 @@ PrimExpr fmod(PrimExpr x, PrimExpr y, Span span) {
return tir::Call(x.dtype(), op, {x, y}, span);
}
-TIR_REGISTER_PURE_UNARY_OP("tir.fmod");
+TVM_TIR_REGISTER_PURE_UNARY_OP("fmod");
// floor
PrimExpr floor(PrimExpr x, Span span) {
@@ -797,7 +797,7 @@ PrimExpr floor(PrimExpr x, Span span) {
return tir::Call(x.dtype(), op, {x}, span);
}
-TIR_REGISTER_PURE_UNARY_OP("tir.floor").set_attr<TVectorizable>("TVectorizable", true);
+TVM_TIR_REGISTER_PURE_UNARY_OP("floor").set_attr<TVectorizable>("TVectorizable", true);
// ceil
PrimExpr ceil(PrimExpr x, Span span) {
@@ -811,7 +811,7 @@ PrimExpr ceil(PrimExpr x, Span span) {
return tir::Call(x.dtype(), op, {x}, span);
}
-TIR_REGISTER_PURE_UNARY_OP("tir.ceil").set_attr<TVectorizable>("TVectorizable", true);
+TVM_TIR_REGISTER_PURE_UNARY_OP("ceil").set_attr<TVectorizable>("TVectorizable", true);
// round
PrimExpr round(PrimExpr x, Span span) {
@@ -825,7 +825,7 @@ PrimExpr round(PrimExpr x, Span span) {
return tir::Call(x.dtype(), op, {x}, span);
}
-TIR_REGISTER_PURE_UNARY_OP("tir.round").set_attr<TVectorizable>("TVectorizable", true);
+TVM_TIR_REGISTER_PURE_UNARY_OP("round").set_attr<TVectorizable>("TVectorizable", true);
// nearbyint
PrimExpr nearbyint(PrimExpr x, Span span) {
@@ -839,7 +839,7 @@ PrimExpr nearbyint(PrimExpr x, Span span) {
return tir::Call(x.dtype(), op, {x}, span);
}
-TIR_REGISTER_PURE_UNARY_OP("tir.nearbyint");
+TVM_TIR_REGISTER_PURE_UNARY_OP("nearbyint");
// trunc
PrimExpr trunc(PrimExpr x, Span span) {
@@ -856,67 +856,77 @@ PrimExpr trunc(PrimExpr x, Span span) {
return tir::Call(x.dtype(), op, {x}, span);
}
-TIR_REGISTER_PURE_UNARY_OP("tir.trunc").set_attr<TVectorizable>("TVectorizable", true);
+TVM_TIR_REGISTER_PURE_UNARY_OP("trunc").set_attr<TVectorizable>("TVectorizable", true);
// unary op registration.
-TIR_REGISTER_PURE_UNARY_OP("tir.exp").set_attr<TVectorizable>("TVectorizable", true);
+TVM_TIR_REGISTER_PURE_UNARY_OP("exp").set_attr<TVectorizable>("TVectorizable", true);
-TIR_REGISTER_PURE_UNARY_OP("tir.exp2").set_attr<TVectorizable>("TVectorizable", true);
+TVM_TIR_REGISTER_PURE_UNARY_OP("exp2").set_attr<TVectorizable>("TVectorizable", true);
-TIR_REGISTER_PURE_UNARY_OP("tir.exp10").set_attr<TVectorizable>("TVectorizable", true);
+TVM_TIR_REGISTER_PURE_UNARY_OP("exp10").set_attr<TVectorizable>("TVectorizable", true);
-TIR_REGISTER_PURE_UNARY_OP("tir.erf");
+TVM_TIR_REGISTER_PURE_UNARY_OP("erf");
-TIR_REGISTER_PURE_UNARY_OP("tir.tanh").set_attr<TVectorizable>("TVectorizable", true);
+TVM_TIR_REGISTER_PURE_UNARY_OP("tanh").set_attr<TVectorizable>("TVectorizable", true);
-TIR_REGISTER_PURE_UNARY_OP("tir.sigmoid").set_attr<TVectorizable>("TVectorizable", true);
+TVM_TIR_REGISTER_PURE_UNARY_OP("sigmoid").set_attr<TVectorizable>("TVectorizable", true);
-TIR_REGISTER_PURE_UNARY_OP("tir.sqrt").set_attr<TVectorizable>("TVectorizable", true);
+TVM_TIR_REGISTER_PURE_UNARY_OP("sqrt").set_attr<TVectorizable>("TVectorizable", true);
-TIR_REGISTER_PURE_UNARY_OP("tir.rsqrt");
+TVM_TIR_REGISTER_PURE_UNARY_OP("rsqrt");
-TIR_REGISTER_PURE_UNARY_OP("tir.log").set_attr<TVectorizable>("TVectorizable", true);
+TVM_TIR_REGISTER_PURE_UNARY_OP("log").set_attr<TVectorizable>("TVectorizable", true);
-TIR_REGISTER_PURE_UNARY_OP("tir.log2").set_attr<TVectorizable>("TVectorizable", true);
+TVM_TIR_REGISTER_PURE_UNARY_OP("log2").set_attr<TVectorizable>("TVectorizable", true);
-TIR_REGISTER_PURE_UNARY_OP("tir.log1p");
+TVM_TIR_REGISTER_PURE_UNARY_OP("log1p");
-TIR_REGISTER_PURE_UNARY_OP("tir.log10").set_attr<TVectorizable>("TVectorizable", true);
+TVM_TIR_REGISTER_PURE_UNARY_OP("log10").set_attr<TVectorizable>("TVectorizable", true);
-TIR_REGISTER_PURE_UNARY_OP("tir.tan").set_attr<TVectorizable>("TVectorizable", true);
+TVM_TIR_REGISTER_PURE_UNARY_OP("tan").set_attr<TVectorizable>("TVectorizable", true);
-TIR_REGISTER_PURE_UNARY_OP("tir.cos").set_attr<TVectorizable>("TVectorizable", true);
+TVM_TIR_REGISTER_PURE_UNARY_OP("cos").set_attr<TVectorizable>("TVectorizable", true);
-TIR_REGISTER_PURE_UNARY_OP("tir.cosh").set_attr<TVectorizable>("TVectorizable", true);
+TVM_TIR_REGISTER_PURE_UNARY_OP("cosh").set_attr<TVectorizable>("TVectorizable", true);
-TIR_REGISTER_PURE_UNARY_OP("tir.sin").set_attr<TVectorizable>("TVectorizable", true);
+TVM_TIR_REGISTER_PURE_UNARY_OP("sin").set_attr<TVectorizable>("TVectorizable", true);
-TIR_REGISTER_PURE_UNARY_OP("tir.sinh").set_attr<TVectorizable>("TVectorizable", true);
+TVM_TIR_REGISTER_PURE_UNARY_OP("sinh").set_attr<TVectorizable>("TVectorizable", true);
-TIR_REGISTER_PURE_UNARY_OP("tir.asin");
+TVM_TIR_REGISTER_PURE_UNARY_OP("asin");
-TIR_REGISTER_PURE_UNARY_OP("tir.acos");
+TVM_TIR_REGISTER_PURE_UNARY_OP("acos");
-TIR_REGISTER_PURE_UNARY_OP("tir.atan");
+TVM_TIR_REGISTER_PURE_UNARY_OP("atan");
-TIR_REGISTER_PURE_UNARY_OP("tir.acosh");
+TVM_TIR_REGISTER_PURE_UNARY_OP("acosh");
-TIR_REGISTER_PURE_UNARY_OP("tir.asinh");
+TVM_TIR_REGISTER_PURE_UNARY_OP("asinh");
-TIR_REGISTER_PURE_UNARY_OP("tir.atanh");
+TVM_TIR_REGISTER_PURE_UNARY_OP("atanh");
-TIR_REGISTER_PURE_UNARY_OP("tir.clz");
+TVM_TIR_REGISTER_PURE_UNARY_OP("clz");
// binary intrinsics
-TIR_REGISTER_PURE_BINARY_OP("tir.atan2");
+TVM_TIR_REGISTER_PURE_BINARY_OP("atan2");
-TIR_REGISTER_PURE_BINARY_OP("tir.nextafter");
+TVM_TIR_REGISTER_PURE_BINARY_OP("nextafter");
-TIR_REGISTER_PURE_BINARY_OP("tir.hypot");
+TVM_TIR_REGISTER_PURE_BINARY_OP("hypot");
-TIR_REGISTER_PURE_BINARY_OP("tir.copysign");
+TVM_TIR_REGISTER_PURE_BINARY_OP("copysign");
-TIR_REGISTER_PURE_BINARY_OP("tir.ldexp");
+TVM_TIR_REGISTER_PURE_BINARY_OP("ldexp");
+
+TVM_TIR_REGISTER_OP("TVMBackendAllocWorkspace")
+ .set_num_inputs(5)
+ .set_attr<TGlobalSymbol>("TGlobalSymbol", "TVMBackendAllocWorkspace")
+ .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
+
+TVM_TIR_REGISTER_OP("TVMBackendFreeWorkspace")
+ .set_num_inputs(3)
+ .set_attr<TGlobalSymbol>("TGlobalSymbol", "TVMBackendFreeWorkspace")
+ .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
// expose basic functions to node namespace
TVM_REGISTER_GLOBAL("node._const").set_body([](TVMArgs args, TVMRetValue* ret) {
diff --git a/src/tir/op/runtime.cc b/src/tir/op/runtime.cc
deleted file mode 100644
index adabae9e75..0000000000
--- a/src/tir/op/runtime.cc
+++ /dev/null
@@ -1,41 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file tir/op/runtime.cc
- * \brief TIR ops for runtime functions.
- */
-#include <tvm/ir/op.h>
-#include <tvm/tir/op_attr_types.h>
-
-namespace tvm {
-namespace tir {
-
-TVM_REGISTER_OP("tir.TVMBackendAllocWorkspace")
- .set_num_inputs(5)
- .set_attr<TGlobalSymbol>("TGlobalSymbol", "TVMBackendAllocWorkspace")
- .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
-
-TVM_REGISTER_OP("tir.TVMBackendFreeWorkspace")
- .set_num_inputs(3)
- .set_attr<TGlobalSymbol>("TGlobalSymbol", "TVMBackendFreeWorkspace")
- .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
-
-} // namespace tir
-} // namespace tvm
diff --git a/tests/cpp/traced_object_test.cc b/tests/cpp/traced_object_test.cc
deleted file mode 100644
index 7890a67eef..0000000000
--- a/tests/cpp/traced_object_test.cc
+++ /dev/null
@@ -1,268 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-#include <dmlc/logging.h>
-#include <gtest/gtest.h>
-#include <tvm/node/repr_printer.h>
-#include <tvm/runtime/container/map.h>
-#include <tvm/script/printer/traced_object.h>
-
-using namespace tvm;
-
-namespace {
-
-class DummyObjectNode : public Object {
- public:
- void VisitAttrs(AttrVisitor* v) {}
-
- static constexpr const char* _type_key = "TracedObjectTestDummyObject";
- TVM_DECLARE_FINAL_OBJECT_INFO(DummyObjectNode, Object);
-};
-
-class DummyObject : public ObjectRef {
- public:
- TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(DummyObject, ObjectRef, DummyObjectNode);
-};
-
-TVM_REGISTER_NODE_TYPE(DummyObjectNode);
-
-class ObjectWithAttrsNode : public Object {
- public:
- int64_t int64_attr = 5;
- Map<String, String> map_attr;
- Array<String> array_attr;
- DummyObject obj_attr;
-
- ObjectWithAttrsNode() : obj_attr(make_object<DummyObjectNode>()) {}
-
- void VisitAttrs(AttrVisitor* v) {
- v->Visit("int64_attr", &int64_attr);
- v->Visit("map_attr", &map_attr);
- v->Visit("array_attr", &array_attr);
- v->Visit("obj_attr", &obj_attr);
- }
-
- static constexpr const char* _type_key = "TracedObjectTestObjectWithAttrs";
- TVM_DECLARE_FINAL_OBJECT_INFO(ObjectWithAttrsNode, Object);
-};
-
-class ObjectWithAttrs : public ObjectRef {
- public:
- TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(ObjectWithAttrs, ObjectRef, ObjectWithAttrsNode);
-};
-
-TVM_REGISTER_NODE_TYPE(ObjectWithAttrsNode);
-
-} // anonymous namespace
-
-TEST(TracedObjectTest, MakeTraced_RootObject) {
- ObjectWithAttrs root(make_object<ObjectWithAttrsNode>());
- auto root_traced = MakeTraced(root);
-
- static_assert(std::is_same<decltype(root_traced), TracedObject<ObjectWithAttrs>>::value);
- ICHECK(root_traced.GetPath()->PathsEqual(ObjectPath::Root()));
- ICHECK_EQ(root_traced.Get().get(), root.get());
-}
-
-TEST(TracedObjectTest, MakeTraced_WithPath) {
- ObjectWithAttrs obj(make_object<ObjectWithAttrsNode>());
- auto traced = MakeTraced(obj, ObjectPath::Root()->Attr("foo"));
-
- static_assert(std::is_same<decltype(traced), TracedObject<ObjectWithAttrs>>::value);
- ICHECK(traced.GetPath()->PathsEqual(ObjectPath::Root()->Attr("foo")));
- ICHECK_EQ(traced.Get().get(), obj.get());
-}
-
-TEST(TracedObjectTest, TracedObject_ImplicitConversionFromDerived) {
- DummyObject obj(make_object<DummyObjectNode>());
- auto traced = MakeTraced(obj);
- static_assert(std::is_same<decltype(traced), TracedObject<DummyObject>>::value);
-
- // Check that TracedObject<DummyObject> is implicitly converted to TracedObject<ObjectRef>
- auto base_traced = [](const TracedObject<ObjectRef>& base) { return base; }(traced);
-
- static_assert(std::is_same<decltype(base_traced), TracedObject<ObjectRef>>::value);
-}
-
-TEST(TracedObjectTest, TracedObject_GetAttr_ObjectRef) {
- ObjectWithAttrs root(make_object<ObjectWithAttrsNode>());
- auto root_traced = MakeTraced(root);
- auto obj_attr = root_traced.GetAttr(&ObjectWithAttrsNode::obj_attr);
- static_assert(std::is_same<decltype(obj_attr), TracedObject<DummyObject>>::value);
- ICHECK(obj_attr.GetPath()->PathsEqual(ObjectPath::Root()->Attr("obj_attr")));
- ICHECK_EQ(obj_attr.Get().get(), root->obj_attr.get());
-}
-
-TEST(TracedObjectTest, TracedObject_GetAttr_Map) {
- ObjectWithAttrs root(make_object<ObjectWithAttrsNode>());
- root->map_attr.Set("foo", "bar");
-
- auto root_traced = MakeTraced(root);
- auto map_attr = root_traced.GetAttr(&ObjectWithAttrsNode::map_attr);
- static_assert(std::is_same<decltype(map_attr), TracedMap<String, String>>::value);
- ICHECK(map_attr.GetPath()->PathsEqual(ObjectPath::Root()->Attr("map_attr")));
- ICHECK_EQ(map_attr.Get().get(), root->map_attr.get());
-
- auto map_val = map_attr.at("foo");
- ICHECK_EQ(map_val.Get(), "bar");
- ICHECK(
- map_val.GetPath()->PathsEqual(ObjectPath::Root()->Attr("map_attr")->MapValue(String("foo"))));
-}
-
-TEST(TracedObjectTest, TracedObject_GetAttr_Array) {
- ObjectWithAttrs root(make_object<ObjectWithAttrsNode>());
- root->array_attr.push_back("foo");
- root->array_attr.push_back("bar");
-
- auto root_traced = MakeTraced(root);
- auto array_attr = root_traced.GetAttr(&ObjectWithAttrsNode::array_attr);
- static_assert(std::is_same<decltype(array_attr), TracedArray<String>>::value);
- ICHECK(array_attr.GetPath()->PathsEqual(ObjectPath::Root()->Attr("array_attr")));
- ICHECK_EQ(array_attr.Get().get(), root->array_attr.get());
-
- auto array_val = array_attr[1];
- ICHECK_EQ(array_val.Get(), "bar");
- ICHECK(array_val.GetPath()->PathsEqual(ObjectPath::Root()->Attr("array_attr")->ArrayIndex(1)));
-}
-
-TEST(TracedObjectTest, TracedObject_GetAttr_Int64) {
- ObjectWithAttrs root(make_object<ObjectWithAttrsNode>());
- auto root_traced = MakeTraced(root);
-
- auto int64_attr = root_traced.GetAttr(&ObjectWithAttrsNode::int64_attr);
- static_assert(std::is_same<decltype(int64_attr), TracedBasicValue<int64_t>>::value);
- ICHECK_EQ(int64_attr.Get(), 5);
- ICHECK(int64_attr.GetPath()->PathsEqual(ObjectPath::Root()->Attr("int64_attr")));
-}
-
-TEST(TracedObjectTest, TracedObject_IsInstance) {
- ObjectRef dummy(make_object<DummyObjectNode>());
- auto traced = MakeTraced(dummy);
- ICHECK(traced.IsInstance<DummyObject>());
- ICHECK(!traced.IsInstance<ObjectWithAttrs>());
-}
-
-TEST(TracedObjectTest, TracedObject_Downcast) {
- ObjectRef root(make_object<DummyObjectNode>());
- auto traced = MakeTraced(root);
-
- auto as_dummy = traced.Downcast<DummyObject>();
- static_assert(std::is_same<decltype(as_dummy), TracedObject<DummyObject>>::value);
- ICHECK_EQ(as_dummy.Get(), root);
-
- // Try downcasting to a wrong type
- bool caught = false;
- try {
- traced.Downcast<ObjectWithAttrs>();
- } catch (std::exception& e) {
- caught = strstr(e.what(),
- "Downcast from TracedObjectTestDummyObject to TracedObjectTestObjectWithAttrs "
- "failed") != nullptr;
- }
- ICHECK(caught);
-}
-
-TEST(TracedObjectTest, TracedObject_TryDowncast) {
- ObjectRef root(make_object<DummyObjectNode>());
- auto traced = MakeTraced(root);
-
- auto as_dummy = traced.TryDowncast<DummyObject>();
- static_assert(std::is_same<decltype(as_dummy), TracedOptional<DummyObject>>::value);
- ICHECK(as_dummy.defined());
- ICHECK_EQ(as_dummy.value().Get(), root);
-
- // Try downcasting to a wrong type
- ICHECK(!traced.TryDowncast<ObjectWithAttrs>().defined());
-}
-
-TEST(TracedObjectTest, TracedMap_At) {
- Map<String, String> m({{"k1", "foo"}, {"k2", "bar"}});
- auto traced = MakeTraced(m);
-
- auto traced_foo = traced.at("k1");
- static_assert(std::is_same<decltype(traced_foo), TracedObject<String>>::value);
- ICHECK_EQ(traced_foo.Get(), "foo");
- ICHECK(traced_foo.GetPath()->PathsEqual(ObjectPath::Root()->MapValue(String("k1"))));
-}
-
-TEST(TracedObjectTest, TracedMap_Iterator) {
- Map<String, String> m({{"k1", "foo"}, {"k2", "bar"}});
- auto traced = MakeTraced(m);
-
- size_t k1_count = 0;
- size_t k2_count = 0;
-
- for (const auto& kv : traced) {
- if (kv.first == "k1") {
- ++k1_count;
- ICHECK_EQ(kv.second.Get(), "foo");
- ICHECK(kv.second.GetPath()->PathsEqual(ObjectPath::Root()->MapValue(String("k1"))));
- } else if (kv.first == "k2") {
- ++k2_count;
- ICHECK_EQ(kv.second.Get(), "bar");
- ICHECK(kv.second.GetPath()->PathsEqual(ObjectPath::Root()->MapValue(String("k2"))));
- } else {
- ICHECK(false);
- }
- }
-
- ICHECK_EQ(k1_count, 1);
- ICHECK_EQ(k2_count, 1);
-}
-
-TEST(TracedObjectTest, TracedArray_Index) {
- Array<String> a = {"foo", "bar"};
- auto traced = MakeTraced(a);
-
- auto traced_bar = traced[1];
- static_assert(std::is_same<decltype(traced_bar), TracedObject<String>>::value);
- ICHECK_EQ(traced_bar.Get(), "bar");
- ICHECK(traced_bar.GetPath()->PathsEqual(ObjectPath::Root()->ArrayIndex(1)));
-}
-
-TEST(TracedObjectTest, TracedArray_Iterator) {
- Array<String> a = {"foo", "bar"};
- auto traced = MakeTraced(a);
-
- size_t index = 0;
- for (const auto& x : traced) {
- if (index == 0) {
- ICHECK_EQ(x.Get(), "foo");
- ICHECK(x.GetPath()->PathsEqual(ObjectPath::Root()->ArrayIndex(0)));
- } else if (index == 1) {
- ICHECK_EQ(x.Get(), "bar");
- ICHECK(x.GetPath()->PathsEqual(ObjectPath::Root()->ArrayIndex(1)));
- } else {
- ICHECK(false);
- }
- ++index;
- }
-
- ICHECK_EQ(index, 2);
-}
-
-TEST(TracedObjectTest, TracedBasicValue_ApplyFunc) {
- auto traced = MakeTraced(123, ObjectPath::Root()->Attr("foo"));
- static_assert(std::is_same<decltype(traced), TracedBasicValue<int>>::value);
-
- auto transformed = traced.ApplyFunc([](int x) { return x + 4.0; });
- static_assert(std::is_same<decltype(transformed), TracedBasicValue<double>>::value);
-
- ICHECK(transformed.GetPath()->PathsEqual(ObjectPath::Root()->Attr("foo")));
-}
diff --git a/tests/cpp/tvmscript_printer_irdocsifier_test.cc b/tests/cpp/tvmscript_printer_irdocsifier_test.cc
deleted file mode 100644
index 8c68399df2..0000000000
--- a/tests/cpp/tvmscript_printer_irdocsifier_test.cc
+++ /dev/null
@@ -1,117 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-#include <dmlc/logging.h>
-#include <gtest/gtest.h>
-#include <tvm/runtime/logging.h>
-#include <tvm/runtime/memory.h>
-#include <tvm/script/printer/doc.h>
-#include <tvm/script/printer/ir_docsifier.h>
-#include <tvm/script/printer/traced_object.h>
-
-using namespace tvm;
-using namespace tvm::script::printer;
-
-class TestObjectNode : public Object {
- public:
- void VisitAttrs(AttrVisitor* v) {}
-
- static constexpr const char* _type_key = "test.script.printer.irdocsifier.TestObject";
- TVM_DECLARE_FINAL_OBJECT_INFO(TestObjectNode, Object);
-};
-
-class TestObject : public ObjectRef {
- public:
- TestObject() : ObjectRef(runtime::make_object<TestObjectNode>()) {}
- TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TestObject, ObjectRef, TestObjectNode);
-};
-
-TVM_REGISTER_NODE_TYPE(TestObjectNode);
-
-TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
- .set_dispatch<TestObject>([](TracedObject<TestObject> obj, IRDocsifier p) {
- return IdDoc("x");
- });
-
-TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
- .set_dispatch<TestObject>("tir", [](TracedObject<TestObject> obj, IRDocsifier p) {
- return IdDoc("tir");
- });
-
-TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
- .set_dispatch<TestObject>("relax", [](TracedObject<TestObject> obj, IRDocsifier p) {
- return IdDoc("relax");
- });
-
-TEST(PrinterIRDocsifierTest, AsDoc) {
- IRDocsifier p(Map<String, String>{});
- ObjectPath path = ObjectPath::Root();
- TestObject obj;
-
- IdDoc doc = p->AsDoc<IdDoc>(MakeTraced(obj, path));
-
- ICHECK_EQ(doc->name, "x");
-}
-
-TEST(PrinterIRDocsifierTest, AsExprDoc) {
- IRDocsifier p(Map<String, String>{});
- ObjectPath path = ObjectPath::Root();
- TestObject obj;
-
- ExprDoc doc = p->AsExprDoc(MakeTraced(obj, path));
-
- ICHECK_EQ(Downcast<IdDoc>(doc)->name, "x");
-}
-
-TEST(PrinterIRDocsifierTest, WithDispatchToken) {
- IRDocsifier p(Map<String, String>{});
- TracedObject<TestObject> obj = MakeTraced(TestObject(), ObjectPath::Root());
-
- ICHECK_EQ(p->AsDoc<IdDoc>(obj)->name, "x");
-
- {
- auto ctx = p->WithDispatchToken("tir");
- ICHECK_EQ(p->AsDoc<IdDoc>(obj)->name, "tir");
-
- {
- auto ctx = p->WithDispatchToken("relax");
- ICHECK_EQ(p->AsDoc<IdDoc>(obj)->name, "relax");
- }
-
- ICHECK_EQ(p->AsDoc<IdDoc>(obj)->name, "tir");
- }
-
- ICHECK_EQ(p->AsDoc<IdDoc>(obj)->name, "x");
-}
-
-TEST(PrinterIRDocsifierTest, WithFrame) {
- IRDocsifier p(Map<String, String>{});
- TestObject obj;
-
- {
- VarDefFrame frame;
- auto ctx = p->WithFrame(frame);
- ICHECK_EQ(p->frames.size(), 1);
-
- p->vars->Define(obj, "x", ObjectPath::Root(), frame);
- ICHECK(p->vars->IsVarDefined(obj));
- }
- ICHECK_EQ(p->frames.size(), 0);
- ICHECK(!p->vars->IsVarDefined(obj));
-}
diff --git a/tests/cpp/tvmscript_printer_traced_object_functor_test.cc b/tests/cpp/tvmscript_printer_traced_object_functor_test.cc
deleted file mode 100644
index d662ce1324..0000000000
--- a/tests/cpp/tvmscript_printer_traced_object_functor_test.cc
+++ /dev/null
@@ -1,188 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-#include <dmlc/logging.h>
-#include <gtest/gtest.h>
-#include <tvm/node/object_path.h>
-#include <tvm/runtime/packed_func.h>
-#include <tvm/script/printer/traced_object.h>
-#include <tvm/script/printer/traced_object_functor.h>
-
-using namespace tvm;
-using namespace tvm::script::printer;
-
-namespace {
-
-class FooObjectNode : public Object {
- public:
- void VisitAttrs(AttrVisitor* v) {}
-
- static constexpr const char* _type_key = "test.TracedObjectFunctor.FooObject";
- TVM_DECLARE_FINAL_OBJECT_INFO(FooObjectNode, Object);
-};
-
-class FooObject : public ObjectRef {
- public:
- FooObject() { this->data_ = make_object<FooObjectNode>(); }
- TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(FooObject, ObjectRef, FooObjectNode);
-};
-
-TVM_REGISTER_NODE_TYPE(FooObjectNode);
-
-class BarObjectNode : public Object {
- public:
- void VisitAttrs(AttrVisitor* v) {}
-
- static constexpr const char* _type_key = "test.TracedObjectFunctor.BarObject";
- TVM_DECLARE_FINAL_OBJECT_INFO(BarObjectNode, Object);
-};
-
-class BarObject : public ObjectRef {
- public:
- BarObject() { this->data_ = make_object<BarObjectNode>(); }
- TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(BarObject, ObjectRef, BarObjectNode);
-};
-
-TVM_REGISTER_NODE_TYPE(BarObjectNode);
-
-String ComputeFoo(TracedObject<FooObject> foo) { return "Foo"; }
-
-} // anonymous namespace
-
-TEST(TracedObjectFunctorTest, NormalRegistration) {
- TracedObjectFunctor<String> functor;
- ObjectPath path = ObjectPath::Root();
-
- functor.set_dispatch<FooObject>([](TracedObject<FooObject> o) -> String { return "Foo"; });
- functor.set_dispatch<BarObject>([](TracedObject<BarObject> o) -> String { return "Bar"; });
-
- ICHECK_EQ(functor("", MakeTraced(FooObject(), path)), "Foo");
- ICHECK_EQ(functor("", MakeTraced(BarObject(), path)), "Bar");
-}
-
-TEST(TracedObjectFunctorTest, RegistrationWithFunction) {
- TracedObjectFunctor<String> functor;
- ObjectPath path = ObjectPath::Root();
-
- functor.set_dispatch<FooObject>([](TracedObject<FooObject> o) -> String { return "FooLambda"; });
- functor.set_dispatch<FooObject>("tir", ComputeFoo);
-
- ICHECK_EQ(functor("", MakeTraced(FooObject(), path)), "FooLambda");
- ICHECK_EQ(functor("tir", MakeTraced(FooObject(), path)), "Foo");
-}
-
-TEST(TracedObjectFunctorTest, RegistrationWithDispatchToken) {
- TracedObjectFunctor<String> functor;
- ObjectPath path = ObjectPath::Root();
-
- functor.set_dispatch<FooObject>([](TracedObject<FooObject> o) -> String { return "Foo"; });
- functor.set_dispatch<FooObject>("tir",
- [](TracedObject<FooObject> o) -> String { return "Foo tir"; });
- functor.set_dispatch<FooObject>("relax",
- [](TracedObject<FooObject> o) -> String { return "Foo relax"; });
-
- ICHECK_EQ(functor("", MakeTraced(FooObject(), path)), "Foo");
- ICHECK_EQ(functor("tir", MakeTraced(FooObject(), path)), "Foo tir");
- ICHECK_EQ(functor("relax", MakeTraced(FooObject(), path)), "Foo relax");
- ICHECK_EQ(functor("xyz", MakeTraced(FooObject(), path)), "Foo");
-}
-
-TEST(TracedObjectFunctorTest, RegistrationWithPackedFunc) {
- TracedObjectFunctor<String> functor;
- ObjectPath path = ObjectPath::Root();
-
- auto f_default = [](runtime::TVMArgs, runtime::TVMRetValue* ret) { *ret = String("default"); };
- auto f_tir = [](runtime::TVMArgs, runtime::TVMRetValue* ret) { *ret = String("tir"); };
-
- functor.set_dispatch("", FooObjectNode::RuntimeTypeIndex(), runtime::PackedFunc(f_default));
- functor.set_dispatch("tir", FooObjectNode::RuntimeTypeIndex(), runtime::PackedFunc(f_tir));
-
- ICHECK_EQ(functor("", MakeTraced(FooObject(), path)), "default");
- ICHECK_EQ(functor("tir", MakeTraced(FooObject(), path)), "tir");
-}
-
-TEST(TracedObjectFunctorTest, ExtraArg) {
- TracedObjectFunctor<int, int> functor;
- ObjectPath path = ObjectPath::Root();
-
- functor.set_dispatch<FooObject>([](TracedObject<FooObject> o, int x) { return x; });
- functor.set_dispatch<BarObject>([](TracedObject<BarObject> o, int x) { return x + 1; });
-
- ICHECK_EQ(functor("", MakeTraced(FooObject(), path), 2), 2);
- ICHECK_EQ(functor("", MakeTraced(BarObject(), path), 2), 3);
- ICHECK_EQ(functor("tir", MakeTraced(BarObject(), path), 2), 3);
-}
-
-TEST(TracedObjectFunctorTest, RemoveDispatchFunction) {
- TracedObjectFunctor<String> functor;
- ObjectPath path = ObjectPath::Root();
-
- functor.set_dispatch<FooObject>([](TracedObject<FooObject> o) -> String { return "Foo"; });
- functor.set_dispatch<FooObject>("tir",
- [](TracedObject<FooObject> o) -> String { return "Foo tir"; });
-
- ICHECK_EQ(functor("", MakeTraced(FooObject(), path)), "Foo");
- ICHECK_EQ(functor("tir", MakeTraced(FooObject(), path)), "Foo tir");
-
- functor.remove_dispatch("tir", FooObjectNode::RuntimeTypeIndex());
- ICHECK_EQ(functor("tir", MakeTraced(FooObject(), path)), "Foo");
-}
-
-TEST(TracedObjectFunctorTest, CallWithUnregisteredType) {
- TracedObjectFunctor<int, int> functor;
- ObjectPath path = ObjectPath::Root();
-
- bool failed = false;
- try {
- ICHECK_EQ(functor("", MakeTraced(FooObject(), path), 2), 2);
- } catch (...) {
- failed = true;
- }
- ASSERT_EQ(failed, true);
-}
-
-TEST(TracedObjectFunctorTest, DuplicateRegistration_WithoutToken) {
- TracedObjectFunctor<int, int> functor;
- ObjectPath path = ObjectPath::Root();
-
- functor.set_dispatch<FooObject>([](TracedObject<FooObject> o, int x) { return x; });
-
- bool failed = false;
- try {
- functor.set_dispatch<FooObject>([](TracedObject<FooObject> o, int x) { return x; });
- } catch (...) {
- failed = true;
- }
- ASSERT_EQ(failed, true);
-}
-
-TEST(TracedObjectFunctorTest, DuplicateRegistration_WithToken) {
- TracedObjectFunctor<int, int> functor;
- ObjectPath path = ObjectPath::Root();
-
- functor.set_dispatch<FooObject>("tir", [](TracedObject<FooObject> o, int x) { return x; });
-
- bool failed = false;
- try {
- functor.set_dispatch<FooObject>("tir", [](TracedObject<FooObject> o, int x) { return x; });
- } catch (...) {
- failed = true;
- }
- ASSERT_EQ(failed, true);
-}
diff --git a/tests/cpp/tvmscript_printer_var_table_test.cc b/tests/cpp/tvmscript_printer_var_table_test.cc
deleted file mode 100644
index b447c81ac0..0000000000
--- a/tests/cpp/tvmscript_printer_var_table_test.cc
+++ /dev/null
@@ -1,158 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-#include <dmlc/logging.h>
-#include <gtest/gtest.h>
-#include <tvm/node/object_path.h>
-#include <tvm/runtime/logging.h>
-#include <tvm/script/printer/frame.h>
-#include <tvm/script/printer/var_table.h>
-#include <tvm/tir/var.h>
-
-using namespace tvm;
-using namespace tvm::script::printer;
-
-TEST(PrinterVarTableTest, Define) {
- VarTable vars;
- MetadataFrame frame;
- tir::Var x("x");
- ObjectPath object_path = ObjectPath::Root();
-
- IdDoc doc = vars->Define(x, "x", object_path, frame);
-
- ICHECK_EQ(doc->name, "x");
-
- IdDoc second_doc = Downcast<IdDoc>(vars->GetVarDoc(x, object_path).value());
-
- ICHECK_EQ(second_doc->name, "x");
-}
-
-TEST(PrinterVarTableTest, DefineByDoc) {
- VarTable vars;
- MetadataFrame frame;
- tir::Var x("x");
- ObjectPath object_path = ObjectPath::Root();
-
- auto doc_factory = []() { return LiteralDoc::Str("x"); };
-
- vars->DefineByDoc(x, doc_factory, frame);
-
- ExprDoc doc = vars->GetVarDoc(x, object_path).value();
-
- ICHECK_EQ(Downcast<String>(Downcast<LiteralDoc>(doc)->value), "x");
-}
-
-TEST(PrinterVarTableTest, GetVarDocWithUnknownVariable) {
- VarTable vars;
- MetadataFrame frame;
- tir::Var x("x");
- tir::Var y("y");
- ObjectPath object_path = ObjectPath::Root();
-
- Doc doc = vars->Define(x, "x", object_path, frame);
- ICHECK(!vars->GetVarDoc(y, object_path).defined());
-}
-
-TEST(PrinterVarTableTest, GetVarDocWithObjectPath) {
- VarTable vars;
- MetadataFrame frame;
- tir::Var x("x");
- ObjectPath object_path = ObjectPath::Root();
- ObjectPath second_object_path = ObjectPath::Root()->Attr("x");
-
- IdDoc doc = vars->Define(x, "x", object_path, frame);
- ICHECK_EQ(doc->source_paths[0], object_path);
- ICHECK_EQ(doc->source_paths.size(), 1);
-
- Doc second_doc = vars->GetVarDoc(x, second_object_path).value();
- ICHECK_EQ(second_doc->source_paths[0], second_object_path);
- ICHECK_EQ(second_doc->source_paths.size(), 1);
-}
-
-TEST(PrinterVarTableTest, IsVarDefined) {
- VarTable vars;
- MetadataFrame frame;
- tir::Var x("x");
- tir::Var y("y");
- ObjectPath object_path = ObjectPath::Root();
-
- vars->Define(x, "x", object_path, frame);
- ICHECK(vars->IsVarDefined(x));
- ICHECK(!vars->IsVarDefined(y));
-}
-
-TEST(PrinterVarTableTest, VarRemovedAfterFrameOutOfScope) {
- VarTable vars;
- MetadataFrame frame;
- tir::Var x("x");
- ObjectPath object_path = ObjectPath::Root();
-
- vars->Define(x, "x", object_path, frame);
- ICHECK(vars->IsVarDefined(x));
-
- frame->ExitWithScope();
- ICHECK(!vars->IsVarDefined(x));
-}
-
-TEST(PrinterVarTableTest, DefineDuplicateName) {
- VarTable vars;
- MetadataFrame frame;
- tir::Var x("x");
- tir::Var y("y");
- ObjectPath object_path = ObjectPath::Root();
-
- IdDoc x_doc = vars->Define(x, "x", object_path, frame);
- IdDoc y_doc = vars->Define(y, "x", object_path, frame);
-
- ICHECK_NE(x_doc->name, y_doc->name);
-}
-
-TEST(PrinterVarTableTest, DefineDuplicateVariable) {
- VarTable vars;
- MetadataFrame frame;
- tir::Var x("x");
- ObjectPath object_path = ObjectPath::Root();
-
- vars->Define(x, "x", object_path, frame);
-
- bool failed = false;
- try {
- vars->Define(x, "x", object_path, frame);
- } catch (...) {
- failed = true;
- }
- ASSERT_EQ(failed, true);
-}
-
-TEST(PrinterVarTableTest, DefineByDocWithIdDoc) {
- VarTable vars;
- MetadataFrame frame;
- tir::Var x("x");
- ObjectPath object_path = ObjectPath::Root();
-
- bool failed = false;
- try {
- // User has to use `Define` if variable needs to be mapped to IdDoc
- vars->DefineByDoc(
- x, []() { return IdDoc("x"); }, frame);
- } catch (...) {
- failed = true;
- }
- ASSERT_EQ(failed, true);
-}
diff --git a/tests/python/unittest/test_tvmscript_printer_entry_point.py b/tests/python/unittest/test_tvmscript_printer_entry_point.py
deleted file mode 100644
index 208386dbdd..0000000000
--- a/tests/python/unittest/test_tvmscript_printer_entry_point.py
+++ /dev/null
@@ -1,30 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-import pytest
-
-from tvm.error import TVMError
-from tvm.script.printer import script
-from tvm.tir import FloatImm
-
-
-def test_as_script_unknown_ir():
- ir_node = FloatImm("float32", 1.0)
-
- with pytest.raises(TVMError) as e:
- script(ir_node, "test_xyz", {})
-
- assert "test_xyz" in str(e.value)
diff --git a/tests/python/unittest/test_tvmscript_printer_frame.py b/tests/python/unittest/test_tvmscript_printer_frame.py
deleted file mode 100644
index bd98d64456..0000000000
--- a/tests/python/unittest/test_tvmscript_printer_frame.py
+++ /dev/null
@@ -1,60 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-from tvm.script.printer.frame import MetadataFrame
-
-
-def test_frame_add_callback():
- frame = MetadataFrame()
-
- flag = 0
-
- def callback1():
- nonlocal flag
- flag += 1
-
- def callback2():
- nonlocal flag
- flag += 5
-
- frame.add_exit_callback(callback1)
- with frame:
- frame.add_exit_callback(callback2)
- assert flag == 0
-
- assert flag == 6
-
-
-def test_frame_clear_callbacks_after_exit():
- frame = MetadataFrame()
-
- flag = 0
-
- def callback():
- nonlocal flag
- flag += 1
-
- frame.add_exit_callback(callback)
-
- with frame:
- pass
-
- assert flag == 1
-
- with frame:
- pass
-
- assert flag == 1
diff --git a/tests/python/unittest/test_tvmscript_printer_irdocsifier.py b/tests/python/unittest/test_tvmscript_printer_irdocsifier.py
deleted file mode 100644
index d9d552ce4b..0000000000
--- a/tests/python/unittest/test_tvmscript_printer_irdocsifier.py
+++ /dev/null
@@ -1,123 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-import pytest
-
-from tvm.runtime import ObjectPath
-from tvm.script.printer.doc import IdDoc
-from tvm.script.printer.frame import MetadataFrame, VarDefFrame
-from tvm.script.printer.ir_docsifier import IRDocsifier, RootNodeContainer
-from tvm.tir import Var
-
-
-@pytest.fixture
-def ir_docsifier():
- """
- Creates an IRDocsifier instance with a special dispatch token.
- """
- _ir_docsifier = IRDocsifier({})
- with _ir_docsifier.dispatch_token(f"{__file__}"):
- yield _ir_docsifier
-
-
-def _get_id_doc_printer(id_name):
- def printer(obj, object_path, ir_docsifier): # pylint: disable=unused-argument
- return IdDoc(id_name)
-
- return printer
-
-
-def _root_dispatch_function(obj, ir_docsifier):
- doc = ir_docsifier.as_doc(obj, ObjectPath.root())
- doc.source_paths = [ObjectPath.root().attr("irdocsifier_test")]
- return doc
-
-
-# Because the dispatch table is global, tests should only set dispatch function under
-# unique dispatch token.
-IRDocsifier.set_dispatch(Var, _get_id_doc_printer("x"), f"{__file__}")
-IRDocsifier.set_root_dispatch(f"{__file__}", _root_dispatch_function)
-
-
-def test_set_dispatch(ir_docsifier):
- IRDocsifier.set_dispatch(Var, _get_id_doc_printer("x2"), f"{__file__}-2")
- with ir_docsifier.dispatch_token(f"{__file__}-2"):
- doc = ir_docsifier.as_doc(Var("x", dtype="int8"), ObjectPath.root())
- assert doc.name == "x2"
-
- doc = ir_docsifier.as_doc(Var("x", dtype="int8"), ObjectPath.root())
- assert doc.name == "x"
-
-
-def test_set_root_dispatch(ir_docsifier):
- doc = ir_docsifier.as_doc(RootNodeContainer(Var("x", dtype="int8")), ObjectPath.root())
- assert ObjectPath.root().attr("irdocsifier_test") in doc.source_paths
-
-
-def test_as_doc(ir_docsifier):
- object_path = ObjectPath.root()
- doc = ir_docsifier.as_doc(Var("x", "int8"), ObjectPath.root())
- assert doc.name == "x"
- assert list(doc.source_paths) == [object_path]
-
-
-def test_with_dispatch_token(ir_docsifier):
- initial_token_count = len(ir_docsifier.dispatch_tokens)
-
- with ir_docsifier.dispatch_token("tir"):
- assert len(ir_docsifier.dispatch_tokens) == initial_token_count + 1
-
- assert len(ir_docsifier.dispatch_tokens) == initial_token_count
-
-
-def test_with_frame(ir_docsifier):
- initial_frame_count = len(ir_docsifier.frames)
-
- frame = VarDefFrame()
- is_callback_called = False
-
- def callback():
- nonlocal is_callback_called
- is_callback_called = True
-
- frame.add_exit_callback(callback)
-
- with ir_docsifier.frame(frame):
- assert len(ir_docsifier.frames) == initial_frame_count + 1
- assert not is_callback_called
-
- assert len(ir_docsifier.frames) == initial_frame_count
- assert is_callback_called
-
-
-def test_get_frame(ir_docsifier):
- with ir_docsifier.frame(VarDefFrame()) as frame_a:
- assert ir_docsifier.get_frame(MetadataFrame) is None
- assert ir_docsifier.get_frame(VarDefFrame) == frame_a
-
- with ir_docsifier.frame(VarDefFrame()) as frame_b:
- assert ir_docsifier.get_frame(MetadataFrame) is None
- assert ir_docsifier.get_frame(VarDefFrame) == frame_b
-
- with ir_docsifier.frame(MetadataFrame()) as frame_c:
- assert ir_docsifier.get_frame(MetadataFrame) == frame_c
- assert ir_docsifier.get_frame(VarDefFrame) == frame_b
-
- assert ir_docsifier.get_frame(MetadataFrame) is None
- assert ir_docsifier.get_frame(VarDefFrame) == frame_b
-
- assert ir_docsifier.get_frame(MetadataFrame) is None
- assert ir_docsifier.get_frame(VarDefFrame) == frame_a
diff --git a/tests/python/unittest/test_tvmscript_printer_var_table.py b/tests/python/unittest/test_tvmscript_printer_var_table.py
deleted file mode 100644
index eab63a08dd..0000000000
--- a/tests/python/unittest/test_tvmscript_printer_var_table.py
+++ /dev/null
@@ -1,89 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""
-This file tests the FFI binding of script.printer.VarTable.
-These only make sure parameter can be passed to the C++ functions
-correctly. The test for the functionality of VarTable is in C++.
-"""
-
-from tvm.runtime import ObjectPath
-from tvm.script.printer.doc import LiteralDoc
-from tvm.script.printer.frame import VarDefFrame
-from tvm.script.printer.var_table import VarTable
-from tvm.tir import Var
-
-
-def test_define():
- var_table = VarTable()
- var_name = "a"
- var_obj = Var(var_name, dtype="int32")
- object_path = ObjectPath.root().attr("a")
- frame = VarDefFrame()
-
- id_doc = var_table.define(var_obj, var_name, object_path, frame)
-
- assert id_doc.name == "a"
- assert list(id_doc.source_paths) == [object_path]
-
- id_doc = var_table.get_var_doc(var_obj, object_path)
-
- assert id_doc.name == "a"
- assert list(id_doc.source_paths) == [object_path]
-
-
-def test_define_by_doc():
- var_table = VarTable()
- var_name = "a"
- var_obj = Var(var_name, dtype="int32")
- object_path = ObjectPath.root().attr("a")
- frame = VarDefFrame()
-
- var_table.define_by_doc(var_obj, lambda: LiteralDoc(var_name), frame)
-
- var_doc = var_table.get_var_doc(var_obj, object_path)
-
- assert isinstance(var_doc, LiteralDoc)
- assert var_doc.value == var_name
- assert list(var_doc.source_paths) == [object_path]
-
-
-def test_is_var_defined():
- var_table = VarTable()
- a = Var("a", dtype="int32")
- object_path = ObjectPath.root().attr("a")
- frame = VarDefFrame()
-
- var_table.define(a, "a", object_path, frame)
-
- assert var_table.is_var_defined(a)
- assert a in var_table
-
-
-def test_var_out_of_scope():
- var_table = VarTable()
- var_name = "a"
- var_obj = Var(var_name, dtype="int32")
- object_path = ObjectPath.root().attr("a")
- frame = VarDefFrame()
-
- var_table.define(var_obj, var_name, object_path, frame)
-
- with frame:
- assert var_obj in var_table
-
- assert var_obj not in var_table
- assert var_table.get_var_doc(var_obj, object_path) is None