You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2024/03/04 13:56:26 UTC
(tvm) branch main updated: [TVMScript] Infer T.reads() for DeclBuffer nodes (#16663)
This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 31bb4b58fe [TVMScript] Infer T.reads() for DeclBuffer nodes (#16663)
31bb4b58fe is described below
commit 31bb4b58fe1b99ec8c626a7252e159d9d94dd7dd
Author: Eric Lunderberg <Lu...@users.noreply.github.com>
AuthorDate: Mon Mar 4 07:56:19 2024 -0600
[TVMScript] Infer T.reads() for DeclBuffer nodes (#16663)
Prior to this commit, the automatic `T.reads()` and `T.writes()`
annotations were only generated for buffers appearing as function
arguments, as `T.alloc_buffer` in a `T.block`, or as `T.match_buffer`
in a `T.block`. However, inferred `T.reads()` for a buffer defined by
the `"tir.BindParams"` pass would be erroneously missing. These
annotations may be required for correct scheduling (see discussion in
[PR#16660](https://github.com/apache/tvm/pull/16660)).
This commit updates the TVMScript parsing to infer `T.reads()` and
`T.writes()` annotations for buffers defined with `DeclBuffer` nodes.
---
src/tir/ir/script/script_complete.cc | 11 +++++++
tests/python/tvmscript/test_tvmscript_complete.py | 36 ++++++++++++++++++-----
2 files changed, 39 insertions(+), 8 deletions(-)
diff --git a/src/tir/ir/script/script_complete.cc b/src/tir/ir/script/script_complete.cc
index 5ff1c65ca9..e6e942a87b 100644
--- a/src/tir/ir/script/script_complete.cc
+++ b/src/tir/ir/script/script_complete.cc
@@ -99,6 +99,17 @@ class ScriptCompleter : public StmtMutator {
}
}
+ Stmt VisitStmt_(const DeclBufferNode* op) final {
+ if (buffer_var_map_->count(op->buffer->data)) {
+ return StmtMutator::VisitStmt_(op);
+ } else {
+ buffer_var_map_->Set(op->buffer->data, op->buffer);
+ auto output = StmtMutator::VisitStmt_(op);
+ buffer_var_map_->erase(op->buffer->data);
+ return output;
+ }
+ }
+
bool is_root_block_ = true;
};
diff --git a/tests/python/tvmscript/test_tvmscript_complete.py b/tests/python/tvmscript/test_tvmscript_complete.py
index 2723566d8c..60002dbdb0 100644
--- a/tests/python/tvmscript/test_tvmscript_complete.py
+++ b/tests/python/tvmscript/test_tvmscript_complete.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
-import tvm
+import tvm.testing
from tvm.ir import Range
from tvm.script import tir as T
@@ -336,11 +336,31 @@ def test_complete_alloc_buffer():
)
+def test_access_region_for_decl_buffer():
+ @T.prim_func(private=True)
+ def automatic_access_regions(A: T.Buffer(4, "int32"), C: T.Buffer(4, "int32")):
+ B_data = T.allocate_const([1, 2, 3, 4], "int32", extents=[4])
+ B = T.decl_buffer(4, "int32", data=B_data)
+
+ for i in range(4):
+ with T.block("compute"):
+ vi = T.axis.remap("S", [i])
+ C[vi] = A[vi] + B[vi]
+
+ @T.prim_func(private=True)
+ def explicit_access_regions(A: T.Buffer(4, "int32"), C: T.Buffer(4, "int32")):
+ B_data = T.allocate_const([1, 2, 3, 4], "int32", extents=[4])
+ B = T.decl_buffer(4, "int32", data=B_data)
+
+ for i in range(4):
+ with T.block("compute"):
+ vi = T.axis.remap("S", [i])
+ T.reads(A[vi], B[vi])
+ T.writes(C[vi])
+ C[vi] = A[vi] + B[vi]
+
+ tvm.ir.assert_structural_equal(explicit_access_regions, automatic_access_regions)
+
+
if __name__ == "__main__":
- test_complete_matmul()
- test_complete_matmul_original()
- test_complete_with_root()
- test_complete_part_region()
- test_complete_buffer_indices()
- test_complete_match_buffer()
- test_complete_alloc_buffer()
+ tvm.testing.main()