You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2022/08/13 08:06:55 UTC

[tvm] branch main updated: [TIR] Add pass ManifestSharedMemoryLocalStage (#12355)

This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 036aa722ae [TIR] Add pass ManifestSharedMemoryLocalStage (#12355)
036aa722ae is described below

commit 036aa722aefd7126c17063cddb260b9347f17598
Author: Wuwei Lin <wu...@apache.org>
AuthorDate: Sat Aug 13 01:06:50 2022 -0700

    [TIR] Add pass ManifestSharedMemoryLocalStage (#12355)
    
    Added a pass to insert local (cache) stage for the shared memory. It's similar to cache read but bypasses the limitation of int set analysis for compacting buffer region by inferring the buffer shape from the loop extents.
---
 include/tvm/tir/stmt.h                             |   3 +
 include/tvm/tir/transform.h                        |   6 +
 python/tvm/tir/transform/transform.py              |  11 +
 src/driver/driver_api.cc                           |   1 +
 .../manifest_shared_memory_local_stage.cc          | 287 +++++++++++++++++++++
 tests/python/frontend/pytorch/test_forward.py      |   2 +-
 ...transform_manifest_shared_memory_local_stage.py | 134 ++++++++++
 7 files changed, 443 insertions(+), 1 deletion(-)

diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h
index 5be1b9626d..bee9819a22 100644
--- a/include/tvm/tir/stmt.h
+++ b/include/tvm/tir/stmt.h
@@ -1547,6 +1547,9 @@ constexpr const char* software_pipeline_async_stages = "software_pipeline_async_
 /*! \brief Mark the buffers which is const access and can be transformed layout. */
 constexpr const char* layout_free_buffers = "layout_free_buffers";
 
+/*! \brief Mark the local stage for the shared memory access should be added. */
+constexpr const char* manifest_shared_memory_local_stage = "tir.manifest_shared_memory_local_stage";
+
 /*! \brief Mark the tiling structure of blocks that are applied by rule Multi-Level-Tiling */
 constexpr const char* meta_schedule_tiling_structure = "meta_schedule.tiling_structure";
 
diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h
index c758a00b3f..fd4261e4a4 100644
--- a/include/tvm/tir/transform.h
+++ b/include/tvm/tir/transform.h
@@ -674,6 +674,12 @@ TVM_DLL Pass InjectPTXAsyncCopy();
  */
 TVM_DLL Pass RemoveWeightLayoutRewriteBlock();
 
+/*!
+ * \brief Add the explicit local stage for the shared memory access on GPU.
+ * \return The pass.
+ */
+TVM_DLL Pass ManifestSharedMemoryLocalStage();
+
 }  // namespace transform
 }  // namespace tir
 }  // namespace tvm
diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py
index eb2cff641c..324471c718 100644
--- a/python/tvm/tir/transform/transform.py
+++ b/python/tvm/tir/transform/transform.py
@@ -949,3 +949,14 @@ def RemoveWeightLayoutRewriteBlock():
         The result pass
     """
     return _ffi_api.RemoveWeightLayoutRewriteBlock()  # type: ignore
+
+
+def ManifestSharedMemoryLocalStage():
+    """Add the explicit local stage for the shared memory access on GPU.
+
+    Returns
+    -------
+    fpass : tvm.transform.Pass
+        The result pass
+    """
+    return _ffi_api.ManifestSharedMemoryLocalStage()  # type: ignore
diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc
index 9bd2e8a812..e528686d96 100644
--- a/src/driver/driver_api.cc
+++ b/src/driver/driver_api.cc
@@ -199,6 +199,7 @@ Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition) {
   pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation());
   pass_list.push_back(tir::transform::ConvertBlocksToOpaque());
   pass_list.push_back(tir::transform::UnifyThreadBinding());
+  pass_list.push_back(tir::transform::ManifestSharedMemoryLocalStage());
   pass_list.push_back(tir::transform::CompactBufferAllocation());
   pass_list.push_back(tir::transform::LowerMatchBuffer());
   pass_list.push_back(tir::transform::InjectSoftwarePipeline());
diff --git a/src/tir/transforms/manifest_shared_memory_local_stage.cc b/src/tir/transforms/manifest_shared_memory_local_stage.cc
new file mode 100644
index 0000000000..3a3abf0b80
--- /dev/null
+++ b/src/tir/transforms/manifest_shared_memory_local_stage.cc
@@ -0,0 +1,287 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file manifest_shared_memroy_local_stage.cc
+ * \brief Add the explicit local stage for the shared memory access on GPU.
+ *
+ * This pass finds the cache_read stage on the shared memory, and create another intermediate stage
+ * to store the data into local memory first, and then copy the data from local memory to the shared
+ * memory. This is similar to the schedule primitive cache_read, but it bypasses the limitation
+ * of requiring buffer access to be contiguous in each dimension.
+ */
+#include <tvm/arith/analyzer.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/op.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+
+#include <unordered_set>
+
+#include "../../runtime/thread_storage_scope.h"
+#include "../schedule/transform.h"
+#include "tvm/tir/stmt.h"
+
+namespace tvm {
+namespace tir {
+
+/*! \brief Rewriter for the block storing to the target buffer. Create an intermediate cache stage
+ * to store the result. Rewrite the original block to load from the intermediate buffer.
+ */
+class IntermediateStageRewriter {
+ public:
+  explicit IntermediateStageRewriter(const Array<Stmt>& ancestor_loop_or_blocks)
+      : ancestor_loop_or_blocks_(ancestor_loop_or_blocks) {}
+
+  std::tuple<Buffer, Buffer, Block, Stmt> Rewrite(const BlockNode* block) {
+    const BufferStoreNode* store = block->body.as<BufferStoreNode>();
+    CHECK(store != nullptr && runtime::StorageScope::Create(store->buffer.scope()).rank ==
+                                  runtime::StorageRank::kShared)
+        << "ValueError: Expect the body of the block to be BufferStore to shared memory.";
+
+    const Buffer& target_buffer = store->buffer;
+
+    // Step 0: Collect relaxed loops
+    std::vector<const ForNode*> relaxed_loops = CollectRelaxedOuterLoops(block, target_buffer);
+
+    // Step 1: Create buffer for the local stage
+    Buffer new_buffer{nullptr};
+    Array<PrimExpr> buffer_indices;
+    std::tie(new_buffer, buffer_indices) = CreateIntermediateBuffer(relaxed_loops, target_buffer);
+
+    // Step 2: Create the local stage block
+    Stmt local_stage = MakeLocalStage(block, new_buffer, buffer_indices, relaxed_loops, store);
+
+    // Step 3: Create BufferLoad from the intermediate buffer
+    BufferLoad new_buffer_load = BufferLoad(new_buffer, buffer_indices);
+    BufferStore new_buffer_store = Downcast<BufferStore>(block->body);
+    new_buffer_store.CopyOnWrite()->value = new_buffer_load;
+    Block new_block = GetRef<Block>(block);
+    new_block.CopyOnWrite()->body = std::move(new_buffer_store);
+
+    return {target_buffer, new_buffer, new_block, local_stage};
+  }
+
+ private:
+  /*! \brief Collect relaxed outer loops from innermost to outermost */
+  std::vector<const ForNode*> CollectRelaxedOuterLoops(const BlockNode* block,
+                                                       const Buffer& target_buffer) {
+    std::vector<const ForNode*> relaxed_loops;
+    for (int n = static_cast<int>(ancestor_loop_or_blocks_.size()) - 1, i = n - 1; i >= 0; --i) {
+      const Stmt& ancestor = ancestor_loop_or_blocks_[i];
+      if (const ForNode* ancestor_loop = ancestor.as<ForNode>()) {
+        CHECK(ancestor_loop->kind == ForKind::kSerial ||
+              ancestor_loop->kind == ForKind::kVectorized)
+            << "ValueError: Expect the ancestor loops to be serial or vectorized, got "
+            << ancestor_loop->kind;
+        relaxed_loops.push_back(ancestor.as<ForNode>());
+
+        if (i < n - 1) {
+          CHECK(ancestor_loop->body.same_as(ancestor_loop_or_blocks_[i + 1]))
+              << "ValueError: Expect the ancestor loops to have a single child.";
+        } else {
+          const BlockRealizeNode* block_realize = ancestor_loop->body.as<BlockRealizeNode>();
+          ICHECK(block_realize != nullptr);
+          CHECK(block_realize != nullptr && block_realize->block.get() == block)
+              << "ValueError: Expect the ancestor loops to have a single child.";
+        }
+      } else {
+        const BlockRealizeNode* ancestor_block_realize = ancestor.as<BlockRealizeNode>();
+        ICHECK(ancestor_block_realize != nullptr);
+        const BlockNode* ancestor_block = ancestor_block_realize->block.get();
+        auto it = std::find_if(
+            ancestor_block->alloc_buffers.begin(), ancestor_block->alloc_buffers.end(),
+            [&target_buffer](const Buffer& buffer) { return buffer.same_as(target_buffer); });
+        CHECK(it != ancestor_block->alloc_buffers.end())
+            << "ValueError: Expect the shared memory allocation to be in the parent block.";
+        break;
+      }
+    }
+    return relaxed_loops;
+  }
+
+  /*! \brief Create the intermediate stage. */
+  Stmt MakeLocalStage(const BlockNode* block, const Buffer& new_buffer,
+                      Array<PrimExpr> local_stage_indices,
+                      std::vector<const ForNode*> relaxed_loops, const BufferStoreNode* store) {
+    // Step 0: Create the body of the local stage, which is BufferStore to the intermediate buffer.
+    Stmt local_stage = BufferStore(new_buffer, store->value, local_stage_indices);
+
+    // Step 1: Make block and block realize
+    BufferRegion write_buffer_region = BufferRegion::FromPoint(new_buffer, local_stage_indices);
+    local_stage =
+        Block(/*iter_vars=*/{}, /*reads=*/block->reads, /*writes=*/{write_buffer_region}, "",
+              /*body=*/std::move(local_stage));
+    local_stage = BlockRealize(
+        /*iter_values=*/{},
+        /*predicate=*/ancestor_loop_or_blocks_.back().as<BlockRealizeNode>()->predicate,
+        Downcast<Block>(local_stage));
+
+    // Step 2: Add outer loops
+    Map<Var, PrimExpr> subst_map;
+    for (const ForNode* relaxed_loop : relaxed_loops) {
+      ObjectPtr<ForNode> for_node = make_object<ForNode>(*relaxed_loop);
+      for_node->loop_var = for_node->loop_var.copy_with_suffix("");
+      for_node->body = std::move(local_stage);
+      local_stage = For(for_node);
+      subst_map.Set(relaxed_loop->loop_var, for_node->loop_var);
+    }
+    local_stage = Substitute(local_stage, subst_map);
+    return local_stage;
+  }
+
+  /*! \brief Create the intermediate buffer with the extents of the relaxed outer loops. */
+  std::pair<Buffer, Array<PrimExpr>> CreateIntermediateBuffer(
+      const std::vector<const ForNode*> relaxed_loops, const Buffer& buffer) const {
+    Array<PrimExpr> buffer_indices;
+    Array<PrimExpr> new_buffer_shape;
+
+    // Create the intermediate buffer for the local stage. The shape of the new buffer is the
+    // extents of the relaxed outer loops.
+
+    for (auto it = relaxed_loops.rbegin(); it != relaxed_loops.rend(); ++it) {
+      const ForNode* relaxed_loop = *it;
+      buffer_indices.push_back(relaxed_loop->min + relaxed_loop->loop_var);
+      new_buffer_shape.push_back(relaxed_loop->extent);
+    }
+    Buffer new_buffer = WithScope(buffer, "local");
+    new_buffer.CopyOnWrite()->shape = new_buffer_shape;
+    return {new_buffer, buffer_indices};
+  }
+
+  const Array<Stmt>& ancestor_loop_or_blocks_;
+};
+
+class SharedMemoryLocalStageInserter : public StmtMutator {
+ public:
+  Stmt VisitStmt_(const ForNode* op) final {
+    ancestor_loop_or_blocks_.push_back(GetRef<Stmt>(op));
+    Stmt new_stmt = StmtMutator::VisitStmt_(op);
+    ancestor_loop_or_blocks_.pop_back();
+    return new_stmt;
+  }
+
+  Stmt VisitStmt_(const BlockRealizeNode* op) final {
+    ancestor_loop_or_blocks_.push_back(GetRef<Stmt>(op));
+    Stmt new_stmt = StmtMutator::VisitStmt_(op);
+    ancestor_loop_or_blocks_.pop_back();
+    return new_stmt;
+  }
+
+  Stmt VisitStmt_(const BlockNode* op) final {
+    if (op->annotations.count(attr::manifest_shared_memory_local_stage)) {
+      // Rewrite the shared memory access to load from the intermediate buffer.
+      // The annotated block must be a leaf block (will be checked during rewriting). No need to
+      // visit its body recursively.
+
+      Buffer target_buffer{nullptr};
+      Buffer new_buffer{nullptr};
+      Block new_block{nullptr};
+      Stmt local_stage{nullptr};
+      IntermediateStageRewriter rewriter(ancestor_loop_or_blocks_);
+      std::tie(target_buffer, new_buffer, new_block, local_stage) = rewriter.Rewrite(op);
+      buffer_remap_.Set(target_buffer, new_buffer);
+
+      new_block.CopyOnWrite()->annotations.erase(attr::manifest_shared_memory_local_stage);
+      buffer_local_stage_.Set(target_buffer, local_stage);
+      target_buffers_.push_back(target_buffer);
+
+      return std::move(new_block);
+    }
+
+    std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> allocated_buffers(
+        op->alloc_buffers.begin(), op->alloc_buffers.end());
+
+    // Visit children and insert local stages (if any) to the proper location.
+    Array<Buffer> new_alloc_buffers;
+    Array<Stmt> new_seq;
+
+    // Helper function to check if the subtree (body of the block) contains any target buffers.
+    // If so, the allocated intermediate buffer and the local stage should be lifted to the current
+    // block.
+    auto f_check_subtree = [&](int start, int end) {
+      for (int i = start; i < end; ++i) {
+        const Buffer& buffer = target_buffers_[i];
+        if (allocated_buffers.count(buffer)) {
+          new_seq.push_back(buffer_local_stage_.at(buffer));
+          new_alloc_buffers.push_back(buffer_remap_.at(buffer));
+        }
+      }
+    };
+
+    if (const SeqStmtNode* seq = op->body.as<SeqStmtNode>()) {
+      // Visit each element of the SeqStmt. Create a new SeqStmt if any of the children is modified.
+      bool changed = false;  // whether the SeqStmt has been changed
+      for (int i = 0, n = seq->seq.size(); i < n; ++i) {
+        int subtree_start = target_buffers_.size();
+        Stmt new_seq_elem = VisitStmt(seq->seq[i]);
+        int subtree_end = target_buffers_.size();
+        f_check_subtree(subtree_start, subtree_end);
+        new_seq.push_back(new_seq_elem);
+        if (!new_seq_elem.same_as(seq->seq[i])) {
+          changed = true;
+        }
+      }
+      if (!changed) {
+        return GetRef<Stmt>(op);
+      }
+    } else {
+      int subtree_start = target_buffers_.size();
+      Stmt body = VisitStmt(op->body);
+      int subtree_end = target_buffers_.size();
+      f_check_subtree(subtree_start, subtree_end);
+      if (body.same_as(op->body)) {
+        return GetRef<Stmt>(op);
+      }
+      new_seq.push_back(body);
+    }
+
+    Block new_block = GetRef<Block>(op);
+    BlockNode* new_block_node = new_block.CopyOnWrite();
+    // Add new buffer allocations if any.
+    if (new_alloc_buffers.size() > 0) {
+      new_block_node->alloc_buffers = Concat(new_block_node->alloc_buffers, new_alloc_buffers);
+    }
+    new_block_node->body = new_seq.size() == 1 ? new_seq[0] : SeqStmt(new_seq);
+    return std::move(new_block);
+  }
+
+  std::vector<Stmt> ancestor_loop_or_blocks_;  // ancestor loops or block realize
+  Map<Buffer, Buffer> buffer_remap_;  // mapping from the target buffer to the intermediate buffer
+  Map<Buffer, Stmt> buffer_local_stage_;  // mapping from the target buffer to the local stage
+  Array<Buffer> target_buffers_;          // the target buffers for rewriting
+};
+
+namespace transform {
+
+Pass ManifestSharedMemoryLocalStage() {
+  auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
+    auto* n = f.CopyOnWrite();
+    n->body = SharedMemoryLocalStageInserter()(std::move(n->body));
+    return f;
+  };
+  return CreatePrimFuncPass(pass_func, 0, "tir.ManifestSharedMemoryLocalStage", {});
+}
+
+TVM_REGISTER_GLOBAL("tir.transform.ManifestSharedMemoryLocalStage")
+    .set_body_typed(ManifestSharedMemoryLocalStage);
+
+}  // namespace transform
+}  // namespace tir
+}  // namespace tvm
diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py
index cb6cb0f93f..3828412de0 100755
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -4623,7 +4623,7 @@ def test_all_any():
         return lambda x: f(x, dim=dim, keepdim=keepdim)
 
     def test_fn_no_arg(f):
-        return lambda x: f(x)
+        return lambda x: f(x)  # pylint: disable=unnecessary-lambda
 
     for f in [torch.all, torch.any]:
         verify_model(test_fn(f, 0), [torch.rand(1, 2).bool()])
diff --git a/tests/python/unittest/test_tir_transform_manifest_shared_memory_local_stage.py b/tests/python/unittest/test_tir_transform_manifest_shared_memory_local_stage.py
new file mode 100644
index 0000000000..111b91d5fd
--- /dev/null
+++ b/tests/python/unittest/test_tir_transform_manifest_shared_memory_local_stage.py
@@ -0,0 +1,134 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import tvm
+import tvm.testing
+from tvm.script import tir as T
+
+
+# fmt: off
+# pylint: disable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks
+
+
+@tvm.script.ir_module
+class MatmulBefore:
+    @T.prim_func
+    def main(A: T.Buffer[(1024, 1024), "float32"], B: T.Buffer[(1024, 1024), "float32"], C: T.Buffer[(1024, 1024), "float32"]) -> None:
+        # function attr dict
+        T.func_attr({"global_symbol": "default_function", "tir.noalias": True})
+        # body
+        # with T.block("root")
+        for blockIdx_y in T.thread_binding(32, thread="blockIdx.y"):
+            for blockIdx_x in T.thread_binding(32, thread="blockIdx.x"):
+                for threadIdx_y in T.thread_binding(2, thread="threadIdx.y"):
+                    for threadIdx_x in T.thread_binding(2, thread="threadIdx.x"):
+                        for k_0 in T.serial(32):
+                            with T.block():
+                                T.reads(A[blockIdx_y * 32 : blockIdx_y * 32 + 32, k_0 * 32 : k_0 * 32 + 32], B[k_0 * 32 : k_0 * 32 + 32, blockIdx_x * 32 : blockIdx_x * 32 + 32])
+                                T.writes(C[blockIdx_y * 32 : blockIdx_y * 32 + 32, blockIdx_x * 32 : blockIdx_x * 32 + 32])
+                                A_shared = T.alloc_buffer([1024, 1024], dtype="float32", scope="shared")
+                                B_shared = T.alloc_buffer([1024, 1024], dtype="float32", scope="shared")
+                                for ax0_ax1_fused_0 in T.serial(64):
+                                    for ax0_ax1_fused_3 in T.vectorized(4):
+                                        with T.block("A_shared"):
+                                            T.reads(A[blockIdx_y * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32])
+                                            T.writes(A_shared[blockIdx_y * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32])
+                                            T.block_attr({"tir.manifest_shared_memory_local_stage":1})
+                                            A_shared[blockIdx_y * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32] = A[blockIdx_y * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32]
+                                for ax0_ax1_fused_0 in T.serial(64):
+                                    for ax0_ax1_fused_3 in T.vectorized(4):
+                                        with T.block("B_shared"):
+                                            T.reads(B[k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, blockIdx_x * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32])
+                                            T.writes(B_shared[k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, blockIdx_x * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32])
+                                            T.block_attr({"tir.manifest_shared_memory_local_stage":1})
+                                            B_shared[k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, blockIdx_x * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32] = B[k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, blockIdx_x * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32]
+                                for k_1, i_2, j_2, k_2 in T.grid(2, 16, 16, 16):
+                                    with T.block("C"):
+                                        T.reads(A_shared[blockIdx_y * 32 + threadIdx_y * 16 + i_2, k_0 * 32 + k_1 * 16 + k_2], B_shared[k_0 * 32 + k_1 * 16 + k_2, blockIdx_x * 32 + threadIdx_x * 16 + j_2])
+                                        T.writes(C[blockIdx_y * 32 + threadIdx_y * 16 + i_2, blockIdx_x * 32 + threadIdx_x * 16 + j_2])
+                                        if k_0 * 32 + k_1 * 16 + k_2 == 0:
+                                            C[blockIdx_y * 32 + threadIdx_y * 16 + i_2, blockIdx_x * 32 + threadIdx_x * 16 + j_2] = T.float32(0)
+                                        C[blockIdx_y * 32 + threadIdx_y * 16 + i_2, blockIdx_x * 32 + threadIdx_x * 16 + j_2] = C[blockIdx_y * 32 + threadIdx_y * 16 + i_2, blockIdx_x * 32 + threadIdx_x * 16 + j_2] + A_shared[blockIdx_y * 32 + threadIdx_y * 16 + i_2, k_0 * 32 + k_1 * 16 + k_2] * B_shared[k_0 * 32 + k_1 * 16 + k_2, blockIdx_x * 32 + threadIdx_x * 16 + j_2]
+
+
+@tvm.script.ir_module
+class MatmulAfter:
+    @T.prim_func
+    def main(A: T.Buffer[(1024, 1024), "float32"], B: T.Buffer[(1024, 1024), "float32"], C: T.Buffer[(1024, 1024), "float32"]) -> None:
+        # function attr dict
+        T.func_attr({"global_symbol": "default_function", "tir.noalias": True})
+        # body
+        # with T.block("root")
+        for blockIdx_y in T.thread_binding(32, thread="blockIdx.y"):
+            for blockIdx_x in T.thread_binding(32, thread="blockIdx.x"):
+                for threadIdx_y in T.thread_binding(2, thread="threadIdx.y"):
+                    for threadIdx_x in T.thread_binding(2, thread="threadIdx.x"):
+                        for k_0 in T.serial(32):
+                            with T.block():
+                                T.reads(A[blockIdx_y * 32 : blockIdx_y * 32 + 32, k_0 * 32 : k_0 * 32 + 32], B[k_0 * 32 : k_0 * 32 + 32, blockIdx_x * 32 : blockIdx_x * 32 + 32])
+                                T.writes(C[blockIdx_y * 32 : blockIdx_y * 32 + 32, blockIdx_x * 32 : blockIdx_x * 32 + 32])
+                                A_shared = T.alloc_buffer([1024, 1024], dtype="float32", scope="shared")
+                                B_shared = T.alloc_buffer([1024, 1024], dtype="float32", scope="shared")
+                                A_shared_local = T.alloc_buffer([64, 4], dtype="float32", scope="local")
+                                B_shared_local = T.alloc_buffer([64, 4], dtype="float32", scope="local")
+                                for ax0_ax1_fused_0 in T.serial(64):
+                                    for ax0_ax1_fused_3 in T.vectorized(4):
+                                        with T.block():
+                                            T.reads(A[blockIdx_y * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32])
+                                            T.writes(A_shared_local[ax0_ax1_fused_0, ax0_ax1_fused_3])
+                                            A_shared_local[ax0_ax1_fused_0, ax0_ax1_fused_3] = A[blockIdx_y * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32]
+                                for ax0_ax1_fused_0 in T.serial(64):
+                                    for ax0_ax1_fused_3 in T.vectorized(4):
+                                        with T.block("A_shared"):
+                                            T.reads(A[blockIdx_y * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32])
+                                            T.writes(A_shared[blockIdx_y * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32])
+                                            A_shared[blockIdx_y * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32] = A_shared_local[ax0_ax1_fused_0, ax0_ax1_fused_3]
+                                for ax0_ax1_fused_0 in T.serial(64):
+                                    for ax0_ax1_fused_3 in T.vectorized(4):
+                                        with T.block():
+                                            T.reads(B[k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, blockIdx_x * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32])
+                                            T.writes(B_shared_local[ax0_ax1_fused_0, ax0_ax1_fused_3])
+                                            B_shared_local[ax0_ax1_fused_0, ax0_ax1_fused_3] = B[k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, blockIdx_x * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32]
+                                for ax0_ax1_fused_0 in T.serial(64):
+                                    for ax0_ax1_fused_3 in T.vectorized(4):
+                                        with T.block("B_shared"):
+                                            T.reads(B[k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, blockIdx_x * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32])
+                                            T.writes(B_shared[k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, blockIdx_x * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32])
+                                            B_shared[k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, blockIdx_x * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32] = B_shared_local[ax0_ax1_fused_0, ax0_ax1_fused_3]
+                                for k_1, i_2, j_2, k_2 in T.grid(2, 16, 16, 16):
+                                    with T.block("C"):
+                                        T.reads(A_shared[blockIdx_y * 32 + threadIdx_y * 16 + i_2, k_0 * 32 + k_1 * 16 + k_2], B_shared[k_0 * 32 + k_1 * 16 + k_2, blockIdx_x * 32 + threadIdx_x * 16 + j_2])
+                                        T.writes(C[blockIdx_y * 32 + threadIdx_y * 16 + i_2, blockIdx_x * 32 + threadIdx_x * 16 + j_2])
+                                        if k_0 * 32 + k_1 * 16 + k_2 == 0:
+                                            C[blockIdx_y * 32 + threadIdx_y * 16 + i_2, blockIdx_x * 32 + threadIdx_x * 16 + j_2] = T.float32(0)
+                                        C[blockIdx_y * 32 + threadIdx_y * 16 + i_2, blockIdx_x * 32 + threadIdx_x * 16 + j_2] = C[blockIdx_y * 32 + threadIdx_y * 16 + i_2, blockIdx_x * 32 + threadIdx_x * 16 + j_2] + A_shared[blockIdx_y * 32 + threadIdx_y * 16 + i_2, k_0 * 32 + k_1 * 16 + k_2] * B_shared[k_0 * 32 + k_1 * 16 + k_2, blockIdx_x * 32 + threadIdx_x * 16 + j_2]
+
+
+# fmt: on
+# pylint: enable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks
+
+
+def _check(before, expected):
+    after = tvm.tir.transform.ManifestSharedMemoryLocalStage()(before)
+    tvm.ir.assert_structural_equal(after, expected)
+
+
+def test_transform_matmul():
+    _check(MatmulBefore, MatmulAfter)
+
+
+if __name__ == "__main__":
+    tvm.testing.main()