You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/12/14 13:54:27 UTC

[GitHub] [tvm] Hzfengsy commented on a change in pull request #9738: [TensorIR] Primitive "SetScope"

Hzfengsy commented on a change in pull request #9738:
URL: https://github.com/apache/tvm/pull/9738#discussion_r768682378



##########
File path: src/tir/schedule/primitive/block_annotate.cc
##########
@@ -16,6 +16,9 @@
  * specific language governing permissions and limitations
  * under the License.
  */
+#include <tvm/runtime/container/array.h>

Review comment:
       Do we need to include this explictly?

##########
File path: src/tir/schedule/primitive/block_annotate.cc
##########
@@ -233,14 +238,139 @@ class StorageAlignInvalidAnnotationError : public ScheduleError {
   Block block_;
 };
 
+/*!
+ * \brief A helper mutator which recursively mutates the old buffer's storage scope and collects
+ * the block sref reuse information for the following replacement.
+ */
+class StorageScopeMutator : StmtExprMutator {
+ public:
+  /*!
+   * \param allocate_site The block where `old_buffer` was allocated.
+   * \param old_buffer The old buffer
+   * \param storage_scope The storage scope to be set
+   * \param block_sref_reuse The block sref reuse map to be updated
+   * \return The new block after the mutation
+   */
+  static Block Mutate(const Block& allocate_site, const Buffer& old_buffer,
+                      const String& storage_scope, Map<Block, Block>* block_sref_reuse) {
+    Buffer new_buffer = WithScope(old_buffer, storage_scope);
+    StorageScopeMutator mutator(old_buffer, new_buffer, storage_scope, block_sref_reuse);
+    Stmt new_block = mutator.VisitStmt(allocate_site);
+    return Downcast<Block>(new_block);
+  }
+
+ private:
+  StorageScopeMutator(const Buffer& old_buffer, Buffer new_buffer, String storage_scope,
+                      Map<Block, Block>* block_sref_reuse)
+      : storage_scope_(std::move(storage_scope)), block_sref_reuse_(block_sref_reuse) {
+    buffer_var_map_[old_buffer->data.get()] = std::move(new_buffer);
+  }
+
+  PrimExpr VisitExpr_(const VarNode* var) final {
+    auto it = buffer_var_map_.find(var);
+    return it != buffer_var_map_.end() ? it->second->data : GetRef<Var>(var);
+  }
+
+  PrimExpr VisitExpr_(const BufferLoadNode* load) final {
+    BufferLoad res = Downcast<BufferLoad>(ExprMutator::VisitExpr_(load));
+
+    auto it = buffer_var_map_.find(res->buffer->data.get());
+    if (it != buffer_var_map_.end()) {
+      ObjectPtr<BufferLoadNode> ptr = make_object<BufferLoadNode>(*res.get());
+      ptr->buffer = it->second;
+      return PrimExpr(ptr);
+    } else {
+      return res;
+    }
+  }
+
+  Stmt VisitStmt_(const BufferStoreNode* store) final {
+    BufferStore res = Downcast<BufferStore>(StmtMutator::VisitStmt_(store));
+
+    auto it = buffer_var_map_.find(res->buffer->data.get());
+    if (it != buffer_var_map_.end()) {
+      ObjectPtr<BufferStoreNode> ptr = make_object<BufferStoreNode>(*res.get());
+      ptr->buffer = it->second;
+      return Stmt(ptr);
+    } else {
+      return res;
+    }
+  }
+
+  Stmt VisitStmt_(const BlockNode* block) final {
+    // To reduce the number of blocks in block sref reuse map, we check whether the block is really
+    // mutated (i.e., the old buffer appears in the block). If so, we return the block after
+    // mutation. Otherwise we just return the original block.
+
+    // Define the mutation functions.
+    auto f_mutate_match_buffers = [this](const MatchBufferRegion& match_buffer) {
+      auto it = buffer_var_map_.find(match_buffer->source->buffer->data.get());
+      if (it != buffer_var_map_.end()) {
+        Buffer new_target_buffer = WithScope(match_buffer->buffer, storage_scope_);
+        buffer_var_map_[match_buffer->buffer->data.get()] = new_target_buffer;
+        return MatchBufferRegion(new_target_buffer,
+                                 BufferRegion(it->second, match_buffer->source->region));
+      } else {
+        return match_buffer;
+      }
+    };
+    auto f_mutate_read_write_region = [this](const BufferRegion& buffer_region) {
+      auto it = buffer_var_map_.find(buffer_region->buffer->data.get());
+      return it == buffer_var_map_.end() ? buffer_region
+                                         : BufferRegion(it->second, buffer_region->region);
+    };
+    auto f_mutate_alloc_buffers = [this](const Buffer& buffer) {
+      auto it = buffer_var_map_.find(buffer->data.get());
+      return it == buffer_var_map_.end() ? buffer : it->second;
+    };
+
+    // Step 1. Mutate `match_buffers`. If an old buffer appears as a source of MatchBufferRegion,
+    // the storage scope of the target buffer also needs to be set.
+    Array<MatchBufferRegion> match_buffers =
+        MutateArray(block->match_buffers, f_mutate_match_buffers);
+    // Step 2. Mutate the read/write region.
+    Array<BufferRegion> reads = MutateArray(block->reads, f_mutate_read_write_region);
+    Array<BufferRegion> writes = MutateArray(block->writes, f_mutate_read_write_region);
+    // Step 3. Mutate `alloc_buffers` for the old buffer allocated in this block.
+    Array<Buffer> alloc_buffers = MutateArray(block->alloc_buffers, f_mutate_alloc_buffers);
+    // Step 4. Recursively mutate the block.
+    Block mutated_block = Downcast<Block>(StmtMutator::VisitStmt_(block));
+
+    if (mutated_block.get() == block && reads.same_as(mutated_block->reads) &&
+        writes.same_as(mutated_block->writes) &&
+        alloc_buffers.same_as(mutated_block->alloc_buffers) &&
+        match_buffers.same_as(mutated_block->match_buffers)) {
+      return GetRef<Block>(block);
+    } else {
+      ObjectPtr<BlockNode> n = CopyOnWrite(mutated_block.get());
+      n->reads = std::move(reads);
+      n->writes = std::move(writes);
+      n->alloc_buffers = std::move(alloc_buffers);
+      n->match_buffers = std::move(match_buffers);
+
+      Block new_block(n);
+      block_sref_reuse_->Set(GetRef<Block>(block), new_block);
+      return new_block;
+    }
+  }
+
+  /*! \brief The storage scope to be set. */
+  String storage_scope_;
+  /*! \brief A mapping which maps old buffer vars to new buffers, including the buffers defined in
+   *         MatchBufferRegion.*/

Review comment:
       ```suggestion
     /*!
      * \brief A mapping which maps old buffer vars to new buffers, including the buffers defined in
      *         MatchBufferRegion.
      */
   ```




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org