You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2022/11/15 20:25:55 UTC
[tvm] branch main updated: [MetaSchedule] Support schedules with cache read in RewriteLayout (#13384)
This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 8c30bda738 [MetaSchedule] Support schedules with cache read in RewriteLayout (#13384)
8c30bda738 is described below
commit 8c30bda738eb0b07c0457b6ee651f3f32857903b
Author: masahi <ma...@gmail.com>
AuthorDate: Wed Nov 16 05:25:48 2022 +0900
[MetaSchedule] Support schedules with cache read in RewriteLayout (#13384)
Currently when `CacheRead` and `RewriteLayout` are used together, the index map is derived based on the cache read block, which leads to weird result. This is because the current implementation assumes that the "layout-free" buffer is directly consumed by an "anchor" op such as conv2d or dense.
When `CacheRead` is involved, we need to find the index map for the cache-read buffer as it is consumed by an anchor op, and apply the same transformation to the layout-free buffer. My solution supports more general cases where there are multiple cache reads forming a "chain" of blocks, starting from the one that directly consumes the layout-free buffer passed as a parameter. So the layout transformation is back propagated over such chain.
---
src/meta_schedule/postproc/rewrite_layout.cc | 159 ++++++++----
src/te/operation/create_primfunc.cc | 9 +-
.../test_meta_schedule_postproc_rewrite_layout.py | 276 +++++++++++++++++++++
tests/python/unittest/test_te_create_primfunc.py | 1 +
4 files changed, 390 insertions(+), 55 deletions(-)
diff --git a/src/meta_schedule/postproc/rewrite_layout.cc b/src/meta_schedule/postproc/rewrite_layout.cc
index 3aed6680e3..71ae433871 100644
--- a/src/meta_schedule/postproc/rewrite_layout.cc
+++ b/src/meta_schedule/postproc/rewrite_layout.cc
@@ -16,6 +16,7 @@
* specific language governing permissions and limitations
* under the License.
*/
+#include <optional>
#include <unordered_set>
#include "../utils.h"
@@ -25,23 +26,15 @@ namespace tir {
/*!
* \brief Collect the block and index where the buffer is read.
- * \note The buffers are expected to be read by only one BufferLoad
+ * \note The buffer is expected to be read by only one BufferLoad
*/
class BufferReadPosCollector : public StmtExprVisitor {
public:
- explicit BufferReadPosCollector(const Array<Buffer>& buffers) {
- for (const Buffer& buf : buffers) {
- buffers_.insert(buf.get());
- }
- }
+ explicit BufferReadPosCollector(const Buffer& buffer) : buffer_(buffer.get()) {}
- const std::unordered_map<const BufferNode*, std::pair<Block, int>>& GetBufferLocations() const {
- return buffer_locs_;
- }
+ const std::pair<Block, int>& GetBufferLocation() const { return buffer_loc_; }
- const std::unordered_map<const BufferNode*, Optional<IndexMap>>& GetBufferIndexMap() const {
- return buffer_index_maps_;
- }
+ const Optional<IndexMap> GetBufferIndexMap() const { return buffer_index_map_; }
private:
void VisitStmt_(const ForNode* op) final {
@@ -61,7 +54,7 @@ class BufferReadPosCollector : public StmtExprVisitor {
CHECK(cur_realize_.defined()) << "BufferLoad occurred outside of any block";
const Buffer& buffer = op->buffer;
- if (buffers_.count(buffer.get())) {
+ if (buffer_ == buffer.get()) {
Map<Var, PrimExpr> subst_map;
for (size_t i = 0; i < cur_realize_->iter_values.size(); i++) {
const Var& var = cur_realize_->block->iter_vars[i]->var;
@@ -72,14 +65,14 @@ class BufferReadPosCollector : public StmtExprVisitor {
for (const PrimExpr& e : op->indices) {
subst_indices.push_back(Substitute(e, subst_map));
}
- buffer_index_maps_[buffer.get()] = SuggestIndexMap(/*buffer=*/buffer, //
- /*indices=*/subst_indices, //
- /*loops=*/loop_stack_, //
- /*predicate=*/cur_realize_->predicate, //
- /*analyzer=*/&analyzer_);
+ buffer_index_map_ = SuggestIndexMap(/*buffer=*/buffer, //
+ /*indices=*/subst_indices, //
+ /*loops=*/loop_stack_, //
+ /*predicate=*/cur_realize_->predicate, //
+ /*analyzer=*/&analyzer_);
int buffer_index = GetReadBufferIndex(cur_realize_->block, buffer);
ICHECK(buffer_index != -1);
- buffer_locs_[buffer.get()] = std::make_pair(cur_realize_->block, buffer_index);
+ buffer_loc_ = std::make_pair(cur_realize_->block, buffer_index);
}
}
@@ -93,12 +86,12 @@ class BufferReadPosCollector : public StmtExprVisitor {
}
private:
- /*! \brief All interested buffer. */
- std::unordered_set<const BufferNode*> buffers_;
- /*! \brief The result mapping from buffer to its inner-most block and read index. */
- std::unordered_map<const BufferNode*, std::pair<Block, int>> buffer_locs_;
- /*! \brief The result mapping from buffer to its IndexMap. */
- std::unordered_map<const BufferNode*, Optional<IndexMap>> buffer_index_maps_;
+ /*! \brief The buffer of interest. */
+ const BufferNode* buffer_;
+ /*! \brief The block that consumes the buffer and the corresponding read index. */
+ std::pair<Block, int> buffer_loc_;
+ /*! \brief The proposed IndexMap. */
+ Optional<IndexMap> buffer_index_map_;
/*! \brief Loop stack for calculating IndexMap. */
Array<For> loop_stack_;
@@ -143,8 +136,56 @@ Array<Buffer> CollectLayoutFreeBuffers(const PrimFuncNode* func) {
return layout_free_buffers;
}
+std::optional<std::tuple<Block, int, IndexMap>> GetSuggestedIndexMap(
+ Buffer buffer, const PrimFuncNode* prim_func) {
+ BufferReadPosCollector collector(buffer);
+ collector(prim_func->body);
+
+ const auto& index_map = collector.GetBufferIndexMap();
+
+ if (!index_map.defined() || !index_map) {
+ return std::nullopt;
+ }
+
+ const auto& [anchor_block, buffer_index] = collector.GetBufferLocation();
+
+ return std::make_tuple(anchor_block, buffer_index, index_map.value());
+}
+
+/*! \brief Get a chain of cache-read blocks, starting from the one consuming buf. */
+std::vector<std::string> GetCacheReadChain(const Buffer& buf, const PrimFuncNode* prim_func) {
+ class BufferReadChainCollector : public StmtVisitor {
+ public:
+ explicit BufferReadChainCollector(const Buffer& buffer) : cur_buffer_(buffer.get()) {}
+
+ void VisitStmt_(const BlockNode* op) final {
+ // Check if this block is doing cache_read or a similar operation that consumes cur_buffer_.
+ if (!op->init && op->reads.size() == 1 && op->writes.size() == 1 &&
+ op->reads[0]->buffer.get() == cur_buffer_) {
+ cache_read_chain.push_back(op->name_hint);
+ cur_buffer_ = op->writes[0]->buffer.get();
+ }
+ StmtVisitor::VisitStmt_(op);
+ }
+
+ std::vector<std::string> cache_read_chain;
+
+ private:
+ const BufferNode* cur_buffer_;
+ };
+
+ BufferReadChainCollector collector(buf);
+ collector(prim_func->body);
+ return collector.cache_read_chain;
+}
+
bool RewriteLayout(const Schedule& sch) {
std::vector<std::pair<StmtSRef, String>> results;
+ auto add_layout_rewrite_block = [&sch](BlockRV consumer_block_rv, int buffer_index) {
+ BlockRV rewrite_block_rv = sch->CacheRead(consumer_block_rv, buffer_index, "global");
+ sch->Annotate(rewrite_block_rv, attr::meta_schedule_layout_rewrite_preproc, const_true());
+ };
+
for (const auto& [g_var, base_func] : sch->mod()->functions) {
const String& func_name = g_var->name_hint;
const auto* prim_func = base_func.as<PrimFuncNode>();
@@ -153,36 +194,46 @@ bool RewriteLayout(const Schedule& sch) {
continue;
}
- Array<Buffer> layout_free_buffers = CollectLayoutFreeBuffers(prim_func);
-
- // Collect Buffer read positions
- BufferReadPosCollector collector(layout_free_buffers);
- collector(prim_func->body);
- const auto& locations = collector.GetBufferLocations();
- const auto& index_maps = collector.GetBufferIndexMap();
- // Check all buffers are collected
- if (locations.size() != layout_free_buffers.size() ||
- index_maps.size() != layout_free_buffers.size()) {
- return false;
- }
-
- for (const auto& kv : locations) {
- const Buffer& buffer = GetRef<Buffer>(kv.first);
- const Block& block = kv.second.first;
- int buffer_index = kv.second.second;
-
- // Get IndexMap
- const Optional<IndexMap> index_map = index_maps.at(buffer.get());
- if (!index_map.defined()) {
- continue;
+ for (auto buffer : CollectLayoutFreeBuffers(prim_func)) {
+ const auto cache_read_chain = GetCacheReadChain(buffer, prim_func);
+ if (cache_read_chain.empty()) {
+ // The common case, where the layout-free buffer is directly consumed by an anchor op such
+ // as conv2d or dense.
+ auto tup_opt = GetSuggestedIndexMap(buffer, prim_func);
+ if (tup_opt == std::nullopt) continue;
+
+ auto [anchor_block, buffer_index, index_map] = *tup_opt;
+ auto anchor_block_rv = sch->GetBlock(anchor_block->name_hint, func_name);
+ add_layout_rewrite_block(anchor_block_rv, buffer_index);
+ sch->TransformLayout(anchor_block_rv, buffer_index, BufferIndexType::kRead, index_map,
+ NullOpt);
+ } else {
+ // When the layout-free buffer is consumed by cache_read, we need to find the index map
+ // for a cache-read buffer that is directly consumed by an anchor op. The last buffer
+ // in cache_read_chain corresponds to that buffer.
+ Block cache_read_block = sch->Get(sch->GetBlock(cache_read_chain.back(), func_name));
+ ICHECK_EQ(cache_read_block->writes.size(), 1);
+ auto tup_opt = GetSuggestedIndexMap(cache_read_block->writes[0]->buffer, prim_func);
+ if (tup_opt == std::nullopt) continue;
+
+ auto [anchor_block, buffer_index, index_map] = *tup_opt;
+ // Transform the layout of the last cache-read buffer.
+ sch->TransformLayout(sch->GetBlock(anchor_block->name_hint, func_name), buffer_index,
+ BufferIndexType::kRead, index_map, NullOpt);
+
+ // Propagate the layout transformation over cache_read_chain, starting from
+ // the next-to-last cache-read buffer.
+ for (int i = static_cast<int>(cache_read_chain.size()) - 1; i >= 0; --i) {
+ BlockRV cache_read_block_rv = sch->GetBlock(cache_read_chain[i], func_name);
+ if (i == 0) {
+ // Before the first cache_read that consumes the layout-free buffer, insert
+ // a layout-rewrite block. Another cache-read buffer is added, and its layout is
+ // transformed by TransformLayout below.
+ add_layout_rewrite_block(cache_read_block_rv, 0);
+ }
+ sch->TransformLayout(cache_read_block_rv, 0, BufferIndexType::kRead, index_map, NullOpt);
+ }
}
-
- // Apply schedule
- BlockRV block_rv = sch->GetBlock(block->name_hint, func_name);
- BlockRV cached_block_rv = sch->CacheRead(block_rv, buffer_index, "global");
- sch->TransformLayout(block_rv, buffer_index, BufferIndexType::kRead, index_map.value(),
- NullOpt);
- sch->Annotate(cached_block_rv, attr::meta_schedule_layout_rewrite_preproc, const_true());
}
}
return true;
diff --git a/src/te/operation/create_primfunc.cc b/src/te/operation/create_primfunc.cc
index 80da5a7279..0581ad60e8 100644
--- a/src/te/operation/create_primfunc.cc
+++ b/src/te/operation/create_primfunc.cc
@@ -110,13 +110,20 @@ class LayoutFreePlaceholdersNormalizer : public StmtMutator {
Block block = Downcast<Block>(StmtMutator::VisitStmt_(_block));
BlockNode* n = block.CopyOnWrite();
if (Optional<ObjectRef> ann = n->annotations.Get(topi_attr)) {
+ Array<Buffer> new_buffers;
for (Buffer buffer : Downcast<Array<Buffer>>(ann)) {
auto it = buffer2index_.find(buffer);
if (it != buffer2index_.end()) {
layout_free_buffer_indices_.insert(it->second);
+ } else {
+ new_buffers.push_back(buffer);
}
}
- n->annotations.erase(topi_attr);
+ if (new_buffers.empty()) {
+ n->annotations.erase(topi_attr);
+ } else {
+ n->annotations.Set(topi_attr, new_buffers);
+ }
}
for (const String& attr : this->blocklist) {
auto it = n->annotations.find(attr);
diff --git a/tests/python/unittest/test_meta_schedule_postproc_rewrite_layout.py b/tests/python/unittest/test_meta_schedule_postproc_rewrite_layout.py
index 91a51c8e90..98c1f73685 100644
--- a/tests/python/unittest/test_meta_schedule_postproc_rewrite_layout.py
+++ b/tests/python/unittest/test_meta_schedule_postproc_rewrite_layout.py
@@ -204,5 +204,281 @@ def test_layout_rewrite():
tvm.ir.assert_structural_equal(sch.mod["main"], rewritten_tir_matmul)
+# fmt: off
+@tvm.script.ir_module
+class Conv2dCacheRead:
+ @T.prim_func
+ def main(p0: T.Buffer[(1, 56, 56, 64), "float32"], p1: T.Buffer[(3, 3, 64, 64), "float32"], conv2d_nhwc: T.Buffer[(1, 56, 56, 64), "float32"]):
+ T.func_attr({"layout_free_buffers": [1], "tir.noalias": True, "global_symbol": "main"})
+ pad_temp = T.alloc_buffer([1, 58, 58, 64], dtype="float32")
+ conv2d_nhwc_global = T.alloc_buffer([1, 56, 56, 64], dtype="float32")
+ pad_temp_global = T.alloc_buffer([1, 58, 58, 64], dtype="float32")
+ p1_global = T.alloc_buffer([3, 3, 64, 64], dtype="float32")
+ for i0_0_i1_0_i2_0_fused in T.parallel(4, annotations={"pragma_auto_unroll_max_step":16, "pragma_unroll_explicit":1}):
+ for ax0, ax1, ax2 in T.grid(1, 30, 30):
+ for ax3_fused in T.vectorized(64):
+ with T.block("pad_temp"):
+ i0 = T.axis.spatial(1, ax0)
+ i1 = T.axis.spatial(58, i0_0_i1_0_i2_0_fused // 2 * 28 + ax1)
+ i2 = T.axis.spatial(58, i0_0_i1_0_i2_0_fused % 2 * 28 + ax2)
+ i3 = T.axis.spatial(64, ax3_fused)
+ T.reads(p0[i0, i1 - 1, i2 - 1, i3])
+ T.writes(pad_temp[i0, i1, i2, i3])
+ pad_temp[i0, i1, i2, i3] = T.if_then_else(1 <= i1 and i1 < 57 and 1 <= i2 and i2 < 57, p0[i0, i1 - 1, i2 - 1, i3], T.float32(0), dtype="float32")
+ for i3_0 in T.serial(16):
+ for ax0_ax1_ax2_ax3_fused in T.serial(57600):
+ with T.block("pad_temp_global"):
+ v0 = T.axis.spatial(1, 0)
+ v1 = T.axis.spatial(58, i0_0_i1_0_i2_0_fused // 2 * 28 + ax0_ax1_ax2_ax3_fused // 1920)
+ v2 = T.axis.spatial(58, i0_0_i1_0_i2_0_fused % 2 * 28 + ax0_ax1_ax2_ax3_fused % 1920 // 64)
+ v3 = T.axis.spatial(64, ax0_ax1_ax2_ax3_fused % 64)
+ T.reads(pad_temp[v0, v1, v2, v3])
+ T.writes(pad_temp_global[v0, v1, v2, v3])
+ pad_temp_global[v0, v1, v2, v3] = pad_temp[v0, v1, v2, v3]
+ for ax0_ax1_ax2_ax3_fused in T.serial(2304):
+ with T.block("p1_global"):
+ v0 = T.axis.spatial(3, ax0_ax1_ax2_ax3_fused // 768)
+ v1 = T.axis.spatial(3, ax0_ax1_ax2_ax3_fused % 768 // 256)
+ v2 = T.axis.spatial(64, ax0_ax1_ax2_ax3_fused % 256 // 4)
+ v3 = T.axis.spatial(64, i3_0 * 4 + ax0_ax1_ax2_ax3_fused % 4)
+ T.reads(p1[v0, v1, v2, v3])
+ T.writes(p1_global[v0, v1, v2, v3])
+ p1_global[v0, v1, v2, v3] = p1[v0, v1, v2, v3]
+ for i0_1, i1_1, i2_1, i3_1 in T.grid(1, 7, 2, 1):
+ for i0_2_init, i1_2_init, i2_2_init, i3_2_init, i0_3_init, i1_3_init, i2_3_init in T.grid(1, 1, 14, 2, 1, 4, 1):
+ for i3_3_fused_init in T.vectorized(2):
+ with T.block("conv2d_nhwc_init"):
+ nn = T.axis.spatial(1, i0_2_init + i0_3_init + i0_1)
+ yy = T.axis.spatial(56, i0_0_i1_0_i2_0_fused // 2 * 28 + i1_1 * 4 + i1_2_init * 4 + i1_3_init)
+ xx = T.axis.spatial(56, i2_3_init + i0_0_i1_0_i2_0_fused % 2 * 28 + i2_1 * 14 + i2_2_init)
+ ff = T.axis.spatial(64, i3_0 * 4 + i3_1 * 4 + i3_2_init * 2 + i3_3_fused_init)
+ T.reads()
+ T.writes(conv2d_nhwc_global[nn, yy, xx, ff])
+ T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"})
+ conv2d_nhwc_global[nn, yy, xx, ff] = T.float32(0)
+ for i4_0, i5_0, i6_0, i0_2, i1_2, i2_2, i3_2, i4_1, i5_1, i6_1, i0_3, i1_3, i2_3 in T.grid(1, 1, 2, 1, 1, 14, 2, 3, 3, 32, 1, 4, 1):
+ for i3_3_fused in T.vectorized(2):
+ with T.block("conv2d_nhwc_update"):
+ nn = T.axis.spatial(1, i0_2 + i0_3 + i0_1)
+ yy = T.axis.spatial(56, i0_0_i1_0_i2_0_fused // 2 * 28 + i1_1 * 4 + i1_2 * 4 + i1_3)
+ xx = T.axis.spatial(56, i2_3 + i0_0_i1_0_i2_0_fused % 2 * 28 + i2_1 * 14 + i2_2)
+ ff = T.axis.spatial(64, i3_0 * 4 + i3_1 * 4 + i3_2 * 2 + i3_3_fused)
+ ry = T.axis.reduce(3, i4_0 * 3 + i4_1)
+ rx = T.axis.reduce(3, i5_0 * 3 + i5_1)
+ rc = T.axis.reduce(64, i6_0 * 32 + i6_1)
+ T.reads(conv2d_nhwc_global[nn, yy, xx, ff], pad_temp_global[nn, yy + ry, xx + rx, rc], p1_global[ry, rx, rc, ff])
+ T.writes(conv2d_nhwc_global[nn, yy, xx, ff])
+ T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"})
+ conv2d_nhwc_global[nn, yy, xx, ff] = conv2d_nhwc_global[nn, yy, xx, ff] + pad_temp_global[nn, yy + ry, xx + rx, rc] * p1_global[ry, rx, rc, ff]
+ for ax0, ax1, ax2 in T.grid(1, 4, 14):
+ for ax3_fused in T.vectorized(4):
+ with T.block("conv2d_nhwc_global"):
+ v0 = T.axis.spatial(1, ax0)
+ v1 = T.axis.spatial(56, i0_0_i1_0_i2_0_fused // 2 * 28 + i1_1 * 4 + ax1)
+ v2 = T.axis.spatial(56, i0_0_i1_0_i2_0_fused % 2 * 28 + i2_1 * 14 + ax2)
+ v3 = T.axis.spatial(64, i3_0 * 4 + ax3_fused)
+ T.reads(conv2d_nhwc_global[v0, v1, v2, v3])
+ T.writes(conv2d_nhwc[v0, v1, v2, v3])
+ conv2d_nhwc[v0, v1, v2, v3] = conv2d_nhwc_global[v0, v1, v2, v3]
+
+
+@tvm.script.ir_module
+class Conv2dCacheReadRewritten:
+ @T.prim_func
+ def main(p0: T.Buffer[(1, 56, 56, 64), "float32"], p1: T.Buffer[(3, 3, 64, 64), "float32"], conv2d_nhwc: T.Buffer[(1, 56, 56, 64), "float32"]):
+ T.func_attr({"layout_free_buffers": [1], "tir.noalias": True, "global_symbol": "main"})
+ pad_temp = T.alloc_buffer([1, 58, 58, 64], dtype="float32")
+ conv2d_nhwc_global = T.alloc_buffer([1, 56, 56, 64], dtype="float32")
+ pad_temp_global = T.alloc_buffer([1, 58, 58, 64], dtype="float32")
+ p1_global = T.alloc_buffer([16, 2, 2, 3, 3, 32, 2], dtype="float32")
+ p1_global_1 = T.alloc_buffer([16, 2, 2, 3, 3, 32, 2], dtype="float32")
+ for ax0, ax1, ax2, ax3 in T.grid(3, 3, 64, 64):
+ with T.block("p1_global"):
+ v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
+ T.reads(p1[v0, v1, v2, v3])
+ T.writes(p1_global_1[v3 // 4, v2 // 32, v3 % 4 // 2, v0, v1, v2 % 32, v3 % 2])
+ T.block_attr({"meta_schedule.layout_rewrite_preproc":True})
+ p1_global_1[v3 // 4, v2 // 32, v3 % 4 // 2, v0, v1, v2 % 32, v3 % 2] = p1[v0, v1, v2, v3]
+ for i0_0_i1_0_i2_0_fused in T.parallel(4, annotations={"pragma_auto_unroll_max_step":16, "pragma_unroll_explicit":1}):
+ for ax0, ax1, ax2 in T.grid(1, 30, 30):
+ for ax3_fused in T.vectorized(64):
+ with T.block("pad_temp"):
+ i0 = T.axis.spatial(1, ax0)
+ i1 = T.axis.spatial(58, i0_0_i1_0_i2_0_fused // 2 * 28 + ax1)
+ i2 = T.axis.spatial(58, i0_0_i1_0_i2_0_fused % 2 * 28 + ax2)
+ i3 = T.axis.spatial(64, ax3_fused)
+ T.reads(p0[i0, i1 - 1, i2 - 1, i3])
+ T.writes(pad_temp[i0, i1, i2, i3])
+ pad_temp[i0, i1, i2, i3] = T.if_then_else(1 <= i1 and i1 < 57 and 1 <= i2 and i2 < 57, p0[i0, i1 - 1, i2 - 1, i3], T.float32(0), dtype="float32")
+ for i3_0 in T.serial(16):
+ for ax0_ax1_ax2_ax3_fused in T.serial(57600):
+ with T.block("pad_temp_global"):
+ v0 = T.axis.spatial(1, 0)
+ v1 = T.axis.spatial(58, i0_0_i1_0_i2_0_fused // 2 * 28 + ax0_ax1_ax2_ax3_fused // 1920)
+ v2 = T.axis.spatial(58, i0_0_i1_0_i2_0_fused % 2 * 28 + ax0_ax1_ax2_ax3_fused % 1920 // 64)
+ v3 = T.axis.spatial(64, ax0_ax1_ax2_ax3_fused % 64)
+ T.reads(pad_temp[v0, v1, v2, v3])
+ T.writes(pad_temp_global[v0, v1, v2, v3])
+ pad_temp_global[v0, v1, v2, v3] = pad_temp[v0, v1, v2, v3]
+ for ax0_ax1_ax2_ax3_fused in T.serial(2304):
+ with T.block("p1_global"):
+ v0 = T.axis.spatial(3, ax0_ax1_ax2_ax3_fused // 768)
+ v1 = T.axis.spatial(3, ax0_ax1_ax2_ax3_fused % 768 // 256)
+ v2 = T.axis.spatial(64, ax0_ax1_ax2_ax3_fused % 256 // 4)
+ v3 = T.axis.spatial(64, i3_0 * 4 + ax0_ax1_ax2_ax3_fused % 4)
+ T.reads(p1_global_1[v3 // 4, v2 // 32, v3 % 4 // 2, v0, v1, v2 % 32, v3 % 2])
+ T.writes(p1_global[v3 // 4, v2 // 32, v3 % 4 // 2, v0, v1, v2 % 32, v3 % 2])
+ p1_global[v3 // 4, v2 // 32, v3 % 4 // 2, v0, v1, v2 % 32, v3 % 2] = p1_global_1[v3 // 4, v2 // 32, v3 % 4 // 2, v0, v1, v2 % 32, v3 % 2]
+ for i0_1, i1_1, i2_1, i3_1 in T.grid(1, 7, 2, 1):
+ for i0_2_init, i1_2_init, i2_2_init, i3_2_init, i0_3_init, i1_3_init, i2_3_init in T.grid(1, 1, 14, 2, 1, 4, 1):
+ for i3_3_fused_init in T.vectorized(2):
+ with T.block("conv2d_nhwc_init"):
+ nn = T.axis.spatial(1, i0_2_init + i0_3_init + i0_1)
+ yy = T.axis.spatial(56, i0_0_i1_0_i2_0_fused // 2 * 28 + i1_1 * 4 + i1_2_init * 4 + i1_3_init)
+ xx = T.axis.spatial(56, i2_3_init + i0_0_i1_0_i2_0_fused % 2 * 28 + i2_1 * 14 + i2_2_init)
+ ff = T.axis.spatial(64, i3_0 * 4 + i3_1 * 4 + i3_2_init * 2 + i3_3_fused_init)
+ T.reads()
+ T.writes(conv2d_nhwc_global[nn, yy, xx, ff])
+ T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"})
+ conv2d_nhwc_global[nn, yy, xx, ff] = T.float32(0)
+ for i4_0, i5_0, i6_0, i0_2, i1_2, i2_2, i3_2, i4_1, i5_1, i6_1, i0_3, i1_3, i2_3 in T.grid(1, 1, 2, 1, 1, 14, 2, 3, 3, 32, 1, 4, 1):
+ for i3_3_fused in T.vectorized(2):
+ with T.block("conv2d_nhwc_update"):
+ nn = T.axis.spatial(1, i0_2 + i0_3 + i0_1)
+ yy = T.axis.spatial(56, i0_0_i1_0_i2_0_fused // 2 * 28 + i1_1 * 4 + i1_2 * 4 + i1_3)
+ xx = T.axis.spatial(56, i2_3 + i0_0_i1_0_i2_0_fused % 2 * 28 + i2_1 * 14 + i2_2)
+ ff = T.axis.spatial(64, i3_0 * 4 + i3_1 * 4 + i3_2 * 2 + i3_3_fused)
+ ry = T.axis.reduce(3, i4_0 * 3 + i4_1)
+ rx = T.axis.reduce(3, i5_0 * 3 + i5_1)
+ rc = T.axis.reduce(64, i6_0 * 32 + i6_1)
+ T.reads(conv2d_nhwc_global[nn, yy, xx, ff], pad_temp_global[nn, yy + ry, xx + rx, rc], p1_global[ff // 4, rc // 32, ff % 4 // 2, ry, rx, rc % 32, ff % 2])
+ T.writes(conv2d_nhwc_global[nn, yy, xx, ff])
+ T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"})
+ conv2d_nhwc_global[nn, yy, xx, ff] = conv2d_nhwc_global[nn, yy, xx, ff] + pad_temp_global[nn, yy + ry, xx + rx, rc] * p1_global[ff // 4, rc // 32, ff % 4 // 2, ry, rx, rc % 32, ff % 2]
+ for ax0, ax1, ax2 in T.grid(1, 4, 14):
+ for ax3_fused in T.vectorized(4):
+ with T.block("conv2d_nhwc_global"):
+ v0 = T.axis.spatial(1, ax0)
+ v1 = T.axis.spatial(56, i0_0_i1_0_i2_0_fused // 2 * 28 + i1_1 * 4 + ax1)
+ v2 = T.axis.spatial(56, i0_0_i1_0_i2_0_fused % 2 * 28 + i2_1 * 14 + ax2)
+ v3 = T.axis.spatial(64, i3_0 * 4 + ax3_fused)
+ T.reads(conv2d_nhwc_global[v0, v1, v2, v3])
+ T.writes(conv2d_nhwc[v0, v1, v2, v3])
+ conv2d_nhwc[v0, v1, v2, v3] = conv2d_nhwc_global[v0, v1, v2, v3]
+
+
+@tvm.script.ir_module
+class Conv2dCacheReadMultipleRewritten:
+ @T.prim_func
+ def main(p0: T.Buffer[(1, 56, 56, 64), "float32"], p1: T.Buffer[(3, 3, 64, 64), "float32"], conv2d_nhwc: T.Buffer[(1, 56, 56, 64), "float32"]):
+ T.func_attr({"layout_free_buffers": [1], "tir.noalias": True, "global_symbol": "main"})
+ pad_temp = T.alloc_buffer([1, 58, 58, 64], dtype="float32")
+ conv2d_nhwc_global = T.alloc_buffer([1, 56, 56, 64], dtype="float32")
+ pad_temp_global = T.alloc_buffer([1, 58, 58, 64], dtype="float32")
+ p1_global = T.alloc_buffer([16, 2, 2, 3, 3, 32, 2], dtype="float32")
+ p1_global2 = T.alloc_buffer([16, 2, 2, 3, 3, 32, 2], dtype="float32", scope="global2")
+ p1_global_1 = T.alloc_buffer([16, 2, 2, 3, 3, 32, 2], dtype="float32")
+ for ax0, ax1, ax2, ax3 in T.grid(3, 3, 64, 64):
+ with T.block("p1_global"):
+ v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
+ T.reads(p1[v0, v1, v2, v3])
+ T.writes(p1_global_1[v3 // 4, v2 // 32, v3 % 4 // 2, v0, v1, v2 % 32, v3 % 2])
+ T.block_attr({"meta_schedule.layout_rewrite_preproc":True})
+ p1_global_1[v3 // 4, v2 // 32, v3 % 4 // 2, v0, v1, v2 % 32, v3 % 2] = p1[v0, v1, v2, v3]
+ for ax0, ax1, ax2, ax3 in T.grid(3, 3, 64, 64):
+ with T.block("p1_global2"):
+ v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
+ T.reads(p1_global_1[v3 // 4, v2 // 32, v3 % 4 // 2, v0, v1, v2 % 32, v3 % 2])
+ T.writes(p1_global2[v3 // 4, v2 // 32, v3 % 4 // 2, v0, v1, v2 % 32, v3 % 2])
+ p1_global2[v3 // 4, v2 // 32, v3 % 4 // 2, v0, v1, v2 % 32, v3 % 2] = p1_global_1[v3 // 4, v2 // 32, v3 % 4 // 2, v0, v1, v2 % 32, v3 % 2]
+ for i0_0_i1_0_i2_0_fused in T.parallel(4, annotations={"pragma_auto_unroll_max_step":16, "pragma_unroll_explicit":1}):
+ for ax0, ax1, ax2 in T.grid(1, 30, 30):
+ for ax3_fused in T.vectorized(64):
+ with T.block("pad_temp"):
+ i0 = T.axis.spatial(1, ax0)
+ i1 = T.axis.spatial(58, i0_0_i1_0_i2_0_fused // 2 * 28 + ax1)
+ i2 = T.axis.spatial(58, i0_0_i1_0_i2_0_fused % 2 * 28 + ax2)
+ i3 = T.axis.spatial(64, ax3_fused)
+ T.reads(p0[i0, i1 - 1, i2 - 1, i3])
+ T.writes(pad_temp[i0, i1, i2, i3])
+ pad_temp[i0, i1, i2, i3] = T.if_then_else(1 <= i1 and i1 < 57 and 1 <= i2 and i2 < 57, p0[i0, i1 - 1, i2 - 1, i3], T.float32(0), dtype="float32")
+ for i3_0 in T.serial(16):
+ for ax0_ax1_ax2_ax3_fused in T.serial(57600):
+ with T.block("pad_temp_global"):
+ v0 = T.axis.spatial(1, 0)
+ v1 = T.axis.spatial(58, i0_0_i1_0_i2_0_fused // 2 * 28 + ax0_ax1_ax2_ax3_fused // 1920)
+ v2 = T.axis.spatial(58, i0_0_i1_0_i2_0_fused % 2 * 28 + ax0_ax1_ax2_ax3_fused % 1920 // 64)
+ v3 = T.axis.spatial(64, ax0_ax1_ax2_ax3_fused % 64)
+ T.reads(pad_temp[v0, v1, v2, v3])
+ T.writes(pad_temp_global[v0, v1, v2, v3])
+ pad_temp_global[v0, v1, v2, v3] = pad_temp[v0, v1, v2, v3]
+ for ax0_ax1_ax2_ax3_fused in T.serial(2304):
+ with T.block("p1_global"):
+ v0 = T.axis.spatial(3, ax0_ax1_ax2_ax3_fused // 768)
+ v1 = T.axis.spatial(3, ax0_ax1_ax2_ax3_fused % 768 // 256)
+ v2 = T.axis.spatial(64, ax0_ax1_ax2_ax3_fused % 256 // 4)
+ v3 = T.axis.spatial(64, i3_0 * 4 + ax0_ax1_ax2_ax3_fused % 4)
+ T.reads(p1_global2[v3 // 4, v2 // 32, v3 % 4 // 2, v0, v1, v2 % 32, v3 % 2])
+ T.writes(p1_global[v3 // 4, v2 // 32, v3 % 4 // 2, v0, v1, v2 % 32, v3 % 2])
+ p1_global[v3 // 4, v2 // 32, v3 % 4 // 2, v0, v1, v2 % 32, v3 % 2] = p1_global2[v3 // 4, v2 // 32, v3 % 4 // 2, v0, v1, v2 % 32, v3 % 2]
+ for i0_1, i1_1, i2_1, i3_1 in T.grid(1, 7, 2, 1):
+ for i0_2_init, i1_2_init, i2_2_init, i3_2_init, i0_3_init, i1_3_init, i2_3_init in T.grid(1, 1, 14, 2, 1, 4, 1):
+ for i3_3_fused_init in T.vectorized(2):
+ with T.block("conv2d_nhwc_init"):
+ nn = T.axis.spatial(1, i0_2_init + i0_3_init + i0_1)
+ yy = T.axis.spatial(56, i0_0_i1_0_i2_0_fused // 2 * 28 + i1_1 * 4 + i1_2_init * 4 + i1_3_init)
+ xx = T.axis.spatial(56, i2_3_init + i0_0_i1_0_i2_0_fused % 2 * 28 + i2_1 * 14 + i2_2_init)
+ ff = T.axis.spatial(64, i3_0 * 4 + i3_1 * 4 + i3_2_init * 2 + i3_3_fused_init)
+ T.reads()
+ T.writes(conv2d_nhwc_global[nn, yy, xx, ff])
+ T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"})
+ conv2d_nhwc_global[nn, yy, xx, ff] = T.float32(0)
+ for i4_0, i5_0, i6_0, i0_2, i1_2, i2_2, i3_2, i4_1, i5_1, i6_1, i0_3, i1_3, i2_3 in T.grid(1, 1, 2, 1, 1, 14, 2, 3, 3, 32, 1, 4, 1):
+ for i3_3_fused in T.vectorized(2):
+ with T.block("conv2d_nhwc_update"):
+ nn = T.axis.spatial(1, i0_2 + i0_3 + i0_1)
+ yy = T.axis.spatial(56, i0_0_i1_0_i2_0_fused // 2 * 28 + i1_1 * 4 + i1_2 * 4 + i1_3)
+ xx = T.axis.spatial(56, i2_3 + i0_0_i1_0_i2_0_fused % 2 * 28 + i2_1 * 14 + i2_2)
+ ff = T.axis.spatial(64, i3_0 * 4 + i3_1 * 4 + i3_2 * 2 + i3_3_fused)
+ ry = T.axis.reduce(3, i4_0 * 3 + i4_1)
+ rx = T.axis.reduce(3, i5_0 * 3 + i5_1)
+ rc = T.axis.reduce(64, i6_0 * 32 + i6_1)
+ T.reads(conv2d_nhwc_global[nn, yy, xx, ff], pad_temp_global[nn, yy + ry, xx + rx, rc], p1_global[ff // 4, rc // 32, ff % 4 // 2, ry, rx, rc % 32, ff % 2])
+ T.writes(conv2d_nhwc_global[nn, yy, xx, ff])
+ T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"})
+ conv2d_nhwc_global[nn, yy, xx, ff] = conv2d_nhwc_global[nn, yy, xx, ff] + pad_temp_global[nn, yy + ry, xx + rx, rc] * p1_global[ff // 4, rc // 32, ff % 4 // 2, ry, rx, rc % 32, ff % 2]
+ for ax0, ax1, ax2 in T.grid(1, 4, 14):
+ for ax3_fused in T.vectorized(4):
+ with T.block("conv2d_nhwc_global"):
+ v0 = T.axis.spatial(1, ax0)
+ v1 = T.axis.spatial(56, i0_0_i1_0_i2_0_fused // 2 * 28 + i1_1 * 4 + ax1)
+ v2 = T.axis.spatial(56, i0_0_i1_0_i2_0_fused % 2 * 28 + i2_1 * 14 + ax2)
+ v3 = T.axis.spatial(64, i3_0 * 4 + ax3_fused)
+ T.reads(conv2d_nhwc_global[v0, v1, v2, v3])
+ T.writes(conv2d_nhwc[v0, v1, v2, v3])
+ conv2d_nhwc[v0, v1, v2, v3] = conv2d_nhwc_global[v0, v1, v2, v3]
+
+# fmt: on
+
+
+def test_layout_rewrite_cache_read():
+ target = Target("llvm")
+ ctx = _create_context(Conv2dCacheRead, target)
+ sch = tvm.tir.Schedule(Conv2dCacheRead, debug_mask="all")
+ sch.enter_postproc()
+ assert ctx.space_generator.postprocs[0].apply(sch)
+ tvm.ir.assert_structural_equal(sch.mod, Conv2dCacheReadRewritten)
+
+
+def test_layout_rewrite_cache_read_multiple():
+ target = Target("llvm")
+ ctx = _create_context(Conv2dCacheRead, target)
+ sch = tvm.tir.Schedule(Conv2dCacheRead, debug_mask="all")
+ sch.cache_read(sch.get_block("p1_global"), 0, "global2")
+ sch.enter_postproc()
+ assert ctx.space_generator.postprocs[0].apply(sch)
+ tvm.ir.assert_structural_equal(sch.mod, Conv2dCacheReadMultipleRewritten)
+
+
if __name__ == "__main__":
tvm.testing.main()
diff --git a/tests/python/unittest/test_te_create_primfunc.py b/tests/python/unittest/test_te_create_primfunc.py
index b59880758e..7b8173d0b2 100644
--- a/tests/python/unittest/test_te_create_primfunc.py
+++ b/tests/python/unittest/test_te_create_primfunc.py
@@ -390,6 +390,7 @@ def expected_layout_attr(
C[x, y] = C[x, y] + A[x, k] * B[y, k]
for i0, i1 in T.grid(128, 128):
with T.block("D"):
+ T.block_attr({"layout_free_placeholders": [C]})
x, y = T.axis.remap("SS", [i0, i1])
D[x, y] = C[x, y] + T.float32(1)