You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by lu...@apache.org on 2024/03/14 10:51:46 UTC
(tvm) branch main updated: [TIR] Improve well-formed check's handling of match buffer (#16655)
This is an automated email from the ASF dual-hosted git repository.
lunderberg pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 695f958bc9 [TIR] Improve well-formed check's handling of match buffer (#16655)
695f958bc9 is described below
commit 695f958bc9ef40e625a84ad9355df2e75e6498a0
Author: Eric Lunderberg <Lu...@users.noreply.github.com>
AuthorDate: Thu Mar 14 05:51:39 2024 -0500
[TIR] Improve well-formed check's handling of match buffer (#16655)
* [TIR] Improve well-formed check's handling of match buffer
- The `T.match_buffer` at the start of a function may contain repeated
use of the same data var. For example, a function that must accept
two `DLTensor` objects with the same backing allocation.
- The `"buffer_bind_scope"` is an older style of match buffer, and may
be the point of definition for variables.
* Improved comment, added context.pop_back()
---
src/tir/analysis/verify_well_formed.cc | 1 +
src/tir/ir/tir_visitor_with_path.cc | 78 +++++------
src/tir/ir/tir_visitor_with_path.h | 43 ++++++
.../test_tir_analysis_verify_well_formed.py | 149 +++++++++++++++++++++
4 files changed, 228 insertions(+), 43 deletions(-)
diff --git a/src/tir/analysis/verify_well_formed.cc b/src/tir/analysis/verify_well_formed.cc
index 943a119711..c001d35054 100644
--- a/src/tir/analysis/verify_well_formed.cc
+++ b/src/tir/analysis/verify_well_formed.cc
@@ -228,6 +228,7 @@ class UndefinedVarVerifier : public Verifier<UndefinedVarVerifier> {
using Verifier::Verifier;
private:
+ using Verifier::Visit;
void Visit(const PrimFunc& prim_func, ObjectPath path) override {
Verifier::Visit(prim_func, path);
redefine_allowed_within_function_.clear();
diff --git a/src/tir/ir/tir_visitor_with_path.cc b/src/tir/ir/tir_visitor_with_path.cc
index a80f2300e2..37b3ce55a2 100644
--- a/src/tir/ir/tir_visitor_with_path.cc
+++ b/src/tir/ir/tir_visitor_with_path.cc
@@ -78,47 +78,22 @@ void TIRVisitorWithPath::Visit(const PrimFunc& func, ObjectPath path) {
// variable has occurred. Therefore, to ensure that we only avoid
// duplicate calls to VisitVarDef, these semantics need to be
// checked.
- std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> defined_params;
std::vector<std::variant<DefContext<Var>, DefContext<Buffer>>> context;
auto ppath = path->Attr("params");
for (size_t i = 0; i < func->params.size(); i++) {
context.push_back(WithDef(func->params[i], ppath->ArrayIndex(i)));
- defined_params.insert(func->params[i]);
}
- auto try_visit_implicit_var_def = [this, &defined_params, &context](const PrimExpr& expr,
- ObjectPath path) {
- if (auto opt = expr.as<Var>()) {
- auto var = opt.value();
- if (!defined_params.count(var)) {
- context.push_back(WithDef(var, path));
- defined_params.insert(var);
- }
- }
- };
- auto try_visit_implicit_var_def_array = [&try_visit_implicit_var_def](const Array<PrimExpr>& arr,
- ObjectPath path) {
- for (size_t i = 0; i < arr.size(); i++) {
- try_visit_implicit_var_def(arr[i], path->ArrayIndex(i));
- }
- };
-
auto buffer_map_path = path->Attr("buffer_map");
for (size_t i = 0; i < func->params.size(); i++) {
if (auto opt = func->buffer_map.Get(func->params[i])) {
auto buf = opt.value();
auto buf_path = buffer_map_path->MapValue(ppath->ArrayIndex(i));
- // A buffer in the buffer_map always defines its data pointer
- context.push_back(WithDef(buf->data, buf_path->Attr("data")));
-
- // But other implicit definitions only apply if they weren't
- // provided as explicit parameters, and they weren't defined
- // implicitly by any previous buffer.
- try_visit_implicit_var_def_array(buf->shape, buf_path->Attr("shape"));
- try_visit_implicit_var_def_array(buf->strides, buf_path->Attr("strides"));
- try_visit_implicit_var_def(buf->elem_offset, buf_path->Attr("elem_offset"));
+ for (auto& def : WithMatchBufferDefs(buf, buf_path)) {
+ context.push_back(std::move(def));
+ }
}
}
@@ -127,7 +102,7 @@ void TIRVisitorWithPath::Visit(const PrimFunc& func, ObjectPath path) {
for (size_t i = 0; i < func->params.size(); i++) {
if (auto opt = func->buffer_map.Get(func->params[i])) {
auto buf_path = buffer_map_path->MapValue(ppath->ArrayIndex(i));
- EnterDef(opt.value(), buf_path);
+ context.push_back(WithDef(opt.value(), buf_path));
}
}
@@ -199,16 +174,40 @@ void TIRVisitorWithPath::VisitStmt_(const LetStmtNode* op, ObjectPath path) {
void TIRVisitorWithPath::VisitStmt_(const AttrStmtNode* op, ObjectPath path) {
Visit(op->value, path->Attr("value"));
- std::optional<DefContext<IterVar>> context = std::nullopt;
+ std::vector<std::variant<DefContext<IterVar>, DefContext<Var>>> context;
if (auto iter_var = op->node.as<IterVar>();
iter_var && (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread)) {
// Some attributes serve as a source of definition for the
// tir::Var they annotate.
- context = WithDef(iter_var.value(), path->Attr("node"));
+ context.push_back(WithDef(iter_var.value(), path->Attr("node")));
+
+ } else if (op->attr_key == attr::buffer_bind_scope) {
+ // The `attr::buffer_bind_scope` attribute defines a view into an
+ // existing buffer, similar to the newer
+ // `BlockNode::match_buffers` field. It requires the buffer being
+ // viewed to be defined prior to the attribute. The
+ // `attr::buffer_bind_scope` is the point of definition for the
+ // `tir::Buffer buffer_view`, its `tir::Var` data pointer, and any
+ // symbolic shapes used within `buffer_view that are not already
+ // defined.
+ Array<ObjectRef> arr = Downcast<Array<ObjectRef>>(op->node);
+ ICHECK_EQ(arr.size(), 2U);
+ Buffer buffer_view = Downcast<Buffer>(arr[0]);
+ Buffer orig_buffer = Downcast<Buffer>(arr[1]);
+ Visit(orig_buffer, path->Attr("node")->ArrayIndex(1));
+
+ for (auto& var : WithMatchBufferDefs(buffer_view, path->Attr("node")->ArrayIndex(0))) {
+ context.push_back(std::move(var));
+ }
+
} else if (auto expr = op->node.as<PrimExpr>()) {
Visit(expr.value(), path->Attr("node"));
}
Visit(op->body, path->Attr("body"));
+
+ while (context.size()) {
+ context.pop_back();
+ }
}
void TIRVisitorWithPath::VisitStmt_(const ForNode* op, ObjectPath path) {
@@ -250,7 +249,8 @@ void TIRVisitorWithPath::VisitStmt_(const BufferStoreNode* op, ObjectPath path)
void TIRVisitorWithPath::VisitStmt_(const BufferRealizeNode* op, ObjectPath path) {
Visit(op->condition, path->Attr("condition"));
Visit(op->bounds, path->Attr("bounds"));
- auto context = WithDef(op->buffer, path->Attr("buffer"));
+ auto context = WithDefIfUndefined(op->buffer->data, path->Attr("buffer")->Attr("data"));
+ Visit(op->buffer, path->Attr("buffer"));
Visit(op->body, path->Attr("body"));
}
@@ -318,18 +318,10 @@ void TIRVisitorWithPath::VisitStmt_(const BlockNode* op, ObjectPath path) {
for (size_t i = 0; i < op->match_buffers.size(); i++) {
auto buf = op->match_buffers[i]->buffer;
auto buffer_path = match_path->ArrayIndex(i)->Attr("buffer");
- auto buffer_strides_path = buffer_path->Attr("strides");
- context.push_back(WithDef(buf->data, buffer_path->Attr("data")));
- // Define buffer strides and elem_offset if they are vars
- if (const auto* v = buf->elem_offset.as<VarNode>()) {
- context.push_back(WithDef(GetRef<Var>(v), buffer_path->Attr("elem_offset")));
- }
- for (size_t i = 0; i < buf->strides.size(); ++i) {
- if (const auto* v = buf->strides[i].as<VarNode>()) {
- context.push_back(WithDef(GetRef<Var>(v), buffer_strides_path->ArrayIndex(i)));
- }
+
+ for (auto& def : WithMatchBufferDefs(buf, buffer_path)) {
+ context.push_back(std::move(def));
}
- context.push_back(WithDef(buf, buffer_path));
}
}
diff --git a/src/tir/ir/tir_visitor_with_path.h b/src/tir/ir/tir_visitor_with_path.h
index dd0da1fe77..1ae6df58f7 100644
--- a/src/tir/ir/tir_visitor_with_path.h
+++ b/src/tir/ir/tir_visitor_with_path.h
@@ -29,7 +29,10 @@
#include <tvm/tir/stmt_functor.h>
#include <exception>
+#include <optional>
+#include <unordered_set>
#include <utility>
+#include <vector>
namespace tvm {
namespace tir {
@@ -173,6 +176,7 @@ class TIRVisitorWithPath : protected ExprFunctor<void(const PrimExpr&, ObjectPat
// construction of the DefContext and the destruction, we avoid
// this case and allow the first error to propagate upward.
if (self_ && std::uncaught_exceptions() == uncaught_exceptions_) {
+ self_->in_scope_definitions_.erase(obj_);
self_->ExitDef(obj_, path_);
}
}
@@ -182,6 +186,7 @@ class TIRVisitorWithPath : protected ExprFunctor<void(const PrimExpr&, ObjectPat
DefContext(TIRVisitorWithPath* self, T obj, ObjectPath path)
: self_(self), obj_(obj), path_(path), uncaught_exceptions_(std::uncaught_exceptions()) {
+ self_->in_scope_definitions_.insert(obj_);
self_->EnterDef(obj_, path_);
}
@@ -203,6 +208,44 @@ class TIRVisitorWithPath : protected ExprFunctor<void(const PrimExpr&, ObjectPat
DefContext<T> WithDef(T obj, ObjectPath path) {
return DefContext(this, obj, path);
}
+
+ /* \brief Utility to track the scope of a node's definition. */
+ template <typename T>
+ std::optional<DefContext<T>> WithDefIfUndefined(T obj, ObjectPath path) {
+ if (in_scope_definitions_.count(obj)) {
+ return std::nullopt;
+ } else {
+ return WithDef(obj, path);
+ }
+ }
+
+ std::vector<DefContext<Var>> WithMatchBufferDefs(Buffer buf, ObjectPath path) {
+ std::vector<DefContext<Var>> context;
+
+ auto try_visit_implicit_var_def = [this, &context](const PrimExpr& expr, ObjectPath path) {
+ if (auto opt = expr.as<Var>()) {
+ auto var = opt.value();
+ if (auto var_def = WithDefIfUndefined(var, path)) {
+ context.push_back(std::move(var_def).value());
+ }
+ }
+ };
+ auto try_visit_implicit_var_def_array = [&try_visit_implicit_var_def](
+ const Array<PrimExpr>& arr, ObjectPath path) {
+ for (size_t i = 0; i < arr.size(); i++) {
+ try_visit_implicit_var_def(arr[i], path->ArrayIndex(i));
+ }
+ };
+
+ try_visit_implicit_var_def(buf->data, path->Attr("data"));
+ try_visit_implicit_var_def_array(buf->shape, path->Attr("shape"));
+ try_visit_implicit_var_def_array(buf->strides, path->Attr("strides"));
+ try_visit_implicit_var_def(buf->elem_offset, path->Attr("elem_offset"));
+
+ return context;
+ }
+
+ std::unordered_set<ObjectRef, ObjectPtrHash, ObjectPtrEqual> in_scope_definitions_;
};
} // namespace tir
diff --git a/tests/python/tir-analysis/test_tir_analysis_verify_well_formed.py b/tests/python/tir-analysis/test_tir_analysis_verify_well_formed.py
index 8c153afc9d..a1b3bee1b2 100644
--- a/tests/python/tir-analysis/test_tir_analysis_verify_well_formed.py
+++ b/tests/python/tir-analysis/test_tir_analysis_verify_well_formed.py
@@ -199,5 +199,154 @@ def test_reuse_of_env_thread_across_functions_is_ill_formed():
tvm.tir.analysis.verify_well_formed(mod)
+def test_multiple_buffer_arguments_may_share_allocation():
+ """T.match_buffer may re-use a data argument
+
+ Like the shape/strides/elem_offset fields in a buffer, the first
+ occurrence of a `buffer->data` field defines it, and the
+ occurrences are usages of that definition.
+ """
+
+ @I.ir_module
+ class mod:
+ @T.prim_func
+ def func(A_handle: T.handle, B_handle: T.handle):
+ A = T.match_buffer(A_handle, [256], "float32")
+ B = T.match_buffer(B_handle, [256], "float32", data=A.data)
+
+ pass
+
+ tvm.tir.analysis.verify_well_formed(mod)
+
+
+def test_buffer_bind_scope_defines_buffer_obj():
+ """The "buffer_bind_scope" attribute defines a buffer view"""
+
+ @I.ir_module
+ class mod:
+ @T.prim_func
+ def func(A: T.Buffer([256, 256], "float32")):
+
+ for tile_i, tile_j in T.grid(16, 16):
+ B = T.Buffer([16, 16], "float32")
+ T.attr(
+ [B, A],
+ "buffer_bind_scope",
+ T.tvm_tuple(
+ tile_i * 16,
+ 16,
+ tile_j * 16,
+ 16,
+ dtype="handle",
+ ),
+ )
+ for i, j in T.grid(16, 16):
+ B[i, j] = 0.0
+
+ tvm.tir.analysis.verify_well_formed(mod)
+
+
+def test_buffer_bind_scope_defines_symbolic_variables():
+ """The "buffer_bind_scope" attribute may define symbolic variables"""
+
+ @I.ir_module
+ class mod:
+ @T.prim_func
+ def func(A: T.Buffer([256, 256], "int32")):
+
+ for tile_i, tile_j in T.grid(16, 16):
+ elem_offset = T.int32()
+ B = T.Buffer([16, 16], "int32", elem_offset=elem_offset)
+ T.attr(
+ [B, A],
+ "buffer_bind_scope",
+ T.tvm_tuple(
+ tile_i * 16,
+ 16,
+ tile_j * 16,
+ 16,
+ dtype="handle",
+ ),
+ )
+ for i, j in T.grid(16, 16):
+ B[i, j] = elem_offset
+
+ tvm.tir.analysis.verify_well_formed(mod)
+
+
+def test_block_match_buffer_defines_buffer_obj():
+ """In a block, T.match_buffer defines a buffer view"""
+
+ @I.ir_module
+ class mod:
+ @T.prim_func
+ def func(A: T.Buffer([256, 256], "float32")):
+ for iters in T.grid(16, 16, 16, 16):
+ with T.block("compute"):
+ tile_i, tile_j, i, j = T.axis.remap("SSSS", iters)
+ B = T.match_buffer(
+ A[tile_i * 16 : (tile_i + 1) * 16, tile_j * 16 : (tile_j + 1) * 16],
+ dtype="float32",
+ )
+ B[i, j] = 0.0
+
+ tvm.tir.analysis.verify_well_formed(mod)
+
+
+def test_block_match_buffer_defines_symbolic_variables():
+ """In a block, T.match_buffer may define symbolic variables"""
+
+ @I.ir_module
+ class mod:
+ @T.prim_func
+ def func(A: T.Buffer([256, 256], "int32")):
+
+ for iters in T.grid(16, 16, 16, 16):
+ with T.block("compute"):
+ tile_i, tile_j, i, j = T.axis.remap("SSSS", iters)
+
+ elem_offset = T.int32()
+ B = T.match_buffer(
+ A[tile_i * 16 : (tile_i + 1) * 16, tile_j * 16 : (tile_j + 1) * 16],
+ dtype="float32",
+ elem_offset=elem_offset,
+ )
+
+ B[i, j] = elem_offset
+
+ tvm.tir.analysis.verify_well_formed(mod)
+
+
+def test_buffer_realize_on_external_buffer_is_annotation():
+ """A T.realize statement on an existing buffer annotates the region used"""
+
+ @I.ir_module
+ class mod:
+ @T.prim_func
+ def func(A: T.Buffer(256, "int32")):
+ T.realize(A[0:16], "global")
+
+ for i in range(16):
+ A[i] = 1
+
+ tvm.tir.analysis.verify_well_formed(mod)
+
+
+def test_buffer_realize_is_allocation():
+ """A T.realize statement on an fresh buffer allocates the buffer"""
+
+ @I.ir_module
+ class mod:
+ @T.prim_func
+ def func():
+ A = T.Buffer(256, "int32")
+ T.realize(A[0:16], "global")
+
+ for i in range(16):
+ A[i] = 1
+
+ tvm.tir.analysis.verify_well_formed(mod)
+
+
if __name__ == "__main__":
tvm.testing.main()