You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2022/07/27 23:56:20 UTC
[tvm] branch main updated: [TVMScript] StmtDoc Definitions (#12111)
This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new fcec5f4a76 [TVMScript] StmtDoc Definitions (#12111)
fcec5f4a76 is described below
commit fcec5f4a763280dbb42f027654fd47088d0b3c5a
Author: Lite Ye <ye...@gmail.com>
AuthorDate: Wed Jul 27 19:56:15 2022 -0400
[TVMScript] StmtDoc Definitions (#12111)
This PR addes:
- All StmtDoc subclasses
- Python bindings for StmtDoc
Tracking issue: https://github.com/apache/tvm/issues/11912
---
include/tvm/script/printer/doc.h | 506 +++++++++++++++++++++
python/tvm/script/printer/_ffi_api.py | 2 +-
python/tvm/script/printer/doc.py | 248 ++++++++--
python/tvm/script/printer/doc_printer.py | 5 +-
src/script/printer/doc.cc | 170 +++++++
.../python/unittest/test_tvmscript_printer_doc.py | 260 +++++++++++
6 files changed, 1154 insertions(+), 37 deletions(-)
diff --git a/include/tvm/script/printer/doc.h b/include/tvm/script/printer/doc.h
index f3f980e53f..e3dd83743e 100644
--- a/include/tvm/script/printer/doc.h
+++ b/include/tvm/script/printer/doc.h
@@ -119,6 +119,79 @@ class ExprDoc : public Doc {
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ExprDoc, Doc, ExprDocNode);
};
+/*!
+ * \brief The base class of statement doc.
+ *
+ * \sa StmtDoc
+ */
+class StmtDocNode : public DocNode {
+ public:
+ /*!
+ * \brief The comment of this doc.
+ *
+ * The actual position of the comment depends on the type of Doc
+ * and also the DocPrinter implementation. It could be on the same
+ * line as the statement, or the line above, or inside the statement
+ * if it spans over multiple lines.
+ * */
+ mutable Optional<String> comment{NullOpt};
+
+ void VisitAttrs(AttrVisitor* v) {
+ DocNode::VisitAttrs(v);
+ v->Visit("comment", &comment);
+ }
+
+ static constexpr const char* _type_key = "script.printer.StmtDoc";
+ TVM_DECLARE_BASE_OBJECT_INFO(StmtDocNode, DocNode);
+};
+
+/*!
+ * \brief Reference type of StmtDocNode.
+ *
+ * \sa StmtDocNode
+ */
+class StmtDoc : public Doc {
+ protected:
+ StmtDoc() = default;
+
+ public:
+ TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(StmtDoc, Doc, StmtDocNode);
+};
+
+/*!
+ * \brief The container doc that holds a list of StmtDoc.
+ * \note `StmtBlockDoc` is never used in the IR, but a temporary container that allows holding a
+ * list of StmtDoc.
+ * \sa StmtBlockDoc
+ */
+class StmtBlockDocNode : public DocNode {
+ public:
+ /*! \brief The list of statements. */
+ Array<StmtDoc> stmts;
+
+ void VisitAttrs(AttrVisitor* v) {
+ DocNode::VisitAttrs(v);
+ v->Visit("stmts", &stmts);
+ }
+
+ static constexpr const char* _type_key = "script.printer.StmtBlockDoc";
+ TVM_DECLARE_FINAL_OBJECT_INFO(StmtBlockDocNode, DocNode);
+};
+
+/*!
+ * \brief Reference type of StmtBlockDocNode.
+ * \sa StmtBlockDocNode
+ */
+class StmtBlockDoc : public Doc {
+ public:
+ /*!
+ * \brief Constructor of StmtBlockDoc.
+ * \param stmts The list of statements.
+ */
+ explicit StmtBlockDoc(Array<StmtDoc> stmts);
+ TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(StmtBlockDoc, Doc, StmtBlockDocNode);
+};
+
/*!
* \brief Doc that represents literal value.
*
@@ -219,6 +292,7 @@ class IdDoc : public ExprDoc {
* \param name The name of identifier.
*/
explicit IdDoc(String name);
+ explicit IdDoc(std::nullptr_t) : ExprDoc(nullptr) {}
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(IdDoc, ExprDoc, IdDocNode);
};
@@ -640,6 +714,438 @@ class SliceDoc : public Doc {
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(SliceDoc, Doc, SliceDocNode);
};
+/*!
+ * \brief Doc that represents assign statement.
+ *
+ * \sa AssignDoc
+ */
+class AssignDocNode : public StmtDocNode {
+ public:
+ /*! \brief The left hand side of the assignment */
+ ExprDoc lhs{nullptr};
+ /*!
+ * \brief The right hand side of the assignment.
+ *
+ * If null, this doc represents declaration, e.g. `A: T.Buffer[(1,2)]`
+ * */
+ Optional<ExprDoc> rhs;
+ /*! \brief The type annotation of this assignment. */
+ Optional<ExprDoc> annotation;
+
+ void VisitAttrs(AttrVisitor* v) {
+ StmtDocNode::VisitAttrs(v);
+ v->Visit("lhs", &lhs);
+ v->Visit("rhs", &rhs);
+ v->Visit("annotation", &annotation);
+ }
+
+ static constexpr const char* _type_key = "script.printer.AssignDoc";
+ TVM_DECLARE_FINAL_OBJECT_INFO(AssignDocNode, StmtDocNode);
+};
+
+/*!
+ * \brief Reference type of AssignDocNode.
+ *
+ * \sa AssignDoc
+ */
+class AssignDoc : public StmtDoc {
+ public:
+ /*!
+ * \brief Constructor of AssignDoc.
+ * \param lhs The left hand side of the assignment.
+ * \param rhs The right hand side of the assignment.
+ * \param annotation The type annotation of this assignment.
+ */
+ explicit AssignDoc(ExprDoc lhs, Optional<ExprDoc> rhs, Optional<ExprDoc> annotation);
+ TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(AssignDoc, StmtDoc, AssignDocNode);
+};
+
+/*!
+ * \brief Doc that represent if-then-else statement.
+ *
+ * \sa IfDoc
+ */
+class IfDocNode : public StmtDocNode {
+ public:
+ /*! \brief The predicate of the if-then-else statement. */
+ ExprDoc predicate{nullptr};
+ /*! \brief The then branch of the if-then-else statement. */
+ Array<StmtDoc> then_branch;
+ /*! \brief The else branch of the if-then-else statement. */
+ Array<StmtDoc> else_branch;
+
+ void VisitAttrs(AttrVisitor* v) {
+ StmtDocNode::VisitAttrs(v);
+ v->Visit("predicate", &predicate);
+ v->Visit("then_branch", &then_branch);
+ v->Visit("else_branch", &else_branch);
+ }
+
+ static constexpr const char* _type_key = "script.printer.IfDoc";
+ TVM_DECLARE_FINAL_OBJECT_INFO(IfDocNode, StmtDocNode);
+};
+
+/*!
+ * \brief Reference type of IfDocNode.
+ *
+ * \sa IfDocNode
+ */
+class IfDoc : public StmtDoc {
+ public:
+ /*!
+ * \brief Constructor of IfDoc.
+ * \param predicate The predicate of the if-then-else statement.
+ * \param then_branch The then branch of the if-then-else statement.
+ * \param else_branch The else branch of the if-then-else statement.
+ */
+ explicit IfDoc(ExprDoc predicate, Array<StmtDoc> then_branch, Array<StmtDoc> else_branch);
+ TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(IfDoc, StmtDoc, IfDocNode);
+};
+
+/*!
+ * \brief Doc that represents while statement.
+ *
+ * \sa WhileDoc
+ */
+class WhileDocNode : public StmtDocNode {
+ public:
+ /*! \brief The predicate of the while statement. */
+ ExprDoc predicate{nullptr};
+ /*! \brief The body of the while statement. */
+ Array<StmtDoc> body;
+
+ void VisitAttrs(AttrVisitor* v) {
+ StmtDocNode::VisitAttrs(v);
+ v->Visit("predicate", &predicate);
+ v->Visit("body", &body);
+ }
+
+ static constexpr const char* _type_key = "script.printer.WhileDoc";
+ TVM_DECLARE_FINAL_OBJECT_INFO(WhileDocNode, StmtDocNode);
+};
+
+/*!
+ * \brief Reference type of WhileDocNode.
+ *
+ * \sa WhileDocNode
+ */
+class WhileDoc : public StmtDoc {
+ public:
+ /*!
+ * \brief Constructor of WhileDoc.
+ * \param predicate The predicate of the while statement.
+ * \param body The body of the while statement.
+ */
+ explicit WhileDoc(ExprDoc predicate, Array<StmtDoc> body);
+ TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(WhileDoc, StmtDoc, WhileDocNode);
+};
+
+/*!
+ * \brief Doc that represents for statement.
+ *
+ * Example:
+ * for 'lhs' in 'rhs':
+ * 'body...'
+ *
+ * \sa ForDoc
+ */
+class ForDocNode : public StmtDocNode {
+ public:
+ /*! \brief The left hand side of the assignment of iterating variable. */
+ ExprDoc lhs{nullptr};
+ /*! \brief The right hand side of the assignment of iterating variable. */
+ ExprDoc rhs{nullptr};
+ /*! \brief The body of the for statement. */
+ Array<StmtDoc> body;
+
+ void VisitAttrs(AttrVisitor* v) {
+ StmtDocNode::VisitAttrs(v);
+ v->Visit("lhs", &lhs);
+ v->Visit("rhs", &rhs);
+ v->Visit("body", &body);
+ }
+
+ static constexpr const char* _type_key = "script.printer.ForDoc";
+ TVM_DECLARE_FINAL_OBJECT_INFO(ForDocNode, StmtDocNode);
+};
+
+/*!
+ * \brief Reference type of ForDocNode.
+ *
+ * \sa ForDocNode
+ */
+class ForDoc : public StmtDoc {
+ public:
+ /*!
+ * \brief Constructor of ForDoc.
+ * \param lhs The left hand side of the assignment of iterating variable.
+ * \param rhs The right hand side of the assignment of iterating variable.
+ * \param body The body of the for statement.
+ */
+ explicit ForDoc(ExprDoc lhs, ExprDoc rhs, Array<StmtDoc> body);
+ TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ForDoc, StmtDoc, ForDocNode);
+};
+
+/*!
+ * \brief Doc that represents special scopes.
+ *
+ * Specifically, this means the with statement in Python:
+ *
+ * with 'rhs' as 'lhs':
+ * 'body...'
+ *
+ * \sa ScopeDoc
+ */
+class ScopeDocNode : public StmtDocNode {
+ public:
+ /*! \brief The name of the scoped variable. */
+ Optional<ExprDoc> lhs{NullOpt};
+ /*! \brief The value of the scoped variable. */
+ ExprDoc rhs{nullptr};
+ /*! \brief The body of the scope doc. */
+ Array<StmtDoc> body;
+
+ void VisitAttrs(AttrVisitor* v) {
+ StmtDocNode::VisitAttrs(v);
+ v->Visit("lhs", &lhs);
+ v->Visit("rhs", &rhs);
+ v->Visit("body", &body);
+ }
+
+ static constexpr const char* _type_key = "script.printer.ScopeDoc";
+ TVM_DECLARE_FINAL_OBJECT_INFO(ScopeDocNode, StmtDocNode);
+};
+
+/*!
+ * \brief Reference type of ScopeDocNode.
+ *
+ * \sa ScopeDocNode
+ */
+class ScopeDoc : public StmtDoc {
+ public:
+ /*!
+ * \brief Constructor of ScopeDoc.
+ * \param lhs The name of the scoped variable.
+ * \param rhs The value of the scoped variable.
+ * \param body The body of the scope doc.
+ */
+ explicit ScopeDoc(Optional<ExprDoc> lhs, ExprDoc rhs, Array<StmtDoc> body);
+
+ /*!
+ * \brief Constructor of ScopeDoc.
+ * \param rhs The value of the scoped variable.
+ * \param body The body of the scope doc.
+ */
+ explicit ScopeDoc(ExprDoc rhs, Array<StmtDoc> body);
+
+ TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ScopeDoc, StmtDoc, ScopeDocNode);
+};
+
+/*!
+ * \brief Doc that represents an expression as statement.
+ *
+ * \sa ExprStmtDoc
+ */
+class ExprStmtDocNode : public StmtDocNode {
+ public:
+ /*! \brief The expression represented by this doc. */
+ ExprDoc expr{nullptr};
+
+ void VisitAttrs(AttrVisitor* v) {
+ StmtDocNode::VisitAttrs(v);
+ v->Visit("expr", &expr);
+ }
+
+ static constexpr const char* _type_key = "script.printer.ExprStmtDoc";
+ TVM_DECLARE_FINAL_OBJECT_INFO(ExprStmtDocNode, StmtDocNode);
+};
+
+/*!
+ * \brief Reference type of ExprStmtDocNode.
+ *
+ * \sa ExprStmtDocNode
+ */
+class ExprStmtDoc : public StmtDoc {
+ public:
+ /*!
+ * \brief Constructor of ExprStmtDoc.
+ * \param expr The expression represented by this doc.
+ */
+ explicit ExprStmtDoc(ExprDoc expr);
+ TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ExprStmtDoc, StmtDoc, ExprStmtDocNode);
+};
+
+/*!
+ * \brief Doc that represents assert statement.
+ *
+ * \sa AssertDoc
+ */
+class AssertDocNode : public StmtDocNode {
+ public:
+ /*! \brief The expression to test. */
+ ExprDoc test{nullptr};
+ /*! \brief The optional error message when assertion failed. */
+ Optional<ExprDoc> msg{NullOpt};
+
+ void VisitAttrs(AttrVisitor* v) {
+ StmtDocNode::VisitAttrs(v);
+ v->Visit("test", &test);
+ v->Visit("msg", &msg);
+ }
+
+ static constexpr const char* _type_key = "script.printer.AssertDoc";
+ TVM_DECLARE_FINAL_OBJECT_INFO(AssertDocNode, StmtDocNode);
+};
+
+/*!
+ * \brief Reference type of AssertDocNode.
+ *
+ * \sa AssertDocNode
+ */
+class AssertDoc : public StmtDoc {
+ public:
+ /*!
+ * \brief Constructor of AssertDoc.
+ * \param test The expression to test.
+ * \param msg The optional error message when assertion failed.
+ */
+ explicit AssertDoc(ExprDoc test, Optional<ExprDoc> msg = NullOpt);
+ TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(AssertDoc, StmtDoc, AssertDocNode);
+};
+
+/*!
+ * \brief Doc that represents return statement.
+ *
+ * \sa ReturnDoc
+ */
+class ReturnDocNode : public StmtDocNode {
+ public:
+ /*! \brief The value to return. */
+ ExprDoc value{nullptr};
+
+ void VisitAttrs(AttrVisitor* v) {
+ StmtDocNode::VisitAttrs(v);
+ v->Visit("value", &value);
+ }
+
+ static constexpr const char* _type_key = "script.printer.ReturnDoc";
+ TVM_DECLARE_FINAL_OBJECT_INFO(ReturnDocNode, StmtDocNode);
+};
+
+/*!
+ * \brief Reference type of ReturnDocNode.
+ *
+ * \sa ReturnDocNode
+ */
+class ReturnDoc : public StmtDoc {
+ public:
+ /*!
+ * \brief Constructor of ReturnDoc.
+ * \param value The value to return.
+ */
+ explicit ReturnDoc(ExprDoc value);
+ TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ReturnDoc, StmtDoc, ReturnDocNode);
+};
+
+/*!
+ * \brief Doc that represents function definition.
+ *
+ * \sa FunctionDoc
+ */
+class FunctionDocNode : public StmtDocNode {
+ public:
+ /*! \brief The name of function. */
+ IdDoc name{nullptr};
+ /*!
+ * \brief The arguments of function.
+ *
+ * The `lhs` means argument name,
+ * `annotation` means argument type,
+ * and `rhs` means default value.
+ */
+ Array<AssignDoc> args;
+ /*! \brief Decorators of function. */
+ Array<ExprDoc> decorators;
+ /*! \brief The return type of function. */
+ ExprDoc return_type{nullptr};
+ /*! \brief The body of function. */
+ Array<StmtDoc> body;
+
+ void VisitAttrs(AttrVisitor* v) {
+ StmtDocNode::VisitAttrs(v);
+ v->Visit("name", &name);
+ v->Visit("args", &args);
+ v->Visit("decorators", &decorators);
+ v->Visit("return_type", &return_type);
+ v->Visit("body", &body);
+ }
+
+ static constexpr const char* _type_key = "script.printer.FunctionDoc";
+ TVM_DECLARE_FINAL_OBJECT_INFO(FunctionDocNode, StmtDocNode);
+};
+
+/*!
+ * \brief Reference type of FunctionDocNode.
+ *
+ * \sa FunctionDocNode
+ */
+class FunctionDoc : public StmtDoc {
+ public:
+ /*!
+ * \brief Constructor of FunctionDoc.
+ * \param name The name of function..
+ * \param args The arguments of function.
+ * \param decorators The decorator of function.
+ * \param return_type The return type of function.
+ * \param body The body of function.
+ */
+ explicit FunctionDoc(IdDoc name, Array<AssignDoc> args, Array<ExprDoc> decorators,
+ ExprDoc return_type, Array<StmtDoc> body);
+ TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(FunctionDoc, StmtDoc, FunctionDocNode);
+};
+
+/*!
+ * \brief Doc that represents class definition.
+ *
+ * \sa ClassDoc
+ */
+class ClassDocNode : public StmtDocNode {
+ public:
+ /*! \brief The name of class. */
+ IdDoc name{nullptr};
+ /*! \brief Decorators of class. */
+ Array<ExprDoc> decorators;
+ /*! \brief The body of class. */
+ Array<StmtDoc> body;
+
+ void VisitAttrs(AttrVisitor* v) {
+ StmtDocNode::VisitAttrs(v);
+ v->Visit("name", &name);
+ v->Visit("decorators", &decorators);
+ v->Visit("body", &body);
+ }
+
+ static constexpr const char* _type_key = "script.printer.ClassDoc";
+ TVM_DECLARE_FINAL_OBJECT_INFO(ClassDocNode, StmtDocNode);
+};
+
+/*!
+ * \brief Reference type of ClassDocNode.
+ *
+ * \sa ClassDocNode
+ */
+class ClassDoc : public StmtDoc {
+ public:
+ /*!
+ * \brief Constructor of ClassDoc.
+ * \param name The name of class.
+ * \param decorators The decorator of class.
+ * \param body The body of class.
+ */
+ explicit ClassDoc(IdDoc name, Array<ExprDoc> decorators, Array<StmtDoc> body);
+ TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ClassDoc, StmtDoc, ClassDocNode);
+};
+
} // namespace printer
} // namespace script
} // namespace tvm
diff --git a/python/tvm/script/printer/_ffi_api.py b/python/tvm/script/printer/_ffi_api.py
index baa639fe2d..944ecad01e 100644
--- a/python/tvm/script/printer/_ffi_api.py
+++ b/python/tvm/script/printer/_ffi_api.py
@@ -17,4 +17,4 @@
"""FFI APIs for tvm.script.printer"""
import tvm._ffi
-tvm._ffi._init_api("script.printer", __name__)
+tvm._ffi._init_api("script.printer", __name__) # pylint: disable=protected-access
diff --git a/python/tvm/script/printer/doc.py b/python/tvm/script/printer/doc.py
index acdb63dcf2..747ffc42f1 100644
--- a/python/tvm/script/printer/doc.py
+++ b/python/tvm/script/printer/doc.py
@@ -16,11 +16,10 @@
# under the License.
"""Doc types for TVMScript Unified Printer"""
-from typing import List, Dict, Tuple, Optional, Union, Sequence
from enum import IntEnum, unique
+from typing import Dict, List, Optional, Sequence, Tuple, Union
-import tvm._ffi
-import tvm.ir.container
+from tvm._ffi import register_object
from tvm.runtime import Object
from tvm.tir import FloatImm, IntImm
@@ -47,7 +46,7 @@ class ExprDoc(Object):
-------
doc : AttrAccessDoc
"""
- return _ffi_api.ExprDocAttr(self, name) # type: ignore
+ return _ffi_api.ExprDocAttr(self, name) # type: ignore # pylint: disable=no-member
def call(self, *args: Tuple["ExprDoc"], **kwargs: Dict[str, "ExprDoc"]) -> "CallDoc":
"""
@@ -66,7 +65,7 @@ class ExprDoc(Object):
"""
kwargs_keys = list(kwargs.keys())
kwargs_values = list(kwargs.values())
- return _ffi_api.ExprDocCall(self, args, kwargs_keys, kwargs_values) # type: ignore
+ return _ffi_api.ExprDocCall(self, args, kwargs_keys, kwargs_values) # type: ignore # pylint: disable=no-member
_IndexType = Union["ExprDoc", "SliceDoc"]
@@ -85,7 +84,7 @@ class ExprDoc(Object):
"""
if not isinstance(indices, tuple):
indices = (indices,)
- return _ffi_api.ExprDocIndex(self, indices) # type: ignore
+ return _ffi_api.ExprDocIndex(self, indices) # type: ignore # pylint: disable=no-member
def __iter__(self):
"""
@@ -100,7 +99,34 @@ class ExprDoc(Object):
raise RuntimeError(f"{self.__class__} cannot be used as iterable.")
-@tvm._ffi.register_object("script.printer.LiteralDoc")
+class StmtDoc(Doc):
+ """Base class of statement doc"""
+
+ @property
+ def comment(self) -> Optional[str]:
+ # It has to call the dunder method to avoid infinite recursion
+ return self.__getattr__("comment") # pylint: disable=unnecessary-dunder-call
+
+ @comment.setter
+ def comment(self, value):
+ return _ffi_api.StmtDocSetComment(self, value) # type: ignore # pylint: disable=no-member
+
+
+@register_object("script.printer.StmtBlockDoc")
+class StmtBlockDoc(Doc):
+ """The container doc that holds a list of StmtDoc.
+
+ Note: `StmtBlockDoc` is never used in the IR, but a temporary container that allows holding a
+ list of StmtDoc.
+ """
+
+ stmts: Sequence[StmtDoc]
+
+ def __init__(self, stmts: List[StmtDoc]):
+ self.__init_handle_by_constructor__(_ffi_api.StmtBlockDoc, stmts) # type: ignore # pylint: disable=no-member
+
+
+@register_object("script.printer.LiteralDoc")
class LiteralDoc(ExprDoc):
"""Doc that represents literal value"""
@@ -108,30 +134,30 @@ class LiteralDoc(ExprDoc):
def __init__(self, value: Union[str, float, bool, int, None]):
if value is None:
- self.__init_handle_by_constructor__(_ffi_api.LiteralDocNone) # type: ignore
+ self.__init_handle_by_constructor__(_ffi_api.LiteralDocNone) # type: ignore # pylint: disable=no-member
elif isinstance(value, str):
- self.__init_handle_by_constructor__(_ffi_api.LiteralDocStr, value) # type: ignore
+ self.__init_handle_by_constructor__(_ffi_api.LiteralDocStr, value) # type: ignore # pylint: disable=no-member
elif isinstance(value, float):
- self.__init_handle_by_constructor__(_ffi_api.LiteralDocFloat, value) # type: ignore
+ self.__init_handle_by_constructor__(_ffi_api.LiteralDocFloat, value) # type: ignore # pylint: disable=no-member
elif isinstance(value, bool):
- self.__init_handle_by_constructor__(_ffi_api.LiteralDocBoolean, value) # type: ignore
+ self.__init_handle_by_constructor__(_ffi_api.LiteralDocBoolean, value) # type: ignore # pylint: disable=no-member
elif isinstance(value, int):
- self.__init_handle_by_constructor__(_ffi_api.LiteralDocInt, value) # type: ignore
+ self.__init_handle_by_constructor__(_ffi_api.LiteralDocInt, value) # type: ignore # pylint: disable=no-member
else:
raise TypeError(f"Unsupported type {type(value)} for LiteralDoc")
-@tvm._ffi.register_object("script.printer.IdDoc")
+@register_object("script.printer.IdDoc")
class IdDoc(ExprDoc):
"""Doc that represents identifier"""
name: str
def __init__(self, name: str):
- self.__init_handle_by_constructor__(_ffi_api.IdDoc, name) # type: ignore
+ self.__init_handle_by_constructor__(_ffi_api.IdDoc, name) # type: ignore # pylint: disable=no-member
-@tvm._ffi.register_object("script.printer.AttrAccessDoc")
+@register_object("script.printer.AttrAccessDoc")
class AttrAccessDoc(ExprDoc):
"""Doc that represents attribute access on an expression"""
@@ -139,10 +165,10 @@ class AttrAccessDoc(ExprDoc):
name: str
def __init__(self, value: ExprDoc, name: str):
- self.__init_handle_by_constructor__(_ffi_api.AttrAccessDoc, value, name) # type: ignore
+ self.__init_handle_by_constructor__(_ffi_api.AttrAccessDoc, value, name) # type: ignore # pylint: disable=no-member
-@tvm._ffi.register_object("script.printer.IndexDoc")
+@register_object("script.printer.IndexDoc")
class IndexDoc(ExprDoc):
"""Doc that represents index access on an expression"""
@@ -150,10 +176,10 @@ class IndexDoc(ExprDoc):
indices: Sequence[Union[ExprDoc, "SliceDoc"]]
def __init__(self, value: ExprDoc, indices: List[Union[ExprDoc, "SliceDoc"]]):
- self.__init_handle_by_constructor__(_ffi_api.IndexDoc, value, indices) # type: ignore
+ self.__init_handle_by_constructor__(_ffi_api.IndexDoc, value, indices) # type: ignore # pylint: disable=no-member
-@tvm._ffi.register_object("script.printer.CallDoc")
+@register_object("script.printer.CallDoc")
class CallDoc(ExprDoc):
"""Doc that represents function call"""
@@ -166,14 +192,18 @@ class CallDoc(ExprDoc):
kwargs_keys = list(kwargs.keys())
kwargs_values = list(kwargs.values())
self.__init_handle_by_constructor__(
- _ffi_api.CallDoc, callee, args, kwargs_keys, kwargs_values # type: ignore
+ _ffi_api.CallDoc, # type: ignore # pylint: disable=no-member
+ callee,
+ args,
+ kwargs_keys,
+ kwargs_values,
)
@unique
class OperationKind(IntEnum):
"""
- This enum represents the kind of operation (operator) in OpeartionDoc
+ This enum represents the kind of operation (operator) in OperationDoc
It's mirrored from OperationDocNode::Kind at include/tvm/script/printer/doc.h
"""
@@ -214,7 +244,7 @@ class OperationKind(IntEnum):
# pylint: enable=invalid-name
-@tvm._ffi.register_object("script.printer.OperationDoc")
+@register_object("script.printer.OperationDoc")
class OperationDoc(ExprDoc):
"""
Doc that represents operation
@@ -227,10 +257,10 @@ class OperationDoc(ExprDoc):
operands: Sequence[ExprDoc]
def __init__(self, kind: OperationKind, operands: List[ExprDoc]):
- self.__init_handle_by_constructor__(_ffi_api.OperationDoc, kind, operands) # type: ignore
+ self.__init_handle_by_constructor__(_ffi_api.OperationDoc, kind, operands) # type: ignore # pylint: disable=no-member
-@tvm._ffi.register_object("script.printer.LambdaDoc")
+@register_object("script.printer.LambdaDoc")
class LambdaDoc(ExprDoc):
"""Doc that represents lambda function"""
@@ -238,30 +268,30 @@ class LambdaDoc(ExprDoc):
body: ExprDoc
def __init__(self, args: List[IdDoc], body: ExprDoc):
- self.__init_handle_by_constructor__(_ffi_api.LambdaDoc, args, body) # type: ignore
+ self.__init_handle_by_constructor__(_ffi_api.LambdaDoc, args, body) # type: ignore # pylint: disable=no-member
-@tvm._ffi.register_object("script.printer.TupleDoc")
+@register_object("script.printer.TupleDoc")
class TupleDoc(ExprDoc):
"""Doc that represents tuple literal"""
elements: Sequence[ExprDoc]
def __init__(self, elements: List[ExprDoc]):
- self.__init_handle_by_constructor__(_ffi_api.TupleDoc, elements) # type: ignore
+ self.__init_handle_by_constructor__(_ffi_api.TupleDoc, elements) # type: ignore # pylint: disable=no-member
-@tvm._ffi.register_object("script.printer.ListDoc")
+@register_object("script.printer.ListDoc")
class ListDoc(ExprDoc):
"""Doc that represents list literal"""
elements: Sequence[ExprDoc]
def __init__(self, elements: List[ExprDoc]):
- self.__init_handle_by_constructor__(_ffi_api.ListDoc, elements) # type: ignore
+ self.__init_handle_by_constructor__(_ffi_api.ListDoc, elements) # type: ignore # pylint: disable=no-member
-@tvm._ffi.register_object("script.printer.DictDoc")
+@register_object("script.printer.DictDoc")
class DictDoc(ExprDoc):
"""Doc that represents dict literal"""
@@ -271,10 +301,10 @@ class DictDoc(ExprDoc):
def __init__(self, content: Dict[ExprDoc, ExprDoc]):
keys = list(content.keys())
values = list(content.values())
- self.__init_handle_by_constructor__(_ffi_api.DictDoc, keys, values) # type: ignore
+ self.__init_handle_by_constructor__(_ffi_api.DictDoc, keys, values) # type: ignore # pylint: disable=no-member
-@tvm._ffi.register_object("script.printer.SliceDoc")
+@register_object("script.printer.SliceDoc")
class SliceDoc(ExprDoc):
"""
Doc that represents slice in Index expression
@@ -292,4 +322,156 @@ class SliceDoc(ExprDoc):
stop: Optional[ExprDoc] = None,
step: Optional[ExprDoc] = None,
):
- self.__init_handle_by_constructor__(_ffi_api.SliceDoc, start, stop, step) # type: ignore
+ self.__init_handle_by_constructor__(_ffi_api.SliceDoc, start, stop, step) # type: ignore # pylint: disable=no-member
+
+
+@register_object("script.printer.AssignDoc")
+class AssignDoc(StmtDoc):
+ """Doc that represents assign statement."""
+
+ lhs: ExprDoc
+ rhs: Optional[ExprDoc]
+ annotation: Optional[ExprDoc]
+
+ def __init__(self, lhs: ExprDoc, rhs: Optional[ExprDoc], annotation: Optional[ExprDoc] = None):
+ self.__init_handle_by_constructor__(
+ _ffi_api.AssignDoc, # type: ignore # pylint: disable=no-member
+ lhs,
+ rhs,
+ annotation,
+ )
+
+
+@register_object("script.printer.IfDoc")
+class IfDoc(StmtDoc):
+ """Doc that represent if-then-else statement."""
+
+ predicate: ExprDoc
+ then_branch: Sequence[StmtDoc]
+ else_branch: Sequence[StmtDoc]
+
+ def __init__(self, predicate: ExprDoc, then_branch: List[StmtDoc], else_branch: List[StmtDoc]):
+ self.__init_handle_by_constructor__(
+ _ffi_api.IfDoc, # type: ignore # pylint: disable=no-member
+ predicate,
+ then_branch,
+ else_branch,
+ )
+
+
+@register_object("script.printer.WhileDoc")
+class WhileDoc(StmtDoc):
+ """Doc that represents while statement."""
+
+ predicate: ExprDoc
+ body: Sequence[StmtDoc]
+
+ def __init__(self, predicate: ExprDoc, body: List[StmtDoc]):
+ self.__init_handle_by_constructor__(_ffi_api.WhileDoc, predicate, body) # type: ignore # pylint: disable=no-member
+
+
+@register_object("script.printer.ForDoc")
+class ForDoc(StmtDoc):
+ """Doc that represents for statement."""
+
+ lhs: ExprDoc
+ rhs: ExprDoc
+ body: Sequence[StmtDoc]
+
+ def __init__(self, lhs: ExprDoc, rhs: ExprDoc, body: List[StmtDoc]):
+ self.__init_handle_by_constructor__(_ffi_api.ForDoc, lhs, rhs, body) # type: ignore # pylint: disable=no-member
+
+
+@register_object("script.printer.ScopeDoc")
+class ScopeDoc(StmtDoc):
+ """
+ Doc that represents special scopes.
+
+ Specifically, this means the with statement in Python:
+
+ with <rhs> as <lhs>:
+ <body...>
+ """
+
+ lhs: Optional[ExprDoc]
+ rhs: ExprDoc
+ body: Sequence[StmtDoc]
+
+ def __init__(self, lhs: Optional[ExprDoc], rhs: ExprDoc, body: List[StmtDoc]):
+ self.__init_handle_by_constructor__(_ffi_api.ScopeDoc, lhs, rhs, body) # type: ignore # pylint: disable=no-member
+
+
+@register_object("script.printer.ExprStmtDoc")
+class ExprStmtDoc(StmtDoc):
+ """Doc that represents an expression as statement."""
+
+ expr: ExprDoc
+
+ def __init__(self, expr: ExprDoc):
+ self.__init_handle_by_constructor__(_ffi_api.ExprStmtDoc, expr) # type: ignore # pylint: disable=no-member
+
+
+@register_object("script.printer.AssertDoc")
+class AssertDoc(StmtDoc):
+ """Doc that represents assert statement."""
+
+ test: ExprDoc
+ msg: Optional[ExprDoc]
+
+ def __init__(self, test: ExprDoc, msg: Optional[ExprDoc] = None):
+ self.__init_handle_by_constructor__(_ffi_api.AssertDoc, test, msg) # type: ignore # pylint: disable=no-member
+
+
+@register_object("script.printer.ReturnDoc")
+class ReturnDoc(StmtDoc):
+ """Doc that represents return statement."""
+
+ value: ExprDoc
+
+ def __init__(self, value: ExprDoc):
+ self.__init_handle_by_constructor__(_ffi_api.ReturnDoc, value) # type: ignore # pylint: disable=no-member
+
+
+@register_object("script.printer.FunctionDoc")
+class FunctionDoc(StmtDoc):
+ """Doc that represents function definition."""
+
+ name: IdDoc
+ args: Sequence[AssignDoc]
+ decorators: Sequence[ExprDoc]
+ return_type: ExprDoc
+ body: Sequence[StmtDoc]
+
+ def __init__(
+ self,
+ name: IdDoc,
+ args: List[AssignDoc],
+ decorators: List[ExprDoc],
+ return_type: ExprDoc,
+ body: List[StmtDoc],
+ ):
+ self.__init_handle_by_constructor__(
+ _ffi_api.FunctionDoc, # type: ignore # pylint: disable=no-member
+ name,
+ args,
+ decorators,
+ return_type,
+ body,
+ )
+
+
+@register_object("script.printer.ClassDoc")
+class ClassDoc(StmtDoc):
+ """Doc that represents class definition."""
+
+ name: IdDoc
+ decorators: Sequence[ExprDoc]
+ body: Sequence[StmtDoc]
+
+ def __init__(self, name: IdDoc, decorators: List[ExprDoc], body: List[StmtDoc]):
+ self.__init_handle_by_constructor__(
+ _ffi_api.ClassDoc, # type: ignore # pylint: disable=no-member
+ name,
+ decorators,
+ body,
+ )
diff --git a/python/tvm/script/printer/doc_printer.py b/python/tvm/script/printer/doc_printer.py
index 404632b44c..1cb56ecbf7 100644
--- a/python/tvm/script/printer/doc_printer.py
+++ b/python/tvm/script/printer/doc_printer.py
@@ -21,8 +21,7 @@ from .doc import Doc
def to_python_script(doc: Doc, indent_spaces: int = 4) -> str:
- """
- Convert Doc into Python script.
+ """Convert Doc into Python script.
Parameters
----------
@@ -36,4 +35,4 @@ def to_python_script(doc: Doc, indent_spaces: int = 4) -> str:
script : str
The text representation of Doc in Python syntax
"""
- return _ffi_api.DocToPythonScript(doc, indent_spaces) # type: ignore
+ return _ffi_api.DocToPythonScript(doc, indent_spaces) # type: ignore # pylint: disable=no-member
diff --git a/src/script/printer/doc.cc b/src/script/printer/doc.cc
index ed81f9d2dd..bfff0cfad4 100644
--- a/src/script/printer/doc.cc
+++ b/src/script/printer/doc.cc
@@ -16,6 +16,8 @@
* specific language governing permissions and limitations
* under the License.
*/
+#include <tvm/runtime/container/array.h>
+#include <tvm/runtime/logging.h>
#include <tvm/runtime/registry.h>
#include <tvm/script/printer/doc.h>
@@ -38,6 +40,12 @@ ExprDoc ExprDocNode::Call(Array<ExprDoc, void> args, Array<String, void> kwargs_
return CallDoc(GetRef<ExprDoc>(this), args, kwargs_keys, kwargs_values);
}
+StmtBlockDoc::StmtBlockDoc(Array<StmtDoc> stmts) {
+ ObjectPtr<StmtBlockDocNode> n = make_object<StmtBlockDocNode>();
+ n->stmts = stmts;
+ this->data_ = std::move(n);
+}
+
LiteralDoc::LiteralDoc(ObjectRef value) {
ObjectPtr<LiteralDocNode> n = make_object<LiteralDocNode>();
n->value = value;
@@ -115,6 +123,99 @@ SliceDoc::SliceDoc(Optional<ExprDoc> start, Optional<ExprDoc> stop, Optional<Exp
this->data_ = std::move(n);
}
+AssignDoc::AssignDoc(ExprDoc lhs, Optional<ExprDoc> rhs, Optional<ExprDoc> annotation) {
+ CHECK(rhs.defined() || annotation.defined())
+ << "ValueError: At least one of rhs and annotation needs to be non-null for AssignDoc.";
+ CHECK(lhs->IsInstance<IdDocNode>() || annotation == nullptr)
+ << "ValueError: annotation can only be nonnull if lhs is an identifier.";
+
+ ObjectPtr<AssignDocNode> n = make_object<AssignDocNode>();
+ n->lhs = lhs;
+ n->rhs = rhs;
+ n->annotation = annotation;
+ this->data_ = std::move(n);
+}
+
+IfDoc::IfDoc(ExprDoc predicate, Array<StmtDoc> then_branch, Array<StmtDoc> else_branch) {
+ CHECK(!then_branch.empty() || !else_branch.empty())
+ << "ValueError: At least one of the then branch or else branch needs to be non-empty.";
+
+ ObjectPtr<IfDocNode> n = make_object<IfDocNode>();
+ n->predicate = predicate;
+ n->then_branch = then_branch;
+ n->else_branch = else_branch;
+ this->data_ = std::move(n);
+}
+
+WhileDoc::WhileDoc(ExprDoc predicate, Array<StmtDoc> body) {
+ ObjectPtr<WhileDocNode> n = make_object<WhileDocNode>();
+ n->predicate = predicate;
+ n->body = body;
+ this->data_ = std::move(n);
+}
+
+ForDoc::ForDoc(ExprDoc lhs, ExprDoc rhs, Array<StmtDoc> body) {
+ ObjectPtr<ForDocNode> n = make_object<ForDocNode>();
+ n->lhs = lhs;
+ n->rhs = rhs;
+ n->body = body;
+ this->data_ = std::move(n);
+}
+
+ScopeDoc::ScopeDoc(Optional<ExprDoc> lhs, ExprDoc rhs, Array<StmtDoc> body) {
+ ObjectPtr<ScopeDocNode> n = make_object<ScopeDocNode>();
+ n->lhs = lhs;
+ n->rhs = rhs;
+ n->body = body;
+ this->data_ = std::move(n);
+}
+
+ScopeDoc::ScopeDoc(ExprDoc rhs, Array<StmtDoc> body) {
+ ObjectPtr<ScopeDocNode> n = make_object<ScopeDocNode>();
+ n->lhs = NullOpt;
+ n->rhs = rhs;
+ n->body = body;
+ this->data_ = std::move(n);
+}
+
+ExprStmtDoc::ExprStmtDoc(ExprDoc expr) {
+ ObjectPtr<ExprStmtDocNode> n = make_object<ExprStmtDocNode>();
+ n->expr = expr;
+ this->data_ = std::move(n);
+}
+
+AssertDoc::AssertDoc(ExprDoc test, Optional<ExprDoc> msg) {
+ ObjectPtr<AssertDocNode> n = make_object<AssertDocNode>();
+ n->test = test;
+ n->msg = msg;
+ this->data_ = std::move(n);
+}
+
+ReturnDoc::ReturnDoc(ExprDoc value) {
+ ObjectPtr<ReturnDocNode> n = make_object<ReturnDocNode>();
+ n->value = value;
+ this->data_ = std::move(n);
+}
+
+FunctionDoc::FunctionDoc(IdDoc name, Array<AssignDoc> args, Array<ExprDoc> decorators,
+ ExprDoc return_type, Array<StmtDoc> body) {
+ ObjectPtr<FunctionDocNode> n = make_object<FunctionDocNode>();
+ n->name = name;
+ n->args = args;
+ n->decorators = decorators;
+ n->return_type = return_type;
+ n->body = body;
+ this->data_ = std::move(n);
+}
+
+ClassDoc::ClassDoc(IdDoc name, Array<ExprDoc> decorators, Array<StmtDoc> body) {
+ ObjectPtr<ClassDocNode> n = make_object<ClassDocNode>();
+ n->name = name;
+ n->decorators = decorators;
+ n->body = body;
+ this->data_ = std::move(n);
+}
+
TVM_REGISTER_NODE_TYPE(DocNode);
TVM_REGISTER_NODE_TYPE(ExprDocNode);
@@ -125,6 +226,15 @@ TVM_REGISTER_GLOBAL("script.printer.ExprDocCall")
.set_body_method<ExprDoc, ExprDocNode, ExprDoc, Array<ExprDoc>, Array<String>, Array<ExprDoc>>(
&ExprDocNode::Call);
+TVM_REGISTER_NODE_TYPE(StmtDocNode);
+TVM_REGISTER_GLOBAL("script.printer.StmtDocSetComment")
+ .set_body_typed([](StmtDoc doc, Optional<String> comment) { doc->comment = comment; });
+
+TVM_REGISTER_NODE_TYPE(StmtBlockDocNode);
+TVM_REGISTER_GLOBAL("script.printer.StmtBlockDoc").set_body_typed([](Array<StmtDoc> stmts) {
+ return StmtBlockDoc(stmts);
+});
+
TVM_REGISTER_NODE_TYPE(LiteralDocNode);
TVM_REGISTER_GLOBAL("script.printer.LiteralDocNone").set_body_typed(LiteralDoc::None);
TVM_REGISTER_GLOBAL("script.printer.LiteralDocInt").set_body_typed(LiteralDoc::Int);
@@ -185,6 +295,66 @@ TVM_REGISTER_GLOBAL("script.printer.SliceDoc")
.set_body_typed([](Optional<ExprDoc> start, Optional<ExprDoc> stop, Optional<ExprDoc> step) {
return SliceDoc(start, stop, step);
});
+
+TVM_REGISTER_NODE_TYPE(AssignDocNode);
+TVM_REGISTER_GLOBAL("script.printer.AssignDoc")
+ .set_body_typed([](ExprDoc lhs, Optional<ExprDoc> rhs, Optional<ExprDoc> annotation) {
+ return AssignDoc(lhs, rhs, annotation);
+ });
+
+TVM_REGISTER_NODE_TYPE(IfDocNode);
+TVM_REGISTER_GLOBAL("script.printer.IfDoc")
+ .set_body_typed([](ExprDoc predicate, Array<StmtDoc> then_branch, Array<StmtDoc> else_branch) {
+ return IfDoc(predicate, then_branch, else_branch);
+ });
+
+TVM_REGISTER_NODE_TYPE(WhileDocNode);
+TVM_REGISTER_GLOBAL("script.printer.WhileDoc")
+ .set_body_typed([](ExprDoc predicate, Array<StmtDoc> body) {
+ return WhileDoc(predicate, body);
+ });
+
+TVM_REGISTER_NODE_TYPE(ForDocNode);
+TVM_REGISTER_GLOBAL("script.printer.ForDoc")
+ .set_body_typed([](ExprDoc lhs, ExprDoc rhs, Array<StmtDoc> body) {
+ return ForDoc(lhs, rhs, body);
+ });
+
+TVM_REGISTER_NODE_TYPE(ScopeDocNode);
+TVM_REGISTER_GLOBAL("script.printer.ScopeDoc")
+ .set_body_typed([](Optional<ExprDoc> lhs, ExprDoc rhs, Array<StmtDoc> body) {
+ return ScopeDoc(lhs, rhs, body);
+ });
+
+TVM_REGISTER_NODE_TYPE(ExprStmtDocNode);
+TVM_REGISTER_GLOBAL("script.printer.ExprStmtDoc").set_body_typed([](ExprDoc expr) {
+ return ExprStmtDoc(expr);
+});
+
+TVM_REGISTER_NODE_TYPE(AssertDocNode);
+TVM_REGISTER_GLOBAL("script.printer.AssertDoc")
+ .set_body_typed([](ExprDoc test, Optional<ExprDoc> msg = NullOpt) {
+ return AssertDoc(test, msg);
+ });
+
+TVM_REGISTER_NODE_TYPE(ReturnDocNode);
+TVM_REGISTER_GLOBAL("script.printer.ReturnDoc").set_body_typed([](ExprDoc value) {
+ return ReturnDoc(value);
+});
+
+TVM_REGISTER_NODE_TYPE(FunctionDocNode);
+TVM_REGISTER_GLOBAL("script.printer.FunctionDoc")
+ .set_body_typed([](IdDoc name, Array<AssignDoc> args, Array<ExprDoc> decorators,
+ ExprDoc return_type, Array<StmtDoc> body) {
+ return FunctionDoc(name, args, decorators, return_type, body);
+ });
+
+TVM_REGISTER_NODE_TYPE(ClassDocNode);
+TVM_REGISTER_GLOBAL("script.printer.ClassDoc")
+ .set_body_typed([](IdDoc name, Array<ExprDoc> decorators, Array<StmtDoc> body) {
+ return ClassDoc(name, decorators, body);
+ });
+
} // namespace printer
} // namespace script
} // namespace tvm
diff --git a/tests/python/unittest/test_tvmscript_printer_doc.py b/tests/python/unittest/test_tvmscript_printer_doc.py
index 4ff6a0f547..040a829010 100644
--- a/tests/python/unittest/test_tvmscript_printer_doc.py
+++ b/tests/python/unittest/test_tvmscript_printer_doc.py
@@ -14,6 +14,11 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+"""
+In this test file, we want to make sure the Python code can construct
+Doc objects, then access and modify their attributes correctly.
+"""
+
import pytest
from tvm.script.printer.doc import (
@@ -29,6 +34,17 @@ from tvm.script.printer.doc import (
ListDoc,
DictDoc,
SliceDoc,
+ StmtBlockDoc,
+ AssignDoc,
+ IfDoc,
+ WhileDoc,
+ ForDoc,
+ ScopeDoc,
+ ExprStmtDoc,
+ AssertDoc,
+ ReturnDoc,
+ FunctionDoc,
+ ClassDoc,
)
@@ -244,3 +260,247 @@ def test_expr_doc_call_with(args, kwargs):
assert doc.callee == target
assert list(doc.args) == args
assert dict(zip(doc.kwargs_keys, doc.kwargs_values)) == kwargs
+
+
+@pytest.mark.parametrize(
+ "stmts",
+ [
+ [],
+ [ExprStmtDoc(IdDoc("x"))],
+ [ExprStmtDoc(IdDoc("x")), ExprStmtDoc(IdDoc("y"))],
+ ],
+)
+def test_stmt_block_doc(stmts):
+ doc = StmtBlockDoc(stmts)
+
+ assert list(doc.stmts) == stmts
+
+
+@pytest.mark.parametrize(
+ "lhs, rhs, annotation",
+ [
+ (IdDoc("x"), IdDoc("y"), None),
+ (IdDoc("x"), None, IdDoc("int")),
+ (IdDoc("x"), IdDoc("y"), IdDoc("int")),
+ ],
+)
+def test_assign_doc(lhs, rhs, annotation):
+ doc = AssignDoc(lhs, rhs, annotation)
+
+ assert doc.lhs == lhs
+ assert doc.rhs == rhs
+ assert doc.annotation == annotation
+
+
+@pytest.mark.parametrize(
+ "lhs, rhs, annotation",
+ [
+ (IdDoc("x"), None, None),
+ (TupleDoc([IdDoc("x"), IdDoc("y")]), None, IdDoc("int")),
+ (TupleDoc([IdDoc("x"), IdDoc("y")]), IdDoc("u"), IdDoc("int")),
+ ],
+)
+def test_invalid_assign_doc(lhs, rhs, annotation):
+ with pytest.raises(ValueError) as e:
+ AssignDoc(lhs, rhs, annotation)
+ assert "AssignDoc" in str(e.value)
+
+
+@pytest.mark.parametrize(
+ "else_branch",
+ [
+ [],
+ [ExprStmtDoc(IdDoc("x"))],
+ [ExprStmtDoc(IdDoc("x")), ExprStmtDoc(IdDoc("y"))],
+ ],
+)
+@pytest.mark.parametrize(
+ "then_branch",
+ [
+ [],
+ [ExprStmtDoc(IdDoc("x"))],
+ [ExprStmtDoc(IdDoc("x")), ExprStmtDoc(IdDoc("y"))],
+ ],
+)
+def test_if_doc(then_branch, else_branch):
+ predicate = IdDoc("x")
+
+ if not then_branch and not else_branch:
+ with pytest.raises(ValueError) as e:
+ IfDoc(predicate, then_branch, else_branch)
+ assert "IfDoc" in str(e.value)
+ return
+ else:
+ doc = IfDoc(predicate, then_branch, else_branch)
+
+ assert doc.predicate == predicate
+ assert list(doc.then_branch) == then_branch
+ assert list(doc.else_branch) == else_branch
+
+
+@pytest.mark.parametrize(
+ "body",
+ [
+ [],
+ [ExprStmtDoc(IdDoc("x"))],
+ [ExprStmtDoc(IdDoc("x")), ExprStmtDoc(IdDoc("y"))],
+ ],
+)
+def test_while_doc(body):
+ predicate = IdDoc("x")
+
+ doc = WhileDoc(predicate, body)
+
+ assert doc.predicate == predicate
+ assert list(doc.body) == body
+
+
+@pytest.mark.parametrize(
+ "body",
+ [
+ [],
+ [ExprStmtDoc(IdDoc("x"))],
+ [ExprStmtDoc(IdDoc("x")), ExprStmtDoc(IdDoc("y"))],
+ ],
+)
+def test_for_doc(body):
+ lhs = IdDoc("x")
+ rhs = IdDoc("y")
+
+ doc = ForDoc(lhs, rhs, body)
+
+ assert doc.lhs == lhs
+ assert doc.rhs == rhs
+ assert list(doc.body) == body
+
+
+@pytest.mark.parametrize(
+ "lhs",
+ [
+ None,
+ IdDoc("x"),
+ ],
+)
+@pytest.mark.parametrize(
+ "body",
+ [
+ [],
+ [ExprStmtDoc(IdDoc("x"))],
+ [ExprStmtDoc(IdDoc("x")), ExprStmtDoc(IdDoc("y"))],
+ ],
+)
+def test_scope_doc(lhs, body):
+ rhs = IdDoc("y")
+
+ doc = ScopeDoc(lhs, rhs, body)
+
+ assert doc.lhs == lhs
+ assert doc.rhs == rhs
+ assert list(doc.body) == body
+
+
+def test_expr_stmt_doc():
+ expr = IdDoc("x")
+
+ doc = ExprStmtDoc(expr)
+
+ assert doc.expr == expr
+
+
+@pytest.mark.parametrize(
+ "msg",
+ [
+ None,
+ LiteralDoc("msg"),
+ ],
+)
+def test_assert_doc(msg):
+ test = IdDoc("x")
+
+ doc = AssertDoc(test, msg)
+
+ assert doc.test == test
+ assert doc.msg == msg
+
+
+def test_return_doc():
+ value = IdDoc("x")
+
+ doc = ReturnDoc(value)
+
+ assert doc.value == value
+
+
+@pytest.mark.parametrize(
+ "args",
+ [
+ [],
+ [AssignDoc(IdDoc("x"), None, IdDoc("int"))],
+ [
+ AssignDoc(IdDoc("x"), None, IdDoc("int")),
+ AssignDoc(IdDoc("y"), LiteralDoc(1), IdDoc("int")),
+ ],
+ ],
+)
+@pytest.mark.parametrize(
+ "decorators",
+ [
+ [],
+ [IdDoc("test")],
+ [IdDoc("test"), IdDoc("test2")],
+ ],
+)
+@pytest.mark.parametrize(
+ "body",
+ [
+ [],
+ [ExprStmtDoc(IdDoc("x"))],
+ [ExprStmtDoc(IdDoc("x")), ExprStmtDoc(IdDoc("y"))],
+ ],
+)
+def test_function_doc(args, decorators, body):
+ name = IdDoc("name")
+ return_type = LiteralDoc(None)
+
+ doc = FunctionDoc(name, args, decorators, return_type, body)
+
+ assert doc.name == name
+ assert list(doc.args) == args
+ assert list(doc.decorators) == decorators
+ assert doc.return_type == return_type
+ assert list(doc.body) == body
+
+
+@pytest.mark.parametrize(
+ "decorators",
+ [
+ [],
+ [IdDoc("test")],
+ [IdDoc("test"), IdDoc("test2")],
+ ],
+)
+@pytest.mark.parametrize(
+ "body",
+ [
+ [],
+ [ExprStmtDoc(IdDoc("x"))],
+ [ExprStmtDoc(IdDoc("x")), ExprStmtDoc(IdDoc("y"))],
+ ],
+)
+def test_class_doc(decorators, body):
+ name = IdDoc("name")
+
+ doc = ClassDoc(name, decorators, body)
+
+ assert doc.name == name
+ assert list(doc.decorators) == decorators
+ assert list(doc.body) == body
+
+
+def test_stmt_doc_comment():
+ doc = ExprStmtDoc(IdDoc("x"))
+ assert doc.comment is None
+
+ comment = "test comment"
+ doc.comment = comment
+ assert doc.comment == comment