You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by le...@apache.org on 2021/12/07 09:37:30 UTC

[tvm] branch main updated: [TIR][USMP] Augmenting the algo interface with memory pressure (#9649)

This is an automated email from the ASF dual-hosted git repository.

leandron pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new cb132e2  [TIR][USMP] Augmenting the algo interface with memory pressure (#9649)
cb132e2 is described below

commit cb132e28084dda92f27c73f2aad2ac8252956062
Author: Manupa Karunaratne <ma...@arm.com>
AuthorDate: Tue Dec 7 09:37:09 2021 +0000

    [TIR][USMP] Augmenting the algo interface with memory pressure (#9649)
    
    This commit adds memory pressue to be an arugment to
    the USMP algorithm interface as certain iterative algorithms
    could use this as a guide determine the termination
    criteria.
    
    Change-Id: I3fb5eea3fe5ba43e68c23625d411e557f6dd89a3
---
 include/tvm/tir/usmp/utils.h                       | 39 ++++++++++++++++++++++
 src/tir/usmp/algo/greedy.cc                        | 14 ++++----
 src/tir/usmp/analysis/extract_buffer_info.cc       | 31 +++++++++--------
 src/tir/usmp/utils.cc                              | 22 ++++++++++++
 tests/python/unittest/test_tir_usmp_algo.py        | 22 ++++++------
 .../test_tir_usmp_analysis_extract_bufferinfo.py   | 25 ++++++++------
 tests/python/unittest/test_tir_usmp_utils.py       |  6 ++--
 7 files changed, 114 insertions(+), 45 deletions(-)

diff --git a/include/tvm/tir/usmp/utils.h b/include/tvm/tir/usmp/utils.h
index 32a2bc6..145c61d 100644
--- a/include/tvm/tir/usmp/utils.h
+++ b/include/tvm/tir/usmp/utils.h
@@ -154,6 +154,45 @@ class BufferInfo : public ObjectRef {
 };
 
 /*!
+ * \brief This is a composite node that is produced by extract_buffer_info
+ * analysis pass that contains useful global information that could be useful
+ * for memory planning algorithms.
+ */
+struct BufferInfoAnalysisNode : public Object {
+  /*! \brief The BufferInfo object and its associated TIR statement */
+  Map<BufferInfo, tir::Stmt> buffer_info_stmts;
+  /*! \brief This represent maximum amount of memory being used at
+   * any point of time in the inference. This value is largely the
+   * best allocation an algorithm could achieve. Due to
+   * the complexities of conflict graphs, it would not be feasible
+   * to achieve this value, practically. However, it can be useful
+   * for iterative algorithms to know this value to define termination
+   * criteria.*/
+  Integer memory_pressure;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    v->Visit("buffer_info_stmts", &buffer_info_stmts);
+    v->Visit("memory_pressure", &memory_pressure);
+  }
+
+  bool SEqualReduce(const BufferInfoAnalysisNode* other, SEqualReducer equal) const {
+    return equal(buffer_info_stmts, other->buffer_info_stmts) &&
+           equal(memory_pressure, other->memory_pressure);
+  }
+
+  void SHashReduce(SHashReducer hash_reduce) const {
+    hash_reduce(buffer_info_stmts);
+    hash_reduce(memory_pressure);
+  }
+};
+
+class BufferInfoAnalysis : public ObjectRef {
+ public:
+  TVM_DLL BufferInfoAnalysis(Map<BufferInfo, tir::Stmt> buffer_info_stmts, Integer memory_pressure);
+  TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(BufferInfoAnalysis, ObjectRef, BufferInfoAnalysisNode);
+};
+
+/*!
  * \brief The pool allocation produced after the USMP algorithm
  */
 struct PoolAllocationNode : public Object {
diff --git a/src/tir/usmp/algo/greedy.cc b/src/tir/usmp/algo/greedy.cc
index b98d828..5e1ce5f 100644
--- a/src/tir/usmp/algo/greedy.cc
+++ b/src/tir/usmp/algo/greedy.cc
@@ -209,22 +209,24 @@ class GreedyConflicts : public GreedyBase {
   }
 };
 
-Map<BufferInfo, PoolAllocation> GreedyBySize(const Array<BufferInfo>& buffer_info_arr) {
+Map<BufferInfo, PoolAllocation> GreedyBySize(const Array<BufferInfo>& buffer_info_arr,
+                                             const Integer& memory_pressure) {
   return GreedySize().PlanMemory(buffer_info_arr);
 }
 
-Map<BufferInfo, PoolAllocation> GreedyByConflicts(const Array<BufferInfo>& buffer_info_arr) {
+Map<BufferInfo, PoolAllocation> GreedyByConflicts(const Array<BufferInfo>& buffer_info_arr,
+                                                  const Integer& memory_pressure) {
   return GreedyConflicts().PlanMemory(buffer_info_arr);
 }
 
 TVM_REGISTER_GLOBAL("tir.usmp.algo.greedy_by_size")
-    .set_body_typed([](Array<BufferInfo> buffer_info_arr) {
-      return GreedyBySize(buffer_info_arr);
+    .set_body_typed([](Array<BufferInfo> buffer_info_arr, Integer memory_pressure) {
+      return GreedyBySize(buffer_info_arr, memory_pressure);
     });
 
 TVM_REGISTER_GLOBAL("tir.usmp.algo.greedy_by_conflicts")
-    .set_body_typed([](Array<BufferInfo> buffer_info_arr) {
-      return GreedyByConflicts(buffer_info_arr);
+    .set_body_typed([](Array<BufferInfo> buffer_info_arr, Integer memory_pressure) {
+      return GreedyByConflicts(buffer_info_arr, memory_pressure);
     });
 
 }  // namespace algo
diff --git a/src/tir/usmp/analysis/extract_buffer_info.cc b/src/tir/usmp/analysis/extract_buffer_info.cc
index 3fea721..ea53f27 100644
--- a/src/tir/usmp/analysis/extract_buffer_info.cc
+++ b/src/tir/usmp/analysis/extract_buffer_info.cc
@@ -63,7 +63,7 @@ class BufferInfoExtractor : public StmtExprVisitor {
     // Pushing a scope info for the initial body of the main function
     scope_stack_.push(ScopeInfo());
   }
-  Map<BufferInfo, tir::Stmt> operator()(const PrimFunc& func);
+  BufferInfoAnalysis operator()(const PrimFunc& func);
 
  private:
   void VisitStmt(const Stmt& n) override;
@@ -400,7 +400,7 @@ void BufferInfoExtractor::VisitExpr_(const CallNode* op) {
   StmtExprVisitor::VisitExpr_(op);
 }
 
-Map<BufferInfo, tir::Stmt> BufferInfoExtractor::operator()(const PrimFunc& main_func) {
+BufferInfoAnalysis BufferInfoExtractor::operator()(const PrimFunc& main_func) {
   VisitPrimFunc(main_func, Call());
 
   // Create a vector of liveness events
@@ -454,33 +454,32 @@ Map<BufferInfo, tir::Stmt> BufferInfoExtractor::operator()(const PrimFunc& main_
 
   // Traverse the liveness events using a open set to track what
   // is live while updating the conflicts through out the linear traversal
-  std::unordered_map<BufferInfo, int, ObjectPtrHash, ObjectPtrEqual> open_set;
+
+  int open_set_size = 0;
+  int max_open_set_size = 0;
+  std::unordered_set<BufferInfo, ObjectPtrHash, ObjectPtrEqual> open_set;
   for (const auto& le_event : le_events_timeline) {
     if (le_event.le_type == START) {
-      for (const auto& kv : open_set) {
-        BufferInfo open_buffer_info = kv.first;
+      for (const BufferInfo& open_buffer_info : open_set) {
         open_buffer_info->conflicts.push_back(le_event.buffer_info);
         if (le_event.buffer_info != open_buffer_info) {
           le_event.buffer_info->conflicts.push_back(open_buffer_info);
         }
       }
-      if (open_set.find(le_event.buffer_info) == open_set.end()) {
-        open_set[le_event.buffer_info] = 1;
-      } else {
-        open_set[le_event.buffer_info] += 1;
+      open_set_size += le_event.buffer_info->size_bytes;
+      if (open_set_size > max_open_set_size) {
+        max_open_set_size = open_set_size;
       }
+      open_set.insert(le_event.buffer_info);
     } else {
-      if (open_set[le_event.buffer_info] == 1) {
-        open_set.erase(le_event.buffer_info);
-      } else {
-        open_set[le_event.buffer_info] -= 1;
-      }
+      open_set_size -= le_event.buffer_info->size_bytes;
+      open_set.erase(le_event.buffer_info);
     }
   }
-  return this->buffer_info_map_;
+  return BufferInfoAnalysis(this->buffer_info_map_, max_open_set_size);
 }
 
-Map<BufferInfo, tir::Stmt> ExtractBufferInfo(const PrimFunc& main_func, const IRModule& mod) {
+BufferInfoAnalysis ExtractBufferInfo(const PrimFunc& main_func, const IRModule& mod) {
   return BufferInfoExtractor(mod)(main_func);
 }
 
diff --git a/src/tir/usmp/utils.cc b/src/tir/usmp/utils.cc
index b7177cc..7a6a683 100644
--- a/src/tir/usmp/utils.cc
+++ b/src/tir/usmp/utils.cc
@@ -66,6 +66,28 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
                 << ",\n  alignment=" << node->alignment << ")";
     });
 
+BufferInfoAnalysis::BufferInfoAnalysis(Map<BufferInfo, tir::Stmt> buffer_info_stmts,
+                                       Integer memory_pressure) {
+  auto bufinfo_analysis_node = make_object<BufferInfoAnalysisNode>();
+  bufinfo_analysis_node->buffer_info_stmts = buffer_info_stmts;
+  bufinfo_analysis_node->memory_pressure = memory_pressure;
+  data_ = std::move(bufinfo_analysis_node);
+}
+
+TVM_REGISTER_NODE_TYPE(BufferInfoAnalysisNode);
+TVM_REGISTER_GLOBAL("tir.usmp.BufferInfoAnalysis")
+    .set_body_typed([](Map<BufferInfo, tir::Stmt> buffer_info_stmts, Integer memory_pressure) {
+      return BufferInfoAnalysis(buffer_info_stmts, memory_pressure);
+    });
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+    .set_dispatch<BufferInfoAnalysisNode>([](const ObjectRef& ref, ReprPrinter* p) {
+      auto* node = static_cast<const BufferInfoAnalysisNode*>(ref.get());
+      p->stream << "BufferInfoAnalysisNode(\n"
+                << "buffer_info_stmts=" << node->buffer_info_stmts
+                << ",\n  memory_pressure=" << node->memory_pressure << ")";
+    });
+
 PoolInfo::PoolInfo(String pool_name, Map<Target, String> target_access, Integer size_hint_bytes) {
   auto poolinfo_node = make_object<PoolInfoNode>();
   poolinfo_node->pool_name = pool_name;
diff --git a/tests/python/unittest/test_tir_usmp_algo.py b/tests/python/unittest/test_tir_usmp_algo.py
index 61a70ad..1a763d0 100644
--- a/tests/python/unittest/test_tir_usmp_algo.py
+++ b/tests/python/unittest/test_tir_usmp_algo.py
@@ -120,7 +120,7 @@ def test_no_pool_error():
     with pytest.raises(
         tvm.TVMError, match="TVM USMP Error: the space available in the provided pools exceeded"
     ):
-        buffer_pool_allocations = fusmp_algo(buffer_info_arr)
+        buffer_pool_allocations = fusmp_algo(buffer_info_arr, 0)
 
 
 @pytest.mark.parametrize("algorithm", ["greedy_by_size", "greedy_by_conflicts"])
@@ -148,7 +148,7 @@ def test_name_based_ordering(algorithm):
 
         buffer_info_arr = [bi_a, bi_b, bi_c]
         fusmp_algo = tvm.get_global_func(f"tir.usmp.algo.{algorithm}")
-        buffer_pool_allocations = fusmp_algo(buffer_info_arr)
+        buffer_pool_allocations = fusmp_algo(buffer_info_arr, 0)
         assert buffer_pool_allocations[bi_a].byte_offset == 20
         assert buffer_pool_allocations[bi_b].byte_offset == 10
         assert buffer_pool_allocations[bi_c].byte_offset == 0
@@ -216,7 +216,7 @@ def test_linear(algorithm, workspace_size):
 
     buffer_info_arr = [bi_a, bi_b, bi_c, bi_d, bi_e, bi_f]
     fusmp_algo = tvm.get_global_func(f"tir.usmp.algo.{algorithm}")
-    buffer_pool_allocations = fusmp_algo(buffer_info_arr)
+    buffer_pool_allocations = fusmp_algo(buffer_info_arr, 0)
     _check_max_workspace_size(buffer_pool_allocations, global_workspace_pool, workspace_size)
 
 
@@ -287,7 +287,7 @@ def test_fanout(algorithm, workspace_size):
 
     buffer_info_arr = [bi_a, bi_b, bi_c, bi_d, bi_e, bi_f, bi_g]
     fusmp_algo = tvm.get_global_func(f"tir.usmp.algo.{algorithm}")
-    buffer_pool_allocations = fusmp_algo(buffer_info_arr)
+    buffer_pool_allocations = fusmp_algo(buffer_info_arr, 0)
     _check_max_workspace_size(buffer_pool_allocations, global_workspace_pool, workspace_size)
 
 
@@ -382,12 +382,13 @@ def test_mobilenet_subgraph(algorithm, fast_memory_size, slow_memory_size):
         tir_mod, [fast_memory_pool, slow_memory_pool]
     )
     main_func = tir_mod["run_model"]
-    buffer_info_map = tvm.tir.usmp.analysis.extract_buffer_info(main_func, tir_mod)
+    buffer_info_analysis = tvm.tir.usmp.analysis.extract_buffer_info(main_func, tir_mod)
+    assert buffer_info_analysis.memory_pressure == 1117718
 
     fcreate_array_bi = tvm.get_global_func("tir.usmp.CreateArrayBufferInfo")
-    buffer_info_arr = fcreate_array_bi(buffer_info_map)
+    buffer_info_arr = fcreate_array_bi(buffer_info_analysis.buffer_info_stmts)
     fusmp_algo = tvm.get_global_func(f"tir.usmp.algo.{algorithm}")
-    buffer_pool_allocations = fusmp_algo(buffer_info_arr)
+    buffer_pool_allocations = fusmp_algo(buffer_info_arr, buffer_info_analysis.memory_pressure)
 
     buffer_info_map_names = dict()
     for buf_info in buffer_info_arr:
@@ -540,12 +541,13 @@ def test_resnet_subgraph(algorithm, workspace_size):
     tir_mod = _assign_targets_to_primfuncs_irmodule(tir_mod, target)
     tir_mod = _assign_poolinfos_to_allocates_in_irmodule(tir_mod, [global_workspace_pool])
     main_func = tir_mod["tvmgen_default_run_model"]
-    buffer_info_map = tvm.tir.usmp.analysis.extract_buffer_info(main_func, tir_mod)
+    buffer_info_analysis = tvm.tir.usmp.analysis.extract_buffer_info(main_func, tir_mod)
+    assert buffer_info_analysis.memory_pressure == 7200256
 
     fcreate_array_bi = tvm.get_global_func("tir.usmp.CreateArrayBufferInfo")
-    buffer_info_arr = fcreate_array_bi(buffer_info_map)
+    buffer_info_arr = fcreate_array_bi(buffer_info_analysis.buffer_info_stmts)
     fusmp_algo = tvm.get_global_func(f"tir.usmp.algo.{algorithm}")
-    buffer_pool_allocations = fusmp_algo(buffer_info_arr)
+    buffer_pool_allocations = fusmp_algo(buffer_info_arr, buffer_info_analysis.memory_pressure)
 
     buffer_info_map_names = dict()
     for buf_info in buffer_info_arr:
diff --git a/tests/python/unittest/test_tir_usmp_analysis_extract_bufferinfo.py b/tests/python/unittest/test_tir_usmp_analysis_extract_bufferinfo.py
index abaa0cd..ed8ff32 100644
--- a/tests/python/unittest/test_tir_usmp_analysis_extract_bufferinfo.py
+++ b/tests/python/unittest/test_tir_usmp_analysis_extract_bufferinfo.py
@@ -176,8 +176,9 @@ def test_linear():
     tir_mod = _assign_poolinfos_to_allocates_in_irmodule(
         tir_mod, [fast_memory_pool, slow_memory_pool]
     )
-    buffer_info_map = tvm.tir.usmp.analysis.extract_buffer_info(tir_mod["run_model"], tir_mod)
-    buffer_info_map = _replace_stmt_with_buf_var_names(buffer_info_map)
+    buffer_info_analysis = tvm.tir.usmp.analysis.extract_buffer_info(tir_mod["run_model"], tir_mod)
+    assert buffer_info_analysis.memory_pressure == 1117718
+    buffer_info_map = _replace_stmt_with_buf_var_names(buffer_info_analysis.buffer_info_stmts)
 
     # check conflicts
     _verify_conflicts("PaddedInput_7", ["sid_9", "sid_8", "Conv2dOutput_7"], buffer_info_map)
@@ -293,8 +294,9 @@ def test_parallel_serial_mixed_for_loops():
         all_serial_tir_mod, [global_ws_pool]
     )
     main_func = all_serial_tir_mod["run_model"]
-    buffer_info_map = tvm.tir.usmp.analysis.extract_buffer_info(main_func, all_serial_tir_mod)
-    buffer_info_map = _replace_stmt_with_buf_var_names(buffer_info_map)
+    buffer_info_analysis = tvm.tir.usmp.analysis.extract_buffer_info(main_func, all_serial_tir_mod)
+    assert buffer_info_analysis.memory_pressure == 430848
+    buffer_info_map = _replace_stmt_with_buf_var_names(buffer_info_analysis.buffer_info_stmts)
 
     # When all loops are serial all allocates are touched by USMP
     assert len(buffer_info_map) == 3
@@ -309,10 +311,11 @@ def test_parallel_serial_mixed_for_loops():
         parallel_serial_mixed_tir_mod, [global_ws_pool]
     )
     main_func = parallel_serial_mixed_tir_mod["run_model"]
-    buffer_info_map = tvm.tir.usmp.analysis.extract_buffer_info(
+    buffer_info_analysis = tvm.tir.usmp.analysis.extract_buffer_info(
         main_func, parallel_serial_mixed_tir_mod
     )
-    buffer_info_map = _replace_stmt_with_buf_var_names(buffer_info_map)
+    assert buffer_info_analysis.memory_pressure == 430848
+    buffer_info_map = _replace_stmt_with_buf_var_names(buffer_info_analysis.buffer_info_stmts)
 
     # USMP will not touch (yet) the allocates inside parallel for loops
     assert len(buffer_info_map) == 2
@@ -656,8 +659,9 @@ def test_inception_structure():
     tir_mod = _assign_targets_to_primfuncs_irmodule(tir_mod, target)
     tir_mod = _assign_poolinfos_to_allocates_in_irmodule(tir_mod, [global_ws_pool])
     main_func = tir_mod["run_model"]
-    buffer_info_map = tvm.tir.usmp.analysis.extract_buffer_info(main_func, tir_mod)
-    buffer_info_map = _replace_stmt_with_buf_var_names(buffer_info_map)
+    buffer_info_analysis = tvm.tir.usmp.analysis.extract_buffer_info(main_func, tir_mod)
+    assert buffer_info_analysis.memory_pressure == 1117718
+    buffer_info_map = _replace_stmt_with_buf_var_names(buffer_info_analysis.buffer_info_stmts)
 
     # check conflicts
     _verify_conflicts(
@@ -1369,8 +1373,9 @@ def test_multiple_calls_to_same_primfunc():
     tir_mod = _assign_targets_to_primfuncs_irmodule(tir_mod, target)
     tir_mod = _assign_poolinfos_to_allocates_in_irmodule(tir_mod, [global_ws_pool])
     main_func = tir_mod["run_model"]
-    buffer_info_map = tvm.tir.usmp.analysis.extract_buffer_info(main_func, tir_mod)
-    buffer_info_map = _replace_stmt_with_buf_var_names(buffer_info_map)
+    buffer_info_analysis = tvm.tir.usmp.analysis.extract_buffer_info(main_func, tir_mod)
+    assert buffer_info_analysis.memory_pressure == 11424
+    buffer_info_map = _replace_stmt_with_buf_var_names(buffer_info_analysis.buffer_info_stmts)
 
     # check conflicts
     _verify_conflicts(
diff --git a/tests/python/unittest/test_tir_usmp_utils.py b/tests/python/unittest/test_tir_usmp_utils.py
index 232bf6a..34e526a 100644
--- a/tests/python/unittest/test_tir_usmp_utils.py
+++ b/tests/python/unittest/test_tir_usmp_utils.py
@@ -193,10 +193,10 @@ def test_create_array_buffer_info():
     tir_mod = _assign_targets_to_primfuncs_irmodule(tir_mod, target)
     tir_mod = _assign_poolinfos_to_allocates_in_irmodule(tir_mod, [global_ws_pool])
     main_func = tir_mod["tvmgen_default_run_model"]
-    buffer_info_map = tvm.tir.usmp.analysis.extract_buffer_info(main_func, tir_mod)
-    buffer_info_array = fcreate_array_bi(buffer_info_map)
+    buffer_info_analysis = tvm.tir.usmp.analysis.extract_buffer_info(main_func, tir_mod)
+    buffer_info_array = fcreate_array_bi(buffer_info_analysis.buffer_info_stmts)
     for buffer_info in buffer_info_array:
-        assert buffer_info in buffer_info_map.keys()
+        assert buffer_info in buffer_info_analysis.buffer_info_stmts.keys()
 
 
 if __name__ == "__main__":