You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/08/30 20:29:34 UTC

[GitHub] [tvm] vinx13 commented on a diff in pull request #12412: [TIR, TVMScript] Update printer / parser to make T.allocate return buffer var

vinx13 commented on code in PR #12412:
URL: https://github.com/apache/tvm/pull/12412#discussion_r958892025


##########
tests/python/unittest/test_tir_transform_unroll_loop.py:
##########
@@ -117,16 +117,19 @@ class before:
         @T.prim_func
         def main():
             for i in T.unroll(2):
-                with T.allocate([16], "float32", "global") as buf:
+                with T.allocate([16], "float32", "global") as buf_data:
+                    buf = T.buffer_decl(shape=[16], dtype="float32", scope="global", data=buf_data)
                     buf[0] = 0.0
 
     @tvm.script.ir_module
     class expected:
         @T.prim_func
         def main():
-            with T.allocate([16], "float32", "global") as buf1:
+            with T.allocate([16], "float32", "global") as buf1_data:

Review Comment:
   Yes. The reason I kept `T.allocate` here is that the pass also need some updates before we can use `T.decl_buffer`. (The scope of this PR is to only update existing TVM scripts without touching related passes to minimize changes)



##########
src/printer/tvmscript_printer.cc:
##########
@@ -1028,59 +1039,52 @@ Doc TVMScriptPrinter::VisitStmt_(const BufferRealizeNode* op) {
   return Doc();
 }
 
-namespace {
-struct AllocUsage {
-  Buffer alloc_buffer;
-  Array<Buffer> aliasing_buffers;
-};
-
-template <typename AllocNode>
-AllocUsage FindAllocateUsage(AllocNode* op, Map<Var, Array<Buffer>>* cache_ptr) {
-  Map<Var, Array<Buffer>>& cache = *cache_ptr;
-  if (!cache.count(op->buffer_var)) {
-    cache = BufferUsageFinder::FindUsage(std::move(cache), op->body);
+bool IsAllocateDeclBufferPattern(const AllocateNode* allocate, const DeclBufferNode* decl_buffer) {
+  const Var& buffer_var = allocate->buffer_var;
+  const Buffer& buffer = decl_buffer->buffer;
+  if (!buffer_var.same_as(buffer->data)) {
+    return false;
   }
-  Array<Buffer> buffer_usage = cache.Get(op->buffer_var).value_or({});
-
-  auto is_exact_match = [](Buffer a, Buffer b) {
-    if (a->dtype != b->dtype) return false;
-    if (a->shape.size() != b->shape.size()) return false;
-
-    arith::Analyzer analyzer;
-    for (size_t i = 0; i < a->shape.size(); i++) {
-      if (!analyzer.CanProveEqual(a->shape[i], b->shape[i])) {
-        return false;
-      }
-    }
-    return true;
-  };
-
-  // If the buffer allocated via T.allocate is an exact match to the
-  // usage of the buffer later on, then that buffer is the return
-  // value of T.allocate, and no T.buffer_decl statement is needed.
-  Buffer alloc_buffer(op->buffer_var, op->dtype, op->extents, {}, 0, op->buffer_var->name_hint, 0,
-                      0, kDefault);
-  bool found_alloc_buf = false;
-  Array<Buffer> aliasing_buffers;
-  for (const auto& buf : buffer_usage) {
-    if (!found_alloc_buf && is_exact_match(buf, alloc_buffer)) {
-      alloc_buffer = buf;
-      found_alloc_buf = true;
-    } else {
-      aliasing_buffers.push_back(buf);
+  if (allocate->dtype != buffer->dtype) {
+    return false;
+  }
+  if (!is_one(allocate->condition)) {
+    return false;
+  }
+  if (allocate->annotations.size()) {
+    return false;
+  }
+  if (allocate->extents.size() != buffer->shape.size()) {
+    return false;
+  }
+  tir::ExprDeepEqual expr_equal;
+  for (size_t i = 0, n = allocate->extents.size(); i < n; ++i) {
+    if (!expr_equal(allocate->extents[i], buffer->shape[i])) {
+      return false;
     }
   }
-
-  return AllocUsage{alloc_buffer, aliasing_buffers};
+  return true;
 }
-}  // namespace
 
 Doc TVMScriptPrinter::VisitStmt_(const AllocateNode* op) {
-  auto usage = FindAllocateUsage(op, &buffer_var_usage_);
-  Buffer& alloc_buffer = usage.alloc_buffer;
-  Array<Buffer>& aliasing_buffers = usage.aliasing_buffers;
-  buf_not_in_headers_.insert(alloc_buffer.get());
-  var_not_in_headers_.insert(alloc_buffer->data.get());
+  var_not_in_headers_.insert(op->buffer_var.get());
+
+  if (!buffer_var_usage_.count(op->buffer_var)) {
+    buffer_var_usage_ = BufferUsageFinder::FindUsage(std::move(buffer_var_usage_), op->body);
+  }
+  Array<Buffer> buffer_usage = buffer_var_usage_.Get(op->buffer_var).value_or({});
+
+  if (buffer_usage.empty()) {
+    if (const DeclBufferNode* decl_buffer = op->body.as<DeclBufferNode>()) {
+      if (IsAllocateDeclBufferPattern(op, decl_buffer)) {
+        // As a syntax sugar, we identify the pattern of Allocate and DeclBuffer and print a single
+        // DeclBuffer statement. It is intentionally to call `Print` instead of `PrintBody` here to
+        // delegate the printing of the current node to `DeclBufferNode` while maintaining the
+        // same value of `current_num_` and `num_child_`.
+        return Print(op->body);

Review Comment:
   That's correct. So I checked `buffer_usage.empty()` above to make sure `T.buffer_decl` is not needed



##########
src/printer/tvmscript_printer.cc:
##########
@@ -1028,59 +1039,52 @@ Doc TVMScriptPrinter::VisitStmt_(const BufferRealizeNode* op) {
   return Doc();
 }
 
-namespace {
-struct AllocUsage {
-  Buffer alloc_buffer;
-  Array<Buffer> aliasing_buffers;
-};
-
-template <typename AllocNode>
-AllocUsage FindAllocateUsage(AllocNode* op, Map<Var, Array<Buffer>>* cache_ptr) {
-  Map<Var, Array<Buffer>>& cache = *cache_ptr;
-  if (!cache.count(op->buffer_var)) {
-    cache = BufferUsageFinder::FindUsage(std::move(cache), op->body);
+bool IsAllocateDeclBufferPattern(const AllocateNode* allocate, const DeclBufferNode* decl_buffer) {
+  const Var& buffer_var = allocate->buffer_var;
+  const Buffer& buffer = decl_buffer->buffer;
+  if (!buffer_var.same_as(buffer->data)) {
+    return false;
   }
-  Array<Buffer> buffer_usage = cache.Get(op->buffer_var).value_or({});
-
-  auto is_exact_match = [](Buffer a, Buffer b) {
-    if (a->dtype != b->dtype) return false;
-    if (a->shape.size() != b->shape.size()) return false;
-
-    arith::Analyzer analyzer;
-    for (size_t i = 0; i < a->shape.size(); i++) {
-      if (!analyzer.CanProveEqual(a->shape[i], b->shape[i])) {
-        return false;
-      }
-    }
-    return true;
-  };
-
-  // If the buffer allocated via T.allocate is an exact match to the
-  // usage of the buffer later on, then that buffer is the return
-  // value of T.allocate, and no T.buffer_decl statement is needed.
-  Buffer alloc_buffer(op->buffer_var, op->dtype, op->extents, {}, 0, op->buffer_var->name_hint, 0,
-                      0, kDefault);
-  bool found_alloc_buf = false;
-  Array<Buffer> aliasing_buffers;
-  for (const auto& buf : buffer_usage) {
-    if (!found_alloc_buf && is_exact_match(buf, alloc_buffer)) {
-      alloc_buffer = buf;
-      found_alloc_buf = true;
-    } else {
-      aliasing_buffers.push_back(buf);
+  if (allocate->dtype != buffer->dtype) {
+    return false;
+  }
+  if (!is_one(allocate->condition)) {
+    return false;
+  }
+  if (allocate->annotations.size()) {
+    return false;
+  }
+  if (allocate->extents.size() != buffer->shape.size()) {
+    return false;
+  }
+  tir::ExprDeepEqual expr_equal;
+  for (size_t i = 0, n = allocate->extents.size(); i < n; ++i) {
+    if (!expr_equal(allocate->extents[i], buffer->shape[i])) {
+      return false;
     }
   }
-
-  return AllocUsage{alloc_buffer, aliasing_buffers};
+  return true;
 }
-}  // namespace
 
 Doc TVMScriptPrinter::VisitStmt_(const AllocateNode* op) {
-  auto usage = FindAllocateUsage(op, &buffer_var_usage_);
-  Buffer& alloc_buffer = usage.alloc_buffer;
-  Array<Buffer>& aliasing_buffers = usage.aliasing_buffers;
-  buf_not_in_headers_.insert(alloc_buffer.get());
-  var_not_in_headers_.insert(alloc_buffer->data.get());
+  var_not_in_headers_.insert(op->buffer_var.get());
+
+  if (!buffer_var_usage_.count(op->buffer_var)) {
+    buffer_var_usage_ = BufferUsageFinder::FindUsage(std::move(buffer_var_usage_), op->body);
+  }
+  Array<Buffer> buffer_usage = buffer_var_usage_.Get(op->buffer_var).value_or({});
+
+  if (buffer_usage.empty()) {

Review Comment:
   Usage in `DeclBuffer` is excluded from the result of `BufferUsageFinder`



##########
src/printer/tvmscript_printer.cc:
##########
@@ -1028,59 +1039,52 @@ Doc TVMScriptPrinter::VisitStmt_(const BufferRealizeNode* op) {
   return Doc();
 }
 
-namespace {
-struct AllocUsage {
-  Buffer alloc_buffer;
-  Array<Buffer> aliasing_buffers;
-};
-
-template <typename AllocNode>
-AllocUsage FindAllocateUsage(AllocNode* op, Map<Var, Array<Buffer>>* cache_ptr) {
-  Map<Var, Array<Buffer>>& cache = *cache_ptr;
-  if (!cache.count(op->buffer_var)) {
-    cache = BufferUsageFinder::FindUsage(std::move(cache), op->body);
+bool IsAllocateDeclBufferPattern(const AllocateNode* allocate, const DeclBufferNode* decl_buffer) {

Review Comment:
   Nice catch! Indeed it's clearer without `decl_buffer` argument



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org