You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2022/07/27 23:56:20 UTC

[tvm] branch main updated: [TVMScript] StmtDoc Definitions (#12111)

This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new fcec5f4a76 [TVMScript] StmtDoc Definitions (#12111)
fcec5f4a76 is described below

commit fcec5f4a763280dbb42f027654fd47088d0b3c5a
Author: Lite Ye <ye...@gmail.com>
AuthorDate: Wed Jul 27 19:56:15 2022 -0400

    [TVMScript] StmtDoc Definitions (#12111)
    
    This PR addes:
    
    - All StmtDoc subclasses
    - Python bindings for StmtDoc
    
    Tracking issue: https://github.com/apache/tvm/issues/11912
---
 include/tvm/script/printer/doc.h                   | 506 +++++++++++++++++++++
 python/tvm/script/printer/_ffi_api.py              |   2 +-
 python/tvm/script/printer/doc.py                   | 248 ++++++++--
 python/tvm/script/printer/doc_printer.py           |   5 +-
 src/script/printer/doc.cc                          | 170 +++++++
 .../python/unittest/test_tvmscript_printer_doc.py  | 260 +++++++++++
 6 files changed, 1154 insertions(+), 37 deletions(-)

diff --git a/include/tvm/script/printer/doc.h b/include/tvm/script/printer/doc.h
index f3f980e53f..e3dd83743e 100644
--- a/include/tvm/script/printer/doc.h
+++ b/include/tvm/script/printer/doc.h
@@ -119,6 +119,79 @@ class ExprDoc : public Doc {
   TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ExprDoc, Doc, ExprDocNode);
 };
 
+/*!
+ * \brief The base class of statement doc.
+ *
+ * \sa StmtDoc
+ */
+class StmtDocNode : public DocNode {
+ public:
+  /*!
+   * \brief The comment of this doc.
+   *
+   * The actual position of the comment depends on the type of Doc
+   * and also the DocPrinter implementation. It could be on the same
+   * line as the statement, or the line above, or inside the statement
+   * if it spans over multiple lines.
+   * */
+  mutable Optional<String> comment{NullOpt};
+
+  void VisitAttrs(AttrVisitor* v) {
+    DocNode::VisitAttrs(v);
+    v->Visit("comment", &comment);
+  }
+
+  static constexpr const char* _type_key = "script.printer.StmtDoc";
+  TVM_DECLARE_BASE_OBJECT_INFO(StmtDocNode, DocNode);
+};
+
+/*!
+ * \brief Reference type of StmtDocNode.
+ *
+ * \sa StmtDocNode
+ */
+class StmtDoc : public Doc {
+ protected:
+  StmtDoc() = default;
+
+ public:
+  TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(StmtDoc, Doc, StmtDocNode);
+};
+
+/*!
+ * \brief The container doc that holds a list of StmtDoc.
+ * \note `StmtBlockDoc` is never used in the IR, but a temporary container that allows holding a
+ * list of StmtDoc.
+ * \sa StmtBlockDoc
+ */
+class StmtBlockDocNode : public DocNode {
+ public:
+  /*! \brief The list of statements. */
+  Array<StmtDoc> stmts;
+
+  void VisitAttrs(AttrVisitor* v) {
+    DocNode::VisitAttrs(v);
+    v->Visit("stmts", &stmts);
+  }
+
+  static constexpr const char* _type_key = "script.printer.StmtBlockDoc";
+  TVM_DECLARE_FINAL_OBJECT_INFO(StmtBlockDocNode, DocNode);
+};
+
+/*!
+ * \brief Reference type of StmtBlockDocNode.
+ * \sa StmtBlockDocNode
+ */
+class StmtBlockDoc : public Doc {
+ public:
+  /*!
+   * \brief Constructor of StmtBlockDoc.
+   * \param stmts The list of statements.
+   */
+  explicit StmtBlockDoc(Array<StmtDoc> stmts);
+  TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(StmtBlockDoc, Doc, StmtBlockDocNode);
+};
+
 /*!
  * \brief Doc that represents literal value.
  *
@@ -219,6 +292,7 @@ class IdDoc : public ExprDoc {
    * \param name The name of identifier.
    */
   explicit IdDoc(String name);
+  explicit IdDoc(std::nullptr_t) : ExprDoc(nullptr) {}
   TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(IdDoc, ExprDoc, IdDocNode);
 };
 
@@ -640,6 +714,438 @@ class SliceDoc : public Doc {
   TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(SliceDoc, Doc, SliceDocNode);
 };
 
+/*!
+ * \brief Doc that represents assign statement.
+ *
+ * \sa AssignDoc
+ */
+class AssignDocNode : public StmtDocNode {
+ public:
+  /*! \brief The left hand side of the assignment */
+  ExprDoc lhs{nullptr};
+  /*!
+   * \brief The right hand side of the assignment.
+   *
+   * If null, this doc represents declaration, e.g. `A: T.Buffer[(1,2)]`
+   * */
+  Optional<ExprDoc> rhs;
+  /*! \brief The type annotation of this assignment. */
+  Optional<ExprDoc> annotation;
+
+  void VisitAttrs(AttrVisitor* v) {
+    StmtDocNode::VisitAttrs(v);
+    v->Visit("lhs", &lhs);
+    v->Visit("rhs", &rhs);
+    v->Visit("annotation", &annotation);
+  }
+
+  static constexpr const char* _type_key = "script.printer.AssignDoc";
+  TVM_DECLARE_FINAL_OBJECT_INFO(AssignDocNode, StmtDocNode);
+};
+
+/*!
+ * \brief Reference type of AssignDocNode.
+ *
+ * \sa AssignDoc
+ */
+class AssignDoc : public StmtDoc {
+ public:
+  /*!
+   * \brief Constructor of AssignDoc.
+   * \param lhs The left hand side of the assignment.
+   * \param rhs The right hand side of the assignment.
+   * \param annotation The type annotation of this assignment.
+   */
+  explicit AssignDoc(ExprDoc lhs, Optional<ExprDoc> rhs, Optional<ExprDoc> annotation);
+  TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(AssignDoc, StmtDoc, AssignDocNode);
+};
+
+/*!
+ * \brief Doc that represent if-then-else statement.
+ *
+ * \sa IfDoc
+ */
+class IfDocNode : public StmtDocNode {
+ public:
+  /*! \brief The predicate of the if-then-else statement. */
+  ExprDoc predicate{nullptr};
+  /*! \brief The then branch of the if-then-else statement. */
+  Array<StmtDoc> then_branch;
+  /*! \brief The else branch of the if-then-else statement. */
+  Array<StmtDoc> else_branch;
+
+  void VisitAttrs(AttrVisitor* v) {
+    StmtDocNode::VisitAttrs(v);
+    v->Visit("predicate", &predicate);
+    v->Visit("then_branch", &then_branch);
+    v->Visit("else_branch", &else_branch);
+  }
+
+  static constexpr const char* _type_key = "script.printer.IfDoc";
+  TVM_DECLARE_FINAL_OBJECT_INFO(IfDocNode, StmtDocNode);
+};
+
+/*!
+ * \brief Reference type of IfDocNode.
+ *
+ * \sa IfDocNode
+ */
+class IfDoc : public StmtDoc {
+ public:
+  /*!
+   * \brief Constructor of IfDoc.
+   * \param predicate The predicate of the if-then-else statement.
+   * \param then_branch The then branch of the if-then-else statement.
+   * \param else_branch The else branch of the if-then-else statement.
+   */
+  explicit IfDoc(ExprDoc predicate, Array<StmtDoc> then_branch, Array<StmtDoc> else_branch);
+  TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(IfDoc, StmtDoc, IfDocNode);
+};
+
+/*!
+ * \brief Doc that represents while statement.
+ *
+ * \sa WhileDoc
+ */
+class WhileDocNode : public StmtDocNode {
+ public:
+  /*! \brief The predicate of the while statement. */
+  ExprDoc predicate{nullptr};
+  /*! \brief The body of the while statement. */
+  Array<StmtDoc> body;
+
+  void VisitAttrs(AttrVisitor* v) {
+    StmtDocNode::VisitAttrs(v);
+    v->Visit("predicate", &predicate);
+    v->Visit("body", &body);
+  }
+
+  static constexpr const char* _type_key = "script.printer.WhileDoc";
+  TVM_DECLARE_FINAL_OBJECT_INFO(WhileDocNode, StmtDocNode);
+};
+
+/*!
+ * \brief Reference type of WhileDocNode.
+ *
+ * \sa WhileDocNode
+ */
+class WhileDoc : public StmtDoc {
+ public:
+  /*!
+   * \brief Constructor of WhileDoc.
+   * \param predicate The predicate of the while statement.
+   * \param body The body of the while statement.
+   */
+  explicit WhileDoc(ExprDoc predicate, Array<StmtDoc> body);
+  TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(WhileDoc, StmtDoc, WhileDocNode);
+};
+
+/*!
+ * \brief Doc that represents for statement.
+ *
+ * Example:
+ * for 'lhs' in 'rhs':
+ *   'body...'
+ *
+ * \sa ForDoc
+ */
+class ForDocNode : public StmtDocNode {
+ public:
+  /*! \brief The left hand side of the assignment of iterating variable. */
+  ExprDoc lhs{nullptr};
+  /*! \brief The right hand side of the assignment of iterating variable. */
+  ExprDoc rhs{nullptr};
+  /*! \brief The body of the for statement. */
+  Array<StmtDoc> body;
+
+  void VisitAttrs(AttrVisitor* v) {
+    StmtDocNode::VisitAttrs(v);
+    v->Visit("lhs", &lhs);
+    v->Visit("rhs", &rhs);
+    v->Visit("body", &body);
+  }
+
+  static constexpr const char* _type_key = "script.printer.ForDoc";
+  TVM_DECLARE_FINAL_OBJECT_INFO(ForDocNode, StmtDocNode);
+};
+
+/*!
+ * \brief Reference type of ForDocNode.
+ *
+ * \sa ForDocNode
+ */
+class ForDoc : public StmtDoc {
+ public:
+  /*!
+   * \brief Constructor of ForDoc.
+   * \param lhs The left hand side of the assignment of iterating variable.
+   * \param rhs The right hand side of the assignment of iterating variable.
+   * \param body The body of the for statement.
+   */
+  explicit ForDoc(ExprDoc lhs, ExprDoc rhs, Array<StmtDoc> body);
+  TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ForDoc, StmtDoc, ForDocNode);
+};
+
+/*!
+ * \brief Doc that represents special scopes.
+ *
+ * Specifically, this means the with statement in Python:
+ *
+ * with 'rhs' as 'lhs':
+ *   'body...'
+ *
+ * \sa ScopeDoc
+ */
+class ScopeDocNode : public StmtDocNode {
+ public:
+  /*! \brief The name of the scoped variable. */
+  Optional<ExprDoc> lhs{NullOpt};
+  /*! \brief The value of the scoped variable. */
+  ExprDoc rhs{nullptr};
+  /*! \brief The body of the scope doc. */
+  Array<StmtDoc> body;
+
+  void VisitAttrs(AttrVisitor* v) {
+    StmtDocNode::VisitAttrs(v);
+    v->Visit("lhs", &lhs);
+    v->Visit("rhs", &rhs);
+    v->Visit("body", &body);
+  }
+
+  static constexpr const char* _type_key = "script.printer.ScopeDoc";
+  TVM_DECLARE_FINAL_OBJECT_INFO(ScopeDocNode, StmtDocNode);
+};
+
+/*!
+ * \brief Reference type of ScopeDocNode.
+ *
+ * \sa ScopeDocNode
+ */
+class ScopeDoc : public StmtDoc {
+ public:
+  /*!
+   * \brief Constructor of ScopeDoc.
+   * \param lhs The name of the scoped variable.
+   * \param rhs The value of the scoped variable.
+   * \param body The body of the scope doc.
+   */
+  explicit ScopeDoc(Optional<ExprDoc> lhs, ExprDoc rhs, Array<StmtDoc> body);
+
+  /*!
+   * \brief Constructor of ScopeDoc.
+   * \param rhs The value of the scoped variable.
+   * \param body The body of the scope doc.
+   */
+  explicit ScopeDoc(ExprDoc rhs, Array<StmtDoc> body);
+
+  TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ScopeDoc, StmtDoc, ScopeDocNode);
+};
+
+/*!
+ * \brief Doc that represents an expression as statement.
+ *
+ * \sa ExprStmtDoc
+ */
+class ExprStmtDocNode : public StmtDocNode {
+ public:
+  /*! \brief The expression represented by this doc. */
+  ExprDoc expr{nullptr};
+
+  void VisitAttrs(AttrVisitor* v) {
+    StmtDocNode::VisitAttrs(v);
+    v->Visit("expr", &expr);
+  }
+
+  static constexpr const char* _type_key = "script.printer.ExprStmtDoc";
+  TVM_DECLARE_FINAL_OBJECT_INFO(ExprStmtDocNode, StmtDocNode);
+};
+
+/*!
+ * \brief Reference type of ExprStmtDocNode.
+ *
+ * \sa ExprStmtDocNode
+ */
+class ExprStmtDoc : public StmtDoc {
+ public:
+  /*!
+   * \brief Constructor of ExprStmtDoc.
+   * \param expr The expression represented by this doc.
+   */
+  explicit ExprStmtDoc(ExprDoc expr);
+  TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ExprStmtDoc, StmtDoc, ExprStmtDocNode);
+};
+
+/*!
+ * \brief Doc that represents assert statement.
+ *
+ * \sa AssertDoc
+ */
+class AssertDocNode : public StmtDocNode {
+ public:
+  /*! \brief The expression to test. */
+  ExprDoc test{nullptr};
+  /*! \brief The optional error message when assertion failed. */
+  Optional<ExprDoc> msg{NullOpt};
+
+  void VisitAttrs(AttrVisitor* v) {
+    StmtDocNode::VisitAttrs(v);
+    v->Visit("test", &test);
+    v->Visit("msg", &msg);
+  }
+
+  static constexpr const char* _type_key = "script.printer.AssertDoc";
+  TVM_DECLARE_FINAL_OBJECT_INFO(AssertDocNode, StmtDocNode);
+};
+
+/*!
+ * \brief Reference type of AssertDocNode.
+ *
+ * \sa AssertDocNode
+ */
+class AssertDoc : public StmtDoc {
+ public:
+  /*!
+   * \brief Constructor of AssertDoc.
+   * \param test The expression to test.
+   * \param msg The optional error message when assertion failed.
+   */
+  explicit AssertDoc(ExprDoc test, Optional<ExprDoc> msg = NullOpt);
+  TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(AssertDoc, StmtDoc, AssertDocNode);
+};
+
+/*!
+ * \brief Doc that represents return statement.
+ *
+ * \sa ReturnDoc
+ */
+class ReturnDocNode : public StmtDocNode {
+ public:
+  /*! \brief The value to return. */
+  ExprDoc value{nullptr};
+
+  void VisitAttrs(AttrVisitor* v) {
+    StmtDocNode::VisitAttrs(v);
+    v->Visit("value", &value);
+  }
+
+  static constexpr const char* _type_key = "script.printer.ReturnDoc";
+  TVM_DECLARE_FINAL_OBJECT_INFO(ReturnDocNode, StmtDocNode);
+};
+
+/*!
+ * \brief Reference type of ReturnDocNode.
+ *
+ * \sa ReturnDocNode
+ */
+class ReturnDoc : public StmtDoc {
+ public:
+  /*!
+   * \brief Constructor of ReturnDoc.
+   * \param value The value to return.
+   */
+  explicit ReturnDoc(ExprDoc value);
+  TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ReturnDoc, StmtDoc, ReturnDocNode);
+};
+
+/*!
+ * \brief Doc that represents function definition.
+ *
+ * \sa FunctionDoc
+ */
+class FunctionDocNode : public StmtDocNode {
+ public:
+  /*! \brief The name of function. */
+  IdDoc name{nullptr};
+  /*!
+   * \brief The arguments of function.
+   *
+   * The `lhs` means argument name,
+   * `annotation` means argument type,
+   * and `rhs` means default value.
+   */
+  Array<AssignDoc> args;
+  /*! \brief Decorators of function. */
+  Array<ExprDoc> decorators;
+  /*! \brief The return type of function. */
+  ExprDoc return_type{nullptr};
+  /*! \brief The body of function. */
+  Array<StmtDoc> body;
+
+  void VisitAttrs(AttrVisitor* v) {
+    StmtDocNode::VisitAttrs(v);
+    v->Visit("name", &name);
+    v->Visit("args", &args);
+    v->Visit("decorators", &decorators);
+    v->Visit("return_type", &return_type);
+    v->Visit("body", &body);
+  }
+
+  static constexpr const char* _type_key = "script.printer.FunctionDoc";
+  TVM_DECLARE_FINAL_OBJECT_INFO(FunctionDocNode, StmtDocNode);
+};
+
+/*!
+ * \brief Reference type of FunctionDocNode.
+ *
+ * \sa FunctionDocNode
+ */
+class FunctionDoc : public StmtDoc {
+ public:
+  /*!
+   * \brief Constructor of FunctionDoc.
+   * \param name The name of function..
+   * \param args The arguments of function.
+   * \param decorators The decorator of function.
+   * \param return_type The return type of function.
+   * \param body The body of function.
+   */
+  explicit FunctionDoc(IdDoc name, Array<AssignDoc> args, Array<ExprDoc> decorators,
+                       ExprDoc return_type, Array<StmtDoc> body);
+  TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(FunctionDoc, StmtDoc, FunctionDocNode);
+};
+
+/*!
+ * \brief Doc that represents class definition.
+ *
+ * \sa ClassDoc
+ */
+class ClassDocNode : public StmtDocNode {
+ public:
+  /*! \brief The name of class. */
+  IdDoc name{nullptr};
+  /*! \brief Decorators of class. */
+  Array<ExprDoc> decorators;
+  /*! \brief The body of class. */
+  Array<StmtDoc> body;
+
+  void VisitAttrs(AttrVisitor* v) {
+    StmtDocNode::VisitAttrs(v);
+    v->Visit("name", &name);
+    v->Visit("decorators", &decorators);
+    v->Visit("body", &body);
+  }
+
+  static constexpr const char* _type_key = "script.printer.ClassDoc";
+  TVM_DECLARE_FINAL_OBJECT_INFO(ClassDocNode, StmtDocNode);
+};
+
+/*!
+ * \brief Reference type of ClassDocNode.
+ *
+ * \sa ClassDocNode
+ */
+class ClassDoc : public StmtDoc {
+ public:
+  /*!
+   * \brief Constructor of ClassDoc.
+   * \param name The name of class.
+   * \param decorators The decorator of class.
+   * \param body The body of class.
+   */
+  explicit ClassDoc(IdDoc name, Array<ExprDoc> decorators, Array<StmtDoc> body);
+  TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ClassDoc, StmtDoc, ClassDocNode);
+};
+
 }  // namespace printer
 }  // namespace script
 }  // namespace tvm
diff --git a/python/tvm/script/printer/_ffi_api.py b/python/tvm/script/printer/_ffi_api.py
index baa639fe2d..944ecad01e 100644
--- a/python/tvm/script/printer/_ffi_api.py
+++ b/python/tvm/script/printer/_ffi_api.py
@@ -17,4 +17,4 @@
 """FFI APIs for tvm.script.printer"""
 import tvm._ffi
 
-tvm._ffi._init_api("script.printer", __name__)
+tvm._ffi._init_api("script.printer", __name__)  # pylint: disable=protected-access
diff --git a/python/tvm/script/printer/doc.py b/python/tvm/script/printer/doc.py
index acdb63dcf2..747ffc42f1 100644
--- a/python/tvm/script/printer/doc.py
+++ b/python/tvm/script/printer/doc.py
@@ -16,11 +16,10 @@
 # under the License.
 """Doc types for TVMScript Unified Printer"""
 
-from typing import List, Dict, Tuple, Optional, Union, Sequence
 from enum import IntEnum, unique
+from typing import Dict, List, Optional, Sequence, Tuple, Union
 
-import tvm._ffi
-import tvm.ir.container
+from tvm._ffi import register_object
 from tvm.runtime import Object
 from tvm.tir import FloatImm, IntImm
 
@@ -47,7 +46,7 @@ class ExprDoc(Object):
         -------
         doc : AttrAccessDoc
         """
-        return _ffi_api.ExprDocAttr(self, name)  # type: ignore
+        return _ffi_api.ExprDocAttr(self, name)  # type: ignore # pylint: disable=no-member
 
     def call(self, *args: Tuple["ExprDoc"], **kwargs: Dict[str, "ExprDoc"]) -> "CallDoc":
         """
@@ -66,7 +65,7 @@ class ExprDoc(Object):
         """
         kwargs_keys = list(kwargs.keys())
         kwargs_values = list(kwargs.values())
-        return _ffi_api.ExprDocCall(self, args, kwargs_keys, kwargs_values)  # type: ignore
+        return _ffi_api.ExprDocCall(self, args, kwargs_keys, kwargs_values)  # type: ignore # pylint: disable=no-member
 
     _IndexType = Union["ExprDoc", "SliceDoc"]
 
@@ -85,7 +84,7 @@ class ExprDoc(Object):
         """
         if not isinstance(indices, tuple):
             indices = (indices,)
-        return _ffi_api.ExprDocIndex(self, indices)  # type: ignore
+        return _ffi_api.ExprDocIndex(self, indices)  # type: ignore # pylint: disable=no-member
 
     def __iter__(self):
         """
@@ -100,7 +99,34 @@ class ExprDoc(Object):
         raise RuntimeError(f"{self.__class__} cannot be used as iterable.")
 
 
-@tvm._ffi.register_object("script.printer.LiteralDoc")
+class StmtDoc(Doc):
+    """Base class of statement doc"""
+
+    @property
+    def comment(self) -> Optional[str]:
+        # It has to call the dunder method to avoid infinite recursion
+        return self.__getattr__("comment")  # pylint: disable=unnecessary-dunder-call
+
+    @comment.setter
+    def comment(self, value):
+        return _ffi_api.StmtDocSetComment(self, value)  # type: ignore # pylint: disable=no-member
+
+
+@register_object("script.printer.StmtBlockDoc")
+class StmtBlockDoc(Doc):
+    """The container doc that holds a list of StmtDoc.
+
+    Note: `StmtBlockDoc` is never used in the IR, but a temporary container that allows holding a
+    list of StmtDoc.
+    """
+
+    stmts: Sequence[StmtDoc]
+
+    def __init__(self, stmts: List[StmtDoc]):
+        self.__init_handle_by_constructor__(_ffi_api.StmtBlockDoc, stmts)  # type: ignore # pylint: disable=no-member
+
+
+@register_object("script.printer.LiteralDoc")
 class LiteralDoc(ExprDoc):
     """Doc that represents literal value"""
 
@@ -108,30 +134,30 @@ class LiteralDoc(ExprDoc):
 
     def __init__(self, value: Union[str, float, bool, int, None]):
         if value is None:
-            self.__init_handle_by_constructor__(_ffi_api.LiteralDocNone)  # type: ignore
+            self.__init_handle_by_constructor__(_ffi_api.LiteralDocNone)  # type: ignore # pylint: disable=no-member
         elif isinstance(value, str):
-            self.__init_handle_by_constructor__(_ffi_api.LiteralDocStr, value)  # type: ignore
+            self.__init_handle_by_constructor__(_ffi_api.LiteralDocStr, value)  # type: ignore # pylint: disable=no-member
         elif isinstance(value, float):
-            self.__init_handle_by_constructor__(_ffi_api.LiteralDocFloat, value)  # type: ignore
+            self.__init_handle_by_constructor__(_ffi_api.LiteralDocFloat, value)  # type: ignore # pylint: disable=no-member
         elif isinstance(value, bool):
-            self.__init_handle_by_constructor__(_ffi_api.LiteralDocBoolean, value)  # type: ignore
+            self.__init_handle_by_constructor__(_ffi_api.LiteralDocBoolean, value)  # type: ignore # pylint: disable=no-member
         elif isinstance(value, int):
-            self.__init_handle_by_constructor__(_ffi_api.LiteralDocInt, value)  # type: ignore
+            self.__init_handle_by_constructor__(_ffi_api.LiteralDocInt, value)  # type: ignore # pylint: disable=no-member
         else:
             raise TypeError(f"Unsupported type {type(value)} for LiteralDoc")
 
 
-@tvm._ffi.register_object("script.printer.IdDoc")
+@register_object("script.printer.IdDoc")
 class IdDoc(ExprDoc):
     """Doc that represents identifier"""
 
     name: str
 
     def __init__(self, name: str):
-        self.__init_handle_by_constructor__(_ffi_api.IdDoc, name)  # type: ignore
+        self.__init_handle_by_constructor__(_ffi_api.IdDoc, name)  # type: ignore # pylint: disable=no-member
 
 
-@tvm._ffi.register_object("script.printer.AttrAccessDoc")
+@register_object("script.printer.AttrAccessDoc")
 class AttrAccessDoc(ExprDoc):
     """Doc that represents attribute access on an expression"""
 
@@ -139,10 +165,10 @@ class AttrAccessDoc(ExprDoc):
     name: str
 
     def __init__(self, value: ExprDoc, name: str):
-        self.__init_handle_by_constructor__(_ffi_api.AttrAccessDoc, value, name)  # type: ignore
+        self.__init_handle_by_constructor__(_ffi_api.AttrAccessDoc, value, name)  # type: ignore # pylint: disable=no-member
 
 
-@tvm._ffi.register_object("script.printer.IndexDoc")
+@register_object("script.printer.IndexDoc")
 class IndexDoc(ExprDoc):
     """Doc that represents index access on an expression"""
 
@@ -150,10 +176,10 @@ class IndexDoc(ExprDoc):
     indices: Sequence[Union[ExprDoc, "SliceDoc"]]
 
     def __init__(self, value: ExprDoc, indices: List[Union[ExprDoc, "SliceDoc"]]):
-        self.__init_handle_by_constructor__(_ffi_api.IndexDoc, value, indices)  # type: ignore
+        self.__init_handle_by_constructor__(_ffi_api.IndexDoc, value, indices)  # type: ignore # pylint: disable=no-member
 
 
-@tvm._ffi.register_object("script.printer.CallDoc")
+@register_object("script.printer.CallDoc")
 class CallDoc(ExprDoc):
     """Doc that represents function call"""
 
@@ -166,14 +192,18 @@ class CallDoc(ExprDoc):
         kwargs_keys = list(kwargs.keys())
         kwargs_values = list(kwargs.values())
         self.__init_handle_by_constructor__(
-            _ffi_api.CallDoc, callee, args, kwargs_keys, kwargs_values  # type: ignore
+            _ffi_api.CallDoc,  # type: ignore # pylint: disable=no-member
+            callee,
+            args,
+            kwargs_keys,
+            kwargs_values,
         )
 
 
 @unique
 class OperationKind(IntEnum):
     """
-    This enum represents the kind of operation (operator) in OpeartionDoc
+    This enum represents the kind of operation (operator) in OperationDoc
 
     It's mirrored from OperationDocNode::Kind at include/tvm/script/printer/doc.h
     """
@@ -214,7 +244,7 @@ class OperationKind(IntEnum):
     # pylint: enable=invalid-name
 
 
-@tvm._ffi.register_object("script.printer.OperationDoc")
+@register_object("script.printer.OperationDoc")
 class OperationDoc(ExprDoc):
     """
     Doc that represents operation
@@ -227,10 +257,10 @@ class OperationDoc(ExprDoc):
     operands: Sequence[ExprDoc]
 
     def __init__(self, kind: OperationKind, operands: List[ExprDoc]):
-        self.__init_handle_by_constructor__(_ffi_api.OperationDoc, kind, operands)  # type: ignore
+        self.__init_handle_by_constructor__(_ffi_api.OperationDoc, kind, operands)  # type: ignore # pylint: disable=no-member
 
 
-@tvm._ffi.register_object("script.printer.LambdaDoc")
+@register_object("script.printer.LambdaDoc")
 class LambdaDoc(ExprDoc):
     """Doc that represents lambda function"""
 
@@ -238,30 +268,30 @@ class LambdaDoc(ExprDoc):
     body: ExprDoc
 
     def __init__(self, args: List[IdDoc], body: ExprDoc):
-        self.__init_handle_by_constructor__(_ffi_api.LambdaDoc, args, body)  # type: ignore
+        self.__init_handle_by_constructor__(_ffi_api.LambdaDoc, args, body)  # type: ignore # pylint: disable=no-member
 
 
-@tvm._ffi.register_object("script.printer.TupleDoc")
+@register_object("script.printer.TupleDoc")
 class TupleDoc(ExprDoc):
     """Doc that represents tuple literal"""
 
     elements: Sequence[ExprDoc]
 
     def __init__(self, elements: List[ExprDoc]):
-        self.__init_handle_by_constructor__(_ffi_api.TupleDoc, elements)  # type: ignore
+        self.__init_handle_by_constructor__(_ffi_api.TupleDoc, elements)  # type: ignore # pylint: disable=no-member
 
 
-@tvm._ffi.register_object("script.printer.ListDoc")
+@register_object("script.printer.ListDoc")
 class ListDoc(ExprDoc):
     """Doc that represents list literal"""
 
     elements: Sequence[ExprDoc]
 
     def __init__(self, elements: List[ExprDoc]):
-        self.__init_handle_by_constructor__(_ffi_api.ListDoc, elements)  # type: ignore
+        self.__init_handle_by_constructor__(_ffi_api.ListDoc, elements)  # type: ignore # pylint: disable=no-member
 
 
-@tvm._ffi.register_object("script.printer.DictDoc")
+@register_object("script.printer.DictDoc")
 class DictDoc(ExprDoc):
     """Doc that represents dict literal"""
 
@@ -271,10 +301,10 @@ class DictDoc(ExprDoc):
     def __init__(self, content: Dict[ExprDoc, ExprDoc]):
         keys = list(content.keys())
         values = list(content.values())
-        self.__init_handle_by_constructor__(_ffi_api.DictDoc, keys, values)  # type: ignore
+        self.__init_handle_by_constructor__(_ffi_api.DictDoc, keys, values)  # type: ignore # pylint: disable=no-member
 
 
-@tvm._ffi.register_object("script.printer.SliceDoc")
+@register_object("script.printer.SliceDoc")
 class SliceDoc(ExprDoc):
     """
     Doc that represents slice in Index expression
@@ -292,4 +322,156 @@ class SliceDoc(ExprDoc):
         stop: Optional[ExprDoc] = None,
         step: Optional[ExprDoc] = None,
     ):
-        self.__init_handle_by_constructor__(_ffi_api.SliceDoc, start, stop, step)  # type: ignore
+        self.__init_handle_by_constructor__(_ffi_api.SliceDoc, start, stop, step)  # type: ignore # pylint: disable=no-member
+
+
+@register_object("script.printer.AssignDoc")
+class AssignDoc(StmtDoc):
+    """Doc that represents assign statement."""
+
+    lhs: ExprDoc
+    rhs: Optional[ExprDoc]
+    annotation: Optional[ExprDoc]
+
+    def __init__(self, lhs: ExprDoc, rhs: Optional[ExprDoc], annotation: Optional[ExprDoc] = None):
+        self.__init_handle_by_constructor__(
+            _ffi_api.AssignDoc,  # type: ignore # pylint: disable=no-member
+            lhs,
+            rhs,
+            annotation,
+        )
+
+
+@register_object("script.printer.IfDoc")
+class IfDoc(StmtDoc):
+    """Doc that represent if-then-else statement."""
+
+    predicate: ExprDoc
+    then_branch: Sequence[StmtDoc]
+    else_branch: Sequence[StmtDoc]
+
+    def __init__(self, predicate: ExprDoc, then_branch: List[StmtDoc], else_branch: List[StmtDoc]):
+        self.__init_handle_by_constructor__(
+            _ffi_api.IfDoc,  # type: ignore # pylint: disable=no-member
+            predicate,
+            then_branch,
+            else_branch,
+        )
+
+
+@register_object("script.printer.WhileDoc")
+class WhileDoc(StmtDoc):
+    """Doc that represents while statement."""
+
+    predicate: ExprDoc
+    body: Sequence[StmtDoc]
+
+    def __init__(self, predicate: ExprDoc, body: List[StmtDoc]):
+        self.__init_handle_by_constructor__(_ffi_api.WhileDoc, predicate, body)  # type: ignore # pylint: disable=no-member
+
+
+@register_object("script.printer.ForDoc")
+class ForDoc(StmtDoc):
+    """Doc that represents for statement."""
+
+    lhs: ExprDoc
+    rhs: ExprDoc
+    body: Sequence[StmtDoc]
+
+    def __init__(self, lhs: ExprDoc, rhs: ExprDoc, body: List[StmtDoc]):
+        self.__init_handle_by_constructor__(_ffi_api.ForDoc, lhs, rhs, body)  # type: ignore # pylint: disable=no-member
+
+
+@register_object("script.printer.ScopeDoc")
+class ScopeDoc(StmtDoc):
+    """
+    Doc that represents special scopes.
+
+    Specifically, this means the with statement in Python:
+
+    with <rhs> as <lhs>:
+        <body...>
+    """
+
+    lhs: Optional[ExprDoc]
+    rhs: ExprDoc
+    body: Sequence[StmtDoc]
+
+    def __init__(self, lhs: Optional[ExprDoc], rhs: ExprDoc, body: List[StmtDoc]):
+        self.__init_handle_by_constructor__(_ffi_api.ScopeDoc, lhs, rhs, body)  # type: ignore # pylint: disable=no-member
+
+
+@register_object("script.printer.ExprStmtDoc")
+class ExprStmtDoc(StmtDoc):
+    """Doc that represents an expression as statement."""
+
+    expr: ExprDoc
+
+    def __init__(self, expr: ExprDoc):
+        self.__init_handle_by_constructor__(_ffi_api.ExprStmtDoc, expr)  # type: ignore # pylint: disable=no-member
+
+
+@register_object("script.printer.AssertDoc")
+class AssertDoc(StmtDoc):
+    """Doc that represents assert statement."""
+
+    test: ExprDoc
+    msg: Optional[ExprDoc]
+
+    def __init__(self, test: ExprDoc, msg: Optional[ExprDoc] = None):
+        self.__init_handle_by_constructor__(_ffi_api.AssertDoc, test, msg)  # type: ignore # pylint: disable=no-member
+
+
+@register_object("script.printer.ReturnDoc")
+class ReturnDoc(StmtDoc):
+    """Doc that represents return statement."""
+
+    value: ExprDoc
+
+    def __init__(self, value: ExprDoc):
+        self.__init_handle_by_constructor__(_ffi_api.ReturnDoc, value)  # type: ignore # pylint: disable=no-member
+
+
+@register_object("script.printer.FunctionDoc")
+class FunctionDoc(StmtDoc):
+    """Doc that represents function definition."""
+
+    name: IdDoc
+    args: Sequence[AssignDoc]
+    decorators: Sequence[ExprDoc]
+    return_type: ExprDoc
+    body: Sequence[StmtDoc]
+
+    def __init__(
+        self,
+        name: IdDoc,
+        args: List[AssignDoc],
+        decorators: List[ExprDoc],
+        return_type: ExprDoc,
+        body: List[StmtDoc],
+    ):
+        self.__init_handle_by_constructor__(
+            _ffi_api.FunctionDoc,  # type: ignore # pylint: disable=no-member
+            name,
+            args,
+            decorators,
+            return_type,
+            body,
+        )
+
+
+@register_object("script.printer.ClassDoc")
+class ClassDoc(StmtDoc):
+    """Doc that represents class definition."""
+
+    name: IdDoc
+    decorators: Sequence[ExprDoc]
+    body: Sequence[StmtDoc]
+
+    def __init__(self, name: IdDoc, decorators: List[ExprDoc], body: List[StmtDoc]):
+        self.__init_handle_by_constructor__(
+            _ffi_api.ClassDoc,  # type: ignore # pylint: disable=no-member
+            name,
+            decorators,
+            body,
+        )
diff --git a/python/tvm/script/printer/doc_printer.py b/python/tvm/script/printer/doc_printer.py
index 404632b44c..1cb56ecbf7 100644
--- a/python/tvm/script/printer/doc_printer.py
+++ b/python/tvm/script/printer/doc_printer.py
@@ -21,8 +21,7 @@ from .doc import Doc
 
 
 def to_python_script(doc: Doc, indent_spaces: int = 4) -> str:
-    """
-    Convert Doc into Python script.
+    """Convert Doc into Python script.
 
     Parameters
     ----------
@@ -36,4 +35,4 @@ def to_python_script(doc: Doc, indent_spaces: int = 4) -> str:
     script : str
         The text representation of Doc in Python syntax
     """
-    return _ffi_api.DocToPythonScript(doc, indent_spaces)  # type: ignore
+    return _ffi_api.DocToPythonScript(doc, indent_spaces)  # type: ignore # pylint: disable=no-member
diff --git a/src/script/printer/doc.cc b/src/script/printer/doc.cc
index ed81f9d2dd..bfff0cfad4 100644
--- a/src/script/printer/doc.cc
+++ b/src/script/printer/doc.cc
@@ -16,6 +16,8 @@
  * specific language governing permissions and limitations
  * under the License.
  */
+#include <tvm/runtime/container/array.h>
+#include <tvm/runtime/logging.h>
 #include <tvm/runtime/registry.h>
 #include <tvm/script/printer/doc.h>
 
@@ -38,6 +40,12 @@ ExprDoc ExprDocNode::Call(Array<ExprDoc, void> args, Array<String, void> kwargs_
   return CallDoc(GetRef<ExprDoc>(this), args, kwargs_keys, kwargs_values);
 }
 
+StmtBlockDoc::StmtBlockDoc(Array<StmtDoc> stmts) {
+  ObjectPtr<StmtBlockDocNode> n = make_object<StmtBlockDocNode>();
+  n->stmts = stmts;
+  this->data_ = std::move(n);
+}
+
 LiteralDoc::LiteralDoc(ObjectRef value) {
   ObjectPtr<LiteralDocNode> n = make_object<LiteralDocNode>();
   n->value = value;
@@ -115,6 +123,99 @@ SliceDoc::SliceDoc(Optional<ExprDoc> start, Optional<ExprDoc> stop, Optional<Exp
   this->data_ = std::move(n);
 }
 
+AssignDoc::AssignDoc(ExprDoc lhs, Optional<ExprDoc> rhs, Optional<ExprDoc> annotation) {
+  CHECK(rhs.defined() || annotation.defined())
+      << "ValueError: At least one of rhs and annotation needs to be non-null for AssignDoc.";
+  CHECK(lhs->IsInstance<IdDocNode>() || annotation == nullptr)
+      << "ValueError: annotation can only be nonnull if lhs is an identifier.";
+
+  ObjectPtr<AssignDocNode> n = make_object<AssignDocNode>();
+  n->lhs = lhs;
+  n->rhs = rhs;
+  n->annotation = annotation;
+  this->data_ = std::move(n);
+}
+
+IfDoc::IfDoc(ExprDoc predicate, Array<StmtDoc> then_branch, Array<StmtDoc> else_branch) {
+  CHECK(!then_branch.empty() || !else_branch.empty())
+      << "ValueError: At least one of the then branch or else branch needs to be non-empty.";
+
+  ObjectPtr<IfDocNode> n = make_object<IfDocNode>();
+  n->predicate = predicate;
+  n->then_branch = then_branch;
+  n->else_branch = else_branch;
+  this->data_ = std::move(n);
+}
+
+WhileDoc::WhileDoc(ExprDoc predicate, Array<StmtDoc> body) {
+  ObjectPtr<WhileDocNode> n = make_object<WhileDocNode>();
+  n->predicate = predicate;
+  n->body = body;
+  this->data_ = std::move(n);
+}
+
+ForDoc::ForDoc(ExprDoc lhs, ExprDoc rhs, Array<StmtDoc> body) {
+  ObjectPtr<ForDocNode> n = make_object<ForDocNode>();
+  n->lhs = lhs;
+  n->rhs = rhs;
+  n->body = body;
+  this->data_ = std::move(n);
+}
+
+ScopeDoc::ScopeDoc(Optional<ExprDoc> lhs, ExprDoc rhs, Array<StmtDoc> body) {
+  ObjectPtr<ScopeDocNode> n = make_object<ScopeDocNode>();
+  n->lhs = lhs;
+  n->rhs = rhs;
+  n->body = body;
+  this->data_ = std::move(n);
+}
+
+ScopeDoc::ScopeDoc(ExprDoc rhs, Array<StmtDoc> body) {
+  ObjectPtr<ScopeDocNode> n = make_object<ScopeDocNode>();
+  n->lhs = NullOpt;
+  n->rhs = rhs;
+  n->body = body;
+  this->data_ = std::move(n);
+}
+
+ExprStmtDoc::ExprStmtDoc(ExprDoc expr) {
+  ObjectPtr<ExprStmtDocNode> n = make_object<ExprStmtDocNode>();
+  n->expr = expr;
+  this->data_ = std::move(n);
+}
+
+AssertDoc::AssertDoc(ExprDoc test, Optional<ExprDoc> msg) {
+  ObjectPtr<AssertDocNode> n = make_object<AssertDocNode>();
+  n->test = test;
+  n->msg = msg;
+  this->data_ = std::move(n);
+}
+
+ReturnDoc::ReturnDoc(ExprDoc value) {
+  ObjectPtr<ReturnDocNode> n = make_object<ReturnDocNode>();
+  n->value = value;
+  this->data_ = std::move(n);
+}
+
+FunctionDoc::FunctionDoc(IdDoc name, Array<AssignDoc> args, Array<ExprDoc> decorators,
+                         ExprDoc return_type, Array<StmtDoc> body) {
+  ObjectPtr<FunctionDocNode> n = make_object<FunctionDocNode>();
+  n->name = name;
+  n->args = args;
+  n->decorators = decorators;
+  n->return_type = return_type;
+  n->body = body;
+  this->data_ = std::move(n);
+}
+
+ClassDoc::ClassDoc(IdDoc name, Array<ExprDoc> decorators, Array<StmtDoc> body) {
+  ObjectPtr<ClassDocNode> n = make_object<ClassDocNode>();
+  n->name = name;
+  n->decorators = decorators;
+  n->body = body;
+  this->data_ = std::move(n);
+}
+
 TVM_REGISTER_NODE_TYPE(DocNode);
 
 TVM_REGISTER_NODE_TYPE(ExprDocNode);
@@ -125,6 +226,15 @@ TVM_REGISTER_GLOBAL("script.printer.ExprDocCall")
     .set_body_method<ExprDoc, ExprDocNode, ExprDoc, Array<ExprDoc>, Array<String>, Array<ExprDoc>>(
         &ExprDocNode::Call);
 
+TVM_REGISTER_NODE_TYPE(StmtDocNode);
+TVM_REGISTER_GLOBAL("script.printer.StmtDocSetComment")
+    .set_body_typed([](StmtDoc doc, Optional<String> comment) { doc->comment = comment; });
+
+TVM_REGISTER_NODE_TYPE(StmtBlockDocNode);
+TVM_REGISTER_GLOBAL("script.printer.StmtBlockDoc").set_body_typed([](Array<StmtDoc> stmts) {
+  return StmtBlockDoc(stmts);
+});
+
 TVM_REGISTER_NODE_TYPE(LiteralDocNode);
 TVM_REGISTER_GLOBAL("script.printer.LiteralDocNone").set_body_typed(LiteralDoc::None);
 TVM_REGISTER_GLOBAL("script.printer.LiteralDocInt").set_body_typed(LiteralDoc::Int);
@@ -185,6 +295,66 @@ TVM_REGISTER_GLOBAL("script.printer.SliceDoc")
     .set_body_typed([](Optional<ExprDoc> start, Optional<ExprDoc> stop, Optional<ExprDoc> step) {
       return SliceDoc(start, stop, step);
     });
+
+TVM_REGISTER_NODE_TYPE(AssignDocNode);
+TVM_REGISTER_GLOBAL("script.printer.AssignDoc")
+    .set_body_typed([](ExprDoc lhs, Optional<ExprDoc> rhs, Optional<ExprDoc> annotation) {
+      return AssignDoc(lhs, rhs, annotation);
+    });
+
+TVM_REGISTER_NODE_TYPE(IfDocNode);
+TVM_REGISTER_GLOBAL("script.printer.IfDoc")
+    .set_body_typed([](ExprDoc predicate, Array<StmtDoc> then_branch, Array<StmtDoc> else_branch) {
+      return IfDoc(predicate, then_branch, else_branch);
+    });
+
+TVM_REGISTER_NODE_TYPE(WhileDocNode);
+TVM_REGISTER_GLOBAL("script.printer.WhileDoc")
+    .set_body_typed([](ExprDoc predicate, Array<StmtDoc> body) {
+      return WhileDoc(predicate, body);
+    });
+
+TVM_REGISTER_NODE_TYPE(ForDocNode);
+TVM_REGISTER_GLOBAL("script.printer.ForDoc")
+    .set_body_typed([](ExprDoc lhs, ExprDoc rhs, Array<StmtDoc> body) {
+      return ForDoc(lhs, rhs, body);
+    });
+
+TVM_REGISTER_NODE_TYPE(ScopeDocNode);
+TVM_REGISTER_GLOBAL("script.printer.ScopeDoc")
+    .set_body_typed([](Optional<ExprDoc> lhs, ExprDoc rhs, Array<StmtDoc> body) {
+      return ScopeDoc(lhs, rhs, body);
+    });
+
+TVM_REGISTER_NODE_TYPE(ExprStmtDocNode);
+TVM_REGISTER_GLOBAL("script.printer.ExprStmtDoc").set_body_typed([](ExprDoc expr) {
+  return ExprStmtDoc(expr);
+});
+
+TVM_REGISTER_NODE_TYPE(AssertDocNode);
+TVM_REGISTER_GLOBAL("script.printer.AssertDoc")
+    .set_body_typed([](ExprDoc test, Optional<ExprDoc> msg = NullOpt) {
+      return AssertDoc(test, msg);
+    });
+
+TVM_REGISTER_NODE_TYPE(ReturnDocNode);
+TVM_REGISTER_GLOBAL("script.printer.ReturnDoc").set_body_typed([](ExprDoc value) {
+  return ReturnDoc(value);
+});
+
+TVM_REGISTER_NODE_TYPE(FunctionDocNode);
+TVM_REGISTER_GLOBAL("script.printer.FunctionDoc")
+    .set_body_typed([](IdDoc name, Array<AssignDoc> args, Array<ExprDoc> decorators,
+                       ExprDoc return_type, Array<StmtDoc> body) {
+      return FunctionDoc(name, args, decorators, return_type, body);
+    });
+
+TVM_REGISTER_NODE_TYPE(ClassDocNode);
+TVM_REGISTER_GLOBAL("script.printer.ClassDoc")
+    .set_body_typed([](IdDoc name, Array<ExprDoc> decorators, Array<StmtDoc> body) {
+      return ClassDoc(name, decorators, body);
+    });
+
 }  // namespace printer
 }  // namespace script
 }  // namespace tvm
diff --git a/tests/python/unittest/test_tvmscript_printer_doc.py b/tests/python/unittest/test_tvmscript_printer_doc.py
index 4ff6a0f547..040a829010 100644
--- a/tests/python/unittest/test_tvmscript_printer_doc.py
+++ b/tests/python/unittest/test_tvmscript_printer_doc.py
@@ -14,6 +14,11 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+"""
+In this test file, we want to make sure the Python code can construct
+Doc objects, then access and modify their attributes correctly.
+"""
+
 import pytest
 
 from tvm.script.printer.doc import (
@@ -29,6 +34,17 @@ from tvm.script.printer.doc import (
     ListDoc,
     DictDoc,
     SliceDoc,
+    StmtBlockDoc,
+    AssignDoc,
+    IfDoc,
+    WhileDoc,
+    ForDoc,
+    ScopeDoc,
+    ExprStmtDoc,
+    AssertDoc,
+    ReturnDoc,
+    FunctionDoc,
+    ClassDoc,
 )
 
 
@@ -244,3 +260,247 @@ def test_expr_doc_call_with(args, kwargs):
     assert doc.callee == target
     assert list(doc.args) == args
     assert dict(zip(doc.kwargs_keys, doc.kwargs_values)) == kwargs
+
+
+@pytest.mark.parametrize(
+    "stmts",
+    [
+        [],
+        [ExprStmtDoc(IdDoc("x"))],
+        [ExprStmtDoc(IdDoc("x")), ExprStmtDoc(IdDoc("y"))],
+    ],
+)
+def test_stmt_block_doc(stmts):
+    doc = StmtBlockDoc(stmts)
+
+    assert list(doc.stmts) == stmts
+
+
+@pytest.mark.parametrize(
+    "lhs, rhs, annotation",
+    [
+        (IdDoc("x"), IdDoc("y"), None),
+        (IdDoc("x"), None, IdDoc("int")),
+        (IdDoc("x"), IdDoc("y"), IdDoc("int")),
+    ],
+)
+def test_assign_doc(lhs, rhs, annotation):
+    doc = AssignDoc(lhs, rhs, annotation)
+
+    assert doc.lhs == lhs
+    assert doc.rhs == rhs
+    assert doc.annotation == annotation
+
+
+@pytest.mark.parametrize(
+    "lhs, rhs, annotation",
+    [
+        (IdDoc("x"), None, None),
+        (TupleDoc([IdDoc("x"), IdDoc("y")]), None, IdDoc("int")),
+        (TupleDoc([IdDoc("x"), IdDoc("y")]), IdDoc("u"), IdDoc("int")),
+    ],
+)
+def test_invalid_assign_doc(lhs, rhs, annotation):
+    with pytest.raises(ValueError) as e:
+        AssignDoc(lhs, rhs, annotation)
+    assert "AssignDoc" in str(e.value)
+
+
+@pytest.mark.parametrize(
+    "else_branch",
+    [
+        [],
+        [ExprStmtDoc(IdDoc("x"))],
+        [ExprStmtDoc(IdDoc("x")), ExprStmtDoc(IdDoc("y"))],
+    ],
+)
+@pytest.mark.parametrize(
+    "then_branch",
+    [
+        [],
+        [ExprStmtDoc(IdDoc("x"))],
+        [ExprStmtDoc(IdDoc("x")), ExprStmtDoc(IdDoc("y"))],
+    ],
+)
+def test_if_doc(then_branch, else_branch):
+    predicate = IdDoc("x")
+
+    if not then_branch and not else_branch:
+        with pytest.raises(ValueError) as e:
+            IfDoc(predicate, then_branch, else_branch)
+        assert "IfDoc" in str(e.value)
+        return
+    else:
+        doc = IfDoc(predicate, then_branch, else_branch)
+
+    assert doc.predicate == predicate
+    assert list(doc.then_branch) == then_branch
+    assert list(doc.else_branch) == else_branch
+
+
+@pytest.mark.parametrize(
+    "body",
+    [
+        [],
+        [ExprStmtDoc(IdDoc("x"))],
+        [ExprStmtDoc(IdDoc("x")), ExprStmtDoc(IdDoc("y"))],
+    ],
+)
+def test_while_doc(body):
+    predicate = IdDoc("x")
+
+    doc = WhileDoc(predicate, body)
+
+    assert doc.predicate == predicate
+    assert list(doc.body) == body
+
+
+@pytest.mark.parametrize(
+    "body",
+    [
+        [],
+        [ExprStmtDoc(IdDoc("x"))],
+        [ExprStmtDoc(IdDoc("x")), ExprStmtDoc(IdDoc("y"))],
+    ],
+)
+def test_for_doc(body):
+    lhs = IdDoc("x")
+    rhs = IdDoc("y")
+
+    doc = ForDoc(lhs, rhs, body)
+
+    assert doc.lhs == lhs
+    assert doc.rhs == rhs
+    assert list(doc.body) == body
+
+
+@pytest.mark.parametrize(
+    "lhs",
+    [
+        None,
+        IdDoc("x"),
+    ],
+)
+@pytest.mark.parametrize(
+    "body",
+    [
+        [],
+        [ExprStmtDoc(IdDoc("x"))],
+        [ExprStmtDoc(IdDoc("x")), ExprStmtDoc(IdDoc("y"))],
+    ],
+)
+def test_scope_doc(lhs, body):
+    rhs = IdDoc("y")
+
+    doc = ScopeDoc(lhs, rhs, body)
+
+    assert doc.lhs == lhs
+    assert doc.rhs == rhs
+    assert list(doc.body) == body
+
+
+def test_expr_stmt_doc():
+    expr = IdDoc("x")
+
+    doc = ExprStmtDoc(expr)
+
+    assert doc.expr == expr
+
+
+@pytest.mark.parametrize(
+    "msg",
+    [
+        None,
+        LiteralDoc("msg"),
+    ],
+)
+def test_assert_doc(msg):
+    test = IdDoc("x")
+
+    doc = AssertDoc(test, msg)
+
+    assert doc.test == test
+    assert doc.msg == msg
+
+
+def test_return_doc():
+    value = IdDoc("x")
+
+    doc = ReturnDoc(value)
+
+    assert doc.value == value
+
+
+@pytest.mark.parametrize(
+    "args",
+    [
+        [],
+        [AssignDoc(IdDoc("x"), None, IdDoc("int"))],
+        [
+            AssignDoc(IdDoc("x"), None, IdDoc("int")),
+            AssignDoc(IdDoc("y"), LiteralDoc(1), IdDoc("int")),
+        ],
+    ],
+)
+@pytest.mark.parametrize(
+    "decorators",
+    [
+        [],
+        [IdDoc("test")],
+        [IdDoc("test"), IdDoc("test2")],
+    ],
+)
+@pytest.mark.parametrize(
+    "body",
+    [
+        [],
+        [ExprStmtDoc(IdDoc("x"))],
+        [ExprStmtDoc(IdDoc("x")), ExprStmtDoc(IdDoc("y"))],
+    ],
+)
+def test_function_doc(args, decorators, body):
+    name = IdDoc("name")
+    return_type = LiteralDoc(None)
+
+    doc = FunctionDoc(name, args, decorators, return_type, body)
+
+    assert doc.name == name
+    assert list(doc.args) == args
+    assert list(doc.decorators) == decorators
+    assert doc.return_type == return_type
+    assert list(doc.body) == body
+
+
+@pytest.mark.parametrize(
+    "decorators",
+    [
+        [],
+        [IdDoc("test")],
+        [IdDoc("test"), IdDoc("test2")],
+    ],
+)
+@pytest.mark.parametrize(
+    "body",
+    [
+        [],
+        [ExprStmtDoc(IdDoc("x"))],
+        [ExprStmtDoc(IdDoc("x")), ExprStmtDoc(IdDoc("y"))],
+    ],
+)
+def test_class_doc(decorators, body):
+    name = IdDoc("name")
+
+    doc = ClassDoc(name, decorators, body)
+
+    assert doc.name == name
+    assert list(doc.decorators) == decorators
+    assert list(doc.body) == body
+
+
+def test_stmt_doc_comment():
+    doc = ExprStmtDoc(IdDoc("x"))
+    assert doc.comment is None
+
+    comment = "test comment"
+    doc.comment = comment
+    assert doc.comment == comment