You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/06/28 15:01:10 UTC

[tvm] branch main updated: [USMP] Improve algorithm extensibility (#11880)

This is an automated email from the ASF dual-hosted git repository.

manupa pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new bd49b0846a [USMP] Improve algorithm extensibility (#11880)
bd49b0846a is described below

commit bd49b0846a6f435991a55dec11f5a01169b83b36
Author: Rafael Stahl <du...@web.de>
AuthorDate: Tue Jun 28 17:01:02 2022 +0200

    [USMP] Improve algorithm extensibility (#11880)
    
    * [USMP] Improve algorithm extensibility
    
    * [USMP] add option for custom_algorithm to avoid PackedFunc on default path
    
    * [USMP] add test for custom algorithm
    
    * [lint] fix wrong line length
    
    * [USMP][test] fix PoolInfo for latest tvm
---
 include/tvm/tir/usmp/utils.h                  |  5 +++
 src/tir/usmp/unified_static_memory_planner.cc | 28 ++++++++++++-----
 tests/python/unittest/test_tir_usmp_algo.py   | 45 +++++++++++++++++++++++++++
 3 files changed, 71 insertions(+), 7 deletions(-)

diff --git a/include/tvm/tir/usmp/utils.h b/include/tvm/tir/usmp/utils.h
index 5b3b44ff7e..59430eee83 100644
--- a/include/tvm/tir/usmp/utils.h
+++ b/include/tvm/tir/usmp/utils.h
@@ -45,6 +45,11 @@ constexpr const char* kUSMPAlgorithmOption = "tir.usmp.algorithm";
  * \brief PassContext option to enable placing I/O tensors in the workspace
  */
 constexpr const char* kUSMPUseWorkspaceIO = "tir.usmp.use_workspace_io";
+/*!
+ * \brief PassContext option to specify a custom memory planning algorithm in USMP.
+ * The algorithm should be provided as registered PackedFunc with the name tir.usmp.algorithm.NAME
+ */
+constexpr const char* kUSMPCustomAlgorithmOption = "tir.usmp.custom_algorithm";
 
 namespace tir {
 namespace usmp {
diff --git a/src/tir/usmp/unified_static_memory_planner.cc b/src/tir/usmp/unified_static_memory_planner.cc
index d7eb0f3a7e..60030c1595 100644
--- a/src/tir/usmp/unified_static_memory_planner.cc
+++ b/src/tir/usmp/unified_static_memory_planner.cc
@@ -41,6 +41,7 @@ namespace tvm {
 TVM_REGISTER_PASS_CONFIG_OPTION(kUSMPEnableOption, Bool);
 TVM_REGISTER_PASS_CONFIG_OPTION(kUSMPAlgorithmOption, String);
 TVM_REGISTER_PASS_CONFIG_OPTION(kUSMPUseWorkspaceIO, Bool);
+TVM_REGISTER_PASS_CONFIG_OPTION(kUSMPCustomAlgorithmOption, String);
 
 namespace tir {
 namespace usmp {
@@ -53,7 +54,8 @@ static std::unordered_map<String, std::function<Map<BufferInfo, PoolAllocation>(
                {"greedy_by_conflicts", algo::GreedyByConflicts},
                {"hill_climb", algo::HillClimb}};
 
-IRModule PlanMemory(const IRModule& mod, String algo, bool use_workspace_io) {
+IRModule PlanMemory(const IRModule& mod, String algo, bool use_workspace_io,
+                    Optional<String> opt_custom_algo) {
   VLOG(1) << "workspace required = " << CalculateModuleWorkspaceSize(mod);
   IRModule module = mod->ShallowCopy();
   if (use_workspace_io) {
@@ -64,10 +66,21 @@ IRModule PlanMemory(const IRModule& mod, String algo, bool use_workspace_io) {
   BufferInfoAnalysis buffer_info_analysis = ExtractBufferInfo(main_func, module);
   Array<BufferInfo> buffer_info_arr =
       ConvertToArrayOfBufferInfo(buffer_info_analysis->buffer_info_stmts);
-  CHECK(algorithms.count(algo)) << "The selected USMP algorithm : " << algo
-                                << " is not defined. Please define it in the above algorithms map.";
+  decltype(algorithms)::mapped_type algorithm;
+  if (opt_custom_algo) {
+    String algo_func_name = "tir.usmp.algo." + opt_custom_algo.value();
+    const runtime::PackedFunc* pfAlgo = runtime::Registry::Get(algo_func_name);
+    CHECK(pfAlgo) << "The selected custom USMP algorithm : " << opt_custom_algo.value()
+                  << " is not defined. Please register it as " << algo_func_name;
+    algorithm = *pfAlgo;
+  } else {
+    CHECK(algorithms.count(algo))
+        << "The selected USMP algorithm : " << algo
+        << " is not defined. Please define it in the above algorithms map.";
+    algorithm = algorithms[algo];
+  }
   Map<BufferInfo, PoolAllocation> buffer_info_pool_allocations =
-      algorithms[algo](buffer_info_arr, buffer_info_analysis->memory_pressure);
+      algorithm(buffer_info_arr, buffer_info_analysis->memory_pressure);
 
   Map<Stmt, PoolAllocation> stmt_pool_allocations = AssignStmtPoolAllocations(
       buffer_info_analysis->buffer_info_stmts, buffer_info_pool_allocations);
@@ -98,6 +111,7 @@ tvm::transform::Pass UnifiedStaticMemoryPlanner() {
   auto usmp_main_pass_func = [=](IRModule m, tvm::transform::PassContext ctx) {
     auto algorithm_str = ctx->GetConfig(kUSMPAlgorithmOption, String(usmp::kDefaultAlgo));
     auto use_workspace_io = ctx->GetConfig(kUSMPUseWorkspaceIO, Bool(false));
+    auto custom_algorithm_str = ctx->GetConfig<String>(kUSMPCustomAlgorithmOption);
     tvm::relay::Executor executor_config =
         m->GetAttr<tvm::relay::Executor>(tvm::attr::kExecutor).value();
     String interface_api = executor_config->GetAttr<String>("interface-api").value_or("packed");
@@ -109,9 +123,9 @@ tvm::transform::Pass UnifiedStaticMemoryPlanner() {
                                   << "Please use interface_api c to be able to enable "
                                   << kUSMPUseWorkspaceIO << "\n";
     }
-    return Downcast<IRModule>(usmp::PlanMemory(m,
-                                               algorithm_str.value_or(String(usmp::kDefaultAlgo)),
-                                               use_workspace_io.value_or(Bool(false))));
+    return Downcast<IRModule>(
+        usmp::PlanMemory(m, algorithm_str.value_or(String(usmp::kDefaultAlgo)),
+                         use_workspace_io.value_or(Bool(false)), custom_algorithm_str));
   };
 
   return tvm::transform::CreateModulePass(usmp_main_pass_func, 0,
diff --git a/tests/python/unittest/test_tir_usmp_algo.py b/tests/python/unittest/test_tir_usmp_algo.py
index 9d30a0d195..140f6d1b14 100644
--- a/tests/python/unittest/test_tir_usmp_algo.py
+++ b/tests/python/unittest/test_tir_usmp_algo.py
@@ -683,3 +683,48 @@ def test_resnet_subgraph(algorithm, workspace_size):
     )
 
     _check_max_workspace_size(buffer_pool_allocations, global_workspace_pool, workspace_size)
+
+
+def test_custom_algo():
+    target = Target("c")
+    global_workspace_pool = WorkspacePoolInfo(
+        "global_workspace",
+        [target],
+    )
+    tir_mod = ResnetStructure
+    tir_mod = _assign_targets_to_primfuncs_irmodule(tir_mod, target)
+    tir_mod = _assign_poolinfos_to_allocates_in_irmodule(tir_mod, [global_workspace_pool])
+    tir_mod = tir_mod.with_attr("executor", tvm.relay.backend.Executor("aot"))
+    tir_mod = tir_mod.with_attr("runtime", tvm.relay.backend.Runtime("crt"))
+    tir_mod["__tvm_main__"] = tir_mod[
+        "tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast"
+    ]
+
+    algo_called = False
+
+    @tvm.register_func("tir.usmp.algo.trivial")
+    def _trivial_algo(buf_infos, mem_pressure):
+        nonlocal algo_called
+        algo_called = True
+        out_layout = {}
+        offset = 0
+        for buf_info in buf_infos:
+            pool_info = buf_info.pool_candidates[0]
+            out_layout[buf_info] = usmp_utils.PoolAllocation(pool_info, offset)
+            offset += buf_info.size_bytes
+        return out_layout
+
+    usmp_pass = tvm.get_global_func("tir.transform.UnifiedStaticMemoryPlanner")
+    usmp_pass()(tir_mod)
+    assert not algo_called
+
+    with tvm.transform.PassContext(config={"tir.usmp.custom_algorithm": "trivial"}):
+        usmp_pass()(tir_mod)
+
+    assert algo_called
+
+    with pytest.raises(
+        tvm.TVMError, match="The selected custom USMP algorithm : invalid is not defined"
+    ):
+        with tvm.transform.PassContext(config={"tir.usmp.custom_algorithm": "invalid"}):
+            usmp_pass()(tir_mod)