You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2022/07/28 05:33:49 UTC
[tvm] branch main updated: [TVMScript] StmtDoc Printing (#12112)
This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 1131c92233 [TVMScript] StmtDoc Printing (#12112)
1131c92233 is described below
commit 1131c922332ab05e985fb2f40d1c4fdb9fe51b05
Author: Lite Ye <ye...@gmail.com>
AuthorDate: Thu Jul 28 01:33:44 2022 -0400
[TVMScript] StmtDoc Printing (#12112)
This PR addes:
- StmtDoc Printing in PythonDocPrinter
Tracking issue: https://github.com/apache/tvm/issues/11912
---
src/script/printer/base_doc_printer.cc | 22 +
src/script/printer/base_doc_printer.h | 63 ++-
src/script/printer/python_doc_printer.cc | 212 ++++++-
.../test_tvmscript_printer_python_doc_printer.py | 615 ++++++++++++++++++++-
4 files changed, 906 insertions(+), 6 deletions(-)
diff --git a/src/script/printer/base_doc_printer.cc b/src/script/printer/base_doc_printer.cc
index 42d3f2d8f3..4129152129 100644
--- a/src/script/printer/base_doc_printer.cc
+++ b/src/script/printer/base_doc_printer.cc
@@ -58,6 +58,28 @@ void DocPrinter::PrintDoc(const Doc& doc) {
PrintTypedDoc(GetRef<DictDoc>(doc_node));
} else if (const auto* doc_node = doc.as<SliceDocNode>()) {
PrintTypedDoc(GetRef<SliceDoc>(doc_node));
+ } else if (const auto* doc_node = doc.as<StmtBlockDocNode>()) {
+ PrintTypedDoc(GetRef<StmtBlockDoc>(doc_node));
+ } else if (const auto* doc_node = doc.as<AssignDocNode>()) {
+ PrintTypedDoc(GetRef<AssignDoc>(doc_node));
+ } else if (const auto* doc_node = doc.as<IfDocNode>()) {
+ PrintTypedDoc(GetRef<IfDoc>(doc_node));
+ } else if (const auto* doc_node = doc.as<WhileDocNode>()) {
+ PrintTypedDoc(GetRef<WhileDoc>(doc_node));
+ } else if (const auto* doc_node = doc.as<ForDocNode>()) {
+ PrintTypedDoc(GetRef<ForDoc>(doc_node));
+ } else if (const auto* doc_node = doc.as<ScopeDocNode>()) {
+ PrintTypedDoc(GetRef<ScopeDoc>(doc_node));
+ } else if (const auto* doc_node = doc.as<ExprStmtDocNode>()) {
+ PrintTypedDoc(GetRef<ExprStmtDoc>(doc_node));
+ } else if (const auto* doc_node = doc.as<AssertDocNode>()) {
+ PrintTypedDoc(GetRef<AssertDoc>(doc_node));
+ } else if (const auto* doc_node = doc.as<ReturnDocNode>()) {
+ PrintTypedDoc(GetRef<ReturnDoc>(doc_node));
+ } else if (const auto* doc_node = doc.as<FunctionDocNode>()) {
+ PrintTypedDoc(GetRef<FunctionDoc>(doc_node));
+ } else if (const auto* doc_node = doc.as<ClassDocNode>()) {
+ PrintTypedDoc(GetRef<ClassDoc>(doc_node));
} else {
LOG(FATAL) << "Do not know how to print " << doc->GetTypeKey();
throw;
diff --git a/src/script/printer/base_doc_printer.h b/src/script/printer/base_doc_printer.h
index d5bfdcd94c..8633dd0ded 100644
--- a/src/script/printer/base_doc_printer.h
+++ b/src/script/printer/base_doc_printer.h
@@ -84,22 +84,22 @@ class DocPrinter {
virtual void PrintTypedDoc(const LiteralDoc& doc) = 0;
/*!
- * \brief Virtual method to print a IdDoc
+ * \brief Virtual method to print an IdDoc
*/
virtual void PrintTypedDoc(const IdDoc& doc) = 0;
/*!
- * \brief Virtual method to print a AttrAccessDoc
+ * \brief Virtual method to print an AttrAccessDoc
*/
virtual void PrintTypedDoc(const AttrAccessDoc& doc) = 0;
/*!
- * \brief Virtual method to print a IndexDoc
+ * \brief Virtual method to print an IndexDoc
*/
virtual void PrintTypedDoc(const IndexDoc& doc) = 0;
/*!
- * \brief Virtual method to print a OperationDoc
+ * \brief Virtual method to print an OperationDoc
*/
virtual void PrintTypedDoc(const OperationDoc& doc) = 0;
@@ -133,6 +133,61 @@ class DocPrinter {
*/
virtual void PrintTypedDoc(const SliceDoc& doc) = 0;
+ /*!
+ * \brief Virtual method to print a StmtBlockDoc
+ */
+ virtual void PrintTypedDoc(const StmtBlockDoc& doc) = 0;
+
+ /*!
+ * \brief Virtual method to print an AssignDoc
+ */
+ virtual void PrintTypedDoc(const AssignDoc& doc) = 0;
+
+ /*!
+ * \brief Virtual method to print an IfDoc
+ */
+ virtual void PrintTypedDoc(const IfDoc& doc) = 0;
+
+ /*!
+ * \brief Virtual method to print a WhileDoc
+ */
+ virtual void PrintTypedDoc(const WhileDoc& doc) = 0;
+
+ /*!
+ * \brief Virtual method to print a ForDoc
+ */
+ virtual void PrintTypedDoc(const ForDoc& doc) = 0;
+
+ /*!
+ * \brief Virtual method to print a ScopeDoc
+ */
+ virtual void PrintTypedDoc(const ScopeDoc& doc) = 0;
+
+ /*!
+ * \brief Virtual method to print an ExprStmtDoc
+ */
+ virtual void PrintTypedDoc(const ExprStmtDoc& doc) = 0;
+
+ /*!
+ * \brief Virtual method to print an AssertDoc
+ */
+ virtual void PrintTypedDoc(const AssertDoc& doc) = 0;
+
+ /*!
+ * \brief Virtual method to print a ReturnDoc
+ */
+ virtual void PrintTypedDoc(const ReturnDoc& doc) = 0;
+
+ /*!
+ * \brief Virtual method to print a FunctionDoc
+ */
+ virtual void PrintTypedDoc(const FunctionDoc& doc) = 0;
+
+ /*!
+ * \brief Virtual method to print a ClassDoc
+ */
+ virtual void PrintTypedDoc(const ClassDoc& doc) = 0;
+
/*!
* \brief Increase the indent level of any content to be
* printed after this call
diff --git a/src/script/printer/python_doc_printer.cc b/src/script/printer/python_doc_printer.cc
index 5c7b048f81..f44577ff80 100644
--- a/src/script/printer/python_doc_printer.cc
+++ b/src/script/printer/python_doc_printer.cc
@@ -16,11 +16,15 @@
* specific language governing permissions and limitations
* under the License.
*/
-
#include <tvm/runtime/logging.h>
#include <tvm/runtime/registry.h>
+#include <tvm/script/printer/doc.h>
+
+#include <algorithm>
+#include <string>
#include "../../support/str_escape.h"
+#include "../../support/utils.h"
#include "./base_doc_printer.h"
namespace tvm {
@@ -45,8 +49,21 @@ class PythonDocPrinter : public DocPrinter {
void PrintTypedDoc(const DictDoc& doc) final;
void PrintTypedDoc(const TupleDoc& doc) final;
void PrintTypedDoc(const SliceDoc& doc) final;
+ void PrintTypedDoc(const StmtBlockDoc& doc) final;
+ void PrintTypedDoc(const AssignDoc& doc) final;
+ void PrintTypedDoc(const IfDoc& doc) final;
+ void PrintTypedDoc(const WhileDoc& doc) final;
+ void PrintTypedDoc(const ForDoc& doc) final;
+ void PrintTypedDoc(const ExprStmtDoc& doc) final;
+ void PrintTypedDoc(const AssertDoc& doc) final;
+ void PrintTypedDoc(const ReturnDoc& doc) final;
+ void PrintTypedDoc(const ScopeDoc& doc) final;
+ void PrintTypedDoc(const FunctionDoc& doc) final;
+ void PrintTypedDoc(const ClassDoc& doc) final;
private:
+ void NewLineWithoutIndent() { output_ << "\n"; }
+
template <typename DocType>
void PrintJoinedDocs(const Array<DocType>& docs, const std::string& separator) {
bool is_first = true;
@@ -59,6 +76,65 @@ class PythonDocPrinter : public DocPrinter {
PrintDoc(doc);
}
}
+
+ void PrintIndentedBlock(const Array<StmtDoc>& docs) {
+ IncreaseIndent();
+ for (const StmtDoc& d : docs) {
+ NewLine();
+ PrintDoc(d);
+ }
+ if (docs.empty()) {
+ NewLine();
+ output_ << "pass";
+ }
+ DecreaseIndent();
+ }
+
+ void PrintDecorators(const Array<ExprDoc>& decorators) {
+ for (const ExprDoc& decorator : decorators) {
+ output_ << "@";
+ PrintDoc(decorator);
+ NewLine();
+ }
+ }
+
+ void MaybePrintCommentInline(const StmtDoc& stmt) {
+ if (stmt->comment.defined()) {
+ const std::string& comment = stmt->comment.value();
+ bool has_newline = std::find(comment.begin(), comment.end(), '\n') != comment.end();
+ CHECK(!has_newline) << "ValueError: the comment string of " << stmt->GetTypeKey()
+ << " cannot have newline.";
+ output_ << " # " << comment;
+ }
+ }
+
+ void MaybePrintCommentWithNewLine(const StmtDoc& stmt) {
+ if (stmt->comment.defined()) {
+ std::vector<std::string> comment_lines = support::Split(stmt->comment.value(), '\n');
+ for (const std::string& line : comment_lines) {
+ output_ << "# " << line;
+ NewLine();
+ }
+ }
+ }
+
+ void PrintBlockComment(const String& comment) {
+ IncreaseIndent();
+ NewLine() << "\"\"\"";
+
+ std::vector<std::string> comment_lines = support::Split(comment, '\n');
+ for (const std::string& line : comment_lines) {
+ if (line.empty()) {
+ // No indentation on empty line
+ output_ << "\n";
+ } else {
+ NewLine() << line;
+ }
+ }
+
+ NewLine() << "\"\"\"";
+ DecreaseIndent();
+ }
};
void PythonDocPrinter::PrintTypedDoc(const LiteralDoc& doc) {
@@ -260,6 +336,140 @@ void PythonDocPrinter::PrintTypedDoc(const SliceDoc& doc) {
}
}
+void PythonDocPrinter::PrintTypedDoc(const StmtBlockDoc& doc) {
+ for (const StmtDoc& stmt : doc->stmts) {
+ PrintDoc(stmt);
+ NewLine();
+ }
+}
+
+void PythonDocPrinter::PrintTypedDoc(const AssignDoc& doc) {
+ if (const auto* tuple_doc = doc->lhs.as<TupleDocNode>()) {
+ PrintJoinedDocs(tuple_doc->elements, ", ");
+ } else {
+ PrintDoc(doc->lhs);
+ }
+
+ if (doc->annotation) {
+ output_ << ": ";
+ PrintDoc(doc->annotation.value());
+ }
+ if (doc->rhs) {
+ output_ << " = ";
+ PrintDoc(doc->rhs.value());
+ }
+ MaybePrintCommentInline(doc);
+}
+
+void PythonDocPrinter::PrintTypedDoc(const IfDoc& doc) {
+ MaybePrintCommentWithNewLine(doc);
+ output_ << "if ";
+ PrintDoc(doc->predicate);
+ output_ << ":";
+
+ PrintIndentedBlock(doc->then_branch);
+
+ if (!doc->else_branch.empty()) {
+ NewLine();
+ output_ << "else:";
+ PrintIndentedBlock(doc->else_branch);
+ }
+}
+
+void PythonDocPrinter::PrintTypedDoc(const WhileDoc& doc) {
+ MaybePrintCommentWithNewLine(doc);
+ output_ << "while ";
+ PrintDoc(doc->predicate);
+ output_ << ":";
+
+ PrintIndentedBlock(doc->body);
+}
+
+void PythonDocPrinter::PrintTypedDoc(const ForDoc& doc) {
+ MaybePrintCommentWithNewLine(doc);
+ output_ << "for ";
+ PrintDoc(doc->lhs);
+ output_ << " in ";
+ PrintDoc(doc->rhs);
+ output_ << ":";
+
+ PrintIndentedBlock(doc->body);
+}
+
+void PythonDocPrinter::PrintTypedDoc(const ScopeDoc& doc) {
+ MaybePrintCommentWithNewLine(doc);
+ output_ << "with ";
+ PrintDoc(doc->rhs);
+ if (doc->lhs != nullptr) {
+ output_ << " as ";
+ PrintDoc(doc->lhs.value());
+ }
+ output_ << ":";
+
+ PrintIndentedBlock(doc->body);
+}
+
+void PythonDocPrinter::PrintTypedDoc(const ExprStmtDoc& doc) {
+ PrintDoc(doc->expr);
+ MaybePrintCommentInline(doc);
+}
+
+void PythonDocPrinter::PrintTypedDoc(const AssertDoc& doc) {
+ output_ << "assert ";
+ PrintDoc(doc->test);
+ if (doc->msg.defined()) {
+ output_ << ", ";
+ PrintDoc(doc->msg.value());
+ }
+ MaybePrintCommentInline(doc);
+}
+
+void PythonDocPrinter::PrintTypedDoc(const ReturnDoc& doc) {
+ output_ << "return ";
+ PrintDoc(doc->value);
+ MaybePrintCommentInline(doc);
+}
+
+void PythonDocPrinter::PrintTypedDoc(const FunctionDoc& doc) {
+ for (const AssignDoc& arg_doc : doc->args) {
+ ICHECK(arg_doc->comment == nullptr) << "Function arg cannot have comment attached to them.";
+ }
+
+ PrintDecorators(doc->decorators);
+
+ output_ << "def ";
+ PrintDoc(doc->name);
+
+ output_ << "(";
+ PrintJoinedDocs(doc->args, ", ");
+ output_ << ")";
+
+ output_ << " -> ";
+ PrintDoc(doc->return_type);
+
+ output_ << ":";
+
+ if (doc->comment.defined()) {
+ PrintBlockComment(doc->comment.value());
+ }
+ PrintIndentedBlock(doc->body);
+ NewLineWithoutIndent();
+}
+
+void PythonDocPrinter::PrintTypedDoc(const ClassDoc& doc) {
+ PrintDecorators(doc->decorators);
+
+ output_ << "class ";
+ PrintDoc(doc->name);
+ output_ << ":";
+
+ if (doc->comment.defined()) {
+ PrintBlockComment(doc->comment.value());
+ }
+ PrintIndentedBlock(doc->body);
+ NewLineWithoutIndent();
+}
+
String DocToPythonScript(Doc doc, int indent_spaces) {
PythonDocPrinter printer(indent_spaces);
printer.Append(doc);
diff --git a/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py b/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py
index b65eaa6b98..523f62d8b5 100644
--- a/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py
+++ b/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py
@@ -15,18 +15,30 @@
# specific language governing permissions and limitations
# under the License.
import pytest
+import itertools
from tvm.script.printer.doc import (
+ AssertDoc,
+ AssignDoc,
CallDoc,
+ ClassDoc,
DictDoc,
+ ExprStmtDoc,
+ ForDoc,
+ FunctionDoc,
IdDoc,
+ IfDoc,
LambdaDoc,
ListDoc,
LiteralDoc,
OperationDoc,
OperationKind,
+ ReturnDoc,
+ ScopeDoc,
SliceDoc,
+ StmtBlockDoc,
TupleDoc,
+ WhileDoc,
)
from tvm.script.printer.doc_printer import to_python_script
@@ -36,10 +48,19 @@ def format_script(s: str) -> str:
Remove leading and trailing blank lines, and make the minimum idention 0
"""
s = s.strip("\n")
+
non_empty_lines = [line for line in s.splitlines() if line and not line.isspace()]
+ if not non_empty_lines:
+ # no actual content
+ return "\n"
+
line_indents = [len(line) - len(line.lstrip(" ")) for line in non_empty_lines]
spaces_to_remove = min(line_indents)
- return "\n".join(line[spaces_to_remove:] for line in s.splitlines()) + "\n"
+
+ cleaned_lines = "\n".join(line[spaces_to_remove:] for line in s.splitlines())
+ if not cleaned_lines.endswith("\n"):
+ cleaned_lines += "\n"
+ return cleaned_lines
@pytest.mark.parametrize(
@@ -59,6 +80,7 @@ def format_script(s: str) -> str:
(LiteralDoc(3.25), "3.25"),
(LiteralDoc(-0.5), "-0.5"),
],
+ ids=itertools.count(),
)
def test_print_literal_doc(doc, expected):
assert to_python_script(doc) == format_script(expected)
@@ -73,6 +95,7 @@ def test_print_literal_doc(doc, expected):
"test_case",
"test123",
],
+ ids=itertools.count(),
)
def test_print_id_doc(name):
doc = IdDoc(name)
@@ -87,6 +110,7 @@ def test_print_id_doc(name):
"Attr",
"attr_1",
],
+ ids=itertools.count(),
)
def test_print_attr_doc(attr):
doc = IdDoc("x").attr(attr)
@@ -125,6 +149,7 @@ def test_print_attr_doc(attr):
"[x, y, z]",
),
],
+ ids=itertools.count(),
)
def test_print_index_doc(indices, expected):
doc = IdDoc("x")[indices]
@@ -271,6 +296,7 @@ def test_operation_doc_test_exhaustive():
"(x, y, key0=u, key1=v)",
),
],
+ ids=itertools.count(),
)
def test_print_call_doc(args, kwargs, expected):
doc = CallDoc(IdDoc("f"), *args, **kwargs)
@@ -297,6 +323,7 @@ def test_print_call_doc(args, kwargs, expected):
"lambda x, y, z: 0",
),
],
+ ids=itertools.count(),
)
def test_print_lambda_doc(args, expected):
doc = LambdaDoc(args, body=LiteralDoc(0))
@@ -323,6 +350,7 @@ def test_print_lambda_doc(args, expected):
"[x, y, z]",
),
],
+ ids=itertools.count(),
)
def test_print_list_doc(elements, expected):
doc = ListDoc(elements)
@@ -349,6 +377,7 @@ def test_print_list_doc(elements, expected):
"(x, y, z)",
),
],
+ ids=itertools.count(),
)
def test_print_tuple_doc(elements, expected):
doc = TupleDoc(elements)
@@ -379,6 +408,7 @@ def test_print_tuple_doc(elements, expected):
'{"key_x": x, "key_y": y, "key_z": z}',
),
],
+ ids=itertools.count(),
)
def test_print_dict_doc(content, expected):
doc = DictDoc(content)
@@ -421,7 +451,590 @@ def test_print_dict_doc(content, expected):
"1:2:3",
),
],
+ ids=itertools.count(),
)
def test_print_slice_doc(slice_doc, expected):
doc = IdDoc("x")[slice_doc]
assert to_python_script(doc) == format_script(f"x[{expected}]")
+
+
+@pytest.mark.parametrize(
+ "stmts, expected",
+ [
+ (
+ [],
+ "",
+ ),
+ (
+ [ExprStmtDoc(IdDoc("x"))],
+ "x",
+ ),
+ (
+ [ExprStmtDoc(IdDoc("x")), ExprStmtDoc(IdDoc("y"))],
+ """
+ x
+ y
+ """,
+ ),
+ ],
+ ids=itertools.count(),
+)
+def test_print_stmt_block_doc(stmts, expected):
+ doc = StmtBlockDoc(stmts)
+ assert to_python_script(doc).strip() == format_script(expected).strip()
+
+
+@pytest.mark.parametrize(
+ "doc, expected",
+ [
+ (
+ AssignDoc(IdDoc("x"), IdDoc("y"), None),
+ "x = y",
+ ),
+ (
+ AssignDoc(IdDoc("x"), IdDoc("y"), IdDoc("int")),
+ "x: int = y",
+ ),
+ (
+ AssignDoc(IdDoc("x"), None, IdDoc("int")),
+ "x: int",
+ ),
+ (
+ AssignDoc(TupleDoc([IdDoc("x"), IdDoc("y")]), IdDoc("z"), None),
+ "x, y = z",
+ ),
+ (
+ AssignDoc(TupleDoc([IdDoc("x"), TupleDoc([IdDoc("y"), IdDoc("z")])]), IdDoc("z"), None),
+ "x, (y, z) = z",
+ ),
+ ],
+ ids=itertools.count(),
+)
+def test_print_assign_doc(doc, expected):
+ assert to_python_script(doc) == format_script(expected)
+
+
+@pytest.mark.parametrize(
+ "then_branch, else_branch, expected",
+ [
+ (
+ [ExprStmtDoc(IdDoc("x"))],
+ [],
+ """
+ if pred:
+ x
+ """,
+ ),
+ (
+ [],
+ [ExprStmtDoc(IdDoc("y"))],
+ """
+ if pred:
+ pass
+ else:
+ y
+ """,
+ ),
+ (
+ [ExprStmtDoc(IdDoc("x"))],
+ [ExprStmtDoc(IdDoc("y"))],
+ """
+ if pred:
+ x
+ else:
+ y
+ """,
+ ),
+ ],
+ ids=itertools.count(),
+)
+def test_print_if_doc(then_branch, else_branch, expected):
+ doc = IfDoc(IdDoc("pred"), then_branch, else_branch)
+ assert to_python_script(doc) == format_script(expected)
+
+
+@pytest.mark.parametrize(
+ "body, expected",
+ [
+ (
+ [ExprStmtDoc(IdDoc("x"))],
+ """
+ while pred:
+ x
+ """,
+ ),
+ (
+ [],
+ """
+ while pred:
+ pass
+ """,
+ ),
+ ],
+ ids=itertools.count(),
+)
+def test_print_while_doc(body, expected):
+ doc = WhileDoc(IdDoc("pred"), body)
+ assert to_python_script(doc) == format_script(expected)
+
+
+@pytest.mark.parametrize(
+ "body, expected",
+ [
+ (
+ [ExprStmtDoc(IdDoc("x"))],
+ """
+ for x in y:
+ x
+ """,
+ ),
+ (
+ [],
+ """
+ for x in y:
+ pass
+ """,
+ ),
+ ],
+ ids=itertools.count(),
+)
+def test_print_for_doc(body, expected):
+ doc = ForDoc(IdDoc("x"), IdDoc("y"), body)
+ assert to_python_script(doc) == format_script(expected)
+
+
+@pytest.mark.parametrize(
+ "lhs, body, expected",
+ [
+ (
+ IdDoc("c"),
+ [ExprStmtDoc(IdDoc("x"))],
+ """
+ with context() as c:
+ x
+ """,
+ ),
+ (
+ IdDoc("c"),
+ [],
+ """
+ with context() as c:
+ pass
+ """,
+ ),
+ (
+ None,
+ [],
+ """
+ with context():
+ pass
+ """,
+ ),
+ (
+ None,
+ [ExprStmtDoc(IdDoc("x"))],
+ """
+ with context():
+ x
+ """,
+ ),
+ ],
+ ids=itertools.count(),
+)
+def test_print_scope_doc(lhs, body, expected):
+ doc = ScopeDoc(lhs, CallDoc(IdDoc("context")), body)
+ assert to_python_script(doc) == format_script(expected)
+
+
+def test_print_expr_stmt_doc():
+ doc = ExprStmtDoc(CallDoc(IdDoc("f"), IdDoc("x")))
+ assert to_python_script(doc) == format_script("f(x)")
+
+
+@pytest.mark.parametrize(
+ "msg, expected",
+ [
+ (
+ None,
+ """
+ assert True
+ """,
+ ),
+ (
+ LiteralDoc("test message"),
+ """
+ assert True, "test message"
+ """,
+ ),
+ ],
+ ids=itertools.count(),
+)
+def test_print_assert_doc(msg, expected):
+ test = LiteralDoc(True)
+
+ doc = AssertDoc(test, msg)
+
+ assert to_python_script(doc) == format_script(expected)
+
+
+@pytest.mark.parametrize(
+ "value, expected",
+ [
+ (
+ LiteralDoc(None),
+ """
+ return None
+ """,
+ ),
+ (
+ IdDoc("x"),
+ """
+ return x
+ """,
+ ),
+ ],
+ ids=itertools.count(),
+)
+def test_print_return_doc(value, expected):
+ doc = ReturnDoc(value)
+ assert to_python_script(doc) == format_script(expected)
+
+
+@pytest.mark.parametrize(
+ "args, decorators, body, expected",
+ [
+ (
+ [],
+ [],
+ [],
+ """
+ def func() -> None:
+ pass
+ """,
+ ),
+ (
+ [AssignDoc(IdDoc("x"), rhs=None, annotation=IdDoc("int"))],
+ [],
+ [],
+ """
+ def func(x: int) -> None:
+ pass
+ """,
+ ),
+ (
+ [AssignDoc(IdDoc("x"), rhs=LiteralDoc(1), annotation=IdDoc("int"))],
+ [],
+ [],
+ """
+ def func(x: int = 1) -> None:
+ pass
+ """,
+ ),
+ (
+ [],
+ [IdDoc("wrap")],
+ [],
+ """
+ @wrap
+ def func() -> None:
+ pass
+ """,
+ ),
+ (
+ [],
+ [IdDoc("wrap_outter"), IdDoc("wrap_inner")],
+ [],
+ """
+ @wrap_outter
+ @wrap_inner
+ def func() -> None:
+ pass
+ """,
+ ),
+ (
+ [
+ AssignDoc(IdDoc("x"), rhs=None, annotation=IdDoc("int")),
+ AssignDoc(IdDoc("y"), rhs=LiteralDoc(1), annotation=IdDoc("int")),
+ ],
+ [IdDoc("wrap")],
+ [],
+ """
+ @wrap
+ def func(x: int, y: int = 1) -> None:
+ pass
+ """,
+ ),
+ (
+ [
+ AssignDoc(IdDoc("x"), rhs=None, annotation=IdDoc("int")),
+ AssignDoc(IdDoc("y"), rhs=LiteralDoc(1), annotation=IdDoc("int")),
+ ],
+ [IdDoc("wrap")],
+ [
+ AssignDoc(IdDoc("y"), OperationDoc(OperationKind.Add, [IdDoc("x"), LiteralDoc(1)])),
+ AssignDoc(IdDoc("y"), OperationDoc(OperationKind.Sub, [IdDoc("y"), LiteralDoc(1)])),
+ ],
+ """
+ @wrap
+ def func(x: int, y: int = 1) -> None:
+ y = x + 1
+ y = y - 1
+ """,
+ ),
+ ],
+ ids=itertools.count(),
+)
+def test_print_function_doc(args, decorators, body, expected):
+ doc = FunctionDoc(IdDoc("func"), args, decorators, LiteralDoc(None), body)
+ assert to_python_script(doc) == format_script(expected) # test
+
+
+def get_func_doc_for_class(name):
+ args = [
+ AssignDoc(IdDoc("x"), rhs=None, annotation=IdDoc("int")),
+ AssignDoc(IdDoc("y"), rhs=LiteralDoc(1), annotation=IdDoc("int")),
+ ]
+ body = [
+ AssignDoc(IdDoc("y"), OperationDoc(OperationKind.Add, [IdDoc("x"), LiteralDoc(1)])),
+ AssignDoc(IdDoc("y"), OperationDoc(OperationKind.Sub, [IdDoc("y"), LiteralDoc(1)])),
+ ]
+ return FunctionDoc(
+ name=IdDoc(name),
+ args=args,
+ decorators=[IdDoc("wrap")],
+ return_type=LiteralDoc(None),
+ body=body,
+ )
+
+
+@pytest.mark.parametrize(
+ "decorators, body, expected",
+ [
+ (
+ [],
+ [],
+ """
+ class TestClass:
+ pass
+ """,
+ ),
+ (
+ [IdDoc("wrap")],
+ [],
+ """
+ @wrap
+ class TestClass:
+ pass
+ """,
+ ),
+ (
+ [IdDoc("wrap_outter"), IdDoc("wrap_inner")],
+ [],
+ """
+ @wrap_outter
+ @wrap_inner
+ class TestClass:
+ pass
+ """,
+ ),
+ (
+ [IdDoc("wrap")],
+ [get_func_doc_for_class("f1")],
+ """
+ @wrap
+ class TestClass:
+ @wrap
+ def f1(x: int, y: int = 1) -> None:
+ y = x + 1
+ y = y - 1
+
+ """,
+ ),
+ (
+ [IdDoc("wrap")],
+ [get_func_doc_for_class("f1"), get_func_doc_for_class("f2")],
+ """
+ @wrap
+ class TestClass:
+ @wrap
+ def f1(x: int, y: int = 1) -> None:
+ y = x + 1
+ y = y - 1
+
+ @wrap
+ def f2(x: int, y: int = 1) -> None:
+ y = x + 1
+ y = y - 1
+
+ """,
+ ),
+ ],
+ ids=itertools.count(),
+)
+def test_print_class_doc(decorators, body, expected):
+ doc = ClassDoc(IdDoc("TestClass"), decorators, body)
+ assert to_python_script(doc) == format_script(expected)
+
+
+@pytest.mark.parametrize(
+ "doc, comment, expected",
+ [
+ (
+ AssignDoc(IdDoc("x"), IdDoc("y"), IdDoc("int")),
+ "comment",
+ """
+ x: int = y # comment
+ """,
+ ),
+ (
+ IfDoc(IdDoc("x"), [ExprStmtDoc(IdDoc("y"))], [ExprStmtDoc(IdDoc("z"))]),
+ "comment",
+ """
+ # comment
+ if x:
+ y
+ else:
+ z
+ """,
+ ),
+ (
+ IfDoc(IdDoc("x"), [ExprStmtDoc(IdDoc("y"))], [ExprStmtDoc(IdDoc("z"))]),
+ "comment line 1\ncomment line 2",
+ """
+ # comment line 1
+ # comment line 2
+ if x:
+ y
+ else:
+ z
+ """,
+ ),
+ (
+ WhileDoc(
+ LiteralDoc(True),
+ [
+ AssignDoc(IdDoc("x"), IdDoc("y")),
+ ],
+ ),
+ "comment",
+ """
+ # comment
+ while True:
+ x = y
+ """,
+ ),
+ (
+ ForDoc(IdDoc("x"), IdDoc("y"), []),
+ "comment",
+ """
+ # comment
+ for x in y:
+ pass
+ """,
+ ),
+ (
+ ScopeDoc(IdDoc("x"), IdDoc("y"), []),
+ "comment",
+ """
+ # comment
+ with y as x:
+ pass
+ """,
+ ),
+ (
+ ExprStmtDoc(IdDoc("x")),
+ "comment",
+ """
+ x # comment
+ """,
+ ),
+ (
+ AssertDoc(LiteralDoc(True)),
+ "comment",
+ """
+ assert True # comment
+ """,
+ ),
+ (
+ ReturnDoc(LiteralDoc(1)),
+ "comment",
+ """
+ return 1 # comment
+ """,
+ ),
+ (
+ get_func_doc_for_class("f"),
+ "comment",
+ '''
+ @wrap
+ def f(x: int, y: int = 1) -> None:
+ """
+ comment
+ """
+ y = x + 1
+ y = y - 1
+ ''',
+ ),
+ (
+ get_func_doc_for_class("f"),
+ "comment line 1\n\ncomment line 3",
+ '''
+ @wrap
+ def f(x: int, y: int = 1) -> None:
+ """
+ comment line 1
+
+ comment line 3
+ """
+ y = x + 1
+ y = y - 1
+ ''',
+ ),
+ (
+ ClassDoc(IdDoc("TestClass"), decorators=[IdDoc("wrap")], body=[]),
+ "comment",
+ '''
+ @wrap
+ class TestClass:
+ """
+ comment
+ """
+ pass
+ ''',
+ ),
+ (
+ ClassDoc(IdDoc("TestClass"), decorators=[IdDoc("wrap")], body=[]),
+ "comment line 1\n\ncomment line 3",
+ '''
+ @wrap
+ class TestClass:
+ """
+ comment line 1
+
+ comment line 3
+ """
+ pass
+ ''',
+ ),
+ ],
+ ids=itertools.count(),
+)
+def test_print_doc_comment(doc, comment, expected):
+ doc.comment = comment
+ assert to_python_script(doc) == format_script(expected)
+
+
+@pytest.mark.parametrize(
+ "doc",
+ [
+ AssignDoc(IdDoc("x"), IdDoc("y"), IdDoc("int")),
+ ExprStmtDoc(IdDoc("x")),
+ AssertDoc(IdDoc("x")),
+ ReturnDoc(IdDoc("x")),
+ ],
+)
+def test_print_invalid_multiline_doc_comment(doc):
+ doc.comment = "1\n2"
+ with pytest.raises(ValueError) as e:
+ to_python_script(doc)
+ assert "cannot have newline" in str(e.value)