You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/02/15 01:14:57 UTC

[GitHub] [tvm] vinx13 commented on a change in pull request #9727: [TE][TIR] Implement layout transformations, non-flat memory buffers

vinx13 commented on a change in pull request #9727:
URL: https://github.com/apache/tvm/pull/9727#discussion_r805065309



##########
File path: tests/python/unittest/test_tir_transform_instrument_bound_checkers.py
##########
@@ -94,7 +94,7 @@ def test_out_of_bounds_vectorize_llvm(nn, index_a, index_b):
 @tvm.testing.requires_llvm
 def test_in_bounds_vectorize_llvm():
     n = 512
-    lanes = 2
+    lanes = 1

Review comment:
       is this change needed to pass the test?

##########
File path: include/tvm/tir/buffer.h
##########
@@ -98,10 +113,17 @@ class BufferNode : public Object {
   }
 
   bool SEqualReduce(const BufferNode* other, SEqualReducer equal) const {
-    // Use DefEqual as buffer can define variables
-    // in its semantics, skip name as name is not important.
+    // Use DefEqual as buffer can define variables in its semantics,
+    // skip name as name is not important.
+
+    // The pre-flattened information is only used for type-checking,
+    // and doesn't represent a different computation.
+    //
+    // TODO(Lunderberg): Move the pre-flattened buffer information
+    // into the PrimFunc's buffer_map.

Review comment:
       update the comment as pre-flattened information is no longer here

##########
File path: src/tir/transforms/merge_dynamic_shared_memory_allocations.cc
##########
@@ -294,22 +308,61 @@ class DynamicSharedMemoryRewriter : public StmtExprMutator {
   }
 
   PrimExpr VisitExpr_(const LoadNode* op) final {
-    if (IsDynamicSharedMemory(op->buffer_var)) {
-      PrimExpr offset = GetBufferOffset(op->buffer_var, op->dtype);
-      PrimExpr index = StmtExprMutator::VisitExpr(op->index);
-      return Load(op->dtype, merged_buf_var_, offset + index, op->predicate, op->span);
-    }
-    return StmtExprMutator::VisitExpr_(op);
+    LOG(FATAL) << "Unexpected use of deprecated LoadNode.  Please use BufferLoadNode instead.";
+    return PrimExpr();
   }
 
   Stmt VisitStmt_(const StoreNode* op) final {
-    if (IsDynamicSharedMemory(op->buffer_var)) {
-      PrimExpr offset = GetBufferOffset(op->buffer_var, op->value->dtype);
-      PrimExpr index = StmtExprMutator::VisitExpr(op->index);
-      PrimExpr value = StmtExprMutator::VisitExpr(op->value);
-      return Store(merged_buf_var_, value, offset + index, op->predicate, op->span);
+    LOG(FATAL) << "Unexpected use of deprecated StoreNode.  Please use BufferStoreNode instead.";
+    return Stmt();
+  }
+
+  PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+    auto node = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
+    return VisitBufferAccess(std::move(node));
+  }
+
+  Stmt VisitStmt_(const BufferStoreNode* op) final {
+    auto node = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
+    return VisitBufferAccess(std::move(node));
+  }
+
+  template <typename Node>
+  Node VisitBufferAccess(Node node) {
+    if (IsDynamicSharedMemory(node->buffer->data)) {
+      ICHECK_EQ(node->indices.size(), 1)
+          << "MergeDynamicSharedMemoryAllocations expects flat memory buffers, "
+          << "and is to be run after "
+          << "StorageFlatten (TE schedules) or FlattenBuffer (TIR schedules)";
+      Array<PrimExpr> indices = {node->indices[0] +
+                                 this->GetBufferOffset(node->buffer->data, node->buffer->dtype)};
+
+      auto writer = node.CopyOnWrite();
+      writer->buffer = GetUpdatedBuffer(node->buffer);
+      writer->indices = indices;
     }
-    return StmtExprMutator::VisitStmt_(op);
+
+    return node;
+  }
+
+  Buffer GetUpdatedBuffer(Buffer buffer) {
+    auto key = buffer.get();
+    auto it = buffer_remap_.find(key);
+    if (it != buffer_remap_.end()) {
+      return it->second;
+    }
+
+    if (IsDynamicSharedMemory(buffer->data)) {
+      ICHECK_EQ(buffer->shape.size(), 1)
+          << "MergeDynamicSharedMemoryAllocations expects flat memory buffers, "
+          << "and is to be run after "
+          << "StorageFlatten (TE schedules) or FlattenBuffer (TIR schedules)";
+      auto writer = buffer.CopyOnWrite();
+      writer->data = merged_buf_var_;

Review comment:
       This pass consolidates multiple dynamic shared memory buffers into one, shall we also create a buffer `merged_buf_` instead of creating an alias here? 

##########
File path: python/tvm/script/tir/special_stmt.py
##########
@@ -132,6 +132,7 @@ def match_buffer(
             align=-1,
             offset_factor=0,
             buffer_type="default",
+            flatten_buffer=False,

Review comment:
       Update the comment here https://github.com/apache/tvm/blob/e40414fc81d99c235f857ae9c741a4f25d072f79/python/tvm/script/tir/special_stmt.py#L103-L104 and mention the usage of `flatten_buffer`

##########
File path: src/tir/transforms/split_host_device.cc
##########
@@ -155,10 +160,27 @@ class VarUseDefAnalysis : public StmtExprMutator {
   }
 
   PrimExpr VisitExpr_(const LoadNode* op) final {
-    this->HandleUse(op->buffer_var);
+    LOG(FATAL) << "Unexpected use of deprecated LoadNode.  Please use BufferLoadNode instead.";
+    return PrimExpr();
+  }
+
+  PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+    this->HandleUse(op->buffer->data);
     return StmtExprMutator::VisitExpr_(op);
   }
 
+  void VisitBuffer(Buffer buffer) {
+    this->HandleUse(buffer->data);
+    auto visit_arr = [&](Array<PrimExpr> arr) {
+      for (const auto& element : arr) {
+        this->VisitExpr(element);
+      }
+    };
+
+    visit_arr(buffer->shape);
+    visit_arr(buffer->strides);
+  }
+

Review comment:
       is this used?

##########
File path: include/tvm/tir/function.h
##########
@@ -136,15 +157,31 @@ class PrimFunc : public BaseFunc {
  public:
   /*!
    * \brief Constructor
+   *
    * \param params The parameters of the function.
+   *
    * \param body The body of the function.
+   *
    * \param ret_type The return type of the function.
+   *
    * \param buffer_map The buffer map for parameter buffer unpacking.
+   * This contains buffer objects as they appear in the body of the
+   * PrimFunc.  (e.g. a buffer of shape ``[1024]`` originally
+   * generated as a tensor of shape ``[32, 32]``)
+   *
+   * \param preflattened_buffer_map The buffer map for
+   * parameter buffer unpacking.  This contains buffer
+   * objects as they are expected to be passed in by the
+   * callee.  (e.g. a buffer of shape ``[32, 32]`` originally
+   * generated as a tensor of shape ``[32, 32]``)
+   *
    * \param attrs Additional function attributes.
+   *
    * \param span The location of this object in the source code.
    */
   TVM_DLL PrimFunc(Array<tir::Var> params, Stmt body, Type ret_type = VoidType(),
                    Map<tir::Var, Buffer> buffer_map = Map<tir::Var, Buffer>(),
+                   Map<tir::Var, Buffer> preflattened_buffer_map = Map<tir::Var, Buffer>(),

Review comment:
       use `Optional<Map<tir::Var, Buffer>>` to explicitly mark it as optional




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org