You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by cj...@apache.org on 2018/03/23 18:10:28 UTC
[incubator-mxnet] 01/03: no bulk exec for aggregate
This is an automated email from the ASF dual-hosted git repository.
cjolivier01 pushed a commit to branch cython
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
commit 306626c3fd6e9124f436ae85fc1fc192912302ad
Author: Olivier <co...@amazon.com>
AuthorDate: Mon Mar 12 13:41:47 2018 -0700
no bulk exec for aggregate
---
src/executor/graph_executor.cc | 13 +++++++++----
src/imperative/imperative.cc | 7 ++++++-
2 files changed, 15 insertions(+), 5 deletions(-)
diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc
index 7d31a31..8c0ac52 100644
--- a/src/executor/graph_executor.cc
+++ b/src/executor/graph_executor.cc
@@ -1346,11 +1346,12 @@ void GraphExecutor::InitOpSegs() {
if (monitor_callback_) return;
// Generate segments based on the graph structure
- bool prefer_bulk_exec_inference = dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_INFERENCE", true);
+ const bool prefer_bulk_exec_inference = dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_INFERENCE", true);
+
// Whether to perform bulk exec for training
- bool prefer_bulk_exec = dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_TRAIN", 1);
+ const bool prefer_bulk_exec = dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_TRAIN", 1) != 0;
- bool is_training = num_forward_nodes_ != total_num_nodes;
+ const bool is_training = num_forward_nodes_ != total_num_nodes;
if (prefer_bulk_exec && is_training) {
this->BulkTrainingOpSegs(total_num_nodes);
@@ -1365,7 +1366,11 @@ void GraphExecutor::InitOpSegs() {
void GraphExecutor::BulkTrainingOpSegs(size_t total_num_nodes) {
// The maximum number of node in a segment executed in bulk
- size_t num_nodes_threshold = dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15);
+
+ const size_t num_nodes_threshold = dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15U);
+
+// const size_t num_nodes_threshold = profiler::Profiler::Get()->AggregateEnabled() ? 1 :
+// dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15U);
// create forward segments for training
size_t topo_start = 0;
diff --git a/src/imperative/imperative.cc b/src/imperative/imperative.cc
index fbbaf82..a915255 100644
--- a/src/imperative/imperative.cc
+++ b/src/imperative/imperative.cc
@@ -18,6 +18,7 @@
*/
#include <unordered_set>
#include <iostream>
+#include "../profiler/profiler.h"
#include "./imperative_utils.h"
namespace mxnet {
@@ -579,7 +580,11 @@ std::vector<NDArray*> Imperative::Backward(
bool prev_recording = set_is_recording(create_graph);
bool prev_training = set_is_training(is_train);
- int prev_bulk_size = Engine::Get()->set_bulk_size(backward_bulk_size_);
+
+// const int prev_bulk_size = Engine::Get()->set_bulk_size(
+// profiler::Profiler::Get()->AggregateEnabled() ? 1 : backward_bulk_size_);
+
+ const int prev_bulk_size = Engine::Get()->set_bulk_size(backward_bulk_size_);
RunGraph(retain_graph, idx, arrays, num_forward_nodes, idx.num_nodes(),
std::move(array_reqs), std::move(ref_count), &states, dispatch_modes);
--
To stop receiving notification emails like this one, please contact
cjolivier01@apache.org.