You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2024/03/04 13:56:26 UTC

(tvm) branch main updated: [TVMScript] Infer T.reads() for DeclBuffer nodes (#16663)

This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 31bb4b58fe [TVMScript] Infer T.reads() for DeclBuffer nodes (#16663)
31bb4b58fe is described below

commit 31bb4b58fe1b99ec8c626a7252e159d9d94dd7dd
Author: Eric Lunderberg <Lu...@users.noreply.github.com>
AuthorDate: Mon Mar 4 07:56:19 2024 -0600

    [TVMScript] Infer T.reads() for DeclBuffer nodes (#16663)
    
    Prior to this commit, the automatic `T.reads()` and `T.writes()`
    annotations were only generated for buffers appearing as function
    arguments, as `T.alloc_buffer` in a `T.block`, or as `T.match_buffer`
    in a `T.block`.  However, inferred `T.reads()` for a buffer defined by
    the `"tir.BindParams"` pass would be erroneously missing.  These
    annotations may be required for correct scheduling (see discussion in
    [PR#16660](https://github.com/apache/tvm/pull/16660)).
    
    This commit updates the TVMScript parsing to infer `T.reads()` and
    `T.writes()` annotations for buffers defined with `DeclBuffer` nodes.
---
 src/tir/ir/script/script_complete.cc              | 11 +++++++
 tests/python/tvmscript/test_tvmscript_complete.py | 36 ++++++++++++++++++-----
 2 files changed, 39 insertions(+), 8 deletions(-)

diff --git a/src/tir/ir/script/script_complete.cc b/src/tir/ir/script/script_complete.cc
index 5ff1c65ca9..e6e942a87b 100644
--- a/src/tir/ir/script/script_complete.cc
+++ b/src/tir/ir/script/script_complete.cc
@@ -99,6 +99,17 @@ class ScriptCompleter : public StmtMutator {
     }
   }
 
+  Stmt VisitStmt_(const DeclBufferNode* op) final {
+    if (buffer_var_map_->count(op->buffer->data)) {
+      return StmtMutator::VisitStmt_(op);
+    } else {
+      buffer_var_map_->Set(op->buffer->data, op->buffer);
+      auto output = StmtMutator::VisitStmt_(op);
+      buffer_var_map_->erase(op->buffer->data);
+      return output;
+    }
+  }
+
   bool is_root_block_ = true;
 };
 
diff --git a/tests/python/tvmscript/test_tvmscript_complete.py b/tests/python/tvmscript/test_tvmscript_complete.py
index 2723566d8c..60002dbdb0 100644
--- a/tests/python/tvmscript/test_tvmscript_complete.py
+++ b/tests/python/tvmscript/test_tvmscript_complete.py
@@ -15,7 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 
-import tvm
+import tvm.testing
 from tvm.ir import Range
 from tvm.script import tir as T
 
@@ -336,11 +336,31 @@ def test_complete_alloc_buffer():
     )
 
 
+def test_access_region_for_decl_buffer():
+    @T.prim_func(private=True)
+    def automatic_access_regions(A: T.Buffer(4, "int32"), C: T.Buffer(4, "int32")):
+        B_data = T.allocate_const([1, 2, 3, 4], "int32", extents=[4])
+        B = T.decl_buffer(4, "int32", data=B_data)
+
+        for i in range(4):
+            with T.block("compute"):
+                vi = T.axis.remap("S", [i])
+                C[vi] = A[vi] + B[vi]
+
+    @T.prim_func(private=True)
+    def explicit_access_regions(A: T.Buffer(4, "int32"), C: T.Buffer(4, "int32")):
+        B_data = T.allocate_const([1, 2, 3, 4], "int32", extents=[4])
+        B = T.decl_buffer(4, "int32", data=B_data)
+
+        for i in range(4):
+            with T.block("compute"):
+                vi = T.axis.remap("S", [i])
+                T.reads(A[vi], B[vi])
+                T.writes(C[vi])
+                C[vi] = A[vi] + B[vi]
+
+    tvm.ir.assert_structural_equal(explicit_access_regions, automatic_access_regions)
+
+
 if __name__ == "__main__":
-    test_complete_matmul()
-    test_complete_matmul_original()
-    test_complete_with_root()
-    test_complete_part_region()
-    test_complete_buffer_indices()
-    test_complete_match_buffer()
-    test_complete_alloc_buffer()
+    tvm.testing.main()