You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by lu...@apache.org on 2024/03/14 10:51:46 UTC

(tvm) branch main updated: [TIR] Improve well-formed check's handling of match buffer (#16655)

This is an automated email from the ASF dual-hosted git repository.

lunderberg pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 695f958bc9 [TIR] Improve well-formed check's handling of match buffer (#16655)
695f958bc9 is described below

commit 695f958bc9ef40e625a84ad9355df2e75e6498a0
Author: Eric Lunderberg <Lu...@users.noreply.github.com>
AuthorDate: Thu Mar 14 05:51:39 2024 -0500

    [TIR] Improve well-formed check's handling of match buffer (#16655)
    
    * [TIR] Improve well-formed check's handling of match buffer
    
    - The `T.match_buffer` at the start of a function may contain repeated
      use of the same data var.  For example, a function that must accept
      two `DLTensor` objects with the same backing allocation.
    
    - The `"buffer_bind_scope"` is an older style of match buffer, and may
      be the point of definition for variables.
    
    * Improved comment, added context.pop_back()
---
 src/tir/analysis/verify_well_formed.cc             |   1 +
 src/tir/ir/tir_visitor_with_path.cc                |  78 +++++------
 src/tir/ir/tir_visitor_with_path.h                 |  43 ++++++
 .../test_tir_analysis_verify_well_formed.py        | 149 +++++++++++++++++++++
 4 files changed, 228 insertions(+), 43 deletions(-)

diff --git a/src/tir/analysis/verify_well_formed.cc b/src/tir/analysis/verify_well_formed.cc
index 943a119711..c001d35054 100644
--- a/src/tir/analysis/verify_well_formed.cc
+++ b/src/tir/analysis/verify_well_formed.cc
@@ -228,6 +228,7 @@ class UndefinedVarVerifier : public Verifier<UndefinedVarVerifier> {
   using Verifier::Verifier;
 
  private:
+  using Verifier::Visit;
   void Visit(const PrimFunc& prim_func, ObjectPath path) override {
     Verifier::Visit(prim_func, path);
     redefine_allowed_within_function_.clear();
diff --git a/src/tir/ir/tir_visitor_with_path.cc b/src/tir/ir/tir_visitor_with_path.cc
index a80f2300e2..37b3ce55a2 100644
--- a/src/tir/ir/tir_visitor_with_path.cc
+++ b/src/tir/ir/tir_visitor_with_path.cc
@@ -78,47 +78,22 @@ void TIRVisitorWithPath::Visit(const PrimFunc& func, ObjectPath path) {
   // variable has occurred.  Therefore, to ensure that we only avoid
   // duplicate calls to VisitVarDef, these semantics need to be
   // checked.
-  std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> defined_params;
   std::vector<std::variant<DefContext<Var>, DefContext<Buffer>>> context;
 
   auto ppath = path->Attr("params");
   for (size_t i = 0; i < func->params.size(); i++) {
     context.push_back(WithDef(func->params[i], ppath->ArrayIndex(i)));
-    defined_params.insert(func->params[i]);
   }
 
-  auto try_visit_implicit_var_def = [this, &defined_params, &context](const PrimExpr& expr,
-                                                                      ObjectPath path) {
-    if (auto opt = expr.as<Var>()) {
-      auto var = opt.value();
-      if (!defined_params.count(var)) {
-        context.push_back(WithDef(var, path));
-        defined_params.insert(var);
-      }
-    }
-  };
-  auto try_visit_implicit_var_def_array = [&try_visit_implicit_var_def](const Array<PrimExpr>& arr,
-                                                                        ObjectPath path) {
-    for (size_t i = 0; i < arr.size(); i++) {
-      try_visit_implicit_var_def(arr[i], path->ArrayIndex(i));
-    }
-  };
-
   auto buffer_map_path = path->Attr("buffer_map");
   for (size_t i = 0; i < func->params.size(); i++) {
     if (auto opt = func->buffer_map.Get(func->params[i])) {
       auto buf = opt.value();
       auto buf_path = buffer_map_path->MapValue(ppath->ArrayIndex(i));
 
-      // A buffer in the buffer_map always defines its data pointer
-      context.push_back(WithDef(buf->data, buf_path->Attr("data")));
-
-      // But other implicit definitions only apply if they weren't
-      // provided as explicit parameters, and they weren't defined
-      // implicitly by any previous buffer.
-      try_visit_implicit_var_def_array(buf->shape, buf_path->Attr("shape"));
-      try_visit_implicit_var_def_array(buf->strides, buf_path->Attr("strides"));
-      try_visit_implicit_var_def(buf->elem_offset, buf_path->Attr("elem_offset"));
+      for (auto& def : WithMatchBufferDefs(buf, buf_path)) {
+        context.push_back(std::move(def));
+      }
     }
   }
 
@@ -127,7 +102,7 @@ void TIRVisitorWithPath::Visit(const PrimFunc& func, ObjectPath path) {
   for (size_t i = 0; i < func->params.size(); i++) {
     if (auto opt = func->buffer_map.Get(func->params[i])) {
       auto buf_path = buffer_map_path->MapValue(ppath->ArrayIndex(i));
-      EnterDef(opt.value(), buf_path);
+      context.push_back(WithDef(opt.value(), buf_path));
     }
   }
 
@@ -199,16 +174,40 @@ void TIRVisitorWithPath::VisitStmt_(const LetStmtNode* op, ObjectPath path) {
 void TIRVisitorWithPath::VisitStmt_(const AttrStmtNode* op, ObjectPath path) {
   Visit(op->value, path->Attr("value"));
 
-  std::optional<DefContext<IterVar>> context = std::nullopt;
+  std::vector<std::variant<DefContext<IterVar>, DefContext<Var>>> context;
   if (auto iter_var = op->node.as<IterVar>();
       iter_var && (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread)) {
     // Some attributes serve as a source of definition for the
     // tir::Var they annotate.
-    context = WithDef(iter_var.value(), path->Attr("node"));
+    context.push_back(WithDef(iter_var.value(), path->Attr("node")));
+
+  } else if (op->attr_key == attr::buffer_bind_scope) {
+    // The `attr::buffer_bind_scope` attribute defines a view into an
+    // existing buffer, similar to the newer
+    // `BlockNode::match_buffers` field.  It requires the buffer being
+    // viewed to be defined prior to the attribute.  The
+    // `attr::buffer_bind_scope` is the point of definition for the
+    // `tir::Buffer buffer_view`, its `tir::Var` data pointer, and any
+    // symbolic shapes used within `buffer_view that are not already
+    // defined.
+    Array<ObjectRef> arr = Downcast<Array<ObjectRef>>(op->node);
+    ICHECK_EQ(arr.size(), 2U);
+    Buffer buffer_view = Downcast<Buffer>(arr[0]);
+    Buffer orig_buffer = Downcast<Buffer>(arr[1]);
+    Visit(orig_buffer, path->Attr("node")->ArrayIndex(1));
+
+    for (auto& var : WithMatchBufferDefs(buffer_view, path->Attr("node")->ArrayIndex(0))) {
+      context.push_back(std::move(var));
+    }
+
   } else if (auto expr = op->node.as<PrimExpr>()) {
     Visit(expr.value(), path->Attr("node"));
   }
   Visit(op->body, path->Attr("body"));
+
+  while (context.size()) {
+    context.pop_back();
+  }
 }
 
 void TIRVisitorWithPath::VisitStmt_(const ForNode* op, ObjectPath path) {
@@ -250,7 +249,8 @@ void TIRVisitorWithPath::VisitStmt_(const BufferStoreNode* op, ObjectPath path)
 void TIRVisitorWithPath::VisitStmt_(const BufferRealizeNode* op, ObjectPath path) {
   Visit(op->condition, path->Attr("condition"));
   Visit(op->bounds, path->Attr("bounds"));
-  auto context = WithDef(op->buffer, path->Attr("buffer"));
+  auto context = WithDefIfUndefined(op->buffer->data, path->Attr("buffer")->Attr("data"));
+  Visit(op->buffer, path->Attr("buffer"));
   Visit(op->body, path->Attr("body"));
 }
 
@@ -318,18 +318,10 @@ void TIRVisitorWithPath::VisitStmt_(const BlockNode* op, ObjectPath path) {
     for (size_t i = 0; i < op->match_buffers.size(); i++) {
       auto buf = op->match_buffers[i]->buffer;
       auto buffer_path = match_path->ArrayIndex(i)->Attr("buffer");
-      auto buffer_strides_path = buffer_path->Attr("strides");
-      context.push_back(WithDef(buf->data, buffer_path->Attr("data")));
-      // Define buffer strides and elem_offset if they are vars
-      if (const auto* v = buf->elem_offset.as<VarNode>()) {
-        context.push_back(WithDef(GetRef<Var>(v), buffer_path->Attr("elem_offset")));
-      }
-      for (size_t i = 0; i < buf->strides.size(); ++i) {
-        if (const auto* v = buf->strides[i].as<VarNode>()) {
-          context.push_back(WithDef(GetRef<Var>(v), buffer_strides_path->ArrayIndex(i)));
-        }
+
+      for (auto& def : WithMatchBufferDefs(buf, buffer_path)) {
+        context.push_back(std::move(def));
       }
-      context.push_back(WithDef(buf, buffer_path));
     }
   }
 
diff --git a/src/tir/ir/tir_visitor_with_path.h b/src/tir/ir/tir_visitor_with_path.h
index dd0da1fe77..1ae6df58f7 100644
--- a/src/tir/ir/tir_visitor_with_path.h
+++ b/src/tir/ir/tir_visitor_with_path.h
@@ -29,7 +29,10 @@
 #include <tvm/tir/stmt_functor.h>
 
 #include <exception>
+#include <optional>
+#include <unordered_set>
 #include <utility>
+#include <vector>
 
 namespace tvm {
 namespace tir {
@@ -173,6 +176,7 @@ class TIRVisitorWithPath : protected ExprFunctor<void(const PrimExpr&, ObjectPat
       // construction of the DefContext and the destruction, we avoid
       // this case and allow the first error to propagate upward.
       if (self_ && std::uncaught_exceptions() == uncaught_exceptions_) {
+        self_->in_scope_definitions_.erase(obj_);
         self_->ExitDef(obj_, path_);
       }
     }
@@ -182,6 +186,7 @@ class TIRVisitorWithPath : protected ExprFunctor<void(const PrimExpr&, ObjectPat
 
     DefContext(TIRVisitorWithPath* self, T obj, ObjectPath path)
         : self_(self), obj_(obj), path_(path), uncaught_exceptions_(std::uncaught_exceptions()) {
+      self_->in_scope_definitions_.insert(obj_);
       self_->EnterDef(obj_, path_);
     }
 
@@ -203,6 +208,44 @@ class TIRVisitorWithPath : protected ExprFunctor<void(const PrimExpr&, ObjectPat
   DefContext<T> WithDef(T obj, ObjectPath path) {
     return DefContext(this, obj, path);
   }
+
+  /* \brief Utility to track the scope of a node's definition. */
+  template <typename T>
+  std::optional<DefContext<T>> WithDefIfUndefined(T obj, ObjectPath path) {
+    if (in_scope_definitions_.count(obj)) {
+      return std::nullopt;
+    } else {
+      return WithDef(obj, path);
+    }
+  }
+
+  std::vector<DefContext<Var>> WithMatchBufferDefs(Buffer buf, ObjectPath path) {
+    std::vector<DefContext<Var>> context;
+
+    auto try_visit_implicit_var_def = [this, &context](const PrimExpr& expr, ObjectPath path) {
+      if (auto opt = expr.as<Var>()) {
+        auto var = opt.value();
+        if (auto var_def = WithDefIfUndefined(var, path)) {
+          context.push_back(std::move(var_def).value());
+        }
+      }
+    };
+    auto try_visit_implicit_var_def_array = [&try_visit_implicit_var_def](
+                                                const Array<PrimExpr>& arr, ObjectPath path) {
+      for (size_t i = 0; i < arr.size(); i++) {
+        try_visit_implicit_var_def(arr[i], path->ArrayIndex(i));
+      }
+    };
+
+    try_visit_implicit_var_def(buf->data, path->Attr("data"));
+    try_visit_implicit_var_def_array(buf->shape, path->Attr("shape"));
+    try_visit_implicit_var_def_array(buf->strides, path->Attr("strides"));
+    try_visit_implicit_var_def(buf->elem_offset, path->Attr("elem_offset"));
+
+    return context;
+  }
+
+  std::unordered_set<ObjectRef, ObjectPtrHash, ObjectPtrEqual> in_scope_definitions_;
 };
 
 }  // namespace tir
diff --git a/tests/python/tir-analysis/test_tir_analysis_verify_well_formed.py b/tests/python/tir-analysis/test_tir_analysis_verify_well_formed.py
index 8c153afc9d..a1b3bee1b2 100644
--- a/tests/python/tir-analysis/test_tir_analysis_verify_well_formed.py
+++ b/tests/python/tir-analysis/test_tir_analysis_verify_well_formed.py
@@ -199,5 +199,154 @@ def test_reuse_of_env_thread_across_functions_is_ill_formed():
         tvm.tir.analysis.verify_well_formed(mod)
 
 
+def test_multiple_buffer_arguments_may_share_allocation():
+    """T.match_buffer may re-use a data argument
+
+    Like the shape/strides/elem_offset fields in a buffer, the first
+    occurrence of a `buffer->data` field defines it, and the
+    occurrences are usages of that definition.
+    """
+
+    @I.ir_module
+    class mod:
+        @T.prim_func
+        def func(A_handle: T.handle, B_handle: T.handle):
+            A = T.match_buffer(A_handle, [256], "float32")
+            B = T.match_buffer(B_handle, [256], "float32", data=A.data)
+
+            pass
+
+    tvm.tir.analysis.verify_well_formed(mod)
+
+
+def test_buffer_bind_scope_defines_buffer_obj():
+    """The "buffer_bind_scope" attribute defines a buffer view"""
+
+    @I.ir_module
+    class mod:
+        @T.prim_func
+        def func(A: T.Buffer([256, 256], "float32")):
+
+            for tile_i, tile_j in T.grid(16, 16):
+                B = T.Buffer([16, 16], "float32")
+                T.attr(
+                    [B, A],
+                    "buffer_bind_scope",
+                    T.tvm_tuple(
+                        tile_i * 16,
+                        16,
+                        tile_j * 16,
+                        16,
+                        dtype="handle",
+                    ),
+                )
+                for i, j in T.grid(16, 16):
+                    B[i, j] = 0.0
+
+    tvm.tir.analysis.verify_well_formed(mod)
+
+
+def test_buffer_bind_scope_defines_symbolic_variables():
+    """The "buffer_bind_scope" attribute may define symbolic variables"""
+
+    @I.ir_module
+    class mod:
+        @T.prim_func
+        def func(A: T.Buffer([256, 256], "int32")):
+
+            for tile_i, tile_j in T.grid(16, 16):
+                elem_offset = T.int32()
+                B = T.Buffer([16, 16], "int32", elem_offset=elem_offset)
+                T.attr(
+                    [B, A],
+                    "buffer_bind_scope",
+                    T.tvm_tuple(
+                        tile_i * 16,
+                        16,
+                        tile_j * 16,
+                        16,
+                        dtype="handle",
+                    ),
+                )
+                for i, j in T.grid(16, 16):
+                    B[i, j] = elem_offset
+
+    tvm.tir.analysis.verify_well_formed(mod)
+
+
+def test_block_match_buffer_defines_buffer_obj():
+    """In a block, T.match_buffer defines a buffer view"""
+
+    @I.ir_module
+    class mod:
+        @T.prim_func
+        def func(A: T.Buffer([256, 256], "float32")):
+            for iters in T.grid(16, 16, 16, 16):
+                with T.block("compute"):
+                    tile_i, tile_j, i, j = T.axis.remap("SSSS", iters)
+                    B = T.match_buffer(
+                        A[tile_i * 16 : (tile_i + 1) * 16, tile_j * 16 : (tile_j + 1) * 16],
+                        dtype="float32",
+                    )
+                    B[i, j] = 0.0
+
+    tvm.tir.analysis.verify_well_formed(mod)
+
+
+def test_block_match_buffer_defines_symbolic_variables():
+    """In a block, T.match_buffer may define symbolic variables"""
+
+    @I.ir_module
+    class mod:
+        @T.prim_func
+        def func(A: T.Buffer([256, 256], "int32")):
+
+            for iters in T.grid(16, 16, 16, 16):
+                with T.block("compute"):
+                    tile_i, tile_j, i, j = T.axis.remap("SSSS", iters)
+
+                    elem_offset = T.int32()
+                    B = T.match_buffer(
+                        A[tile_i * 16 : (tile_i + 1) * 16, tile_j * 16 : (tile_j + 1) * 16],
+                        dtype="float32",
+                        elem_offset=elem_offset,
+                    )
+
+                    B[i, j] = elem_offset
+
+    tvm.tir.analysis.verify_well_formed(mod)
+
+
+def test_buffer_realize_on_external_buffer_is_annotation():
+    """A T.realize statement on an existing buffer annotates the region used"""
+
+    @I.ir_module
+    class mod:
+        @T.prim_func
+        def func(A: T.Buffer(256, "int32")):
+            T.realize(A[0:16], "global")
+
+            for i in range(16):
+                A[i] = 1
+
+    tvm.tir.analysis.verify_well_formed(mod)
+
+
+def test_buffer_realize_is_allocation():
+    """A T.realize statement on an fresh buffer allocates the buffer"""
+
+    @I.ir_module
+    class mod:
+        @T.prim_func
+        def func():
+            A = T.Buffer(256, "int32")
+            T.realize(A[0:16], "global")
+
+            for i in range(16):
+                A[i] = 1
+
+    tvm.tir.analysis.verify_well_formed(mod)
+
+
 if __name__ == "__main__":
     tvm.testing.main()