You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by sy...@apache.org on 2022/06/24 11:02:12 UTC
[tvm] branch main updated: [TIR][Pass] Remove-Weight-Layout-Rewrite-Block (#11870)
This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 8c3d922b7e [TIR][Pass] Remove-Weight-Layout-Rewrite-Block (#11870)
8c3d922b7e is described below
commit 8c3d922b7ed019cd9c00cb763a5b76fa5a7af664
Author: Junru Shao <ju...@gmail.com>
AuthorDate: Fri Jun 24 04:02:06 2022 -0700
[TIR][Pass] Remove-Weight-Layout-Rewrite-Block (#11870)
---
include/tvm/meta_schedule/apply_history_best.h | 8 +-
include/tvm/tir/stmt.h | 10 +-
include/tvm/tir/transform.h | 6 +
python/tvm/meta_schedule/apply_history_best.py | 6 +-
python/tvm/meta_schedule/builder/local_builder.py | 8 +-
python/tvm/relay/backend/te_compiler.py | 15 ++-
python/tvm/relay/build_module.py | 12 +-
python/tvm/tir/transform/transform.py | 12 +-
src/meta_schedule/apply_history_best.cc | 7 +-
src/meta_schedule/arg_info.cc | 9 +-
.../feature_extractor/per_store_feature.cc | 1 +
src/meta_schedule/utils.h | 1 +
src/relay/backend/te_compiler_cache.cc | 27 ++++-
src/te/operation/create_primfunc.cc | 8 +-
.../remove_weight_layout_rewrite_block.cc | 121 +++++++++++++++++++++
...e_postproc_rewrite_parallel_vectorize_unroll.py | 6 -
tests/python/unittest/test_te_create_primfunc.py | 2 +-
...transform_remove_weight_layout_rewrite_block.py | 91 ++++++++++++++++
18 files changed, 308 insertions(+), 42 deletions(-)
diff --git a/include/tvm/meta_schedule/apply_history_best.h b/include/tvm/meta_schedule/apply_history_best.h
index 3a8983012b..8405ebbacf 100644
--- a/include/tvm/meta_schedule/apply_history_best.h
+++ b/include/tvm/meta_schedule/apply_history_best.h
@@ -39,8 +39,11 @@ namespace meta_schedule {
*/
class ApplyHistoryBestNode : public runtime::Object {
public:
+ /*! \brief A callback function that filters TE compute */
using FTEFilterFunc =
runtime::TypedPackedFunc<Optional<tir::PrimFunc>(const Array<te::Tensor, void>&)>;
+ /*! \brief A callback function that takes a tuning record and does something with it */
+ using FTakeTuningRecord = runtime::TypedPackedFunc<void(const TuningRecord&)>;
/*! \brief The database to be queried from */
Database database{nullptr};
@@ -60,9 +63,12 @@ class ApplyHistoryBestNode : public runtime::Object {
* \param mod The module to be queried
* \param target The target to be queried
* \param dispatched The IRs after dispatch
+ * \param f_take_tuning_record A callback function that takes a tuning record and does something
+ * with it
*/
Optional<IRModule> Query(runtime::String task_name, IRModule mod, Target target,
- Optional<Array<IRModule>> dispatched);
+ Optional<Array<IRModule>> dispatched,
+ FTakeTuningRecord f_take_tuning_record);
static constexpr const char* _type_key = "meta_schedule.ApplyHistoryBest";
TVM_DECLARE_FINAL_OBJECT_INFO(ApplyHistoryBestNode, runtime::Object);
diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h
index fc02550c7e..4c8a3076a2 100644
--- a/include/tvm/tir/stmt.h
+++ b/include/tvm/tir/stmt.h
@@ -1483,6 +1483,9 @@ constexpr const char* software_pipeline_stage = "software_pipeline_stage";
/*! \brief Mark the order of a statement in the software pipeline */
constexpr const char* software_pipeline_order = "software_pipeline_order";
+/*! \brief Mark the buffers which is const access and can be transformed layout. */
+constexpr const char* layout_free_buffers = "layout_free_buffers";
+
/*! \brief Mark the tiling structure of blocks that are applied by rule Multi-Level-Tiling */
constexpr const char* meta_schedule_tiling_structure = "meta_schedule.tiling_structure";
@@ -1516,11 +1519,12 @@ constexpr const char* meta_schedule_unroll_explicit = "meta_schedule.unroll_expl
/*! \brief Mark auto-unroll setting on the block. */
constexpr const char* meta_schedule_unroll_implicit = "meta_schedule.unroll_implicit";
-/*!
- * \brief Mark that a block should be further rewritten using tensorization.
- */
+/*! \brief Mark that a block should be further rewritten using tensorization. */
constexpr const char* meta_schedule_auto_tensorize = "meta_schedule.auto_tensorize";
+/*! \brief Mark that a block is a preprocessor block for layout rewrite. */
+constexpr const char* meta_schedule_layout_rewrite_preproc = "meta_schedule.layout_rewrite_preproc";
+
/*!
* \brief Check if attr_key is a pragma key extension
* \param attr_key The attr key to be compared
diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h
index 39a6459048..74f13420a2 100644
--- a/include/tvm/tir/transform.h
+++ b/include/tvm/tir/transform.h
@@ -650,6 +650,12 @@ TVM_DLL Pass Filter(runtime::TypedPackedFunc<bool(PrimFunc)> fcond);
*/
TVM_DLL Pass InjectPTXAsyncCopy();
+/*!
+ * \brief Remove the weight layout rewrite block
+ * \return The pass.
+ */
+TVM_DLL Pass RemoveWeightLayoutRewriteBlock();
+
} // namespace transform
} // namespace tir
} // namespace tvm
diff --git a/python/tvm/meta_schedule/apply_history_best.py b/python/tvm/meta_schedule/apply_history_best.py
index d618c3a04f..1a8ab2d358 100644
--- a/python/tvm/meta_schedule/apply_history_best.py
+++ b/python/tvm/meta_schedule/apply_history_best.py
@@ -26,7 +26,7 @@ from tvm.te import Tensor
from tvm.tir import PrimFunc
from . import _ffi_api
-from .database import Database
+from .database import Database, TuningRecord
from .utils import make_logging_func
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
@@ -71,6 +71,7 @@ class ApplyHistoryBest(Object):
mod: IRModule,
target: Target,
dispatched: Optional[List[IRModule]],
+ f_take_tuning_record: Callable[[TuningRecord], None] = None,
) -> Union[IRModule, None]:
"""The entry point of the integration
@@ -84,6 +85,8 @@ class ApplyHistoryBest(Object):
Target Info
dispatched : Optional[List[IRModule]]
A list of low-level IRs that the high-level IR could potentially dispatch to
+ f_take_tuning_record : Callable[[TuningRecord], None] = None
+ A callback function that takes a tuning record and does something with it
Returns
-------
@@ -97,6 +100,7 @@ class ApplyHistoryBest(Object):
mod,
target,
dispatched,
+ f_take_tuning_record,
)
@staticmethod
diff --git a/python/tvm/meta_schedule/builder/local_builder.py b/python/tvm/meta_schedule/builder/local_builder.py
index 6f0f523b47..69e7b0ca60 100644
--- a/python/tvm/meta_schedule/builder/local_builder.py
+++ b/python/tvm/meta_schedule/builder/local_builder.py
@@ -26,11 +26,7 @@ from tvm.runtime import Module, NDArray, load_param_dict, save_param_dict
from tvm.target import Target
from ...contrib.popen_pool import MapResult, PopenPoolExecutor, StatusKind
-from ..utils import (
- cpu_count,
- derived_object,
- get_global_func_with_default_on_worker,
-)
+from ..utils import cpu_count, derived_object, get_global_func_with_default_on_worker
from .builder import BuilderInput, BuilderResult, PyBuilder
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
@@ -258,8 +254,10 @@ def default_build(mod: IRModule, target: Target, _params: Optional[Dict[str, NDA
"""
# pylint: disable=import-outside-toplevel
from tvm.driver import build as tvm_build
+ from tvm.tir.transform import RemoveWeightLayoutRewriteBlock
# pylint: enable=import-outside-toplevel
+ mod = RemoveWeightLayoutRewriteBlock()(mod)
return tvm_build(mod, target=target)
diff --git a/python/tvm/relay/backend/te_compiler.py b/python/tvm/relay/backend/te_compiler.py
index 654c8f66ac..9b2907ccdb 100644
--- a/python/tvm/relay/backend/te_compiler.py
+++ b/python/tvm/relay/backend/te_compiler.py
@@ -19,15 +19,16 @@
from __future__ import absolute_import
import logging
+
import tvm
-from tvm import te, autotvm
-from tvm.ir.transform import PassContext
+from tvm import autotvm, te
from tvm.runtime import Object
from tvm.support import libinfo
from tvm.target import Target
-from ..backend.utils import mangle_module_name
+
from .. import function as _function
from .. import ty as _ty
+from ..backend.utils import mangle_module_name
from . import _backend
logger = logging.getLogger("te_compiler")
@@ -170,6 +171,12 @@ def select_implementation(op, attrs, inputs, out_type, target, use_autotvm=True)
ret : tuple(relay.op.OpImplementation, List[tvm.te.Tensor])
The best op implementation and the corresponding output tensors.
"""
+ # pylint: disable=import-outside-toplevel
+ from tvm.auto_scheduler import is_auto_scheduler_enabled
+ from tvm.meta_schedule import is_meta_schedule_enabled
+
+ # pylint: enable=import-outside-toplevel
+
all_impls = get_valid_implementations(op, attrs, inputs, out_type, target)
if len(all_impls) == 0:
raise RuntimeError(f"No valid {op} implementations for {target}")
@@ -177,7 +184,7 @@ def select_implementation(op, attrs, inputs, out_type, target, use_autotvm=True)
# Disable autotvm if auto_scheduler is enabled.
# (i.e., always return the implementation with the highest priority for auto-scheduler).
- if PassContext.current().config.get("relay.backend.use_auto_scheduler", False):
+ if is_auto_scheduler_enabled() or is_meta_schedule_enabled():
use_autotvm = False
# If not use autotvm, always return the implementation with the highest priority
diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py
index 23892554cf..1353d8c5f5 100644
--- a/python/tvm/relay/build_module.py
+++ b/python/tvm/relay/build_module.py
@@ -22,7 +22,6 @@ import warnings
import numpy as np
from tvm.ir import IRModule
-from tvm.ir.transform import PassContext
from tvm.target import Target
from .. import autotvm
@@ -139,20 +138,23 @@ class BuildModule(object):
params : dict
The parameters of the final graph.
"""
+ # pylint: disable=import-outside-toplevel
+ from tvm.auto_scheduler import is_auto_scheduler_enabled
+ from tvm.meta_schedule import is_meta_schedule_enabled
+ # pylint: enable=import-outside-toplevel
# Setup the params.
if params:
self._set_params(params)
# Build the IR module. If auto_scheduler is not enabled,
# then use the TOPI-defined schedule.
- use_auto_scheduler = PassContext.current().config.get(
- "relay.backend.use_auto_scheduler", False
- )
# Turn off AutoTVM config not found warnings if auto_scheduler is enabled.
old_autotvm_silent = autotvm.GLOBAL_SCOPE.silent
- autotvm.GLOBAL_SCOPE.silent = use_auto_scheduler or old_autotvm_silent
+ autotvm.GLOBAL_SCOPE.silent = (
+ is_auto_scheduler_enabled() or is_meta_schedule_enabled() or old_autotvm_silent
+ )
mod_name = mangle_module_name(mod_name)
diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py
index e1ddfe439a..9a20f9a777 100644
--- a/python/tvm/tir/transform/transform.py
+++ b/python/tvm/tir/transform/transform.py
@@ -16,7 +16,7 @@
# under the License.
"""Wrapping existing transformations."""
# pylint: disable=invalid-name
-from typing import Optional, Callable
+from typing import Callable, Optional
from . import _ffi_api
from . import function_pass as _fpass
@@ -836,3 +836,13 @@ def InjectPTXAsyncCopy():
The result pass
"""
return _ffi_api.InjectPTXAsyncCopy() # type: ignore
+
+
+def RemoveWeightLayoutRewriteBlock():
+ """Remove weight layout rewrite block before benchmarking during tuning stage.
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The result pass
+ """
+ return _ffi_api.RemoveWeightLayoutRewriteBlock() # type: ignore
diff --git a/src/meta_schedule/apply_history_best.cc b/src/meta_schedule/apply_history_best.cc
index e5cc929fd0..22445a9cf7 100644
--- a/src/meta_schedule/apply_history_best.cc
+++ b/src/meta_schedule/apply_history_best.cc
@@ -103,8 +103,8 @@ ApplyHistoryBest::ApplyHistoryBest(Database database,
}
Optional<IRModule> ApplyHistoryBestNode::Query(runtime::String task_name, IRModule mod,
- Target target,
- Optional<Array<IRModule>> dispatched) {
+ Target target, Optional<Array<IRModule>> dispatched,
+ FTakeTuningRecord f_take_tuning_record) {
ICHECK(dispatched.defined());
ICHECK_EQ(dispatched.value().size(), 1);
ICHECK(HasOnlyOneFunction<relay::Function>(mod)) << mod;
@@ -122,6 +122,9 @@ Optional<IRModule> ApplyHistoryBestNode::Query(runtime::String task_name, IRModu
if (database->HasWorkload(prim_mod)) {
Array<TuningRecord> records = database->GetTopK(database->CommitWorkload(prim_mod), 1);
if (records.size() == 1) {
+ if (f_take_tuning_record != nullptr) {
+ f_take_tuning_record(records[0]);
+ }
tir::Schedule sch =
tir::Schedule::Traced(records[0]->workload->mod, /*seed=*/-1, /*debug_mask=*/0,
/*error_render_level=*/tir::ScheduleErrorRenderLevel::kNone);
diff --git a/src/meta_schedule/arg_info.cc b/src/meta_schedule/arg_info.cc
index 37897a5ac6..672df86deb 100644
--- a/src/meta_schedule/arg_info.cc
+++ b/src/meta_schedule/arg_info.cc
@@ -61,11 +61,10 @@ Array<ArgInfo> ArgInfo::FromPrimFunc(const tir::PrimFunc& func) {
}
Array<ArgInfo> ArgInfo::FromEntryFunc(const IRModule& mod, bool remove_preproc) {
- // TODO(@jinhongyii): add pass for layout rewrite
- // if (remove_preproc) {
- // IRModule new_mod = tir::transform::RemoveWeightLayoutRewriteBlock()(mod);
- // return ArgInfo::FromPrimFunc(FindEntryFunc(new_mod));
- // }
+ if (remove_preproc) {
+ IRModule new_mod = tir::transform::RemoveWeightLayoutRewriteBlock()(mod);
+ return ArgInfo::FromPrimFunc(FindEntryFunc(new_mod));
+ }
return ArgInfo::FromPrimFunc(FindEntryFunc(mod));
}
diff --git a/src/meta_schedule/feature_extractor/per_store_feature.cc b/src/meta_schedule/feature_extractor/per_store_feature.cc
index d3d63e7824..93f6767b11 100644
--- a/src/meta_schedule/feature_extractor/per_store_feature.cc
+++ b/src/meta_schedule/feature_extractor/per_store_feature.cc
@@ -300,6 +300,7 @@ Pass SimplifyForFeatureExtraction() {
*/
Sequential PassListForPerStoreFeature() {
return Sequential({
+ tir::transform::RemoveWeightLayoutRewriteBlock(),
tir::transform::SimplifyForFeatureExtraction(),
tir::transform::LowerCrossThreadReduction(),
tir::transform::LowerInitBlock(),
diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h
index 76deb62f23..ca696da71e 100644
--- a/src/meta_schedule/utils.h
+++ b/src/meta_schedule/utils.h
@@ -41,6 +41,7 @@
#include <tvm/runtime/container/optional.h>
#include <tvm/support/parallel_for.h>
#include <tvm/tir/schedule/schedule.h>
+#include <tvm/tir/transform.h>
#include <algorithm>
#include <string>
diff --git a/src/relay/backend/te_compiler_cache.cc b/src/relay/backend/te_compiler_cache.cc
index 8715900c0c..0f519721b0 100644
--- a/src/relay/backend/te_compiler_cache.cc
+++ b/src/relay/backend/te_compiler_cache.cc
@@ -36,6 +36,8 @@
#include <tvm/te/schedule.h>
#include <tvm/te/schedule_pass.h>
#include <tvm/tir/function.h>
+#include <tvm/tir/index_map.h>
+#include <tvm/tir/transform.h>
#include <tvm/topi/tags.h>
#include <functional>
@@ -47,6 +49,7 @@
#include "../../te/operation/create_primfunc.h"
#include "../op/memory/memory.h"
+#include "../transforms/meta_schedule_layout_rewrite.h"
#include "../transforms/pass_utils.h"
#include "utils.h"
@@ -59,6 +62,16 @@ TVM_REGISTER_NODE_TYPE(CachedFuncNode);
TVM_REGISTER_NODE_TYPE(CCacheKeyNode);
TVM_REGISTER_NODE_TYPE(CCacheValueNode);
+void ExtractTransformLayout(const meta_schedule::TuningRecord& record) {
+ static tir::InstructionKind kind_transform_layout = tir::InstructionKind::Get("TransformLayout");
+ for (const tir::Instruction& inst : record->trace->insts) {
+ if (inst->kind.same_as(kind_transform_layout)) {
+ ICHECK_EQ(inst->attrs.size(), 3);
+ relay::MetaScheduleLayoutRewriter::LayoutQueuePush(Downcast<tir::IndexMap>(inst->attrs[2]));
+ }
+ }
+}
+
LoweredOutput::LoweredOutput(tvm::Array<te::Tensor> outputs, OpImplementation impl) {
auto n = make_object<LoweredOutputNode>();
n->outputs = std::move(outputs);
@@ -353,10 +366,16 @@ class ScheduleBuilder : public ExprVisitor {
meta_schedule_ctx_.value()->te_filter_func(te_args)) {
IRModule relay_mod({{prim_fn_var, relay_func}});
IRModule tir_mod({{prim_fn_var, tir_func.value()}});
- if (Optional<IRModule> scheduled_mod = meta_schedule_ctx_.value()->Query(
- prim_fn_var->name_hint, relay_mod, target_, Array<IRModule>{tir_mod})) {
- ICHECK_EQ(scheduled_mod.value()->functions.count(prim_fn_var), 1);
- prim_func = Downcast<tir::PrimFunc>(scheduled_mod.value()->functions[prim_fn_var]);
+ if (Optional<IRModule> opt_scheduled_mod = meta_schedule_ctx_.value()->Query(
+ /*task_name=*/prim_fn_var->name_hint, //
+ /*mod=*/relay_mod, //
+ /*target=*/target_, //
+ /*dispatched=*/Array<IRModule>{tir_mod}, //
+ /*f_take_tuning_record=*/ExtractTransformLayout)) {
+ IRModule scheduled_mod =
+ tir::transform::RemoveWeightLayoutRewriteBlock()(opt_scheduled_mod.value());
+ ICHECK_EQ(scheduled_mod->functions.count(prim_fn_var), 1);
+ prim_func = Downcast<tir::PrimFunc>(scheduled_mod->functions[prim_fn_var]);
}
}
}
diff --git a/src/te/operation/create_primfunc.cc b/src/te/operation/create_primfunc.cc
index 2aeb799a04..e361e8e344 100644
--- a/src/te/operation/create_primfunc.cc
+++ b/src/te/operation/create_primfunc.cc
@@ -103,12 +103,12 @@ class LayoutFreePlaceholdersNormalizer : public StmtMutator {
for (int i : this->layout_free_buffer_indices_) {
indices.push_back(Integer(i));
}
- return WithAttr(std::move(func), attr, indices);
+ return WithAttr(std::move(func), tir::attr::layout_free_buffers, indices);
}
Stmt VisitStmt_(const BlockNode* _block) final {
Block block = Downcast<Block>(StmtMutator::VisitStmt_(_block));
- if (Optional<ObjectRef> ann = block->annotations.Get(attr)) {
+ if (Optional<ObjectRef> ann = block->annotations.Get(topi_attr)) {
Array<Buffer> buffers = Downcast<Array<Buffer>>(ann);
for (Buffer buffer : buffers) {
auto it = buffer2index_.find(buffer);
@@ -116,14 +116,14 @@ class LayoutFreePlaceholdersNormalizer : public StmtMutator {
layout_free_buffer_indices_.insert(it->second);
}
}
- block.CopyOnWrite()->annotations.erase(attr);
+ block.CopyOnWrite()->annotations.erase(topi_attr);
}
return block;
}
std::unordered_map<tir::Buffer, int, ObjectPtrHash, ObjectPtrEqual> buffer2index_;
std::set<int> layout_free_buffer_indices_;
- String attr = "layout_free_placeholders";
+ String topi_attr = "layout_free_placeholders";
};
BlockRealize GenerateBlockFromTensors(const te::ComputeOp& compute_op,
diff --git a/src/tir/transforms/remove_weight_layout_rewrite_block.cc b/src/tir/transforms/remove_weight_layout_rewrite_block.cc
new file mode 100644
index 0000000000..5f47e670c6
--- /dev/null
+++ b/src/tir/transforms/remove_weight_layout_rewrite_block.cc
@@ -0,0 +1,121 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file remove_weight_layout_rewrite_block.cc
+ * \brief Remove weight layout rewrite block before benchmark
+ */
+
+#include <tvm/tir/op.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+
+namespace tvm {
+namespace tir {
+
+class WeightLayoutRewriteBlockRemover : public StmtMutator {
+ public:
+ static PrimFunc Remove(PrimFunc f) {
+ WeightLayoutRewriteBlockRemover remover;
+ PrimFuncNode* n = f.CopyOnWrite();
+ n->body = remover(std::move(n->body));
+ Map<tir::Var, Buffer> buffer_map;
+ for (const auto& kv : f->buffer_map) {
+ Var param = kv.first;
+ Buffer buffer = kv.second;
+ auto it = remover.buf_map_.find(buffer);
+ if (it != remover.buf_map_.end()) {
+ buffer_map.Set(param, (*it).second);
+ } else {
+ buffer_map.Set(param, buffer);
+ }
+ }
+ n->buffer_map = std::move(buffer_map);
+ return f;
+ }
+
+ private:
+ Stmt VisitStmt_(const BlockNode* op) final {
+ Block block = Downcast<Block>(StmtMutator::VisitStmt_(op));
+
+ auto it = block->annotations.find(attr::meta_schedule_layout_rewrite_preproc);
+ if (it == block->annotations.end() || !is_one(Downcast<PrimExpr>((*it).second))) {
+ // The block is not a weight layout block
+ // Remove allocates if needed
+ Array<Buffer> alloc_buffers;
+ for (const Buffer& buffer : block->alloc_buffers) {
+ if (!rewritten_buffers_.count(buffer)) {
+ alloc_buffers.push_back(buffer);
+ }
+ }
+ if (alloc_buffers.size() < block->alloc_buffers.size()) {
+ auto n = CopyOnWrite(block.get());
+ n->alloc_buffers = std::move(alloc_buffers);
+ return Stmt(n);
+ } else {
+ return block;
+ }
+ }
+
+ // Step 0. Checking block attrs
+ ICHECK(block->alloc_buffers.empty());
+ ICHECK(block->match_buffers.empty());
+
+ // Step 1. Checking the body is a BufferStore
+ const auto* store = block->body.as<BufferStoreNode>();
+ ICHECK(store);
+
+ // Step 2. Checking the rhs of buffer store is a BufferLoad
+ const auto* load = store->value.as<BufferLoadNode>();
+ ICHECK(load);
+
+ // Step 3. Update Buffer
+ buf_map_.Set(load->buffer, store->buffer);
+ rewritten_buffers_.insert(store->buffer);
+
+ // Step 4. Set block body as no_op
+ auto n = CopyOnWrite(block.get());
+ n->body = std::move(Evaluate(0));
+ n->reads = {};
+ n->writes = {};
+ return Stmt(n);
+ }
+
+ private:
+ /*! \brief The buffer map from original layout buffer to rewritten buffer */
+ Map<Buffer, Buffer> buf_map_;
+ /*! \brief The buffer map from original layout buffer to rewritten buffer */
+ std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> rewritten_buffers_;
+};
+namespace transform {
+
+Pass RemoveWeightLayoutRewriteBlock() {
+ auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
+ return WeightLayoutRewriteBlockRemover::Remove(std::move(f));
+ };
+ return CreatePrimFuncPass(pass_func, 0, "tir.RemoveWeightLayoutRewriteBlock", {});
+}
+
+TVM_REGISTER_GLOBAL("tir.transform.RemoveWeightLayoutRewriteBlock")
+ .set_body_typed(RemoveWeightLayoutRewriteBlock);
+
+} // namespace transform
+
+} // namespace tir
+} // namespace tvm
diff --git a/tests/python/unittest/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py b/tests/python/unittest/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py
index f9b71bfdb6..44b0e79f0c 100644
--- a/tests/python/unittest/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py
+++ b/tests/python/unittest/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py
@@ -74,10 +74,6 @@ def Move_PUV0(a: T.handle, b: T.handle) -> None:
class Fused_NN_Dense:
@T.prim_func
def main(placeholder: T.Buffer[(64, 768), "float32"], placeholder_1: T.Buffer[(768, 768), "float32"], T_matmul_NT: T.Buffer[(64, 768), "float32"]) -> None:
- # function attr dict
- T.func_attr({"global_symbol": "main", "tir.noalias": True, "layout_free_placeholders": [1]})
- # body
- # with T.block("root")
for i0, i1, i2 in T.grid(64, 768, 768):
with T.block("T_matmul_NT"):
i, j, k = T.axis.remap("SSR", [i0, i1, i2])
@@ -93,7 +89,6 @@ def before_matmul_vectorize(
placeholder_1: T.Buffer[(768, 768), "float32"],
T_matmul_NT: T.Buffer[(64, 768), "float32"],
) -> None:
- T.func_attr({"global_symbol": "main", "tir.noalias": True, "layout_free_placeholders": [1]})
with T.block("root"):
T.reads()
T.writes()
@@ -124,7 +119,6 @@ def after_matmul_vectorize(
placeholder_1: T.Buffer[(768, 768), "float32"],
T_matmul_NT: T.Buffer[(64, 768), "float32"],
) -> None:
- T.func_attr({"global_symbol": "main", "tir.noalias": True, "layout_free_placeholders": [1]})
T_matmul_NT_global = T.alloc_buffer([64, 768], dtype="float32")
for i0_0, i1_0, i0_1, i1_1 in T.grid(1, 16, 1, 3):
for i2_0, i0_2, i1_2, i2_1, i0_3 in T.grid(48, 8, 1, 16, 8):
diff --git a/tests/python/unittest/test_te_create_primfunc.py b/tests/python/unittest/test_te_create_primfunc.py
index 5d9ad003b4..d3f444ec08 100644
--- a/tests/python/unittest/test_te_create_primfunc.py
+++ b/tests/python/unittest/test_te_create_primfunc.py
@@ -377,7 +377,7 @@ def expected_layout_attr(
B: T.Buffer[(128, 128), "float32"],
D: T.Buffer[(128, 128), "float32"],
) -> None:
- T.func_attr({"global_symbol": "main", "tir.noalias": True, "layout_free_placeholders": [1]})
+ T.func_attr({"global_symbol": "main", "tir.noalias": True, "layout_free_buffers": [1]})
C = T.alloc_buffer([128, 128], dtype="float32")
for i0, i1, i2 in T.grid(128, 128, 128):
with T.block("C"):
diff --git a/tests/python/unittest/test_tir_transform_remove_weight_layout_rewrite_block.py b/tests/python/unittest/test_tir_transform_remove_weight_layout_rewrite_block.py
new file mode 100644
index 0000000000..7a01428381
--- /dev/null
+++ b/tests/python/unittest/test_tir_transform_remove_weight_layout_rewrite_block.py
@@ -0,0 +1,91 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import sys
+
+import tvm
+from tvm.ir.module import IRModule
+from tvm.script import tir as T
+from tvm.tir.function import PrimFunc
+
+
+def _check(before, expect):
+ if isinstance(before, PrimFunc):
+ before = IRModule({"main": before})
+ if isinstance(expect, PrimFunc):
+ expect = IRModule({"main": expect})
+
+ mod = tvm.tir.transform.RemoveWeightLayoutRewriteBlock()(before)
+ tvm.ir.assert_structural_equal(mod, expect)
+
+
+def test_matmul():
+ @T.prim_func
+ def before(
+ A: T.Buffer[(16, 16), "float32"],
+ B: T.Buffer[(16, 16), "float32"],
+ C: T.Buffer[(16, 16), "float32"],
+ ) -> None:
+ T.func_attr({"layout_free_buffers": [1]})
+ B_ = T.alloc_buffer([16, 4, 4], dtype="float32")
+ for i0_o, i1_o in T.grid(16, 16):
+ with T.block("layout_rewrite"):
+ i0, i1 = T.axis.remap("SS", [i0_o, i1_o])
+ T.reads(B[i0, i1])
+ T.writes(B_[i1, i0 // 4, i0 % 4])
+ T.block_attr({"meta_schedule.layout_rewrite_preproc": True})
+ B_[i1, i0 // 4, i0 % 4] = B[i0, i1]
+ for i0, j, k0, i1, k1 in T.grid(4, 16, 4, 4, 4):
+ with T.block("matmul"):
+ vi = T.axis.spatial(16, i0 * 4 + i1)
+ vj = T.axis.spatial(16, j)
+ vk = T.axis.reduce(16, k0 * 4 + k1)
+ T.reads(A[vi, vk], B_[vj, vk // 4, vk % 4])
+ T.writes(C[vi, vj])
+ with T.init():
+ C[vi, vj] = T.float32(0)
+ C[vi, vj] = C[vi, vj] + A[vi, vk] * B_[vj, vk // 4, vk % 4]
+
+ @T.prim_func
+ def after(
+ A: T.Buffer[(16, 16), "float32"],
+ B: T.Buffer[(16, 4, 4), "float32"],
+ C: T.Buffer[(16, 16), "float32"],
+ ) -> None:
+ T.func_attr({"layout_free_buffers": [1]})
+ for i0_o, i1_o in T.grid(16, 16):
+ with T.block("layout_rewrite"):
+ i0, i1 = T.axis.remap("SS", [i0_o, i1_o])
+ T.reads()
+ T.writes()
+ T.block_attr({"meta_schedule.layout_rewrite_preproc": True})
+ T.evaluate(0)
+ for i0, j, k0, i1, k1 in T.grid(4, 16, 4, 4, 4):
+ with T.block("matmul"):
+ vi = T.axis.spatial(16, i0 * 4 + i1)
+ vj = T.axis.spatial(16, j)
+ vk = T.axis.reduce(16, k0 * 4 + k1)
+ T.reads(A[vi, vk], B[vj, vk // 4, vk % 4])
+ T.writes(C[vi, vj])
+ with T.init():
+ C[vi, vj] = T.float32(0)
+ C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk // 4, vk % 4]
+
+ _check(before, after)
+
+
+if __name__ == "__main__":
+ test_matmul()