You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by le...@apache.org on 2021/12/07 09:37:30 UTC
[tvm] branch main updated: [TIR][USMP] Augmenting the algo interface with memory pressure (#9649)
This is an automated email from the ASF dual-hosted git repository.
leandron pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new cb132e2 [TIR][USMP] Augmenting the algo interface with memory pressure (#9649)
cb132e2 is described below
commit cb132e28084dda92f27c73f2aad2ac8252956062
Author: Manupa Karunaratne <ma...@arm.com>
AuthorDate: Tue Dec 7 09:37:09 2021 +0000
[TIR][USMP] Augmenting the algo interface with memory pressure (#9649)
This commit adds memory pressue to be an arugment to
the USMP algorithm interface as certain iterative algorithms
could use this as a guide determine the termination
criteria.
Change-Id: I3fb5eea3fe5ba43e68c23625d411e557f6dd89a3
---
include/tvm/tir/usmp/utils.h | 39 ++++++++++++++++++++++
src/tir/usmp/algo/greedy.cc | 14 ++++----
src/tir/usmp/analysis/extract_buffer_info.cc | 31 +++++++++--------
src/tir/usmp/utils.cc | 22 ++++++++++++
tests/python/unittest/test_tir_usmp_algo.py | 22 ++++++------
.../test_tir_usmp_analysis_extract_bufferinfo.py | 25 ++++++++------
tests/python/unittest/test_tir_usmp_utils.py | 6 ++--
7 files changed, 114 insertions(+), 45 deletions(-)
diff --git a/include/tvm/tir/usmp/utils.h b/include/tvm/tir/usmp/utils.h
index 32a2bc6..145c61d 100644
--- a/include/tvm/tir/usmp/utils.h
+++ b/include/tvm/tir/usmp/utils.h
@@ -154,6 +154,45 @@ class BufferInfo : public ObjectRef {
};
/*!
+ * \brief This is a composite node that is produced by extract_buffer_info
+ * analysis pass that contains useful global information that could be useful
+ * for memory planning algorithms.
+ */
+struct BufferInfoAnalysisNode : public Object {
+ /*! \brief The BufferInfo object and its associated TIR statement */
+ Map<BufferInfo, tir::Stmt> buffer_info_stmts;
+ /*! \brief This represent maximum amount of memory being used at
+ * any point of time in the inference. This value is largely the
+ * best allocation an algorithm could achieve. Due to
+ * the complexities of conflict graphs, it would not be feasible
+ * to achieve this value, practically. However, it can be useful
+ * for iterative algorithms to know this value to define termination
+ * criteria.*/
+ Integer memory_pressure;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {
+ v->Visit("buffer_info_stmts", &buffer_info_stmts);
+ v->Visit("memory_pressure", &memory_pressure);
+ }
+
+ bool SEqualReduce(const BufferInfoAnalysisNode* other, SEqualReducer equal) const {
+ return equal(buffer_info_stmts, other->buffer_info_stmts) &&
+ equal(memory_pressure, other->memory_pressure);
+ }
+
+ void SHashReduce(SHashReducer hash_reduce) const {
+ hash_reduce(buffer_info_stmts);
+ hash_reduce(memory_pressure);
+ }
+};
+
+class BufferInfoAnalysis : public ObjectRef {
+ public:
+ TVM_DLL BufferInfoAnalysis(Map<BufferInfo, tir::Stmt> buffer_info_stmts, Integer memory_pressure);
+ TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(BufferInfoAnalysis, ObjectRef, BufferInfoAnalysisNode);
+};
+
+/*!
* \brief The pool allocation produced after the USMP algorithm
*/
struct PoolAllocationNode : public Object {
diff --git a/src/tir/usmp/algo/greedy.cc b/src/tir/usmp/algo/greedy.cc
index b98d828..5e1ce5f 100644
--- a/src/tir/usmp/algo/greedy.cc
+++ b/src/tir/usmp/algo/greedy.cc
@@ -209,22 +209,24 @@ class GreedyConflicts : public GreedyBase {
}
};
-Map<BufferInfo, PoolAllocation> GreedyBySize(const Array<BufferInfo>& buffer_info_arr) {
+Map<BufferInfo, PoolAllocation> GreedyBySize(const Array<BufferInfo>& buffer_info_arr,
+ const Integer& memory_pressure) {
return GreedySize().PlanMemory(buffer_info_arr);
}
-Map<BufferInfo, PoolAllocation> GreedyByConflicts(const Array<BufferInfo>& buffer_info_arr) {
+Map<BufferInfo, PoolAllocation> GreedyByConflicts(const Array<BufferInfo>& buffer_info_arr,
+ const Integer& memory_pressure) {
return GreedyConflicts().PlanMemory(buffer_info_arr);
}
TVM_REGISTER_GLOBAL("tir.usmp.algo.greedy_by_size")
- .set_body_typed([](Array<BufferInfo> buffer_info_arr) {
- return GreedyBySize(buffer_info_arr);
+ .set_body_typed([](Array<BufferInfo> buffer_info_arr, Integer memory_pressure) {
+ return GreedyBySize(buffer_info_arr, memory_pressure);
});
TVM_REGISTER_GLOBAL("tir.usmp.algo.greedy_by_conflicts")
- .set_body_typed([](Array<BufferInfo> buffer_info_arr) {
- return GreedyByConflicts(buffer_info_arr);
+ .set_body_typed([](Array<BufferInfo> buffer_info_arr, Integer memory_pressure) {
+ return GreedyByConflicts(buffer_info_arr, memory_pressure);
});
} // namespace algo
diff --git a/src/tir/usmp/analysis/extract_buffer_info.cc b/src/tir/usmp/analysis/extract_buffer_info.cc
index 3fea721..ea53f27 100644
--- a/src/tir/usmp/analysis/extract_buffer_info.cc
+++ b/src/tir/usmp/analysis/extract_buffer_info.cc
@@ -63,7 +63,7 @@ class BufferInfoExtractor : public StmtExprVisitor {
// Pushing a scope info for the initial body of the main function
scope_stack_.push(ScopeInfo());
}
- Map<BufferInfo, tir::Stmt> operator()(const PrimFunc& func);
+ BufferInfoAnalysis operator()(const PrimFunc& func);
private:
void VisitStmt(const Stmt& n) override;
@@ -400,7 +400,7 @@ void BufferInfoExtractor::VisitExpr_(const CallNode* op) {
StmtExprVisitor::VisitExpr_(op);
}
-Map<BufferInfo, tir::Stmt> BufferInfoExtractor::operator()(const PrimFunc& main_func) {
+BufferInfoAnalysis BufferInfoExtractor::operator()(const PrimFunc& main_func) {
VisitPrimFunc(main_func, Call());
// Create a vector of liveness events
@@ -454,33 +454,32 @@ Map<BufferInfo, tir::Stmt> BufferInfoExtractor::operator()(const PrimFunc& main_
// Traverse the liveness events using a open set to track what
// is live while updating the conflicts through out the linear traversal
- std::unordered_map<BufferInfo, int, ObjectPtrHash, ObjectPtrEqual> open_set;
+
+ int open_set_size = 0;
+ int max_open_set_size = 0;
+ std::unordered_set<BufferInfo, ObjectPtrHash, ObjectPtrEqual> open_set;
for (const auto& le_event : le_events_timeline) {
if (le_event.le_type == START) {
- for (const auto& kv : open_set) {
- BufferInfo open_buffer_info = kv.first;
+ for (const BufferInfo& open_buffer_info : open_set) {
open_buffer_info->conflicts.push_back(le_event.buffer_info);
if (le_event.buffer_info != open_buffer_info) {
le_event.buffer_info->conflicts.push_back(open_buffer_info);
}
}
- if (open_set.find(le_event.buffer_info) == open_set.end()) {
- open_set[le_event.buffer_info] = 1;
- } else {
- open_set[le_event.buffer_info] += 1;
+ open_set_size += le_event.buffer_info->size_bytes;
+ if (open_set_size > max_open_set_size) {
+ max_open_set_size = open_set_size;
}
+ open_set.insert(le_event.buffer_info);
} else {
- if (open_set[le_event.buffer_info] == 1) {
- open_set.erase(le_event.buffer_info);
- } else {
- open_set[le_event.buffer_info] -= 1;
- }
+ open_set_size -= le_event.buffer_info->size_bytes;
+ open_set.erase(le_event.buffer_info);
}
}
- return this->buffer_info_map_;
+ return BufferInfoAnalysis(this->buffer_info_map_, max_open_set_size);
}
-Map<BufferInfo, tir::Stmt> ExtractBufferInfo(const PrimFunc& main_func, const IRModule& mod) {
+BufferInfoAnalysis ExtractBufferInfo(const PrimFunc& main_func, const IRModule& mod) {
return BufferInfoExtractor(mod)(main_func);
}
diff --git a/src/tir/usmp/utils.cc b/src/tir/usmp/utils.cc
index b7177cc..7a6a683 100644
--- a/src/tir/usmp/utils.cc
+++ b/src/tir/usmp/utils.cc
@@ -66,6 +66,28 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
<< ",\n alignment=" << node->alignment << ")";
});
+BufferInfoAnalysis::BufferInfoAnalysis(Map<BufferInfo, tir::Stmt> buffer_info_stmts,
+ Integer memory_pressure) {
+ auto bufinfo_analysis_node = make_object<BufferInfoAnalysisNode>();
+ bufinfo_analysis_node->buffer_info_stmts = buffer_info_stmts;
+ bufinfo_analysis_node->memory_pressure = memory_pressure;
+ data_ = std::move(bufinfo_analysis_node);
+}
+
+TVM_REGISTER_NODE_TYPE(BufferInfoAnalysisNode);
+TVM_REGISTER_GLOBAL("tir.usmp.BufferInfoAnalysis")
+ .set_body_typed([](Map<BufferInfo, tir::Stmt> buffer_info_stmts, Integer memory_pressure) {
+ return BufferInfoAnalysis(buffer_info_stmts, memory_pressure);
+ });
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+ .set_dispatch<BufferInfoAnalysisNode>([](const ObjectRef& ref, ReprPrinter* p) {
+ auto* node = static_cast<const BufferInfoAnalysisNode*>(ref.get());
+ p->stream << "BufferInfoAnalysisNode(\n"
+ << "buffer_info_stmts=" << node->buffer_info_stmts
+ << ",\n memory_pressure=" << node->memory_pressure << ")";
+ });
+
PoolInfo::PoolInfo(String pool_name, Map<Target, String> target_access, Integer size_hint_bytes) {
auto poolinfo_node = make_object<PoolInfoNode>();
poolinfo_node->pool_name = pool_name;
diff --git a/tests/python/unittest/test_tir_usmp_algo.py b/tests/python/unittest/test_tir_usmp_algo.py
index 61a70ad..1a763d0 100644
--- a/tests/python/unittest/test_tir_usmp_algo.py
+++ b/tests/python/unittest/test_tir_usmp_algo.py
@@ -120,7 +120,7 @@ def test_no_pool_error():
with pytest.raises(
tvm.TVMError, match="TVM USMP Error: the space available in the provided pools exceeded"
):
- buffer_pool_allocations = fusmp_algo(buffer_info_arr)
+ buffer_pool_allocations = fusmp_algo(buffer_info_arr, 0)
@pytest.mark.parametrize("algorithm", ["greedy_by_size", "greedy_by_conflicts"])
@@ -148,7 +148,7 @@ def test_name_based_ordering(algorithm):
buffer_info_arr = [bi_a, bi_b, bi_c]
fusmp_algo = tvm.get_global_func(f"tir.usmp.algo.{algorithm}")
- buffer_pool_allocations = fusmp_algo(buffer_info_arr)
+ buffer_pool_allocations = fusmp_algo(buffer_info_arr, 0)
assert buffer_pool_allocations[bi_a].byte_offset == 20
assert buffer_pool_allocations[bi_b].byte_offset == 10
assert buffer_pool_allocations[bi_c].byte_offset == 0
@@ -216,7 +216,7 @@ def test_linear(algorithm, workspace_size):
buffer_info_arr = [bi_a, bi_b, bi_c, bi_d, bi_e, bi_f]
fusmp_algo = tvm.get_global_func(f"tir.usmp.algo.{algorithm}")
- buffer_pool_allocations = fusmp_algo(buffer_info_arr)
+ buffer_pool_allocations = fusmp_algo(buffer_info_arr, 0)
_check_max_workspace_size(buffer_pool_allocations, global_workspace_pool, workspace_size)
@@ -287,7 +287,7 @@ def test_fanout(algorithm, workspace_size):
buffer_info_arr = [bi_a, bi_b, bi_c, bi_d, bi_e, bi_f, bi_g]
fusmp_algo = tvm.get_global_func(f"tir.usmp.algo.{algorithm}")
- buffer_pool_allocations = fusmp_algo(buffer_info_arr)
+ buffer_pool_allocations = fusmp_algo(buffer_info_arr, 0)
_check_max_workspace_size(buffer_pool_allocations, global_workspace_pool, workspace_size)
@@ -382,12 +382,13 @@ def test_mobilenet_subgraph(algorithm, fast_memory_size, slow_memory_size):
tir_mod, [fast_memory_pool, slow_memory_pool]
)
main_func = tir_mod["run_model"]
- buffer_info_map = tvm.tir.usmp.analysis.extract_buffer_info(main_func, tir_mod)
+ buffer_info_analysis = tvm.tir.usmp.analysis.extract_buffer_info(main_func, tir_mod)
+ assert buffer_info_analysis.memory_pressure == 1117718
fcreate_array_bi = tvm.get_global_func("tir.usmp.CreateArrayBufferInfo")
- buffer_info_arr = fcreate_array_bi(buffer_info_map)
+ buffer_info_arr = fcreate_array_bi(buffer_info_analysis.buffer_info_stmts)
fusmp_algo = tvm.get_global_func(f"tir.usmp.algo.{algorithm}")
- buffer_pool_allocations = fusmp_algo(buffer_info_arr)
+ buffer_pool_allocations = fusmp_algo(buffer_info_arr, buffer_info_analysis.memory_pressure)
buffer_info_map_names = dict()
for buf_info in buffer_info_arr:
@@ -540,12 +541,13 @@ def test_resnet_subgraph(algorithm, workspace_size):
tir_mod = _assign_targets_to_primfuncs_irmodule(tir_mod, target)
tir_mod = _assign_poolinfos_to_allocates_in_irmodule(tir_mod, [global_workspace_pool])
main_func = tir_mod["tvmgen_default_run_model"]
- buffer_info_map = tvm.tir.usmp.analysis.extract_buffer_info(main_func, tir_mod)
+ buffer_info_analysis = tvm.tir.usmp.analysis.extract_buffer_info(main_func, tir_mod)
+ assert buffer_info_analysis.memory_pressure == 7200256
fcreate_array_bi = tvm.get_global_func("tir.usmp.CreateArrayBufferInfo")
- buffer_info_arr = fcreate_array_bi(buffer_info_map)
+ buffer_info_arr = fcreate_array_bi(buffer_info_analysis.buffer_info_stmts)
fusmp_algo = tvm.get_global_func(f"tir.usmp.algo.{algorithm}")
- buffer_pool_allocations = fusmp_algo(buffer_info_arr)
+ buffer_pool_allocations = fusmp_algo(buffer_info_arr, buffer_info_analysis.memory_pressure)
buffer_info_map_names = dict()
for buf_info in buffer_info_arr:
diff --git a/tests/python/unittest/test_tir_usmp_analysis_extract_bufferinfo.py b/tests/python/unittest/test_tir_usmp_analysis_extract_bufferinfo.py
index abaa0cd..ed8ff32 100644
--- a/tests/python/unittest/test_tir_usmp_analysis_extract_bufferinfo.py
+++ b/tests/python/unittest/test_tir_usmp_analysis_extract_bufferinfo.py
@@ -176,8 +176,9 @@ def test_linear():
tir_mod = _assign_poolinfos_to_allocates_in_irmodule(
tir_mod, [fast_memory_pool, slow_memory_pool]
)
- buffer_info_map = tvm.tir.usmp.analysis.extract_buffer_info(tir_mod["run_model"], tir_mod)
- buffer_info_map = _replace_stmt_with_buf_var_names(buffer_info_map)
+ buffer_info_analysis = tvm.tir.usmp.analysis.extract_buffer_info(tir_mod["run_model"], tir_mod)
+ assert buffer_info_analysis.memory_pressure == 1117718
+ buffer_info_map = _replace_stmt_with_buf_var_names(buffer_info_analysis.buffer_info_stmts)
# check conflicts
_verify_conflicts("PaddedInput_7", ["sid_9", "sid_8", "Conv2dOutput_7"], buffer_info_map)
@@ -293,8 +294,9 @@ def test_parallel_serial_mixed_for_loops():
all_serial_tir_mod, [global_ws_pool]
)
main_func = all_serial_tir_mod["run_model"]
- buffer_info_map = tvm.tir.usmp.analysis.extract_buffer_info(main_func, all_serial_tir_mod)
- buffer_info_map = _replace_stmt_with_buf_var_names(buffer_info_map)
+ buffer_info_analysis = tvm.tir.usmp.analysis.extract_buffer_info(main_func, all_serial_tir_mod)
+ assert buffer_info_analysis.memory_pressure == 430848
+ buffer_info_map = _replace_stmt_with_buf_var_names(buffer_info_analysis.buffer_info_stmts)
# When all loops are serial all allocates are touched by USMP
assert len(buffer_info_map) == 3
@@ -309,10 +311,11 @@ def test_parallel_serial_mixed_for_loops():
parallel_serial_mixed_tir_mod, [global_ws_pool]
)
main_func = parallel_serial_mixed_tir_mod["run_model"]
- buffer_info_map = tvm.tir.usmp.analysis.extract_buffer_info(
+ buffer_info_analysis = tvm.tir.usmp.analysis.extract_buffer_info(
main_func, parallel_serial_mixed_tir_mod
)
- buffer_info_map = _replace_stmt_with_buf_var_names(buffer_info_map)
+ assert buffer_info_analysis.memory_pressure == 430848
+ buffer_info_map = _replace_stmt_with_buf_var_names(buffer_info_analysis.buffer_info_stmts)
# USMP will not touch (yet) the allocates inside parallel for loops
assert len(buffer_info_map) == 2
@@ -656,8 +659,9 @@ def test_inception_structure():
tir_mod = _assign_targets_to_primfuncs_irmodule(tir_mod, target)
tir_mod = _assign_poolinfos_to_allocates_in_irmodule(tir_mod, [global_ws_pool])
main_func = tir_mod["run_model"]
- buffer_info_map = tvm.tir.usmp.analysis.extract_buffer_info(main_func, tir_mod)
- buffer_info_map = _replace_stmt_with_buf_var_names(buffer_info_map)
+ buffer_info_analysis = tvm.tir.usmp.analysis.extract_buffer_info(main_func, tir_mod)
+ assert buffer_info_analysis.memory_pressure == 1117718
+ buffer_info_map = _replace_stmt_with_buf_var_names(buffer_info_analysis.buffer_info_stmts)
# check conflicts
_verify_conflicts(
@@ -1369,8 +1373,9 @@ def test_multiple_calls_to_same_primfunc():
tir_mod = _assign_targets_to_primfuncs_irmodule(tir_mod, target)
tir_mod = _assign_poolinfos_to_allocates_in_irmodule(tir_mod, [global_ws_pool])
main_func = tir_mod["run_model"]
- buffer_info_map = tvm.tir.usmp.analysis.extract_buffer_info(main_func, tir_mod)
- buffer_info_map = _replace_stmt_with_buf_var_names(buffer_info_map)
+ buffer_info_analysis = tvm.tir.usmp.analysis.extract_buffer_info(main_func, tir_mod)
+ assert buffer_info_analysis.memory_pressure == 11424
+ buffer_info_map = _replace_stmt_with_buf_var_names(buffer_info_analysis.buffer_info_stmts)
# check conflicts
_verify_conflicts(
diff --git a/tests/python/unittest/test_tir_usmp_utils.py b/tests/python/unittest/test_tir_usmp_utils.py
index 232bf6a..34e526a 100644
--- a/tests/python/unittest/test_tir_usmp_utils.py
+++ b/tests/python/unittest/test_tir_usmp_utils.py
@@ -193,10 +193,10 @@ def test_create_array_buffer_info():
tir_mod = _assign_targets_to_primfuncs_irmodule(tir_mod, target)
tir_mod = _assign_poolinfos_to_allocates_in_irmodule(tir_mod, [global_ws_pool])
main_func = tir_mod["tvmgen_default_run_model"]
- buffer_info_map = tvm.tir.usmp.analysis.extract_buffer_info(main_func, tir_mod)
- buffer_info_array = fcreate_array_bi(buffer_info_map)
+ buffer_info_analysis = tvm.tir.usmp.analysis.extract_buffer_info(main_func, tir_mod)
+ buffer_info_array = fcreate_array_bi(buffer_info_analysis.buffer_info_stmts)
for buffer_info in buffer_info_array:
- assert buffer_info in buffer_info_map.keys()
+ assert buffer_info in buffer_info_analysis.buffer_info_stmts.keys()
if __name__ == "__main__":