You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2018/05/09 17:45:14 UTC
[incubator-mxnet] branch master updated: Add Util Function for
Memory Plan Inspection (#10859)
This is an automated email from the ASF dual-hosted git repository.
jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new c5958f0 Add Util Function for Memory Plan Inspection (#10859)
c5958f0 is described below
commit c5958f0ade579487ecdbd6c1bcb67538f531c62a
Author: Haibin Lin <li...@gmail.com>
AuthorDate: Wed May 9 10:45:02 2018 -0700
Add Util Function for Memory Plan Inspection (#10859)
* add inplace option for bn backward
* Add util code for debugging memory plan for the graph
* update var name
* fix bug
* fix lint
* add example
---
src/common/exec_utils.h | 124 ++++++++++++++++++++++++++++++++++
src/executor/graph_executor.cc | 7 ++
src/executor/infer_graph_attr_pass.cc | 31 +--------
3 files changed, 133 insertions(+), 29 deletions(-)
diff --git a/src/common/exec_utils.h b/src/common/exec_utils.h
index 3ac86fb..4881e0f 100644
--- a/src/common/exec_utils.h
+++ b/src/common/exec_utils.h
@@ -25,6 +25,8 @@
#define MXNET_COMMON_EXEC_UTILS_H_
#include <vector>
+#include <string>
+#include <utility>
#include "../common/utils.h"
namespace mxnet {
@@ -226,7 +228,129 @@ inline bool DefaultStorageType(const nnvm::NodeAttrs& attrs,
return true;
}
+// string representation of storage id
+inline std::string storage_str(int storage_id) {
+ std::string str;
+ if (storage_id == -1) {
+ str = "var (-1)";
+ } else if (storage_id == -2) {
+ str = "external storage (-2)";
+ } else {
+ str = "group " + std::to_string(storage_id);
+ }
+ return str;
+}
+
+/* log the static memory plan of the graph. Example:
+ node 0 var
+ node 1 _copy
+ input 0: [80,3,224,224] (47040 KB) -> var storage (-1)
+ output 1: [80,3,224,224] (47040 KB) -> group 0
+ node 2 var
+ node 3 var
+ node 4 var
+ node 5 var
+ node 6 BatchNorm
+ input 1: [80,3,224,224] (47040 KB) -> group 0
+ input 2: [3] (0 KB) -> var storage (-1)
+ input 3: [3] (0 KB) -> var storage (-1)
+ input 4: [3] (0 KB) -> var storage (-1)
+ input 5: [3] (0 KB) -> var storage (-1)
+ output 6: [80,3,224,224] (47040 KB) -> group 1
+ output 7: [3] (0 KB) -> group 3
+ output 8: [3] (0 KB) -> group 2
+ ...
+ */
+inline void LogMemoryPlan(const nnvm::Graph& g) {
+ const auto &idx = g.indexed_graph();
+ const auto& vshape = g.GetAttr<nnvm::ShapeVector>("shape");
+ const auto& vtype = g.GetAttr<nnvm::DTypeVector>("dtype");
+ const auto& vstorage = g.GetAttr<nnvm::StorageVector>("storage_id");
+ // find node range
+ uint32_t node_start = 0, node_end = idx.num_nodes();
+ if (g.attrs.count("node_range")) {
+ const auto& range = g.GetAttr<std::pair<uint32_t, uint32_t> >("node_range");
+ node_start = range.first;
+ node_end = range.second;
+ }
+ for (uint32_t nid = node_start; nid < node_end; ++nid) {
+ const auto& inode = idx[nid];
+ if (inode.source->is_variable()) {
+ LOG(INFO) << "node " << nid << " var";
+ } else {
+ LOG(INFO) << "node " << nid << " " << inode.source->attrs.op->name;
+ for (const auto& e : inode.inputs) {
+ auto eid = idx.entry_id(e);
+ size_t kilo_bytes = vshape[eid].Size() * mshadow::mshadow_sizeof(vtype[eid]) / 1024;
+ LOG(INFO) << "\t\tinput " << eid << ": " << vshape[eid] << " ("
+ << kilo_bytes << " KB) -> " << storage_str(vstorage[eid]);
+ }
+ for (uint32_t index = 0; index < inode.source->num_outputs(); ++index) {
+ uint32_t eid = idx.entry_id(nid, index);
+ size_t kilo_bytes = vshape[eid].Size() * mshadow::mshadow_sizeof(vtype[eid]) / 1024;
+ LOG(INFO) << "\t\toutput " << eid << ": " << vshape[eid] << " ("
+ << kilo_bytes << " KB) -> " << storage_str(vstorage[eid]);
+ }
+ }
+ }
+}
+
+/* log the static memory plan of the graph. Example:
+ node 0 var
+ node 1 _copy: fcompute
+ input 0: default
+ output 1: default
+ node 2 var
+ node 3 Convolution: fcompute
+ input 1: default
+ input 2: default
+ output 3: default
+ node 4 var
+ node 5 var
+ node 6 var
+ node 7 var
+ node 8 BatchNorm: fcompute
+ input 3: default
+ input 4: default
+ input 5: default
+ input 6: default
+ input 7: default
+ output 8: default
+ output 9: default
+ output 10: default
+ ...
+ */
+inline void LogInferStorage(const nnvm::Graph& g) {
+ const auto &idx = g.indexed_graph();
+ const auto& vstorage_type = g.GetAttr<StorageTypeVector>("storage_type");
+ const auto& dispatch_modes = g.GetAttr<DispatchModeVector>("dispatch_mode");
+ uint32_t node_start = 0, node_end = idx.num_nodes();
+ if (g.attrs.count("node_range")) {
+ const auto& range = g.GetAttr<std::pair<uint32_t, uint32_t> >("node_range");
+ node_start = range.first;
+ node_end = range.second;
+ }
+ for (uint32_t nid = node_start; nid < node_end; ++nid) {
+ const auto& inode = idx[nid];
+ if (inode.source->is_variable()) {
+ LOG(INFO) << "node " << nid << " var";
+ } else {
+ LOG(INFO) << "node " << nid << " " << inode.source->attrs.op->name
+ << ": " << dispatch_mode_string(dispatch_modes[nid]);
+ for (const auto& e : inode.inputs) {
+ auto eid = idx.entry_id(e);
+ LOG(INFO) << "\t\tinput " << eid << ": " << stype_string(vstorage_type[eid]);
+ }
+ for (uint32_t index = 0; index < inode.source->num_outputs(); ++index) {
+ uint32_t eid = idx.entry_id(nid, index);
+ LOG(INFO) << "\t\toutput " << eid << ": " << stype_string(vstorage_type[eid]);
+ }
+ }
+ }
+}
+
} // namespace common
} // namespace mxnet
#endif // MXNET_COMMON_EXEC_UTILS_H_
+
diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc
index d5dacf7..7a15f6c 100644
--- a/src/executor/graph_executor.cc
+++ b/src/executor/graph_executor.cc
@@ -32,6 +32,7 @@
#include "./graph_executor.h"
#include "../profiler/profiler.h"
#include "../common/utils.h"
+#include "../common/exec_utils.h"
namespace mxnet {
namespace exec {
@@ -904,6 +905,12 @@ void GraphExecutor::FinishInitGraph(nnvm::Symbol symbol,
}
g = DetectInplaceAddTo(g);
+ // log the static memory plan of the graph
+ static bool mem_log_verbose = dmlc::GetEnv("MXNET_MEM_PLAN_VERBOSE_LOGGING", false);
+ if (mem_log_verbose) {
+ common::LogMemoryPlan(g);
+ }
+
g = AttachOpExecs(g);
g = AttachOpResources(g);
graph_ = std::move(g);
diff --git a/src/executor/infer_graph_attr_pass.cc b/src/executor/infer_graph_attr_pass.cc
index 191fbe9..0abee04 100644
--- a/src/executor/infer_graph_attr_pass.cc
+++ b/src/executor/infer_graph_attr_pass.cc
@@ -391,36 +391,9 @@ nnvm::Graph InferStorageType(nnvm::Graph&& graph,
common::DefaultStorageType, false, "dispatch_mode", DispatchMode::kVariable);
// log the storage types and dispatch modes of the graph
- bool log_verbose = dmlc::GetEnv("MXNET_INFER_STORAGE_TYPE_VERBOSE_LOGGING", false);
+ static bool log_verbose = dmlc::GetEnv("MXNET_INFER_STORAGE_TYPE_VERBOSE_LOGGING", false);
if (log_verbose) {
- const auto &idx = ret.indexed_graph();
- const auto& vstorage_type = ret.GetAttr<StorageTypeVector>("storage_type");
- const auto& dispatch_modes = ret.GetAttr<DispatchModeVector>("dispatch_mode");
- uint32_t node_start = 0, node_end = idx.num_nodes();
- if (ret.attrs.count("node_range")) {
- const auto& range = ret.GetAttr<std::pair<uint32_t, uint32_t> >("node_range");
- node_start = range.first;
- node_end = range.second;
- }
- for (uint32_t nid = node_start; nid < node_end; ++nid) {
- const auto& inode = idx[nid];
- if (inode.source->is_variable()) {
- LOG(INFO) << "node " << nid << " var";
- } else {
- LOG(INFO) << "node " << nid << " " << inode.source->attrs.op->name
- << ": " << common::dispatch_mode_string(dispatch_modes[nid]);
- for (const auto& e : inode.inputs) {
- auto eid = idx.entry_id(e);
- LOG(INFO) << "\t\tinput " << eid << ": "
- << common::stype_string(vstorage_type[eid]);
- }
- for (uint32_t index = 0; index < inode.source->num_outputs(); ++index) {
- uint32_t eid = idx.entry_id(nid, index);
- LOG(INFO) << "\t\toutput " << eid << ": "
- << common::stype_string(vstorage_type[eid]);
- }
- }
- }
+ common::LogInferStorage(ret);
}
return ret;
}
--
To stop receiving notification emails like this one, please contact
jxie@apache.org.