You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2019/03/07 05:59:12 UTC

[incubator-mxnet] branch master updated: Bulked op segments to allow Variable nodes (#14200)

This is an automated email from the ASF dual-hosted git repository.

haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 8beea18  Bulked op segments to allow Variable nodes (#14200)
8beea18 is described below

commit 8beea18e3d9835f90b59d3f9de8f9945ac819423
Author: Dick Carter <di...@comcast.net>
AuthorDate: Wed Mar 6 21:58:52 2019 -0800

    Bulked op segments to allow Variable nodes (#14200)
    
    * Bulked op seg size to ignore Variable nodes, limited by MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_{FWD,BWD}.
    
    * Document new env variables. Unify operation with imperative.
    
    * Add timing-based tests of symbol and gluon op bulking.
    
    * Rename test_in_separate_process -> run_in_spawned_process.
    
    * Remove redundant util test_operator_gpu.py:_test_in_separate_process().
    
    * Consolidate references to env vars that set op-bulking policy.
    
    * Test for effect of MXNET_EXEC_BULK_EXEC_TRAIN=0.
    
    * Fix python2 print() issue.
    
    * Trigger CI.
    
    * Consolidate similar op bulking routines.
    
    * Trigger CI.
    
    * Trigger CI.
    
    * Add instrumentation to debug failing CI.
---
 docs/faq/env_var.md                   |   8 ++-
 include/mxnet/imperative.h            |  23 +++++++-
 src/executor/graph_executor.cc        | 106 ++++++++--------------------------
 src/executor/graph_executor.h         |   6 +-
 src/imperative/cached_op.cc           |  11 +++-
 src/imperative/cached_op.h            |   4 +-
 tests/python/gpu/test_gluon_gpu.py    |  78 +++++++++++++++++++++++++
 tests/python/gpu/test_operator_gpu.py |  99 +++++++++++++++++++++++++------
 tests/python/unittest/common.py       |  50 ++++++++++++++++
 9 files changed, 274 insertions(+), 111 deletions(-)

diff --git a/docs/faq/env_var.md b/docs/faq/env_var.md
index f49cb19..095c214 100644
--- a/docs/faq/env_var.md
+++ b/docs/faq/env_var.md
@@ -115,7 +115,13 @@ $env:MXNET_STORAGE_FALLBACK_LOG_VERBOSE=0
   - If set to `1`, during training MXNet executes the computation graph as several subgraphs in bulk mode.
 * MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN
   - Values: Int ```(default=15)```
-  - The maximum number of nodes in the subgraph executed in bulk during training(not inference). Setting this to a larger number may reduce the degree of parallelism for multi-GPU training.
+  - The maximum number of nodes in the subgraph executed in bulk during training (not inference). Setting this to a larger number may reduce the degree of parallelism for multi-GPU training.
+* MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_FWD
+  - Values: Int ```(default=<value of MXNET_EXEC_BULK_MAX_NODE_TRAIN>)```
+  - The maximum number of nodes in the subgraph executed in bulk during training (not inference) in the forward pass.
+* MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_BWD
+  - Values: Int ```(default=<value of MXNET_EXEC_BULK_MAX_NODE_TRAIN>)```
+  - The maximum number of nodes in the subgraph executed in bulk during training (not inference) in the backward pass.
 
 ## Control the Data Communication
 
diff --git a/include/mxnet/imperative.h b/include/mxnet/imperative.h
index 7ea60df..52cedb2 100644
--- a/include/mxnet/imperative.h
+++ b/include/mxnet/imperative.h
@@ -129,14 +129,31 @@ class Imperative {
                                  bool create_graph);
   /*! \return AutogradRuntime singleton */
   static Imperative* Get();
+  /*! \brief Should op execution bulking be employed during inference. */
+  static bool PreferBulkExecInference() {
+    return dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_INFERENCE", true);
+  }
+  /*! \brief Should op execution bulking be employed during training. */
+  static bool PreferBulkExecTrain() {
+    return dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_TRAIN", true);
+  }
+  /*! \brief The max number of op nodes in a bulk during forward pass of training. */
+  static int BulkExecMaxNodeTrainFwd() {
+    return dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_FWD",
+                        dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15));
+  }
+  /*! \brief The max number of op nodes in a bulk during backward pass of training. */
+  static int BulkExecMaxNodeTrainBwd() {
+    return dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_BWD",
+                        dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15));
+  }
 
  private:
   friend class NDArray;
   /*! \brief make constructor protected. */
   Imperative() {
-    if (dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_TRAIN", 1)) {
-      backward_bulk_size_ =  dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15);
-    }
+    if (PreferBulkExecTrain())
+      backward_bulk_size_ = BulkExecMaxNodeTrainBwd();
   }
   /*! \brief find the input/output ndarrays that are needed for backward */
   void GetBackwardDependency(
diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc
index 436eae3..3d74bfb 100644
--- a/src/executor/graph_executor.cc
+++ b/src/executor/graph_executor.cc
@@ -1191,105 +1191,49 @@ void GraphExecutor::InitOpSegs() {
   cached_seg_opr_.resize(total_num_nodes, p);
   if (monitor_callback_) return;
 
+  // Symbolic bulking is set by the same environment variables as Imperative bulking.
   // Generate segments based on the graph structure
-  bool prefer_bulk_exec_inference = dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_INFERENCE", true);
+  bool prefer_bulk_exec_inference = Imperative::PreferBulkExecInference();
   // Whether to perform bulk exec for training
   const profiler::Profiler *prof = profiler::Profiler::Get();
-  bool prefer_bulk_exec = dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_TRAIN", 1)
-                          && (!prof || !prof->AggregateEnabled());
+  bool prefer_bulk_exec_train = Imperative::PreferBulkExecTrain()
+                                && (!prof || !prof->AggregateEnabled());
 
   bool is_training = num_forward_nodes_ != total_num_nodes;
 
-  if (prefer_bulk_exec  && is_training) {
-    this->BulkTrainingOpSegs(total_num_nodes);
+  if (prefer_bulk_exec_train && is_training) {
+    // Bulk the forward portion of the graph per the bulk segment max size for forward training
+    this->BulkOpSegs(0, num_forward_nodes_, Imperative::BulkExecMaxNodeTrainFwd());
+    // Bulk the backward portion of the graph per the bulk segment max size for backward training
+    this->BulkOpSegs(num_forward_nodes_, total_num_nodes, Imperative::BulkExecMaxNodeTrainBwd());
   }
 
   if (prefer_bulk_exec_inference && !is_training) {
-    this->BulkInferenceOpSegs();
+    // Bulk the entire graph as one bulk segment if possible
+    this->BulkOpSegs(0, total_num_nodes, total_num_nodes);
   }
 }
 
 
-void GraphExecutor::BulkTrainingOpSegs(size_t total_num_nodes) {
-  // The maximum number of node in a segment executed in bulk
-  size_t num_nodes_threshold = dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15);
-
-  // create forward segments for training
-  size_t topo_start = 0;
-  for (size_t nid = 0; nid < num_forward_nodes_; nid++) {
-    auto &node = graph_.indexed_graph()[nid].source;
-    auto &op_node = op_nodes_[nid];
-    // check if the segment relies on external input, or exceeds maxinum number of node,
-    // or requires async ops
-    if (node->is_variable() || nid - topo_start > num_nodes_threshold ||
-        op_node.exec->exec_type() != ExecType::kSync) {
-      // create a new segment for the previous nodes if the current one cannot be bulked
-      cached_seg_opr_[topo_start] = this->CreateCachedSegOpr(topo_start, nid);
-      topo_start = nid + 1;
-    }
-  }
-  // the last segment
-  if (topo_start != num_forward_nodes_) {
-    cached_seg_opr_[topo_start] = this->CreateCachedSegOpr(topo_start, num_forward_nodes_);
-  }
-
-  // create backward segments for training
-  // get all gradient variables
-  std::unordered_set<engine::VarHandle> grad_vars;
-  for (auto &kv : grad_store_) {
-    grad_vars.insert(kv.second.var());
-  }
-  auto &idx = graph_.indexed_graph();
-  topo_start = num_forward_nodes_;
-  for (size_t nid = num_forward_nodes_; nid < total_num_nodes; nid++) {
-    auto &op_node = op_nodes_[nid];
-    if (op_node.skip_exec_node || op_node.exec == nullptr) {
-      continue;
-    }
-    if (idx[nid].source->is_variable() || nid - topo_start > num_nodes_threshold ||
-        op_node.exec->exec_type() != ExecType::kSync) {
-      cached_seg_opr_[topo_start] = this->CreateCachedSegOpr(topo_start, nid);
-      topo_start = nid + 1;
-    } else {
-      // If it produces output gradient, don't include it in the segment
-      bool output_gradient = false;
-      for (auto &out_arr : op_node.exec->out_array) {
-        if (grad_vars.find(out_arr.var()) != grad_vars.end()) {
-          output_gradient = true;
-        }
-      }
-      if (output_gradient) {
-        cached_seg_opr_[topo_start] = this->CreateCachedSegOpr(topo_start, nid);
-        topo_start = nid + 1;
-      }
-    }
-  }
-  // last segment for backward
-  if (topo_start < total_num_nodes) {
-    cached_seg_opr_[topo_start] = this->CreateCachedSegOpr(topo_start, total_num_nodes);
-  }
-}
-
-void GraphExecutor::BulkInferenceOpSegs() {
-  // Attempt to bulk the whole graph for inference.  We will only create new segments when
-  // required for non-kSync operations.
-  size_t topo_start = 0;
-  for (size_t nid = 0; nid < num_forward_nodes_; nid++) {
+void GraphExecutor::BulkOpSegs(size_t from_node, size_t up_to_node, size_t segment_num_nodes_max) {
+  size_t topo_start = from_node;
+  size_t segment_node_count = 0;
+  for (size_t nid = from_node; nid < up_to_node; nid++) {
     auto &node = graph_.indexed_graph()[nid].source;
     auto &op_node = op_nodes_[nid];
-
-    // Variables do not need to be segmented at inference time.
-    if (node->is_variable()) continue;
-
-    if (op_node.exec->exec_type() != ExecType::kSync) {
-      cached_seg_opr_[topo_start] = this->CreateCachedSegOpr(topo_start, nid);
+    // Variables, such as learned weights, are ignored in the segment_node_count
+    bool ignore_node = node->is_variable() || op_node.skip_exec_node || op_node.exec == nullptr;
+    if (!ignore_node)
+      segment_node_count++;
+    bool can_bulk = ignore_node || op_node.exec->exec_type() == ExecType::kSync;
+    // check if we need to create the segment based on properties of this node
+    if (!can_bulk || nid == up_to_node - 1 || segment_node_count >= segment_num_nodes_max) {
+      // Create a new segment for the previous nodes- include also this node if it's bulkable
+      cached_seg_opr_[topo_start] = this->CreateCachedSegOpr(topo_start, can_bulk ? nid + 1 : nid);
       topo_start = nid + 1;
+      segment_node_count = 0;
     }
   }
-  // The last segment
-  if (topo_start != num_forward_nodes_) {
-    cached_seg_opr_[topo_start] = this->CreateCachedSegOpr(topo_start, num_forward_nodes_);
-  }
 }
 
 void GraphExecutor::ExecuteMonInputCallback(size_t nid) {
diff --git a/src/executor/graph_executor.h b/src/executor/graph_executor.h
index ed49e5b..b556a2b 100644
--- a/src/executor/graph_executor.h
+++ b/src/executor/graph_executor.h
@@ -213,10 +213,8 @@ class GraphExecutor : public Executor {
   void ExecuteMonInputCallback(size_t nid);
   // run the monitor callback for output of node `nid`
   void ExecuteMonOutputCallback(size_t nid);
-  // peform bulking and segmentation on an inference graph
-  void BulkInferenceOpSegs();
-  // perform bulking and segmentation on a training graph
-  void BulkTrainingOpSegs(size_t total_num_nodes);
+  // peform bulking and segmentation on the region [from_node, up_to_node) of a graph
+  void BulkOpSegs(size_t from_node, size_t up_to_node, size_t segment_num_nodes_max);
   // indicate whether there is a backward graph for gradients.
   bool need_grad_;
   // internal graph
diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc
index 61dfb9c..c9215c5 100644
--- a/src/imperative/cached_op.cc
+++ b/src/imperative/cached_op.cc
@@ -619,9 +619,18 @@ void CachedOp::StaticInitExec(
       SetupOpExec(g, i, state.execs[i], state.arrays, state.array_reqs);
     }
 
+    // Init bulk_size for Inference mode with bulking enabled (= entire forward graph).
     size_t bulk_size = idx.num_nodes();
     if (recording || keep_fwd) {
-      bulk_size = keep_fwd ? config_.backward_bulk_size : config_.forward_bulk_size;
+      // Training mode
+      if (!Imperative::PreferBulkExecTrain())
+        bulk_size = 0;
+      else
+        bulk_size = keep_fwd ? config_.backward_bulk_size : config_.forward_bulk_size;
+    } else {
+      // Inference mode
+      if (!Imperative::PreferBulkExecInference())
+        bulk_size = 0;
     }
 
     CreateEngineOpSeg(idx, default_ctx, start_nid, end_nid, bulk_size,
diff --git a/src/imperative/cached_op.h b/src/imperative/cached_op.h
index 5a0351a..b3192dc 100644
--- a/src/imperative/cached_op.h
+++ b/src/imperative/cached_op.h
@@ -53,10 +53,10 @@ struct CachedOpConfig : public dmlc::Parameter<CachedOpConfig> {
     .set_default(2)
     .describe("Maximum number of operators that can be inlined.");
     DMLC_DECLARE_FIELD(forward_bulk_size)
-    .set_default(dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15))
+    .set_default(Imperative::BulkExecMaxNodeTrainFwd())
     .describe("Segment size of bulk execution during forward pass.");
     DMLC_DECLARE_FIELD(backward_bulk_size)
-    .set_default(dmlc::GetEnv("MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN", 15))
+    .set_default(Imperative::BulkExecMaxNodeTrainBwd())
     .describe("Segment size of bulk execution during backward pass.");
     DMLC_DECLARE_FIELD(data_indices)
     .set_default(nnvm::Tuple<uint32_t>())
diff --git a/tests/python/gpu/test_gluon_gpu.py b/tests/python/gpu/test_gluon_gpu.py
index 54bfcee..88b436a 100644
--- a/tests/python/gpu/test_gluon_gpu.py
+++ b/tests/python/gpu/test_gluon_gpu.py
@@ -38,6 +38,7 @@ from mxnet.test_utils import rand_ndarray
 curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
 sys.path.insert(0, os.path.join(curr_path, '../unittest'))
 from common import setup_module, with_seed, teardown, assert_raises_cudnn_not_satisfied
+from common import run_in_spawned_process
 from test_gluon import *
 from test_loss import *
 from test_gluon_rnn import *
@@ -408,6 +409,83 @@ def test_large_models():
         # Evaluate model
         net(data_in).asnumpy()
 
+# isolated execution bulking test function to be invoked with different env var settings
+def _test_bulking_in_process(seed, time_per_iteration):
+    # Use flip since it's a simple function with same-sized I/O unlikely to ever be fused.
+    class Flip(gluon.HybridBlock):
+        def __init__(self, **kwargs):
+            super(Flip, self).__init__(**kwargs)
+
+        def hybrid_forward(self, F, x):
+            return F.flip(x, axis=0)
+
+    def get_net(num_ops):
+        net = nn.HybridSequential()
+        with net.name_scope():
+            for _ in range(num_ops):
+                net.add(Flip())
+        return net
+
+    data_shape = (10,)
+    num_ops = 1000
+    num_iterations = 20
+
+    # build model
+    x = mx.ndarray.zeros(data_shape)
+    x.attach_grad()
+    dy = mx.ndarray.ones(data_shape)
+    net = get_net(num_ops)
+    net.hybridize(static_alloc=True, static_shape=True)
+
+    # time a number of forward() and backward() executions after some warm-up iterations
+    warmups = 1
+    for i in range(num_iterations+warmups):
+        with autograd.record():
+            if i == warmups:
+                start = time.time()
+            y = net(x)
+            y.backward(dy)
+            x.grad.wait_to_read()
+
+    time_per_iteration.value = (time.time() - start) / num_iterations
+
+@with_seed()
+def test_bulking():
+    # test case format: (max_fwd_segment_size, max_bwd_segment_size, enable_bulking_in_training)
+    test_cases = [(0,0,True), (1,1,True), (15,15,False), (15,0,True), (0,15,True), (15,15,True)]
+    times = {}
+    times_str = ''
+    for seg_sizes in test_cases:
+        # Create shared variable to return measured time from test process
+        time_per_iteration = mp.Manager().Value('d', 0.0)
+        if not run_in_spawned_process(_test_bulking_in_process,
+                                  {'MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_FWD' : seg_sizes[0],
+                                   'MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_BWD' : seg_sizes[1],
+                                   'MXNET_EXEC_BULK_EXEC_TRAIN' : seg_sizes[2]},
+                                  time_per_iteration):
+            # skip test since the python version can't run it properly.  Warning msg was logged.
+            return
+        times[seg_sizes] = time_per_iteration.value
+        times_str += \
+            '\n    runtime of (fwd,bwd,enable) op seg setting ({},{},{}) =\t{:.1f} msec'.format(
+                seg_sizes[0], seg_sizes[1], seg_sizes[2], 1000.0 * times[seg_sizes])
+
+    fastest_non_bulked_time = min(times[(0,0,True)], times[(1,1,True)], times[(15,15,False)])
+    slowest_half_bulked_time = max(times[(0,15,True)], times[(15,0,True)])
+    fastest_half_bulked_time = min(times[(0,15,True)], times[(15,0,True)])
+    fully_bulked_time = times[(15,15,True)]
+
+    print(times_str)
+    # Non-bulked times[0,0,True], times[1,1,True] and times[15,15,False] should be about the same,
+    # slower than both half-bulked times[0,15,True] and times[15,0,True]
+    assert slowest_half_bulked_time < fastest_non_bulked_time, \
+        'A half-bulked exec time is slower than the non-bulked time by {} secs! {}' \
+            .format(slowest_half_bulked_time - fastest_non_bulked_time, times_str)
+    # The fully bulked times[15,15,True] should be faster than both half-bulked runs
+    assert fully_bulked_time < fastest_half_bulked_time, \
+        'The fully-bulked exec time is slower than a half-bulked time by {} secs! {}' \
+            .format(fully_bulked_time - fastest_half_bulked_time, times_str)
+
 
 if __name__ == '__main__':
     import nose
diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py
index f329916..7d7c2ed 100644
--- a/tests/python/gpu/test_operator_gpu.py
+++ b/tests/python/gpu/test_operator_gpu.py
@@ -33,6 +33,7 @@ from numpy.testing import assert_allclose
 curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
 sys.path.insert(0, os.path.join(curr_path, '../unittest'))
 from common import setup_module, with_seed, teardown, assert_raises_cudnn_not_satisfied
+from common import run_in_spawned_process
 from test_operator import *
 from test_optimizer import *
 from test_random import *
@@ -521,24 +522,6 @@ def test_convolution_options():
     check_consistency_NxM([sym, sym_no_cudnn], ctx_list)
 
 
-# Helper function to run tests in a subprocess to avoid save/restore of os.environ.
-# Also avoids issues of cached environment variable lookups in the backend.
-def _test_in_separate_process(func, env, *args):
-    try:
-        mpctx = mp.get_context('spawn')
-    except:
-        print('SKIP: python%s.%s lacks the required process fork-exec support ... ' %
-              sys.version_info[0:2], file=sys.stderr, end='')
-    else:
-        seed = np.random.randint(0,1024*1024*1024)
-        for (key, value) in env.items():
-            os.environ[key] = str(value)
-        # Prepend seed as first arg
-        p = mpctx.Process(target=func, args=(seed,)+args)
-        p.start()
-        p.join()
-        assert p.exitcode == 0, "Non-zero exit code %d from %s()." % (p.exitcode, func.__name__)
-
 def _conv_with_num_streams(seed):
     with random_seed(seed):
         # Try to expose timing-dependent improper workspace sharing by parallel dgrad and wgrad
@@ -576,8 +559,10 @@ def test_convolution_multiple_streams():
 
     for num_streams in [1, 2]:
         for engine in engines:
-            _test_in_separate_process(_conv_with_num_streams,
+            print("Starting engine %s with %d streams." % (engine, num_streams), file=sys.stderr)
+            run_in_spawned_process(_conv_with_num_streams,
                 {'MXNET_GPU_WORKER_NSTREAMS' : num_streams, 'MXNET_ENGINE_TYPE' : engine})
+            print("Finished engine %s with %d streams." % (engine, num_streams), file=sys.stderr)
 
 
 # This test is designed to expose an issue with cudnn v7.1.4 algo find() when invoked with large c.
@@ -2127,6 +2112,82 @@ def test_bilinear_sampler_versions():
                     assert_almost_equal(exe.grad_dict['grid'].asnumpy(), exe_list[ref_idx].grad_dict['grid'].asnumpy(), rtol=1e-3, atol=1e-5)
 
 
+@with_seed()
+def test_bulking():
+    # Return the execution time of a model with the specified limits to the bulked op segments
+    def test_bulking_helper(data_shape, num_ops, num_iterations,
+                            max_fwd_segment_size, max_bwd_segment_size, enable_bulking_in_training):
+        orig_environ = os.environ.copy()
+        try:
+            # Explore different ways of setting the env vars.
+            # The framework does not cache the bulked seg size env var lookups during symbolic.
+            os.environ['MXNET_EXEC_BULK_EXEC_TRAIN'] = str(enable_bulking_in_training)
+            if max_fwd_segment_size == max_bwd_segment_size:
+                os.environ['MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN'] = str(max_fwd_segment_size)
+                os.environ.pop('MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_FWD', None)
+                os.environ.pop('MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_BWD', None)
+            else:
+                os.environ.pop('MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN', None)
+                os.environ['MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_FWD'] = str(max_fwd_segment_size)
+                os.environ['MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_BWD'] = str(max_bwd_segment_size)
+
+            ctx = default_context()
+            # build symbol
+            X = mx.sym.Variable('X')
+            sym = mx.sym.flip(X, axis=0)
+            for _ in range(num_ops-1):
+                sym = mx.sym.flip(sym, axis=0)
+            x = mx.ndarray.zeros(data_shape)
+            dx = mx.ndarray.zeros(data_shape)
+            dy = mx.ndarray.ones(data_shape)
+            exe = sym.bind(ctx=ctx, args=[x], args_grad = {'X':dx})
+
+            # time a number of forward() and backward() executions after some warm-up iterations
+            warmups = 1
+            for i in range(num_iterations+warmups):
+                if i == warmups:
+                    start = time.time()
+                exe.forward(is_train=True)
+                exe.backward(dy)
+                dx.wait_to_read()
+            time_per_iteration = (time.time() - start) / num_iterations
+        finally:
+            os.environ.clear()
+            os.environ.update(orig_environ)
+        return time_per_iteration
+
+    data_shape = (10,)
+    num_ops = 1000
+    num_iterations = 20
+
+    # test case format: (max_fwd_segment_size, max_bwd_segment_size, enable_bulking_in_training)
+    test_cases = [(0,0,True), (1,1,True), (15,15,False), (15,0,True), (0,15,True), (15,15,True)]
+    times = {}
+    times_str = ''
+    for seg_sizes in test_cases:
+        times[seg_sizes] = test_bulking_helper(data_shape, num_ops, num_iterations,
+                                               seg_sizes[0], seg_sizes[1], seg_sizes[2])
+        times_str +=\
+            '\n    runtime of (fwd,bwd,enable) op seg setting ({},{},{}) =\t{:.1f} msec'.format(
+            seg_sizes[0], seg_sizes[1], seg_sizes[2], 1000.0 * times[seg_sizes])
+
+    fastest_non_bulked_time = min(times[(0,0,True)], times[(1,1,True)], times[(15,15,False)])
+    slowest_half_bulked_time = max(times[(0,15,True)], times[(15,0,True)])
+    fastest_half_bulked_time = min(times[(0,15,True)], times[(15,0,True)])
+    fully_bulked_time = times[(15,15,True)]
+
+    print(times_str)
+    # Non-bulked times[0,0,True], times[1,1,True] and times[15,15,False] should be about the same,
+    # slower than both half-bulked times[0,15,True] and times[15,0,True]
+    assert slowest_half_bulked_time < fastest_non_bulked_time,\
+        'A half-bulked exec time is slower than the non-bulked time by {} secs! {}'\
+            .format(slowest_half_bulked_time - fastest_non_bulked_time, times_str)
+    # The fully bulked times[15,15,True] should be faster than both half-bulked runs
+    assert fully_bulked_time < fastest_half_bulked_time,\
+        'The fully-bulked exec time is slower than a half-bulked time by {} secs! {}'\
+            .format(fully_bulked_time - fastest_half_bulked_time, times_str)
+
+
 def test_context_num_gpus():
     # Test that num_gpus reports at least one GPU, as the test is run on a GPU host.
     assert mx.context.num_gpus() > 0
diff --git a/tests/python/unittest/common.py b/tests/python/unittest/common.py
index abfba73..7cd637d 100644
--- a/tests/python/unittest/common.py
+++ b/tests/python/unittest/common.py
@@ -15,7 +15,9 @@
 # specific language governing permissions and limitations
 # under the License.
 
+from __future__ import print_function
 import sys, os, logging
+import multiprocessing as mp
 import mxnet as mx
 import numpy as np
 import random
@@ -39,6 +41,7 @@ def assertRaises(expected_exception, func, *args, **kwargs):
         # Did not raise exception
         assert False, "%s did not raise %s" % (func.__name__, expected_exception.__name__)
 
+
 def default_logger():
     """A logger used to output seed information to nosetests logs."""
     logger = logging.getLogger(__name__)
@@ -51,6 +54,7 @@ def default_logger():
             logger.setLevel(logging.INFO)
     return logger
 
+
 @contextmanager
 def random_seed(seed=None):
     """
@@ -181,6 +185,7 @@ def with_seed(seed=None):
         return test_new
     return test_helper
 
+
 def setup_module():
     """
     A function with a 'magic name' executed automatically before each nosetests module
@@ -265,3 +270,48 @@ def teardown():
     It waits for all operations in one file to finish before carrying on the next.
     """
     mx.nd.waitall()
+
+
+def run_in_spawned_process(func, env, *args):
+    """
+    Helper function to run a test in its own process.
+
+    Avoids issues with Singleton- or otherwise-cached environment variable lookups in the backend.
+    Adds a seed as first arg to propagate determinism.
+
+    Parameters
+    ----------
+
+    func : function to run in a spawned process.
+    env : dict of additional environment values to set temporarily in the environment before exec.
+    args : args to pass to the function.
+
+    Returns
+    -------
+    Whether the python version supports running the function as a spawned process.
+
+    This routine calculates a random seed and passes it into the test as a first argument.  If the
+    test uses random values, it should include an outer 'with random_seed(seed):'.  If the
+    test needs to return values to the caller, consider use of shared variable arguments.
+    """
+    try:
+        mpctx = mp.get_context('spawn')
+    except:
+        print('SKIP: python%s.%s lacks the required process fork-exec support ... ' %
+              sys.version_info[0:2], file=sys.stderr, end='')
+        return False
+    else:
+        seed = np.random.randint(0,1024*1024*1024)
+        orig_environ = os.environ.copy()
+        try:
+            for (key, value) in env.items():
+                os.environ[key] = str(value)
+            # Prepend seed as first arg
+            p = mpctx.Process(target=func, args=(seed,)+args)
+            p.start()
+            p.join()
+            assert p.exitcode == 0, "Non-zero exit code %d from %s()." % (p.exitcode, func.__name__)
+        finally:
+            os.environ.clear()
+            os.environ.update(orig_environ)
+    return True
\ No newline at end of file