You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2022/11/15 20:25:55 UTC

[tvm] branch main updated: [MetaSchedule] Support schedules with cache read in RewriteLayout (#13384)

This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 8c30bda738 [MetaSchedule] Support schedules with cache read in RewriteLayout (#13384)
8c30bda738 is described below

commit 8c30bda738eb0b07c0457b6ee651f3f32857903b
Author: masahi <ma...@gmail.com>
AuthorDate: Wed Nov 16 05:25:48 2022 +0900

    [MetaSchedule] Support schedules with cache read in RewriteLayout (#13384)
    
    Currently when `CacheRead` and `RewriteLayout` are used together, the index map is derived based on the cache read block, which leads to weird result. This is because the current implementation assumes that the "layout-free" buffer is directly consumed by an "anchor" op such as conv2d or dense.
    
    When `CacheRead` is involved, we need to find the index map for the cache-read buffer as it is consumed by an anchor op, and apply the same transformation to the layout-free buffer. My solution supports more general cases where there are multiple cache reads forming a "chain" of blocks, starting from the one that directly consumes the layout-free buffer passed as a parameter. So the layout transformation is back propagated over such chain.
---
 src/meta_schedule/postproc/rewrite_layout.cc       | 159 ++++++++----
 src/te/operation/create_primfunc.cc                |   9 +-
 .../test_meta_schedule_postproc_rewrite_layout.py  | 276 +++++++++++++++++++++
 tests/python/unittest/test_te_create_primfunc.py   |   1 +
 4 files changed, 390 insertions(+), 55 deletions(-)

diff --git a/src/meta_schedule/postproc/rewrite_layout.cc b/src/meta_schedule/postproc/rewrite_layout.cc
index 3aed6680e3..71ae433871 100644
--- a/src/meta_schedule/postproc/rewrite_layout.cc
+++ b/src/meta_schedule/postproc/rewrite_layout.cc
@@ -16,6 +16,7 @@
  * specific language governing permissions and limitations
  * under the License.
  */
+#include <optional>
 #include <unordered_set>
 
 #include "../utils.h"
@@ -25,23 +26,15 @@ namespace tir {
 
 /*!
  * \brief Collect the block and index where the buffer is read.
- * \note The buffers are expected to be read by only one BufferLoad
+ * \note The buffer is expected to be read by only one BufferLoad
  */
 class BufferReadPosCollector : public StmtExprVisitor {
  public:
-  explicit BufferReadPosCollector(const Array<Buffer>& buffers) {
-    for (const Buffer& buf : buffers) {
-      buffers_.insert(buf.get());
-    }
-  }
+  explicit BufferReadPosCollector(const Buffer& buffer) : buffer_(buffer.get()) {}
 
-  const std::unordered_map<const BufferNode*, std::pair<Block, int>>& GetBufferLocations() const {
-    return buffer_locs_;
-  }
+  const std::pair<Block, int>& GetBufferLocation() const { return buffer_loc_; }
 
-  const std::unordered_map<const BufferNode*, Optional<IndexMap>>& GetBufferIndexMap() const {
-    return buffer_index_maps_;
-  }
+  const Optional<IndexMap> GetBufferIndexMap() const { return buffer_index_map_; }
 
  private:
   void VisitStmt_(const ForNode* op) final {
@@ -61,7 +54,7 @@ class BufferReadPosCollector : public StmtExprVisitor {
     CHECK(cur_realize_.defined()) << "BufferLoad occurred outside of any block";
 
     const Buffer& buffer = op->buffer;
-    if (buffers_.count(buffer.get())) {
+    if (buffer_ == buffer.get()) {
       Map<Var, PrimExpr> subst_map;
       for (size_t i = 0; i < cur_realize_->iter_values.size(); i++) {
         const Var& var = cur_realize_->block->iter_vars[i]->var;
@@ -72,14 +65,14 @@ class BufferReadPosCollector : public StmtExprVisitor {
       for (const PrimExpr& e : op->indices) {
         subst_indices.push_back(Substitute(e, subst_map));
       }
-      buffer_index_maps_[buffer.get()] = SuggestIndexMap(/*buffer=*/buffer,                      //
-                                                         /*indices=*/subst_indices,              //
-                                                         /*loops=*/loop_stack_,                  //
-                                                         /*predicate=*/cur_realize_->predicate,  //
-                                                         /*analyzer=*/&analyzer_);
+      buffer_index_map_ = SuggestIndexMap(/*buffer=*/buffer,                      //
+                                          /*indices=*/subst_indices,              //
+                                          /*loops=*/loop_stack_,                  //
+                                          /*predicate=*/cur_realize_->predicate,  //
+                                          /*analyzer=*/&analyzer_);
       int buffer_index = GetReadBufferIndex(cur_realize_->block, buffer);
       ICHECK(buffer_index != -1);
-      buffer_locs_[buffer.get()] = std::make_pair(cur_realize_->block, buffer_index);
+      buffer_loc_ = std::make_pair(cur_realize_->block, buffer_index);
     }
   }
 
@@ -93,12 +86,12 @@ class BufferReadPosCollector : public StmtExprVisitor {
   }
 
  private:
-  /*! \brief All interested buffer. */
-  std::unordered_set<const BufferNode*> buffers_;
-  /*! \brief The result mapping from buffer to its inner-most block and read index. */
-  std::unordered_map<const BufferNode*, std::pair<Block, int>> buffer_locs_;
-  /*! \brief The result mapping from buffer to its IndexMap. */
-  std::unordered_map<const BufferNode*, Optional<IndexMap>> buffer_index_maps_;
+  /*! \brief The buffer of interest. */
+  const BufferNode* buffer_;
+  /*! \brief The block that consumes the buffer and the corresponding read index. */
+  std::pair<Block, int> buffer_loc_;
+  /*! \brief The proposed IndexMap. */
+  Optional<IndexMap> buffer_index_map_;
 
   /*! \brief Loop stack for calculating IndexMap. */
   Array<For> loop_stack_;
@@ -143,8 +136,56 @@ Array<Buffer> CollectLayoutFreeBuffers(const PrimFuncNode* func) {
   return layout_free_buffers;
 }
 
+std::optional<std::tuple<Block, int, IndexMap>> GetSuggestedIndexMap(
+    Buffer buffer, const PrimFuncNode* prim_func) {
+  BufferReadPosCollector collector(buffer);
+  collector(prim_func->body);
+
+  const auto& index_map = collector.GetBufferIndexMap();
+
+  if (!index_map.defined() || !index_map) {
+    return std::nullopt;
+  }
+
+  const auto& [anchor_block, buffer_index] = collector.GetBufferLocation();
+
+  return std::make_tuple(anchor_block, buffer_index, index_map.value());
+}
+
+/*! \brief Get a chain of cache-read blocks, starting from the one consuming buf. */
+std::vector<std::string> GetCacheReadChain(const Buffer& buf, const PrimFuncNode* prim_func) {
+  class BufferReadChainCollector : public StmtVisitor {
+   public:
+    explicit BufferReadChainCollector(const Buffer& buffer) : cur_buffer_(buffer.get()) {}
+
+    void VisitStmt_(const BlockNode* op) final {
+      // Check if this block is doing cache_read or a similar operation that consumes cur_buffer_.
+      if (!op->init && op->reads.size() == 1 && op->writes.size() == 1 &&
+          op->reads[0]->buffer.get() == cur_buffer_) {
+        cache_read_chain.push_back(op->name_hint);
+        cur_buffer_ = op->writes[0]->buffer.get();
+      }
+      StmtVisitor::VisitStmt_(op);
+    }
+
+    std::vector<std::string> cache_read_chain;
+
+   private:
+    const BufferNode* cur_buffer_;
+  };
+
+  BufferReadChainCollector collector(buf);
+  collector(prim_func->body);
+  return collector.cache_read_chain;
+}
+
 bool RewriteLayout(const Schedule& sch) {
   std::vector<std::pair<StmtSRef, String>> results;
+  auto add_layout_rewrite_block = [&sch](BlockRV consumer_block_rv, int buffer_index) {
+    BlockRV rewrite_block_rv = sch->CacheRead(consumer_block_rv, buffer_index, "global");
+    sch->Annotate(rewrite_block_rv, attr::meta_schedule_layout_rewrite_preproc, const_true());
+  };
+
   for (const auto& [g_var, base_func] : sch->mod()->functions) {
     const String& func_name = g_var->name_hint;
     const auto* prim_func = base_func.as<PrimFuncNode>();
@@ -153,36 +194,46 @@ bool RewriteLayout(const Schedule& sch) {
       continue;
     }
 
-    Array<Buffer> layout_free_buffers = CollectLayoutFreeBuffers(prim_func);
-
-    // Collect Buffer read positions
-    BufferReadPosCollector collector(layout_free_buffers);
-    collector(prim_func->body);
-    const auto& locations = collector.GetBufferLocations();
-    const auto& index_maps = collector.GetBufferIndexMap();
-    // Check all buffers are collected
-    if (locations.size() != layout_free_buffers.size() ||
-        index_maps.size() != layout_free_buffers.size()) {
-      return false;
-    }
-
-    for (const auto& kv : locations) {
-      const Buffer& buffer = GetRef<Buffer>(kv.first);
-      const Block& block = kv.second.first;
-      int buffer_index = kv.second.second;
-
-      // Get IndexMap
-      const Optional<IndexMap> index_map = index_maps.at(buffer.get());
-      if (!index_map.defined()) {
-        continue;
+    for (auto buffer : CollectLayoutFreeBuffers(prim_func)) {
+      const auto cache_read_chain = GetCacheReadChain(buffer, prim_func);
+      if (cache_read_chain.empty()) {
+        // The common case, where the layout-free buffer is directly consumed by an anchor op such
+        // as conv2d or dense.
+        auto tup_opt = GetSuggestedIndexMap(buffer, prim_func);
+        if (tup_opt == std::nullopt) continue;
+
+        auto [anchor_block, buffer_index, index_map] = *tup_opt;
+        auto anchor_block_rv = sch->GetBlock(anchor_block->name_hint, func_name);
+        add_layout_rewrite_block(anchor_block_rv, buffer_index);
+        sch->TransformLayout(anchor_block_rv, buffer_index, BufferIndexType::kRead, index_map,
+                             NullOpt);
+      } else {
+        // When the layout-free buffer is consumed by cache_read, we need to find the index map
+        // for a cache-read buffer that is directly consumed by an anchor op. The last buffer
+        // in cache_read_chain corresponds to that buffer.
+        Block cache_read_block = sch->Get(sch->GetBlock(cache_read_chain.back(), func_name));
+        ICHECK_EQ(cache_read_block->writes.size(), 1);
+        auto tup_opt = GetSuggestedIndexMap(cache_read_block->writes[0]->buffer, prim_func);
+        if (tup_opt == std::nullopt) continue;
+
+        auto [anchor_block, buffer_index, index_map] = *tup_opt;
+        // Transform the layout of the last cache-read buffer.
+        sch->TransformLayout(sch->GetBlock(anchor_block->name_hint, func_name), buffer_index,
+                             BufferIndexType::kRead, index_map, NullOpt);
+
+        // Propagate the layout transformation over cache_read_chain, starting from
+        // the next-to-last cache-read buffer.
+        for (int i = static_cast<int>(cache_read_chain.size()) - 1; i >= 0; --i) {
+          BlockRV cache_read_block_rv = sch->GetBlock(cache_read_chain[i], func_name);
+          if (i == 0) {
+            // Before the first cache_read that consumes the layout-free buffer, insert
+            // a layout-rewrite block. Another cache-read buffer is added, and its layout is
+            // transformed by TransformLayout below.
+            add_layout_rewrite_block(cache_read_block_rv, 0);
+          }
+          sch->TransformLayout(cache_read_block_rv, 0, BufferIndexType::kRead, index_map, NullOpt);
+        }
       }
-
-      // Apply schedule
-      BlockRV block_rv = sch->GetBlock(block->name_hint, func_name);
-      BlockRV cached_block_rv = sch->CacheRead(block_rv, buffer_index, "global");
-      sch->TransformLayout(block_rv, buffer_index, BufferIndexType::kRead, index_map.value(),
-                           NullOpt);
-      sch->Annotate(cached_block_rv, attr::meta_schedule_layout_rewrite_preproc, const_true());
     }
   }
   return true;
diff --git a/src/te/operation/create_primfunc.cc b/src/te/operation/create_primfunc.cc
index 80da5a7279..0581ad60e8 100644
--- a/src/te/operation/create_primfunc.cc
+++ b/src/te/operation/create_primfunc.cc
@@ -110,13 +110,20 @@ class LayoutFreePlaceholdersNormalizer : public StmtMutator {
     Block block = Downcast<Block>(StmtMutator::VisitStmt_(_block));
     BlockNode* n = block.CopyOnWrite();
     if (Optional<ObjectRef> ann = n->annotations.Get(topi_attr)) {
+      Array<Buffer> new_buffers;
       for (Buffer buffer : Downcast<Array<Buffer>>(ann)) {
         auto it = buffer2index_.find(buffer);
         if (it != buffer2index_.end()) {
           layout_free_buffer_indices_.insert(it->second);
+        } else {
+          new_buffers.push_back(buffer);
         }
       }
-      n->annotations.erase(topi_attr);
+      if (new_buffers.empty()) {
+        n->annotations.erase(topi_attr);
+      } else {
+        n->annotations.Set(topi_attr, new_buffers);
+      }
     }
     for (const String& attr : this->blocklist) {
       auto it = n->annotations.find(attr);
diff --git a/tests/python/unittest/test_meta_schedule_postproc_rewrite_layout.py b/tests/python/unittest/test_meta_schedule_postproc_rewrite_layout.py
index 91a51c8e90..98c1f73685 100644
--- a/tests/python/unittest/test_meta_schedule_postproc_rewrite_layout.py
+++ b/tests/python/unittest/test_meta_schedule_postproc_rewrite_layout.py
@@ -204,5 +204,281 @@ def test_layout_rewrite():
     tvm.ir.assert_structural_equal(sch.mod["main"], rewritten_tir_matmul)
 
 
+# fmt: off
+@tvm.script.ir_module
+class Conv2dCacheRead:
+    @T.prim_func
+    def main(p0: T.Buffer[(1, 56, 56, 64), "float32"], p1: T.Buffer[(3, 3, 64, 64), "float32"], conv2d_nhwc: T.Buffer[(1, 56, 56, 64), "float32"]):
+        T.func_attr({"layout_free_buffers": [1], "tir.noalias": True, "global_symbol": "main"})
+        pad_temp = T.alloc_buffer([1, 58, 58, 64], dtype="float32")
+        conv2d_nhwc_global = T.alloc_buffer([1, 56, 56, 64], dtype="float32")
+        pad_temp_global = T.alloc_buffer([1, 58, 58, 64], dtype="float32")
+        p1_global = T.alloc_buffer([3, 3, 64, 64], dtype="float32")
+        for i0_0_i1_0_i2_0_fused in T.parallel(4, annotations={"pragma_auto_unroll_max_step":16, "pragma_unroll_explicit":1}):
+            for ax0, ax1, ax2 in T.grid(1, 30, 30):
+                for ax3_fused in T.vectorized(64):
+                    with T.block("pad_temp"):
+                        i0 = T.axis.spatial(1, ax0)
+                        i1 = T.axis.spatial(58, i0_0_i1_0_i2_0_fused // 2 * 28 + ax1)
+                        i2 = T.axis.spatial(58, i0_0_i1_0_i2_0_fused % 2 * 28 + ax2)
+                        i3 = T.axis.spatial(64, ax3_fused)
+                        T.reads(p0[i0, i1 - 1, i2 - 1, i3])
+                        T.writes(pad_temp[i0, i1, i2, i3])
+                        pad_temp[i0, i1, i2, i3] = T.if_then_else(1 <= i1 and i1 < 57 and 1 <= i2 and i2 < 57, p0[i0, i1 - 1, i2 - 1, i3], T.float32(0), dtype="float32")
+            for i3_0 in T.serial(16):
+                for ax0_ax1_ax2_ax3_fused in T.serial(57600):
+                    with T.block("pad_temp_global"):
+                        v0 = T.axis.spatial(1, 0)
+                        v1 = T.axis.spatial(58, i0_0_i1_0_i2_0_fused // 2 * 28 + ax0_ax1_ax2_ax3_fused // 1920)
+                        v2 = T.axis.spatial(58, i0_0_i1_0_i2_0_fused % 2 * 28 + ax0_ax1_ax2_ax3_fused % 1920 // 64)
+                        v3 = T.axis.spatial(64, ax0_ax1_ax2_ax3_fused % 64)
+                        T.reads(pad_temp[v0, v1, v2, v3])
+                        T.writes(pad_temp_global[v0, v1, v2, v3])
+                        pad_temp_global[v0, v1, v2, v3] = pad_temp[v0, v1, v2, v3]
+                for ax0_ax1_ax2_ax3_fused in T.serial(2304):
+                    with T.block("p1_global"):
+                        v0 = T.axis.spatial(3, ax0_ax1_ax2_ax3_fused // 768)
+                        v1 = T.axis.spatial(3, ax0_ax1_ax2_ax3_fused % 768 // 256)
+                        v2 = T.axis.spatial(64, ax0_ax1_ax2_ax3_fused % 256 // 4)
+                        v3 = T.axis.spatial(64, i3_0 * 4 + ax0_ax1_ax2_ax3_fused % 4)
+                        T.reads(p1[v0, v1, v2, v3])
+                        T.writes(p1_global[v0, v1, v2, v3])
+                        p1_global[v0, v1, v2, v3] = p1[v0, v1, v2, v3]
+                for i0_1, i1_1, i2_1, i3_1 in T.grid(1, 7, 2, 1):
+                    for i0_2_init, i1_2_init, i2_2_init, i3_2_init, i0_3_init, i1_3_init, i2_3_init in T.grid(1, 1, 14, 2, 1, 4, 1):
+                        for i3_3_fused_init in T.vectorized(2):
+                            with T.block("conv2d_nhwc_init"):
+                                nn = T.axis.spatial(1, i0_2_init + i0_3_init + i0_1)
+                                yy = T.axis.spatial(56, i0_0_i1_0_i2_0_fused // 2 * 28 + i1_1 * 4 + i1_2_init * 4 + i1_3_init)
+                                xx = T.axis.spatial(56, i2_3_init + i0_0_i1_0_i2_0_fused % 2 * 28 + i2_1 * 14 + i2_2_init)
+                                ff = T.axis.spatial(64, i3_0 * 4 + i3_1 * 4 + i3_2_init * 2 + i3_3_fused_init)
+                                T.reads()
+                                T.writes(conv2d_nhwc_global[nn, yy, xx, ff])
+                                T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"})
+                                conv2d_nhwc_global[nn, yy, xx, ff] = T.float32(0)
+                    for i4_0, i5_0, i6_0, i0_2, i1_2, i2_2, i3_2, i4_1, i5_1, i6_1, i0_3, i1_3, i2_3 in T.grid(1, 1, 2, 1, 1, 14, 2, 3, 3, 32, 1, 4, 1):
+                        for i3_3_fused in T.vectorized(2):
+                            with T.block("conv2d_nhwc_update"):
+                                nn = T.axis.spatial(1, i0_2 + i0_3 + i0_1)
+                                yy = T.axis.spatial(56, i0_0_i1_0_i2_0_fused // 2 * 28 + i1_1 * 4 + i1_2 * 4 + i1_3)
+                                xx = T.axis.spatial(56, i2_3 + i0_0_i1_0_i2_0_fused % 2 * 28 + i2_1 * 14 + i2_2)
+                                ff = T.axis.spatial(64, i3_0 * 4 + i3_1 * 4 + i3_2 * 2 + i3_3_fused)
+                                ry = T.axis.reduce(3, i4_0 * 3 + i4_1)
+                                rx = T.axis.reduce(3, i5_0 * 3 + i5_1)
+                                rc = T.axis.reduce(64, i6_0 * 32 + i6_1)
+                                T.reads(conv2d_nhwc_global[nn, yy, xx, ff], pad_temp_global[nn, yy + ry, xx + rx, rc], p1_global[ry, rx, rc, ff])
+                                T.writes(conv2d_nhwc_global[nn, yy, xx, ff])
+                                T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"})
+                                conv2d_nhwc_global[nn, yy, xx, ff] = conv2d_nhwc_global[nn, yy, xx, ff] + pad_temp_global[nn, yy + ry, xx + rx, rc] * p1_global[ry, rx, rc, ff]
+                    for ax0, ax1, ax2 in T.grid(1, 4, 14):
+                        for ax3_fused in T.vectorized(4):
+                            with T.block("conv2d_nhwc_global"):
+                                v0 = T.axis.spatial(1, ax0)
+                                v1 = T.axis.spatial(56, i0_0_i1_0_i2_0_fused // 2 * 28 + i1_1 * 4 + ax1)
+                                v2 = T.axis.spatial(56, i0_0_i1_0_i2_0_fused % 2 * 28 + i2_1 * 14 + ax2)
+                                v3 = T.axis.spatial(64, i3_0 * 4 + ax3_fused)
+                                T.reads(conv2d_nhwc_global[v0, v1, v2, v3])
+                                T.writes(conv2d_nhwc[v0, v1, v2, v3])
+                                conv2d_nhwc[v0, v1, v2, v3] = conv2d_nhwc_global[v0, v1, v2, v3]
+
+
+@tvm.script.ir_module
+class Conv2dCacheReadRewritten:
+    @T.prim_func
+    def main(p0: T.Buffer[(1, 56, 56, 64), "float32"], p1: T.Buffer[(3, 3, 64, 64), "float32"], conv2d_nhwc: T.Buffer[(1, 56, 56, 64), "float32"]):
+        T.func_attr({"layout_free_buffers": [1], "tir.noalias": True, "global_symbol": "main"})
+        pad_temp = T.alloc_buffer([1, 58, 58, 64], dtype="float32")
+        conv2d_nhwc_global = T.alloc_buffer([1, 56, 56, 64], dtype="float32")
+        pad_temp_global = T.alloc_buffer([1, 58, 58, 64], dtype="float32")
+        p1_global = T.alloc_buffer([16, 2, 2, 3, 3, 32, 2], dtype="float32")
+        p1_global_1 = T.alloc_buffer([16, 2, 2, 3, 3, 32, 2], dtype="float32")
+        for ax0, ax1, ax2, ax3 in T.grid(3, 3, 64, 64):
+            with T.block("p1_global"):
+                v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
+                T.reads(p1[v0, v1, v2, v3])
+                T.writes(p1_global_1[v3 // 4, v2 // 32, v3 % 4 // 2, v0, v1, v2 % 32, v3 % 2])
+                T.block_attr({"meta_schedule.layout_rewrite_preproc":True})
+                p1_global_1[v3 // 4, v2 // 32, v3 % 4 // 2, v0, v1, v2 % 32, v3 % 2] = p1[v0, v1, v2, v3]
+        for i0_0_i1_0_i2_0_fused in T.parallel(4, annotations={"pragma_auto_unroll_max_step":16, "pragma_unroll_explicit":1}):
+            for ax0, ax1, ax2 in T.grid(1, 30, 30):
+                for ax3_fused in T.vectorized(64):
+                    with T.block("pad_temp"):
+                        i0 = T.axis.spatial(1, ax0)
+                        i1 = T.axis.spatial(58, i0_0_i1_0_i2_0_fused // 2 * 28 + ax1)
+                        i2 = T.axis.spatial(58, i0_0_i1_0_i2_0_fused % 2 * 28 + ax2)
+                        i3 = T.axis.spatial(64, ax3_fused)
+                        T.reads(p0[i0, i1 - 1, i2 - 1, i3])
+                        T.writes(pad_temp[i0, i1, i2, i3])
+                        pad_temp[i0, i1, i2, i3] = T.if_then_else(1 <= i1 and i1 < 57 and 1 <= i2 and i2 < 57, p0[i0, i1 - 1, i2 - 1, i3], T.float32(0), dtype="float32")
+            for i3_0 in T.serial(16):
+                for ax0_ax1_ax2_ax3_fused in T.serial(57600):
+                    with T.block("pad_temp_global"):
+                        v0 = T.axis.spatial(1, 0)
+                        v1 = T.axis.spatial(58, i0_0_i1_0_i2_0_fused // 2 * 28 + ax0_ax1_ax2_ax3_fused // 1920)
+                        v2 = T.axis.spatial(58, i0_0_i1_0_i2_0_fused % 2 * 28 + ax0_ax1_ax2_ax3_fused % 1920 // 64)
+                        v3 = T.axis.spatial(64, ax0_ax1_ax2_ax3_fused % 64)
+                        T.reads(pad_temp[v0, v1, v2, v3])
+                        T.writes(pad_temp_global[v0, v1, v2, v3])
+                        pad_temp_global[v0, v1, v2, v3] = pad_temp[v0, v1, v2, v3]
+                for ax0_ax1_ax2_ax3_fused in T.serial(2304):
+                    with T.block("p1_global"):
+                        v0 = T.axis.spatial(3, ax0_ax1_ax2_ax3_fused // 768)
+                        v1 = T.axis.spatial(3, ax0_ax1_ax2_ax3_fused % 768 // 256)
+                        v2 = T.axis.spatial(64, ax0_ax1_ax2_ax3_fused % 256 // 4)
+                        v3 = T.axis.spatial(64, i3_0 * 4 + ax0_ax1_ax2_ax3_fused % 4)
+                        T.reads(p1_global_1[v3 // 4, v2 // 32, v3 % 4 // 2, v0, v1, v2 % 32, v3 % 2])
+                        T.writes(p1_global[v3 // 4, v2 // 32, v3 % 4 // 2, v0, v1, v2 % 32, v3 % 2])
+                        p1_global[v3 // 4, v2 // 32, v3 % 4 // 2, v0, v1, v2 % 32, v3 % 2] = p1_global_1[v3 // 4, v2 // 32, v3 % 4 // 2, v0, v1, v2 % 32, v3 % 2]
+                for i0_1, i1_1, i2_1, i3_1 in T.grid(1, 7, 2, 1):
+                    for i0_2_init, i1_2_init, i2_2_init, i3_2_init, i0_3_init, i1_3_init, i2_3_init in T.grid(1, 1, 14, 2, 1, 4, 1):
+                        for i3_3_fused_init in T.vectorized(2):
+                            with T.block("conv2d_nhwc_init"):
+                                nn = T.axis.spatial(1, i0_2_init + i0_3_init + i0_1)
+                                yy = T.axis.spatial(56, i0_0_i1_0_i2_0_fused // 2 * 28 + i1_1 * 4 + i1_2_init * 4 + i1_3_init)
+                                xx = T.axis.spatial(56, i2_3_init + i0_0_i1_0_i2_0_fused % 2 * 28 + i2_1 * 14 + i2_2_init)
+                                ff = T.axis.spatial(64, i3_0 * 4 + i3_1 * 4 + i3_2_init * 2 + i3_3_fused_init)
+                                T.reads()
+                                T.writes(conv2d_nhwc_global[nn, yy, xx, ff])
+                                T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"})
+                                conv2d_nhwc_global[nn, yy, xx, ff] = T.float32(0)
+                    for i4_0, i5_0, i6_0, i0_2, i1_2, i2_2, i3_2, i4_1, i5_1, i6_1, i0_3, i1_3, i2_3 in T.grid(1, 1, 2, 1, 1, 14, 2, 3, 3, 32, 1, 4, 1):
+                        for i3_3_fused in T.vectorized(2):
+                            with T.block("conv2d_nhwc_update"):
+                                nn = T.axis.spatial(1, i0_2 + i0_3 + i0_1)
+                                yy = T.axis.spatial(56, i0_0_i1_0_i2_0_fused // 2 * 28 + i1_1 * 4 + i1_2 * 4 + i1_3)
+                                xx = T.axis.spatial(56, i2_3 + i0_0_i1_0_i2_0_fused % 2 * 28 + i2_1 * 14 + i2_2)
+                                ff = T.axis.spatial(64, i3_0 * 4 + i3_1 * 4 + i3_2 * 2 + i3_3_fused)
+                                ry = T.axis.reduce(3, i4_0 * 3 + i4_1)
+                                rx = T.axis.reduce(3, i5_0 * 3 + i5_1)
+                                rc = T.axis.reduce(64, i6_0 * 32 + i6_1)
+                                T.reads(conv2d_nhwc_global[nn, yy, xx, ff], pad_temp_global[nn, yy + ry, xx + rx, rc], p1_global[ff // 4, rc // 32, ff % 4 // 2, ry, rx, rc % 32, ff % 2])
+                                T.writes(conv2d_nhwc_global[nn, yy, xx, ff])
+                                T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"})
+                                conv2d_nhwc_global[nn, yy, xx, ff] = conv2d_nhwc_global[nn, yy, xx, ff] + pad_temp_global[nn, yy + ry, xx + rx, rc] * p1_global[ff // 4, rc // 32, ff % 4 // 2, ry, rx, rc % 32, ff % 2]
+                    for ax0, ax1, ax2 in T.grid(1, 4, 14):
+                        for ax3_fused in T.vectorized(4):
+                            with T.block("conv2d_nhwc_global"):
+                                v0 = T.axis.spatial(1, ax0)
+                                v1 = T.axis.spatial(56, i0_0_i1_0_i2_0_fused // 2 * 28 + i1_1 * 4 + ax1)
+                                v2 = T.axis.spatial(56, i0_0_i1_0_i2_0_fused % 2 * 28 + i2_1 * 14 + ax2)
+                                v3 = T.axis.spatial(64, i3_0 * 4 + ax3_fused)
+                                T.reads(conv2d_nhwc_global[v0, v1, v2, v3])
+                                T.writes(conv2d_nhwc[v0, v1, v2, v3])
+                                conv2d_nhwc[v0, v1, v2, v3] = conv2d_nhwc_global[v0, v1, v2, v3]
+
+
+@tvm.script.ir_module
+class Conv2dCacheReadMultipleRewritten:
+    @T.prim_func
+    def main(p0: T.Buffer[(1, 56, 56, 64), "float32"], p1: T.Buffer[(3, 3, 64, 64), "float32"], conv2d_nhwc: T.Buffer[(1, 56, 56, 64), "float32"]):
+        T.func_attr({"layout_free_buffers": [1], "tir.noalias": True, "global_symbol": "main"})
+        pad_temp = T.alloc_buffer([1, 58, 58, 64], dtype="float32")
+        conv2d_nhwc_global = T.alloc_buffer([1, 56, 56, 64], dtype="float32")
+        pad_temp_global = T.alloc_buffer([1, 58, 58, 64], dtype="float32")
+        p1_global = T.alloc_buffer([16, 2, 2, 3, 3, 32, 2], dtype="float32")
+        p1_global2 = T.alloc_buffer([16, 2, 2, 3, 3, 32, 2], dtype="float32", scope="global2")
+        p1_global_1 = T.alloc_buffer([16, 2, 2, 3, 3, 32, 2], dtype="float32")
+        for ax0, ax1, ax2, ax3 in T.grid(3, 3, 64, 64):
+            with T.block("p1_global"):
+                v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
+                T.reads(p1[v0, v1, v2, v3])
+                T.writes(p1_global_1[v3 // 4, v2 // 32, v3 % 4 // 2, v0, v1, v2 % 32, v3 % 2])
+                T.block_attr({"meta_schedule.layout_rewrite_preproc":True})
+                p1_global_1[v3 // 4, v2 // 32, v3 % 4 // 2, v0, v1, v2 % 32, v3 % 2] = p1[v0, v1, v2, v3]
+        for ax0, ax1, ax2, ax3 in T.grid(3, 3, 64, 64):
+            with T.block("p1_global2"):
+                v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
+                T.reads(p1_global_1[v3 // 4, v2 // 32, v3 % 4 // 2, v0, v1, v2 % 32, v3 % 2])
+                T.writes(p1_global2[v3 // 4, v2 // 32, v3 % 4 // 2, v0, v1, v2 % 32, v3 % 2])
+                p1_global2[v3 // 4, v2 // 32, v3 % 4 // 2, v0, v1, v2 % 32, v3 % 2] = p1_global_1[v3 // 4, v2 // 32, v3 % 4 // 2, v0, v1, v2 % 32, v3 % 2]
+        for i0_0_i1_0_i2_0_fused in T.parallel(4, annotations={"pragma_auto_unroll_max_step":16, "pragma_unroll_explicit":1}):
+            for ax0, ax1, ax2 in T.grid(1, 30, 30):
+                for ax3_fused in T.vectorized(64):
+                    with T.block("pad_temp"):
+                        i0 = T.axis.spatial(1, ax0)
+                        i1 = T.axis.spatial(58, i0_0_i1_0_i2_0_fused // 2 * 28 + ax1)
+                        i2 = T.axis.spatial(58, i0_0_i1_0_i2_0_fused % 2 * 28 + ax2)
+                        i3 = T.axis.spatial(64, ax3_fused)
+                        T.reads(p0[i0, i1 - 1, i2 - 1, i3])
+                        T.writes(pad_temp[i0, i1, i2, i3])
+                        pad_temp[i0, i1, i2, i3] = T.if_then_else(1 <= i1 and i1 < 57 and 1 <= i2 and i2 < 57, p0[i0, i1 - 1, i2 - 1, i3], T.float32(0), dtype="float32")
+            for i3_0 in T.serial(16):
+                for ax0_ax1_ax2_ax3_fused in T.serial(57600):
+                    with T.block("pad_temp_global"):
+                        v0 = T.axis.spatial(1, 0)
+                        v1 = T.axis.spatial(58, i0_0_i1_0_i2_0_fused // 2 * 28 + ax0_ax1_ax2_ax3_fused // 1920)
+                        v2 = T.axis.spatial(58, i0_0_i1_0_i2_0_fused % 2 * 28 + ax0_ax1_ax2_ax3_fused % 1920 // 64)
+                        v3 = T.axis.spatial(64, ax0_ax1_ax2_ax3_fused % 64)
+                        T.reads(pad_temp[v0, v1, v2, v3])
+                        T.writes(pad_temp_global[v0, v1, v2, v3])
+                        pad_temp_global[v0, v1, v2, v3] = pad_temp[v0, v1, v2, v3]
+                for ax0_ax1_ax2_ax3_fused in T.serial(2304):
+                    with T.block("p1_global"):
+                        v0 = T.axis.spatial(3, ax0_ax1_ax2_ax3_fused // 768)
+                        v1 = T.axis.spatial(3, ax0_ax1_ax2_ax3_fused % 768 // 256)
+                        v2 = T.axis.spatial(64, ax0_ax1_ax2_ax3_fused % 256 // 4)
+                        v3 = T.axis.spatial(64, i3_0 * 4 + ax0_ax1_ax2_ax3_fused % 4)
+                        T.reads(p1_global2[v3 // 4, v2 // 32, v3 % 4 // 2, v0, v1, v2 % 32, v3 % 2])
+                        T.writes(p1_global[v3 // 4, v2 // 32, v3 % 4 // 2, v0, v1, v2 % 32, v3 % 2])
+                        p1_global[v3 // 4, v2 // 32, v3 % 4 // 2, v0, v1, v2 % 32, v3 % 2] = p1_global2[v3 // 4, v2 // 32, v3 % 4 // 2, v0, v1, v2 % 32, v3 % 2]
+                for i0_1, i1_1, i2_1, i3_1 in T.grid(1, 7, 2, 1):
+                    for i0_2_init, i1_2_init, i2_2_init, i3_2_init, i0_3_init, i1_3_init, i2_3_init in T.grid(1, 1, 14, 2, 1, 4, 1):
+                        for i3_3_fused_init in T.vectorized(2):
+                            with T.block("conv2d_nhwc_init"):
+                                nn = T.axis.spatial(1, i0_2_init + i0_3_init + i0_1)
+                                yy = T.axis.spatial(56, i0_0_i1_0_i2_0_fused // 2 * 28 + i1_1 * 4 + i1_2_init * 4 + i1_3_init)
+                                xx = T.axis.spatial(56, i2_3_init + i0_0_i1_0_i2_0_fused % 2 * 28 + i2_1 * 14 + i2_2_init)
+                                ff = T.axis.spatial(64, i3_0 * 4 + i3_1 * 4 + i3_2_init * 2 + i3_3_fused_init)
+                                T.reads()
+                                T.writes(conv2d_nhwc_global[nn, yy, xx, ff])
+                                T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"})
+                                conv2d_nhwc_global[nn, yy, xx, ff] = T.float32(0)
+                    for i4_0, i5_0, i6_0, i0_2, i1_2, i2_2, i3_2, i4_1, i5_1, i6_1, i0_3, i1_3, i2_3 in T.grid(1, 1, 2, 1, 1, 14, 2, 3, 3, 32, 1, 4, 1):
+                        for i3_3_fused in T.vectorized(2):
+                            with T.block("conv2d_nhwc_update"):
+                                nn = T.axis.spatial(1, i0_2 + i0_3 + i0_1)
+                                yy = T.axis.spatial(56, i0_0_i1_0_i2_0_fused // 2 * 28 + i1_1 * 4 + i1_2 * 4 + i1_3)
+                                xx = T.axis.spatial(56, i2_3 + i0_0_i1_0_i2_0_fused % 2 * 28 + i2_1 * 14 + i2_2)
+                                ff = T.axis.spatial(64, i3_0 * 4 + i3_1 * 4 + i3_2 * 2 + i3_3_fused)
+                                ry = T.axis.reduce(3, i4_0 * 3 + i4_1)
+                                rx = T.axis.reduce(3, i5_0 * 3 + i5_1)
+                                rc = T.axis.reduce(64, i6_0 * 32 + i6_1)
+                                T.reads(conv2d_nhwc_global[nn, yy, xx, ff], pad_temp_global[nn, yy + ry, xx + rx, rc], p1_global[ff // 4, rc // 32, ff % 4 // 2, ry, rx, rc % 32, ff % 2])
+                                T.writes(conv2d_nhwc_global[nn, yy, xx, ff])
+                                T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"})
+                                conv2d_nhwc_global[nn, yy, xx, ff] = conv2d_nhwc_global[nn, yy, xx, ff] + pad_temp_global[nn, yy + ry, xx + rx, rc] * p1_global[ff // 4, rc // 32, ff % 4 // 2, ry, rx, rc % 32, ff % 2]
+                    for ax0, ax1, ax2 in T.grid(1, 4, 14):
+                        for ax3_fused in T.vectorized(4):
+                            with T.block("conv2d_nhwc_global"):
+                                v0 = T.axis.spatial(1, ax0)
+                                v1 = T.axis.spatial(56, i0_0_i1_0_i2_0_fused // 2 * 28 + i1_1 * 4 + ax1)
+                                v2 = T.axis.spatial(56, i0_0_i1_0_i2_0_fused % 2 * 28 + i2_1 * 14 + ax2)
+                                v3 = T.axis.spatial(64, i3_0 * 4 + ax3_fused)
+                                T.reads(conv2d_nhwc_global[v0, v1, v2, v3])
+                                T.writes(conv2d_nhwc[v0, v1, v2, v3])
+                                conv2d_nhwc[v0, v1, v2, v3] = conv2d_nhwc_global[v0, v1, v2, v3]
+
+# fmt: on
+
+
+def test_layout_rewrite_cache_read():
+    target = Target("llvm")
+    ctx = _create_context(Conv2dCacheRead, target)
+    sch = tvm.tir.Schedule(Conv2dCacheRead, debug_mask="all")
+    sch.enter_postproc()
+    assert ctx.space_generator.postprocs[0].apply(sch)
+    tvm.ir.assert_structural_equal(sch.mod, Conv2dCacheReadRewritten)
+
+
+def test_layout_rewrite_cache_read_multiple():
+    target = Target("llvm")
+    ctx = _create_context(Conv2dCacheRead, target)
+    sch = tvm.tir.Schedule(Conv2dCacheRead, debug_mask="all")
+    sch.cache_read(sch.get_block("p1_global"), 0, "global2")
+    sch.enter_postproc()
+    assert ctx.space_generator.postprocs[0].apply(sch)
+    tvm.ir.assert_structural_equal(sch.mod, Conv2dCacheReadMultipleRewritten)
+
+
 if __name__ == "__main__":
     tvm.testing.main()
diff --git a/tests/python/unittest/test_te_create_primfunc.py b/tests/python/unittest/test_te_create_primfunc.py
index b59880758e..7b8173d0b2 100644
--- a/tests/python/unittest/test_te_create_primfunc.py
+++ b/tests/python/unittest/test_te_create_primfunc.py
@@ -390,6 +390,7 @@ def expected_layout_attr(
             C[x, y] = C[x, y] + A[x, k] * B[y, k]
     for i0, i1 in T.grid(128, 128):
         with T.block("D"):
+            T.block_attr({"layout_free_placeholders": [C]})
             x, y = T.axis.remap("SS", [i0, i1])
             D[x, y] = C[x, y] + T.float32(1)