You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/08/30 20:08:52 UTC

[GitHub] [tvm] Lunderberg commented on a diff in pull request #12412: [TIR, TVMScript] Update printer / parser to make T.allocate return buffer var

Lunderberg commented on code in PR #12412:
URL: https://github.com/apache/tvm/pull/12412#discussion_r958848078


##########
src/printer/tvmscript_printer.cc:
##########
@@ -100,13 +100,21 @@ class BufferUsageFinder : public StmtExprVisitor {
     StmtExprVisitor::VisitStmt_(op);
   }
 
+  void VisitStmt_(const DeclBufferNode* op) final {
+    buffers_declared_.insert(op->buffer.get());

Review Comment:
   Should we also track which buffers have gone out of scope?  If I'm understanding it correctly, a single `DeclBufferNode` would also allow for usage outside of the `DeclBufferNode::body`, where I'd expect it to only apply within the scope of the node.



##########
tests/python/contrib/test_ethosu/test_hoist_allocates.py:
##########
@@ -242,7 +266,8 @@ def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,),
             T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[2], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global_1[0], 112, 12, placeholder_d_global_1[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
             T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_4[0], 112, placeholder_global_2[0], dtype="handle"))
             T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_5[0], 32, placeholder_d_global_2[0], dtype="handle"))
-            placeholder_d_global_3 = T.allocate([32], "uint8", "global")
+            placeholder_d_global_3_data = T.allocate([32], "uint8", "global")

Review Comment:
   Same question here, whether we can use a single `T.decl_buffer` call.



##########
src/printer/tvmscript_printer.cc:
##########
@@ -1028,59 +1039,52 @@ Doc TVMScriptPrinter::VisitStmt_(const BufferRealizeNode* op) {
   return Doc();
 }
 
-namespace {
-struct AllocUsage {
-  Buffer alloc_buffer;
-  Array<Buffer> aliasing_buffers;
-};
-
-template <typename AllocNode>
-AllocUsage FindAllocateUsage(AllocNode* op, Map<Var, Array<Buffer>>* cache_ptr) {
-  Map<Var, Array<Buffer>>& cache = *cache_ptr;
-  if (!cache.count(op->buffer_var)) {
-    cache = BufferUsageFinder::FindUsage(std::move(cache), op->body);
+bool IsAllocateDeclBufferPattern(const AllocateNode* allocate, const DeclBufferNode* decl_buffer) {

Review Comment:
   Why is the `decl_buffer` argument needed?  It looks like this pattern only applies when the `DeclBufferNode` is the immediate child of `AllocateNode`, so we could pull that part of the check into this function.  I'm thinking something like the following:
   
   ```c++
   bool IsAllocateDeclBufferPattern(const AllocateNode* allocate) {
       const DeclBufferNode* decl_buffer = allocate->body.as<DeclBufferNode>();
       if(!decl_buffer) {
           return false;
       }
       // Continue as normal from here.
   }
   ```
   
   (P.S. I really like this as a way to provide the cleaner TVMScript syntax without immediately requiring additional TIR changes, and I'm glad that it gets rid of the `FindAllocateUsage`.  That felt like a hack as I was putting it in.)



##########
tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py:
##########
@@ -56,8 +56,8 @@ def main(placeholder_6: T.Buffer[(192,), "int8"], ethosu_conv2d_1: T.Buffer[(512
         placeholder_8 = T.buffer_decl([1], "uint8")
         placeholder_5 = T.buffer_decl([1], "uint8")
         # body
-        ethosu_conv2d_2 = T.allocate([1024], "uint8", "global")
-        ethosu_conv2d_3 = T.allocate([2048], "uint8", "global")
+        ethosu_conv2d_2 = T.decl_buffer([1024], "uint8", scope="global")

Review Comment:
   Can we remove the `scope="global"` parameter since it matches the default?



##########
src/printer/tvmscript_printer.cc:
##########
@@ -1028,59 +1039,52 @@ Doc TVMScriptPrinter::VisitStmt_(const BufferRealizeNode* op) {
   return Doc();
 }
 
-namespace {
-struct AllocUsage {
-  Buffer alloc_buffer;
-  Array<Buffer> aliasing_buffers;
-};
-
-template <typename AllocNode>
-AllocUsage FindAllocateUsage(AllocNode* op, Map<Var, Array<Buffer>>* cache_ptr) {
-  Map<Var, Array<Buffer>>& cache = *cache_ptr;
-  if (!cache.count(op->buffer_var)) {
-    cache = BufferUsageFinder::FindUsage(std::move(cache), op->body);
+bool IsAllocateDeclBufferPattern(const AllocateNode* allocate, const DeclBufferNode* decl_buffer) {
+  const Var& buffer_var = allocate->buffer_var;
+  const Buffer& buffer = decl_buffer->buffer;
+  if (!buffer_var.same_as(buffer->data)) {
+    return false;
   }
-  Array<Buffer> buffer_usage = cache.Get(op->buffer_var).value_or({});
-
-  auto is_exact_match = [](Buffer a, Buffer b) {
-    if (a->dtype != b->dtype) return false;
-    if (a->shape.size() != b->shape.size()) return false;
-
-    arith::Analyzer analyzer;
-    for (size_t i = 0; i < a->shape.size(); i++) {
-      if (!analyzer.CanProveEqual(a->shape[i], b->shape[i])) {
-        return false;
-      }
-    }
-    return true;
-  };
-
-  // If the buffer allocated via T.allocate is an exact match to the
-  // usage of the buffer later on, then that buffer is the return
-  // value of T.allocate, and no T.buffer_decl statement is needed.
-  Buffer alloc_buffer(op->buffer_var, op->dtype, op->extents, {}, 0, op->buffer_var->name_hint, 0,
-                      0, kDefault);
-  bool found_alloc_buf = false;
-  Array<Buffer> aliasing_buffers;
-  for (const auto& buf : buffer_usage) {
-    if (!found_alloc_buf && is_exact_match(buf, alloc_buffer)) {
-      alloc_buffer = buf;
-      found_alloc_buf = true;
-    } else {
-      aliasing_buffers.push_back(buf);
+  if (allocate->dtype != buffer->dtype) {
+    return false;
+  }
+  if (!is_one(allocate->condition)) {
+    return false;
+  }
+  if (allocate->annotations.size()) {
+    return false;
+  }
+  if (allocate->extents.size() != buffer->shape.size()) {
+    return false;
+  }
+  tir::ExprDeepEqual expr_equal;
+  for (size_t i = 0, n = allocate->extents.size(); i < n; ++i) {
+    if (!expr_equal(allocate->extents[i], buffer->shape[i])) {
+      return false;
     }
   }
-
-  return AllocUsage{alloc_buffer, aliasing_buffers};
+  return true;
 }
-}  // namespace
 
 Doc TVMScriptPrinter::VisitStmt_(const AllocateNode* op) {
-  auto usage = FindAllocateUsage(op, &buffer_var_usage_);
-  Buffer& alloc_buffer = usage.alloc_buffer;
-  Array<Buffer>& aliasing_buffers = usage.aliasing_buffers;
-  buf_not_in_headers_.insert(alloc_buffer.get());
-  var_not_in_headers_.insert(alloc_buffer->data.get());
+  var_not_in_headers_.insert(op->buffer_var.get());
+
+  if (!buffer_var_usage_.count(op->buffer_var)) {
+    buffer_var_usage_ = BufferUsageFinder::FindUsage(std::move(buffer_var_usage_), op->body);
+  }
+  Array<Buffer> buffer_usage = buffer_var_usage_.Get(op->buffer_var).value_or({});
+
+  if (buffer_usage.empty()) {

Review Comment:
   What is the benefit of the check on `buffer_usage.empty()`.  It looks like it would prevent the Allocate/DeclBuffer pattern from being printed whenever the buffer var is used.



##########
tests/python/contrib/test_ethosu/test_copy_compute_reordering.py:
##########
@@ -40,14 +40,14 @@ def main() -> None:
         buffer9 = T.buffer_decl([32], "uint8")
         buffer10 = T.buffer_decl([2048], "int8")
         # body
-        p1 = T.allocate([128], "uint8", "global")
-        p2 = T.allocate([112], "uint8", "global")
-        p3 = T.allocate([112], "uint8", "global")
-        p4 = T.allocate([32], "uint8", "global")
-        p5 = T.allocate([32], "uint8", "global")
-        p6 = T.allocate([32], "uint8", "global")
-        p7 = T.allocate([112], "uint8", "global")
-        p8 = T.allocate([32], "uint8", "global")
+        p1 = T.decl_buffer([128], "uint8", scope="global")

Review Comment:
   Since `"global"` is the default value for scope, can we remove the `scope = "global"` parameter?  It looks like it was only present before because there was no default scope for `allocate()`.



##########
tests/python/contrib/test_ethosu/test_merge_constants.py:
##########
@@ -44,8 +44,10 @@ def main(buffer2: T.Buffer[(128,), "uint8"], buffer3: T.Buffer[(32,), "uint8"])
             buffer1 = T.buffer_decl([8192], "int8")
             buffer10 = T.buffer_decl([2048], "int8")
             # body
-            p1 = T.allocate([128], "uint8", "global")
-            p4 = T.allocate([32], "uint8", "global")
+            p1_data = T.allocate([128], "uint8", "global")

Review Comment:
   Can this use `T.decl_buffer`, here and lower in the file?



##########
tests/python/unittest/test_tir_renew_defs.py:
##########
@@ -135,7 +135,8 @@ def test_undefined_buffer():
     @T.prim_func
     def access_alloc():
         # Buffer A should be remapped
-        A = T.allocate([128], "float16", "global")
+        A_data = T.allocate([128], "float16", "global")

Review Comment:
   Can this be done through `T.decl_buffer` without the `T.allocate` statement?



##########
tests/python/contrib/test_ethosu/test_hoist_allocates.py:
##########
@@ -227,13 +244,20 @@ def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,),
             T.preflattened_buffer(placeholder, [1, 16, 16, 32], dtype="int8", data=placeholder.data)
             T.preflattened_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8", data=ethosu_write.data)
             # body
-            placeholder_global = T.allocate([128], "uint8", "global")
-            placeholder_global_1 = T.allocate([112], "uint8", "global")
-            placeholder_global_2 = T.allocate([112], "uint8", "global")
-            placeholder_d_global = T.allocate([32], "uint8", "global")
-            placeholder_d_global_1 = T.allocate([32], "uint8", "global")
-            placeholder_d_global_2 = T.allocate([32], "uint8", "global")
-            placeholder_global_3 = T.allocate([112], "uint8", "global")
+            placeholder_global_data = T.allocate([128], "uint8", "global")

Review Comment:
   Do these require the separate `T.allocate` call?   This section looks like it matches the allocate/decl_buffer pattern.



##########
tests/python/unittest/test_tir_transform_unroll_loop.py:
##########
@@ -117,16 +117,19 @@ class before:
         @T.prim_func
         def main():
             for i in T.unroll(2):
-                with T.allocate([16], "float32", "global") as buf:
+                with T.allocate([16], "float32", "global") as buf_data:
+                    buf = T.buffer_decl(shape=[16], dtype="float32", scope="global", data=buf_data)
                     buf[0] = 0.0
 
     @tvm.script.ir_module
     class expected:
         @T.prim_func
         def main():
-            with T.allocate([16], "float32", "global") as buf1:
+            with T.allocate([16], "float32", "global") as buf1_data:

Review Comment:
   Since the `T.decl_buffer` is defined as a scope handler, it could be used in this context too, correct?



##########
tests/python/unittest/test_tir_transform_flatten_buffer.py:
##########
@@ -33,7 +33,8 @@ def elementwise_func(a: T.handle, c: T.handle) -> None:
     A = T.match_buffer(a, (16, 16), "float32")
     C = T.match_buffer(c, (16, 16), "float32")
     for i in T.serial(0, 16):
-        B_new = T.allocate([1, 16], "float32", "global")
+        B_new_data = T.allocate([1, 16], "float32", "global")

Review Comment:
   Can this be done with `T.decl_buffer` without the `data` argument?



##########
src/printer/tvmscript_printer.cc:
##########
@@ -1028,59 +1039,52 @@ Doc TVMScriptPrinter::VisitStmt_(const BufferRealizeNode* op) {
   return Doc();
 }
 
-namespace {
-struct AllocUsage {
-  Buffer alloc_buffer;
-  Array<Buffer> aliasing_buffers;
-};
-
-template <typename AllocNode>
-AllocUsage FindAllocateUsage(AllocNode* op, Map<Var, Array<Buffer>>* cache_ptr) {
-  Map<Var, Array<Buffer>>& cache = *cache_ptr;
-  if (!cache.count(op->buffer_var)) {
-    cache = BufferUsageFinder::FindUsage(std::move(cache), op->body);
+bool IsAllocateDeclBufferPattern(const AllocateNode* allocate, const DeclBufferNode* decl_buffer) {
+  const Var& buffer_var = allocate->buffer_var;
+  const Buffer& buffer = decl_buffer->buffer;
+  if (!buffer_var.same_as(buffer->data)) {
+    return false;
   }
-  Array<Buffer> buffer_usage = cache.Get(op->buffer_var).value_or({});
-
-  auto is_exact_match = [](Buffer a, Buffer b) {
-    if (a->dtype != b->dtype) return false;
-    if (a->shape.size() != b->shape.size()) return false;
-
-    arith::Analyzer analyzer;
-    for (size_t i = 0; i < a->shape.size(); i++) {
-      if (!analyzer.CanProveEqual(a->shape[i], b->shape[i])) {
-        return false;
-      }
-    }
-    return true;
-  };
-
-  // If the buffer allocated via T.allocate is an exact match to the
-  // usage of the buffer later on, then that buffer is the return
-  // value of T.allocate, and no T.buffer_decl statement is needed.
-  Buffer alloc_buffer(op->buffer_var, op->dtype, op->extents, {}, 0, op->buffer_var->name_hint, 0,
-                      0, kDefault);
-  bool found_alloc_buf = false;
-  Array<Buffer> aliasing_buffers;
-  for (const auto& buf : buffer_usage) {
-    if (!found_alloc_buf && is_exact_match(buf, alloc_buffer)) {
-      alloc_buffer = buf;
-      found_alloc_buf = true;
-    } else {
-      aliasing_buffers.push_back(buf);
+  if (allocate->dtype != buffer->dtype) {
+    return false;
+  }
+  if (!is_one(allocate->condition)) {
+    return false;
+  }
+  if (allocate->annotations.size()) {
+    return false;
+  }
+  if (allocate->extents.size() != buffer->shape.size()) {
+    return false;
+  }
+  tir::ExprDeepEqual expr_equal;
+  for (size_t i = 0, n = allocate->extents.size(); i < n; ++i) {
+    if (!expr_equal(allocate->extents[i], buffer->shape[i])) {
+      return false;
     }
   }
-
-  return AllocUsage{alloc_buffer, aliasing_buffers};
+  return true;
 }
-}  // namespace
 
 Doc TVMScriptPrinter::VisitStmt_(const AllocateNode* op) {
-  auto usage = FindAllocateUsage(op, &buffer_var_usage_);
-  Buffer& alloc_buffer = usage.alloc_buffer;
-  Array<Buffer>& aliasing_buffers = usage.aliasing_buffers;
-  buf_not_in_headers_.insert(alloc_buffer.get());
-  var_not_in_headers_.insert(alloc_buffer->data.get());
+  var_not_in_headers_.insert(op->buffer_var.get());
+
+  if (!buffer_var_usage_.count(op->buffer_var)) {
+    buffer_var_usage_ = BufferUsageFinder::FindUsage(std::move(buffer_var_usage_), op->body);
+  }
+  Array<Buffer> buffer_usage = buffer_var_usage_.Get(op->buffer_var).value_or({});
+
+  if (buffer_usage.empty()) {
+    if (const DeclBufferNode* decl_buffer = op->body.as<DeclBufferNode>()) {
+      if (IsAllocateDeclBufferPattern(op, decl_buffer)) {
+        // As a syntax sugar, we identify the pattern of Allocate and DeclBuffer and print a single
+        // DeclBuffer statement. It is intentionally to call `Print` instead of `PrintBody` here to
+        // delegate the printing of the current node to `DeclBufferNode` while maintaining the
+        // same value of `current_num_` and `num_child_`.
+        return Print(op->body);

Review Comment:
   This branch skips the call to `PrintNonHeaderBufferDeclarations` and the checks for the `with` syntax.  It looks like the `with` syntax is handled inside the printer for `DeclBufferNode`, but the `PrintNonHeaderBufferDeclarations` does not.  As a result, the `T.buffer_decl` statement for any buffers that alias `op->buffer_var` would erroneously show up at the function header, instead of being inside the `DeclBuffer` node.
   
   This won't be an issue once the `DeclBufferNode` is mandatory, but could cause bugs until then.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org