You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2020/03/10 03:37:29 UTC

[incubator-mxnet] branch master updated: Implement storage tagging, the first half of the memory profiler (#17656)

This is an automated email from the ASF dual-hosted git repository.

zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 4dddb08  Implement storage tagging, the first half of the memory profiler (#17656)
4dddb08 is described below

commit 4dddb087293965af7e7af8d0eaead60a9051f0fc
Author: Bojian Zheng <bo...@mail.utoronto.ca>
AuthorDate: Mon Mar 9 23:36:30 2020 -0400

    Implement storage tagging, the first half of the memory profiler (#17656)
---
 CMakeLists.txt                                  |  14 +++
 Makefile                                        |   7 ++
 ci/docker/runtime_functions.sh                  |   8 +-
 cmake/Modules/FindNVML.cmake                    |  84 +++++++++++++
 include/mxnet/c_api.h                           |   7 ++
 include/mxnet/libinfo.h                         |   4 +
 include/mxnet/ndarray.h                         |  11 +-
 include/mxnet/resource.h                        |  70 ++++++++---
 include/mxnet/storage.h                         |  11 +-
 python/mxnet/gluon/block.py                     |  24 +++-
 python/mxnet/operator.py                        |   8 +-
 python/mxnet/optimizer/updater.py               |   4 +-
 python/mxnet/profiler.py                        |  52 ++++++++
 python/mxnet/symbol/register.py                 |  11 ++
 python/mxnet/symbol/symbol.py                   |  11 +-
 src/c_api/c_api.cc                              |  34 ++++--
 src/c_api/c_api_ndarray.cc                      |   5 +
 src/c_api/c_api_profile.cc                      |  19 +++
 src/c_api/c_api_symbolic.cc                     |  14 ++-
 src/common/cuda_utils.h                         |  19 +++
 src/common/utils.h                              |  20 ++-
 src/executor/graph_executor.cc                  |  36 +++++-
 src/imperative/cached_op.cc                     |  40 ++++++
 src/imperative/cached_op.h                      |   8 +-
 src/imperative/imperative_utils.h               |  23 +++-
 src/io/iter_image_recordio_2.cc                 |   7 ++
 src/kvstore/comm.h                              |  19 ++-
 src/kvstore/comm_tree.h                         |  12 +-
 src/kvstore/kvstore_local.h                     |  22 ++++
 src/ndarray/ndarray.cc                          |  27 +++-
 src/operator/linalg_impl.h                      |  47 ++++---
 src/operator/nn/cudnn/cudnn_convolution-inl.h   |   5 +-
 src/operator/nn/cudnn/cudnn_deconvolution-inl.h |   5 +-
 src/operator/rnn-inl.h                          |   4 +
 src/profiler/profiler.cc                        |  18 +++
 src/profiler/profiler.h                         |  14 ++-
 src/profiler/storage_profiler.cc                | 130 ++++++++++++++++++++
 src/profiler/storage_profiler.h                 |  95 ++++++++++++++-
 src/resource.cc                                 |  30 +++--
 src/storage/gpu_device_storage.h                |   8 +-
 src/storage/pinned_memory_storage.h             |   5 +
 src/storage/pooled_storage_manager.h            |  11 ++
 src/storage/storage.cc                          |   2 +-
 tests/python/unittest/test_profiler.py          | 156 +++++++++++++++++++++++-
 tests/python/unittest/test_symbol.py            |   3 +-
 tests/python/unittest/test_thread_local.py      |   1 +
 46 files changed, 1066 insertions(+), 99 deletions(-)

diff --git a/CMakeLists.txt b/CMakeLists.txt
index 037a692..5806883 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -33,6 +33,7 @@ Format: Auto | Common | All | LIST(ARCH_AND_PTX ...)
 option(USE_NCCL "Use NVidia NCCL with CUDA" OFF)
 option(USE_OPENCV "Build with OpenCV support" ON)
 option(USE_OPENMP "Build with Openmp support" ON)
+cmake_dependent_option(USE_NVML "Build with nvml support if found" ON "USE_CUDA" OFF)
 cmake_dependent_option(USE_CUDNN "Build with cudnn support" ON "USE_CUDA" OFF) # one could set CUDNN_ROOT for search path
 cmake_dependent_option(USE_NVTX "Build with nvtx support if found" ON "USE_CUDA" OFF)
 cmake_dependent_option(USE_SSE "Build with x86 SSE instruction support" ON "NOT ARM" OFF)
@@ -630,6 +631,19 @@ if(USE_CUDA)
   list(APPEND SOURCE ${CUDA})
   add_definitions(-DMXNET_USE_CUDA=1)
 
+  if(UNIX)
+    if(USE_NVML)
+      find_package(NVML)
+      if(NVML_FOUND)
+        include_directories(${NVML_INCLUDE_DIRS})
+        list(APPEND mxnet_LINKER_LIBS ${NVML_LIBRARIES})
+        add_definitions(-DMXNET_USE_NVML=1)
+      else()
+        add_definitions(-DMXNET_USE_NVML=0)
+        message(WARNING "Could not find NVML libraries")
+      endif()
+    endif()
+  endif()
   if(USE_NCCL)
     find_package(NCCL)
     if(NCCL_FOUND)
diff --git a/Makefile b/Makefile
index 90303ae..49ba8fe 100644
--- a/Makefile
+++ b/Makefile
@@ -526,6 +526,12 @@ ifeq ($(USE_CUDA), 1)
 	# Make sure to add stubs as fallback in order to be able to build
 	# without full CUDA install (especially if run without nvidia-docker)
 	LDFLAGS += -L/usr/local/cuda/lib64/stubs
+	ifeq ($(USE_NVML), 1)
+		LDFLAGS += -lnvidia-ml
+		CFLAGS += -DMXNET_USE_NVML=1
+	else
+		CFLAGS += -DMXNET_USE_NVML=0
+	endif
 	ifeq ($(USE_NCCL), 1)
 		ifneq ($(USE_NCCL_PATH), NONE)
 			CFLAGS += -I$(USE_NCCL_PATH)/include
@@ -537,6 +543,7 @@ ifeq ($(USE_CUDA), 1)
 		CFLAGS += -DMXNET_USE_NCCL=0
 	endif
 else
+	CFLAGS += -DMXNET_USE_NVML=0
 	CFLAGS += -DMXNET_USE_NCCL=0
 endif
 
diff --git a/ci/docker/runtime_functions.sh b/ci/docker/runtime_functions.sh
index c5ecdd7..ffd871a 100755
--- a/ci/docker/runtime_functions.sh
+++ b/ci/docker/runtime_functions.sh
@@ -701,7 +701,7 @@ build_ubuntu_gpu_mkldnn() {
     set -ex
     cd /work/build
     cmake \
-        -DCMAKE_BUILD_TYPE="RelWithDebInfo" \
+        -DCMAKE_BUILD_TYPE=Release \
         -DUSE_MKL_IF_AVAILABLE=OFF \
         -DUSE_TVM_OP=ON \
         -DUSE_CUDA=ON \
@@ -715,7 +715,7 @@ build_ubuntu_gpu_mkldnn_nocudnn() {
     set -ex
     cd /work/build
     cmake \
-        -DCMAKE_BUILD_TYPE="RelWithDebInfo" \
+        -DCMAKE_BUILD_TYPE=Release \
         -DUSE_MKL_IF_AVAILABLE=OFF \
         -DUSE_TVM_OP=ON \
         -DUSE_CUDA=ON \
@@ -730,7 +730,7 @@ build_ubuntu_gpu_cuda101_cudnn7() {
     set -ex
     cd /work/build
     cmake \
-        -DCMAKE_BUILD_TYPE="RelWithDebInfo" \
+        -DCMAKE_BUILD_TYPE=Release \
         -DUSE_MKL_IF_AVAILABLE=OFF \
         -DUSE_TVM_OP=ON \
         -DUSE_CUDA=ON \
@@ -786,7 +786,7 @@ build_ubuntu_gpu_cuda101_cudnn7_no_tvm_op() {
     set -ex
     cd /work/build
     cmake \
-        -DCMAKE_BUILD_TYPE="RelWithDebInfo" \
+        -DCMAKE_BUILD_TYPE=Release \
         -DUSE_MKL_IF_AVAILABLE=OFF \
         -DUSE_TVM_OP=OFF \
         -DUSE_CUDA=ON \
diff --git a/cmake/Modules/FindNVML.cmake b/cmake/Modules/FindNVML.cmake
new file mode 100644
index 0000000..e098132
--- /dev/null
+++ b/cmake/Modules/FindNVML.cmake
@@ -0,0 +1,84 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# Find the nvml libraries
+#
+# The following variables are optionally searched for defaults
+#  NVML_ROOT_DIR: Base directory where all NVML components are found
+#  NVML_INCLUDE_DIR: Directory where NVML header is found
+#  NVML_LIB_DIR: Directory where NVML library is found
+#
+# The following are set after configuration is done:
+#  NVML_FOUND
+#  NVML_INCLUDE_DIRS
+#  NVML_LIBRARIES
+#
+# The path hints include CUDA_TOOLKIT_ROOT_DIR seeing as some folks
+# install NVML in the same location as the CUDA toolkit.
+# See https://github.com/caffe2/caffe2/issues/1601
+
+if ($ENV{NVML_ROOT_DIR})
+  message(WARNING "NVML_ROOT_DIR is deprecated. Please set NVML_ROOT instead.")
+endif()
+
+find_path(NVML_INCLUDE_DIRS
+  NAMES nvml.h
+  HINTS
+  ${NVML_INCLUDE_DIR}
+  ${NVML_ROOT_DIR}
+  ${NVML_ROOT_DIR}/include
+  ${CUDA_TOOLKIT_ROOT_DIR}/include
+  $ENV{NVML_DIR}/include
+  )
+
+find_library(NVML_LIBRARIES
+  NAMES nvidia-ml
+  HINTS
+  ${NVML_LIB_DIR}
+  ${NVML_ROOT_DIR}
+  ${NVML_ROOT_DIR}/lib
+  ${NVML_ROOT_DIR}/lib/x86_64-linux-gnu
+  ${NVML_ROOT_DIR}/lib64
+  ${CUDA_TOOLKIT_ROOT_DIR}/lib64/stubs
+  $ENV{NVML_DIR}/lib
+  )
+
+# if not found in any of the above paths, finally, check in the /usr/local/cuda for UNIX systems
+if (UNIX)
+  set (search_paths "/usr/local/cuda")
+
+  find_path(NVML_INCLUDE_DIRS
+    NAMES nvml.h
+    PATHS ${search_paths}
+    PATH_SUFFIXES include
+  )
+
+  find_library(NVML_LIBRARIES
+    NAMES nvidia-ml
+    PATHS ${search_paths}
+    PATH_SUFFIXES lib64/stubs
+  )
+endif()
+
+include(FindPackageHandleStandardArgs)
+find_package_handle_standard_args(NVML DEFAULT_MSG NVML_INCLUDE_DIRS NVML_LIBRARIES)
+
+if(NVML_FOUND)
+  message(STATUS "Found NVML (include: ${NVML_INCLUDE_DIRS}, library: ${NVML_LIBRARIES})")
+  mark_as_advanced(NVML_ROOT_DIR NVML_INCLUDE_DIRS NVML_LIBRARIES)
+endif()
+
diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h
index bb2a568..efa0033 100644
--- a/include/mxnet/c_api.h
+++ b/include/mxnet/c_api.h
@@ -314,6 +314,13 @@ MXNET_DLL int MXSetProcessProfilerState(int state, int profile_process,
 MXNET_DLL int MXSetProfilerState(int state);
 
 /*!
+ * \brief Set the scope of profiler for current process
+ * \param scope indicate the working scope of profiler
+ * \return 0 when success, -1 when failure happens.
+ */
+MXNET_DLL int MXSetProfilerScope(const char* scope);
+
+/*!
  * \brief Save profile and stop profiler
  * \param finished true if stat output should stop after this point
  * \param profile_process an int,
diff --git a/include/mxnet/libinfo.h b/include/mxnet/libinfo.h
index 1972688..8b31b2d 100644
--- a/include/mxnet/libinfo.h
+++ b/include/mxnet/libinfo.h
@@ -55,6 +55,10 @@
 #define MXNET_USE_CUDNN MSHADOW_USE_CUDNN
 #endif
 
+#ifndef MXNET_USE_NVML
+#define MXNET_USE_NVML 0
+#endif
+
 #ifndef MXNET_USE_NCCL
 #define MXNET_USE_NCCL 0
 #endif
diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h
index c55e49e..fd7cc38 100644
--- a/include/mxnet/ndarray.h
+++ b/include/mxnet/ndarray.h
@@ -361,6 +361,9 @@ class NDArray {
     CheckAndAlloc();
     return ptr_->shandle;
   }
+  /*! \brief assign profiler scope and name to the storage handles */
+  void AssignStorageInfo(const std::string& profiler_scope,
+                         const std::string& name);
   /*!
    * \brief Block until all the pending write operations with respect
    *    to current NDArray are finished, and read can be performed.
@@ -989,7 +992,7 @@ class NDArray {
     /*! \brief check if delay alloc is on, do alloc if not yet done */
     inline void CheckAndAlloc(void) {
       if (delay_alloc) {
-        shandle = Storage::Get()->Alloc(shandle.size, shandle.ctx);
+        Storage::Get()->Alloc(&shandle);
 #if MXNET_USE_MKLDNN == 1
         mkl_mem_ = nullptr;
 #endif
@@ -1004,7 +1007,8 @@ class NDArray {
           << "CheckAndAlloc(dbytes) is only intended for kDefaultStorage";
       dbytes = std::max(dbytes, static_cast<uint64_t>(shandle.size));
       if (delay_alloc) {
-        shandle = Storage::Get()->Alloc(dbytes, shandle.ctx);
+        shandle.size = dbytes;
+        Storage::Get()->Alloc(&shandle);
 #if MXNET_USE_MKLDNN == 1
         mkl_mem_ = nullptr;
 #endif
@@ -1013,7 +1017,8 @@ class NDArray {
         // free storage
         Storage::Get()->Free(shandle);
         // init storage
-        shandle = Storage::Get()->Alloc(dbytes, shandle.ctx);
+        shandle.size = dbytes;
+        Storage::Get()->Alloc(&shandle);
 #if MXNET_USE_MKLDNN == 1
         mkl_mem_ = nullptr;
 #endif
diff --git a/include/mxnet/resource.h b/include/mxnet/resource.h
index 0114c3b..d766faa 100644
--- a/include/mxnet/resource.h
+++ b/include/mxnet/resource.h
@@ -26,6 +26,7 @@
 #define MXNET_RESOURCE_H_
 
 #include <dmlc/logging.h>
+#include <string>
 #include "./base.h"
 #include "./engine.h"
 #include "./random_generator.h"
@@ -62,6 +63,28 @@ struct ResourceRequest {
       : type(type) {}
 };
 
+namespace {
+/// \brief Given a path, extract the filename.
+inline std::string __extract_fname(const std::string& path) {
+  std::size_t last_dir_pos = path.find_last_of("/\\");
+  if (last_dir_pos == std::string::npos) {
+    return path;
+  }
+  return path.substr(last_dir_pos + 1);
+}
+}  // anonymous namespace
+
+#if (defined(__GNUC__) || defined(__GNUG__)) && !defined(__clang__)
+#define MXNET_RESOURCE_DEFAULT_NAME_FARG(tag) \
+    std::string(tag) \
+    + " (" + __extract_fname(__builtin_FILE()) \
+    + " +" +  std::to_string(__builtin_LINE()) + ")"
+#else  // !__GNUC__ || __clang__
+#define MXNET_RESOURCE_DEFAULT_NAME_FARG(tag) \
+    std::string(tag) \
+    + " (" + __extract_fname(__FILE__) \
+    + " +" +  std::to_string(__LINE__) + ")"
+#endif  // __GNUC__ && !__clang__
 
 /*!
  * \brief Resources used by mxnet operations.
@@ -120,16 +143,18 @@ struct Resource {
    *  when running on device, so the launched kernels that depend on the temp space
    *  can finish correctly.
    *
-   * \param shape the Shape of returning tensor.
-   * \param stream the stream of retruning tensor.
+   * \param shape   the shape of returning tensor.
+   * \param stream  the stream of returning tensor.
+   * \param name    the name of the operator requesting the resource.
    * \return the mshadow tensor requested.
-   * \tparam xpu the device type of random number generator.
-   * \tparam ndim the number of dimension of the tensor requested.
+   * \tparam xpu   the device type of random number generator.
+   * \tparam ndim  the number of dimension of the tensor requested.
    */
   template<typename xpu, int ndim>
   inline mshadow::Tensor<xpu, ndim, real_t> get_space(
-      mshadow::Shape<ndim> shape, mshadow::Stream<xpu> *stream) const {
-    return get_space_typed<xpu, ndim, real_t>(shape, stream);
+      mshadow::Shape<ndim> shape, mshadow::Stream<xpu> *stream,
+      const std::string &name = MXNET_RESOURCE_DEFAULT_NAME_FARG("temp_space")) const {
+    return get_space_typed<xpu, ndim, real_t>(shape, stream, name);
   }
   /*!
    * \brief Get cpu space requested as mshadow Tensor.
@@ -148,33 +173,37 @@ struct Resource {
    * \brief Get space requested as mshadow Tensor in specified type.
    *  The caller can request arbitrary size.
    *
-   * \param shape the Shape of returning tensor.
-   * \param stream the stream of retruning tensor.
+   * \param shape   the shape of returning tensor.
+   * \param stream  the stream of returning tensor.
+   * \param name    the name of the operator requesting the resource.
    * \return the mshadow tensor requested.
-   * \tparam xpu the device type of random number generator.
-   * \tparam ndim the number of dimension of the tensor requested.
+   * \tparam xpu   the device type of random number generator.
+   * \tparam ndim  the number of dimension of the tensor requested.
    */
   template<typename xpu, int ndim, typename DType>
   inline mshadow::Tensor<xpu, ndim, DType> get_space_typed(
-      mshadow::Shape<ndim> shape, mshadow::Stream<xpu> *stream) const {
+      mshadow::Shape<ndim> shape, mshadow::Stream<xpu> *stream,
+      const std::string &name = MXNET_RESOURCE_DEFAULT_NAME_FARG("temp_space")) const {
     CHECK_EQ(req.type, ResourceRequest::kTempSpace);
     return mshadow::Tensor<xpu, ndim, DType>(
-        reinterpret_cast<DType*>(get_space_internal(shape.Size() * sizeof(DType))),
+        reinterpret_cast<DType*>(get_space_internal(
+          shape.Size() * sizeof(DType), name)),
         shape, shape[ndim - 1], stream);
   }
 #if MXNET_USE_CUDNN == 1
   /*!
-   * \brief Get cudnn dropout descriptor from shared state space.
+   * \brief Get cuDNN dropout descriptor from shared state space.
    *
-   * \param dropout_desc reference to previously created cudnn dropout descriptor.
-   * \param stream the stream of retruning tensor.
+   * \param dropout_desc  reference to previously created cuDNN dropout descriptor.
+   * \param stream  the stream of returning tensor.
+   * \param name    the name of the operator requesting the resource.
    * \return the mshadow tensor requested.
    */
   void get_cudnn_dropout_desc(
-      cudnnDropoutDescriptor_t* dropout_desc,
+      cudnnDropoutDescriptor_t *dropout_desc,
       mshadow::Stream<gpu> *stream,
-      const float dropout,
-      uint64_t seed) const;
+      const float dropout, uint64_t seed,
+      const std::string &name = MXNET_RESOURCE_DEFAULT_NAME_FARG("cudnn_dropout_state")) const;
 #endif  // MXNET_USE_CUDNN == 1
 
   /*!
@@ -195,10 +224,11 @@ struct Resource {
   }
   /*!
    * \brief internal function to get space from resources.
-   * \param size The size of the space.
+   * \param size the Size of the space.
+   * \param name the Name of the operator requesting the resource.
    * \return The allocated space.
    */
-  void* get_space_internal(size_t size) const;
+  void* get_space_internal(size_t size, const std::string &name) const;
   /*!
    * \brief internal function to get cpu space from resources.
    * \param size The size of space.
diff --git a/include/mxnet/storage.h b/include/mxnet/storage.h
index 4d1fc3d..0e90d17 100644
--- a/include/mxnet/storage.h
+++ b/include/mxnet/storage.h
@@ -26,10 +26,14 @@
 #define MXNET_STORAGE_H_
 
 #include <memory>
+#include <string>
 #include "./base.h"
 
 namespace mxnet {
 
+#define MXNET_STORAGE_DEFAULT_PROFILER_SCOPE_CSTR  "<unk>:"
+#define MXNET_STORAGE_DEFAULT_NAME_CSTR  "unknown"
+
 /*!
  * \brief Storage manager across multiple devices.
  */
@@ -55,7 +59,12 @@ class Storage {
      * \brief Id for IPC shared memory
      */
     int shared_pid{-1};
-    int shared_id{-1};
+    int shared_id {-1};
+    /*!
+     * \brief Attributes for tracking storage allocations.
+     */
+    std::string profiler_scope{MXNET_STORAGE_DEFAULT_PROFILER_SCOPE_CSTR};
+    std::string name{MXNET_STORAGE_DEFAULT_NAME_CSTR};
   };
   /*!
    * \brief Allocate a new contiguous memory for a given size.
diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py
index e925b31..a944880 100644
--- a/python/mxnet/gluon/block.py
+++ b/python/mxnet/gluon/block.py
@@ -32,6 +32,7 @@ from .. import symbol, ndarray, initializer, np_symbol
 from ..symbol import Symbol
 from ..ndarray import NDArray
 from .. import name as _name
+from .. import profiler as _profiler
 from .parameter import Parameter, ParameterDict, DeferredInitializationError
 from .utils import _indent, _brief_print_list, HookHandle
 from .utils import _check_same_symbol_type, _check_all_np_ndarrays
@@ -53,18 +54,24 @@ class _BlockScope(object):
 
     @staticmethod
     def create(prefix, params, hint):
-        """Creates prefix and params for new `Block`."""
+        """
+        Creates prefix, params, and profiler scope name for new `Block`.
+        The profiler scope is to support the GPU memory profiler.
+        """
         current = getattr(_BlockScope._current, "value", None)
         if current is None:
             if prefix is None:
                 if not hasattr(_name.NameManager._current, "value"):
                     _name.NameManager._current.value = _name.NameManager()
                 prefix = _name.NameManager._current.value.get(None, hint) + '_'
+            # replace the trailing underscore with colon
+            profiler_scope_name = (prefix[:-1] if prefix.endswith('_') \
+                                   else prefix) + ":"
             if params is None:
                 params = ParameterDict(prefix)
             else:
                 params = ParameterDict(params.prefix, params)
-            return prefix, params
+            return prefix, params, profiler_scope_name
 
         if prefix is None:
             count = current._counter.get(hint, 0)
@@ -75,7 +82,11 @@ class _BlockScope(object):
             params = ParameterDict(parent.prefix+prefix, parent._shared)
         else:
             params = ParameterDict(params.prefix, params)
-        return current._block.prefix+prefix, params
+        # replace the trailing underscore with colon
+        profiler_scope_name = (prefix[:-1] if prefix.endswith('_') \
+                               else prefix) + ":"
+        return current._block.prefix + prefix, params, \
+               current._block._profiler_scope_name + profiler_scope_name
 
     def __enter__(self):
         if self._block._empty_prefix:
@@ -84,6 +95,8 @@ class _BlockScope(object):
         _BlockScope._current.value = self
         self._name_scope = _name.Prefix(self._block.prefix)
         self._name_scope.__enter__()
+        self._profiler_scope = _profiler.Scope(self._block._profiler_scope_name)
+        self._profiler_scope.__enter__()
         return self
 
     def __exit__(self, ptype, value, trace):
@@ -91,6 +104,8 @@ class _BlockScope(object):
             return
         self._name_scope.__exit__(ptype, value, trace)
         self._name_scope = None
+        self._profiler_scope.__exit__(ptype, value, trace)
+        self._profiler_scope = None
         _BlockScope._current.value = self._old_scope
 
 
@@ -274,7 +289,8 @@ class Block(object):
     """
     def __init__(self, prefix=None, params=None):
         self._empty_prefix = prefix == ''
-        self._prefix, self._params = _BlockScope.create(prefix, params, self._alias())
+        self._prefix, self._params, self._profiler_scope_name = \
+                _BlockScope.create(prefix, params, self._alias())
         self._name = self._prefix[:-1] if self._prefix.endswith('_') else self._prefix
         self._scope = _BlockScope(self)
         self._children = OrderedDict()
diff --git a/python/mxnet/operator.py b/python/mxnet/operator.py
index eff7320..04efbee 100644
--- a/python/mxnet/operator.py
+++ b/python/mxnet/operator.py
@@ -733,7 +733,13 @@ def register(reg_name):
         def creator(op_type, argc, keys, vals, ret):
             """internal function"""
             assert py_str(op_type) == reg_name
-            kwargs = dict([(py_str(keys[i]), py_str(vals[i])) for i in range(argc)])
+            kwargs = {}
+            for i in range(argc):
+                key = py_str(keys[i])
+                if key not in ['__ctx_group__', '__lr_mult__', '__wd_mult__',
+                               '__force_mirroring__',
+                               '__mirror_stage__', '__profiler_scope__']:
+                    kwargs[key] = py_str(vals[i])
             op_prop = prop_cls(**kwargs)
 
             def infer_shape_entry(num_tensor, tensor_dims,
diff --git a/python/mxnet/optimizer/updater.py b/python/mxnet/optimizer/updater.py
index 62b7004..a969614 100644
--- a/python/mxnet/optimizer/updater.py
+++ b/python/mxnet/optimizer/updater.py
@@ -21,6 +21,7 @@ import pickle
 import numpy
 from ..base import py_str
 from ..ndarray import NDArray
+from ..profiler import Scope
 from ..util import is_np_array
 from .utils import _as_classic
 
@@ -54,7 +55,8 @@ class Updater(object):
                 indices[i] = py_str(idx)
                 idx = indices[i]
             if idx not in self.states:
-                self.states[idx] = self.optimizer.create_state_multi_precision(idx, weights[i])
+                with Scope("updater:optimizer_state"):
+                    self.states[idx] = self.optimizer.create_state_multi_precision(idx, weights[i])
                 self.states_synced[idx] = True
             elif not self.states_synced[idx]:
                 self.states[idx] = \
diff --git a/python/mxnet/profiler.py b/python/mxnet/profiler.py
index 1a12f6f..3b8830d 100644
--- a/python/mxnet/profiler.py
+++ b/python/mxnet/profiler.py
@@ -20,6 +20,7 @@
 # pylint: disable=too-many-branches, too-many-statements
 """Profiler setting methods."""
 import ctypes
+import threading
 import warnings
 from .base import _LIB, check_call, c_str, ProfileHandle, c_str_array, py_str, KVStoreHandle
 
@@ -36,6 +37,8 @@ def set_config(**kwargs):
     ----------
     filename : string,
         output file for profile data
+    gpu_memory_profile_filename_prefix : string
+        filename prefix for the GPU memory profile
     profile_all : boolean,
         all profile types enabled
     profile_symbolic : boolean,
@@ -497,3 +500,52 @@ class Marker(object):
             Default is `process`.
         """
         check_call(_LIB.MXProfileSetMarker(self.domain.handle, c_str(self.name), c_str(scope)))
+
+
+class Scope(object):
+    """
+    The `_profiler.Scope` was developed to assign the profiler scope for the GPU
+    memory profiler. It is implicitly invoked when the Gluon API is used.
+
+    Parameters
+    ==========
+    name : Name of the Profiler Scope
+    append_mode : Whether to append the old profiler scope at the front.
+    """
+    _current = threading.local()
+
+    def __init__(self, name='<unk>:', append_mode=False):
+        self._name = name + ":" if not name.endswith(":") else name
+        self._old_scope = None
+        if append_mode:
+            if not hasattr(Scope._current, "value"):
+                Scope._current.value = Scope()
+            self._name = Scope._current.value.name + self._name
+
+    def __enter__(self):
+        if not hasattr(Scope._current, "value"):
+            Scope._current.value = Scope()
+        self._old_scope = Scope._current.value
+        Scope._current.value = self
+        # Invoke the C API to propagate the profiler scope information to the
+        # C++ backend.
+        check_call(_LIB.MXSetProfilerScope(c_str(self.name)))
+        return self
+
+    def __exit__(self, ptype, value, trace):
+        assert self._old_scope
+        Scope._current.value = self._old_scope
+        # If the old profiler scope is also of type `profiler.Scope`, invoke the
+        # C API once again to recover the previous scope information. Otherwise,
+        # the default scope `<unk>:` will be set.
+        if isinstance(self._old_scope, Scope):
+            check_call(_LIB.MXSetProfilerScope(c_str(self._old_scope.name)))
+        else:
+            check_call(_LIB.MXSetProfilerScope(c_str("<unk>:")))
+
+    @property
+    def name(self):
+        return self._name
+
+# initialize the default profiler scope
+Scope._current.value = Scope()
diff --git a/python/mxnet/symbol/register.py b/python/mxnet/symbol/register.py
index 6b02e6d..e1ce4d4 100644
--- a/python/mxnet/symbol/register.py
+++ b/python/mxnet/symbol/register.py
@@ -28,6 +28,7 @@ from ..base import mx_uint, check_call, _LIB, py_str
 from ..symbol_doc import _build_doc
 from ..base import _Null, _init_op_module, _is_np_op, _output_is_list
 from ..name import NameManager
+from ..profiler import Scope
 # pylint: enable=unused-import
 
 
@@ -195,6 +196,11 @@ def %s(*%s, **kwargs):"""%(func_name, arr_name))
             key_var_num_args, key_var_num_args))
 
             code.append("""
+    if 'profiler_scope' not in keys:
+        keys.append('profiler_scope')
+        if not hasattr(Scope._current, "value"):
+            Scope._current.value = Scope()
+        vals.append(Scope._current.value.name)
     return _symbol_creator(%d, sym_args, sym_kwargs, keys, vals, name, %s, %s)"""%(
                 handle.value, str(is_np_op), str(output_is_list)))
     else:
@@ -252,6 +258,11 @@ def %s(%s):"""%(func_name, ', '.join(signature)))
     if not hasattr(NameManager._current, "value"):
         NameManager._current.value = NameManager()
     name = NameManager._current.value.get(name, '%s')
+    if 'profiler_scope' not in _keys:
+        _keys.append('profiler_scope')
+        if not hasattr(Scope._current, "value"):
+            Scope._current.value = Scope()
+        _vals.append(Scope._current.value.name)
     return _symbol_creator(%d, None, sym_kwargs, _keys, _vals, name, %s, %s)"""%(
         func_name.lower(), handle.value, str(is_np_op), str(output_is_list)))
 
diff --git a/python/mxnet/symbol/symbol.py b/python/mxnet/symbol/symbol.py
index a4599c8..706152f 100644
--- a/python/mxnet/symbol/symbol.py
+++ b/python/mxnet/symbol/symbol.py
@@ -44,6 +44,7 @@ from . import _internal
 from . import op
 from ._internal import SymbolBase, _set_symbol_class
 from ..util import is_np_shape
+from ..profiler import Scope
 
 __all__ = ["Symbol", "var", "Variable", "Group", "load", "load_json",
            "pow", "power", "maximum", "minimum", "hypot", "eye", "zeros",
@@ -2742,7 +2743,7 @@ class Symbol(SymbolBase):
         raise NotImplementedForSymbol(self.backward, None)
 
 def var(name, attr=None, shape=None, lr_mult=None, wd_mult=None, dtype=None,
-        init=None, stype=None, **kwargs):
+        init=None, stype=None, profiler_scope=None, **kwargs):
     """Creates a symbolic variable with specified name.
 
     Example
@@ -2777,6 +2778,8 @@ def var(name, attr=None, shape=None, lr_mult=None, wd_mult=None, dtype=None,
         Initializer for this variable to (optionally) override the default initializer.
     stype : str
         The storage type of the variable, such as 'row_sparse', 'csr', 'default', etc
+    profiler_scope : str
+        The profiler scope for input variable.
     kwargs : Additional attribute variables
         Additional attributes must start and end with double underscores.
 
@@ -2812,6 +2815,12 @@ def var(name, attr=None, shape=None, lr_mult=None, wd_mult=None, dtype=None,
         attr['__init__'] = init
     if stype is not None:
         attr['__storage_type__'] = str(_STORAGE_TYPE_STR_TO_ID[stype])
+    if profiler_scope is not None:
+        attr['__profiler_scope__'] = profiler_scope
+    else:
+        if not hasattr(Scope._current, "value"):
+            Scope._current.value = Scope()
+        attr['__profiler_scope__'] = Scope._current.value.name
     for k, v in kwargs.items():
         if k.startswith('__') and k.endswith('__'):
             attr[k] = str(v)
diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc
index 31b9d84..ad30350 100644
--- a/src/c_api/c_api.cc
+++ b/src/c_api/c_api.cc
@@ -56,6 +56,7 @@
 #include "../operator/subgraph/partitioner/custom_subgraph_property.h"
 #include "../operator/subgraph/subgraph_property.h"
 #include "../common/utils.h"
+#include "../profiler/profiler.h"
 #include "nnvm/pass_functions.h"
 
 using namespace mxnet;
@@ -989,9 +990,12 @@ void CreateNDArray(const DataType* shape,
               "[CreateNDArray] Size of tensor you are trying to allocate is larger than "
               "2^31 elements. Please build with flag USE_INT64_TENSOR_SIZE=1";
   }
-  *out = new NDArray(requested_shape,
-                     Context::Create(static_cast<Context::DeviceType>(dev_type), dev_id),
-                     delay_alloc != 0, dtype);
+  NDArray* nd = new NDArray(requested_shape,
+                            Context::Create(static_cast<Context::DeviceType>(dev_type), dev_id),
+                            delay_alloc != 0, dtype);
+  nd->AssignStorageInfo(profiler::ProfilerScope::Get()->GetCurrentProfilerScope(),
+                        MXNET_STORAGE_DEFAULT_NAME_CSTR);
+  *out = nd;
 }
 
 int MXNDArrayCreate(const uint32_t *shape,
@@ -1001,9 +1005,12 @@ int MXNDArrayCreate(const uint32_t *shape,
                     int delay_alloc,
                     NDArrayHandle *out) {
   API_BEGIN();
-  *out = new NDArray(mxnet::TShape(shape, shape + ndim),
-                     Context::Create(static_cast<Context::DeviceType>(dev_type), dev_id),
-                     delay_alloc != 0);
+  NDArray* nd = new NDArray(mxnet::TShape(shape, shape + ndim),
+                            Context::Create(static_cast<Context::DeviceType>(dev_type), dev_id),
+                            delay_alloc != 0);
+  nd->AssignStorageInfo(profiler::ProfilerScope::Get()->GetCurrentProfilerScope(),
+                        MXNET_STORAGE_DEFAULT_NAME_CSTR);
+  *out = nd;
   API_END();
 }
 
@@ -1054,12 +1061,15 @@ void CreateSparseNDArray(int storage_type,
     aux_shapes.emplace_back(shape_start, shape_start + aux_ndims[i]);
     shape_start += aux_ndims[i];
   }
-  *out = new NDArray(
+  NDArray* nd = new NDArray(
       NDArrayStorageType(storage_type),
       mxnet::TShape(shape, shape + ndim),
       Context::Create(static_cast<Context::DeviceType>(dev_type), dev_id),
       delay_alloc != 0,
       dtype, aux_types, aux_shapes);
+  nd->AssignStorageInfo(profiler::ProfilerScope::Get()->GetCurrentProfilerScope(),
+                        MXNET_STORAGE_DEFAULT_NAME_CSTR);
+  *out = nd;
 }
 
 int MXNDArrayCreateSparseEx(int storage_type,
@@ -2462,14 +2472,20 @@ int MXNDArrayGetSharedMemHandle(NDArrayHandle handle, int* shared_pid, int* shar
 int MXNDArrayCreateFromSharedMem(int shared_pid, int shared_id, const uint32_t *shape,
                                  uint32_t ndim, int dtype, NDArrayHandle *out) {
   API_BEGIN();
-  *out = new NDArray(shared_pid, shared_id, mxnet::TShape(shape, shape + ndim), dtype);
+  NDArray* nd = new NDArray(shared_pid, shared_id, mxnet::TShape(shape, shape + ndim), dtype);
+  nd->AssignStorageInfo(profiler::ProfilerScope::Get()->GetCurrentProfilerScope(),
+                        MXNET_STORAGE_DEFAULT_NAME_CSTR);
+  *out = nd;
   API_END();
 }
 
 int MXNDArrayCreateFromSharedMemEx(int shared_pid, int shared_id, const int *shape,
                                    int ndim, int dtype, NDArrayHandle *out) {
   API_BEGIN();
-  *out = new NDArray(shared_pid, shared_id, mxnet::TShape(shape, shape + ndim), dtype);
+  NDArray* nd = new NDArray(shared_pid, shared_id, mxnet::TShape(shape, shape + ndim), dtype);
+  nd->AssignStorageInfo(profiler::ProfilerScope::Get()->GetCurrentProfilerScope(),
+                        MXNET_STORAGE_DEFAULT_NAME_CSTR);
+  *out = nd;
   API_END();
 }
 
diff --git a/src/c_api/c_api_ndarray.cc b/src/c_api/c_api_ndarray.cc
index b88eea4..ef03fe6 100644
--- a/src/c_api/c_api_ndarray.cc
+++ b/src/c_api/c_api_ndarray.cc
@@ -38,6 +38,7 @@
 #include "../imperative/imperative_utils.h"
 #include "../imperative/cached_op.h"
 #include "../imperative/cached_op_threadsafe.h"
+#include "../profiler/profiler.h"
 
 using namespace mxnet;
 
@@ -98,6 +99,10 @@ void MXImperativeInvokeImpl(AtomicSymbolCreator creator,
 
   nnvm::NodeAttrs attrs = imperative::ParseAttrs(op, num_inputs, num_params,
                                                  param_keys, param_vals);
+  attrs.dict["__profiler_scope__"] = profiler::ProfilerScope::Get()->GetCurrentProfilerScope();
+  if (attrs.op) {
+    attrs.name = attrs.op->name;
+  }
 
   int infered_num_outputs;
   int num_visible_outputs;
diff --git a/src/c_api/c_api_profile.cc b/src/c_api/c_api_profile.cc
index 5eb219a..d0ad6a1 100644
--- a/src/c_api/c_api_profile.cc
+++ b/src/c_api/c_api_profile.cc
@@ -32,6 +32,7 @@
 #include <mxnet/kvstore.h>
 #include <stack>
 #include "./c_api_common.h"
+#include "../profiler/storage_profiler.h"
 #include "../profiler/profiler.h"
 
 namespace mxnet {
@@ -209,6 +210,7 @@ struct ProfileConfigParam : public dmlc::Parameter<ProfileConfigParam> {
   bool profile_memory;
   bool profile_api;
   std::string filename;
+  std::string gpu_memory_profile_filename_prefix;
   bool continuous_dump;
   float dump_period;
   bool aggregate_stats;
@@ -226,6 +228,10 @@ struct ProfileConfigParam : public dmlc::Parameter<ProfileConfigParam> {
       .describe("Profile C API.  Default is True.");
     DMLC_DECLARE_FIELD(filename).set_default("profile.json")
       .describe("File name to write profiling info.");
+#if MXNET_USE_CUDA
+    DMLC_DECLARE_FIELD(gpu_memory_profile_filename_prefix).set_default("gpu_memory_profile")
+      .describe("File name prefix to write GPU memory profile info.");
+#endif  // MXNET_USE_CUDA
     DMLC_DECLARE_FIELD(continuous_dump).set_default(true)
       .describe("Periodically dump (and append) profiling data to file while running. "
                 "Default is True.");
@@ -298,6 +304,10 @@ int MXSetProcessProfilerConfig(int num_params, const char* const* keys, const ch
                                            param.continuous_dump,
                                            param.dump_period,
                                            param.aggregate_stats);
+#if MXNET_USE_CUDA
+      profiler::GpuDeviceStorageProfiler::Get()->SetConfig(
+          param.gpu_memory_profile_filename_prefix);
+#endif  // MXNET_USE_CUDA
     }
   API_END();
 }
@@ -354,6 +364,9 @@ int MXDumpProcessProfile(int finished, int profile_process, KVStoreHandle kvStor
     CHECK(profiler->IsEnableOutput())
       << "Profiler hasn't been run. Config and start profiler first";
     profiler->DumpProfile(finished != 0);
+#if MXNET_USE_CUDA
+    profiler::GpuDeviceStorageProfiler::Get()->DumpProfile();
+#endif  // MXNET_USE_CUDA
   }
   API_END()
 }
@@ -362,6 +375,12 @@ int MXSetProfilerState(int state) {
   return MXSetProcessProfilerState(state, static_cast<int>(ProfileProcess::kWorker), nullptr);
 }
 
+int MXSetProfilerScope(const char* const scope) {
+  API_BEGIN();
+  profiler::ProfilerScope::Get()->SetCurrentProfilerScope(scope);
+  API_END();
+}
+
 int MXSetProcessProfilerState(int state, int profile_process, KVStoreHandle kvStoreHandle) {
   mxnet::IgnoreProfileCallScope ignore;
   // state, kNotRunning: 0, kRunning: 1
diff --git a/src/c_api/c_api_symbolic.cc b/src/c_api/c_api_symbolic.cc
index 8f78fc1..9042dfa 100644
--- a/src/c_api/c_api_symbolic.cc
+++ b/src/c_api/c_api_symbolic.cc
@@ -41,10 +41,20 @@ void RegisterLegacyOpProp();
 void RegisterLegacyNDFunc();
 }
 const std::vector<std::string> kHiddenKeys = {
-  "ctx_group", "lr_mult", "wd_mult", "force_mirroring", "mirror_stage"
+  "ctx_group",
+  "lr_mult",
+  "wd_mult",
+  "force_mirroring",
+  "mirror_stage",
+  "profiler_scope"
 };
 const std::vector<std::string> kReplacedHiddenKeys = {
-  "__ctx_group__", "__lr_mult__", "__wd_mult__", "__force_mirroring__", "__mirror_stage__"
+  "__ctx_group__",
+  "__lr_mult__",
+  "__wd_mult__",
+  "__force_mirroring__",
+  "__mirror_stage__",
+  "__profiler_scope__"
 };
 const char *kNamespaceSeparator = "$";
 
diff --git a/src/common/cuda_utils.h b/src/common/cuda_utils.h
index ccf0931..0971cfd 100644
--- a/src/common/cuda_utils.h
+++ b/src/common/cuda_utils.h
@@ -56,6 +56,9 @@ extern __cuda_fake_struct blockIdx;
 #include <cuda_runtime.h>
 #include <cublas_v2.h>
 #include <curand.h>
+#if MXNET_USE_NVML
+#include <nvml.h>
+#endif  // MXNET_USE_NVML
 
 #include <vector>
 
@@ -175,6 +178,22 @@ inline __device__ bool __is_supported_cuda_architecture() {
   }
 
 
+#if MXNET_USE_NVML
+/*!
+ * \brief Protected NVML call.
+ * \param func Expression to call.
+ *
+ * It checks for NVML errors after invocation of the expression.
+ */
+#define NVML_CALL(func)                                 \
+  {                                                     \
+    nvmlReturn_t result = (func);                       \
+    CHECK_EQ(result, NVML_SUCCESS)                      \
+      << #func " failed with error "                    \
+      << nvmlErrorString(result);                       \
+  }
+#endif  // MXNET_USE_NVML
+
 #if !defined(_MSC_VER)
 #define CUDA_UNROLL _Pragma("unroll")
 #define CUDA_NOUNROLL _Pragma("nounroll")
diff --git a/src/common/utils.h b/src/common/utils.h
index 31e6dea..40d5a55 100644
--- a/src/common/utils.h
+++ b/src/common/utils.h
@@ -31,6 +31,7 @@
 #include <nnvm/node.h>
 #include <mxnet/engine.h>
 #include <mxnet/ndarray.h>
+#include <mxnet/storage.h>
 #include <mxnet/op_attr_types.h>
 #include <mxnet/graph_attr_types.h>
 #include <nnvm/graph_attr_types.h>
@@ -731,8 +732,10 @@ inline NDArray InitZeros(const NDArrayStorageType stype, const mxnet::TShape &sh
 /*!
  * \brief Helper to add a NDArray of zeros to a std::vector.
  */
-inline void EmplaceBackZeros(const NDArrayStorageType stype, const mxnet::TShape &shape,
-                             const Context &ctx, const int dtype,
+inline void EmplaceBackZeros(const NDArrayStorageType stype,
+                             const mxnet::TShape &shape,
+                             const Context &ctx,
+                             const int dtype,
                              std::vector<NDArray> *vec) {
   // NDArray with default storage
   if (stype == kDefaultStorage) {
@@ -915,6 +918,19 @@ inline int np_binary_out_infer_type(const int type1, const int type2) {
   return get_more_precise_type(type1, type2);
 }
 
+inline const std::string
+NodeAttrsGetProfilerScope(const nnvm::NodeAttrs& attrs) {
+  // obtain the profiler scope name, if assigned previously
+  std::string profiler_scope = MXNET_STORAGE_DEFAULT_PROFILER_SCOPE_CSTR;
+  const std::unordered_map<std::string, std::string>& node_attrs_dict = attrs.dict;
+  const std::unordered_map<std::string, std::string>::const_iterator
+      profiler_scope_iter  = node_attrs_dict.find("__profiler_scope__");
+  if (profiler_scope_iter != node_attrs_dict.end()) {
+    profiler_scope = profiler_scope_iter->second;
+  }
+  return profiler_scope;
+}
+
 }  // namespace common
 }  // namespace mxnet
 #endif  // MXNET_COMMON_UTILS_H_
diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc
index 13bab2e..02ce818 100644
--- a/src/executor/graph_executor.cc
+++ b/src/executor/graph_executor.cc
@@ -342,7 +342,9 @@ nnvm::Graph GraphExecutor::InitFullGraph(nnvm::Symbol symbol,
   if (!need_grad_) return g;
   for (size_t i = 0; i < g.outputs.size(); ++i) {
     NodeEntry ngrad(nnvm::Node::Create(), 0, 0);
-    ngrad.node->attrs.name = "_head_grad_" + std::to_string(i);
+    const nnvm::NodeAttrs& attrs = g.outputs[i].node->attrs;
+    ngrad.node->attrs.name = attrs.name + "_head_grad";
+    ngrad.node->attrs.dict["__profiler_scope__"] = common::NodeAttrsGetProfilerScope(attrs);
     head_grad_entry_.emplace_back(AttrHint(ngrad, g.outputs[i]));
     head_grad_map_[ngrad.node.get()] = i;
   }
@@ -517,9 +519,11 @@ void GraphExecutor::InitArguments(const nnvm::IndexedGraph& idx,
     const int inferred_dtype = inferred_dtypes[eid];
     const NDArrayStorageType inferred_stype = (NDArrayStorageType) inferred_stypes[eid];
     const std::string& arg_name = idx[nid].source->attrs.name;
+    const std::string profiler_scope = common::NodeAttrsGetProfilerScope(idx[nid].source->attrs);
     if (mutable_nodes.count(nid)) {  // aux_states
       EmplaceBackZeros(inferred_stype, inferred_shape, aux_state_ctxes[aux_top],
                        inferred_dtype, aux_state_vec);
+      aux_state_vec->back().AssignStorageInfo(profiler_scope + "aux_state:", arg_name);
       data_entry_[eid] = aux_state_vec->back();
       aux_state_map_.emplace(arg_name, aux_state_vec->back());
       ++aux_top;
@@ -530,6 +534,7 @@ void GraphExecutor::InitArguments(const nnvm::IndexedGraph& idx,
     } else {  // in_args
       EmplaceBackZeros(inferred_stype, inferred_shape, in_arg_ctxes[arg_top],
                        inferred_dtype, in_arg_vec);
+      in_arg_vec->back().AssignStorageInfo(profiler_scope + "in_arg:", arg_name);
       data_entry_[eid] = in_arg_vec->back();
       if (log_verbose_) {
         LOG(INFO) << "\tassign data entry\t" << eid << "\tas "
@@ -545,6 +550,7 @@ void GraphExecutor::InitArguments(const nnvm::IndexedGraph& idx,
         auto grad_stype = (NDArrayStorageType) inferred_stypes[grad_eid];
         EmplaceBackZeros(grad_stype, inferred_shape, arg_grad_ctxes[arg_top],
                          inferred_dtype, arg_grad_vec);
+        arg_grad_vec->back().AssignStorageInfo(profiler_scope + "arg_grad:", arg_name);
         if (log_verbose_) {
           LOG(INFO) << "\tassign grad entry\t" << grad_eid << "\tas "
                     << common::stype_string(grad_stype);
@@ -589,6 +595,7 @@ void GraphExecutor::InitArguments(const nnvm::IndexedGraph& idx,
     const int inferred_dtype = inferred_dtypes[eid];
     const NDArrayStorageType inferred_stype = (NDArrayStorageType) inferred_stypes[eid];
     const std::string& arg_name = idx[nid].source->attrs.name;
+    const std::string profiler_scope = common::NodeAttrsGetProfilerScope(idx[nid].source->attrs);
     // aux_states
     if (mutable_nodes.count(nid)) {
       if (nullptr != shared_exec) {
@@ -611,6 +618,7 @@ void GraphExecutor::InitArguments(const nnvm::IndexedGraph& idx,
       } else {
         EmplaceBackZeros(inferred_stype, inferred_shape, aux_state_ctxes[aux_top],
                          inferred_dtype, aux_state_vec);
+        aux_state_vec->back().AssignStorageInfo(profiler_scope + "aux_state:", arg_name);
       }  // if (has_shared_exec)
       data_entry_[eid] = aux_state_vec->back();
       aux_state_map_.emplace(arg_name, aux_state_vec->back());
@@ -648,6 +656,7 @@ void GraphExecutor::InitArguments(const nnvm::IndexedGraph& idx,
           // doesn't have shared_exec, or non-default storage
           EmplaceBackZeros(inferred_stype, inferred_shape, in_arg_ctxes[arg_top],
                            inferred_dtype, in_arg_vec);
+          in_arg_vec->back().AssignStorageInfo(profiler_scope + "in_arg:", arg_name);
         }
         // gradient for model parameter
         if (kNullOp == grad_req_types[arg_top]) {
@@ -664,6 +673,7 @@ void GraphExecutor::InitArguments(const nnvm::IndexedGraph& idx,
             // no need to reuse memory from shared_exec for gradient of non-default storage
             EmplaceBackZeros(grad_stype, inferred_shape, arg_grad_ctxes[arg_top],
                              inferred_dtype, arg_grad_vec);
+            arg_grad_vec->back().AssignStorageInfo(profiler_scope + "arg_grad:", arg_name);
           }
           grad_store_.emplace_back(grad_req_types[arg_top], arg_grad_vec->back());
         }
@@ -673,6 +683,7 @@ void GraphExecutor::InitArguments(const nnvm::IndexedGraph& idx,
         in_arg_vec->emplace_back(ReshapeOrCreate(arg_name, inferred_shape, inferred_dtype,
                                                  inferred_stype, in_arg_ctxes[arg_top],
                                                  shared_buffer, enable_row_sparse_sharing));
+        in_arg_vec->back().AssignStorageInfo(profiler_scope + "in_arg:", arg_name);
         // gradient for model parameter, row_sparse ndarray sharing disabled
         if (kNullOp == grad_req_types[arg_top]) {
           arg_grad_vec->emplace_back();
@@ -685,6 +696,7 @@ void GraphExecutor::InitArguments(const nnvm::IndexedGraph& idx,
                                                      inferred_dtype, grad_stype,
                                                      arg_grad_ctxes[arg_top], shared_buffer,
                                                      enable_row_sparse_sharing));
+          arg_grad_vec->back().AssignStorageInfo(profiler_scope + "arg_grad:", arg_name);
           grad_store_.emplace_back(grad_req_types[arg_top], arg_grad_vec->back());
         }  // if (kNullOp == grad_req_types[arg_top])
       }  // if (shared_arg_names.count(arg_name))
@@ -1077,12 +1089,17 @@ void GraphExecutor::InitDataEntryMemory(std::vector<NDArray>* shared_pool) {
   CHECK_EQ(data_entry_.size(), vshape.size());
   std::vector<Context> data_context(idx.num_node_entries());
   std::vector<NDArrayStorageType> data_storage_type(idx.num_node_entries(), kUndefinedStorage);
+  std::vector<std::string> data_storage_profiler_scope(idx.num_node_entries());
+  std::vector<std::string> data_storage_name(idx.num_node_entries());
   for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
+    const std::string profiler_scope = common::NodeAttrsGetProfilerScope(idx[nid].source->attrs);
     for (uint32_t i = 0; i < idx[nid].source->num_outputs(); ++i) {
       auto eid = idx.entry_id(nid, i);
       data_context[eid] = vctx[nid];
       CHECK_NE(vstorage_type[eid], kUndefinedStorage);
       data_storage_type[eid] = (NDArrayStorageType) vstorage_type[eid];
+      data_storage_profiler_scope[eid] = profiler_scope;
+      data_storage_name[eid] = idx[nid].source->attrs.name;
     }
   }
 
@@ -1091,6 +1108,8 @@ void GraphExecutor::InitDataEntryMemory(std::vector<NDArray>* shared_pool) {
     Context ctx;
     size_t bytes;
     NDArrayStorageType stype;
+    std::string profiler_scope;
+    std::string name;
   };
   std::vector<PoolEntry> pool_info;
 
@@ -1111,6 +1130,8 @@ void GraphExecutor::InitDataEntryMemory(std::vector<NDArray>* shared_pool) {
     } else {
       data_entry_[data_eid] = NDArray(data_context[eid], vdtype[eid]);
     }
+    data_entry_[data_eid].AssignStorageInfo(data_storage_profiler_scope[data_eid],
+                                            data_storage_name[data_eid]);
     if (log_verbose_) {
       LOG(INFO) << "\tinit head_grad entry\t" << data_eid << "\tas "
                 << common::stype_string(stype);
@@ -1129,11 +1150,14 @@ void GraphExecutor::InitDataEntryMemory(std::vector<NDArray>* shared_pool) {
     if (storage_id < 0) continue;
     size_t sid = static_cast<size_t>(storage_id);
     if (sid >= pool_info.size()) {
-      pool_info.resize(sid + 1, PoolEntry{Context::CPU(), size_t(0), kUndefinedStorage});
+      pool_info.resize(sid + 1, PoolEntry{Context::CPU(), size_t(0), kUndefinedStorage,
+                                          MXNET_STORAGE_DEFAULT_PROFILER_SCOPE_CSTR,
+                                          MXNET_STORAGE_DEFAULT_NAME_CSTR});
     }
     PoolEntry& info = pool_info[sid];
     if (info.bytes == 0) {
-      info = PoolEntry{data_context[i], bytes, data_storage_type[i]};
+      info = PoolEntry{data_context[i], bytes, data_storage_type[i],
+                       data_storage_profiler_scope[i], data_storage_name[i]};
     } else {
       info.bytes = std::max(info.bytes, bytes);
     }
@@ -1183,6 +1207,8 @@ void GraphExecutor::InitDataEntryMemory(std::vector<NDArray>* shared_pool) {
       // TODO(junwu): adding delay_alloc=true to create nd
       // is a temporary solution.
       NDArray nd(shape, ctx, true);
+      nd.AssignStorageInfo(pool_info[i].profiler_scope,
+                           pool_info[i].name);
       data_pool_[i] = nd;
       // put the new allocated arrays to shared pool
       if (shared_pool != nullptr)  {
@@ -1201,6 +1227,8 @@ void GraphExecutor::InitDataEntryMemory(std::vector<NDArray>* shared_pool) {
     if (storage_type == kDefaultStorage) {
       if (!shape_is_known(vshape[i])) {
         data_entry_[i] = NDArray(data_context[i], vdtype[i]);
+        data_entry_[i].AssignStorageInfo(data_storage_profiler_scope[i],
+                                         data_storage_name[i]);
       } else {
         CHECK_GE(storage_id, 0) << "Do not support runtime shape op yet";
         const NDArray& src = data_pool_.at(storage_id);
@@ -1209,6 +1237,8 @@ void GraphExecutor::InitDataEntryMemory(std::vector<NDArray>* shared_pool) {
     } else {
       data_entry_[i] = NDArray(storage_type, vshape[i], data_context[i],
                                true, vdtype[i]);
+      data_entry_[i].AssignStorageInfo(data_storage_profiler_scope[i],
+                                       data_storage_name[i]);
     }
     if (log_verbose_) {
       LOG(INFO) << "\tinit data entry\t" << i << "\tas " << common::stype_string(storage_type);
diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc
index 89dabac..0f8b86a 100644
--- a/src/imperative/cached_op.cc
+++ b/src/imperative/cached_op.cc
@@ -683,6 +683,9 @@ OpStatePtr CachedOp::StaticForward(
     if (!outputs[i]->is_none()) continue;
     *outputs[i] = NDArray(static_cast<NDArrayStorageType>(stypes[eid]),
                           shapes[eid], default_ctx, true, dtypes[eid]);
+    const nnvm::NodeAttrs& attrs = idx[idx.outputs()[i].node_id].source->attrs;
+    outputs[i]->AssignStorageInfo(common::NodeAttrsGetProfilerScope(attrs),
+                                  attrs.name);
   }
 
   StaticRunOps(default_ctx, g, state_ptr, arrays, 0, idx.num_nodes());
@@ -766,6 +769,22 @@ OpStatePtr CachedOp::Forward(
   static const auto cached_op = nnvm::Op::Get("_CachedOp");
 
   CHECK_EQ(inputs.size(), num_inputs());
+  // Assign the storage information for the input arguments. Similar to the
+  // implementation in `graph_executor.cc`, we use `mutable_input_nodes()` to
+  // distinguish between weight parameters and auxiliary states.
+  const auto& fwd_idx = fwd_graph_.indexed_graph();
+  const auto& mutable_input_nodes = fwd_idx.mutable_input_nodes();
+  for (size_t i = 0; i < fwd_idx.input_nodes().size(); ++i) {
+    const uint32_t nid = fwd_idx.input_nodes().at(i);
+    const nnvm::NodeAttrs& attrs = fwd_idx[nid].source->attrs;
+    const std::string& arg_name = attrs.name;
+    const std::string profiler_scope = common::NodeAttrsGetProfilerScope(attrs);
+    if (mutable_input_nodes.count(nid)) {
+      inputs[i]->AssignStorageInfo(profiler_scope + "aux_state:", arg_name);
+    } else {
+      inputs[i]->AssignStorageInfo(profiler_scope + "in_arg:", arg_name);
+    }
+  }
 
   Context default_ctx = inputs[0]->ctx();
   {
@@ -993,6 +1012,27 @@ void CachedOp::Backward(
     const std::vector<NDArray*>& inputs,
     const std::vector<OpReqType>& reqs,
     const std::vector<NDArray*>& outputs) {
+  const auto& fwd_idx = fwd_graph_.indexed_graph();
+  const auto& full_idx = full_graph_.indexed_graph();
+  const auto& mutable_input_nodes = fwd_idx.mutable_input_nodes();
+  for (size_t i = 0, j = 0; i < fwd_idx.input_nodes().size(); ++i) {
+    const uint32_t nid = fwd_idx.input_nodes().at(i);
+    const std::string& arg_name = fwd_idx[nid].source->attrs.name;
+    const std::string profiler_scope =
+        common::NodeAttrsGetProfilerScope(fwd_idx[nid].source->attrs);
+    if (mutable_input_nodes.count(nid)) {
+      continue;
+    }
+    outputs[j++]->AssignStorageInfo(profiler_scope + "arg_grad:", arg_name);
+  }
+  for (size_t i = fwd_idx.input_nodes().size(), j = 0;
+       i < full_idx.input_nodes().size(); ++i) {
+    const nnvm::NodeAttrs& attrs = full_idx[full_idx.input_nodes().at(i)].source->attrs;
+    const std::string& entry_name = attrs.name;
+    const std::string profiler_scope = common::NodeAttrsGetProfilerScope(attrs);
+    inputs[j++]->AssignStorageInfo(profiler_scope, entry_name);
+  }
+
   using namespace imperative;
   CHECK(!Imperative::Get()->is_recording())
       << "CachedOp does not support higher order gradients. "
diff --git a/src/imperative/cached_op.h b/src/imperative/cached_op.h
index c56d8cf..16a5d30 100644
--- a/src/imperative/cached_op.h
+++ b/src/imperative/cached_op.h
@@ -96,6 +96,10 @@ void CreateGraphNDs(const nnvm::Graph& g,
       continue;
     *((*arrays)[eid]) = NDArray(static_cast<NDArrayStorageType>(stypes[eid]),
                                 shapes[eid], default_ctx, true, dtypes[eid]);
+    const nnvm::NodeAttrs& attrs = idx[idx.outputs()[i].node_id].source->attrs;
+    (*arrays)[eid]->AssignStorageInfo(
+        common::NodeAttrsGetProfilerScope(attrs),
+        attrs.name);
   }
 }
 
@@ -136,7 +140,9 @@ void CreateBackwardGraph(nnvm::Graph* fwd_graph,
   ograd_entries->reserve(fwd_graph->outputs.size());
   for (size_t i = 0; i < fwd_graph->outputs.size(); ++i) {
     nnvm::ObjectPtr np = Node::Create();
-    np->attrs.name = "_head_grad_" + std::to_string(i);
+    const nnvm::NodeAttrs& attrs = fwd_graph->outputs[i].node->attrs;
+    np->attrs.name = attrs.name + "_head_grad";
+    np->attrs.dict["__profiler_scope__"] = common::NodeAttrsGetProfilerScope(attrs);
     ograd_entries->emplace_back(np);
   }
 
diff --git a/src/imperative/imperative_utils.h b/src/imperative/imperative_utils.h
index 1560138..21d5298 100644
--- a/src/imperative/imperative_utils.h
+++ b/src/imperative/imperative_utils.h
@@ -189,7 +189,7 @@ inline void SetShapeType(const Context& ctx,
   } else {
     // if infer storage attr is not present, apply the default infer storage function
     infer_stype_success = common::DefaultStorageType(attrs, ctx.dev_mask(), dispatch_mode,
-                                                   &in_storage_types, &out_storage_types);
+                                                     &in_storage_types, &out_storage_types);
   }
   CHECK(infer_stype_success) << "Operator not implemented: "
      << common::operator_stype_string(attrs, ctx.dev_mask(), in_storage_types, out_storage_types);
@@ -206,10 +206,13 @@ inline void SetShapeType(const Context& ctx,
       if (is_dynamic_shape_existing) {
         // once there is dynamic shape somewhere, we could not pre-determine the shape.
         *outputs[i] = NDArray(ctx, out_types[i]);
+        outputs[i]->AssignStorageInfo(common::NodeAttrsGetProfilerScope(attrs), attrs.name);
       } else if (storage_type == kDefaultStorage) {
         *outputs[i] = NDArray(out_shapes[i], ctx, true, out_types[i]);
+        outputs[i]->AssignStorageInfo(common::NodeAttrsGetProfilerScope(attrs), attrs.name);
       } else {
         *outputs[i] = NDArray(storage_type, out_shapes[i], ctx, true, out_types[i]);
+        outputs[i]->AssignStorageInfo(common::NodeAttrsGetProfilerScope(attrs), attrs.name);
       }
     } else {
       CHECK_EQ(outputs[i]->shape(), out_shapes[i])
@@ -894,15 +897,31 @@ inline std::multimap<size_t, NDArray> AllocateMemory(
   const auto& dtypes = g.GetAttr<DTypeVector>("dtype");
   const auto& shapes = g.GetAttr<mxnet::ShapeVector>("shape");
   const auto& stypes = g.GetAttr<StorageTypeVector>("storage_type");
+  std::vector<std::string> data_entry_profiler_scopes(entry_end - entry_start);
+  std::vector<std::string> data_entry_names(entry_end - entry_start);
 
   std::multimap<size_t, NDArray> new_pool;
 
+  for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
+    const std::string profiler_scope = common::NodeAttrsGetProfilerScope(idx[nid].source->attrs);
+    for (uint32_t i = 0; i < idx[nid].source->num_outputs(); ++i) {
+      uint32_t eid = idx.entry_id(nid, i);
+      if (eid < entry_start || eid >= entry_end) {
+        continue;
+      }
+      data_entry_profiler_scopes[eid - entry_start] = profiler_scope;
+      data_entry_names[eid - entry_start] = idx[nid].source->attrs.name;
+    }
+  }
+
   for (uint32_t i = entry_start; i < entry_end; ++i) {
     if (mem_plan[i].storage_id == exec::kExternalStorageID) continue;
     CHECK(arrays[i]->is_none());
     if (mem_plan[i].storage_id == exec::kDynamicStorageID) {
       *arrays[i] = NDArray(static_cast<NDArrayStorageType>(stypes[i]),
                            shapes[i], default_ctx, true, dtypes[i]);
+      arrays[i]->AssignStorageInfo(data_entry_profiler_scopes[i - entry_start],
+                                   data_entry_names[i - entry_start]);
       continue;
     }
     CHECK_EQ(stypes[i], kDefaultStorage);
@@ -915,6 +934,8 @@ inline std::multimap<size_t, NDArray> AllocateMemory(
       } else {
         NDArray buff(mxnet::TShape({static_cast<nnvm::dim_t>(mem_plan[i].size)}),
                      default_ctx, true, mshadow::kUint8);
+        buff.AssignStorageInfo(data_entry_profiler_scopes[i - entry_start],
+                               data_entry_names[i - entry_start]);
         *arrays[i] = buff.AsArray(shapes[i], dtypes[i]);
         new_pool.insert({mem_plan[i].size, buff});
       }
diff --git a/src/io/iter_image_recordio_2.cc b/src/io/iter_image_recordio_2.cc
index 565b671..ad6e2a1 100644
--- a/src/io/iter_image_recordio_2.cc
+++ b/src/io/iter_image_recordio_2.cc
@@ -42,6 +42,7 @@
 #include "./image_iter_common.h"
 #include "./inst_vector.h"
 #include "../common/utils.h"
+#include "../profiler/profiler.h"
 
 namespace mxnet {
 
@@ -296,10 +297,16 @@ inline bool ImageRecordIOParser2<DType>::ParseNext(DataBatch *out) {
     if (dev_id != -1) {
       ctx = Context::CPUPinned(dev_id);
     }
+
+    const std::string profiler_scope =
+        profiler::ProfilerScope::Get()->GetCurrentProfilerScope() + "image_io:";
+
     out->data.at(0) = NDArray(data_shape, ctx, false,
       mshadow::DataType<DType>::kFlag);
+    out->data.at(0).AssignStorageInfo(profiler_scope, "data");
     out->data.at(1) = NDArray(label_shape, ctx, false,
       mshadow::DataType<real_t>::kFlag);
+    out->data.at(1).AssignStorageInfo(profiler_scope, "label");
     unit_size_[0] = param_.data_shape.Size();
     unit_size_[1] = param_.label_width;
   }
diff --git a/src/kvstore/comm.h b/src/kvstore/comm.h
index 88e363b..b03b74c 100644
--- a/src/kvstore/comm.h
+++ b/src/kvstore/comm.h
@@ -34,6 +34,7 @@
 #include "gradient_compression.h"
 #include "../ndarray/ndarray_function.h"
 #include "../operator/tensor/sparse_retain-inl.h"
+#include "../profiler/profiler.h"
 #include "./kvstore_utils.h"
 namespace mxnet {
 namespace kvstore {
@@ -532,9 +533,12 @@ class CommDevice : public Comm {
         // NDArray.Slice or gpu direct memory access. for the latter, we need to
         // remove some ctx check, and also it reduces 20% perf
         buf.copy_buf.resize(src.size()-1);
+        const std::string profiler_scope =
+            profiler::ProfilerScope::Get()->GetCurrentProfilerScope() + "comm_dev:";
         for (size_t i = 0; i < src.size()-1; ++i) {
           buf.copy_buf[i] = NDArray(
             buf_merged.shape(), buf_merged.ctx(), false, buf_merged.dtype());
+          buf.copy_buf[i].AssignStorageInfo(profiler_scope, "copy_buf");
         }
       }
       for (size_t i = 0; i < src.size()-1; ++i) {
@@ -560,18 +564,23 @@ class CommDevice : public Comm {
       buf.compressed_recv_buf.resize(src.size());
       buf.compressed_send_buf.resize(src.size());
       buf.residual.resize(src.size());
-
+      const std::string profiler_scope =
+          profiler::ProfilerScope::Get()->GetCurrentProfilerScope() + "comm_dev:";
       for (size_t i = 0; i < src.size(); ++i) {
         buf.copy_buf[i] = NDArray(buf.merged.shape(), buf.merged.ctx(),
                                   false, buf.merged.dtype());
+        buf.copy_buf[i].AssignStorageInfo(profiler_scope, "copy_buf");
         buf.residual[i] = NDArray(buf.merged.shape(), src[i].ctx(),
                                   false, buf.merged.dtype());
+        buf.residual[i].AssignStorageInfo(profiler_scope, "residual");
         buf.residual[i] = 0;
         int64_t small_size = gc_->GetCompressedSize(buf.merged.shape().Size());
         buf.compressed_recv_buf[i] = NDArray(mxnet::TShape{small_size}, buf.merged.ctx(),
-                                        false, buf.merged.dtype());
+                                             false, buf.merged.dtype());
+        buf.compressed_recv_buf[i].AssignStorageInfo(profiler_scope, "compressed_recv_buf");
         buf.compressed_send_buf[i] = NDArray(mxnet::TShape{small_size}, src[i].ctx(),
-                                        false, buf.merged.dtype());
+                                             false, buf.merged.dtype());
+        buf.compressed_send_buf[i].AssignStorageInfo(profiler_scope, "compressed_send_buf");
       }
     }
 
@@ -686,6 +695,9 @@ class CommDevice : public Comm {
       ctx_info[d.dev_id] = std::make_pair(d, 0);
     }
 
+    const std::string profiler_scope =
+        profiler::ProfilerScope::Get()->GetCurrentProfilerScope() + "kvstore:comm_dev:";
+
     for (auto& sorted_key_attr : sorted_key_attrs_) {
       const int key  = std::get<0>(sorted_key_attr);
       const mxnet::TShape& shape = std::get<1>(sorted_key_attr);
@@ -705,6 +717,7 @@ class CommDevice : public Comm {
       if (buf.merged.is_none()) {
         bool delay_alloc = true;
         buf.merged = NDArray(shape, ctx, delay_alloc, type);
+        buf.merged.AssignStorageInfo(profiler_scope, "merge_buf_" + std::to_string(key));
       }
       ctx_info[ctx.dev_id].second += shape.Size();
     }
diff --git a/src/kvstore/comm_tree.h b/src/kvstore/comm_tree.h
index 11ca2d6..07ffba5 100644
--- a/src/kvstore/comm_tree.h
+++ b/src/kvstore/comm_tree.h
@@ -400,6 +400,9 @@ class CommDeviceTree : public CommDevice {
     bool delay_alloc = true;
     std::map<int, int> key_dist;
 
+    const std::string profiler_scope =
+        profiler::ProfilerScope::Get()->GetCurrentProfilerScope() + "comm_dev_tree:";
+
     for (auto& tree_sorted_key_attr : tree_sorted_key_attrs_) {
       const int key  = std::get<0>(tree_sorted_key_attr);
       const mxnet::TShape& shape = std::get<1>(tree_sorted_key_attr);
@@ -457,6 +460,8 @@ class CommDeviceTree : public CommDevice {
               if (row == devs_.size()-1)
                 shape_copy[0] = last_slice;
               buf.merged[row] = NDArray(shape_copy, ctx, delay_alloc, type);
+              buf.merged[row].AssignStorageInfo(
+                  profiler_scope, "merged_" + std::to_string(key));
               buf.copy_buf.emplace_back();
               if (buf.copy_buf[row].empty()) {
                 buf.copy_buf[row].resize(kBranch-1);
@@ -465,18 +470,23 @@ class CommDeviceTree : public CommDevice {
                                                    buf.merged[row].ctx(),
                                                    delay_alloc,
                                                    buf.merged[row].dtype());
+                  buf.copy_buf[row][col].AssignStorageInfo(profiler_scope, "copy_buf");
                 }
               }
             }
           } else {
             buf.merged.emplace_back(shape, ctx, false, type);
+            buf.merged.back().AssignStorageInfo(
+                profiler_scope, "merged_" + std::to_string(key));
             if (buf.copy_buf.empty()) {
               buf.copy_buf.emplace_back();
               buf.copy_buf[0].resize(kBranch-1);
               for (size_t col = 0; col < buf.copy_buf[0].size(); ++col) {
                 buf.copy_buf[0][col] = NDArray(buf.merged[0].shape(),
-                                               buf.merged[0].ctx(), delay_alloc,
+                                               buf.merged[0].ctx(),
+                                               delay_alloc,
                                                buf.merged[0].dtype());
+                buf.copy_buf[0][col].AssignStorageInfo(profiler_scope, "copy_buf");
               }
             }
           }
diff --git a/src/kvstore/kvstore_local.h b/src/kvstore/kvstore_local.h
index 38bd7d9..bc4e933 100644
--- a/src/kvstore/kvstore_local.h
+++ b/src/kvstore/kvstore_local.h
@@ -37,6 +37,7 @@
 #include "./comm_tree.h"
 #include "./kvstore_utils.h"
 #include "../ndarray/ndarray_function.h"
+#include "../profiler/profiler.h"
 
 namespace mxnet {
 namespace kvstore {
@@ -246,6 +247,15 @@ class KVStoreLocal : public KVStore {
       int key = uniq_keys[i];
       const NDArray& merged = comm_->Reduce(key, grouped_vals[i], priority);
       NDArray& local = local_[key];
+      if (key_type_ == kStringKey) {
+        local.AssignStorageInfo(
+            profiler::ProfilerScope::Get()->GetCurrentProfilerScope() + "kvstore:push:",
+            reverse_str_key_dict_[key]);
+      } else {
+        local.AssignStorageInfo(
+            profiler::ProfilerScope::Get()->GetCurrentProfilerScope() + "kvstore:push:",
+            "local_" + std::to_string(key));
+      }
       if (updater_ != nullptr) {
         CHECK(!local.is_none()) << "key " << key << " has not been inited";
         // if merged is on gpu, we may need copy weight from cpu to gpu
@@ -288,6 +298,18 @@ class KVStoreLocal : public KVStore {
       const NDArray& local = local_[key];
       CHECK(!local.is_none()) << "key " << key << " has not been inited";
       comm_->Broadcast(key, local, grouped_vals[i], priority);
+      for (std::vector<NDArray*>::iterator iter = grouped_vals[i].begin();
+           iter != grouped_vals[i].end(); ++iter) {
+        if (key_type_ == kStringKey) {
+          (*iter)->AssignStorageInfo(
+              profiler::ProfilerScope::Get()->GetCurrentProfilerScope() + "kvstore:pull:",
+              reverse_str_key_dict_[key]);
+        } else {
+          (*iter)->AssignStorageInfo(
+              profiler::ProfilerScope::Get()->GetCurrentProfilerScope() + "kvstore:pull:",
+              "grouped_vals_" + std::to_string(key));
+        }
+      }
     }
   }
 
diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc
index d16b38e..f851383 100644
--- a/src/ndarray/ndarray.cc
+++ b/src/ndarray/ndarray.cc
@@ -36,6 +36,7 @@
 #include "../operator/tensor/matrix_op-inl.h"
 #include "../operator/tensor/init_op.h"
 #include "../operator/nn/mkldnn/mkldnn_base-inl.h"
+#include "../profiler/storage_profiler.h"
 
 #if MXNET_USE_OPENCV
 #include <opencv2/opencv.hpp>
@@ -93,6 +94,25 @@ NDArray::NDArray(const NDArrayStorageType stype, const mxnet::TShape &shape, Con
         dtype, aux_types, aux_shapes);
 }
 
+void NDArray::AssignStorageInfo(const std::string& profiler_scope,
+                                const std::string& name) {
+  if (is_none()) {
+    return;
+  }
+  ptr_->shandle.profiler_scope = profiler_scope;
+  ptr_->shandle.name = name;
+#if MXNET_USE_CUDA
+  profiler::GpuDeviceStorageProfiler::Get()->UpdateStorageInfo(ptr_->shandle);
+#endif  // MXNET_USE_CUDA
+  for (Storage::Handle& aux_handle : ptr_->aux_handles) {
+    aux_handle.profiler_scope = profiler_scope;
+    aux_handle.name = name + "_aux_data";
+#if MXNET_USE_CUDA
+    profiler::GpuDeviceStorageProfiler::Get()->UpdateStorageInfo(aux_handle);
+#endif  // MXNET_USE_CUDA
+  }
+}
+
 void NDArray::SetShapeFromChunk() {
   if (Imperative::Get()->is_np_shape() ||
       !(ptr_->storage_shape.ndim() == 1 && ptr_->storage_shape[0] == 0)) {
@@ -148,7 +168,8 @@ void NDArray::Chunk::CheckAndAllocData(const mxnet::TShape &shape, int dtype) {
     // free storage
     Storage::Get()->Free(shandle);
     // init storage
-    shandle = Storage::Get()->Alloc(dbytes, ctx);
+    shandle.size = dbytes;
+    Storage::Get()->Alloc(&shandle);
 #if MXNET_USE_MKLDNN == 1
     mkl_mem_ = nullptr;
 #endif
@@ -1869,9 +1890,9 @@ void NDArray::Load(dmlc::Stream* fi,
 NDArray NDArray::Copy(Context ctx) const {
   NDArray ret;
   if (kDefaultStorage == storage_type()) {
-    ret = NDArray(shape(), ctx, true, dtype_);
+    ret = NDArray(shape(), ctx, false, dtype_);
   } else if (kUndefinedStorage != storage_type()) {
-    ret = NDArray(storage_type(), shape(), ctx, true, dtype_,
+    ret = NDArray(storage_type(), shape(), ctx, false, dtype_,
                   ptr_->aux_types, ptr_->aux_shapes, storage_shape());
   } else {
     LOG(FATAL) << "NDArray::Copy cannot copy undefined storage-type ndarray to ctx.dev_type="
diff --git a/src/operator/linalg_impl.h b/src/operator/linalg_impl.h
index d83eb0d..e441e27 100644
--- a/src/operator/linalg_impl.h
+++ b/src/operator/linalg_impl.h
@@ -39,6 +39,13 @@ inline void linalg_check_batch_size(int A, int B, int C) {
   CHECK_GT(A, 0) << "Zero batch size for arguments to linear algebra operator";
 }
 
+#ifdef __CUDACC__
+#define EPHEMERAL_GPU_STORAGE_ALLOC(func, var, dtype, size) \
+  Storage::Handle var = Storage::Get()->Alloc(sizeof(dtype) * size, Context::GPU()); \
+  var.profiler_scope = "<ephemeral>:"; \
+  var.name = #func"_"#var;
+#endif
+
 //////////////////////////////// GEMM ////////////////////////////////////////////
 
 // CPU/GPU-versions of BLAS3 function "gemm". Please refer to the BLAS3-documentation
@@ -725,8 +732,9 @@ void linalg_potrf<gpu, DType>(const Tensor<gpu, 2, DType>& A, bool lower, Stream
   CHECK_NOTNULL(s); \
   check_potrf(A, lower); \
   int buffsize(linalg_potrf_buffsize(A, lower, s)); \
-  Storage::Handle buffer = Storage::Get()->Alloc(sizeof(DType)*buffsize, Context::GPU()); \
-  Storage::Handle info = Storage::Get()->Alloc(sizeof(int), Context::GPU()); \
+  EPHEMERAL_GPU_STORAGE_ALLOC(linalg_potrf, buffer, \
+      DType, buffsize); \
+  EPHEMERAL_GPU_STORAGE_ALLOC(linalg_potrf, info, int, 1); \
   CUSOLVER_CALL(cusolver##fname(Stream<gpu>::GetSolverHandle(s), \
                 (lower ? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER), \
                 A.size(0), A.dptr_, A.stride_, static_cast<DType *>(buffer.dptr), buffsize, \
@@ -746,8 +754,9 @@ void linalg_batch_potrf<gpu, DType>(const Tensor<gpu, 3, DType>& A, bool lower,
   CHECK_GT(A.size(0), 0); \
   check_potrf(A[0], lower); \
   int buffsize(linalg_potrf_buffsize(A[0], lower, s)); \
-  Storage::Handle buffer = Storage::Get()->Alloc(sizeof(DType)*buffsize, Context::GPU()); \
-  Storage::Handle info = Storage::Get()->Alloc(sizeof(int), Context::GPU()); \
+  EPHEMERAL_GPU_STORAGE_ALLOC(linalg_batch_potrf, buffer, \
+      DType, buffsize); \
+  EPHEMERAL_GPU_STORAGE_ALLOC(linalg_batch_potrf, info, int, 1); \
   for (mshadow::index_t i = 0; i < A.size(0); ++i) { \
     CUSOLVER_CALL(cusolver##fname(Stream<gpu>::GetSolverHandle(s), \
                  (lower ? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER), \
@@ -818,7 +827,8 @@ void linalg_potri<gpu, DType>(const Tensor<gpu, 2, DType>& A, bool lower, Stream
   using namespace mxnet; \
   CHECK_NOTNULL(s); \
   check_potri(A, lower); \
-  Storage::Handle buffer = Storage::Get()->Alloc(sizeof(DType)*A.MSize(), Context::GPU()); \
+  EPHEMERAL_GPU_STORAGE_ALLOC(linalg_potri, buffer, \
+      DType, A.MSize()); \
   using namespace mshadow::cuda; \
   int ngrid = std::min(kMaxGridNum, \
                        static_cast<int>((A.MSize() + kBaseThreadNum - 1) / kBaseThreadNum)); \
@@ -842,7 +852,8 @@ void linalg_batch_potri<gpu, DType>(const Tensor<gpu, 3, DType>& A, bool lower,
   CHECK_NOTNULL(s); \
   CHECK_GT(A.size(0), 0); \
   check_potri(A[0], lower); \
-  Storage::Handle buffer = Storage::Get()->Alloc(sizeof(DType)*A.MSize(), Context::GPU()); \
+  EPHEMERAL_GPU_STORAGE_ALLOC(linalg_batch_potri, buffer, \
+      DType, A.MSize()); \
   using namespace mshadow::cuda; \
   int ngrid = std::min(kMaxGridNum, \
                        static_cast<int>((A.MSize() + kBaseThreadNum - 1) / kBaseThreadNum)); \
@@ -1024,7 +1035,7 @@ void linalg_gelqf<gpu, DType>(const Tensor<gpu, 2, DType>& A, \
   check_gelqf(A, work); \
   int m(A.size(0)); \
   int lwork(work.size(0) - m); \
-  Storage::Handle info = Storage::Get()->Alloc(sizeof(int), Context::GPU()); \
+  EPHEMERAL_GPU_STORAGE_ALLOC(linalg_gelqf, info, int, 1); \
   CUSOLVER_CALL(cusolver##fname(Stream<gpu>::GetSolverHandle(s), \
                 A.size(1), m, A.dptr_ , A.stride_, work.dptr_, \
                 work.dptr_ + m, lwork, static_cast<int *>(info.dptr))); \
@@ -1048,7 +1059,7 @@ void linalg_orglq<gpu, DType>(const Tensor<gpu, 2, DType>& A, \
   check_gelqf(A, work); \
   int m(A.size(0)); \
   int lwork(work.size(0) - m); \
-  Storage::Handle info = Storage::Get()->Alloc(sizeof(int), Context::GPU()); \
+  EPHEMERAL_GPU_STORAGE_ALLOC(linalg_orglq, info, int, 1); \
   CUSOLVER_CALL(cusolver##fname(Stream<gpu>::GetSolverHandle(s), \
                 A.size(1), m, m, A.dptr_ , A.stride_, work.dptr_, \
                 work.dptr_ + m, lwork, static_cast<int *>(info.dptr))); \
@@ -1084,7 +1095,8 @@ int linalg_gelqf_workspace_query<gpu, DType>(const Tensor<gpu, 2, DType>& A, \
   CUSOLVER_CALL(cusolverDn##prefix##geqrf_bufferSize(Stream<gpu>::GetSolverHandle(s), \
                 A.size(1), m, A.dptr_ , A.stride_, &work1)); \
   int work2(0);  \
-  Storage::Handle tau = Storage::Get()->Alloc(sizeof(DType), Context::GPU()); \
+  EPHEMERAL_GPU_STORAGE_ALLOC(linalg_gelqf_workspace_query, \
+      tau, DType, 1); \
   CUSOLVER_CALL(cusolverDn##prefix##orgqr_bufferSize(Stream<gpu>::GetSolverHandle(s), \
                 A.size(1), m, m, A.dptr_ , A.stride_, static_cast<DType *>(tau.dptr), &work2)); \
   Storage::Get()->Free(tau); \
@@ -1182,7 +1194,7 @@ void linalg_syevd<gpu, DType>(const Tensor<gpu, 2, DType>& A, \
   using mshadow::gpu; \
   CHECK_NOTNULL(s); \
   check_syevd(A, L); \
-  Storage::Handle info = Storage::Get()->Alloc(sizeof(int), Context::GPU()); \
+  EPHEMERAL_GPU_STORAGE_ALLOC(linalg_syevd, info, int, 1); \
   CUSOLVER_CALL(cusolver##fname(Stream<gpu>::GetSolverHandle(s), \
                 CUSOLVER_EIG_MODE_VECTOR, CUBLAS_FILL_MODE_UPPER, \
                 A.size(0), A.dptr_ , A.stride_, L.dptr_, work.dptr_, \
@@ -1302,7 +1314,7 @@ void linalg_gesvd<gpu, DType>(const Tensor<gpu, 2, DType>& UT, \
   using namespace mxnet; \
   using mshadow::gpu; \
   check_gesvd(UT, L, V); \
-  Storage::Handle info = Storage::Get()->Alloc(sizeof(int), Context::GPU()); \
+  EPHEMERAL_GPU_STORAGE_ALLOC(linalg_gesvd, info, int, 1); \
   CUSOLVER_CALL(cusolver##fname(Stream<gpu>::GetSolverHandle(s), \
                 'O', 'S', V.size(1), V.size(0), V.dptr_, V.stride_, L.dptr_, V.dptr_, V.stride_, \
                 UT.dptr_, UT.stride_, work.dptr_, work.size(0), \
@@ -1422,8 +1434,9 @@ void linalg_batch_getrf<gpu, DType>(const Tensor<gpu, 3, DType>& A, \
   using namespace mxnet; \
   using namespace mxnet::op::mxnet_op; \
   CHECK_NOTNULL(s); \
-  Storage::Handle info = Storage::Get()->Alloc(sizeof(int) * A.size(0), Context::GPU()); \
-  Storage::Handle A_ptr_buf = Storage::Get()->Alloc(sizeof(DType *) * A.size(0), Context::GPU()); \
+  EPHEMERAL_GPU_STORAGE_ALLOC(linalg_batch_getrf, info, int, A.size(0)); \
+  EPHEMERAL_GPU_STORAGE_ALLOC(linalg_batch_getrf, A_ptr_buf, \
+      DType *, A.size(0)); \
   DType **A_ptr = static_cast<DType **>(A_ptr_buf.dptr); \
   Kernel<set_matrix, gpu>::Launch(s, A.size(0), \
                                   A_ptr, A.dptr_, \
@@ -1511,10 +1524,12 @@ void linalg_batch_getri<gpu, DType>(const Tensor<gpu, 3, DType>& A, \
   using namespace mxnet; \
   using namespace mxnet::op::mxnet_op; \
   CHECK_NOTNULL(s); \
-  Storage::Handle info = Storage::Get()->Alloc(sizeof(int) * A.size(0), Context::GPU()); \
-  Storage::Handle A_ptr_buf = Storage::Get()->Alloc(sizeof(DType *) * A.size(0), Context::GPU()); \
+  EPHEMERAL_GPU_STORAGE_ALLOC(linalg_batch_getri, info, int, A.size(0)); \
+  EPHEMERAL_GPU_STORAGE_ALLOC(linalg_batch_getri, \
+      A_ptr_buf, DType *, A.size(0)); \
   DType **A_ptr = static_cast<DType **>(A_ptr_buf.dptr); \
-  Storage::Handle LU_ptr_buf = Storage::Get()->Alloc(sizeof(DType *) * A.size(0), Context::GPU()); \
+  EPHEMERAL_GPU_STORAGE_ALLOC(linalg_batch_getri, \
+      LU_ptr_buf, DType *, A.size(0)); \
   DType **LU_ptr = static_cast<DType **>(LU_ptr_buf.dptr); \
   Kernel<set_matrix, gpu>::Launch(s, A.size(0), \
                                   A_ptr, A.dptr_, \
diff --git a/src/operator/nn/cudnn/cudnn_convolution-inl.h b/src/operator/nn/cudnn/cudnn_convolution-inl.h
index ddd0729..056f93b 100644
--- a/src/operator/nn/cudnn/cudnn_convolution-inl.h
+++ b/src/operator/nn/cudnn/cudnn_convolution-inl.h
@@ -760,8 +760,11 @@ class CuDNNConvolutionOp {
   // cudaMalloc() calls by (say) cudnnFind().  `elements` spec the alloc size in DTypes, not bytes.
   void ReserveElements(const std::vector<size_t> &elements) {
     std::vector<Storage::Handle> handles;
-    for (size_t alloc_element : elements)
+    for (size_t alloc_element : elements) {
         handles.push_back(Storage::Get()->Alloc(alloc_element * sizeof(DType), Context::GPU()));
+        handles.back().profiler_scope = "<ephemeral>:";
+        handles.back().name = "reserve_elements";
+    }
     for (auto &handle : handles)
         Storage::Get()->DirectFree(handle);
   }
diff --git a/src/operator/nn/cudnn/cudnn_deconvolution-inl.h b/src/operator/nn/cudnn/cudnn_deconvolution-inl.h
index 9783adc..b701883 100644
--- a/src/operator/nn/cudnn/cudnn_deconvolution-inl.h
+++ b/src/operator/nn/cudnn/cudnn_deconvolution-inl.h
@@ -763,8 +763,11 @@ class CuDNNDeconvolutionOp {
   // cudaMalloc() calls by (say) cudnnFind().  `elements` spec the alloc size in DTypes, not bytes.
   void ReserveElements(const std::vector<size_t> &elements) {
     std::vector<Storage::Handle> handles;
-    for (size_t alloc_element : elements)
+    for (size_t alloc_element : elements) {
         handles.push_back(Storage::Get()->Alloc(alloc_element * sizeof(DType), Context::GPU()));
+        handles.back().profiler_scope = "<ephemeral>:";
+        handles.back().name = "reserve_elements";
+    }
     for (auto &handle : handles)
         Storage::Get()->DirectFree(handle);
   }
diff --git a/src/operator/rnn-inl.h b/src/operator/rnn-inl.h
index d41b5b4..461f90b 100644
--- a/src/operator/rnn-inl.h
+++ b/src/operator/rnn-inl.h
@@ -41,6 +41,7 @@
 #include "./math_functions-inl.h"
 #include "./operator_common.h"
 #include "./rnn_impl.h"
+#include "../profiler/storage_profiler.h"
 
 #if MXNET_USE_CUDNN == 1
 STATIC_ASSERT_CUDNN_VERSION_GE(7000);
@@ -1400,6 +1401,9 @@ class RNNOp {
       workspace_size_ = workspace_byte_ / sizeof(DType);
       // Allocate the reserve space
       reserve_space_ = Storage::Get()->Alloc(reserve_space_byte_, Context::GPU(s->dev_id));
+      reserve_space_.profiler_scope = "cudnn_rnn:";
+      reserve_space_.name = "reserve_space";
+      profiler::GpuDeviceStorageProfiler::Get()->UpdateStorageInfo(reserve_space_);
       // Check that number of params are correct
       size_t cudnn_param_size;
       CUDNN_CALL(cudnnGetRNNParamsSize(s->dnn_handle_,
diff --git a/src/profiler/profiler.cc b/src/profiler/profiler.cc
index c429661..13ab462 100644
--- a/src/profiler/profiler.cc
+++ b/src/profiler/profiler.cc
@@ -294,5 +294,23 @@ void Profiler::SetContinuousProfileDump(bool continuous_dump, float delay_in_sec
   }
 }
 
+ProfilerScope* ProfilerScope::Get() {
+  static std::mutex mtx;
+  static std::shared_ptr<ProfilerScope> prof_scope = nullptr;
+  std::unique_lock<std::mutex> lk(mtx);
+  if (!prof_scope) {
+    prof_scope = std::make_shared<ProfilerScope>();
+  }
+  return prof_scope.get();
+}
+
+void ProfilerScope::SetCurrentProfilerScope(const std::string& scope) {
+  current_profiler_scope_ = scope;
+}
+
+std::string ProfilerScope::GetCurrentProfilerScope() const {
+  return current_profiler_scope_;
+}
+
 }  // namespace profiler
 }  // namespace mxnet
diff --git a/src/profiler/profiler.h b/src/profiler/profiler.h
index f9f997f..132a9f9 100644
--- a/src/profiler/profiler.h
+++ b/src/profiler/profiler.h
@@ -601,7 +601,7 @@ struct ProfileCounter : public ProfileObject {
   }
   /*! \brief operator: object -= v */
   inline uint64_t operator -=(int64_t v) {
-    CHECK_GE(value_, v);
+    CHECK_GE(value_, static_cast<uint64_t>(v));
     if (v >= 0) {
       return DecrementValue(static_cast<uint64_t>(v));
     } else {
@@ -1314,6 +1314,18 @@ inline void Profiler::AddProfileStat<ProfileOperator::OprExecStat>(
 
 #undef VTUNE_ONLY_CODE  // This macro not meant to be used outside of this file
 
+class ProfilerScope {
+ public:
+  /*! \brief Get the profiler scope instance */
+  static ProfilerScope* Get();
+  /*! \brief Set the current profiler scope */
+  void SetCurrentProfilerScope(const std::string& scope);
+  /*! \brief Get the current profiler scope */
+  std::string GetCurrentProfilerScope() const;
+ private:
+  std::string current_profiler_scope_ = MXNET_STORAGE_DEFAULT_PROFILER_SCOPE_CSTR;
+};
+
 }  // namespace profiler
 }  // namespace mxnet
 #endif  // MXNET_PROFILER_PROFILER_H_
diff --git a/src/profiler/storage_profiler.cc b/src/profiler/storage_profiler.cc
new file mode 100644
index 0000000..873cc67
--- /dev/null
+++ b/src/profiler/storage_profiler.cc
@@ -0,0 +1,130 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include "./storage_profiler.h"
+
+#if MXNET_USE_NVML
+#include <nvml.h>
+#endif  // MXNET_USE_NVML
+#include <fstream>
+#include <map>
+#include <unordered_map>
+#include <vector>
+#include "./profiler.h"
+#include "../common/utils.h"
+#include "../common/cuda_utils.h"
+
+namespace mxnet {
+namespace profiler {
+
+#if MXNET_USE_CUDA
+
+GpuDeviceStorageProfiler* GpuDeviceStorageProfiler::Get() {
+  static std::mutex mtx;
+  static std::shared_ptr<GpuDeviceStorageProfiler> gpu_dev_storage_profiler = nullptr;
+  std::unique_lock<std::mutex> lk(mtx);
+  if (!gpu_dev_storage_profiler) {
+    gpu_dev_storage_profiler = std::make_shared<GpuDeviceStorageProfiler>();
+  }
+  return gpu_dev_storage_profiler.get();
+}
+
+void GpuDeviceStorageProfiler::DumpProfile() const {
+  size_t current_pid = common::current_process_id();
+  std::ofstream fout((filename_prefix_ + "-pid_" + std::to_string(current_pid)
+                      + ".csv").c_str());
+  if (!fout.is_open()) {
+    return;
+  }
+  struct AllocEntryDumpFmt {
+    size_t requested_size;
+    int dev_id;
+    size_t actual_size;
+    bool reuse;
+  };
+  // order the GPU memory allocation entries by their attribute name
+  std::multimap<std::string, AllocEntryDumpFmt> gpu_mem_ordered_alloc_entries;
+  // map the GPU device ID to the total amount of allocations
+  std::unordered_map<int, size_t> gpu_dev_id_total_alloc_map;
+  for (const std::pair<void*, AllocEntry>& alloc_entry :
+       gpu_mem_alloc_entries_) {
+    gpu_mem_ordered_alloc_entries.emplace(
+        alloc_entry.second.profiler_scope +
+        alloc_entry.second.name, AllocEntryDumpFmt{
+          alloc_entry.second.requested_size,
+          alloc_entry.second.dev_id,
+          alloc_entry.second.actual_size,
+          alloc_entry.second.reuse});
+    gpu_dev_id_total_alloc_map[alloc_entry.second.dev_id] = 0;
+  }
+  fout << "\"Attribute Name\",\"Requested Size\","
+          "\"Device\",\"Actual Size\",\"Reuse?\"" << std::endl;
+  for (const std::pair<std::string, AllocEntryDumpFmt>& alloc_entry :
+       gpu_mem_ordered_alloc_entries) {
+    fout << "\"" << alloc_entry.first << "\","
+         << "\"" << alloc_entry.second.requested_size << "\","
+         << "\"" << alloc_entry.second.dev_id << "\","
+         << "\"" << alloc_entry.second.actual_size << "\","
+         << "\"" << alloc_entry.second.reuse << "\"" << std::endl;
+    gpu_dev_id_total_alloc_map[alloc_entry.second.dev_id] +=
+        alloc_entry.second.actual_size;
+  }
+#if MXNET_USE_NVML
+  // If NVML has been enabled, add amend term to the GPU memory profile.
+  nvmlDevice_t nvml_device;
+
+  NVML_CALL(nvmlInit());
+  for (std::pair<const int, size_t>& dev_id_total_alloc_pair :
+       gpu_dev_id_total_alloc_map) {
+    unsigned info_count = 0;
+    std::vector<nvmlProcessInfo_t> infos(info_count);
+
+    NVML_CALL(nvmlDeviceGetHandleByIndex(dev_id_total_alloc_pair.first, &nvml_device));
+    // The first call to `nvmlDeviceGetComputeRunningProcesses` is to set the
+    // size of info. Since `NVML_ERROR_INSUFFICIENT_SIZE` will always be
+    // returned, we do not wrap the function call with `NVML_CALL`.
+    nvmlDeviceGetComputeRunningProcesses(nvml_device, &info_count, infos.data());
+    infos = std::vector<nvmlProcessInfo_t>(info_count);
+    NVML_CALL(nvmlDeviceGetComputeRunningProcesses(nvml_device, &info_count, infos.data()));
+
+    bool amend_made = false;
+
+    for (unsigned i = 0; i < info_count; ++i) {
+      if (current_pid == infos[i].pid) {
+        amend_made = true;
+        fout << "\"" << "nvml_amend" << "\","
+             << "\"" << infos[i].usedGpuMemory - dev_id_total_alloc_pair.second << "\","
+             << "\"" << dev_id_total_alloc_pair.first << "\","
+             << "\"" << infos[i].usedGpuMemory - dev_id_total_alloc_pair.second << "\","
+             << "\"0\"" << std::endl;
+        break;
+      }
+    }
+    if (!amend_made) {
+      LOG(INFO) << "NVML is unable to make amendment to the GPU memory profile "
+                   "since it is unable to locate the current process ID. "
+                   "Are you working in Docker without setting --pid=host?";
+    }
+  }  // for (dev_id_total_alloc_pair : gpu_dev_id_total_alloc_map)
+#endif  // MXNET_USE_NVML
+}
+
+#endif  // MXNET_USE_CUDA
+
+}  // namespace profiler
+}  // namespace mxnet
diff --git a/src/profiler/storage_profiler.h b/src/profiler/storage_profiler.h
index 5ab5983..ad87bf1 100644
--- a/src/profiler/storage_profiler.h
+++ b/src/profiler/storage_profiler.h
@@ -19,13 +19,16 @@
 #ifndef MXNET_PROFILER_STORAGE_PROFILER_H_
 #define MXNET_PROFILER_STORAGE_PROFILER_H_
 
+#include <mxnet/libinfo.h>
 #include <mxnet/storage.h>
 #include <string>
+#include <tuple>
 #include <vector>
+#include <unordered_map>
 #include "./profiler.h"
 
 namespace mxnet {
-namespace storage {
+namespace profiler {
 
 /*!
  * \brief Storage allocation/deallocation profiling via ProfileCounters
@@ -106,7 +109,95 @@ class DeviceStorageProfiler {
   std::vector<std::shared_ptr<profiler::ProfileCounter>> mem_counters_;
 };
 
-}  // namespace storage
+#if MXNET_USE_CUDA
+
+/*!
+ * \brief GPU storage allocation/deallocation profiling
+ */
+class GpuDeviceStorageProfiler {
+ public:
+  /*! \brief get the global instance to record an allocation entry */
+  static GpuDeviceStorageProfiler* Get();
+  /*!
+   * \brief Similar functions to the `DeviceStorageProfiler` methods above.
+   *        However, in the case of the `GpuDeviceStorageProfiler`, we are 
+   *        recording extra piece of information on the actual allocation size
+   *        and whether the allocation is a reuse or not.
+   */
+  void OnAlloc(const Storage::Handle &handle,
+               const size_t actual_size, const bool reuse) {
+    if (handle.size > 0) {
+      profiler::Profiler *prof = profiler::Profiler::Get();
+      if (prof->IsProfiling(profiler::Profiler::kMemory)) {
+#ifdef _MSC_VER
+        gpu_mem_alloc_entries_[handle.dptr] = AllocEntry{
+            handle.profiler_scope,
+            handle.name,
+            handle.size,
+            handle.ctx.dev_id,
+            actual_size, reuse};
+#else
+        gpu_mem_alloc_entries_[handle.dptr] = {
+            handle.profiler_scope,
+            handle.name,
+            handle.size,
+            handle.ctx.dev_id,
+            actual_size, reuse};
+#endif
+      }
+    }
+  }
+
+  void OnFree(const Storage::Handle &handle) {
+    if (handle.size > 0) {
+      profiler::Profiler *prof = profiler::Profiler::Get();
+      if (prof->IsProfiling(profiler::Profiler::kMemory)) {
+        // In case of bug which tries to free first
+        if (gpu_mem_alloc_entries_.find(handle.dptr) !=
+            gpu_mem_alloc_entries_.end()) {
+          gpu_mem_alloc_entries_.erase(handle.dptr);
+        }
+      }
+    }
+  }
+
+  void UpdateStorageInfo(const Storage::Handle &handle) {
+    if (handle.size > 0) {
+      profiler::Profiler *prof = profiler::Profiler::Get();
+      if (prof->IsProfiling(profiler::Profiler::kMemory)) {
+        auto entry_iter = gpu_mem_alloc_entries_.find(handle.dptr);
+        if (entry_iter != gpu_mem_alloc_entries_.end()) {
+          entry_iter->second.profiler_scope = handle.profiler_scope;
+          entry_iter->second.name = handle.name;
+        }
+      }
+    }
+  }
+
+  /*! \brief set the dumping filename */
+  void SetConfig(const std::string& filename_prefix) {
+    filename_prefix_ = filename_prefix;
+  }
+  /*! \brief dump the allocation entries to file */
+  void DumpProfile() const;
+
+ private:
+  std::string filename_prefix_ = "gpu_memory_profile";
+  /*! \brief Dynamically-sized dictionary of memory profile counters */
+  struct AllocEntry {
+    std::string profiler_scope;  // profiler scope of the storage handle
+    std::string name;            // name of the storage handle
+    size_t requested_size;       // requested size of the storage handle
+    int dev_id;                  // device ID of the storage handle
+    size_t actual_size;          // actual allocation size
+    bool reuse;                  // whether the allocation is a reuse
+  };
+  std::unordered_map<void*, AllocEntry> gpu_mem_alloc_entries_;
+};
+
+#endif  // MXNET_USE_CUDA
+
+}  // namespace profiler
 }  // namespace mxnet
 
 #endif  // MXNET_PROFILER_STORAGE_PROFILER_H_
diff --git a/src/resource.cc b/src/resource.cc
index 3f46124..65af53d 100644
--- a/src/resource.cc
+++ b/src/resource.cc
@@ -29,12 +29,12 @@
 #include <mxnet/engine.h>
 #include <mxnet/random_generator.h>
 #include <mxnet/resource.h>
-#include <mxnet/storage.h>
 #include <limits>
 #include <atomic>
 #include "./common/lazy_alloc_array.h"
 #include "./common/utils.h"
 #include "./common/cuda_utils.h"
+#include "./profiler/storage_profiler.h"
 
 namespace mxnet {
 namespace resource {
@@ -65,11 +65,16 @@ struct SpaceAllocator {
     host_handle.size = 0;
   }
 
-  inline void* GetSpace(size_t size) {
+  inline void* GetSpace(size_t size, const std::string &name) {
     if (handle.size >= size) return handle.dptr;
 
     Storage::Get()->DirectFree(handle);
     handle = Storage::Get()->Alloc(size, ctx);
+    handle.profiler_scope = "resource:";
+    handle.name = name;
+#if MXNET_USE_CUDA
+    profiler::GpuDeviceStorageProfiler::Get()->UpdateStorageInfo(handle);
+#endif  // MXNET_USE_CUDA
     return handle.dptr;
   }
 
@@ -410,8 +415,9 @@ class ResourceManagerImpl : public ResourceManager {
 };
 }  // namespace resource
 
-void* Resource::get_space_internal(size_t size) const {
-  return static_cast<resource::SpaceAllocator*>(ptr_)->GetSpace(size);
+void* Resource::get_space_internal(size_t size,
+    const std::string &name) const {
+  return static_cast<resource::SpaceAllocator*>(ptr_)->GetSpace(size, name);
 }
 
 void* Resource::get_host_space_internal(size_t size) const {
@@ -420,27 +426,25 @@ void* Resource::get_host_space_internal(size_t size) const {
 
 #if MXNET_USE_CUDNN == 1
 void Resource::get_cudnn_dropout_desc(
-    cudnnDropoutDescriptor_t* dropout_desc,
+    cudnnDropoutDescriptor_t *dropout_desc,
     mshadow::Stream<gpu> *stream,
-    const float dropout,
-    uint64_t seed) const {
+    const float dropout, uint64_t seed,
+    const std::string &name) const {
 
   CHECK_EQ(req.type, ResourceRequest::kCuDNNDropoutDesc);
   auto state_space = static_cast<resource::SpaceAllocator*>(ptr_);
   CHECK_EQ(state_space->ctx.dev_id, stream->dev_id)
-    << "The device id of cudnn dropout state space doesn't match that from stream.";
+    << "The device id of cuDNN dropout state space doesn't match that from stream.";
   if (!state_space->handle.size) {
     // not initialized yet.
     size_t dropout_state_size;
     CUDNN_CALL(cudnnDropoutGetStatesSize(stream->dnn_handle_, &dropout_state_size));
     // reserve GPU space
-    Storage::Get()->DirectFree(
-      Storage::Get()->Alloc(dropout_state_size, state_space->ctx));
+    Storage::Get()->DirectFree(Storage::Get()->Alloc(dropout_state_size, state_space->ctx));
     CUDNN_CALL(cudnnSetDropoutDescriptor(*dropout_desc, stream->dnn_handle_,
                                          dropout,
-                                         state_space->GetSpace(dropout_state_size),
-                                         dropout_state_size,
-                                         seed));
+                                         state_space->GetSpace(dropout_state_size, name),
+                                         dropout_state_size, seed));
   } else {
     // cudnnRestoreDropoutDescriptor() introduced with cuDNN v7
     STATIC_ASSERT_CUDNN_VERSION_GE(7000);
diff --git a/src/storage/gpu_device_storage.h b/src/storage/gpu_device_storage.h
index 5e09561..3eabe1b 100644
--- a/src/storage/gpu_device_storage.h
+++ b/src/storage/gpu_device_storage.h
@@ -28,6 +28,7 @@
 #include "mxnet/base.h"
 #include "mxnet/storage.h"
 #include "../common/cuda_utils.h"
+#include "../profiler/storage_profiler.h"
 #if MXNET_USE_CUDA
 #include <cuda_runtime.h>
 #endif  // MXNET_USE_CUDA
@@ -64,8 +65,11 @@ inline void GPUDeviceStorage::Alloc(Storage::Handle* handle) {
   std::lock_guard<std::mutex> l(Storage::Get()->GetMutex(Context::kGPU));
 #endif  // MXNET_USE_NCCL
   cudaError_t e = cudaMalloc(&handle->dptr, size);
-  if (e != cudaSuccess && e != cudaErrorCudartUnloading)
+  if (e != cudaSuccess && e != cudaErrorCudartUnloading) {
     LOG(FATAL) << "CUDA: " << cudaGetErrorString(e);
+  }
+  // record the allocation event in the memory profiler
+  profiler::GpuDeviceStorageProfiler::Get()->OnAlloc(*handle, size, false);
 #else   // MXNET_USE_CUDA
   LOG(FATAL) << "Please compile with CUDA enabled";
 #endif  // MXNET_USE_CUDA
@@ -83,6 +87,8 @@ inline void GPUDeviceStorage::Free(Storage::Handle handle) {
   if (err != cudaSuccess && err != cudaErrorCudartUnloading) {
     LOG(FATAL) << "CUDA: " << cudaGetErrorString(err);
   }
+  // record the deallocation event in the memory profiler
+  profiler::GpuDeviceStorageProfiler::Get()->OnFree(handle);
 #else   // MXNET_USE_CUDA
   LOG(FATAL) << "Please compile with CUDA enabled";
 #endif  // MXNET_USE_CUDA
diff --git a/src/storage/pinned_memory_storage.h b/src/storage/pinned_memory_storage.h
index 13573d9..5d03fd1 100644
--- a/src/storage/pinned_memory_storage.h
+++ b/src/storage/pinned_memory_storage.h
@@ -30,6 +30,7 @@
 #include "mxnet/base.h"
 #include "mxnet/storage.h"
 #include "../common/cuda_utils.h"
+#include "../profiler/storage_profiler.h"
 
 namespace mxnet {
 namespace storage {
@@ -60,6 +61,8 @@ inline void PinnedMemoryStorage::Alloc(Storage::Handle* handle) {
   mxnet::common::cuda::DeviceStore device_store(handle->ctx.real_dev_id(), true);
   // make the memory available across all devices
   CUDA_CALL(cudaHostAlloc(&handle->dptr, size, cudaHostAllocPortable));
+  // record the allocation event in the memory profiler
+  profiler::GpuDeviceStorageProfiler::Get()->OnAlloc(*handle, size, false);
 }
 
 inline void PinnedMemoryStorage::Free(Storage::Handle handle) {
@@ -72,6 +75,8 @@ inline void PinnedMemoryStorage::Free(Storage::Handle handle) {
   if (err != cudaSuccess && err != cudaErrorCudartUnloading) {
     LOG(FATAL) << "CUDA: " << cudaGetErrorString(err);
   }
+  // record the deallocation event in the memory profiler
+  profiler::GpuDeviceStorageProfiler::Get()->OnFree(handle);
 }
 
 }  // namespace storage
diff --git a/src/storage/pooled_storage_manager.h b/src/storage/pooled_storage_manager.h
index d9d7277..80358fc 100644
--- a/src/storage/pooled_storage_manager.h
+++ b/src/storage/pooled_storage_manager.h
@@ -39,6 +39,7 @@
 #include "./storage_manager.h"
 #include "../common/cuda_utils.h"
 #include "../common/utils.h"
+#include "../profiler/storage_profiler.h"
 
 
 namespace mxnet {
@@ -97,6 +98,7 @@ class GPUPooledStorageManager final : public StorageManager {
       LOG(FATAL) << "CUDA: " << cudaGetErrorString(err);
     }
     used_memory_ -= size;
+    profiler::GpuDeviceStorageProfiler::Get()->OnFree(handle);
   }
 
   // Round a value 'x' up to the next multiple of 'multiple'
@@ -166,11 +168,15 @@ void GPUPooledStorageManager::Alloc(Storage::Handle* handle) {
     }
     used_memory_ += size;
     handle->dptr = ret;
+    // record the allocation event in the memory profiler
+    profiler::GpuDeviceStorageProfiler::Get()->OnAlloc(*handle, size, false);
   } else {
     auto&& reuse_pool = reuse_it->second;
     auto ret = reuse_pool.back();
     reuse_pool.pop_back();
     handle->dptr = ret;
+    // record the allocation event in the memory profiler
+    profiler::GpuDeviceStorageProfiler::Get()->OnAlloc(*handle, size, true);
   }
 }
 
@@ -292,6 +298,7 @@ class GPUPooledRoundedStorageManager final : public StorageManager {
       LOG(FATAL) << "CUDA: " << cudaGetErrorString(err);
     }
     used_memory_ -= size;
+    profiler::GpuDeviceStorageProfiler::Get()->OnFree(handle);
   }
 
  private:
@@ -349,10 +356,14 @@ void GPUPooledRoundedStorageManager::Alloc(Storage::Handle* handle) {
     }
     used_memory_ += size;
     handle->dptr = ret;
+    // record the allocation event in the memory profiler
+    profiler::GpuDeviceStorageProfiler::Get()->OnAlloc(*handle, size, false);
   } else {
     auto ret = reuse_pool.back();
     reuse_pool.pop_back();
     handle->dptr = ret;
+    // record the allocation event in the memory profiler
+    profiler::GpuDeviceStorageProfiler::Get()->OnAlloc(*handle, size, true);
   }
 }
 
diff --git a/src/storage/storage.cc b/src/storage/storage.cc
index 7a59a77..c0903d2 100644
--- a/src/storage/storage.cc
+++ b/src/storage/storage.cc
@@ -53,7 +53,7 @@ class StorageImpl : public Storage {
   // internal storage managers
   std::array<common::LazyAllocArray<storage::StorageManager>,
              kMaxNumberOfDevices> storage_managers_;
-  storage::DeviceStorageProfiler profiler_;
+  profiler::DeviceStorageProfiler profiler_;
 };  // struct Storage::Impl
 #if MXNET_USE_CUDA
 int StorageImpl::num_gpu_device = 0;
diff --git a/tests/python/unittest/test_profiler.py b/tests/python/unittest/test_profiler.py
index 8e8209f..5a0baca 100644
--- a/tests/python/unittest/test_profiler.py
+++ b/tests/python/unittest/test_profiler.py
@@ -18,12 +18,15 @@
 from __future__ import print_function
 import time
 import os
+import csv
 import json
 import unittest
+import numpy as np
 from collections import OrderedDict
 
 import mxnet as mx
 from mxnet import profiler
+from mxnet.gluon import nn
 from common import run_in_spawned_process
 
 
@@ -34,8 +37,7 @@ def enable_profiler(profile_filename, run=True, continuous_dump=False, aggregate
                         profile_api=True,
                         filename=profile_filename,
                         continuous_dump=continuous_dump,
-                        aggregate_stats=aggregate_stats
-                        )
+                        aggregate_stats=aggregate_stats)
     if run is True:
         profiler.set_state('run')
 
@@ -459,6 +461,156 @@ def test_custom_operator_profiling_naive_engine():
             'test_custom_operator_profiling_multiple_custom_ops_symbolic_naive.json')
 
 
+@unittest.skipIf(mx.context.num_gpus() == 0, "GPU memory profiler records allocation on GPUs only")
+def test_gpu_memory_profiler_symbolic():
+    iter_num = 5
+
+    enable_profiler('test_profiler.json', False, False)
+    profiler.set_state('run')
+
+    with profiler.Scope("tensordot"):
+        A = mx.sym.Variable('A')
+        B = mx.sym.Variable('B')
+        C = mx.symbol.dot(A, B, name='dot')
+
+    executor = C.simple_bind(mx.gpu(), 'write', A=(4096, 4096), B=(4096, 4096))
+
+    a = mx.random.uniform(-1.0, 1.0, shape=(4096, 4096))
+    b = mx.random.uniform(-1.0, 1.0, shape=(4096, 4096))
+
+    a.copyto(executor.arg_dict['A'])
+    b.copyto(executor.arg_dict['B'])
+
+    for i in range(iter_num):
+        executor.forward()
+        c = executor.outputs[0]
+        mx.nd.waitall()
+    profiler.set_state('stop')
+    profiler.dump(True)
+
+    expected_alloc_entries = [
+            {'Attribute Name' : 'tensordot:in_arg:A',
+             'Requested Size' : str(4 * a.size)},
+            {'Attribute Name' : 'tensordot:in_arg:B',
+             'Requested Size' : str(4 * b.size)},
+            {'Attribute Name' : 'tensordot:arg_grad:A',
+             'Requested Size' : str(4 * a.size)},
+            {'Attribute Name' : 'tensordot:arg_grad:B',
+             'Requested Size' : str(4 * b.size)},
+            {'Attribute Name' : 'tensordot:dot',
+             'Requested Size' : str(4 * c.size)},
+            {'Attribute Name' : 'tensordot:dot_head_grad',
+             'Requested Size' : str(4 * c.size)}]
+
+    # Sample gpu_memory_profile.csv:
+    # "Attribute Name","Requested Size","Device","Actual Size","Reuse?"
+    # "tensordot:arg_grad:A","67108864","0","67108864","0"
+    # "tensordot:arg_grad:B","67108864","0","67108864","0"
+    # "tensordot:dot","67108864","0","67108864","0"
+    # "tensordot:dot_head_grad","67108864","0","67108864","0"
+    # "tensordot:in_arg:A","67108864","0","67108864","0"
+    # "tensordot:in_arg:B","67108864","0","67108864","0"
+
+    with open('gpu_memory_profile-pid_%d.csv' % (os.getpid()), mode='r') as csv_file:
+        csv_reader = csv.DictReader(csv_file)
+        for expected_alloc_entry in expected_alloc_entries:
+            csv_file.seek(0)
+            entry_found = False
+            for row in csv_reader:
+                if row['Attribute Name'] == expected_alloc_entry['Attribute Name']:
+                    assert row['Requested Size'] == expected_alloc_entry['Requested Size'], \
+                           "requested size={} is not equal to the expected size={}" \
+                           .format(row['Requested Size'],
+                                   expected_alloc_entry['Requested Size'])
+                    entry_found = True
+                    break
+            assert entry_found, \
+                   "Entry for attr_name={} has not been found" \
+                   .format(expected_alloc_entry['Attribute Name'])
+
+
+@unittest.skipIf(mx.context.num_gpus() == 0, "GPU memory profiler records allocation on GPUs only")
+def test_gpu_memory_profiler_gluon():
+    enable_profiler(profile_filename='test_profiler.json',
+                    run=True, continuous_dump=True)
+    profiler.set_state('run')
+
+    model = nn.HybridSequential(prefix='net_')
+    with model.name_scope():
+        model.add(nn.Dense(128, activation='tanh'))
+        model.add(nn.Dropout(0.5))
+        model.add(nn.Dense(64, activation='tanh'),
+                  nn.Dense(32, in_units=64))
+        model.add(nn.Activation('relu'))
+    model.initialize(ctx=mx.gpu())
+    model.hybridize()
+
+    inputs = mx.sym.var('data')
+
+    with mx.autograd.record():
+        out = model(mx.nd.zeros((16, 10), ctx=mx.gpu()))
+    out.backward()
+    mx.nd.waitall()
+    profiler.set_state('stop')
+    profiler.dump(True)
+
+    # Sample gpu_memory_profiler.csv
+    # "Attribute Name","Requested Size","Device","Actual Size","Reuse?"
+    # "<unk>:in_arg:data","640","0","4096","0"
+    # "net:arg_grad:net_dense0_bias","512","0","4096","0"
+    # "net:arg_grad:net_dense0_weight","5120","0","8192","0"
+    # "net:arg_grad:net_dense1_bias","256","0","4096","0"
+    # "net:arg_grad:net_dense1_weight","32768","0","32768","0"
+    # "net:arg_grad:net_dense2_bias","128","0","4096","0"
+    # "net:arg_grad:net_dense2_weight","8192","0","8192","0"
+    # "net:dense0:net_dense0_fwd","8192","0","8192","0"
+    # "net:dense0:tanh:net_dense0_tanh_fwd","8192","0","8192","0"
+    # "net:dense1:net_dense1_fwd","4096","0","4096","0"
+    # "net:dense1:tanh:net_dense1_tanh_fwd","4096","0","4096","0"
+    # "net:dense2:net_dense2_fwd","2048","0","4096","0"
+    # "net:dense2:net_dense2_fwd_backward","4096","0","4096","0"
+    # "net:dropout0:net_dropout0_fwd","8192","0","8192","0"
+    # "net:dropout0:net_dropout0_fwd","8192","0","8192","0"
+    # "net:in_arg:net_dense0_bias","512","0","4096","0"
+    # "net:in_arg:net_dense0_weight","5120","0","8192","0"
+    # "net:in_arg:net_dense1_bias","256","0","4096","0"
+    # "net:in_arg:net_dense1_weight","32768","0","32768","0"
+    # "net:in_arg:net_dense2_bias","128","0","4096","0"
+    # "net:in_arg:net_dense2_weight","8192","0","8192","0"
+    # "net:relu0:net_relu0_fwd","2048","0","4096","0"
+    # "net:relu0:net_relu0_fwd_backward","8192","0","8192","0"
+    # "net:relu0:net_relu0_fwd_head_grad","2048","0","4096","0"
+    # "resource:cudnn_dropout_state (dropout-inl.h +258)","1671168","0","1671168","0"
+    # "resource:temp_space (fully_connected-inl.h +316)","34816","0","36864","0"
+
+    # We are only checking for weight parameters here, also making sure that
+    # there is no unknown entries in the memory profile.
+    with open('gpu_memory_profile-pid_%d.csv' % (os.getpid()), mode='r') as csv_file:
+        csv_reader = csv.DictReader(csv_file)
+        for scope in ['in_arg', 'arg_grad']:
+            for key, nd in model.collect_params().items():
+                expected_arg_name = "net:%s:" % scope + key
+                expected_arg_size = str(4 * np.prod(nd.shape))
+                csv_file.seek(0)
+                entry_found = False
+                for row in csv_reader:
+                    if row['Attribute Name'] == expected_arg_name:
+                        assert row['Requested Size'] == expected_arg_size, \
+                            "requested size={} is not equal to the expected size={}" \
+                            .format(row['Requested Size'], expected_arg_size)
+                        entry_found = True
+                        break
+                assert entry_found, \
+                    "Entry for attr_name={} has not been found" \
+                    .format(expected_arg_name)
+        # Make sure that there is no unknown allocation entry.
+        csv_file.seek(0)
+        for row in csv_reader:
+            if row['Attribute Name'] == "<unk>:unknown" or \
+               row['Attribute Name'] == "<unk>:":
+                assert False, "Unknown allocation entry has been encountered"
+
+
 if __name__ == '__main__':
     import nose
     nose.runmodule()
diff --git a/tests/python/unittest/test_symbol.py b/tests/python/unittest/test_symbol.py
index 8e4fe11..c913f5c 100644
--- a/tests/python/unittest/test_symbol.py
+++ b/tests/python/unittest/test_symbol.py
@@ -296,7 +296,8 @@ def test_load_000800():
         assert k in attr2, k
         v2 = attr2[k]
         for kk, vv1 in v1.items():
-            if kk.startswith('__') and kk.endswith('__'):
+            if kk.startswith('__') and kk.endswith('__') and \
+               kk != '__profiler_scope__':
                 assert kk in v2 and v2[kk] == vv1, k + str(v1) + str(v2)
 
     check_symbol_consistency(sym1, sym2,
diff --git a/tests/python/unittest/test_thread_local.py b/tests/python/unittest/test_thread_local.py
index f0e3c66..63d97f1 100644
--- a/tests/python/unittest/test_thread_local.py
+++ b/tests/python/unittest/test_thread_local.py
@@ -120,6 +120,7 @@ def test_blockscope():
         def __init__(self, prefix):
             self.prefix = prefix
             self._empty_prefix = False
+            self._profiler_scope_name = '<unk>:'
     blockscope_list = []
     status = [False]
     event = threading.Event()