You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by wu...@apache.org on 2021/12/08 15:31:03 UTC

[tvm] branch main updated: [TVMScript] Add syntax sugar for T.handle and T.match_buffer (#9492)

This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new e8889ae  [TVMScript] Add syntax sugar for T.handle and T.match_buffer (#9492)
e8889ae is described below

commit e8889ae0e3adc9cc9b02952d7ac4aac33e9f66b1
Author: Yuanjing Shi <yj...@shingjan.me>
AuthorDate: Wed Dec 8 07:30:27 2021 -0800

    [TVMScript] Add syntax sugar for T.handle and T.match_buffer (#9492)
---
 docker/install/ubuntu_install_python_package.sh    |  2 +-
 python/gen_requirements.py                         |  2 +-
 python/tvm/script/parser.py                        | 61 +++++++++++++++---
 python/tvm/script/tir/__init__.py                  |  2 +-
 python/tvm/script/tir/ty.py                        | 73 ++++++++++++++++++++++
 .../python/unittest/test_tvmscript_syntax_sugar.py | 45 +++++++++++++
 tests/scripts/task_ci_setup.sh                     |  2 +-
 7 files changed, 174 insertions(+), 13 deletions(-)

diff --git a/docker/install/ubuntu_install_python_package.sh b/docker/install/ubuntu_install_python_package.sh
index fb0f596..8b79455 100755
--- a/docker/install/ubuntu_install_python_package.sh
+++ b/docker/install/ubuntu_install_python_package.sh
@@ -36,6 +36,6 @@ pip3 install \
     pytest-xdist \
     requests \
     scipy \
-    synr==0.5.0 \
+    synr==0.6.0 \
     six \
     tornado
diff --git a/python/gen_requirements.py b/python/gen_requirements.py
index b4f3907..bcd8ccd 100755
--- a/python/gen_requirements.py
+++ b/python/gen_requirements.py
@@ -255,7 +255,7 @@ CONSTRAINTS = [
     ("sphinx_autodoc_annotation", None),
     ("sphinx_gallery", None),
     ("sphinx_rtd_theme", None),
-    ("synr", "==0.5.0"),
+    ("synr", "==0.6.0"),
     ("tensorflow", None),
     ("tensorflow-estimator", None),
     ("tflite", None),
diff --git a/python/tvm/script/parser.py b/python/tvm/script/parser.py
index 6cb22ae..0132025 100644
--- a/python/tvm/script/parser.py
+++ b/python/tvm/script/parser.py
@@ -32,6 +32,7 @@ from tvm import IRModule
 from tvm._ffi.base import TVMError
 from tvm.ir import GlobalVar
 from tvm.ir.function import BaseFunc
+from tvm.tir import buffer
 from tvm.tir.function import PrimFunc
 from . import _ffi_api
 from . import tir
@@ -154,10 +155,10 @@ class TVMScriptParser(Transformer):
         ast.BuiltinOp.Not: tvm.tir.Not,
     }
 
-    def __init__(self, base_lienno, tir_namespace):
+    def __init__(self, base_lineno, tir_namespace):
         self.context = None
 
-        self.base_lineno = base_lienno
+        self.base_lineno = base_lineno
         self.current_lineno = 0
         self.current_col_offset = 0
         self.tir_namespace = tir_namespace
@@ -249,7 +250,7 @@ class TVMScriptParser(Transformer):
         func : Function
             The function that provides the signature
 
-        node_call: ast.Call
+        node_call: Union[ast.Call, ast.TypeApply, ast.TypeCall]
             The AST call node that calls into the function.
 
         Returns
@@ -257,12 +258,15 @@ class TVMScriptParser(Transformer):
         arg_list : list
             The parsed positional argument.
         """
-        assert isinstance(node_call, ast.Call)
+        assert isinstance(node_call, (ast.Call, ast.TypeApply, ast.TypeCall))
         # collect arguments
         args = [self.transform(arg) for arg in node_call.params]
-        kw_args = {
-            self.transform(k): self.transform(v) for k, v in node_call.keyword_params.items()
-        }
+        if isinstance(node_call, ast.TypeApply):
+            kw_args = {}  # TypeApply (e.g. foo[bar]) doesn't have kwargs defined in synr
+        else:
+            kw_args = {
+                self.transform(k): self.transform(v) for k, v in node_call.keyword_params.items()
+            }
         # get the name and parameter list of func
         if isinstance(func, (Intrin, ScopeHandler, SpecialStmt)):
             func_name, param_list = func.signature()
@@ -276,6 +280,7 @@ class TVMScriptParser(Transformer):
         reader = CallArgumentReader(func_name, args, kw_args, self, node_call)
         pos_only, kwargs, varargs = param_list
         internal_args = list()
+
         for i, arg_name in enumerate(pos_only):
             internal_args.append(reader.get_pos_only_arg(i + 1, arg_name))
         for i, arg_info in enumerate(kwargs):
@@ -439,8 +444,22 @@ class TVMScriptParser(Transformer):
 
         # add parameters of function
         for arg in node.params:
-            arg_var = tvm.te.var(arg.name, self.parse_type(arg.ty, arg))
-            self.context.update_symbol(arg.name, arg_var, node)
+            # Note that this case is for T.match_buffer syntax sugar
+            if isinstance(arg.ty, (ast.TypeCall, ast.TypeApply)):
+                result = self.handle_match_buffer_type(arg.ty, arg.name)
+                if not isinstance(result, buffer.Buffer):
+                    self.report_error(
+                        "The result type of evaluating TypeCall and TypeApply stmt"
+                        f" is wrong: {type(result)}. It should be a Buffer",
+                        node.span,
+                    )
+                arg_name_with_handle = arg.name + "_handle"
+                arg_var = tvm.te.var(arg_name_with_handle, tvm.ir.PrimType("handle"))
+                self.context.func_buffer_map[arg_var] = result
+                self.context.update_symbol(arg.name, result, node)
+            else:
+                arg_var = tvm.te.var(arg.name, self.parse_type(arg.ty, arg))
+                self.context.update_symbol(arg.name, arg_var, node)
             self.context.func_params.append(arg_var)
 
         if not check_decorator(node.decorators):
@@ -1110,6 +1129,30 @@ class TVMScriptParser(Transformer):
         """
         return node.value
 
+    def transform_TypeTuple(self, node):
+        """Tuple value visitor for types.
+
+        Mostly used in `transform_TypeCall` and `transform_TypeApply`.
+        """
+        return [self.transform(value) for value in node.values]
+
+    def handle_match_buffer_type(self, node, buffer_name):
+        """special function to handle syntax sugar for match buffer.
+
+        This method is for buffer declarations in the function parameters.
+        """
+        func = self.transform(node.func_name)
+        assert isinstance(func, SpecialStmt)
+
+        # parse args and kwargs for TypeCall and TypeApply
+        arg_list = self.parse_arg_list(func, node)
+        # Note that the third element in arg_list would always be the 'name'
+        # TODO: This index is hardcoded as a workaround. Better to make it programmatic
+        if arg_list[2] is None:
+            arg_list[2] = buffer_name
+        buf = func.handle(node, self.context, arg_list, node.func_name.span)
+        return buf
+
     def transform_Return(self, node):
         self.report_error(
             "TVM script does not support return statements. Instead the last statement in any "
diff --git a/python/tvm/script/tir/__init__.py b/python/tvm/script/tir/__init__.py
index 6aa7eb3..472b3de 100644
--- a/python/tvm/script/tir/__init__.py
+++ b/python/tvm/script/tir/__init__.py
@@ -18,6 +18,6 @@
 
 # Type system
 from .ty import int8, int16, int32, int64, float16, float32, float64
-from .ty import boolean, handle, Ptr, Tuple
+from .ty import boolean, handle, Ptr, Tuple, Buffer
 
 from .prim_func import prim_func
diff --git a/python/tvm/script/tir/ty.py b/python/tvm/script/tir/ty.py
index 9140310..2808e7a 100644
--- a/python/tvm/script/tir/ty.py
+++ b/python/tvm/script/tir/ty.py
@@ -21,6 +21,7 @@ a wrapper for uniform Type system in IR
 """
 # pylint: disable=invalid-name
 import tvm
+from .special_stmt import SpecialStmt, convert_to_int
 
 
 class TypeGeneric:  # pylint: disable=too-few-public-methods
@@ -67,6 +68,75 @@ class GenericTupleType(TypeGeneric):  # pylint: disable=abstract-method
         return ConcreteType(tvm.ir.TupleType([vtype.evaluate() for vtype in vtypes]))
 
 
+class GenericBufferType(SpecialStmt):  # pylint: disable=too-few-public-methods, abstract-method
+    """TVM script typing class for uniform Type objects"""
+
+    def __init__(self, vtype):
+        def match_buffer_syntax_sugar(
+            shape,
+            dtype: str = "float32",
+            name: str = None,
+            data=None,
+            strides=None,
+            elem_offset=None,
+            scope="global",
+            align=-1,
+            offset_factor=0,
+            buffer_type="default",
+            span=None,
+        ):
+            if strides is None:
+                strides = []
+            align = convert_to_int(align, "align", self.context.report_error, self.node.span)
+            offset_factor = convert_to_int(
+                offset_factor, "offset_factor", self.context.report_error, self.node.span
+            )
+            buffer = tvm.tir.decl_buffer(
+                shape,
+                dtype,
+                name,
+                data,
+                strides,
+                elem_offset,
+                scope,
+                align,
+                offset_factor,
+                buffer_type,
+                span=span,
+            )
+            return buffer
+
+        self.type = vtype
+        super().__init__(match_buffer_syntax_sugar, def_symbol=True)
+
+    def __call__(
+        self,
+        shape,
+        dtype="float32",
+        *,
+        name: str = None,
+        data=None,
+        strides=None,
+        elem_offset=None,
+        scope="global",
+        align=-1,
+        offset_factor=0,
+        buffer_type="default",
+        span=None,
+    ):
+        """
+        This function is for Buffer(...) syntax sugar.
+        """
+        pass  # pylint: disable=unnecessary-pass
+
+    def __getitem__(self, args):
+        """
+        This function is for Buffer[...] syntax sugar
+        Note that args is the list of all arguments
+        """
+        pass  # pylint: disable=unnecessary-pass
+
+
 int8 = ConcreteType("int8")
 int16 = ConcreteType("int16")
 int32 = ConcreteType("int32")
@@ -78,3 +148,6 @@ boolean = ConcreteType("bool")
 handle = ConcreteType("handle")
 Ptr = GenericPtrType()
 Tuple = GenericTupleType()
+# we don't have 'buffer' type on the cpp side
+# thus 'handle' is used here for convenience's sake
+Buffer = GenericBufferType("handle")
diff --git a/tests/python/unittest/test_tvmscript_syntax_sugar.py b/tests/python/unittest/test_tvmscript_syntax_sugar.py
index b8d1232..0d4c833 100644
--- a/tests/python/unittest/test_tvmscript_syntax_sugar.py
+++ b/tests/python/unittest/test_tvmscript_syntax_sugar.py
@@ -101,5 +101,50 @@ def test_syntax_sugar_fail():
     check_error(loop_syntax_sugar_fail, 3)
 
 
+# match buffer - use kwargs
+@T.prim_func
+def elementwise_handle(
+    a: T.handle,
+    b: T.handle,
+) -> None:
+    A = T.match_buffer(a, (128, 128, 128, 128))
+    B = T.match_buffer(b, (128, 128, 128, 128))
+    for i, j, k, l in T.grid(128, 128, 128, 128):
+        with T.block("B"):
+            vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l])
+            B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * 2.0
+
+
+# match buffer - use buffer with kwargs
+@T.prim_func
+def elementwise_buffer_kwargs(
+    a: T.Buffer(shape=(128, 128, 128, 128), dtype="float32", elem_offset=None),
+    b: T.Buffer(shape=(128, 128, 128, 128), dtype="float32", elem_offset=None),
+) -> None:
+    for i, j, k, l in T.grid(128, 128, 128, 128):
+        with T.block("B"):
+            vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l])
+            b[vi, vj, vk, vl] = a[vi, vj, vk, vl] * 2.0
+
+
+# match buffer - use buffer without kwargs
+@T.prim_func
+def elementwise_buffer_no_kwargs(
+    a: T.Buffer[(128, 128, 128, 128), "float32"],
+    b: T.Buffer[(128, 128, 128, 128), "float32"],
+) -> None:
+    for i, j, k, l in T.grid(128, 128, 128, 128):
+        with T.block("B"):
+            vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l])
+            b[vi, vj, vk, vl] = a[vi, vj, vk, vl] * 2.0
+
+
+def test_match_buffer_syntax_sugar():
+    # with kwargs
+    assert_structural_equal(elementwise_handle, elementwise_buffer_kwargs)
+    # without kwargs
+    assert_structural_equal(elementwise_handle, elementwise_buffer_no_kwargs)
+
+
 if __name__ == "__main__":
     sys.exit(pytest.main([__file__] + sys.argv[1:]))
diff --git a/tests/scripts/task_ci_setup.sh b/tests/scripts/task_ci_setup.sh
index dfd2a32..323bc07 100755
--- a/tests/scripts/task_ci_setup.sh
+++ b/tests/scripts/task_ci_setup.sh
@@ -30,7 +30,7 @@ set -o pipefail
 #
 echo "Addtiional setup in" ${CI_IMAGE_NAME}
 
-python3 -m pip install --user tlcpack-sphinx-addon==0.2.1 synr==0.5.0
+python3 -m pip install --user tlcpack-sphinx-addon==0.2.1 synr==0.6.0
 
 # Rebuild standalone_crt in build/ tree. This file is not currently archived by pack_lib() in
 # Jenkinsfile. We expect config.cmake to be present from pack_lib().