You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/06/28 15:01:10 UTC
[tvm] branch main updated: [USMP] Improve algorithm extensibility (#11880)
This is an automated email from the ASF dual-hosted git repository.
manupa pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new bd49b0846a [USMP] Improve algorithm extensibility (#11880)
bd49b0846a is described below
commit bd49b0846a6f435991a55dec11f5a01169b83b36
Author: Rafael Stahl <du...@web.de>
AuthorDate: Tue Jun 28 17:01:02 2022 +0200
[USMP] Improve algorithm extensibility (#11880)
* [USMP] Improve algorithm extensibility
* [USMP] add option for custom_algorithm to avoid PackedFunc on default path
* [USMP] add test for custom algorithm
* [lint] fix wrong line length
* [USMP][test] fix PoolInfo for latest tvm
---
include/tvm/tir/usmp/utils.h | 5 +++
src/tir/usmp/unified_static_memory_planner.cc | 28 ++++++++++++-----
tests/python/unittest/test_tir_usmp_algo.py | 45 +++++++++++++++++++++++++++
3 files changed, 71 insertions(+), 7 deletions(-)
diff --git a/include/tvm/tir/usmp/utils.h b/include/tvm/tir/usmp/utils.h
index 5b3b44ff7e..59430eee83 100644
--- a/include/tvm/tir/usmp/utils.h
+++ b/include/tvm/tir/usmp/utils.h
@@ -45,6 +45,11 @@ constexpr const char* kUSMPAlgorithmOption = "tir.usmp.algorithm";
* \brief PassContext option to enable placing I/O tensors in the workspace
*/
constexpr const char* kUSMPUseWorkspaceIO = "tir.usmp.use_workspace_io";
+/*!
+ * \brief PassContext option to specify a custom memory planning algorithm in USMP.
+ * The algorithm should be provided as registered PackedFunc with the name tir.usmp.algorithm.NAME
+ */
+constexpr const char* kUSMPCustomAlgorithmOption = "tir.usmp.custom_algorithm";
namespace tir {
namespace usmp {
diff --git a/src/tir/usmp/unified_static_memory_planner.cc b/src/tir/usmp/unified_static_memory_planner.cc
index d7eb0f3a7e..60030c1595 100644
--- a/src/tir/usmp/unified_static_memory_planner.cc
+++ b/src/tir/usmp/unified_static_memory_planner.cc
@@ -41,6 +41,7 @@ namespace tvm {
TVM_REGISTER_PASS_CONFIG_OPTION(kUSMPEnableOption, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kUSMPAlgorithmOption, String);
TVM_REGISTER_PASS_CONFIG_OPTION(kUSMPUseWorkspaceIO, Bool);
+TVM_REGISTER_PASS_CONFIG_OPTION(kUSMPCustomAlgorithmOption, String);
namespace tir {
namespace usmp {
@@ -53,7 +54,8 @@ static std::unordered_map<String, std::function<Map<BufferInfo, PoolAllocation>(
{"greedy_by_conflicts", algo::GreedyByConflicts},
{"hill_climb", algo::HillClimb}};
-IRModule PlanMemory(const IRModule& mod, String algo, bool use_workspace_io) {
+IRModule PlanMemory(const IRModule& mod, String algo, bool use_workspace_io,
+ Optional<String> opt_custom_algo) {
VLOG(1) << "workspace required = " << CalculateModuleWorkspaceSize(mod);
IRModule module = mod->ShallowCopy();
if (use_workspace_io) {
@@ -64,10 +66,21 @@ IRModule PlanMemory(const IRModule& mod, String algo, bool use_workspace_io) {
BufferInfoAnalysis buffer_info_analysis = ExtractBufferInfo(main_func, module);
Array<BufferInfo> buffer_info_arr =
ConvertToArrayOfBufferInfo(buffer_info_analysis->buffer_info_stmts);
- CHECK(algorithms.count(algo)) << "The selected USMP algorithm : " << algo
- << " is not defined. Please define it in the above algorithms map.";
+ decltype(algorithms)::mapped_type algorithm;
+ if (opt_custom_algo) {
+ String algo_func_name = "tir.usmp.algo." + opt_custom_algo.value();
+ const runtime::PackedFunc* pfAlgo = runtime::Registry::Get(algo_func_name);
+ CHECK(pfAlgo) << "The selected custom USMP algorithm : " << opt_custom_algo.value()
+ << " is not defined. Please register it as " << algo_func_name;
+ algorithm = *pfAlgo;
+ } else {
+ CHECK(algorithms.count(algo))
+ << "The selected USMP algorithm : " << algo
+ << " is not defined. Please define it in the above algorithms map.";
+ algorithm = algorithms[algo];
+ }
Map<BufferInfo, PoolAllocation> buffer_info_pool_allocations =
- algorithms[algo](buffer_info_arr, buffer_info_analysis->memory_pressure);
+ algorithm(buffer_info_arr, buffer_info_analysis->memory_pressure);
Map<Stmt, PoolAllocation> stmt_pool_allocations = AssignStmtPoolAllocations(
buffer_info_analysis->buffer_info_stmts, buffer_info_pool_allocations);
@@ -98,6 +111,7 @@ tvm::transform::Pass UnifiedStaticMemoryPlanner() {
auto usmp_main_pass_func = [=](IRModule m, tvm::transform::PassContext ctx) {
auto algorithm_str = ctx->GetConfig(kUSMPAlgorithmOption, String(usmp::kDefaultAlgo));
auto use_workspace_io = ctx->GetConfig(kUSMPUseWorkspaceIO, Bool(false));
+ auto custom_algorithm_str = ctx->GetConfig<String>(kUSMPCustomAlgorithmOption);
tvm::relay::Executor executor_config =
m->GetAttr<tvm::relay::Executor>(tvm::attr::kExecutor).value();
String interface_api = executor_config->GetAttr<String>("interface-api").value_or("packed");
@@ -109,9 +123,9 @@ tvm::transform::Pass UnifiedStaticMemoryPlanner() {
<< "Please use interface_api c to be able to enable "
<< kUSMPUseWorkspaceIO << "\n";
}
- return Downcast<IRModule>(usmp::PlanMemory(m,
- algorithm_str.value_or(String(usmp::kDefaultAlgo)),
- use_workspace_io.value_or(Bool(false))));
+ return Downcast<IRModule>(
+ usmp::PlanMemory(m, algorithm_str.value_or(String(usmp::kDefaultAlgo)),
+ use_workspace_io.value_or(Bool(false)), custom_algorithm_str));
};
return tvm::transform::CreateModulePass(usmp_main_pass_func, 0,
diff --git a/tests/python/unittest/test_tir_usmp_algo.py b/tests/python/unittest/test_tir_usmp_algo.py
index 9d30a0d195..140f6d1b14 100644
--- a/tests/python/unittest/test_tir_usmp_algo.py
+++ b/tests/python/unittest/test_tir_usmp_algo.py
@@ -683,3 +683,48 @@ def test_resnet_subgraph(algorithm, workspace_size):
)
_check_max_workspace_size(buffer_pool_allocations, global_workspace_pool, workspace_size)
+
+
+def test_custom_algo():
+ target = Target("c")
+ global_workspace_pool = WorkspacePoolInfo(
+ "global_workspace",
+ [target],
+ )
+ tir_mod = ResnetStructure
+ tir_mod = _assign_targets_to_primfuncs_irmodule(tir_mod, target)
+ tir_mod = _assign_poolinfos_to_allocates_in_irmodule(tir_mod, [global_workspace_pool])
+ tir_mod = tir_mod.with_attr("executor", tvm.relay.backend.Executor("aot"))
+ tir_mod = tir_mod.with_attr("runtime", tvm.relay.backend.Runtime("crt"))
+ tir_mod["__tvm_main__"] = tir_mod[
+ "tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast"
+ ]
+
+ algo_called = False
+
+ @tvm.register_func("tir.usmp.algo.trivial")
+ def _trivial_algo(buf_infos, mem_pressure):
+ nonlocal algo_called
+ algo_called = True
+ out_layout = {}
+ offset = 0
+ for buf_info in buf_infos:
+ pool_info = buf_info.pool_candidates[0]
+ out_layout[buf_info] = usmp_utils.PoolAllocation(pool_info, offset)
+ offset += buf_info.size_bytes
+ return out_layout
+
+ usmp_pass = tvm.get_global_func("tir.transform.UnifiedStaticMemoryPlanner")
+ usmp_pass()(tir_mod)
+ assert not algo_called
+
+ with tvm.transform.PassContext(config={"tir.usmp.custom_algorithm": "trivial"}):
+ usmp_pass()(tir_mod)
+
+ assert algo_called
+
+ with pytest.raises(
+ tvm.TVMError, match="The selected custom USMP algorithm : invalid is not defined"
+ ):
+ with tvm.transform.PassContext(config={"tir.usmp.custom_algorithm": "invalid"}):
+ usmp_pass()(tir_mod)