You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by wu...@apache.org on 2022/09/26 18:55:16 UTC

[tvm] branch main updated: [TVMScript] Infer T.match_buffer parameters for region (#12890)

This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new fd26813723 [TVMScript] Infer T.match_buffer parameters for region (#12890)
fd26813723 is described below

commit fd268137237d2f6fbff4aa4517449284330c3cd8
Author: Eric Lunderberg <Lu...@users.noreply.github.com>
AuthorDate: Mon Sep 26 13:55:06 2022 -0500

    [TVMScript] Infer T.match_buffer parameters for region (#12890)
    
    * [TVMScript] Infer T.match_buffer parameters for region
    
    When using `T.match_buffer` to define a view into another buffer,
    default shape and dtype parameters can be inferred.
    
    * Updated unit test for new behavior
    
    The test intentionally triggers a failed match based on mismatched
    `elem_offset`.  Therefore, the test now needs to explicitly pass an
    `elem_offset` to trigger the failure, as this now defaults to having a
    `Var` for `match_buffer` calls that represent views.
---
 python/tvm/script/tir/special_stmt.py              | 68 +++++++++++++++++-----
 .../python/unittest/test_tir_lower_match_buffer.py |  4 +-
 .../python/unittest/test_tvmscript_syntax_sugar.py | 25 ++++++++
 3 files changed, 79 insertions(+), 18 deletions(-)

diff --git a/python/tvm/script/tir/special_stmt.py b/python/tvm/script/tir/special_stmt.py
index 15502055b7..7cbf474410 100644
--- a/python/tvm/script/tir/special_stmt.py
+++ b/python/tvm/script/tir/special_stmt.py
@@ -121,8 +121,8 @@ class MatchBuffer(SpecialStmt):
     def __init__(self):
         def match_buffer(
             param,
-            shape,
-            dtype="float32",
+            shape=None,
+            dtype=None,
             data=None,
             strides=None,
             elem_offset=None,
@@ -146,28 +146,64 @@ class MatchBuffer(SpecialStmt):
                 offset_factor, "offset_factor", self.context.report_error, self.node.span
             )
             buffer_name: str = self.node.lhs[0].id.name
-            buffer = tvm.tir.decl_buffer(
-                shape,
-                dtype,
-                buffer_name,
-                data,
-                strides,
-                elem_offset,
-                scope,
-                align,
-                offset_factor,
-                buffer_type,
-                axis_separators,
-                span=span,
-            )
+
             if isinstance(param, tvm.tir.Var):
+                if shape is None:
+                    self.context.report_error(
+                        "Shape must be specified when binding input param",
+                        self.node.rhs.span,
+                    )
+
+                if dtype is None:
+                    dtype = "float32"
+
+                buffer = tvm.tir.decl_buffer(
+                    shape,
+                    dtype,
+                    buffer_name,
+                    data,
+                    strides,
+                    elem_offset,
+                    scope,
+                    align,
+                    offset_factor,
+                    buffer_type,
+                    axis_separators,
+                    span=span,
+                )
                 if param not in self.context.func_params:
                     self.context.report_error(
                         "Can not bind non-input param to buffer", self.node.rhs.params[0].span
                     )
                 self.context.func_buffer_map[param] = buffer
+
             elif isinstance(param, BufferSlice):
                 buffer_region = param.as_buffer_region()
+
+                if shape is None:
+                    shape = [dim.extent for dim in buffer_region.region]
+
+                if dtype is None:
+                    dtype = buffer_region.buffer.dtype
+
+                if elem_offset is None and offset_factor == 0:
+                    offset_factor = 1
+
+                buffer = tvm.tir.decl_buffer(
+                    shape,
+                    dtype,
+                    buffer_name,
+                    data,
+                    strides,
+                    elem_offset,
+                    scope,
+                    align,
+                    offset_factor,
+                    buffer_type,
+                    axis_separators,
+                    span=span,
+                )
+
                 self.context.current_block_scope().match_buffers.append(
                     tvm.tir.MatchBufferRegion(buffer, buffer_region)
                 )
diff --git a/tests/python/unittest/test_tir_lower_match_buffer.py b/tests/python/unittest/test_tir_lower_match_buffer.py
index 93b7caf9cd..6120cf2b67 100644
--- a/tests/python/unittest/test_tir_lower_match_buffer.py
+++ b/tests/python/unittest/test_tir_lower_match_buffer.py
@@ -464,7 +464,7 @@ def fail_match_load(a: T.handle) -> None:
         with T.block():
             T.reads(A[i, j])
             T.writes([])
-            sub_A = T.match_buffer(A[i, j], ())
+            sub_A = T.match_buffer(A[i, j], (), elem_offset=0)
             T.evaluate(sub_A[()])
 
 
@@ -475,7 +475,7 @@ def fail_match_store(a: T.handle) -> None:
         with T.block():
             T.reads([])
             T.writes(A[i, j])
-            sub_A = T.match_buffer(A[i, j], ())
+            sub_A = T.match_buffer(A[i, j], (), elem_offset=0)
             sub_A[()] = 1
 
 
diff --git a/tests/python/unittest/test_tvmscript_syntax_sugar.py b/tests/python/unittest/test_tvmscript_syntax_sugar.py
index d955ec0a8c..2a2f7354d7 100644
--- a/tests/python/unittest/test_tvmscript_syntax_sugar.py
+++ b/tests/python/unittest/test_tvmscript_syntax_sugar.py
@@ -251,6 +251,31 @@ def test_match_buffer_int64():
     assert_structural_equal(original, after_roundtrip, True)
 
 
+def test_match_buffer_region_has_implicit_shape_dtype():
+    @T.prim_func
+    def explicit_shape_dtype(A: T.Buffer[(16, 64), "int32"]):
+        with T.block():
+            B = T.match_buffer(A[8:16, 32:64], shape=(8, 32), dtype="int32")
+            T.evaluate(0)
+
+    @T.prim_func
+    def implicit_shape_dtype(A: T.Buffer[(16, 64), "int32"]):
+        with T.block():
+            B = T.match_buffer(A[8:16, 32:64])
+            T.evaluate(0)
+
+    assert_structural_equal(explicit_shape_dtype, implicit_shape_dtype)
+
+
+def test_match_buffer_input_requires_shape_arg():
+    with pytest.raises(tvm.error.DiagnosticError):
+
+        @T.prim_func
+        def func(a: T.handle):
+            A = T.match_buffer(a, dtype="int32")
+            T.evaluate(0)
+
+
 def test_letstmt_bufferload_without_type_annotation():
     # Variable assignment of PrimExpr types uses the dtype of the
     # PrimExpr to determine the variable's dtype.  Parsing of