You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2020/09/08 16:56:54 UTC

[GitHub] [incubator-mxnet] leezu commented on a change in pull request #18704: Re-enable the test_gpu_memory_profiler_gluon test case

leezu commented on a change in pull request #18704:
URL: https://github.com/apache/incubator-mxnet/pull/18704#discussion_r485059844



##########
File path: tests/python/gpu/test_profiler_gpu.py
##########
@@ -117,31 +129,64 @@ def test_gpu_memory_profiler_gluon():
     profiler.set_state('stop')
     profiler.dump(True)
 
+    # Sample gpu_memory_profile.csv:
+    # "Attribute Name","Requested Size","Device","Actual Size","Reuse?"
+    # <unk>:in_arg:data,640,0,4096,0
+    # hybridsequential:activation0:hybridsequential_activation0_fwd,2048,0,4096,0
+    # hybridsequential:activation0:hybridsequential_activation0_fwd_backward,8192,0,8192,0
+    # hybridsequential:activation0:hybridsequential_activation0_fwd_head_grad,2048,0,4096,0
+    # hybridsequential:dense0:activation0:hybridsequential_dense0_activation0_fwd,8192,0,8192,0
+    # hybridsequential:dense0:arg_grad:bias,512,0,4096,0
+    # hybridsequential:dense0:arg_grad:weight,5120,0,8192,0
+    # hybridsequential:dense0:hybridsequential_dense0_fwd,8192,0,8192,0
+    # hybridsequential:dense0:in_arg:bias,512,0,4096,0
+    # hybridsequential:dense0:in_arg:weight,5120,0,8192,0
+    # hybridsequential:dense1:activation0:hybridsequential_dense1_activation0_fwd,4096,0,4096,0
+    # hybridsequential:dense1:arg_grad:bias,256,0,4096,0
+    # hybridsequential:dense1:arg_grad:weight,32768,0,32768,0
+    # hybridsequential:dense1:hybridsequential_dense1_fwd,4096,0,4096,0
+    # hybridsequential:dense1:in_arg:bias,256,0,4096,0
+    # hybridsequential:dense1:in_arg:weight,32768,0,32768,0
+    # hybridsequential:dense2:arg_grad:bias,128,0,4096,0
+    # hybridsequential:dense2:arg_grad:weight,8192,0,8192,0
+    # hybridsequential:dense2:hybridsequential_dense2_fwd_backward,4096,0,4096,1
+    # hybridsequential:dense2:in_arg:bias,128,0,4096,0
+    # hybridsequential:dense2:in_arg:weight,8192,0,8192,0
+    # hybridsequential:dropout0:hybridsequential_dropout0_fwd,8192,0,8192,0
+    # hybridsequential:dropout0:hybridsequential_dropout0_fwd,8192,0,8192,0
+    # resource:cudnn_dropout_state (dropout-inl.h +256),1474560,0,1474560,0
+    # resource:temp_space (fully_connected-inl.h +316),15360,0,16384,0
+
     # We are only checking for weight parameters here, also making sure that
     # there is no unknown entries in the memory profile.
     with open('gpu_memory_profile-pid_%d.csv' % (os.getpid()), mode='r') as csv_file:
         csv_reader = csv.DictReader(csv_file)
+        # TODO: Remove this print statement later on.
         for row in csv_reader:
             print(",".join(list(row.values())))
-        for scope in ['in_arg', 'arg_grad']:
-            for key, nd in model.collect_params().items():
-                expected_arg_name = "%s:%s:" % (model.name, scope) + nd.name
-                expected_arg_size = str(4 * np.prod(nd.shape))
-                csv_file.seek(0)
-                entry_found = False
-                for row in csv_reader:
-                    if row['Attribute Name'] == expected_arg_name:
-                        assert row['Requested Size'] == expected_arg_size, \
-                            "requested size={} is not equal to the expected size={}" \
-                            .format(row['Requested Size'], expected_arg_size)
-                        entry_found = True
-                        break
-                assert entry_found, \
-                    "Entry for attr_name={} has not been found" \
-                    .format(expected_arg_name)
+        for param in model.collect_params().values():
+            expected_arg_name = "%sin_arg:" % param.var().attr('__profiler_scope__') + \
+                                param.name
+            expected_arg_size = str(4 * np.prod(param.shape))
+            csv_file.seek(0)
+            entry_found = False
+            for row in csv_reader:
+                if row['Attribute Name'] == expected_arg_name and \
+                   row['Requested Size'] == expected_arg_size:
+                    entry_found = True
+                    break
+            assert entry_found, \
+                    "Entry for (attr_name={}, alloc_size={}) has not been found" \
+                        .format(expected_arg_name,
+                                expected_arg_size)
         # Make sure that there is no unknown allocation entry.
         csv_file.seek(0)
         for row in csv_reader:
             if row['Attribute Name'] == "<unk>:unknown" or \
                row['Attribute Name'] == "<unk>:":
                 assert False, "Unknown allocation entry has been encountered"
+
+
+if __name__ == "__main__":
+    import nose
+    nose.runmodule()

Review comment:
       @ArmageddonKnight please remove this

##########
File path: tests/python/gpu/test_profiler_gpu.py
##########
@@ -62,41 +62,53 @@ def test_gpu_memory_profiler_symbolic():
             {'Attribute Name' : 'tensordot:in_arg:B',
              'Requested Size' : str(4 * b.size)},
             {'Attribute Name' : 'tensordot:dot',
-             'Requested Size' : str(4 * c.size)}]
+             'Requested Size' : str(4 * c.size)},
+            {'Attribute Name' : 'init:_random_uniform',
+             'Requested Size' : str(4 * a.size)},
+            {'Attribute Name' : 'init:_random_uniform',
+             'Requested Size' : str(4 * b.size)}]
 
     # Sample gpu_memory_profile.csv:
     # "Attribute Name","Requested Size","Device","Actual Size","Reuse?"
-    # "<unk>:_zeros","67108864","0","67108864","0"
-    # "<unk>:_zeros","67108864","0","67108864","0"
-    # "tensordot:dot","67108864","0","67108864","1"
-    # "tensordot:dot","67108864","0","67108864","1"
-    # "tensordot:in_arg:A","67108864","0","67108864","0"
-    # "tensordot:in_arg:B","67108864","0","67108864","0"
-    # "nvml_amend","1074790400","0","1074790400","0"
+    # init:_random_uniform,33554432,0,33554432,1
+    # init:_random_uniform,8388608,0,8388608,1
+    # resource:temp_space (sample_op.h +365),8,0,4096,0
+    # symbol:arg_grad:unknown,8388608,0,8388608,0
+    # symbol:arg_grad:unknown,33554432,0,33554432,0
+    # tensordot:dot,16777216,0,16777216,0
+    # tensordot:dot_backward,33554432,0,33554432,0
+    # tensordot:dot_backward,8388608,0,8388608,0
+    # tensordot:dot_head_grad,16777216,0,16777216,0
+    # tensordot:in_arg:A,8388608,0,8388608,0
+    # tensordot:in_arg:B,33554432,0,33554432,0
 
     with open('gpu_memory_profile-pid_%d.csv' % (os.getpid()), mode='r') as csv_file:
         csv_reader = csv.DictReader(csv_file)
+        # TODO: Remove this print statement later on.
+        for row in csv_reader:
+            print(",".join(list(row.values())))

Review comment:
       Fix TODO

##########
File path: src/imperative/imperative.cc
##########
@@ -233,7 +233,7 @@ void Imperative::RecordOp(
 
   nnvm::ObjectPtr node = nnvm::Node::Create();
   node->attrs = std::move(attrs);
-  node->attrs.name = "node_" + std::to_string(node_count_++);
+  node_count_ += 1;

Review comment:
       Why do you disable naming the ops in autograd recording?

##########
File path: src/profiler/storage_profiler.cc
##########
@@ -61,11 +62,17 @@ void GpuDeviceStorageProfiler::DumpProfile() const {
   std::multimap<std::string, AllocEntryDumpFmt> gpu_mem_ordered_alloc_entries;
   // map the GPU device ID to the total amount of allocations
   std::unordered_map<int, size_t> gpu_dev_id_total_alloc_map;
+  std::regex gluon_param_regex("([0-9a-fA-F]{8})_([0-9a-fA-F]{4})_"
+                               "([0-9a-fA-F]{4})_([0-9a-fA-F]{4})_"
+                               "([0-9a-fA-F]{12})_([^ ]*)");
+
   for (const std::pair<void *const, AllocEntry>& alloc_entry :
        gpu_mem_alloc_entries_) {
+    std::string alloc_entry_name
+        = std::regex_replace(alloc_entry.second.name, gluon_param_regex, "$6");

Review comment:
       You are relying on `self._var_name = self._uuid.replace('-', '_') + '_' + self._name` from parameter.py here. This isn't true in case of SymbolBlock import. Have you tested the profiler feature with SymbolBlock? The current assumption may be a bit fragile?

##########
File path: src/imperative/imperative.cc
##########
@@ -322,7 +322,7 @@ void Imperative::RecordDeferredCompute(nnvm::NodeAttrs &&attrs,
   }
   node->attrs = std::move(attrs);
   // Need to support NameManager in imperative API to better name node->attrs.name
-  node->attrs.name = "node_" + std::to_string(node_count_++);
+  node_count_ += 1;

Review comment:
       Why do you disable naming the ops in dc recording?

##########
File path: python/mxnet/gluon/parameter.py
##########
@@ -644,7 +644,7 @@ def var(self):
         """Returns a symbol representing this parameter."""
         if self._var is None:
             if self._var_name is None:  # _var_name is set manually in SymbolBlock.import
-                self._var_name = self._uuid
+                self._var_name = self._uuid.replace('-', '_') + '_' + self._name

Review comment:
       You need to document that `src/profiler/storage_profiler.cc` depends on this scheme.

##########
File path: tests/python/gpu/test_profiler_gpu.py
##########
@@ -15,44 +15,44 @@
 # specific language governing permissions and limitations
 # under the License.
 
+import csv
 import os
 import sys
 
+import numpy as np
 import mxnet as mx
 mx.test_utils.set_default_context(mx.gpu(0))
 
 curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
 sys.path.insert(0, os.path.join(curr_path, '../unittest'))
-# We import all tests from ../unittest/test_profiler.py
-# They will be detected by test framework, as long as the current file has a different filename
-from test_profiler import *

Review comment:
       Why do you disable the `./unittest/test_profiler.py` tests on GPU?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org