You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/05/06 16:45:52 UTC

[GitHub] [tvm] csullivan commented on a diff in pull request #11225: [TIR] Add schedule primitive SetAxisSeparator

csullivan commented on code in PR #11225:
URL: https://github.com/apache/tvm/pull/11225#discussion_r866976226


##########
src/tir/schedule/transform.h:
##########
@@ -66,6 +72,122 @@ Array<BufferRegion> ReplaceBuffer(Array<BufferRegion> regions, const Buffer& sou
 Array<MatchBufferRegion> ReplaceBuffer(Array<MatchBufferRegion> match_buffers, const Buffer& source,
                                        const Buffer& target);
 
+/*!
+ * \brief A helper mutator which recursively replaces the old buffer with the new buffer and
+ * collects the block sref reuse information for the following replacement.
+ *
+ * If the buffer to be replaced in used as the source in `match_buffers`, depending the specific
+ * use cases, the target buffers in `match_buffers` may also need to be mutated. In this
+ * case, this class should be subclassed to explicitly handle `match_buffers`.
+ */
+class ReplaceBufferMutator : public StmtExprMutator {
+ public:
+  ReplaceBufferMutator(const Buffer& old_buffer, Buffer new_buffer,
+                       Map<Block, Block>* block_sref_reuse)
+      : block_sref_reuse_(block_sref_reuse) {
+    buffer_var_map_[old_buffer->data.get()] = std::move(new_buffer);
+  }
+
+ protected:
+  PrimExpr VisitExpr_(const VarNode* var) final {
+    auto it = buffer_var_map_.find(var);
+    return it != buffer_var_map_.end() ? it->second->data : GetRef<Var>(var);
+  }
+
+  PrimExpr VisitExpr_(const BufferLoadNode* load) final {
+    BufferLoad res = Downcast<BufferLoad>(ExprMutator::VisitExpr_(load));
+
+    auto it = buffer_var_map_.find(res->buffer->data.get());
+    if (it != buffer_var_map_.end()) {
+      ObjectPtr<BufferLoadNode> ptr = make_object<BufferLoadNode>(*res.get());
+      ptr->buffer = it->second;
+      return PrimExpr(ptr);
+    } else {
+      return std::move(res);
+    }
+  }
+
+  Stmt VisitStmt_(const BufferStoreNode* store) final {
+    BufferStore res = Downcast<BufferStore>(StmtMutator::VisitStmt_(store));
+
+    auto it = buffer_var_map_.find(res->buffer->data.get());
+    if (it != buffer_var_map_.end()) {
+      ObjectPtr<BufferStoreNode> ptr = make_object<BufferStoreNode>(*res.get());
+      ptr->buffer = it->second;
+      return Stmt(ptr);
+    } else {
+      return std::move(res);
+    }
+  }
+
+  virtual MatchBufferRegion VisitMatchBufferRegion(const MatchBufferRegion& match_buffer) {
+    auto it = buffer_var_map_.find(match_buffer->source->buffer->data.get());
+    if (it != buffer_var_map_.end()) {
+      return MatchBufferRegion(match_buffer->buffer,
+                               BufferRegion(it->second, match_buffer->source->region));
+    } else {
+      return match_buffer;
+    }
+  }
+
+  Stmt VisitStmt_(const BlockNode* block) final {
+    // To reduce the number of blocks in block sref reuse map, we check whether the block is really
+    // mutated (i.e., the old buffer appears in the block). If so, we return the block after
+    // mutation. Otherwise we just return the original block.
+
+    auto f_mutate_match_buffer = [this](const MatchBufferRegion& match_buffer) {
+      return this->VisitMatchBufferRegion(match_buffer);
+    };
+    auto f_mutate_read_write_region = [this](const BufferRegion& buffer_region) {
+      auto it = buffer_var_map_.find(buffer_region->buffer->data.get());
+      return it == buffer_var_map_.end() ? buffer_region
+                                         : BufferRegion(it->second, buffer_region->region);
+    };
+    auto f_mutate_alloc_buffers = [this](const Buffer& buffer) {
+      auto it = buffer_var_map_.find(buffer->data.get());
+      return it == buffer_var_map_.end() ? buffer : it->second;
+    };
+
+    // Step 1. Mutate `match_buffers`. If an old buffer appears as a source of MatchBufferRegion,
+    Array<MatchBufferRegion> match_buffers =
+        MutateArray(block->match_buffers, f_mutate_match_buffer);
+    // Step 2. Mutate the read/write region.
+    Array<BufferRegion> reads = MutateArray(block->reads, f_mutate_read_write_region);
+    Array<BufferRegion> writes = MutateArray(block->writes, f_mutate_read_write_region);
+    // Step 3. Mutate `alloc_buffers` for the old buffer allocated in this block.
+    Array<Buffer> alloc_buffers = MutateArray(block->alloc_buffers, f_mutate_alloc_buffers);
+    // Step 4. Recursively mutate the block.
+    Block mutated_block = Downcast<Block>(StmtMutator::VisitStmt_(block));
+
+    if (mutated_block.get() == block && reads.same_as(mutated_block->reads) &&
+        writes.same_as(mutated_block->writes) &&
+        alloc_buffers.same_as(mutated_block->alloc_buffers) &&
+        match_buffers.same_as(mutated_block->match_buffers)) {
+      return GetRef<Block>(block);
+    } else {
+      ObjectPtr<BlockNode> n = CopyOnWrite(mutated_block.get());
+      n->reads = std::move(reads);
+      n->writes = std::move(writes);
+      n->alloc_buffers = std::move(alloc_buffers);
+      n->match_buffers = std::move(match_buffers);
+
+      Block new_block(n);
+      block_sref_reuse_->Set(GetRef<Block>(block), new_block);
+      return std::move(new_block);
+    }
+  }
+
+  /*! \brief The storage scope to be set. */
+  String storage_scope_;

Review Comment:
   Looks like `storage_scope_` is leftover from refactoring StorageScopeMutator onto ReplaceBufferMutator, please delete.



##########
src/tir/schedule/transform.h:
##########
@@ -66,6 +72,122 @@ Array<BufferRegion> ReplaceBuffer(Array<BufferRegion> regions, const Buffer& sou
 Array<MatchBufferRegion> ReplaceBuffer(Array<MatchBufferRegion> match_buffers, const Buffer& source,
                                        const Buffer& target);
 
+/*!
+ * \brief A helper mutator which recursively replaces the old buffer with the new buffer and
+ * collects the block sref reuse information for the following replacement.
+ *
+ * If the buffer to be replaced in used as the source in `match_buffers`, depending the specific
+ * use cases, the target buffers in `match_buffers` may also need to be mutated. In this
+ * case, this class should be subclassed to explicitly handle `match_buffers`.
+ */
+class ReplaceBufferMutator : public StmtExprMutator {
+ public:
+  ReplaceBufferMutator(const Buffer& old_buffer, Buffer new_buffer,
+                       Map<Block, Block>* block_sref_reuse)
+      : block_sref_reuse_(block_sref_reuse) {
+    buffer_var_map_[old_buffer->data.get()] = std::move(new_buffer);
+  }
+
+ protected:
+  PrimExpr VisitExpr_(const VarNode* var) final {
+    auto it = buffer_var_map_.find(var);
+    return it != buffer_var_map_.end() ? it->second->data : GetRef<Var>(var);
+  }
+
+  PrimExpr VisitExpr_(const BufferLoadNode* load) final {
+    BufferLoad res = Downcast<BufferLoad>(ExprMutator::VisitExpr_(load));
+
+    auto it = buffer_var_map_.find(res->buffer->data.get());
+    if (it != buffer_var_map_.end()) {
+      ObjectPtr<BufferLoadNode> ptr = make_object<BufferLoadNode>(*res.get());
+      ptr->buffer = it->second;
+      return PrimExpr(ptr);
+    } else {
+      return std::move(res);
+    }
+  }
+
+  Stmt VisitStmt_(const BufferStoreNode* store) final {
+    BufferStore res = Downcast<BufferStore>(StmtMutator::VisitStmt_(store));
+
+    auto it = buffer_var_map_.find(res->buffer->data.get());
+    if (it != buffer_var_map_.end()) {
+      ObjectPtr<BufferStoreNode> ptr = make_object<BufferStoreNode>(*res.get());
+      ptr->buffer = it->second;
+      return Stmt(ptr);
+    } else {
+      return std::move(res);
+    }
+  }

Review Comment:
   nit: A function template would enforce common code for both BufferLoadNode and BufferStoreNode visitors here. 



##########
src/tir/schedule/transform.h:
##########
@@ -66,6 +72,122 @@ Array<BufferRegion> ReplaceBuffer(Array<BufferRegion> regions, const Buffer& sou
 Array<MatchBufferRegion> ReplaceBuffer(Array<MatchBufferRegion> match_buffers, const Buffer& source,
                                        const Buffer& target);
 
+/*!
+ * \brief A helper mutator which recursively replaces the old buffer with the new buffer and
+ * collects the block sref reuse information for the following replacement.
+ *
+ * If the buffer to be replaced in used as the source in `match_buffers`, depending the specific
+ * use cases, the target buffers in `match_buffers` may also need to be mutated. In this
+ * case, this class should be subclassed to explicitly handle `match_buffers`.
+ */
+class ReplaceBufferMutator : public StmtExprMutator {

Review Comment:
   This is such a helpful visitor 💪 , thank you for introducing it. 



##########
src/tir/schedule/transform.h:
##########
@@ -66,6 +72,122 @@ Array<BufferRegion> ReplaceBuffer(Array<BufferRegion> regions, const Buffer& sou
 Array<MatchBufferRegion> ReplaceBuffer(Array<MatchBufferRegion> match_buffers, const Buffer& source,
                                        const Buffer& target);
 
+/*!
+ * \brief A helper mutator which recursively replaces the old buffer with the new buffer and
+ * collects the block sref reuse information for the following replacement.
+ *
+ * If the buffer to be replaced in used as the source in `match_buffers`, depending the specific
+ * use cases, the target buffers in `match_buffers` may also need to be mutated. In this
+ * case, this class should be subclassed to explicitly handle `match_buffers`.
+ */
+class ReplaceBufferMutator : public StmtExprMutator {

Review Comment:
   I can see this being helpful for some of the non-schedule TIR transformations. I don't see any src/tir/transforms including `../schedule/transform.h` but I do see some includes of `../schedule/utils.h`. Is this an appropriate place for the visitor given its potential use outside schedule transforms?



##########
src/tir/schedule/transform.h:
##########
@@ -66,6 +72,122 @@ Array<BufferRegion> ReplaceBuffer(Array<BufferRegion> regions, const Buffer& sou
 Array<MatchBufferRegion> ReplaceBuffer(Array<MatchBufferRegion> match_buffers, const Buffer& source,
                                        const Buffer& target);
 
+/*!
+ * \brief A helper mutator which recursively replaces the old buffer with the new buffer and
+ * collects the block sref reuse information for the following replacement.
+ *
+ * If the buffer to be replaced in used as the source in `match_buffers`, depending the specific
+ * use cases, the target buffers in `match_buffers` may also need to be mutated. In this
+ * case, this class should be subclassed to explicitly handle `match_buffers`.
+ */
+class ReplaceBufferMutator : public StmtExprMutator {
+ public:
+  ReplaceBufferMutator(const Buffer& old_buffer, Buffer new_buffer,
+                       Map<Block, Block>* block_sref_reuse)
+      : block_sref_reuse_(block_sref_reuse) {
+    buffer_var_map_[old_buffer->data.get()] = std::move(new_buffer);
+  }

Review Comment:
   I don't notice any other implementations defined in this header. Should we move the implementation to the source file?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org