You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by wu...@apache.org on 2021/12/08 15:31:03 UTC
[tvm] branch main updated: [TVMScript] Add syntax sugar for T.handle and T.match_buffer (#9492)
This is an automated email from the ASF dual-hosted git repository.
wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new e8889ae [TVMScript] Add syntax sugar for T.handle and T.match_buffer (#9492)
e8889ae is described below
commit e8889ae0e3adc9cc9b02952d7ac4aac33e9f66b1
Author: Yuanjing Shi <yj...@shingjan.me>
AuthorDate: Wed Dec 8 07:30:27 2021 -0800
[TVMScript] Add syntax sugar for T.handle and T.match_buffer (#9492)
---
docker/install/ubuntu_install_python_package.sh | 2 +-
python/gen_requirements.py | 2 +-
python/tvm/script/parser.py | 61 +++++++++++++++---
python/tvm/script/tir/__init__.py | 2 +-
python/tvm/script/tir/ty.py | 73 ++++++++++++++++++++++
.../python/unittest/test_tvmscript_syntax_sugar.py | 45 +++++++++++++
tests/scripts/task_ci_setup.sh | 2 +-
7 files changed, 174 insertions(+), 13 deletions(-)
diff --git a/docker/install/ubuntu_install_python_package.sh b/docker/install/ubuntu_install_python_package.sh
index fb0f596..8b79455 100755
--- a/docker/install/ubuntu_install_python_package.sh
+++ b/docker/install/ubuntu_install_python_package.sh
@@ -36,6 +36,6 @@ pip3 install \
pytest-xdist \
requests \
scipy \
- synr==0.5.0 \
+ synr==0.6.0 \
six \
tornado
diff --git a/python/gen_requirements.py b/python/gen_requirements.py
index b4f3907..bcd8ccd 100755
--- a/python/gen_requirements.py
+++ b/python/gen_requirements.py
@@ -255,7 +255,7 @@ CONSTRAINTS = [
("sphinx_autodoc_annotation", None),
("sphinx_gallery", None),
("sphinx_rtd_theme", None),
- ("synr", "==0.5.0"),
+ ("synr", "==0.6.0"),
("tensorflow", None),
("tensorflow-estimator", None),
("tflite", None),
diff --git a/python/tvm/script/parser.py b/python/tvm/script/parser.py
index 6cb22ae..0132025 100644
--- a/python/tvm/script/parser.py
+++ b/python/tvm/script/parser.py
@@ -32,6 +32,7 @@ from tvm import IRModule
from tvm._ffi.base import TVMError
from tvm.ir import GlobalVar
from tvm.ir.function import BaseFunc
+from tvm.tir import buffer
from tvm.tir.function import PrimFunc
from . import _ffi_api
from . import tir
@@ -154,10 +155,10 @@ class TVMScriptParser(Transformer):
ast.BuiltinOp.Not: tvm.tir.Not,
}
- def __init__(self, base_lienno, tir_namespace):
+ def __init__(self, base_lineno, tir_namespace):
self.context = None
- self.base_lineno = base_lienno
+ self.base_lineno = base_lineno
self.current_lineno = 0
self.current_col_offset = 0
self.tir_namespace = tir_namespace
@@ -249,7 +250,7 @@ class TVMScriptParser(Transformer):
func : Function
The function that provides the signature
- node_call: ast.Call
+ node_call: Union[ast.Call, ast.TypeApply, ast.TypeCall]
The AST call node that calls into the function.
Returns
@@ -257,12 +258,15 @@ class TVMScriptParser(Transformer):
arg_list : list
The parsed positional argument.
"""
- assert isinstance(node_call, ast.Call)
+ assert isinstance(node_call, (ast.Call, ast.TypeApply, ast.TypeCall))
# collect arguments
args = [self.transform(arg) for arg in node_call.params]
- kw_args = {
- self.transform(k): self.transform(v) for k, v in node_call.keyword_params.items()
- }
+ if isinstance(node_call, ast.TypeApply):
+ kw_args = {} # TypeApply (e.g. foo[bar]) doesn't have kwargs defined in synr
+ else:
+ kw_args = {
+ self.transform(k): self.transform(v) for k, v in node_call.keyword_params.items()
+ }
# get the name and parameter list of func
if isinstance(func, (Intrin, ScopeHandler, SpecialStmt)):
func_name, param_list = func.signature()
@@ -276,6 +280,7 @@ class TVMScriptParser(Transformer):
reader = CallArgumentReader(func_name, args, kw_args, self, node_call)
pos_only, kwargs, varargs = param_list
internal_args = list()
+
for i, arg_name in enumerate(pos_only):
internal_args.append(reader.get_pos_only_arg(i + 1, arg_name))
for i, arg_info in enumerate(kwargs):
@@ -439,8 +444,22 @@ class TVMScriptParser(Transformer):
# add parameters of function
for arg in node.params:
- arg_var = tvm.te.var(arg.name, self.parse_type(arg.ty, arg))
- self.context.update_symbol(arg.name, arg_var, node)
+ # Note that this case is for T.match_buffer syntax sugar
+ if isinstance(arg.ty, (ast.TypeCall, ast.TypeApply)):
+ result = self.handle_match_buffer_type(arg.ty, arg.name)
+ if not isinstance(result, buffer.Buffer):
+ self.report_error(
+ "The result type of evaluating TypeCall and TypeApply stmt"
+ f" is wrong: {type(result)}. It should be a Buffer",
+ node.span,
+ )
+ arg_name_with_handle = arg.name + "_handle"
+ arg_var = tvm.te.var(arg_name_with_handle, tvm.ir.PrimType("handle"))
+ self.context.func_buffer_map[arg_var] = result
+ self.context.update_symbol(arg.name, result, node)
+ else:
+ arg_var = tvm.te.var(arg.name, self.parse_type(arg.ty, arg))
+ self.context.update_symbol(arg.name, arg_var, node)
self.context.func_params.append(arg_var)
if not check_decorator(node.decorators):
@@ -1110,6 +1129,30 @@ class TVMScriptParser(Transformer):
"""
return node.value
+ def transform_TypeTuple(self, node):
+ """Tuple value visitor for types.
+
+ Mostly used in `transform_TypeCall` and `transform_TypeApply`.
+ """
+ return [self.transform(value) for value in node.values]
+
+ def handle_match_buffer_type(self, node, buffer_name):
+ """special function to handle syntax sugar for match buffer.
+
+ This method is for buffer declarations in the function parameters.
+ """
+ func = self.transform(node.func_name)
+ assert isinstance(func, SpecialStmt)
+
+ # parse args and kwargs for TypeCall and TypeApply
+ arg_list = self.parse_arg_list(func, node)
+ # Note that the third element in arg_list would always be the 'name'
+ # TODO: This index is hardcoded as a workaround. Better to make it programmatic
+ if arg_list[2] is None:
+ arg_list[2] = buffer_name
+ buf = func.handle(node, self.context, arg_list, node.func_name.span)
+ return buf
+
def transform_Return(self, node):
self.report_error(
"TVM script does not support return statements. Instead the last statement in any "
diff --git a/python/tvm/script/tir/__init__.py b/python/tvm/script/tir/__init__.py
index 6aa7eb3..472b3de 100644
--- a/python/tvm/script/tir/__init__.py
+++ b/python/tvm/script/tir/__init__.py
@@ -18,6 +18,6 @@
# Type system
from .ty import int8, int16, int32, int64, float16, float32, float64
-from .ty import boolean, handle, Ptr, Tuple
+from .ty import boolean, handle, Ptr, Tuple, Buffer
from .prim_func import prim_func
diff --git a/python/tvm/script/tir/ty.py b/python/tvm/script/tir/ty.py
index 9140310..2808e7a 100644
--- a/python/tvm/script/tir/ty.py
+++ b/python/tvm/script/tir/ty.py
@@ -21,6 +21,7 @@ a wrapper for uniform Type system in IR
"""
# pylint: disable=invalid-name
import tvm
+from .special_stmt import SpecialStmt, convert_to_int
class TypeGeneric: # pylint: disable=too-few-public-methods
@@ -67,6 +68,75 @@ class GenericTupleType(TypeGeneric): # pylint: disable=abstract-method
return ConcreteType(tvm.ir.TupleType([vtype.evaluate() for vtype in vtypes]))
+class GenericBufferType(SpecialStmt): # pylint: disable=too-few-public-methods, abstract-method
+ """TVM script typing class for uniform Type objects"""
+
+ def __init__(self, vtype):
+ def match_buffer_syntax_sugar(
+ shape,
+ dtype: str = "float32",
+ name: str = None,
+ data=None,
+ strides=None,
+ elem_offset=None,
+ scope="global",
+ align=-1,
+ offset_factor=0,
+ buffer_type="default",
+ span=None,
+ ):
+ if strides is None:
+ strides = []
+ align = convert_to_int(align, "align", self.context.report_error, self.node.span)
+ offset_factor = convert_to_int(
+ offset_factor, "offset_factor", self.context.report_error, self.node.span
+ )
+ buffer = tvm.tir.decl_buffer(
+ shape,
+ dtype,
+ name,
+ data,
+ strides,
+ elem_offset,
+ scope,
+ align,
+ offset_factor,
+ buffer_type,
+ span=span,
+ )
+ return buffer
+
+ self.type = vtype
+ super().__init__(match_buffer_syntax_sugar, def_symbol=True)
+
+ def __call__(
+ self,
+ shape,
+ dtype="float32",
+ *,
+ name: str = None,
+ data=None,
+ strides=None,
+ elem_offset=None,
+ scope="global",
+ align=-1,
+ offset_factor=0,
+ buffer_type="default",
+ span=None,
+ ):
+ """
+ This function is for Buffer(...) syntax sugar.
+ """
+ pass # pylint: disable=unnecessary-pass
+
+ def __getitem__(self, args):
+ """
+ This function is for Buffer[...] syntax sugar
+ Note that args is the list of all arguments
+ """
+ pass # pylint: disable=unnecessary-pass
+
+
int8 = ConcreteType("int8")
int16 = ConcreteType("int16")
int32 = ConcreteType("int32")
@@ -78,3 +148,6 @@ boolean = ConcreteType("bool")
handle = ConcreteType("handle")
Ptr = GenericPtrType()
Tuple = GenericTupleType()
+# we don't have 'buffer' type on the cpp side
+# thus 'handle' is used here for convenience's sake
+Buffer = GenericBufferType("handle")
diff --git a/tests/python/unittest/test_tvmscript_syntax_sugar.py b/tests/python/unittest/test_tvmscript_syntax_sugar.py
index b8d1232..0d4c833 100644
--- a/tests/python/unittest/test_tvmscript_syntax_sugar.py
+++ b/tests/python/unittest/test_tvmscript_syntax_sugar.py
@@ -101,5 +101,50 @@ def test_syntax_sugar_fail():
check_error(loop_syntax_sugar_fail, 3)
+# match buffer - use kwargs
+@T.prim_func
+def elementwise_handle(
+ a: T.handle,
+ b: T.handle,
+) -> None:
+ A = T.match_buffer(a, (128, 128, 128, 128))
+ B = T.match_buffer(b, (128, 128, 128, 128))
+ for i, j, k, l in T.grid(128, 128, 128, 128):
+ with T.block("B"):
+ vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l])
+ B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * 2.0
+
+
+# match buffer - use buffer with kwargs
+@T.prim_func
+def elementwise_buffer_kwargs(
+ a: T.Buffer(shape=(128, 128, 128, 128), dtype="float32", elem_offset=None),
+ b: T.Buffer(shape=(128, 128, 128, 128), dtype="float32", elem_offset=None),
+) -> None:
+ for i, j, k, l in T.grid(128, 128, 128, 128):
+ with T.block("B"):
+ vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l])
+ b[vi, vj, vk, vl] = a[vi, vj, vk, vl] * 2.0
+
+
+# match buffer - use buffer without kwargs
+@T.prim_func
+def elementwise_buffer_no_kwargs(
+ a: T.Buffer[(128, 128, 128, 128), "float32"],
+ b: T.Buffer[(128, 128, 128, 128), "float32"],
+) -> None:
+ for i, j, k, l in T.grid(128, 128, 128, 128):
+ with T.block("B"):
+ vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l])
+ b[vi, vj, vk, vl] = a[vi, vj, vk, vl] * 2.0
+
+
+def test_match_buffer_syntax_sugar():
+ # with kwargs
+ assert_structural_equal(elementwise_handle, elementwise_buffer_kwargs)
+ # without kwargs
+ assert_structural_equal(elementwise_handle, elementwise_buffer_no_kwargs)
+
+
if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))
diff --git a/tests/scripts/task_ci_setup.sh b/tests/scripts/task_ci_setup.sh
index dfd2a32..323bc07 100755
--- a/tests/scripts/task_ci_setup.sh
+++ b/tests/scripts/task_ci_setup.sh
@@ -30,7 +30,7 @@ set -o pipefail
#
echo "Addtiional setup in" ${CI_IMAGE_NAME}
-python3 -m pip install --user tlcpack-sphinx-addon==0.2.1 synr==0.5.0
+python3 -m pip install --user tlcpack-sphinx-addon==0.2.1 synr==0.6.0
# Rebuild standalone_crt in build/ tree. This file is not currently archived by pack_lib() in
# Jenkinsfile. We expect config.cmake to be present from pack_lib().