You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/05/05 22:39:40 UTC

[GitHub] [tvm] vinx13 opened a new pull request, #11225: [TIR] Add schedule primitive SetAxisSeparator

vinx13 opened a new pull request, #11225:
URL: https://github.com/apache/tvm/pull/11225

   This PR added a schedule primitive `set_axis_separator`, it modifies `axis_separators` attribute of the target buffer, which affects the physical dimension after flattening.
   
   cc @Lunderberg @csullivan @junrushao1994 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] csullivan merged pull request #11225: [TIR] Add schedule primitive SetAxisSeparator

Posted by GitBox <gi...@apache.org>.
csullivan merged PR #11225:
URL: https://github.com/apache/tvm/pull/11225


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] quic-sanirudh commented on pull request #11225: [TIR] Add schedule primitive SetAxisSeparator

Posted by GitBox <gi...@apache.org>.
quic-sanirudh commented on PR #11225:
URL: https://github.com/apache/tvm/pull/11225#issuecomment-1119853001

   > we can have a user convenience API backed by these two schedule primitives
   
   Ah okay, thanks for the reply @vinx13 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] csullivan commented on a diff in pull request #11225: [TIR] Add schedule primitive SetAxisSeparator

Posted by GitBox <gi...@apache.org>.
csullivan commented on code in PR #11225:
URL: https://github.com/apache/tvm/pull/11225#discussion_r866996238


##########
src/tir/schedule/transform.h:
##########
@@ -66,6 +72,122 @@ Array<BufferRegion> ReplaceBuffer(Array<BufferRegion> regions, const Buffer& sou
 Array<MatchBufferRegion> ReplaceBuffer(Array<MatchBufferRegion> match_buffers, const Buffer& source,
                                        const Buffer& target);
 
+/*!
+ * \brief A helper mutator which recursively replaces the old buffer with the new buffer and
+ * collects the block sref reuse information for the following replacement.
+ *
+ * If the buffer to be replaced in used as the source in `match_buffers`, depending the specific
+ * use cases, the target buffers in `match_buffers` may also need to be mutated. In this
+ * case, this class should be subclassed to explicitly handle `match_buffers`.
+ */
+class ReplaceBufferMutator : public StmtExprMutator {

Review Comment:
   I can see this being helpful for some of the non-schedule TIR transformations. I don't see any src/tir/transforms including `../schedule/transform.h` but I do see some includes of `../schedule/utils.h`. Is `schedule/transform.h` an appropriate place for ReplaceBufferMutator given its potential use outside schedule transforms?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] junrushao1994 commented on a diff in pull request #11225: [TIR] Add schedule primitive SetAxisSeparator

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on code in PR #11225:
URL: https://github.com/apache/tvm/pull/11225#discussion_r866406365


##########
src/tir/schedule/primitive/layout_transformation.cc:
##########
@@ -17,6 +17,8 @@
  * under the License.
  */
 #include "../utils.h"
+#include "tvm/tir/schedule/block_scope.h"
+#include "tvm/tir/stmt.h"

Review Comment:
   nit:
   
   ```suggestion
   #include <tvm/tir/schedule/block_scope.h>
   #include <tvm/tir/stmt.h>
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] tmoreau89 commented on pull request #11225: [TIR] Add schedule primitive SetAxisSeparator

Posted by GitBox <gi...@apache.org>.
tmoreau89 commented on PR #11225:
URL: https://github.com/apache/tvm/pull/11225#issuecomment-1119634333

   cc @quic-sanirudh 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] csullivan commented on a diff in pull request #11225: [TIR] Add schedule primitive SetAxisSeparator

Posted by GitBox <gi...@apache.org>.
csullivan commented on code in PR #11225:
URL: https://github.com/apache/tvm/pull/11225#discussion_r866976226


##########
src/tir/schedule/transform.h:
##########
@@ -66,6 +72,122 @@ Array<BufferRegion> ReplaceBuffer(Array<BufferRegion> regions, const Buffer& sou
 Array<MatchBufferRegion> ReplaceBuffer(Array<MatchBufferRegion> match_buffers, const Buffer& source,
                                        const Buffer& target);
 
+/*!
+ * \brief A helper mutator which recursively replaces the old buffer with the new buffer and
+ * collects the block sref reuse information for the following replacement.
+ *
+ * If the buffer to be replaced in used as the source in `match_buffers`, depending the specific
+ * use cases, the target buffers in `match_buffers` may also need to be mutated. In this
+ * case, this class should be subclassed to explicitly handle `match_buffers`.
+ */
+class ReplaceBufferMutator : public StmtExprMutator {
+ public:
+  ReplaceBufferMutator(const Buffer& old_buffer, Buffer new_buffer,
+                       Map<Block, Block>* block_sref_reuse)
+      : block_sref_reuse_(block_sref_reuse) {
+    buffer_var_map_[old_buffer->data.get()] = std::move(new_buffer);
+  }
+
+ protected:
+  PrimExpr VisitExpr_(const VarNode* var) final {
+    auto it = buffer_var_map_.find(var);
+    return it != buffer_var_map_.end() ? it->second->data : GetRef<Var>(var);
+  }
+
+  PrimExpr VisitExpr_(const BufferLoadNode* load) final {
+    BufferLoad res = Downcast<BufferLoad>(ExprMutator::VisitExpr_(load));
+
+    auto it = buffer_var_map_.find(res->buffer->data.get());
+    if (it != buffer_var_map_.end()) {
+      ObjectPtr<BufferLoadNode> ptr = make_object<BufferLoadNode>(*res.get());
+      ptr->buffer = it->second;
+      return PrimExpr(ptr);
+    } else {
+      return std::move(res);
+    }
+  }
+
+  Stmt VisitStmt_(const BufferStoreNode* store) final {
+    BufferStore res = Downcast<BufferStore>(StmtMutator::VisitStmt_(store));
+
+    auto it = buffer_var_map_.find(res->buffer->data.get());
+    if (it != buffer_var_map_.end()) {
+      ObjectPtr<BufferStoreNode> ptr = make_object<BufferStoreNode>(*res.get());
+      ptr->buffer = it->second;
+      return Stmt(ptr);
+    } else {
+      return std::move(res);
+    }
+  }
+
+  virtual MatchBufferRegion VisitMatchBufferRegion(const MatchBufferRegion& match_buffer) {
+    auto it = buffer_var_map_.find(match_buffer->source->buffer->data.get());
+    if (it != buffer_var_map_.end()) {
+      return MatchBufferRegion(match_buffer->buffer,
+                               BufferRegion(it->second, match_buffer->source->region));
+    } else {
+      return match_buffer;
+    }
+  }
+
+  Stmt VisitStmt_(const BlockNode* block) final {
+    // To reduce the number of blocks in block sref reuse map, we check whether the block is really
+    // mutated (i.e., the old buffer appears in the block). If so, we return the block after
+    // mutation. Otherwise we just return the original block.
+
+    auto f_mutate_match_buffer = [this](const MatchBufferRegion& match_buffer) {
+      return this->VisitMatchBufferRegion(match_buffer);
+    };
+    auto f_mutate_read_write_region = [this](const BufferRegion& buffer_region) {
+      auto it = buffer_var_map_.find(buffer_region->buffer->data.get());
+      return it == buffer_var_map_.end() ? buffer_region
+                                         : BufferRegion(it->second, buffer_region->region);
+    };
+    auto f_mutate_alloc_buffers = [this](const Buffer& buffer) {
+      auto it = buffer_var_map_.find(buffer->data.get());
+      return it == buffer_var_map_.end() ? buffer : it->second;
+    };
+
+    // Step 1. Mutate `match_buffers`. If an old buffer appears as a source of MatchBufferRegion,
+    Array<MatchBufferRegion> match_buffers =
+        MutateArray(block->match_buffers, f_mutate_match_buffer);
+    // Step 2. Mutate the read/write region.
+    Array<BufferRegion> reads = MutateArray(block->reads, f_mutate_read_write_region);
+    Array<BufferRegion> writes = MutateArray(block->writes, f_mutate_read_write_region);
+    // Step 3. Mutate `alloc_buffers` for the old buffer allocated in this block.
+    Array<Buffer> alloc_buffers = MutateArray(block->alloc_buffers, f_mutate_alloc_buffers);
+    // Step 4. Recursively mutate the block.
+    Block mutated_block = Downcast<Block>(StmtMutator::VisitStmt_(block));
+
+    if (mutated_block.get() == block && reads.same_as(mutated_block->reads) &&
+        writes.same_as(mutated_block->writes) &&
+        alloc_buffers.same_as(mutated_block->alloc_buffers) &&
+        match_buffers.same_as(mutated_block->match_buffers)) {
+      return GetRef<Block>(block);
+    } else {
+      ObjectPtr<BlockNode> n = CopyOnWrite(mutated_block.get());
+      n->reads = std::move(reads);
+      n->writes = std::move(writes);
+      n->alloc_buffers = std::move(alloc_buffers);
+      n->match_buffers = std::move(match_buffers);
+
+      Block new_block(n);
+      block_sref_reuse_->Set(GetRef<Block>(block), new_block);
+      return std::move(new_block);
+    }
+  }
+
+  /*! \brief The storage scope to be set. */
+  String storage_scope_;

Review Comment:
   Looks like `storage_scope_` is leftover from refactoring StorageScopeMutator onto ReplaceBufferMutator, please delete.



##########
src/tir/schedule/transform.h:
##########
@@ -66,6 +72,122 @@ Array<BufferRegion> ReplaceBuffer(Array<BufferRegion> regions, const Buffer& sou
 Array<MatchBufferRegion> ReplaceBuffer(Array<MatchBufferRegion> match_buffers, const Buffer& source,
                                        const Buffer& target);
 
+/*!
+ * \brief A helper mutator which recursively replaces the old buffer with the new buffer and
+ * collects the block sref reuse information for the following replacement.
+ *
+ * If the buffer to be replaced in used as the source in `match_buffers`, depending the specific
+ * use cases, the target buffers in `match_buffers` may also need to be mutated. In this
+ * case, this class should be subclassed to explicitly handle `match_buffers`.
+ */
+class ReplaceBufferMutator : public StmtExprMutator {
+ public:
+  ReplaceBufferMutator(const Buffer& old_buffer, Buffer new_buffer,
+                       Map<Block, Block>* block_sref_reuse)
+      : block_sref_reuse_(block_sref_reuse) {
+    buffer_var_map_[old_buffer->data.get()] = std::move(new_buffer);
+  }
+
+ protected:
+  PrimExpr VisitExpr_(const VarNode* var) final {
+    auto it = buffer_var_map_.find(var);
+    return it != buffer_var_map_.end() ? it->second->data : GetRef<Var>(var);
+  }
+
+  PrimExpr VisitExpr_(const BufferLoadNode* load) final {
+    BufferLoad res = Downcast<BufferLoad>(ExprMutator::VisitExpr_(load));
+
+    auto it = buffer_var_map_.find(res->buffer->data.get());
+    if (it != buffer_var_map_.end()) {
+      ObjectPtr<BufferLoadNode> ptr = make_object<BufferLoadNode>(*res.get());
+      ptr->buffer = it->second;
+      return PrimExpr(ptr);
+    } else {
+      return std::move(res);
+    }
+  }
+
+  Stmt VisitStmt_(const BufferStoreNode* store) final {
+    BufferStore res = Downcast<BufferStore>(StmtMutator::VisitStmt_(store));
+
+    auto it = buffer_var_map_.find(res->buffer->data.get());
+    if (it != buffer_var_map_.end()) {
+      ObjectPtr<BufferStoreNode> ptr = make_object<BufferStoreNode>(*res.get());
+      ptr->buffer = it->second;
+      return Stmt(ptr);
+    } else {
+      return std::move(res);
+    }
+  }

Review Comment:
   nit: A function template would enforce common code for both BufferLoadNode and BufferStoreNode visitors here. 



##########
src/tir/schedule/transform.h:
##########
@@ -66,6 +72,122 @@ Array<BufferRegion> ReplaceBuffer(Array<BufferRegion> regions, const Buffer& sou
 Array<MatchBufferRegion> ReplaceBuffer(Array<MatchBufferRegion> match_buffers, const Buffer& source,
                                        const Buffer& target);
 
+/*!
+ * \brief A helper mutator which recursively replaces the old buffer with the new buffer and
+ * collects the block sref reuse information for the following replacement.
+ *
+ * If the buffer to be replaced in used as the source in `match_buffers`, depending the specific
+ * use cases, the target buffers in `match_buffers` may also need to be mutated. In this
+ * case, this class should be subclassed to explicitly handle `match_buffers`.
+ */
+class ReplaceBufferMutator : public StmtExprMutator {

Review Comment:
   This is such a helpful visitor 💪 , thank you for introducing it. 



##########
src/tir/schedule/transform.h:
##########
@@ -66,6 +72,122 @@ Array<BufferRegion> ReplaceBuffer(Array<BufferRegion> regions, const Buffer& sou
 Array<MatchBufferRegion> ReplaceBuffer(Array<MatchBufferRegion> match_buffers, const Buffer& source,
                                        const Buffer& target);
 
+/*!
+ * \brief A helper mutator which recursively replaces the old buffer with the new buffer and
+ * collects the block sref reuse information for the following replacement.
+ *
+ * If the buffer to be replaced in used as the source in `match_buffers`, depending the specific
+ * use cases, the target buffers in `match_buffers` may also need to be mutated. In this
+ * case, this class should be subclassed to explicitly handle `match_buffers`.
+ */
+class ReplaceBufferMutator : public StmtExprMutator {

Review Comment:
   I can see this being helpful for some of the non-schedule TIR transformations. I don't see any src/tir/transforms including `../schedule/transform.h` but I do see some includes of `../schedule/utils.h`. Is this an appropriate place for the visitor given its potential use outside schedule transforms?



##########
src/tir/schedule/transform.h:
##########
@@ -66,6 +72,122 @@ Array<BufferRegion> ReplaceBuffer(Array<BufferRegion> regions, const Buffer& sou
 Array<MatchBufferRegion> ReplaceBuffer(Array<MatchBufferRegion> match_buffers, const Buffer& source,
                                        const Buffer& target);
 
+/*!
+ * \brief A helper mutator which recursively replaces the old buffer with the new buffer and
+ * collects the block sref reuse information for the following replacement.
+ *
+ * If the buffer to be replaced in used as the source in `match_buffers`, depending the specific
+ * use cases, the target buffers in `match_buffers` may also need to be mutated. In this
+ * case, this class should be subclassed to explicitly handle `match_buffers`.
+ */
+class ReplaceBufferMutator : public StmtExprMutator {
+ public:
+  ReplaceBufferMutator(const Buffer& old_buffer, Buffer new_buffer,
+                       Map<Block, Block>* block_sref_reuse)
+      : block_sref_reuse_(block_sref_reuse) {
+    buffer_var_map_[old_buffer->data.get()] = std::move(new_buffer);
+  }

Review Comment:
   I don't notice any other implementations defined in this header. Should we move the implementation to the source file?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] vinx13 commented on a diff in pull request #11225: [TIR] Add schedule primitive SetAxisSeparator

Posted by GitBox <gi...@apache.org>.
vinx13 commented on code in PR #11225:
URL: https://github.com/apache/tvm/pull/11225#discussion_r867038052


##########
src/tir/schedule/transform.h:
##########
@@ -66,6 +72,122 @@ Array<BufferRegion> ReplaceBuffer(Array<BufferRegion> regions, const Buffer& sou
 Array<MatchBufferRegion> ReplaceBuffer(Array<MatchBufferRegion> match_buffers, const Buffer& source,
                                        const Buffer& target);
 
+/*!
+ * \brief A helper mutator which recursively replaces the old buffer with the new buffer and
+ * collects the block sref reuse information for the following replacement.
+ *
+ * If the buffer to be replaced in used as the source in `match_buffers`, depending the specific
+ * use cases, the target buffers in `match_buffers` may also need to be mutated. In this
+ * case, this class should be subclassed to explicitly handle `match_buffers`.
+ */
+class ReplaceBufferMutator : public StmtExprMutator {

Review Comment:
   `transform.h` is for transformational utilities, it is included in `../schedule/utils.h`.



##########
src/tir/schedule/transform.h:
##########
@@ -66,6 +72,122 @@ Array<BufferRegion> ReplaceBuffer(Array<BufferRegion> regions, const Buffer& sou
 Array<MatchBufferRegion> ReplaceBuffer(Array<MatchBufferRegion> match_buffers, const Buffer& source,
                                        const Buffer& target);
 
+/*!
+ * \brief A helper mutator which recursively replaces the old buffer with the new buffer and
+ * collects the block sref reuse information for the following replacement.
+ *
+ * If the buffer to be replaced in used as the source in `match_buffers`, depending the specific
+ * use cases, the target buffers in `match_buffers` may also need to be mutated. In this
+ * case, this class should be subclassed to explicitly handle `match_buffers`.
+ */
+class ReplaceBufferMutator : public StmtExprMutator {
+ public:
+  ReplaceBufferMutator(const Buffer& old_buffer, Buffer new_buffer,
+                       Map<Block, Block>* block_sref_reuse)
+      : block_sref_reuse_(block_sref_reuse) {
+    buffer_var_map_[old_buffer->data.get()] = std::move(new_buffer);
+  }

Review Comment:
   sounds good, updated



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] vinx13 commented on pull request #11225: [TIR] Add schedule primitive SetAxisSeparator

Posted by GitBox <gi...@apache.org>.
vinx13 commented on PR #11225:
URL: https://github.com/apache/tvm/pull/11225#issuecomment-1119827206

   @quic-sanirudh Thanks for the comments. The motivation to have a separate schedule primitive is to decouple the logical-physical mapping from generic layout transformation which doesn't have to deal with physical layout. @Lunderberg suggested that we can have a user convenience API backed by these two schedule primitives


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] csullivan commented on pull request #11225: [TIR] Add schedule primitive SetAxisSeparator

Posted by GitBox <gi...@apache.org>.
csullivan commented on PR #11225:
URL: https://github.com/apache/tvm/pull/11225#issuecomment-1120051709

   Many thanks @vinx13 @quic-sanirudh @junrushao1994 @Lunderberg, this is merged!


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] quic-sanirudh commented on pull request #11225: [TIR] Add schedule primitive SetAxisSeparator

Posted by GitBox <gi...@apache.org>.
quic-sanirudh commented on PR #11225:
URL: https://github.com/apache/tvm/pull/11225#issuecomment-1119688463

   @Lunderberg @vinx13 This is great, thanks a lot, I would love to start using this.
   
   I do have a small doubt however. Is there a reason why we're introducing a new schedule primitive instead of allowing the user to pass something similar to `te.AXIS_SEPARATOR` to layout_transform and internally call both `ScheduleTransformLayout` and `SetAxisSeparator` as done for te?
   
   The reason for this question is that I discussed once with Eric, but again, the list to be passed to axis_separators argument seems like an internal detail and might confuse users. Also it might make it difficult for us to modify how axis_separators are handled internally later if needed once we expose this detail into user facing API. 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org