You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2022/07/07 05:11:15 UTC
[tvm] branch main updated: [TVMScript] Doc Base Class & DocPrinter Scaffolding (#11971)
This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 9f4bf38b57 [TVMScript] Doc Base Class & DocPrinter Scaffolding (#11971)
9f4bf38b57 is described below
commit 9f4bf38b5766609317e9a52bc60d66679ceddf02
Author: Lite Ye <ye...@gmail.com>
AuthorDate: Thu Jul 7 01:11:10 2022 -0400
[TVMScript] Doc Base Class & DocPrinter Scaffolding (#11971)
This PR addes:
- Doc base class
- DocPrinter base class
- PythonDocPrinter
- LiteralDoc and its support in DocPrinter
Tracking issue: #11912
---
CMakeLists.txt | 1 +
include/tvm/script/printer/doc.h | 165 +++++++++++++++++++++
include/tvm/script/printer/doc_printer.h | 43 ++++++
python/tvm/script/printer/__init__.py | 26 ++++
python/tvm/script/printer/_ffi_api.py | 20 +++
python/tvm/script/printer/doc.py | 49 ++++++
python/tvm/script/printer/doc_printer.py | 39 +++++
src/script/printer/base_doc_printer.cc | 49 ++++++
src/script/printer/base_doc_printer.h | 131 ++++++++++++++++
src/script/printer/doc.cc | 43 ++++++
src/script/printer/python_doc_printer.cc | 70 +++++++++
.../python/unittest/test_tvmscript_printer_doc.py | 33 +++++
.../test_tvmscript_printer_python_doc_printer.py | 53 +++++++
tests/scripts/task_mypy.sh | 3 +
14 files changed, 725 insertions(+)
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 306a8be308..46de8f5d07 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -281,6 +281,7 @@ tvm_file_glob(GLOB_RECURSE COMPILER_SRCS
src/parser/*.cc
src/printer/*.cc
src/support/*.cc
+ src/script/*.cc
)
tvm_file_glob(GLOB CODEGEN_SRCS
diff --git a/include/tvm/script/printer/doc.h b/include/tvm/script/printer/doc.h
new file mode 100644
index 0000000000..67c27bd45a
--- /dev/null
+++ b/include/tvm/script/printer/doc.h
@@ -0,0 +1,165 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef TVM_SCRIPT_PRINTER_DOC_H_
+#define TVM_SCRIPT_PRINTER_DOC_H_
+
+#include <tvm/ir/expr.h>
+#include <tvm/node/node.h>
+#include <tvm/runtime/data_type.h>
+
+namespace tvm {
+namespace script {
+namespace printer {
+
+/*!
+ * \brief The base class of all Doc.
+ *
+ * Doc is an intermediate representation between IR from TVM
+ * and the TVMScript code.
+ * During printing, IR graph is first translated into Doc tree,
+ * then the Doc tree is translated to the target language in
+ * text format.
+ *
+ * \sa Doc
+ */
+class DocNode : public Object {
+ public:
+ void VisitAttrs(AttrVisitor* v) {}
+
+ static constexpr const char* _type_key = "script.printer.Doc";
+ TVM_DECLARE_BASE_OBJECT_INFO(DocNode, Object);
+
+ public:
+ virtual ~DocNode() = default;
+};
+
+/*!
+ * \brief Reference type of DocNode.
+ *
+ * \sa DocNode
+ */
+class Doc : public ObjectRef {
+ protected:
+ Doc() = default;
+
+ public:
+ virtual ~Doc() = default;
+ TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Doc, ObjectRef, DocNode);
+};
+
+/*!
+ * \brief The base class of expression doc.
+ *
+ * \sa ExprDoc
+ */
+class ExprDocNode : public DocNode {
+ public:
+ void VisitAttrs(AttrVisitor* v) { DocNode::VisitAttrs(v); }
+
+ static constexpr const char* _type_key = "script.printer.ExprDoc";
+ TVM_DECLARE_BASE_OBJECT_INFO(ExprDocNode, DocNode);
+};
+
+/*!
+ * \brief Reference type of ExprDocNode.
+ *
+ * \sa ExprDocNode
+ */
+class ExprDoc : public Doc {
+ protected:
+ ExprDoc() = default;
+
+ public:
+ TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ExprDoc, Doc, ExprDocNode);
+};
+
+/*!
+ * \brief Doc that represents literal value.
+ *
+ * \sa LiteralDoc
+ */
+class LiteralDocNode : public ExprDocNode {
+ public:
+ /*!
+ * \brief the internal representation of the literal value.
+ *
+ * Possible actual types:
+ * - IntImm (integer or boolean)
+ * - FloatImm
+ * - String
+ * - null
+ */
+ ObjectRef value;
+
+ void VisitAttrs(AttrVisitor* v) {
+ ExprDocNode::VisitAttrs(v);
+ v->Visit("value", &value);
+ }
+
+ static constexpr const char* _type_key = "script.printer.LiteralDoc";
+ TVM_DECLARE_FINAL_OBJECT_INFO(LiteralDocNode, ExprDocNode);
+};
+
+/*!
+ * \brief Reference type of LiteralDocNode.
+ *
+ * \sa LiteralDocNode
+ */
+class LiteralDoc : public ExprDoc {
+ protected:
+ explicit LiteralDoc(ObjectRef value);
+
+ public:
+ /*!
+ * \brief Create a LiteralDoc to represent None/null/empty value.
+ */
+ static LiteralDoc None() { return LiteralDoc(ObjectRef(nullptr)); }
+
+ /*!
+ * \brief Create a LiteralDoc to represent integer.
+ * \param v The integer value.
+ */
+ static LiteralDoc Int(int v) { return LiteralDoc(IntImm(DataType::Int(64), v)); }
+
+ /*!
+ * \brief Create a LiteralDoc to represent boolean.
+ * \param v The boolean value.
+ */
+ static LiteralDoc Boolean(bool v) { return LiteralDoc(IntImm(DataType::Bool(), v)); }
+
+ /*!
+ * \brief Create a LiteralDoc to represent float.
+ * \param v The float value.
+ */
+ static LiteralDoc Float(double v) { return LiteralDoc(FloatImm(DataType::Float(64), v)); }
+
+ /*!
+ * \brief Create a LiteralDoc to represent string.
+ * \param v The string value.
+ */
+ static LiteralDoc Str(const String& v) { return LiteralDoc(v); }
+
+ TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(LiteralDoc, ExprDoc, LiteralDocNode);
+};
+
+} // namespace printer
+} // namespace script
+} // namespace tvm
+
+#endif // TVM_SCRIPT_PRINTER_DOC_H_
diff --git a/include/tvm/script/printer/doc_printer.h b/include/tvm/script/printer/doc_printer.h
new file mode 100644
index 0000000000..6bf502fab9
--- /dev/null
+++ b/include/tvm/script/printer/doc_printer.h
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef TVM_SCRIPT_PRINTER_DOC_PRINTER_H_
+#define TVM_SCRIPT_PRINTER_DOC_PRINTER_H_
+
+#include <tvm/script/printer/doc.h>
+
+namespace tvm {
+namespace script {
+namespace printer {
+
+/*!
+ * \brief Convert Doc into Python script.
+ *
+ * This function unpacks the DocPrinterOptions into function arguments
+ * to be FFI friendly.
+ *
+ * \param doc the doc to be converted
+ * \param indent_spaces the number of spaces used for indention
+ */
+String DocToPythonScript(Doc doc, int indent_spaces = 4);
+
+} // namespace printer
+} // namespace script
+} // namespace tvm
+
+#endif // TVM_SCRIPT_PRINTER_DOC_PRINTER_H_
diff --git a/python/tvm/script/printer/__init__.py b/python/tvm/script/printer/__init__.py
new file mode 100644
index 0000000000..84ab7b0ba8
--- /dev/null
+++ b/python/tvm/script/printer/__init__.py
@@ -0,0 +1,26 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+TVMScript Unified Printer
+
+This package provides a set of APIs to print supported TVM IR into TVMScript
+in a roundtrippable way.
+
+https://github.com/apache/tvm-rfcs/blob/main/rfcs/0074-tvmscript-unified-printer.md
+"""
+
+from . import _ffi_api
diff --git a/python/tvm/script/printer/_ffi_api.py b/python/tvm/script/printer/_ffi_api.py
new file mode 100644
index 0000000000..baa639fe2d
--- /dev/null
+++ b/python/tvm/script/printer/_ffi_api.py
@@ -0,0 +1,20 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""FFI APIs for tvm.script.printer"""
+import tvm._ffi
+
+tvm._ffi._init_api("script.printer", __name__)
diff --git a/python/tvm/script/printer/doc.py b/python/tvm/script/printer/doc.py
new file mode 100644
index 0000000000..f6179d7351
--- /dev/null
+++ b/python/tvm/script/printer/doc.py
@@ -0,0 +1,49 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Doc types for TVMScript Unified Printer"""
+
+import tvm._ffi
+from tvm.runtime import Object
+
+from . import _ffi_api
+
+
+class Doc(Object):
+ """Base class of all Docs"""
+
+
+class ExprDoc(Object):
+ """Base class of all expression Docs"""
+
+
+@tvm._ffi.register_object("script.printer.LiteralDoc")
+class LiteralDoc(ExprDoc):
+ """Doc that represents literal value"""
+
+ def __init__(self, value):
+ if value is None:
+ self.__init_handle_by_constructor__(_ffi_api.LiteralDocNone) # type: ignore
+ elif isinstance(value, str):
+ self.__init_handle_by_constructor__(_ffi_api.LiteralDocStr, value) # type: ignore
+ elif isinstance(value, float):
+ self.__init_handle_by_constructor__(_ffi_api.LiteralDocFloat, value) # type: ignore
+ elif isinstance(value, bool):
+ self.__init_handle_by_constructor__(_ffi_api.LiteralDocBoolean, value) # type: ignore
+ elif isinstance(value, int):
+ self.__init_handle_by_constructor__(_ffi_api.LiteralDocInt, value) # type: ignore
+ else:
+ raise TypeError(f"Unsupported type {type(value)} for LiteralDoc")
diff --git a/python/tvm/script/printer/doc_printer.py b/python/tvm/script/printer/doc_printer.py
new file mode 100644
index 0000000000..404632b44c
--- /dev/null
+++ b/python/tvm/script/printer/doc_printer.py
@@ -0,0 +1,39 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Functions to print doc into text format"""
+
+from . import _ffi_api
+from .doc import Doc
+
+
+def to_python_script(doc: Doc, indent_spaces: int = 4) -> str:
+ """
+ Convert Doc into Python script.
+
+ Parameters
+ ----------
+ doc : Doc
+ The doc to convert into Python script
+ indent_spaces : int
+ The number of indent spaces to use in the output
+
+ Returns
+ -------
+ script : str
+ The text representation of Doc in Python syntax
+ """
+ return _ffi_api.DocToPythonScript(doc, indent_spaces) # type: ignore
diff --git a/src/script/printer/base_doc_printer.cc b/src/script/printer/base_doc_printer.cc
new file mode 100644
index 0000000000..f6874ba1a2
--- /dev/null
+++ b/src/script/printer/base_doc_printer.cc
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "./base_doc_printer.h"
+
+namespace tvm {
+namespace script {
+namespace printer {
+
+DocPrinter::DocPrinter(int indent_spaces) : indent_spaces_(indent_spaces) {}
+
+void DocPrinter::Append(const Doc& doc) { PrintDoc(doc); }
+
+String DocPrinter::GetString() const {
+ std::string text = output_.str();
+ if (!text.empty() && text.back() != '\n') {
+ text.push_back('\n');
+ }
+ return text;
+}
+
+void DocPrinter::PrintDoc(const Doc& doc) {
+ if (const auto* doc_node = doc.as<LiteralDocNode>()) {
+ PrintTypedDoc(GetRef<LiteralDoc>(doc_node));
+ } else {
+ LOG(FATAL) << "Do not know how to print " << doc->GetTypeKey();
+ throw;
+ }
+}
+
+} // namespace printer
+} // namespace script
+} // namespace tvm
diff --git a/src/script/printer/base_doc_printer.h b/src/script/printer/base_doc_printer.h
new file mode 100644
index 0000000000..128fcef2ea
--- /dev/null
+++ b/src/script/printer/base_doc_printer.h
@@ -0,0 +1,131 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef TVM_SCRIPT_PRINTER_BASE_DOC_PRINTER_H_
+#define TVM_SCRIPT_PRINTER_BASE_DOC_PRINTER_H_
+
+#include <tvm/script/printer/doc.h>
+#include <tvm/script/printer/doc_printer.h>
+
+#include <memory>
+#include <ostream>
+#include <string>
+
+namespace tvm {
+namespace script {
+namespace printer {
+
+/*!
+ * \brief DocPrinter is responsible for printing Doc tree into text format
+ * \details This is the base class for translating Doc into string.
+ * Each target language needs to have its subclass of DocPrinter
+ * to define the actual logic of printing Doc.
+ *
+ * \sa Doc
+ */
+class DocPrinter {
+ public:
+ /*!
+ * \brief The constructor of DocPrinter
+ *
+ * \param options the option for printer
+ */
+ explicit DocPrinter(int indent_spaces = 4);
+ virtual ~DocPrinter() = default;
+
+ /*!
+ * \brief Append a doc into the final content
+ *
+ * \param doc the Doc to be printed
+ *
+ * \sa GetString
+ */
+ void Append(const Doc& doc);
+
+ /*!
+ * \brief Get the printed string of all Doc appended
+ *
+ * The content of each Doc in the returned string will
+ * appear in the same order as they are appended.
+ *
+ * \sa Append
+ */
+ String GetString() const;
+
+ protected:
+ /*!
+ * \brief Get the printed string
+ *
+ * It will dispatch to the PrintTypedDoc method based on
+ * the actual type of Doc.
+ *
+ * \sa PrintTypedDoc
+ */
+ void PrintDoc(const Doc& doc);
+
+ /*!
+ * \brief Virtual method to print a LiteralDoc
+ */
+ virtual void PrintTypedDoc(const LiteralDoc& doc) = 0;
+
+ /*!
+ * \brief Increase the indent level of any content to be
+ * printed after this call
+ */
+ void IncreaseIndent() { indent_ += indent_spaces_; }
+
+ /*!
+ * \brief Decrease the indent level of any content to be
+ * printed after this call
+ */
+ void DecreaseIndent() { indent_ -= indent_spaces_; }
+
+ /*!
+ * \brief Add a new line into the output stream
+ *
+ * \sa output_
+ */
+ std::ostream& NewLine() {
+ output_ << "\n";
+ output_ << std::string(indent_, ' ');
+ return output_;
+ }
+
+ /*!
+ * \brief The output stream of printer
+ *
+ * All printed content will be stored in this stream and returned
+ * when GetString is called.
+ *
+ * \sa GetString
+ */
+ std::ostringstream output_;
+
+ private:
+ /*! \brief the number of spaces for one level of indentation */
+ int indent_spaces_ = 4;
+
+ /*! \brief the current level of indent */
+ int indent_ = 0;
+};
+
+} // namespace printer
+} // namespace script
+} // namespace tvm
+
+#endif // TVM_SCRIPT_PRINTER_BASE_DOC_PRINTER_H_
diff --git a/src/script/printer/doc.cc b/src/script/printer/doc.cc
new file mode 100644
index 0000000000..e54adbd36b
--- /dev/null
+++ b/src/script/printer/doc.cc
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <tvm/runtime/registry.h>
+#include <tvm/script/printer/doc.h>
+
+namespace tvm {
+namespace script {
+namespace printer {
+
+LiteralDoc::LiteralDoc(ObjectRef value) {
+ ObjectPtr<LiteralDocNode> n = make_object<LiteralDocNode>();
+ n->value = value;
+ this->data_ = std::move(n);
+}
+
+TVM_REGISTER_NODE_TYPE(DocNode);
+TVM_REGISTER_NODE_TYPE(ExprDocNode);
+TVM_REGISTER_NODE_TYPE(LiteralDocNode);
+TVM_REGISTER_GLOBAL("script.printer.LiteralDocNone").set_body_typed(LiteralDoc::None);
+TVM_REGISTER_GLOBAL("script.printer.LiteralDocInt").set_body_typed(LiteralDoc::Int);
+TVM_REGISTER_GLOBAL("script.printer.LiteralDocBoolean").set_body_typed(LiteralDoc::Boolean);
+TVM_REGISTER_GLOBAL("script.printer.LiteralDocFloat").set_body_typed(LiteralDoc::Float);
+TVM_REGISTER_GLOBAL("script.printer.LiteralDocStr").set_body_typed(LiteralDoc::Str);
+
+} // namespace printer
+} // namespace script
+} // namespace tvm
diff --git a/src/script/printer/python_doc_printer.cc b/src/script/printer/python_doc_printer.cc
new file mode 100644
index 0000000000..cd816e4f70
--- /dev/null
+++ b/src/script/printer/python_doc_printer.cc
@@ -0,0 +1,70 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <tvm/runtime/registry.h>
+
+#include "../../support/str_escape.h"
+#include "./base_doc_printer.h"
+
+namespace tvm {
+namespace script {
+namespace printer {
+
+class PythonDocPrinter : public DocPrinter {
+ public:
+ explicit PythonDocPrinter(int indent_spaces = 4) : DocPrinter(indent_spaces) {}
+
+ protected:
+ using DocPrinter::PrintDoc;
+
+ void PrintTypedDoc(const LiteralDoc& doc) final;
+};
+
+void PythonDocPrinter::PrintTypedDoc(const LiteralDoc& doc) {
+ const ObjectRef& value = doc->value;
+ if (!value.defined()) {
+ output_ << "None";
+ } else if (const auto* int_imm = value.as<IntImmNode>()) {
+ if (int_imm->dtype.is_bool()) {
+ output_ << (int_imm->value ? "True" : "False");
+ } else {
+ output_ << int_imm->value;
+ }
+ } else if (const auto* float_imm = value.as<FloatImmNode>()) {
+ // TODO(yelite): Make float number printing roundtrippable
+ output_.precision(17);
+ output_ << float_imm->value;
+ } else if (const auto* string_obj = value.as<StringObj>()) {
+ output_ << "\"" << support::StrEscape(string_obj->data, string_obj->size) << "\"";
+ } else {
+ LOG(FATAL) << "TypeError: Unsupported literal value type: " << value->GetTypeKey();
+ }
+}
+
+String DocToPythonScript(Doc doc, int indent_spaces) {
+ PythonDocPrinter printer(indent_spaces);
+ printer.Append(doc);
+ return printer.GetString();
+}
+
+TVM_REGISTER_GLOBAL("script.printer.DocToPythonScript").set_body_typed(DocToPythonScript);
+
+} // namespace printer
+} // namespace script
+} // namespace tvm
diff --git a/tests/python/unittest/test_tvmscript_printer_doc.py b/tests/python/unittest/test_tvmscript_printer_doc.py
new file mode 100644
index 0000000000..6330d33bf2
--- /dev/null
+++ b/tests/python/unittest/test_tvmscript_printer_doc.py
@@ -0,0 +1,33 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import pytest
+
+from tvm.tir import IntImm
+from tvm.script.printer.doc import LiteralDoc
+
+
+@pytest.mark.parametrize(
+ "value",
+ [None, "test", 0, 1, -2, 0.0, 1.5, -1.3, True, False],
+)
+def test_literal_doc_construction(value):
+ doc = LiteralDoc(value)
+ if isinstance(value, float):
+ # FloatImm cannot be compared with Python's float directly
+ assert float(doc.value) == pytest.approx(value)
+ else:
+ assert doc.value == value
diff --git a/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py b/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py
new file mode 100644
index 0000000000..55b5e88c88
--- /dev/null
+++ b/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py
@@ -0,0 +1,53 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import pytest
+
+from tvm.script.printer.doc_printer import to_python_script
+from tvm.script.printer.doc import LiteralDoc
+
+
+def format_script(s: str) -> str:
+ """
+ Remove leading and trailing blank lines, and make the minimum idention 0
+ """
+ s = s.strip("\n")
+ non_empty_lines = [line for line in s.splitlines() if line and not line.isspace()]
+ line_indents = [len(line) - len(line.lstrip(" ")) for line in non_empty_lines]
+ spaces_to_remove = min(line_indents)
+ return "\n".join(line[spaces_to_remove:] for line in s.splitlines())
+
+
+@pytest.mark.parametrize(
+ "doc,expected",
+ [
+ (LiteralDoc(None), "None"),
+ (LiteralDoc(True), "True"),
+ (LiteralDoc(False), "False"),
+ (LiteralDoc("test"), '"test"'),
+ (LiteralDoc(""), '""'),
+ (LiteralDoc('""'), r'"\"\""'),
+ (LiteralDoc("\n\t\\test\r"), r'"\n\t\\test\r"'),
+ # TODO: fix the roundatrippable problem caused by utf8
+ pytest.param(LiteralDoc("\x88"), r'"\x88"', marks=pytest.mark.xfail),
+ (LiteralDoc(0), "0"),
+ (LiteralDoc(-1), "-1"),
+ (LiteralDoc(3.25), "3.25"),
+ (LiteralDoc(-0.5), "-0.5"),
+ ],
+)
+def test_print_literal_doc(doc, expected):
+ assert to_python_script(doc).rstrip("\n") == format_script(expected)
diff --git a/tests/scripts/task_mypy.sh b/tests/scripts/task_mypy.sh
index 1ef7db5894..f165adfe1b 100755
--- a/tests/scripts/task_mypy.sh
+++ b/tests/scripts/task_mypy.sh
@@ -32,6 +32,9 @@ mypy --check-untyped-defs python/tvm/tir/analysis/
echo "Checking MyPy Type defs in the transform package."
mypy --check-untyped-defs python/tvm/tir/transform/
+echo "Checking MyPy Type defs in the tvmscript printer package."
+mypy --check-untyped-defs python/tvm/script/printer
+
echo "Checking MyPy Type defs in the TIR package with unittest"
MYPYPATH=$TVM_PATH/python mypy --check-untyped-defs tests/python/unittest/test_tvmscript_type.py