You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2022/07/27 16:27:53 UTC
[tvm] branch main updated: [TVMScript] ExprDoc (#12048)
This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 584b0f31d8 [TVMScript] ExprDoc (#12048)
584b0f31d8 is described below
commit 584b0f31d8994390c81f96a783621982153adf8b
Author: Lite Ye <ye...@gmail.com>
AuthorDate: Wed Jul 27 12:27:48 2022 -0400
[TVMScript] ExprDoc (#12048)
This PR addes:
- All ExprDoc subclasses
- Their Python bindings
- Support of ExprDoc in PythonDocPrinter
- Unit tests for ExprDoc in PythonDocPrinter
Tracking issue: https://github.com/apache/tvm/issues/11912
---
include/tvm/script/printer/doc.h | 482 +++++++++++++++++++++
python/tvm/script/printer/doc.py | 248 ++++++++++-
src/script/printer/base_doc_printer.cc | 20 +
src/script/printer/base_doc_printer.h | 50 +++
src/script/printer/doc.cc | 147 +++++++
src/script/printer/python_doc_printer.cc | 203 +++++++++
.../python/unittest/test_tvmscript_printer_doc.py | 217 +++++++++-
.../test_tvmscript_printer_python_doc_printer.py | 380 +++++++++++++++-
8 files changed, 1741 insertions(+), 6 deletions(-)
diff --git a/include/tvm/script/printer/doc.h b/include/tvm/script/printer/doc.h
index 67c27bd45a..f3f980e53f 100644
--- a/include/tvm/script/printer/doc.h
+++ b/include/tvm/script/printer/doc.h
@@ -63,6 +63,8 @@ class Doc : public ObjectRef {
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Doc, ObjectRef, DocNode);
};
+class ExprDoc;
+
/*!
* \brief The base class of expression doc.
*
@@ -70,6 +72,34 @@ class Doc : public ObjectRef {
*/
class ExprDocNode : public DocNode {
public:
+ /*!
+ * \brief Create a doc representing attribute access on the current ExprDoc
+ * \param attr The attribute to access.
+ */
+ ExprDoc Attr(String attr) const;
+
+ /*!
+ * \brief Create a doc representing index access on the current ExprDoc
+ * \param indices The indices to access.
+ */
+ ExprDoc operator[](Array<Doc> indices) const;
+
+ /*!
+ * \brief Create a doc representing calling the current ExprDoc
+ * \param args The positional arguments of the function call.
+ */
+ ExprDoc Call(Array<ExprDoc, void> args) const;
+
+ /*!
+ * \brief Create a doc representing attribute access on the current ExprDoc
+ * \param args The positional arguments of the function call.
+ * \param kwargs_keys Keys of keywords arguments of the function call.
+ * \param kwargs_values Values of keywords arguments of the function call.
+ */
+ ExprDoc Call(Array<ExprDoc, void> args, //
+ Array<String> kwargs_keys, //
+ Array<ExprDoc, void> kwargs_values) const;
+
void VisitAttrs(AttrVisitor* v) { DocNode::VisitAttrs(v); }
static constexpr const char* _type_key = "script.printer.ExprDoc";
@@ -158,6 +188,458 @@ class LiteralDoc : public ExprDoc {
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(LiteralDoc, ExprDoc, LiteralDocNode);
};
+/*!
+ * \brief Doc that represents identifier.
+ *
+ * \sa IdDoc
+ */
+class IdDocNode : public ExprDocNode {
+ public:
+ /*! \brief The name of the identifier */
+ String name;
+
+ void VisitAttrs(AttrVisitor* v) {
+ ExprDocNode::VisitAttrs(v);
+ v->Visit("name", &name);
+ }
+
+ static constexpr const char* _type_key = "script.printer.IdDoc";
+ TVM_DECLARE_FINAL_OBJECT_INFO(IdDocNode, ExprDocNode);
+};
+
+/*!
+ * \brief Reference type of IdDocNode.
+ *
+ * \sa IdDocNode
+ */
+class IdDoc : public ExprDoc {
+ public:
+ /*!
+ * \brief Constructor of IdDoc.
+ * \param name The name of identifier.
+ */
+ explicit IdDoc(String name);
+ TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(IdDoc, ExprDoc, IdDocNode);
+};
+
+/*!
+ * \brief Doc that represents attribute access on another expression.
+ *
+ * \sa AttrAccessDoc
+ */
+class AttrAccessDocNode : public ExprDocNode {
+ public:
+ /*! \brief The target expression to be accessed */
+ ExprDoc value{nullptr};
+ /*! \brief The attribute to be accessed */
+ String name;
+
+ void VisitAttrs(AttrVisitor* v) {
+ ExprDocNode::VisitAttrs(v);
+ v->Visit("value", &value);
+ v->Visit("name", &name);
+ }
+
+ static constexpr const char* _type_key = "script.printer.AttrAccessDoc";
+ TVM_DECLARE_FINAL_OBJECT_INFO(AttrAccessDocNode, ExprDocNode);
+};
+
+/*!
+ * \brief Reference type of AttrAccessDocNode.
+ *
+ * \sa AttrAccessDocNode
+ */
+class AttrAccessDoc : public ExprDoc {
+ public:
+ /*!
+ * \brief Constructor of AttrAccessDoc
+ * \param value The target expression of attribute access.
+ * \param name The name of attribute to access.
+ */
+ explicit AttrAccessDoc(ExprDoc value, String name);
+ TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(AttrAccessDoc, ExprDoc, AttrAccessDocNode);
+};
+
+/*!
+ * \brief Doc that represents index access on another expression.
+ *
+ * \sa IndexDoc
+ */
+class IndexDocNode : public ExprDocNode {
+ public:
+ /*! \brief The container value to be accessed */
+ ExprDoc value{nullptr};
+ /*!
+ * \brief The indices to access
+ *
+ * Possible actual types:
+ * - ExprDoc (single point access like a[1, 2])
+ * - SliceDoc (slice access like a[1:5, 2])
+ */
+ Array<Doc> indices; // Each element is union of: Slice / ExprDoc
+
+ void VisitAttrs(AttrVisitor* v) {
+ ExprDocNode::VisitAttrs(v);
+ v->Visit("value", &value);
+ v->Visit("indices", &indices);
+ }
+
+ static constexpr const char* _type_key = "script.printer.IndexDoc";
+ TVM_DECLARE_FINAL_OBJECT_INFO(IndexDocNode, ExprDocNode);
+};
+
+/*!
+ * \brief Reference type of IndexDocNode.
+ *
+ * \sa IndexDocNode
+ */
+class IndexDoc : public ExprDoc {
+ public:
+ /*!
+ * \brief Constructor of IndexDoc
+ * \param value The target expression of index access.
+ * \param indices The indices to access.
+ */
+ explicit IndexDoc(ExprDoc value, Array<Doc> indices);
+ TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(IndexDoc, ExprDoc, IndexDocNode);
+};
+
+/*!
+ * \brief Doc that represents function call.
+ *
+ * \sa CallDoc
+ */
+class CallDocNode : public ExprDocNode {
+ public:
+ /*! \brief The callee of this function call */
+ ExprDoc callee{nullptr};
+ /*! \brief The positional arguments */
+ Array<ExprDoc> args;
+ /*! \brief The keys of keyword arguments */
+ Array<String> kwargs_keys;
+ /*!
+ * \brief The values of keyword arguments.
+ *
+ * The i-th element is the value of the i-th key in `kwargs_keys`.
+ * It must have the same length as `kwargs_keys`.
+ */
+ Array<ExprDoc> kwargs_values;
+
+ void VisitAttrs(AttrVisitor* v) {
+ ExprDocNode::VisitAttrs(v);
+ v->Visit("callee", &callee);
+ v->Visit("args", &args);
+ v->Visit("kwargs_keys", &kwargs_keys);
+ v->Visit("kwargs_values", &kwargs_values);
+ }
+
+ static constexpr const char* _type_key = "script.printer.CallDoc";
+ TVM_DECLARE_FINAL_OBJECT_INFO(CallDocNode, ExprDocNode);
+};
+
+/*!
+ * \brief Reference type of CallDocNode.
+ *
+ * \sa CallDocNode
+ */
+class CallDoc : public ExprDoc {
+ public:
+ /*!
+ * \brief Constructor of CallDoc
+ * \param callee The callee of this function call.
+ * \param args The positional arguments.
+ * \param kwargs_keys Keys of keyword arguments.
+ * \param kwargs_values Values of keyword arguments, must have the same length as `kwargs_keys.
+ */
+ CallDoc(ExprDoc callee, Array<ExprDoc> args, Array<String> kwargs_keys,
+ Array<ExprDoc> kwargs_values);
+ TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(CallDoc, ExprDoc, CallDocNode);
+};
+
+/*!
+ * \brief Doc that represents operation.
+ *
+ * It can be unary, binary and other special operators (for example,
+ * the if-then-else expression).
+ *
+ * \sa OperationDoc
+ */
+class OperationDocNode : public ExprDocNode {
+ public:
+ enum class Kind : int32_t {
+ // Unary operators
+ kUnaryStart = 0,
+ kUSub = 1, // -x
+ kInvert = 2, // ~x
+ kUnaryEnd = 3,
+
+ // Binary operators
+ kBinaryStart = 4,
+ kAdd = 5, // +
+ kSub = 6, // -
+ kMult = 7, // *
+ kDiv = 8, // /
+ kFloorDiv = 9, // // in Python
+ kMod = 10, // % in Python
+ kPow = 11, // ** in Python
+ kLShift = 12, // <<
+ kRShift = 13, // >>
+ kBitAnd = 14, // &
+ kBitOr = 15, // |
+ kBitXor = 16, // ^
+ kLt = 17, // <
+ kLtE = 18, // <=
+ kEq = 19, // ==
+ kNotEq = 20, // !=
+ kGt = 21, // >
+ kGtE = 22, // >=
+ kBinaryEnd = 23,
+
+ // Special
+ kSpecialStart = 24,
+ kIfThenElse = 25, // <operands[1]> if <operands[0]> else <operands[2]>
+ kSpecialEnd = 26
+ };
+
+ /*! \brief The kind of operation (operator) */
+ Kind kind;
+ /*! \brief Operands of this expression */
+ Array<ExprDoc> operands;
+
+ void VisitAttrs(AttrVisitor* v) {
+ ExprDocNode::VisitAttrs(v);
+ v->Visit("kind", &kind);
+ v->Visit("operands", &operands);
+ }
+
+ static constexpr const char* _type_key = "script.printer.OperationDoc";
+ TVM_DECLARE_FINAL_OBJECT_INFO(OperationDocNode, ExprDocNode);
+};
+
+/*!
+ * \brief Reference type of OperationDocNode.
+ *
+ * \sa OperationDocNode
+ */
+class OperationDoc : public ExprDoc {
+ public:
+ /*!
+ * \brief Constructor of OperationDoc
+ * \param kind The kind of operation.
+ * \param operands Operands of this expression.
+ */
+ explicit OperationDoc(OperationDocNode::Kind kind, Array<ExprDoc> operands);
+ TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(OperationDoc, ExprDoc, OperationDocNode);
+};
+
+/*!
+ * \brief Doc that represents anonymous function.
+ *
+ * LambdaDoc can only have positional arguments without type annotation,
+ * and a single expression as body.
+ *
+ * \sa LambdaDoc
+ */
+class LambdaDocNode : public ExprDocNode {
+ public:
+ /*! \brief The arguments of this anonymous function */
+ Array<IdDoc> args;
+ /*! \brief The body of this anonymous function */
+ ExprDoc body{nullptr};
+
+ void VisitAttrs(AttrVisitor* v) {
+ ExprDocNode::VisitAttrs(v);
+ v->Visit("args", &args);
+ v->Visit("body", &body);
+ }
+
+ static constexpr const char* _type_key = "script.printer.LambdaDoc";
+ TVM_DECLARE_FINAL_OBJECT_INFO(LambdaDocNode, ExprDocNode);
+};
+
+/*!
+ * \brief Reference type of LambdaDocNode.
+ *
+ * \sa LambdaDocNode
+ */
+class LambdaDoc : public ExprDoc {
+ public:
+ /*!
+ * \brief Constructor of LambdaDoc
+ * \param args Arguments of this function.
+ * \param body Body expression of this function.
+ */
+ explicit LambdaDoc(Array<IdDoc> args, ExprDoc body);
+ TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(LambdaDoc, ExprDoc, LambdaDocNode);
+};
+
+/*!
+ * \brief Doc that represents tuple literal.
+ *
+ * \sa TupleDoc
+ */
+class TupleDocNode : public ExprDocNode {
+ public:
+ /*! \brief Elements of tuple */
+ Array<ExprDoc> elements;
+
+ void VisitAttrs(AttrVisitor* v) {
+ ExprDocNode::VisitAttrs(v);
+ v->Visit("elements", &elements);
+ }
+
+ static constexpr const char* _type_key = "script.printer.TupleDoc";
+ TVM_DECLARE_FINAL_OBJECT_INFO(TupleDocNode, ExprDocNode);
+};
+
+/*!
+ * \brief Reference type of TupleDocNode.
+ *
+ * \sa TupleDocNode
+ */
+class TupleDoc : public ExprDoc {
+ public:
+ /*!
+ * \brief Create an empty TupleDoc
+ */
+ TupleDoc() : TupleDoc(runtime::make_object<TupleDocNode>()) {}
+ /*!
+ * \brief Constructor of TupleDoc
+ * \param elements Elements of tuple.
+ */
+ explicit TupleDoc(Array<ExprDoc> elements);
+ TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TupleDoc, ExprDoc, TupleDocNode);
+};
+
+/*!
+ * \brief Doc that represents list literal.
+ *
+ * \sa AttrAccessDoc
+ */
+class ListDocNode : public ExprDocNode {
+ public:
+ /*! \brief Elements of list */
+ Array<ExprDoc> elements;
+
+ void VisitAttrs(AttrVisitor* v) {
+ ExprDocNode::VisitAttrs(v);
+ v->Visit("elements", &elements);
+ }
+
+ static constexpr const char* _type_key = "script.printer.ListDoc";
+ TVM_DECLARE_FINAL_OBJECT_INFO(ListDocNode, ExprDocNode);
+};
+
+/*!
+ * \brief Reference type of ListDocNode.
+ *
+ * \sa ListDocNode
+ */
+class ListDoc : public ExprDoc {
+ public:
+ /*!
+ * \brief Create an empty ListDoc
+ */
+ ListDoc() : ListDoc(runtime::make_object<ListDocNode>()) {}
+ /*!
+ * \brief Constructor of ListDoc
+ * \param elements Elements of list.
+ */
+ explicit ListDoc(Array<ExprDoc> elements);
+ TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ListDoc, ExprDoc, ListDocNode);
+};
+
+/*!
+ * \brief Doc that represents dictionary literal.
+ *
+ * \sa AttrAccessDoc
+ */
+class DictDocNode : public ExprDocNode {
+ public:
+ /*! \brief keys of dictionary */
+ Array<ExprDoc> keys;
+ /*!
+ * \brief Values of dictionary
+ *
+ * The i-th element is the value of the i-th element of `keys`.
+ * It must have the same length as `keys`.
+ */
+ Array<ExprDoc> values;
+
+ void VisitAttrs(AttrVisitor* v) {
+ ExprDocNode::VisitAttrs(v);
+ v->Visit("keys", &keys);
+ v->Visit("values", &values);
+ }
+
+ static constexpr const char* _type_key = "script.printer.DictDoc";
+ TVM_DECLARE_FINAL_OBJECT_INFO(DictDocNode, ExprDocNode);
+};
+
+/*!
+ * \brief Reference type of DictDocNode.
+ *
+ * \sa DictDocNode
+ */
+class DictDoc : public ExprDoc {
+ public:
+ /*!
+ * \brief Create an empty dictionary
+ */
+ DictDoc() : DictDoc(runtime::make_object<DictDocNode>()) {}
+ /*!
+ * \brief Constructor of DictDoc
+ * \param keys Keys of dictionary.
+ * \param values Values of dictionary, must have same length as `keys`.
+ */
+ explicit DictDoc(Array<ExprDoc> keys, Array<ExprDoc> values);
+ TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(DictDoc, ExprDoc, DictDocNode);
+};
+
+/*!
+ * \brief Doc that represents slice in Index expression.
+ *
+ * This doc can only appear in IndexDoc::indices.
+ *
+ * \sa AttrAccessDoc
+ */
+class SliceDocNode : public DocNode {
+ public:
+ /*! \brief The start of slice */
+ Optional<ExprDoc> start;
+ /*! \brief The exclusive end of slice */
+ Optional<ExprDoc> stop;
+ /*! \brief The step of slice */
+ Optional<ExprDoc> step;
+
+ void VisitAttrs(AttrVisitor* v) {
+ DocNode::VisitAttrs(v);
+ v->Visit("start", &start);
+ v->Visit("stop", &stop);
+ v->Visit("step", &step);
+ }
+
+ static constexpr const char* _type_key = "script.printer.SliceDoc";
+ TVM_DECLARE_FINAL_OBJECT_INFO(SliceDocNode, DocNode);
+};
+
+/*!
+ * \brief Reference type of SliceDocNode.
+ *
+ * \sa SliceDocNode
+ */
+class SliceDoc : public Doc {
+ public:
+ /*!
+ * \brief Constructor of SliceDoc
+ * \param start The start of slice.
+ * \param stop The exclusive end of slice.
+ * \param step The step of slice.
+ */
+ explicit SliceDoc(Optional<ExprDoc> start, Optional<ExprDoc> stop, Optional<ExprDoc> step);
+ TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(SliceDoc, Doc, SliceDocNode);
+};
+
} // namespace printer
} // namespace script
} // namespace tvm
diff --git a/python/tvm/script/printer/doc.py b/python/tvm/script/printer/doc.py
index f6179d7351..acdb63dcf2 100644
--- a/python/tvm/script/printer/doc.py
+++ b/python/tvm/script/printer/doc.py
@@ -16,8 +16,13 @@
# under the License.
"""Doc types for TVMScript Unified Printer"""
+from typing import List, Dict, Tuple, Optional, Union, Sequence
+from enum import IntEnum, unique
+
import tvm._ffi
+import tvm.ir.container
from tvm.runtime import Object
+from tvm.tir import FloatImm, IntImm
from . import _ffi_api
@@ -29,12 +34,79 @@ class Doc(Object):
class ExprDoc(Object):
"""Base class of all expression Docs"""
+ def attr(self, name: str) -> "AttrAccessDoc":
+ """
+ Create a doc that represents attribute access on self.
+
+ Parameters
+ ----------
+ name : str
+ The attribute name to access
+
+ Returns
+ -------
+ doc : AttrAccessDoc
+ """
+ return _ffi_api.ExprDocAttr(self, name) # type: ignore
+
+ def call(self, *args: Tuple["ExprDoc"], **kwargs: Dict[str, "ExprDoc"]) -> "CallDoc":
+ """
+ Create a doc that represents function call, with self as callee.
+
+ Parameters
+ ----------
+ *args : ExprDoc
+ The positional arguments of the function call.
+ **kwargs
+ The keyword arguments of the function call.
+
+ Returns
+ -------
+ doc : CallDoc
+ """
+ kwargs_keys = list(kwargs.keys())
+ kwargs_values = list(kwargs.values())
+ return _ffi_api.ExprDocCall(self, args, kwargs_keys, kwargs_values) # type: ignore
+
+ _IndexType = Union["ExprDoc", "SliceDoc"]
+
+ def __getitem__(self, indices: Union[Tuple[_IndexType], _IndexType]) -> "IndexDoc":
+ """
+ Create a doc that represents index access on self.
+
+ Parameters
+ ----------
+ indices : Union[Tuple[Union["ExprDoc", "SliceDoc"]], Union["ExprDoc", "SliceDoc"]]
+ The indices to access
+
+ Returns
+ -------
+ doc : IndexDoc
+ """
+ if not isinstance(indices, tuple):
+ indices = (indices,)
+ return _ffi_api.ExprDocIndex(self, indices) # type: ignore
+
+ def __iter__(self):
+ """
+ This is implemented to prevent confusing error message when trying to use ExprDoc
+ as iterable. According to PEP-234, An object can be iterated over if it
+ implements __iter__() or __getitem__(). If an object has only __getitem__
+ but not __iter__, interpreter will iterate the object by calling
+ __getitem__ with 0, 1, 2, ..., until an IndexError is raised.
+
+ https://peps.python.org/pep-0234/#python-api-specification
+ """
+ raise RuntimeError(f"{self.__class__} cannot be used as iterable.")
+
@tvm._ffi.register_object("script.printer.LiteralDoc")
class LiteralDoc(ExprDoc):
"""Doc that represents literal value"""
- def __init__(self, value):
+ value: Union[str, IntImm, FloatImm, None]
+
+ def __init__(self, value: Union[str, float, bool, int, None]):
if value is None:
self.__init_handle_by_constructor__(_ffi_api.LiteralDocNone) # type: ignore
elif isinstance(value, str):
@@ -47,3 +119,177 @@ class LiteralDoc(ExprDoc):
self.__init_handle_by_constructor__(_ffi_api.LiteralDocInt, value) # type: ignore
else:
raise TypeError(f"Unsupported type {type(value)} for LiteralDoc")
+
+
+@tvm._ffi.register_object("script.printer.IdDoc")
+class IdDoc(ExprDoc):
+ """Doc that represents identifier"""
+
+ name: str
+
+ def __init__(self, name: str):
+ self.__init_handle_by_constructor__(_ffi_api.IdDoc, name) # type: ignore
+
+
+@tvm._ffi.register_object("script.printer.AttrAccessDoc")
+class AttrAccessDoc(ExprDoc):
+ """Doc that represents attribute access on an expression"""
+
+ value: ExprDoc
+ name: str
+
+ def __init__(self, value: ExprDoc, name: str):
+ self.__init_handle_by_constructor__(_ffi_api.AttrAccessDoc, value, name) # type: ignore
+
+
+@tvm._ffi.register_object("script.printer.IndexDoc")
+class IndexDoc(ExprDoc):
+ """Doc that represents index access on an expression"""
+
+ value: ExprDoc
+ indices: Sequence[Union[ExprDoc, "SliceDoc"]]
+
+ def __init__(self, value: ExprDoc, indices: List[Union[ExprDoc, "SliceDoc"]]):
+ self.__init_handle_by_constructor__(_ffi_api.IndexDoc, value, indices) # type: ignore
+
+
+@tvm._ffi.register_object("script.printer.CallDoc")
+class CallDoc(ExprDoc):
+ """Doc that represents function call"""
+
+ callee: ExprDoc
+ args: Sequence[ExprDoc]
+ kwargs_keys: Sequence[str]
+ kwargs_values: Sequence[ExprDoc]
+
+ def __init__(self, callee: ExprDoc, *args: Tuple[ExprDoc], **kwargs: Dict[str, ExprDoc]):
+ kwargs_keys = list(kwargs.keys())
+ kwargs_values = list(kwargs.values())
+ self.__init_handle_by_constructor__(
+ _ffi_api.CallDoc, callee, args, kwargs_keys, kwargs_values # type: ignore
+ )
+
+
+@unique
+class OperationKind(IntEnum):
+ """
+ This enum represents the kind of operation (operator) in OpeartionDoc
+
+ It's mirrored from OperationDocNode::Kind at include/tvm/script/printer/doc.h
+ """
+
+ # The name convention follows https://docs.python.org/3/library/ast.html
+ # pylint: disable=invalid-name
+
+ _UnaryStart = 0
+ USub = 1
+ Invert = 2
+ _UnaryEnd = 3
+
+ _BinaryStart = 4
+ Add = 5
+ Sub = 6
+ Mult = 7
+ Div = 8
+ FloorDiv = 9
+ Mod = 10
+ Pow = 11
+ LShift = 12
+ RShift = 13
+ BitAnd = 14
+ BitOr = 15
+ BitXor = 16
+ Lt = 17
+ LtE = 18
+ Eq = 19
+ NotEq = 20
+ Gt = 21
+ GtE = 22
+ _BinaryEnd = 23
+
+ _SpecialStart = 24
+ IfThenElse = 25
+ _SpecialEnd = 26
+
+ # pylint: enable=invalid-name
+
+
+@tvm._ffi.register_object("script.printer.OperationDoc")
+class OperationDoc(ExprDoc):
+ """
+ Doc that represents operation
+
+ It can be unary, binary and other special operators (for example, the
+ if-then-else expression).
+ """
+
+ kind: OperationKind
+ operands: Sequence[ExprDoc]
+
+ def __init__(self, kind: OperationKind, operands: List[ExprDoc]):
+ self.__init_handle_by_constructor__(_ffi_api.OperationDoc, kind, operands) # type: ignore
+
+
+@tvm._ffi.register_object("script.printer.LambdaDoc")
+class LambdaDoc(ExprDoc):
+ """Doc that represents lambda function"""
+
+ args: Sequence[IdDoc]
+ body: ExprDoc
+
+ def __init__(self, args: List[IdDoc], body: ExprDoc):
+ self.__init_handle_by_constructor__(_ffi_api.LambdaDoc, args, body) # type: ignore
+
+
+@tvm._ffi.register_object("script.printer.TupleDoc")
+class TupleDoc(ExprDoc):
+ """Doc that represents tuple literal"""
+
+ elements: Sequence[ExprDoc]
+
+ def __init__(self, elements: List[ExprDoc]):
+ self.__init_handle_by_constructor__(_ffi_api.TupleDoc, elements) # type: ignore
+
+
+@tvm._ffi.register_object("script.printer.ListDoc")
+class ListDoc(ExprDoc):
+ """Doc that represents list literal"""
+
+ elements: Sequence[ExprDoc]
+
+ def __init__(self, elements: List[ExprDoc]):
+ self.__init_handle_by_constructor__(_ffi_api.ListDoc, elements) # type: ignore
+
+
+@tvm._ffi.register_object("script.printer.DictDoc")
+class DictDoc(ExprDoc):
+ """Doc that represents dict literal"""
+
+ keys: Sequence[ExprDoc]
+ values: Sequence[ExprDoc]
+
+ def __init__(self, content: Dict[ExprDoc, ExprDoc]):
+ keys = list(content.keys())
+ values = list(content.values())
+ self.__init_handle_by_constructor__(_ffi_api.DictDoc, keys, values) # type: ignore
+
+
+@tvm._ffi.register_object("script.printer.SliceDoc")
+class SliceDoc(ExprDoc):
+ """
+ Doc that represents slice in Index expression
+
+ This doc can only appear in `IndexDoc.indices`.
+ """
+
+ start: Optional[ExprDoc]
+ stop: Optional[ExprDoc]
+ step: Optional[ExprDoc]
+
+ def __init__(
+ self,
+ start: Optional[ExprDoc] = None,
+ stop: Optional[ExprDoc] = None,
+ step: Optional[ExprDoc] = None,
+ ):
+ self.__init_handle_by_constructor__(_ffi_api.SliceDoc, start, stop, step) # type: ignore
diff --git a/src/script/printer/base_doc_printer.cc b/src/script/printer/base_doc_printer.cc
index f6874ba1a2..42d3f2d8f3 100644
--- a/src/script/printer/base_doc_printer.cc
+++ b/src/script/printer/base_doc_printer.cc
@@ -38,6 +38,26 @@ String DocPrinter::GetString() const {
void DocPrinter::PrintDoc(const Doc& doc) {
if (const auto* doc_node = doc.as<LiteralDocNode>()) {
PrintTypedDoc(GetRef<LiteralDoc>(doc_node));
+ } else if (const auto* doc_node = doc.as<IdDocNode>()) {
+ PrintTypedDoc(GetRef<IdDoc>(doc_node));
+ } else if (const auto* doc_node = doc.as<AttrAccessDocNode>()) {
+ PrintTypedDoc(GetRef<AttrAccessDoc>(doc_node));
+ } else if (const auto* doc_node = doc.as<IndexDocNode>()) {
+ PrintTypedDoc(GetRef<IndexDoc>(doc_node));
+ } else if (const auto* doc_node = doc.as<OperationDocNode>()) {
+ PrintTypedDoc(GetRef<OperationDoc>(doc_node));
+ } else if (const auto* doc_node = doc.as<CallDocNode>()) {
+ PrintTypedDoc(GetRef<CallDoc>(doc_node));
+ } else if (const auto* doc_node = doc.as<LambdaDocNode>()) {
+ PrintTypedDoc(GetRef<LambdaDoc>(doc_node));
+ } else if (const auto* doc_node = doc.as<ListDocNode>()) {
+ PrintTypedDoc(GetRef<ListDoc>(doc_node));
+ } else if (const auto* doc_node = doc.as<TupleDocNode>()) {
+ PrintTypedDoc(GetRef<TupleDoc>(doc_node));
+ } else if (const auto* doc_node = doc.as<DictDocNode>()) {
+ PrintTypedDoc(GetRef<DictDoc>(doc_node));
+ } else if (const auto* doc_node = doc.as<SliceDocNode>()) {
+ PrintTypedDoc(GetRef<SliceDoc>(doc_node));
} else {
LOG(FATAL) << "Do not know how to print " << doc->GetTypeKey();
throw;
diff --git a/src/script/printer/base_doc_printer.h b/src/script/printer/base_doc_printer.h
index 128fcef2ea..d5bfdcd94c 100644
--- a/src/script/printer/base_doc_printer.h
+++ b/src/script/printer/base_doc_printer.h
@@ -83,6 +83,56 @@ class DocPrinter {
*/
virtual void PrintTypedDoc(const LiteralDoc& doc) = 0;
+ /*!
+ * \brief Virtual method to print a IdDoc
+ */
+ virtual void PrintTypedDoc(const IdDoc& doc) = 0;
+
+ /*!
+ * \brief Virtual method to print a AttrAccessDoc
+ */
+ virtual void PrintTypedDoc(const AttrAccessDoc& doc) = 0;
+
+ /*!
+ * \brief Virtual method to print a IndexDoc
+ */
+ virtual void PrintTypedDoc(const IndexDoc& doc) = 0;
+
+ /*!
+ * \brief Virtual method to print a OperationDoc
+ */
+ virtual void PrintTypedDoc(const OperationDoc& doc) = 0;
+
+ /*!
+ * \brief Virtual method to print a CallDoc
+ */
+ virtual void PrintTypedDoc(const CallDoc& doc) = 0;
+
+ /*!
+ * \brief Virtual method to print a LambdaDoc
+ */
+ virtual void PrintTypedDoc(const LambdaDoc& doc) = 0;
+
+ /*!
+ * \brief Virtual method to print a ListDoc
+ */
+ virtual void PrintTypedDoc(const ListDoc& doc) = 0;
+
+ /*!
+ * \brief Virtual method to print a TupleDoc
+ */
+ virtual void PrintTypedDoc(const TupleDoc& doc) = 0;
+
+ /*!
+ * \brief Virtual method to print a DictDoc
+ */
+ virtual void PrintTypedDoc(const DictDoc& doc) = 0;
+
+ /*!
+ * \brief Virtual method to print a SliceDoc
+ */
+ virtual void PrintTypedDoc(const SliceDoc& doc) = 0;
+
/*!
* \brief Increase the indent level of any content to be
* printed after this call
diff --git a/src/script/printer/doc.cc b/src/script/printer/doc.cc
index e54adbd36b..ed81f9d2dd 100644
--- a/src/script/printer/doc.cc
+++ b/src/script/printer/doc.cc
@@ -23,14 +23,108 @@ namespace tvm {
namespace script {
namespace printer {
+ExprDoc ExprDocNode::Attr(String attr) const { return AttrAccessDoc(GetRef<ExprDoc>(this), attr); }
+
+ExprDoc ExprDocNode::operator[](Array<Doc> indices) const {
+ return IndexDoc(GetRef<ExprDoc>(this), indices);
+}
+
+ExprDoc ExprDocNode::Call(Array<ExprDoc, void> args) const {
+ return CallDoc(GetRef<ExprDoc>(this), args, {}, {});
+}
+
+ExprDoc ExprDocNode::Call(Array<ExprDoc, void> args, Array<String, void> kwargs_keys,
+ Array<ExprDoc, void> kwargs_values) const {
+ return CallDoc(GetRef<ExprDoc>(this), args, kwargs_keys, kwargs_values);
+}
+
LiteralDoc::LiteralDoc(ObjectRef value) {
ObjectPtr<LiteralDocNode> n = make_object<LiteralDocNode>();
n->value = value;
this->data_ = std::move(n);
}
+IdDoc::IdDoc(String name) {
+ ObjectPtr<IdDocNode> n = make_object<IdDocNode>();
+ n->name = name;
+ this->data_ = std::move(n);
+}
+
+AttrAccessDoc::AttrAccessDoc(ExprDoc value, String name) {
+ ObjectPtr<AttrAccessDocNode> n = make_object<AttrAccessDocNode>();
+ n->value = value;
+ n->name = name;
+ this->data_ = std::move(n);
+}
+
+IndexDoc::IndexDoc(ExprDoc value, Array<Doc> indices) {
+ ObjectPtr<IndexDocNode> n = make_object<IndexDocNode>();
+ n->value = value;
+ n->indices = indices;
+ this->data_ = std::move(n);
+}
+
+CallDoc::CallDoc(ExprDoc callee, Array<ExprDoc> args, Array<String> kwargs_keys,
+ Array<ExprDoc> kwargs_values) {
+ ObjectPtr<CallDocNode> n = make_object<CallDocNode>();
+ n->callee = callee;
+ n->args = args;
+ n->kwargs_keys = kwargs_keys;
+ n->kwargs_values = kwargs_values;
+ this->data_ = std::move(n);
+}
+
+OperationDoc::OperationDoc(OperationDocNode::Kind kind, Array<ExprDoc> operands) {
+ ObjectPtr<OperationDocNode> n = make_object<OperationDocNode>();
+ n->kind = kind;
+ n->operands = operands;
+ this->data_ = std::move(n);
+}
+
+LambdaDoc::LambdaDoc(Array<IdDoc> args, ExprDoc body) {
+ ObjectPtr<LambdaDocNode> n = make_object<LambdaDocNode>();
+ n->args = args;
+ n->body = body;
+ this->data_ = std::move(n);
+}
+
+TupleDoc::TupleDoc(Array<ExprDoc> elements) {
+ ObjectPtr<TupleDocNode> n = make_object<TupleDocNode>();
+ n->elements = elements;
+ this->data_ = std::move(n);
+}
+
+ListDoc::ListDoc(Array<ExprDoc> elements) {
+ ObjectPtr<ListDocNode> n = make_object<ListDocNode>();
+ n->elements = elements;
+ this->data_ = std::move(n);
+}
+
+DictDoc::DictDoc(Array<ExprDoc> keys, Array<ExprDoc> values) {
+ ObjectPtr<DictDocNode> n = make_object<DictDocNode>();
+ n->keys = keys;
+ n->values = values;
+ this->data_ = std::move(n);
+}
+
+SliceDoc::SliceDoc(Optional<ExprDoc> start, Optional<ExprDoc> stop, Optional<ExprDoc> step) {
+ ObjectPtr<SliceDocNode> n = make_object<SliceDocNode>();
+ n->start = start;
+ n->stop = stop;
+ n->step = step;
+ this->data_ = std::move(n);
+}
+
TVM_REGISTER_NODE_TYPE(DocNode);
+
TVM_REGISTER_NODE_TYPE(ExprDocNode);
+TVM_REGISTER_GLOBAL("script.printer.ExprDocAttr").set_body_method<ExprDoc>(&ExprDocNode::Attr);
+TVM_REGISTER_GLOBAL("script.printer.ExprDocIndex")
+ .set_body_method<ExprDoc>(&ExprDocNode::operator[]);
+TVM_REGISTER_GLOBAL("script.printer.ExprDocCall")
+ .set_body_method<ExprDoc, ExprDocNode, ExprDoc, Array<ExprDoc>, Array<String>, Array<ExprDoc>>(
+ &ExprDocNode::Call);
+
TVM_REGISTER_NODE_TYPE(LiteralDocNode);
TVM_REGISTER_GLOBAL("script.printer.LiteralDocNone").set_body_typed(LiteralDoc::None);
TVM_REGISTER_GLOBAL("script.printer.LiteralDocInt").set_body_typed(LiteralDoc::Int);
@@ -38,6 +132,59 @@ TVM_REGISTER_GLOBAL("script.printer.LiteralDocBoolean").set_body_typed(LiteralDo
TVM_REGISTER_GLOBAL("script.printer.LiteralDocFloat").set_body_typed(LiteralDoc::Float);
TVM_REGISTER_GLOBAL("script.printer.LiteralDocStr").set_body_typed(LiteralDoc::Str);
+TVM_REGISTER_NODE_TYPE(IdDocNode);
+TVM_REGISTER_GLOBAL("script.printer.IdDoc").set_body_typed([](String name) { return IdDoc(name); });
+
+TVM_REGISTER_NODE_TYPE(AttrAccessDocNode);
+TVM_REGISTER_GLOBAL("script.printer.AttrAccessDoc").set_body_typed([](ExprDoc value, String attr) {
+ return AttrAccessDoc(value, attr);
+});
+
+TVM_REGISTER_NODE_TYPE(IndexDocNode);
+TVM_REGISTER_GLOBAL("script.printer.IndexDoc")
+ .set_body_typed([](ExprDoc value, Array<Doc> indices) { return IndexDoc(value, indices); });
+
+TVM_REGISTER_NODE_TYPE(CallDocNode);
+TVM_REGISTER_GLOBAL("script.printer.CallDoc")
+ .set_body_typed([](ExprDoc callee, //
+ Array<ExprDoc> args, //
+ Array<String> kwargs_keys, //
+ Array<ExprDoc> kwargs_values) {
+ return CallDoc(callee, args, kwargs_keys, kwargs_values);
+ });
+
+TVM_REGISTER_NODE_TYPE(OperationDocNode);
+TVM_REGISTER_GLOBAL("script.printer.OperationDoc")
+ .set_body_typed([](int32_t kind, Array<ExprDoc> operands) {
+ return OperationDoc(OperationDocNode::Kind(kind), operands);
+ });
+
+TVM_REGISTER_NODE_TYPE(LambdaDocNode);
+TVM_REGISTER_GLOBAL("script.printer.LambdaDoc").set_body_typed([](Array<IdDoc> args, ExprDoc body) {
+ return LambdaDoc(args, body);
+});
+
+TVM_REGISTER_NODE_TYPE(TupleDocNode);
+TVM_REGISTER_GLOBAL("script.printer.TupleDoc").set_body_typed([](Array<ExprDoc> elements) {
+ return TupleDoc(elements);
+});
+
+TVM_REGISTER_NODE_TYPE(ListDocNode);
+TVM_REGISTER_GLOBAL("script.printer.ListDoc").set_body_typed([](Array<ExprDoc> elements) {
+ return ListDoc(elements);
+});
+
+TVM_REGISTER_NODE_TYPE(DictDocNode);
+TVM_REGISTER_GLOBAL("script.printer.DictDoc")
+ .set_body_typed([](Array<ExprDoc> keys, Array<ExprDoc> values) {
+ return DictDoc(keys, values);
+ });
+
+TVM_REGISTER_NODE_TYPE(SliceDocNode);
+TVM_REGISTER_GLOBAL("script.printer.SliceDoc")
+ .set_body_typed([](Optional<ExprDoc> start, Optional<ExprDoc> stop, Optional<ExprDoc> step) {
+ return SliceDoc(start, stop, step);
+ });
} // namespace printer
} // namespace script
} // namespace tvm
diff --git a/src/script/printer/python_doc_printer.cc b/src/script/printer/python_doc_printer.cc
index cd816e4f70..5c7b048f81 100644
--- a/src/script/printer/python_doc_printer.cc
+++ b/src/script/printer/python_doc_printer.cc
@@ -17,6 +17,7 @@
* under the License.
*/
+#include <tvm/runtime/logging.h>
#include <tvm/runtime/registry.h>
#include "../../support/str_escape.h"
@@ -34,6 +35,30 @@ class PythonDocPrinter : public DocPrinter {
using DocPrinter::PrintDoc;
void PrintTypedDoc(const LiteralDoc& doc) final;
+ void PrintTypedDoc(const IdDoc& doc) final;
+ void PrintTypedDoc(const AttrAccessDoc& doc) final;
+ void PrintTypedDoc(const IndexDoc& doc) final;
+ void PrintTypedDoc(const OperationDoc& doc) final;
+ void PrintTypedDoc(const CallDoc& doc) final;
+ void PrintTypedDoc(const LambdaDoc& doc) final;
+ void PrintTypedDoc(const ListDoc& doc) final;
+ void PrintTypedDoc(const DictDoc& doc) final;
+ void PrintTypedDoc(const TupleDoc& doc) final;
+ void PrintTypedDoc(const SliceDoc& doc) final;
+
+ private:
+ template <typename DocType>
+ void PrintJoinedDocs(const Array<DocType>& docs, const std::string& separator) {
+ bool is_first = true;
+ for (auto& doc : docs) {
+ if (is_first) {
+ is_first = false;
+ } else {
+ output_ << separator;
+ }
+ PrintDoc(doc);
+ }
+ }
};
void PythonDocPrinter::PrintTypedDoc(const LiteralDoc& doc) {
@@ -57,6 +82,184 @@ void PythonDocPrinter::PrintTypedDoc(const LiteralDoc& doc) {
}
}
+void PythonDocPrinter::PrintTypedDoc(const IdDoc& doc) { output_ << doc->name; }
+
+void PythonDocPrinter::PrintTypedDoc(const AttrAccessDoc& doc) {
+ PrintDoc(doc->value);
+ output_ << "." << doc->name;
+}
+
+void PythonDocPrinter::PrintTypedDoc(const IndexDoc& doc) {
+ PrintDoc(doc->value);
+ if (doc->indices.size() == 0) {
+ output_ << "[()]";
+ } else {
+ output_ << "[";
+ PrintJoinedDocs(doc->indices, ", ");
+ output_ << "]";
+ }
+}
+
+const std::string OperatorToString(OperationDocNode::Kind operation_kind) {
+ static const std::vector<std::string> op_kind2str = []() {
+ using OpKind = OperationDocNode::Kind;
+ std::map<OpKind, std::string> raw_table = {
+ {OpKind::kUSub, "-"}, //
+ {OpKind::kInvert, "~"}, //
+ {OpKind::kAdd, "+"}, //
+ {OpKind::kSub, "-"}, //
+ {OpKind::kMult, "*"}, //
+ {OpKind::kDiv, "/"}, //
+ {OpKind::kFloorDiv, "//"}, //
+ {OpKind::kMod, "%"}, //
+ {OpKind::kPow, "**"}, //
+ {OpKind::kLShift, "<<"}, //
+ {OpKind::kRShift, ">>"}, //
+ {OpKind::kBitAnd, "&"}, //
+ {OpKind::kBitOr, "|"}, //
+ {OpKind::kBitXor, "^"}, //
+ {OpKind::kLt, "<"}, //
+ {OpKind::kLtE, "<="}, //
+ {OpKind::kEq, "=="}, //
+ {OpKind::kNotEq, "!="}, //
+ {OpKind::kGt, ">"}, //
+ {OpKind::kGtE, ">="}, //
+ };
+
+ std::vector<std::string> table;
+ table.resize(static_cast<int>(OperationDocNode::Kind::kSpecialEnd) + 1);
+
+ for (const auto& kv : raw_table) {
+ table[static_cast<int>(kv.first)] = kv.second;
+ }
+
+ return table;
+ }();
+
+ auto op_index = static_cast<int>(operation_kind);
+ ICHECK_LT(op_index, op_kind2str.size());
+ const std::string str = op_kind2str[op_index];
+ ICHECK(!str.empty()) << "OperationDocNode::Kind " << static_cast<int>(operation_kind)
+ << " cannot be converted to operator token in Python directly.";
+ return str;
+}
+
+void PythonDocPrinter::PrintTypedDoc(const OperationDoc& doc) {
+ using OpKind = OperationDocNode::Kind;
+ if (doc->kind < OpKind::kUnaryEnd) {
+ // Unary Operators
+ ICHECK_EQ(doc->operands.size(), 1);
+ output_ << OperatorToString(doc->kind);
+ PrintDoc(doc->operands[0]);
+ } else if (doc->kind < OpKind::kBinaryEnd) {
+ // Binary Operator
+ ICHECK_EQ(doc->operands.size(), 2);
+ PrintDoc(doc->operands[0]);
+ output_ << " " << OperatorToString(doc->kind) << " ";
+ PrintDoc(doc->operands[1]);
+ } else if (doc->kind == OpKind::kIfThenElse) {
+ ICHECK_EQ(doc->operands.size(), 3)
+ << "ValueError: IfThenElse requires 3 operands, but got " << doc->operands.size();
+ PrintDoc(doc->operands[1]);
+ output_ << " if ";
+ PrintDoc(doc->operands[0]);
+ output_ << " else ";
+ PrintDoc(doc->operands[2]);
+ } else {
+ LOG(FATAL) << "Unknown OperationDocNode::Kind " << static_cast<int>(doc->kind);
+ throw;
+ }
+}
+
+void PythonDocPrinter::PrintTypedDoc(const CallDoc& doc) {
+ PrintDoc(doc->callee);
+
+ output_ << "(";
+
+ // Print positional args
+ bool is_first = true;
+ for (const ExprDoc& arg : doc->args) {
+ if (is_first) {
+ is_first = false;
+ } else {
+ output_ << ", ";
+ }
+ PrintDoc(arg);
+ }
+
+ // Print keyword args
+ ICHECK_EQ(doc->kwargs_keys.size(), doc->kwargs_values.size())
+ << "CallDoc should have equal number of elements in kwargs_keys and kwargs_values.";
+ for (size_t i = 0; i < doc->kwargs_keys.size(); i++) {
+ if (is_first) {
+ is_first = false;
+ } else {
+ output_ << ", ";
+ }
+ const String& keyword = doc->kwargs_keys[i];
+ output_ << keyword;
+ output_ << "=";
+ PrintDoc(doc->kwargs_values[i]);
+ }
+
+ output_ << ")";
+}
+
+void PythonDocPrinter::PrintTypedDoc(const LambdaDoc& doc) {
+ output_ << "lambda ";
+ PrintJoinedDocs(doc->args, ", ");
+ output_ << ": ";
+ PrintDoc(doc->body);
+}
+
+void PythonDocPrinter::PrintTypedDoc(const ListDoc& doc) {
+ output_ << "[";
+ PrintJoinedDocs(doc->elements, ", ");
+ output_ << "]";
+}
+
+void PythonDocPrinter::PrintTypedDoc(const TupleDoc& doc) {
+ output_ << "(";
+ if (doc->elements.size() == 1) {
+ PrintDoc(doc->elements[0]);
+ output_ << ",";
+ } else {
+ PrintJoinedDocs(doc->elements, ", ");
+ }
+ output_ << ")";
+}
+
+void PythonDocPrinter::PrintTypedDoc(const DictDoc& doc) {
+ ICHECK_EQ(doc->keys.size(), doc->values.size())
+ << "DictDoc should have equal number of elements in keys and values.";
+ output_ << "{";
+ size_t idx = 0;
+ for (const ExprDoc& key : doc->keys) {
+ if (idx > 0) {
+ output_ << ", ";
+ }
+ PrintDoc(key);
+ output_ << ": ";
+ PrintDoc(doc->values[idx]);
+ idx++;
+ }
+ output_ << "}";
+}
+
+void PythonDocPrinter::PrintTypedDoc(const SliceDoc& doc) {
+ if (doc->start != nullptr) {
+ PrintDoc(doc->start.value());
+ }
+ output_ << ":";
+ if (doc->stop != nullptr) {
+ PrintDoc(doc->stop.value());
+ }
+ if (doc->step != nullptr) {
+ output_ << ":";
+ PrintDoc(doc->step.value());
+ }
+}
+
String DocToPythonScript(Doc doc, int indent_spaces) {
PythonDocPrinter printer(indent_spaces);
printer.Append(doc);
diff --git a/tests/python/unittest/test_tvmscript_printer_doc.py b/tests/python/unittest/test_tvmscript_printer_doc.py
index 6330d33bf2..4ff6a0f547 100644
--- a/tests/python/unittest/test_tvmscript_printer_doc.py
+++ b/tests/python/unittest/test_tvmscript_printer_doc.py
@@ -16,8 +16,20 @@
# under the License.
import pytest
-from tvm.tir import IntImm
-from tvm.script.printer.doc import LiteralDoc
+from tvm.script.printer.doc import (
+ LiteralDoc,
+ IdDoc,
+ AttrAccessDoc,
+ IndexDoc,
+ CallDoc,
+ OperationKind,
+ OperationDoc,
+ LambdaDoc,
+ TupleDoc,
+ ListDoc,
+ DictDoc,
+ SliceDoc,
+)
@pytest.mark.parametrize(
@@ -26,8 +38,209 @@ from tvm.script.printer.doc import LiteralDoc
)
def test_literal_doc_construction(value):
doc = LiteralDoc(value)
+
if isinstance(value, float):
# FloatImm cannot be compared with Python's float directly
assert float(doc.value) == pytest.approx(value)
else:
assert doc.value == value
+
+
+def test_id_doc():
+ doc = IdDoc("name")
+
+ assert doc.name == "name"
+
+
+def test_attr_access_doc():
+ target = IdDoc("x")
+
+ doc = AttrAccessDoc(target, "attribute")
+
+ assert doc.value == target
+ assert doc.name == "attribute"
+
+
+@pytest.mark.parametrize(
+ "indices",
+ [
+ [],
+ [LiteralDoc(1)],
+ [LiteralDoc(2), IdDoc("x")],
+ [SliceDoc(LiteralDoc(1), LiteralDoc(2))],
+ [SliceDoc(LiteralDoc(1)), IdDoc("y")],
+ ],
+)
+def test_index_doc(indices):
+ target = IdDoc("x")
+
+ doc = IndexDoc(target, indices)
+
+ assert doc.value == target
+ assert list(doc.indices) == indices
+
+
+@pytest.mark.parametrize(
+ "args, kwargs",
+ [
+ ([], {}),
+ ([LiteralDoc("arg")], {}),
+ ([LiteralDoc("arg"), IdDoc("x")], {}),
+ ([], {"x": LiteralDoc("x")}),
+ ([], {"x": LiteralDoc("x"), "y": LiteralDoc("y")}),
+ ([LiteralDoc("arg")], {"x": LiteralDoc("x"), "y": LiteralDoc("y")}),
+ ([LiteralDoc("arg"), IdDoc("x")], {"x": LiteralDoc("x"), "y": LiteralDoc("y")}),
+ ],
+)
+def test_call_doc(args, kwargs):
+ target = IdDoc("x")
+
+ doc = CallDoc(target, *args, **kwargs)
+
+ assert doc.callee == target
+ assert list(doc.args) == args
+ assert dict(zip(doc.kwargs_keys, doc.kwargs_values)) == kwargs
+
+
+@pytest.mark.parametrize(
+ "operands",
+ [
+ [],
+ [LiteralDoc(1)],
+ [LiteralDoc(2), IdDoc("x")],
+ [LiteralDoc(2), IdDoc("x"), LiteralDoc("y")],
+ ],
+)
+def test_operation_doc(operands):
+ # Here we just test the contructor and attr visitor of OperationDoc
+ # so the choice of OperationKind doesn't matter
+ operator = OperationKind.Add
+
+ doc = OperationDoc(OperationKind.Add, operands)
+
+ assert doc.kind == operator
+ assert list(doc.operands) == operands
+
+
+@pytest.mark.parametrize(
+ "args",
+ [
+ [],
+ [IdDoc("x")],
+ [IdDoc("x"), IdDoc("y")],
+ ],
+)
+def test_lambda_doc(args):
+ body = LiteralDoc(1)
+
+ doc = LambdaDoc(args, body)
+
+ assert doc.body == body
+ assert list(doc.args) == args
+
+
+@pytest.mark.parametrize(
+ "elements",
+ [
+ [],
+ [IdDoc("x")],
+ [IdDoc("x"), IdDoc("y")],
+ ],
+)
+def test_tuple_doc(elements):
+ doc = TupleDoc(elements)
+
+ assert list(doc.elements) == elements
+
+
+@pytest.mark.parametrize(
+ "elements",
+ [
+ [],
+ [IdDoc("x")],
+ [IdDoc("x"), IdDoc("y")],
+ ],
+)
+def test_list_doc(elements):
+ doc = ListDoc(elements)
+
+ assert list(doc.elements) == elements
+
+
+@pytest.mark.parametrize(
+ "content",
+ [
+ {},
+ {LiteralDoc("k"): IdDoc("v")},
+ {LiteralDoc("k"): IdDoc("v"), LiteralDoc("k2"): IdDoc("v2")},
+ ],
+)
+def test_dict_doc(content):
+ doc = DictDoc(content)
+
+ assert dict(zip(doc.keys, doc.values)) == content
+
+
+@pytest.mark.parametrize("start", [LiteralDoc(1), None])
+@pytest.mark.parametrize("stop", [LiteralDoc(2), None])
+@pytest.mark.parametrize("step", [LiteralDoc(3), None])
+def test_slice_doc(start, stop, step):
+ doc = SliceDoc(start, stop)
+
+ assert doc.start == start
+ assert doc.stop == stop
+
+
+def test_expr_doc_attr_access():
+ target = IdDoc("x")
+ attr = "test"
+
+ doc = target.attr(attr)
+
+ assert doc.value == target
+ assert doc.name == attr
+
+
+@pytest.mark.parametrize(
+ "indices",
+ [
+ (),
+ LiteralDoc(1),
+ SliceDoc(LiteralDoc(1), LiteralDoc(2)),
+ (LiteralDoc(1),),
+ (LiteralDoc(2), IdDoc("x")),
+ (SliceDoc(LiteralDoc(1), LiteralDoc(2)),),
+ (SliceDoc(LiteralDoc(1)), IdDoc("y")),
+ ],
+)
+def test_expr_doc_get_item(indices):
+ target = IdDoc("x")
+
+ doc = target[indices]
+
+ assert doc.value == target
+ if not isinstance(indices, tuple):
+ indices = (indices,)
+ assert tuple(doc.indices) == indices
+
+
+@pytest.mark.parametrize(
+ "args, kwargs",
+ [
+ ([], {}),
+ ([LiteralDoc("arg")], {}),
+ ([LiteralDoc("arg"), IdDoc("x")], {}),
+ ([], {"x": LiteralDoc("x")}),
+ ([], {"x": LiteralDoc("x"), "y": LiteralDoc("y")}),
+ ([LiteralDoc("arg")], {"x": LiteralDoc("x"), "y": LiteralDoc("y")}),
+ ([LiteralDoc("arg"), IdDoc("x")], {"x": LiteralDoc("x"), "y": LiteralDoc("y")}),
+ ],
+)
+def test_expr_doc_call_with(args, kwargs):
+ target = IdDoc("x")
+
+ doc = target.call(*args, **kwargs)
+
+ assert doc.callee == target
+ assert list(doc.args) == args
+ assert dict(zip(doc.kwargs_keys, doc.kwargs_values)) == kwargs
diff --git a/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py b/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py
index 55b5e88c88..b65eaa6b98 100644
--- a/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py
+++ b/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py
@@ -16,8 +16,19 @@
# under the License.
import pytest
+from tvm.script.printer.doc import (
+ CallDoc,
+ DictDoc,
+ IdDoc,
+ LambdaDoc,
+ ListDoc,
+ LiteralDoc,
+ OperationDoc,
+ OperationKind,
+ SliceDoc,
+ TupleDoc,
+)
from tvm.script.printer.doc_printer import to_python_script
-from tvm.script.printer.doc import LiteralDoc
def format_script(s: str) -> str:
@@ -28,7 +39,7 @@ def format_script(s: str) -> str:
non_empty_lines = [line for line in s.splitlines() if line and not line.isspace()]
line_indents = [len(line) - len(line.lstrip(" ")) for line in non_empty_lines]
spaces_to_remove = min(line_indents)
- return "\n".join(line[spaces_to_remove:] for line in s.splitlines())
+ return "\n".join(line[spaces_to_remove:] for line in s.splitlines()) + "\n"
@pytest.mark.parametrize(
@@ -50,4 +61,367 @@ def format_script(s: str) -> str:
],
)
def test_print_literal_doc(doc, expected):
- assert to_python_script(doc).rstrip("\n") == format_script(expected)
+ assert to_python_script(doc) == format_script(expected)
+
+
+@pytest.mark.parametrize(
+ "name",
+ [
+ "test",
+ "_test",
+ "TestCase",
+ "test_case",
+ "test123",
+ ],
+)
+def test_print_id_doc(name):
+ doc = IdDoc(name)
+ assert to_python_script(doc) == format_script(name)
+
+
+@pytest.mark.parametrize(
+ "attr",
+ [
+ "attr",
+ "_attr",
+ "Attr",
+ "attr_1",
+ ],
+)
+def test_print_attr_doc(attr):
+ doc = IdDoc("x").attr(attr)
+ assert to_python_script(doc) == format_script(f"x.{attr}")
+
+
+@pytest.mark.parametrize(
+ "indices, expected",
+ [
+ (
+ (),
+ "[()]",
+ ),
+ (
+ (LiteralDoc(1),),
+ "[1]",
+ ),
+ (
+ (LiteralDoc(2), IdDoc("x")),
+ "[2, x]",
+ ),
+ (
+ (SliceDoc(LiteralDoc(1), LiteralDoc(2)),),
+ "[1:2]",
+ ),
+ (
+ (SliceDoc(LiteralDoc(1)), IdDoc("y")),
+ "[1:, y]",
+ ),
+ (
+ (SliceDoc(), IdDoc("y")),
+ "[:, y]",
+ ),
+ (
+ (IdDoc("x"), IdDoc("y"), IdDoc("z")),
+ "[x, y, z]",
+ ),
+ ],
+)
+def test_print_index_doc(indices, expected):
+ doc = IdDoc("x")[indices]
+ assert to_python_script(doc) == format_script(f"x{expected}")
+
+
+UNARY_OP_TOKENS = {
+ OperationKind.USub: "-",
+ OperationKind.Invert: "~",
+}
+
+
+@pytest.mark.parametrize(
+ "op_kind, expected_token",
+ list(UNARY_OP_TOKENS.items()),
+ ids=UNARY_OP_TOKENS.keys(),
+)
+def test_print_unary_operation_doc(op_kind, expected_token):
+ doc = OperationDoc(op_kind, [IdDoc("x")])
+ assert to_python_script(doc) == format_script(f"{expected_token}x")
+
+
+BINARY_OP_TOKENS = {
+ OperationKind.Add: "+",
+ OperationKind.Sub: "-",
+ OperationKind.Mult: "*",
+ OperationKind.Div: "/",
+ OperationKind.FloorDiv: "//",
+ OperationKind.Mod: "%",
+ OperationKind.Pow: "**",
+ OperationKind.LShift: "<<",
+ OperationKind.RShift: ">>",
+ OperationKind.BitAnd: "&",
+ OperationKind.BitOr: "|",
+ OperationKind.BitXor: "^",
+ OperationKind.Lt: "<",
+ OperationKind.LtE: "<=",
+ OperationKind.Eq: "==",
+ OperationKind.NotEq: "!=",
+ OperationKind.Gt: ">",
+ OperationKind.GtE: ">=",
+}
+
+
+@pytest.mark.parametrize(
+ "op_kind, expected_token",
+ list(BINARY_OP_TOKENS.items()),
+ ids=BINARY_OP_TOKENS.keys(),
+)
+def test_print_binary_operation_doc(op_kind, expected_token):
+ doc = OperationDoc(op_kind, [IdDoc("x"), IdDoc("y")])
+ assert to_python_script(doc) == format_script(f"x {expected_token} y")
+
+
+SPECIAL_OP_CASES = [
+ (
+ OperationKind.IfThenElse,
+ [LiteralDoc(True), LiteralDoc("true"), LiteralDoc("false")],
+ '"true" if True else "false"',
+ ),
+ (
+ OperationKind.IfThenElse,
+ [IdDoc("x"), LiteralDoc(None), LiteralDoc(1)],
+ "None if x else 1",
+ ),
+]
+
+
+@pytest.mark.parametrize(
+ "op_kind, operands, expected", SPECIAL_OP_CASES, ids=[kind for (kind, *_) in SPECIAL_OP_CASES]
+)
+def test_print_special_operation_doc(op_kind, operands, expected):
+ doc = OperationDoc(op_kind, operands)
+ assert to_python_script(doc) == format_script(expected)
+
+
+def test_operation_doc_test_exhaustive():
+ special_op_covered = {k for k, *_ in SPECIAL_OP_CASES}
+ for op_kind in OperationKind:
+ if OperationKind._UnaryStart < op_kind < OperationKind._UnaryEnd:
+ assert op_kind in UNARY_OP_TOKENS, (
+ f"{op_kind.name} not covered in test_print_unary_operation_doc. "
+ f"Please add the expected token to UNARY_OP_TOKENS"
+ )
+ elif OperationKind._BinaryStart < op_kind < OperationKind._BinaryEnd:
+ assert op_kind in BINARY_OP_TOKENS, (
+ f"{op_kind.name} not covered in test_print_binary_operation_doc. "
+ f"Please add the expected token to BINARY_OP_TOKENS"
+ )
+ elif not op_kind.name.startswith("_"):
+ # Special Op
+ assert op_kind in special_op_covered, (
+ f"{op_kind.name} not covered in test_print_special_operation_doc. "
+ f"Please add the test cases for it to SPECIAL_OP_CASES"
+ )
+
+
+@pytest.mark.parametrize(
+ "args, kwargs, expected",
+ [
+ (
+ (),
+ {},
+ "()",
+ ),
+ (
+ (),
+ {"key0": IdDoc("u")},
+ "(key0=u)",
+ ),
+ (
+ (),
+ {"key0": IdDoc("u"), "key1": IdDoc("v")},
+ "(key0=u, key1=v)",
+ ),
+ (
+ (IdDoc("x"),),
+ {},
+ "(x)",
+ ),
+ (
+ (IdDoc("x"),),
+ {"key0": IdDoc("u")},
+ "(x, key0=u)",
+ ),
+ (
+ (IdDoc("x"),),
+ {"key0": IdDoc("u"), "key1": IdDoc("v")},
+ "(x, key0=u, key1=v)",
+ ),
+ (
+ (IdDoc("x"), (IdDoc("y"))),
+ {},
+ "(x, y)",
+ ),
+ (
+ (IdDoc("x"), (IdDoc("y"))),
+ {"key0": IdDoc("u")},
+ "(x, y, key0=u)",
+ ),
+ (
+ (IdDoc("x"), (IdDoc("y"))),
+ {"key0": IdDoc("u"), "key1": IdDoc("v")},
+ "(x, y, key0=u, key1=v)",
+ ),
+ ],
+)
+def test_print_call_doc(args, kwargs, expected):
+ doc = CallDoc(IdDoc("f"), *args, **kwargs)
+ assert to_python_script(doc) == format_script(f"f{expected}")
+
+
+@pytest.mark.parametrize(
+ "args, expected",
+ [
+ (
+ (),
+ "lambda : 0",
+ ),
+ (
+ (IdDoc("x"),),
+ "lambda x: 0",
+ ),
+ (
+ (IdDoc("x"), IdDoc("y")),
+ "lambda x, y: 0",
+ ),
+ (
+ (IdDoc("x"), IdDoc("y"), IdDoc("z")),
+ "lambda x, y, z: 0",
+ ),
+ ],
+)
+def test_print_lambda_doc(args, expected):
+ doc = LambdaDoc(args, body=LiteralDoc(0))
+ assert to_python_script(doc) == format_script(expected)
+
+
+@pytest.mark.parametrize(
+ "elements, expected",
+ [
+ (
+ (),
+ "[]",
+ ),
+ (
+ [IdDoc("x")],
+ "[x]",
+ ),
+ (
+ [IdDoc("x"), IdDoc("y")],
+ "[x, y]",
+ ),
+ (
+ [IdDoc("x"), IdDoc("y"), IdDoc("z")],
+ "[x, y, z]",
+ ),
+ ],
+)
+def test_print_list_doc(elements, expected):
+ doc = ListDoc(elements)
+ assert to_python_script(doc) == format_script(expected)
+
+
+@pytest.mark.parametrize(
+ "elements, expected",
+ [
+ (
+ (),
+ "()",
+ ),
+ (
+ [IdDoc("x")],
+ "(x,)",
+ ),
+ (
+ [IdDoc("x"), IdDoc("y")],
+ "(x, y)",
+ ),
+ (
+ [IdDoc("x"), IdDoc("y"), IdDoc("z")],
+ "(x, y, z)",
+ ),
+ ],
+)
+def test_print_tuple_doc(elements, expected):
+ doc = TupleDoc(elements)
+ assert to_python_script(doc) == format_script(expected)
+
+
+@pytest.mark.parametrize(
+ "content, expected",
+ [
+ (
+ {},
+ "{}",
+ ),
+ (
+ {LiteralDoc("key_x"): IdDoc("x")},
+ '{"key_x": x}',
+ ),
+ (
+ {LiteralDoc("key_x"): IdDoc("x"), LiteralDoc("key_y"): IdDoc("y")},
+ '{"key_x": x, "key_y": y}',
+ ),
+ (
+ {
+ LiteralDoc("key_x"): IdDoc("x"),
+ LiteralDoc("key_y"): IdDoc("y"),
+ LiteralDoc("key_z"): IdDoc("z"),
+ },
+ '{"key_x": x, "key_y": y, "key_z": z}',
+ ),
+ ],
+)
+def test_print_dict_doc(content, expected):
+ doc = DictDoc(content)
+ assert to_python_script(doc) == format_script(expected)
+
+
+@pytest.mark.parametrize(
+ "slice_doc, expected",
+ [
+ (
+ SliceDoc(),
+ ":",
+ ),
+ (
+ SliceDoc(LiteralDoc(1)),
+ "1:",
+ ),
+ (
+ SliceDoc(None, LiteralDoc(2)),
+ ":2",
+ ),
+ (
+ SliceDoc(LiteralDoc(1), LiteralDoc(2)),
+ "1:2",
+ ),
+ (
+ SliceDoc(None, None, LiteralDoc(3)),
+ "::3",
+ ),
+ (
+ SliceDoc(LiteralDoc(1), None, LiteralDoc(3)),
+ "1::3",
+ ),
+ (
+ SliceDoc(None, LiteralDoc(2), LiteralDoc(3)),
+ ":2:3",
+ ),
+ (
+ SliceDoc(LiteralDoc(1), LiteralDoc(2), LiteralDoc(3)),
+ "1:2:3",
+ ),
+ ],
+)
+def test_print_slice_doc(slice_doc, expected):
+ doc = IdDoc("x")[slice_doc]
+ assert to_python_script(doc) == format_script(f"x[{expected}]")