You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by sy...@apache.org on 2022/06/24 11:02:12 UTC

[tvm] branch main updated: [TIR][Pass] Remove-Weight-Layout-Rewrite-Block (#11870)

This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 8c3d922b7e [TIR][Pass] Remove-Weight-Layout-Rewrite-Block (#11870)
8c3d922b7e is described below

commit 8c3d922b7ed019cd9c00cb763a5b76fa5a7af664
Author: Junru Shao <ju...@gmail.com>
AuthorDate: Fri Jun 24 04:02:06 2022 -0700

    [TIR][Pass] Remove-Weight-Layout-Rewrite-Block (#11870)
---
 include/tvm/meta_schedule/apply_history_best.h     |   8 +-
 include/tvm/tir/stmt.h                             |  10 +-
 include/tvm/tir/transform.h                        |   6 +
 python/tvm/meta_schedule/apply_history_best.py     |   6 +-
 python/tvm/meta_schedule/builder/local_builder.py  |   8 +-
 python/tvm/relay/backend/te_compiler.py            |  15 ++-
 python/tvm/relay/build_module.py                   |  12 +-
 python/tvm/tir/transform/transform.py              |  12 +-
 src/meta_schedule/apply_history_best.cc            |   7 +-
 src/meta_schedule/arg_info.cc                      |   9 +-
 .../feature_extractor/per_store_feature.cc         |   1 +
 src/meta_schedule/utils.h                          |   1 +
 src/relay/backend/te_compiler_cache.cc             |  27 ++++-
 src/te/operation/create_primfunc.cc                |   8 +-
 .../remove_weight_layout_rewrite_block.cc          | 121 +++++++++++++++++++++
 ...e_postproc_rewrite_parallel_vectorize_unroll.py |   6 -
 tests/python/unittest/test_te_create_primfunc.py   |   2 +-
 ...transform_remove_weight_layout_rewrite_block.py |  91 ++++++++++++++++
 18 files changed, 308 insertions(+), 42 deletions(-)

diff --git a/include/tvm/meta_schedule/apply_history_best.h b/include/tvm/meta_schedule/apply_history_best.h
index 3a8983012b..8405ebbacf 100644
--- a/include/tvm/meta_schedule/apply_history_best.h
+++ b/include/tvm/meta_schedule/apply_history_best.h
@@ -39,8 +39,11 @@ namespace meta_schedule {
  */
 class ApplyHistoryBestNode : public runtime::Object {
  public:
+  /*! \brief A callback function that filters TE compute */
   using FTEFilterFunc =
       runtime::TypedPackedFunc<Optional<tir::PrimFunc>(const Array<te::Tensor, void>&)>;
+  /*! \brief  A callback function that takes a tuning record and does something with it */
+  using FTakeTuningRecord = runtime::TypedPackedFunc<void(const TuningRecord&)>;
 
   /*! \brief The database to be queried from */
   Database database{nullptr};
@@ -60,9 +63,12 @@ class ApplyHistoryBestNode : public runtime::Object {
    * \param mod The module to be queried
    * \param target The target to be queried
    * \param dispatched The IRs after dispatch
+   * \param f_take_tuning_record A callback function that takes a tuning record and does something
+   * with it
    */
   Optional<IRModule> Query(runtime::String task_name, IRModule mod, Target target,
-                           Optional<Array<IRModule>> dispatched);
+                           Optional<Array<IRModule>> dispatched,
+                           FTakeTuningRecord f_take_tuning_record);
 
   static constexpr const char* _type_key = "meta_schedule.ApplyHistoryBest";
   TVM_DECLARE_FINAL_OBJECT_INFO(ApplyHistoryBestNode, runtime::Object);
diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h
index fc02550c7e..4c8a3076a2 100644
--- a/include/tvm/tir/stmt.h
+++ b/include/tvm/tir/stmt.h
@@ -1483,6 +1483,9 @@ constexpr const char* software_pipeline_stage = "software_pipeline_stage";
 /*! \brief Mark the order of a statement in the software pipeline */
 constexpr const char* software_pipeline_order = "software_pipeline_order";
 
+/*! \brief Mark the buffers which is const access and can be transformed layout. */
+constexpr const char* layout_free_buffers = "layout_free_buffers";
+
 /*! \brief Mark the tiling structure of blocks that are applied by rule Multi-Level-Tiling */
 constexpr const char* meta_schedule_tiling_structure = "meta_schedule.tiling_structure";
 
@@ -1516,11 +1519,12 @@ constexpr const char* meta_schedule_unroll_explicit = "meta_schedule.unroll_expl
 /*! \brief Mark auto-unroll setting on the block. */
 constexpr const char* meta_schedule_unroll_implicit = "meta_schedule.unroll_implicit";
 
-/*!
- * \brief Mark that a block should be further rewritten using tensorization.
- */
+/*! \brief Mark that a block should be further rewritten using tensorization. */
 constexpr const char* meta_schedule_auto_tensorize = "meta_schedule.auto_tensorize";
 
+/*! \brief Mark that a block is a preprocessor block for layout rewrite. */
+constexpr const char* meta_schedule_layout_rewrite_preproc = "meta_schedule.layout_rewrite_preproc";
+
 /*!
  * \brief Check if attr_key is a pragma key extension
  * \param attr_key The attr key to be compared
diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h
index 39a6459048..74f13420a2 100644
--- a/include/tvm/tir/transform.h
+++ b/include/tvm/tir/transform.h
@@ -650,6 +650,12 @@ TVM_DLL Pass Filter(runtime::TypedPackedFunc<bool(PrimFunc)> fcond);
  */
 TVM_DLL Pass InjectPTXAsyncCopy();
 
+/*!
+ * \brief Remove the weight layout rewrite block
+ * \return The pass.
+ */
+TVM_DLL Pass RemoveWeightLayoutRewriteBlock();
+
 }  // namespace transform
 }  // namespace tir
 }  // namespace tvm
diff --git a/python/tvm/meta_schedule/apply_history_best.py b/python/tvm/meta_schedule/apply_history_best.py
index d618c3a04f..1a8ab2d358 100644
--- a/python/tvm/meta_schedule/apply_history_best.py
+++ b/python/tvm/meta_schedule/apply_history_best.py
@@ -26,7 +26,7 @@ from tvm.te import Tensor
 from tvm.tir import PrimFunc
 
 from . import _ffi_api
-from .database import Database
+from .database import Database, TuningRecord
 from .utils import make_logging_func
 
 logger = logging.getLogger(__name__)  # pylint: disable=invalid-name
@@ -71,6 +71,7 @@ class ApplyHistoryBest(Object):
         mod: IRModule,
         target: Target,
         dispatched: Optional[List[IRModule]],
+        f_take_tuning_record: Callable[[TuningRecord], None] = None,
     ) -> Union[IRModule, None]:
         """The entry point of the integration
 
@@ -84,6 +85,8 @@ class ApplyHistoryBest(Object):
             Target Info
         dispatched : Optional[List[IRModule]]
             A list of low-level IRs that the high-level IR could potentially dispatch to
+        f_take_tuning_record : Callable[[TuningRecord], None] = None
+            A callback function that takes a tuning record and does something with it
 
         Returns
         -------
@@ -97,6 +100,7 @@ class ApplyHistoryBest(Object):
             mod,
             target,
             dispatched,
+            f_take_tuning_record,
         )
 
     @staticmethod
diff --git a/python/tvm/meta_schedule/builder/local_builder.py b/python/tvm/meta_schedule/builder/local_builder.py
index 6f0f523b47..69e7b0ca60 100644
--- a/python/tvm/meta_schedule/builder/local_builder.py
+++ b/python/tvm/meta_schedule/builder/local_builder.py
@@ -26,11 +26,7 @@ from tvm.runtime import Module, NDArray, load_param_dict, save_param_dict
 from tvm.target import Target
 
 from ...contrib.popen_pool import MapResult, PopenPoolExecutor, StatusKind
-from ..utils import (
-    cpu_count,
-    derived_object,
-    get_global_func_with_default_on_worker,
-)
+from ..utils import cpu_count, derived_object, get_global_func_with_default_on_worker
 from .builder import BuilderInput, BuilderResult, PyBuilder
 
 logger = logging.getLogger(__name__)  # pylint: disable=invalid-name
@@ -258,8 +254,10 @@ def default_build(mod: IRModule, target: Target, _params: Optional[Dict[str, NDA
     """
     # pylint: disable=import-outside-toplevel
     from tvm.driver import build as tvm_build
+    from tvm.tir.transform import RemoveWeightLayoutRewriteBlock
 
     # pylint: enable=import-outside-toplevel
+    mod = RemoveWeightLayoutRewriteBlock()(mod)
     return tvm_build(mod, target=target)
 
 
diff --git a/python/tvm/relay/backend/te_compiler.py b/python/tvm/relay/backend/te_compiler.py
index 654c8f66ac..9b2907ccdb 100644
--- a/python/tvm/relay/backend/te_compiler.py
+++ b/python/tvm/relay/backend/te_compiler.py
@@ -19,15 +19,16 @@
 from __future__ import absolute_import
 
 import logging
+
 import tvm
-from tvm import te, autotvm
-from tvm.ir.transform import PassContext
+from tvm import autotvm, te
 from tvm.runtime import Object
 from tvm.support import libinfo
 from tvm.target import Target
-from ..backend.utils import mangle_module_name
+
 from .. import function as _function
 from .. import ty as _ty
+from ..backend.utils import mangle_module_name
 from . import _backend
 
 logger = logging.getLogger("te_compiler")
@@ -170,6 +171,12 @@ def select_implementation(op, attrs, inputs, out_type, target, use_autotvm=True)
     ret : tuple(relay.op.OpImplementation, List[tvm.te.Tensor])
         The best op implementation and the corresponding output tensors.
     """
+    # pylint: disable=import-outside-toplevel
+    from tvm.auto_scheduler import is_auto_scheduler_enabled
+    from tvm.meta_schedule import is_meta_schedule_enabled
+
+    # pylint: enable=import-outside-toplevel
+
     all_impls = get_valid_implementations(op, attrs, inputs, out_type, target)
     if len(all_impls) == 0:
         raise RuntimeError(f"No valid {op} implementations for {target}")
@@ -177,7 +184,7 @@ def select_implementation(op, attrs, inputs, out_type, target, use_autotvm=True)
 
     # Disable autotvm if auto_scheduler is enabled.
     # (i.e., always return the implementation with the highest priority for auto-scheduler).
-    if PassContext.current().config.get("relay.backend.use_auto_scheduler", False):
+    if is_auto_scheduler_enabled() or is_meta_schedule_enabled():
         use_autotvm = False
 
     # If not use autotvm, always return the implementation with the highest priority
diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py
index 23892554cf..1353d8c5f5 100644
--- a/python/tvm/relay/build_module.py
+++ b/python/tvm/relay/build_module.py
@@ -22,7 +22,6 @@ import warnings
 
 import numpy as np
 from tvm.ir import IRModule
-from tvm.ir.transform import PassContext
 from tvm.target import Target
 
 from .. import autotvm
@@ -139,20 +138,23 @@ class BuildModule(object):
         params : dict
             The parameters of the final graph.
         """
+        # pylint: disable=import-outside-toplevel
+        from tvm.auto_scheduler import is_auto_scheduler_enabled
+        from tvm.meta_schedule import is_meta_schedule_enabled
 
+        # pylint: enable=import-outside-toplevel
         # Setup the params.
         if params:
             self._set_params(params)
 
         # Build the IR module. If auto_scheduler is not enabled,
         # then use the TOPI-defined schedule.
-        use_auto_scheduler = PassContext.current().config.get(
-            "relay.backend.use_auto_scheduler", False
-        )
 
         # Turn off AutoTVM config not found warnings if auto_scheduler is enabled.
         old_autotvm_silent = autotvm.GLOBAL_SCOPE.silent
-        autotvm.GLOBAL_SCOPE.silent = use_auto_scheduler or old_autotvm_silent
+        autotvm.GLOBAL_SCOPE.silent = (
+            is_auto_scheduler_enabled() or is_meta_schedule_enabled() or old_autotvm_silent
+        )
 
         mod_name = mangle_module_name(mod_name)
 
diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py
index e1ddfe439a..9a20f9a777 100644
--- a/python/tvm/tir/transform/transform.py
+++ b/python/tvm/tir/transform/transform.py
@@ -16,7 +16,7 @@
 # under the License.
 """Wrapping existing transformations."""
 # pylint: disable=invalid-name
-from typing import Optional, Callable
+from typing import Callable, Optional
 
 from . import _ffi_api
 from . import function_pass as _fpass
@@ -836,3 +836,13 @@ def InjectPTXAsyncCopy():
         The result pass
     """
     return _ffi_api.InjectPTXAsyncCopy()  # type: ignore
+
+
+def RemoveWeightLayoutRewriteBlock():
+    """Remove weight layout rewrite block before benchmarking during tuning stage.
+    Returns
+    -------
+    fpass : tvm.transform.Pass
+        The result pass
+    """
+    return _ffi_api.RemoveWeightLayoutRewriteBlock()  # type: ignore
diff --git a/src/meta_schedule/apply_history_best.cc b/src/meta_schedule/apply_history_best.cc
index e5cc929fd0..22445a9cf7 100644
--- a/src/meta_schedule/apply_history_best.cc
+++ b/src/meta_schedule/apply_history_best.cc
@@ -103,8 +103,8 @@ ApplyHistoryBest::ApplyHistoryBest(Database database,
 }
 
 Optional<IRModule> ApplyHistoryBestNode::Query(runtime::String task_name, IRModule mod,
-                                               Target target,
-                                               Optional<Array<IRModule>> dispatched) {
+                                               Target target, Optional<Array<IRModule>> dispatched,
+                                               FTakeTuningRecord f_take_tuning_record) {
   ICHECK(dispatched.defined());
   ICHECK_EQ(dispatched.value().size(), 1);
   ICHECK(HasOnlyOneFunction<relay::Function>(mod)) << mod;
@@ -122,6 +122,9 @@ Optional<IRModule> ApplyHistoryBestNode::Query(runtime::String task_name, IRModu
   if (database->HasWorkload(prim_mod)) {
     Array<TuningRecord> records = database->GetTopK(database->CommitWorkload(prim_mod), 1);
     if (records.size() == 1) {
+      if (f_take_tuning_record != nullptr) {
+        f_take_tuning_record(records[0]);
+      }
       tir::Schedule sch =
           tir::Schedule::Traced(records[0]->workload->mod, /*seed=*/-1, /*debug_mask=*/0,
                                 /*error_render_level=*/tir::ScheduleErrorRenderLevel::kNone);
diff --git a/src/meta_schedule/arg_info.cc b/src/meta_schedule/arg_info.cc
index 37897a5ac6..672df86deb 100644
--- a/src/meta_schedule/arg_info.cc
+++ b/src/meta_schedule/arg_info.cc
@@ -61,11 +61,10 @@ Array<ArgInfo> ArgInfo::FromPrimFunc(const tir::PrimFunc& func) {
 }
 
 Array<ArgInfo> ArgInfo::FromEntryFunc(const IRModule& mod, bool remove_preproc) {
-  // TODO(@jinhongyii): add pass for layout rewrite
-  // if (remove_preproc) {
-  //   IRModule new_mod = tir::transform::RemoveWeightLayoutRewriteBlock()(mod);
-  //   return ArgInfo::FromPrimFunc(FindEntryFunc(new_mod));
-  // }
+  if (remove_preproc) {
+    IRModule new_mod = tir::transform::RemoveWeightLayoutRewriteBlock()(mod);
+    return ArgInfo::FromPrimFunc(FindEntryFunc(new_mod));
+  }
   return ArgInfo::FromPrimFunc(FindEntryFunc(mod));
 }
 
diff --git a/src/meta_schedule/feature_extractor/per_store_feature.cc b/src/meta_schedule/feature_extractor/per_store_feature.cc
index d3d63e7824..93f6767b11 100644
--- a/src/meta_schedule/feature_extractor/per_store_feature.cc
+++ b/src/meta_schedule/feature_extractor/per_store_feature.cc
@@ -300,6 +300,7 @@ Pass SimplifyForFeatureExtraction() {
  */
 Sequential PassListForPerStoreFeature() {
   return Sequential({
+      tir::transform::RemoveWeightLayoutRewriteBlock(),
       tir::transform::SimplifyForFeatureExtraction(),
       tir::transform::LowerCrossThreadReduction(),
       tir::transform::LowerInitBlock(),
diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h
index 76deb62f23..ca696da71e 100644
--- a/src/meta_schedule/utils.h
+++ b/src/meta_schedule/utils.h
@@ -41,6 +41,7 @@
 #include <tvm/runtime/container/optional.h>
 #include <tvm/support/parallel_for.h>
 #include <tvm/tir/schedule/schedule.h>
+#include <tvm/tir/transform.h>
 
 #include <algorithm>
 #include <string>
diff --git a/src/relay/backend/te_compiler_cache.cc b/src/relay/backend/te_compiler_cache.cc
index 8715900c0c..0f519721b0 100644
--- a/src/relay/backend/te_compiler_cache.cc
+++ b/src/relay/backend/te_compiler_cache.cc
@@ -36,6 +36,8 @@
 #include <tvm/te/schedule.h>
 #include <tvm/te/schedule_pass.h>
 #include <tvm/tir/function.h>
+#include <tvm/tir/index_map.h>
+#include <tvm/tir/transform.h>
 #include <tvm/topi/tags.h>
 
 #include <functional>
@@ -47,6 +49,7 @@
 
 #include "../../te/operation/create_primfunc.h"
 #include "../op/memory/memory.h"
+#include "../transforms/meta_schedule_layout_rewrite.h"
 #include "../transforms/pass_utils.h"
 #include "utils.h"
 
@@ -59,6 +62,16 @@ TVM_REGISTER_NODE_TYPE(CachedFuncNode);
 TVM_REGISTER_NODE_TYPE(CCacheKeyNode);
 TVM_REGISTER_NODE_TYPE(CCacheValueNode);
 
+void ExtractTransformLayout(const meta_schedule::TuningRecord& record) {
+  static tir::InstructionKind kind_transform_layout = tir::InstructionKind::Get("TransformLayout");
+  for (const tir::Instruction& inst : record->trace->insts) {
+    if (inst->kind.same_as(kind_transform_layout)) {
+      ICHECK_EQ(inst->attrs.size(), 3);
+      relay::MetaScheduleLayoutRewriter::LayoutQueuePush(Downcast<tir::IndexMap>(inst->attrs[2]));
+    }
+  }
+}
+
 LoweredOutput::LoweredOutput(tvm::Array<te::Tensor> outputs, OpImplementation impl) {
   auto n = make_object<LoweredOutputNode>();
   n->outputs = std::move(outputs);
@@ -353,10 +366,16 @@ class ScheduleBuilder : public ExprVisitor {
                 meta_schedule_ctx_.value()->te_filter_func(te_args)) {
           IRModule relay_mod({{prim_fn_var, relay_func}});
           IRModule tir_mod({{prim_fn_var, tir_func.value()}});
-          if (Optional<IRModule> scheduled_mod = meta_schedule_ctx_.value()->Query(
-                  prim_fn_var->name_hint, relay_mod, target_, Array<IRModule>{tir_mod})) {
-            ICHECK_EQ(scheduled_mod.value()->functions.count(prim_fn_var), 1);
-            prim_func = Downcast<tir::PrimFunc>(scheduled_mod.value()->functions[prim_fn_var]);
+          if (Optional<IRModule> opt_scheduled_mod = meta_schedule_ctx_.value()->Query(
+                  /*task_name=*/prim_fn_var->name_hint,     //
+                  /*mod=*/relay_mod,                        //
+                  /*target=*/target_,                       //
+                  /*dispatched=*/Array<IRModule>{tir_mod},  //
+                  /*f_take_tuning_record=*/ExtractTransformLayout)) {
+            IRModule scheduled_mod =
+                tir::transform::RemoveWeightLayoutRewriteBlock()(opt_scheduled_mod.value());
+            ICHECK_EQ(scheduled_mod->functions.count(prim_fn_var), 1);
+            prim_func = Downcast<tir::PrimFunc>(scheduled_mod->functions[prim_fn_var]);
           }
         }
       }
diff --git a/src/te/operation/create_primfunc.cc b/src/te/operation/create_primfunc.cc
index 2aeb799a04..e361e8e344 100644
--- a/src/te/operation/create_primfunc.cc
+++ b/src/te/operation/create_primfunc.cc
@@ -103,12 +103,12 @@ class LayoutFreePlaceholdersNormalizer : public StmtMutator {
     for (int i : this->layout_free_buffer_indices_) {
       indices.push_back(Integer(i));
     }
-    return WithAttr(std::move(func), attr, indices);
+    return WithAttr(std::move(func), tir::attr::layout_free_buffers, indices);
   }
 
   Stmt VisitStmt_(const BlockNode* _block) final {
     Block block = Downcast<Block>(StmtMutator::VisitStmt_(_block));
-    if (Optional<ObjectRef> ann = block->annotations.Get(attr)) {
+    if (Optional<ObjectRef> ann = block->annotations.Get(topi_attr)) {
       Array<Buffer> buffers = Downcast<Array<Buffer>>(ann);
       for (Buffer buffer : buffers) {
         auto it = buffer2index_.find(buffer);
@@ -116,14 +116,14 @@ class LayoutFreePlaceholdersNormalizer : public StmtMutator {
           layout_free_buffer_indices_.insert(it->second);
         }
       }
-      block.CopyOnWrite()->annotations.erase(attr);
+      block.CopyOnWrite()->annotations.erase(topi_attr);
     }
     return block;
   }
 
   std::unordered_map<tir::Buffer, int, ObjectPtrHash, ObjectPtrEqual> buffer2index_;
   std::set<int> layout_free_buffer_indices_;
-  String attr = "layout_free_placeholders";
+  String topi_attr = "layout_free_placeholders";
 };
 
 BlockRealize GenerateBlockFromTensors(const te::ComputeOp& compute_op,
diff --git a/src/tir/transforms/remove_weight_layout_rewrite_block.cc b/src/tir/transforms/remove_weight_layout_rewrite_block.cc
new file mode 100644
index 0000000000..5f47e670c6
--- /dev/null
+++ b/src/tir/transforms/remove_weight_layout_rewrite_block.cc
@@ -0,0 +1,121 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file remove_weight_layout_rewrite_block.cc
+ * \brief Remove weight layout rewrite block before benchmark
+ */
+
+#include <tvm/tir/op.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+
+namespace tvm {
+namespace tir {
+
+class WeightLayoutRewriteBlockRemover : public StmtMutator {
+ public:
+  static PrimFunc Remove(PrimFunc f) {
+    WeightLayoutRewriteBlockRemover remover;
+    PrimFuncNode* n = f.CopyOnWrite();
+    n->body = remover(std::move(n->body));
+    Map<tir::Var, Buffer> buffer_map;
+    for (const auto& kv : f->buffer_map) {
+      Var param = kv.first;
+      Buffer buffer = kv.second;
+      auto it = remover.buf_map_.find(buffer);
+      if (it != remover.buf_map_.end()) {
+        buffer_map.Set(param, (*it).second);
+      } else {
+        buffer_map.Set(param, buffer);
+      }
+    }
+    n->buffer_map = std::move(buffer_map);
+    return f;
+  }
+
+ private:
+  Stmt VisitStmt_(const BlockNode* op) final {
+    Block block = Downcast<Block>(StmtMutator::VisitStmt_(op));
+
+    auto it = block->annotations.find(attr::meta_schedule_layout_rewrite_preproc);
+    if (it == block->annotations.end() || !is_one(Downcast<PrimExpr>((*it).second))) {
+      // The block is not a weight layout block
+      // Remove allocates if needed
+      Array<Buffer> alloc_buffers;
+      for (const Buffer& buffer : block->alloc_buffers) {
+        if (!rewritten_buffers_.count(buffer)) {
+          alloc_buffers.push_back(buffer);
+        }
+      }
+      if (alloc_buffers.size() < block->alloc_buffers.size()) {
+        auto n = CopyOnWrite(block.get());
+        n->alloc_buffers = std::move(alloc_buffers);
+        return Stmt(n);
+      } else {
+        return block;
+      }
+    }
+
+    // Step 0. Checking block attrs
+    ICHECK(block->alloc_buffers.empty());
+    ICHECK(block->match_buffers.empty());
+
+    // Step 1. Checking the body is a BufferStore
+    const auto* store = block->body.as<BufferStoreNode>();
+    ICHECK(store);
+
+    // Step 2. Checking the rhs of buffer store is a BufferLoad
+    const auto* load = store->value.as<BufferLoadNode>();
+    ICHECK(load);
+
+    // Step 3. Update Buffer
+    buf_map_.Set(load->buffer, store->buffer);
+    rewritten_buffers_.insert(store->buffer);
+
+    // Step 4. Set block body as no_op
+    auto n = CopyOnWrite(block.get());
+    n->body = std::move(Evaluate(0));
+    n->reads = {};
+    n->writes = {};
+    return Stmt(n);
+  }
+
+ private:
+  /*! \brief The buffer map from original layout buffer to rewritten buffer */
+  Map<Buffer, Buffer> buf_map_;
+  /*! \brief The buffer map from original layout buffer to rewritten buffer */
+  std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> rewritten_buffers_;
+};
+namespace transform {
+
+Pass RemoveWeightLayoutRewriteBlock() {
+  auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
+    return WeightLayoutRewriteBlockRemover::Remove(std::move(f));
+  };
+  return CreatePrimFuncPass(pass_func, 0, "tir.RemoveWeightLayoutRewriteBlock", {});
+}
+
+TVM_REGISTER_GLOBAL("tir.transform.RemoveWeightLayoutRewriteBlock")
+    .set_body_typed(RemoveWeightLayoutRewriteBlock);
+
+}  // namespace transform
+
+}  // namespace tir
+}  // namespace tvm
diff --git a/tests/python/unittest/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py b/tests/python/unittest/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py
index f9b71bfdb6..44b0e79f0c 100644
--- a/tests/python/unittest/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py
+++ b/tests/python/unittest/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py
@@ -74,10 +74,6 @@ def Move_PUV0(a: T.handle, b: T.handle) -> None:
 class Fused_NN_Dense:
     @T.prim_func
     def main(placeholder: T.Buffer[(64, 768), "float32"], placeholder_1: T.Buffer[(768, 768), "float32"], T_matmul_NT: T.Buffer[(64, 768), "float32"]) -> None:
-        # function attr dict
-        T.func_attr({"global_symbol": "main", "tir.noalias": True, "layout_free_placeholders": [1]})
-        # body
-        # with T.block("root")
         for i0, i1, i2 in T.grid(64, 768, 768):
             with T.block("T_matmul_NT"):
                 i, j, k = T.axis.remap("SSR", [i0, i1, i2])
@@ -93,7 +89,6 @@ def before_matmul_vectorize(
     placeholder_1: T.Buffer[(768, 768), "float32"],
     T_matmul_NT: T.Buffer[(64, 768), "float32"],
 ) -> None:
-    T.func_attr({"global_symbol": "main", "tir.noalias": True, "layout_free_placeholders": [1]})
     with T.block("root"):
         T.reads()
         T.writes()
@@ -124,7 +119,6 @@ def after_matmul_vectorize(
     placeholder_1: T.Buffer[(768, 768), "float32"],
     T_matmul_NT: T.Buffer[(64, 768), "float32"],
 ) -> None:
-    T.func_attr({"global_symbol": "main", "tir.noalias": True, "layout_free_placeholders": [1]})
     T_matmul_NT_global = T.alloc_buffer([64, 768], dtype="float32")
     for i0_0, i1_0, i0_1, i1_1 in T.grid(1, 16, 1, 3):
         for i2_0, i0_2, i1_2, i2_1, i0_3 in T.grid(48, 8, 1, 16, 8):
diff --git a/tests/python/unittest/test_te_create_primfunc.py b/tests/python/unittest/test_te_create_primfunc.py
index 5d9ad003b4..d3f444ec08 100644
--- a/tests/python/unittest/test_te_create_primfunc.py
+++ b/tests/python/unittest/test_te_create_primfunc.py
@@ -377,7 +377,7 @@ def expected_layout_attr(
     B: T.Buffer[(128, 128), "float32"],
     D: T.Buffer[(128, 128), "float32"],
 ) -> None:
-    T.func_attr({"global_symbol": "main", "tir.noalias": True, "layout_free_placeholders": [1]})
+    T.func_attr({"global_symbol": "main", "tir.noalias": True, "layout_free_buffers": [1]})
     C = T.alloc_buffer([128, 128], dtype="float32")
     for i0, i1, i2 in T.grid(128, 128, 128):
         with T.block("C"):
diff --git a/tests/python/unittest/test_tir_transform_remove_weight_layout_rewrite_block.py b/tests/python/unittest/test_tir_transform_remove_weight_layout_rewrite_block.py
new file mode 100644
index 0000000000..7a01428381
--- /dev/null
+++ b/tests/python/unittest/test_tir_transform_remove_weight_layout_rewrite_block.py
@@ -0,0 +1,91 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import sys
+
+import tvm
+from tvm.ir.module import IRModule
+from tvm.script import tir as T
+from tvm.tir.function import PrimFunc
+
+
+def _check(before, expect):
+    if isinstance(before, PrimFunc):
+        before = IRModule({"main": before})
+    if isinstance(expect, PrimFunc):
+        expect = IRModule({"main": expect})
+
+    mod = tvm.tir.transform.RemoveWeightLayoutRewriteBlock()(before)
+    tvm.ir.assert_structural_equal(mod, expect)
+
+
+def test_matmul():
+    @T.prim_func
+    def before(
+        A: T.Buffer[(16, 16), "float32"],
+        B: T.Buffer[(16, 16), "float32"],
+        C: T.Buffer[(16, 16), "float32"],
+    ) -> None:
+        T.func_attr({"layout_free_buffers": [1]})
+        B_ = T.alloc_buffer([16, 4, 4], dtype="float32")
+        for i0_o, i1_o in T.grid(16, 16):
+            with T.block("layout_rewrite"):
+                i0, i1 = T.axis.remap("SS", [i0_o, i1_o])
+                T.reads(B[i0, i1])
+                T.writes(B_[i1, i0 // 4, i0 % 4])
+                T.block_attr({"meta_schedule.layout_rewrite_preproc": True})
+                B_[i1, i0 // 4, i0 % 4] = B[i0, i1]
+        for i0, j, k0, i1, k1 in T.grid(4, 16, 4, 4, 4):
+            with T.block("matmul"):
+                vi = T.axis.spatial(16, i0 * 4 + i1)
+                vj = T.axis.spatial(16, j)
+                vk = T.axis.reduce(16, k0 * 4 + k1)
+                T.reads(A[vi, vk], B_[vj, vk // 4, vk % 4])
+                T.writes(C[vi, vj])
+                with T.init():
+                    C[vi, vj] = T.float32(0)
+                C[vi, vj] = C[vi, vj] + A[vi, vk] * B_[vj, vk // 4, vk % 4]
+
+    @T.prim_func
+    def after(
+        A: T.Buffer[(16, 16), "float32"],
+        B: T.Buffer[(16, 4, 4), "float32"],
+        C: T.Buffer[(16, 16), "float32"],
+    ) -> None:
+        T.func_attr({"layout_free_buffers": [1]})
+        for i0_o, i1_o in T.grid(16, 16):
+            with T.block("layout_rewrite"):
+                i0, i1 = T.axis.remap("SS", [i0_o, i1_o])
+                T.reads()
+                T.writes()
+                T.block_attr({"meta_schedule.layout_rewrite_preproc": True})
+                T.evaluate(0)
+        for i0, j, k0, i1, k1 in T.grid(4, 16, 4, 4, 4):
+            with T.block("matmul"):
+                vi = T.axis.spatial(16, i0 * 4 + i1)
+                vj = T.axis.spatial(16, j)
+                vk = T.axis.reduce(16, k0 * 4 + k1)
+                T.reads(A[vi, vk], B[vj, vk // 4, vk % 4])
+                T.writes(C[vi, vj])
+                with T.init():
+                    C[vi, vj] = T.float32(0)
+                C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk // 4, vk % 4]
+
+    _check(before, after)
+
+
+if __name__ == "__main__":
+    test_matmul()