You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/09/29 09:47:29 UTC

[GitHub] [tvm] multiverstack opened a new pull request, #12939: [TIR][Schedule] Add cache_buffer primitive to cache opaque buffer.

multiverstack opened a new pull request, #12939:
URL: https://github.com/apache/tvm/pull/12939

   If a block both read and write a buffer, it cannot do cache read/write separately, otherwise dependency breaks. This primitive perform cache read & write together to keep IR correctness.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] multiverstack-intellif commented on a diff in pull request #12939: [TIR][Schedule] Add cache_buffer primitive to cache opaque buffer

Posted by GitBox <gi...@apache.org>.
multiverstack-intellif commented on code in PR #12939:
URL: https://github.com/apache/tvm/pull/12939#discussion_r990728454


##########
src/tir/schedule/primitive.h:
##########
@@ -267,6 +267,17 @@ TVM_DLL StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int r
  */
 TVM_DLL StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_buffer_index,
                             const String& storage_scope);
+/*!
+ *!
+ * \brief Create 2 blocks that read&write a buffer region into a read/write cache.
+ * \param self The state of the schedule
+ * \param block_sref The block operates on the target buffer.
+ * \param read_buffer_index The index of the buffer in block's read region.
+ * \param storage_scope The target storage scope
+ * \return The reindex stage block.

Review Comment:
   Thanks, fixed now.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] wrongtest-intellif commented on a diff in pull request #12939: [TIR][Schedule] Add cache_buffer primitive to cache opaque buffer

Posted by GitBox <gi...@apache.org>.
wrongtest-intellif commented on code in PR #12939:
URL: https://github.com/apache/tvm/pull/12939#discussion_r987854894


##########
include/tvm/tir/schedule/schedule.h:
##########
@@ -403,6 +403,15 @@ class ScheduleNode : public runtime::Object {
    */
   virtual BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index,
                              const String& storage_scope) = 0;
+  /*!
+   * \brief Create 2 blocks that read&write a buffer region into a read/write cache.
+   * \param block_rv The block operates on the target buffer.
+   * \param read_buffer_index The index of the buffer in block's read region.
+   * \param storage_scope The target storage scope
+   * \return The reindex stage block.

Review Comment:
   return comment requires updation



##########
src/tir/schedule/primitive/cache_read_write.cc:
##########
@@ -1146,6 +1238,100 @@ StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_bu
   return result_block_sref;
 }
 
+Array<StmtSRef> CacheBuffer(ScheduleState self, const StmtSRef& block_sref, int write_buffer_index,

Review Comment:
   IIUC, the second argument should be `read_buffer_index`?



##########
src/tir/schedule/primitive/cache_read_write.cc:
##########
@@ -1146,6 +1238,100 @@ StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_bu
   return result_block_sref;
 }
 
+Array<StmtSRef> CacheBuffer(ScheduleState self, const StmtSRef& block_sref, int write_buffer_index,
+                            const String& storage_scope) {
+  /*!
+   * Do cache read then cache write
+   */
+
+  // Check 0. Check the input storage scope.
+  CheckStorageScope(self, storage_scope);
+
+  // Check 1. Check index, get the target buffer and the parent scope
+  const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
+  Buffer buffer =
+      GetNthAccessBuffer(self, GetRef<Block>(block), write_buffer_index, BufferIndexType::kRead);
+  StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false);
+
+  // Check 3. Check required region cover for cache_read
+  CheckRegionCover(self, scope_sref);
+
+  Array<StmtSRef> results_block_sref;
+  Buffer new_buffer = WithScope(buffer, storage_scope);
+
+  // Do cache read
+  // Cache read step 0. Create CacheStageInfo
+  CacheStageInfo info;
+  info.read_buffer = buffer;
+  // Create the corresponding buffer to be written for cache_read
+  info.write_buffer = new_buffer;
+  // Create the corresponding buffer allocation
+  info.alloc = info.write_buffer;
+  // Indicate which buffers should consume the cache.
+  info.consumer_blocks.push_back(block_sref);
+
+  // Cache read step 1. Update cache stage info for cache_read.
+  BufferRegion cache_region{nullptr};
+  Optional<StmtSRef> _write_block_sref = GetOnlyWriteBlock(self, scope_sref, buffer);
+
+  StmtSRef write_block_sref = _write_block_sref.value();
+  const BlockNode* write_block = TVM_SREF_TO_BLOCK(write_block_sref);
+  // Find the producing region
+  BufferRegion region = GetBufferRegionFromBuffer(write_block->writes, buffer).value();
+  StmtSRef parent_sref = GetRef<StmtSRef>(write_block_sref->parent);
+
+  // Detect insert position
+  CacheBufferLocDetector::Detect(self, write_block_sref, scope_sref, &info);
+  cache_region = RelaxBufferRegion(self, region, write_block_sref, parent_sref, info.loc_sref);
+
+  // Cache read step 2. Making new cache stage block and rewrite readers.
+  Block cache_read_stage = MakeCacheStage(/*cache_region=*/cache_region, /*info=*/&info,
+                                          /*storage_scope=*/storage_scope);
+  Stmt new_scope = CacheReadRewriter::Rewrite(/*scope_sref=*/scope_sref, /*info=*/&info);
+
+  // Cache read step 3. Replacing and updating flags for cache read.
+  self->Replace(scope_sref, new_scope, info.block_reuse);

Review Comment:
   Could we merge this with line 1324 to ensure an atomic state updation?



##########
src/tir/schedule/primitive.h:
##########
@@ -267,6 +267,17 @@ TVM_DLL StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int r
  */
 TVM_DLL StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_buffer_index,
                             const String& storage_scope);
+/*!
+ *!
+ * \brief Create 2 blocks that read&write a buffer region into a read/write cache.
+ * \param self The state of the schedule
+ * \param block_sref The block operates on the target buffer.
+ * \param read_buffer_index The index of the buffer in block's read region.
+ * \param storage_scope The target storage scope
+ * \return The reindex stage block.

Review Comment:
   return comment requires updation



##########
src/tir/schedule/primitive/cache_read_write.cc:
##########
@@ -1146,6 +1238,100 @@ StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_bu
   return result_block_sref;
 }
 
+Array<StmtSRef> CacheBuffer(ScheduleState self, const StmtSRef& block_sref, int write_buffer_index,
+                            const String& storage_scope) {
+  /*!
+   * Do cache read then cache write
+   */
+
+  // Check 0. Check the input storage scope.
+  CheckStorageScope(self, storage_scope);
+
+  // Check 1. Check index, get the target buffer and the parent scope
+  const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
+  Buffer buffer =
+      GetNthAccessBuffer(self, GetRef<Block>(block), write_buffer_index, BufferIndexType::kRead);
+  StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false);
+
+  // Check 3. Check required region cover for cache_read
+  CheckRegionCover(self, scope_sref);
+
+  Array<StmtSRef> results_block_sref;
+  Buffer new_buffer = WithScope(buffer, storage_scope);
+
+  // Do cache read
+  // Cache read step 0. Create CacheStageInfo
+  CacheStageInfo info;
+  info.read_buffer = buffer;
+  // Create the corresponding buffer to be written for cache_read
+  info.write_buffer = new_buffer;
+  // Create the corresponding buffer allocation
+  info.alloc = info.write_buffer;
+  // Indicate which buffers should consume the cache.
+  info.consumer_blocks.push_back(block_sref);
+
+  // Cache read step 1. Update cache stage info for cache_read.
+  BufferRegion cache_region{nullptr};
+  Optional<StmtSRef> _write_block_sref = GetOnlyWriteBlock(self, scope_sref, buffer);
+
+  StmtSRef write_block_sref = _write_block_sref.value();
+  const BlockNode* write_block = TVM_SREF_TO_BLOCK(write_block_sref);
+  // Find the producing region
+  BufferRegion region = GetBufferRegionFromBuffer(write_block->writes, buffer).value();

Review Comment:
   Could we check the write region must exists as the API document described?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] wrongtest-intellif commented on a diff in pull request #12939: [TIR][Schedule] Add cache_buffer primitive to cache opaque buffer

Posted by GitBox <gi...@apache.org>.
wrongtest-intellif commented on code in PR #12939:
URL: https://github.com/apache/tvm/pull/12939#discussion_r984430980


##########
src/tir/schedule/primitive/cache_read_write.cc:
##########
@@ -1146,6 +1250,102 @@ StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_bu
   return result_block_sref;
 }
 
+Array<StmtSRef> CacheBuffer(ScheduleState self, const StmtSRef& block_sref, int write_buffer_index,
+                            const String& storage_scope) {
+  /*!
+   * Do cache read then cache write
+   */
+
+  // Check 0. Check the input storage scope.
+  CheckStorageScope(self, storage_scope);
+
+  // Check 1. Check index, get the target buffer and the parent scope
+  const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
+  Buffer buffer =
+      GetNthAccessBuffer(self, GetRef<Block>(block), write_buffer_index, BufferIndexType::kRead);
+  StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false);
+
+  // Check 3. Check required region cover for cache_read
+  CheckRegionCover(self, scope_sref);
+
+  Array<StmtSRef> results_block_sref;
+  Buffer new_buffer = WithScope(buffer, storage_scope);
+
+  // Do cache read
+  // Cache read step 0. Create CacheStageInfo
+  CacheStageInfo info;
+  info.read_buffer = buffer;
+  // Create the corresponding buffer to be written for cache_read
+  info.write_buffer = new_buffer;
+  // Create the corresponding buffer allocation
+  info.alloc = info.write_buffer;
+  // Indicate which buffers should consume the cache.
+  info.consumer_blocks.push_back(block_sref);
+
+  // Cache read step 1. Update cache stage info for cache_read.
+  BufferRegion cache_region{nullptr};
+  Optional<StmtSRef> _write_block_sref = GetOnlyWriteBlock(self, scope_sref, buffer);
+
+  StmtSRef write_block_sref = _write_block_sref.value();
+  const BlockNode* write_block = TVM_SREF_TO_BLOCK(write_block_sref);
+  // Find the producing region
+  BufferRegion region = GetBufferRegionFromBuffer(write_block->writes, buffer).value();
+  StmtSRef parent_sref = GetRef<StmtSRef>(write_block_sref->parent);
+
+  // Detect insert position
+  CacheBufferLocDetector::Detect(self, write_block_sref, scope_sref, &info);
+  cache_region = RelaxBufferRegion(self, region, write_block_sref, parent_sref, info.loc_sref);
+
+  // Cache read step 2. Making new cache stage block and rewrite readers.
+  Block cache_read_stage = MakeCacheStage(/*cache_region=*/cache_region, /*info=*/&info,
+                                          /*storage_scope=*/storage_scope);
+  Stmt new_scope = CacheReadRewriter::Rewrite(/*scope_sref=*/scope_sref, /*info=*/&info);
+
+  // Cache read step 3. Replacing and updating flags for cache read.
+  self->Replace(scope_sref, new_scope, info.block_reuse);
+  StmtSRef result_block_sref = self->stmt2ref.at(cache_read_stage.get());
+  BlockInfo& block_info_read = self->block_info[result_block_sref];
+  block_info_read.affine_binding = CalculateAffineFlag(self, result_block_sref);
+  block_info_read.region_cover = true;
+  block_info_read.scope->stage_pipeline = true;

Review Comment:
   why it is state_pipeline?



##########
src/tir/schedule/primitive/cache_read_write.cc:
##########
@@ -563,8 +649,17 @@ class CacheReadRewriter : public StmtExprMutator {
     if (block == scope_sref_->stmt) {
       // If so, put buffer allocation on the parent scope
       ObjectPtr<BlockNode> n = make_object<BlockNode>(*stmt.as<BlockNode>());
-      n->alloc_buffers.push_back(info_->alloc);
-      stmt = Block(n);
+      bool alloc_buffer_exists = false;
+      for (const Buffer& it : n->alloc_buffers) {
+        if (it.same_as(info_->alloc)) {
+          alloc_buffer_exists = true;
+        }
+      }
+      // In cache_buffer case, alloc_buffer may be already exits.

Review Comment:
   If alloc_buffer must exists in `cache_buffer` case, maybe we can change info_->alloc to Optional<Buffer> and leave it undefined? 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] multiverstack-intellif commented on a diff in pull request #12939: [TIR][Schedule] Add cache_buffer primitive to cache opaque buffer

Posted by GitBox <gi...@apache.org>.
multiverstack-intellif commented on code in PR #12939:
URL: https://github.com/apache/tvm/pull/12939#discussion_r985033912


##########
src/tir/schedule/primitive/cache_read_write.cc:
##########
@@ -563,8 +649,17 @@ class CacheReadRewriter : public StmtExprMutator {
     if (block == scope_sref_->stmt) {
       // If so, put buffer allocation on the parent scope
       ObjectPtr<BlockNode> n = make_object<BlockNode>(*stmt.as<BlockNode>());
-      n->alloc_buffers.push_back(info_->alloc);
-      stmt = Block(n);
+      bool alloc_buffer_exists = false;
+      for (const Buffer& it : n->alloc_buffers) {
+        if (it.same_as(info_->alloc)) {
+          alloc_buffer_exists = true;
+        }
+      }
+      // In cache_buffer case, alloc_buffer may be already exits.

Review Comment:
   good idea!



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] multiverstack-intellif commented on a diff in pull request #12939: [TIR][Schedule] Add cache_buffer primitive to cache opaque buffer

Posted by GitBox <gi...@apache.org>.
multiverstack-intellif commented on code in PR #12939:
URL: https://github.com/apache/tvm/pull/12939#discussion_r990728730


##########
src/tir/schedule/primitive/cache_read_write.cc:
##########
@@ -1146,6 +1238,100 @@ StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_bu
   return result_block_sref;
 }
 
+Array<StmtSRef> CacheBuffer(ScheduleState self, const StmtSRef& block_sref, int write_buffer_index,
+                            const String& storage_scope) {
+  /*!
+   * Do cache read then cache write
+   */
+
+  // Check 0. Check the input storage scope.
+  CheckStorageScope(self, storage_scope);
+
+  // Check 1. Check index, get the target buffer and the parent scope
+  const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
+  Buffer buffer =
+      GetNthAccessBuffer(self, GetRef<Block>(block), write_buffer_index, BufferIndexType::kRead);
+  StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false);
+
+  // Check 3. Check required region cover for cache_read
+  CheckRegionCover(self, scope_sref);
+
+  Array<StmtSRef> results_block_sref;
+  Buffer new_buffer = WithScope(buffer, storage_scope);
+
+  // Do cache read
+  // Cache read step 0. Create CacheStageInfo
+  CacheStageInfo info;
+  info.read_buffer = buffer;
+  // Create the corresponding buffer to be written for cache_read
+  info.write_buffer = new_buffer;
+  // Create the corresponding buffer allocation
+  info.alloc = info.write_buffer;
+  // Indicate which buffers should consume the cache.
+  info.consumer_blocks.push_back(block_sref);
+
+  // Cache read step 1. Update cache stage info for cache_read.
+  BufferRegion cache_region{nullptr};
+  Optional<StmtSRef> _write_block_sref = GetOnlyWriteBlock(self, scope_sref, buffer);
+
+  StmtSRef write_block_sref = _write_block_sref.value();
+  const BlockNode* write_block = TVM_SREF_TO_BLOCK(write_block_sref);
+  // Find the producing region
+  BufferRegion region = GetBufferRegionFromBuffer(write_block->writes, buffer).value();
+  StmtSRef parent_sref = GetRef<StmtSRef>(write_block_sref->parent);
+
+  // Detect insert position
+  CacheBufferLocDetector::Detect(self, write_block_sref, scope_sref, &info);
+  cache_region = RelaxBufferRegion(self, region, write_block_sref, parent_sref, info.loc_sref);
+
+  // Cache read step 2. Making new cache stage block and rewrite readers.
+  Block cache_read_stage = MakeCacheStage(/*cache_region=*/cache_region, /*info=*/&info,
+                                          /*storage_scope=*/storage_scope);
+  Stmt new_scope = CacheReadRewriter::Rewrite(/*scope_sref=*/scope_sref, /*info=*/&info);
+
+  // Cache read step 3. Replacing and updating flags for cache read.
+  self->Replace(scope_sref, new_scope, info.block_reuse);

Review Comment:
   This seems need big change to achieve, I'll add more checks at the beginning to prevent failure in the middle.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] multiverstack-intellif commented on a diff in pull request #12939: [TIR][Schedule] Add cache_buffer primitive to cache opaque buffer

Posted by GitBox <gi...@apache.org>.
multiverstack-intellif commented on code in PR #12939:
URL: https://github.com/apache/tvm/pull/12939#discussion_r990728471


##########
src/tir/schedule/primitive/cache_read_write.cc:
##########
@@ -1146,6 +1238,100 @@ StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_bu
   return result_block_sref;
 }
 
+Array<StmtSRef> CacheBuffer(ScheduleState self, const StmtSRef& block_sref, int write_buffer_index,

Review Comment:
   Thanks, fixed now.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] multiverstack-intellif commented on a diff in pull request #12939: [TIR][Schedule] Add cache_buffer primitive to cache opaque buffer

Posted by GitBox <gi...@apache.org>.
multiverstack-intellif commented on code in PR #12939:
URL: https://github.com/apache/tvm/pull/12939#discussion_r990728750


##########
src/tir/schedule/primitive/cache_read_write.cc:
##########
@@ -1146,6 +1238,100 @@ StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_bu
   return result_block_sref;
 }
 
+Array<StmtSRef> CacheBuffer(ScheduleState self, const StmtSRef& block_sref, int write_buffer_index,
+                            const String& storage_scope) {
+  /*!
+   * Do cache read then cache write
+   */
+
+  // Check 0. Check the input storage scope.
+  CheckStorageScope(self, storage_scope);
+
+  // Check 1. Check index, get the target buffer and the parent scope
+  const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
+  Buffer buffer =
+      GetNthAccessBuffer(self, GetRef<Block>(block), write_buffer_index, BufferIndexType::kRead);
+  StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false);
+
+  // Check 3. Check required region cover for cache_read
+  CheckRegionCover(self, scope_sref);
+
+  Array<StmtSRef> results_block_sref;
+  Buffer new_buffer = WithScope(buffer, storage_scope);
+
+  // Do cache read
+  // Cache read step 0. Create CacheStageInfo
+  CacheStageInfo info;
+  info.read_buffer = buffer;
+  // Create the corresponding buffer to be written for cache_read
+  info.write_buffer = new_buffer;
+  // Create the corresponding buffer allocation
+  info.alloc = info.write_buffer;
+  // Indicate which buffers should consume the cache.
+  info.consumer_blocks.push_back(block_sref);
+
+  // Cache read step 1. Update cache stage info for cache_read.
+  BufferRegion cache_region{nullptr};
+  Optional<StmtSRef> _write_block_sref = GetOnlyWriteBlock(self, scope_sref, buffer);
+
+  StmtSRef write_block_sref = _write_block_sref.value();
+  const BlockNode* write_block = TVM_SREF_TO_BLOCK(write_block_sref);
+  // Find the producing region
+  BufferRegion region = GetBufferRegionFromBuffer(write_block->writes, buffer).value();

Review Comment:
   Thanks, more checks added in new commit.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] multiverstack-intellif commented on a diff in pull request #12939: [TIR][Schedule] Add cache_buffer primitive to cache opaque buffer

Posted by GitBox <gi...@apache.org>.
multiverstack-intellif commented on code in PR #12939:
URL: https://github.com/apache/tvm/pull/12939#discussion_r990728418


##########
include/tvm/tir/schedule/schedule.h:
##########
@@ -403,6 +403,15 @@ class ScheduleNode : public runtime::Object {
    */
   virtual BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index,
                              const String& storage_scope) = 0;
+  /*!
+   * \brief Create 2 blocks that read&write a buffer region into a read/write cache.
+   * \param block_rv The block operates on the target buffer.
+   * \param read_buffer_index The index of the buffer in block's read region.
+   * \param storage_scope The target storage scope
+   * \return The reindex stage block.

Review Comment:
   Thanks, fixed now.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] multiverstack-intellif commented on a diff in pull request #12939: [TIR][Schedule] Add cache_buffer primitive to cache opaque buffer

Posted by GitBox <gi...@apache.org>.
multiverstack-intellif commented on code in PR #12939:
URL: https://github.com/apache/tvm/pull/12939#discussion_r985033884


##########
src/tir/schedule/primitive/cache_read_write.cc:
##########
@@ -1146,6 +1250,102 @@ StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_bu
   return result_block_sref;
 }
 
+Array<StmtSRef> CacheBuffer(ScheduleState self, const StmtSRef& block_sref, int write_buffer_index,
+                            const String& storage_scope) {
+  /*!
+   * Do cache read then cache write
+   */
+
+  // Check 0. Check the input storage scope.
+  CheckStorageScope(self, storage_scope);
+
+  // Check 1. Check index, get the target buffer and the parent scope
+  const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
+  Buffer buffer =
+      GetNthAccessBuffer(self, GetRef<Block>(block), write_buffer_index, BufferIndexType::kRead);
+  StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false);
+
+  // Check 3. Check required region cover for cache_read
+  CheckRegionCover(self, scope_sref);
+
+  Array<StmtSRef> results_block_sref;
+  Buffer new_buffer = WithScope(buffer, storage_scope);
+
+  // Do cache read
+  // Cache read step 0. Create CacheStageInfo
+  CacheStageInfo info;
+  info.read_buffer = buffer;
+  // Create the corresponding buffer to be written for cache_read
+  info.write_buffer = new_buffer;
+  // Create the corresponding buffer allocation
+  info.alloc = info.write_buffer;
+  // Indicate which buffers should consume the cache.
+  info.consumer_blocks.push_back(block_sref);
+
+  // Cache read step 1. Update cache stage info for cache_read.
+  BufferRegion cache_region{nullptr};
+  Optional<StmtSRef> _write_block_sref = GetOnlyWriteBlock(self, scope_sref, buffer);
+
+  StmtSRef write_block_sref = _write_block_sref.value();
+  const BlockNode* write_block = TVM_SREF_TO_BLOCK(write_block_sref);
+  // Find the producing region
+  BufferRegion region = GetBufferRegionFromBuffer(write_block->writes, buffer).value();
+  StmtSRef parent_sref = GetRef<StmtSRef>(write_block_sref->parent);
+
+  // Detect insert position
+  CacheBufferLocDetector::Detect(self, write_block_sref, scope_sref, &info);
+  cache_region = RelaxBufferRegion(self, region, write_block_sref, parent_sref, info.loc_sref);
+
+  // Cache read step 2. Making new cache stage block and rewrite readers.
+  Block cache_read_stage = MakeCacheStage(/*cache_region=*/cache_region, /*info=*/&info,
+                                          /*storage_scope=*/storage_scope);
+  Stmt new_scope = CacheReadRewriter::Rewrite(/*scope_sref=*/scope_sref, /*info=*/&info);
+
+  // Cache read step 3. Replacing and updating flags for cache read.
+  self->Replace(scope_sref, new_scope, info.block_reuse);
+  StmtSRef result_block_sref = self->stmt2ref.at(cache_read_stage.get());
+  BlockInfo& block_info_read = self->block_info[result_block_sref];
+  block_info_read.affine_binding = CalculateAffineFlag(self, result_block_sref);
+  block_info_read.region_cover = true;
+  block_info_read.scope->stage_pipeline = true;

Review Comment:
   my mistake, it should be always false here.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] Hzfengsy merged pull request #12939: [TIR][Schedule] Add cache_inplace primitive to cache opaque buffer

Posted by GitBox <gi...@apache.org>.
Hzfengsy merged PR #12939:
URL: https://github.com/apache/tvm/pull/12939


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org