You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by wu...@apache.org on 2022/09/26 18:55:16 UTC
[tvm] branch main updated: [TVMScript] Infer T.match_buffer parameters for region (#12890)
This is an automated email from the ASF dual-hosted git repository.
wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new fd26813723 [TVMScript] Infer T.match_buffer parameters for region (#12890)
fd26813723 is described below
commit fd268137237d2f6fbff4aa4517449284330c3cd8
Author: Eric Lunderberg <Lu...@users.noreply.github.com>
AuthorDate: Mon Sep 26 13:55:06 2022 -0500
[TVMScript] Infer T.match_buffer parameters for region (#12890)
* [TVMScript] Infer T.match_buffer parameters for region
When using `T.match_buffer` to define a view into another buffer,
default shape and dtype parameters can be inferred.
* Updated unit test for new behavior
The test intentionally triggers a failed match based on mismatched
`elem_offset`. Therefore, the test now needs to explicitly pass an
`elem_offset` to trigger the failure, as this now defaults to having a
`Var` for `match_buffer` calls that represent views.
---
python/tvm/script/tir/special_stmt.py | 68 +++++++++++++++++-----
.../python/unittest/test_tir_lower_match_buffer.py | 4 +-
.../python/unittest/test_tvmscript_syntax_sugar.py | 25 ++++++++
3 files changed, 79 insertions(+), 18 deletions(-)
diff --git a/python/tvm/script/tir/special_stmt.py b/python/tvm/script/tir/special_stmt.py
index 15502055b7..7cbf474410 100644
--- a/python/tvm/script/tir/special_stmt.py
+++ b/python/tvm/script/tir/special_stmt.py
@@ -121,8 +121,8 @@ class MatchBuffer(SpecialStmt):
def __init__(self):
def match_buffer(
param,
- shape,
- dtype="float32",
+ shape=None,
+ dtype=None,
data=None,
strides=None,
elem_offset=None,
@@ -146,28 +146,64 @@ class MatchBuffer(SpecialStmt):
offset_factor, "offset_factor", self.context.report_error, self.node.span
)
buffer_name: str = self.node.lhs[0].id.name
- buffer = tvm.tir.decl_buffer(
- shape,
- dtype,
- buffer_name,
- data,
- strides,
- elem_offset,
- scope,
- align,
- offset_factor,
- buffer_type,
- axis_separators,
- span=span,
- )
+
if isinstance(param, tvm.tir.Var):
+ if shape is None:
+ self.context.report_error(
+ "Shape must be specified when binding input param",
+ self.node.rhs.span,
+ )
+
+ if dtype is None:
+ dtype = "float32"
+
+ buffer = tvm.tir.decl_buffer(
+ shape,
+ dtype,
+ buffer_name,
+ data,
+ strides,
+ elem_offset,
+ scope,
+ align,
+ offset_factor,
+ buffer_type,
+ axis_separators,
+ span=span,
+ )
if param not in self.context.func_params:
self.context.report_error(
"Can not bind non-input param to buffer", self.node.rhs.params[0].span
)
self.context.func_buffer_map[param] = buffer
+
elif isinstance(param, BufferSlice):
buffer_region = param.as_buffer_region()
+
+ if shape is None:
+ shape = [dim.extent for dim in buffer_region.region]
+
+ if dtype is None:
+ dtype = buffer_region.buffer.dtype
+
+ if elem_offset is None and offset_factor == 0:
+ offset_factor = 1
+
+ buffer = tvm.tir.decl_buffer(
+ shape,
+ dtype,
+ buffer_name,
+ data,
+ strides,
+ elem_offset,
+ scope,
+ align,
+ offset_factor,
+ buffer_type,
+ axis_separators,
+ span=span,
+ )
+
self.context.current_block_scope().match_buffers.append(
tvm.tir.MatchBufferRegion(buffer, buffer_region)
)
diff --git a/tests/python/unittest/test_tir_lower_match_buffer.py b/tests/python/unittest/test_tir_lower_match_buffer.py
index 93b7caf9cd..6120cf2b67 100644
--- a/tests/python/unittest/test_tir_lower_match_buffer.py
+++ b/tests/python/unittest/test_tir_lower_match_buffer.py
@@ -464,7 +464,7 @@ def fail_match_load(a: T.handle) -> None:
with T.block():
T.reads(A[i, j])
T.writes([])
- sub_A = T.match_buffer(A[i, j], ())
+ sub_A = T.match_buffer(A[i, j], (), elem_offset=0)
T.evaluate(sub_A[()])
@@ -475,7 +475,7 @@ def fail_match_store(a: T.handle) -> None:
with T.block():
T.reads([])
T.writes(A[i, j])
- sub_A = T.match_buffer(A[i, j], ())
+ sub_A = T.match_buffer(A[i, j], (), elem_offset=0)
sub_A[()] = 1
diff --git a/tests/python/unittest/test_tvmscript_syntax_sugar.py b/tests/python/unittest/test_tvmscript_syntax_sugar.py
index d955ec0a8c..2a2f7354d7 100644
--- a/tests/python/unittest/test_tvmscript_syntax_sugar.py
+++ b/tests/python/unittest/test_tvmscript_syntax_sugar.py
@@ -251,6 +251,31 @@ def test_match_buffer_int64():
assert_structural_equal(original, after_roundtrip, True)
+def test_match_buffer_region_has_implicit_shape_dtype():
+ @T.prim_func
+ def explicit_shape_dtype(A: T.Buffer[(16, 64), "int32"]):
+ with T.block():
+ B = T.match_buffer(A[8:16, 32:64], shape=(8, 32), dtype="int32")
+ T.evaluate(0)
+
+ @T.prim_func
+ def implicit_shape_dtype(A: T.Buffer[(16, 64), "int32"]):
+ with T.block():
+ B = T.match_buffer(A[8:16, 32:64])
+ T.evaluate(0)
+
+ assert_structural_equal(explicit_shape_dtype, implicit_shape_dtype)
+
+
+def test_match_buffer_input_requires_shape_arg():
+ with pytest.raises(tvm.error.DiagnosticError):
+
+ @T.prim_func
+ def func(a: T.handle):
+ A = T.match_buffer(a, dtype="int32")
+ T.evaluate(0)
+
+
def test_letstmt_bufferload_without_type_annotation():
# Variable assignment of PrimExpr types uses the dtype of the
# PrimExpr to determine the variable's dtype. Parsing of