You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2018/05/29 17:22:12 UTC

[incubator-mxnet] branch master updated: [MXNET-374] handle row_sparse weight in parameter and trainer (#11001)

This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 482e50b  [MXNET-374] handle row_sparse weight in parameter and trainer (#11001)
482e50b is described below

commit 482e50bbbc429409a792ac4664127f34a226cea3
Author: Haibin Lin <li...@gmail.com>
AuthorDate: Tue May 29 10:21:39 2018 -0700

    [MXNET-374] handle row_sparse weight in parameter and trainer (#11001)
    
    * + rsp parameter
    
    * draft
    
    * Fix optimizer pickle
    
    * refactor and document
    
    * add test for save load with cast_stype
    
    * refactor trainer tests
    
    * add test
    
    * add back test
    
    * raise error for load params
    
    * add comment
    
    * remove print
    
    * fix doc
    
    * CR comments
    
    * CR comments
    
    * change error
    
    * remove cast stype
    
    * fix test
    
    * add reset kvstore to trainer
    
    * lint
    
    * add test to CI
    
    * add more checks
---
 ci/docker/runtime_functions.sh              |   1 +
 python/mxnet/gluon/block.py                 |   9 ++
 python/mxnet/gluon/parameter.py             | 123 ++++++++++++++--
 python/mxnet/gluon/trainer.py               | 118 +++++++++++----
 python/mxnet/model.py                       |  19 +++
 src/operator/tensor/indexing_op.h           |   6 +
 tests/nightly/dist_sync_kvstore.py          |  27 +++-
 tests/python/unittest/test_gluon.py         | 220 ++++++++++++++++------------
 tests/python/unittest/test_gluon_trainer.py | 200 +++++++++++++++++++++++++
 9 files changed, 585 insertions(+), 138 deletions(-)

diff --git a/ci/docker/runtime_functions.sh b/ci/docker/runtime_functions.sh
index 7abe767..10bca17 100755
--- a/ci/docker/runtime_functions.sh
+++ b/ci/docker/runtime_functions.sh
@@ -544,6 +544,7 @@ integrationtest_ubuntu_gpu_dist_kvstore() {
     ../../tools/launch.py -n 7 --launcher local python dist_sync_kvstore.py
     ../../tools/launch.py -n 7 --launcher local python dist_sync_kvstore.py --no-multiprecision
     ../../tools/launch.py -n 7 --launcher local python dist_device_sync_kvstore.py
+    ../../tools/launch.py -n 7 --launcher local python dist_sync_kvstore.py --type=gluon
 }
 
 test_ubuntu_cpu_python2() {
diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py
index dbe3c5e..4b37f43 100644
--- a/python/mxnet/gluon/block.py
+++ b/python/mxnet/gluon/block.py
@@ -606,6 +606,7 @@ class HybridBlock(Block):
 
     Refer `Hybrid tutorial <http://mxnet.io/tutorials/gluon/hybrid.html>`_ to see
     the end-to-end usage.
+
     """
     def __init__(self, prefix=None, params=None):
         super(HybridBlock, self).__init__(prefix=prefix, params=params)
@@ -879,6 +880,14 @@ class SymbolBlock(HybridBlock):
                 "Input symbols must be variable, but %s is an output of operators"%str(i)
             input_names.add(i.name)
 
+        # check if any symbol is row_sparse
+        row_sparse_storage = ndarray.ndarray._STORAGE_TYPE_STR_TO_ID['row_sparse']
+        for i in out:
+            for j in i.get_internals():
+                assert(j.attr("__storage_type__") != str(row_sparse_storage)), \
+                    "SymbolBlock doesn't support Parameter '%s' because its storage " \
+                    "type is 'row_sparse'." % j.name
+
         for i in out.list_arguments():
             if i not in input_names:
                 self.params.get(i, allow_deferred_init=True)
diff --git a/python/mxnet/gluon/parameter.py b/python/mxnet/gluon/parameter.py
index c7cbccc..3265fef 100644
--- a/python/mxnet/gluon/parameter.py
+++ b/python/mxnet/gluon/parameter.py
@@ -81,6 +81,8 @@ class Parameter(object):
         Weight decay multiplier (L2 regularizer coefficient). Works similar to lr_mult.
     init : Initializer, default None
         Initializer of this parameter. Will use the global initializer by default.
+    stype: {'default', 'row_sparse', 'csr'}, defaults to 'default'.
+        The storage type of the parameter.
     grad_stype: {'default', 'row_sparse', 'csr'}, defaults to 'default'.
         The storage type of the parameter's gradient.
 
@@ -99,12 +101,13 @@ class Parameter(object):
     """
     def __init__(self, name, grad_req='write', shape=None, dtype=mx_real_t,
                  lr_mult=1.0, wd_mult=1.0, init=None, allow_deferred_init=False,
-                 differentiable=True, grad_stype='default'):
+                 differentiable=True, stype='default', grad_stype='default'):
         self._var = None
         self._data = None
         self._grad = None
         self._ctx_list = None
         self._ctx_map = None
+        self._trainer = None
         self._deferred_init = ()
         self._differentiable = differentiable
         self._allow_deferred_init = allow_deferred_init
@@ -116,10 +119,14 @@ class Parameter(object):
         self.wd_mult = wd_mult
         self.grad_req = grad_req
         self.init = init
-        assert grad_stype in ['default', 'row_sparse', 'csr'], \
-            "grad_stype for Parameter '%s' must be one of 'default', 'row_sparse', or 'csr'," \
-            " but got '%s'" % (name, grad_stype)
+        # sparse related storage type information
+        valid_stypes = ['default', 'row_sparse', 'csr']
+        assert grad_stype in valid_stypes, "grad_stype for Parameter '%s' must be " \
+            "one of 'default', 'row_sparse', or 'csr', but got '%s'" % (name, grad_stype)
+        assert stype in valid_stypes, "stype for Parameter '%s' must be " \
+            "one of 'default', 'row_sparse', or 'csr', but got '%s'" % (name, stype)
         self._grad_stype = grad_stype
+        self._stype = stype
 
 
     def __repr__(self):
@@ -162,6 +169,16 @@ class Parameter(object):
 
         self._shape = new_shape
 
+    def _set_trainer(self, trainer):
+        """ Set the trainer this parameter is associated with. """
+        # trainer cannot be replaced for sparse params
+        if self._stype != 'default' and self._trainer and trainer and self._trainer is not trainer:
+            raise RuntimeError(
+                "Failed to set the trainer for Parameter '%s' because it was already set. " \
+                "More than one trainers for a %s Parameter is not supported." \
+                %(self.name, self._stype))
+        self._trainer = trainer
+
     def _check_and_get(self, arr_list, ctx):
         if arr_list is not None:
             if ctx is list:
@@ -194,6 +211,20 @@ class Parameter(object):
             "because the later does not include Parameters of " \
             "nested child Blocks"%(self.name))
 
+    def _get_row_sparse(self, arr_list, ctx, row_id):
+        """ Get row_sparse data from row_sparse parameters based on row_id. """
+        # get row sparse params based on row ids
+        if not isinstance(row_id, ndarray.NDArray):
+            raise TypeError("row_id must have NDArray type, but %s is given"%(type(row_id)))
+        if not self._trainer:
+            raise RuntimeError("Cannot get row_sparse data for Parameter '%s' when no " \
+                               "Trainer is created with it."%self.name)
+        results = self._check_and_get(arr_list, ctx)
+
+        # fetch row sparse params from the trainer
+        self._trainer._row_sparse_pull(self, results, row_id)
+        return results
+
     def _load_init(self, data, ctx):
         """(Re)initializes by loading from data."""
         if self.shape:
@@ -208,6 +239,8 @@ class Parameter(object):
                 "Failed loading Parameter '%s' from saved params: " \
                 "dtype incompatible expected %s vs saved %s"%(
                     self.name, str(self.dtype), str(data.dtype))
+        if self._stype != data.stype:
+            data = data.tostype(self._stype)
         if isinstance(ctx, Context):
             ctx = [ctx]
         if self._data is None:
@@ -243,7 +276,7 @@ class Parameter(object):
         with autograd.pause():
             if data is None:
                 data = ndarray.zeros(shape=self.shape, dtype=self.dtype,
-                                     ctx=context.cpu())
+                                     ctx=context.cpu(), stype=self._stype)
                 initializer.create(default_init)(
                     initializer.InitDesc(self.name, {'__init__': init}), data)
 
@@ -271,12 +304,18 @@ class Parameter(object):
         self._grad = [ndarray.zeros(shape=i.shape, dtype=i.dtype, ctx=i.context,
                                     stype=self._grad_stype) for i in self._data]
 
-        autograd.mark_variables(self.list_data(), self.list_grad(), self.grad_req)
+        autograd.mark_variables(self._check_and_get(self._data, list),
+                                self._grad, self.grad_req)
 
     def _reduce(self):
         """Reduce data from multiple context."""
-        block = self.list_data()
-        data = ndarray.add_n(*(w.copyto(context.cpu()) for w in block)) / len(block)
+        if self._stype == 'default':
+            block = self.list_data()
+            data = ndarray.add_n(*(w.copyto(context.cpu()) for w in block)) / len(block)
+        else:
+            # fetch all rows for 'row_sparse' param
+            all_row_ids = ndarray.arange(0, self.shape[0], dtype='int64', ctx=context.cpu())
+            data = self.row_sparse_data(all_row_ids)
         return data
 
     def initialize(self, init=None, ctx=None, default_init=initializer.Uniform(),
@@ -380,12 +419,58 @@ class Parameter(object):
             self._deferred_init = self._deferred_init[:3] + (data,)
             return
 
-        for arr in self.list_data():
+        # if update_on_kvstore, we need to make sure the copy stored in kvstore is in sync
+        if self._trainer and self._trainer._kv_initialized and self._trainer._update_on_kvstore:
+            if self not in self._trainer._params_to_init:
+                self._trainer._reset_kvstore()
+
+        for arr in self._check_and_get(self._data, list):
             arr[:] = data
 
+    def row_sparse_data(self, row_id):
+        """Returns a copy of the 'row_sparse' parameter on the same context as row_id's.
+        The copy only retains rows whose ids occur in provided row ids.
+        The parameter must have been initialized on this context before.
+
+        Parameters
+        ----------
+        row_id: NDArray
+            Row ids to retain for the 'row_sparse' parameter.
+
+        Returns
+        -------
+        NDArray on row_id's context
+        """
+        if self._stype != 'row_sparse':
+            raise RuntimeError("Cannot return a copy of Parameter %s via row_sparse_data() " \
+                               "because its storage type is %s. Please use data() instead." \
+                               %(self.name, self._stype))
+        return self._get_row_sparse(self._data, row_id.context, row_id)
+
+    def list_row_sparse_data(self, row_id):
+        """Returns copies of the 'row_sparse' parameter on all contexts, in the same order
+        as creation. The copy only retains rows whose ids occur in provided row ids.
+        The parameter must have been initialized before.
+
+        Parameters
+        ----------
+        row_id: NDArray
+            Row ids to retain for the 'row_sparse' parameter.
+
+        Returns
+        -------
+        list of NDArrays
+        """
+        if self._stype != 'row_sparse':
+            raise RuntimeError("Cannot return copies of Parameter '%s' on all contexts via " \
+                               "list_row_sparse_data() because its storage type is %s. Please " \
+                               "use data() instead." % (self.name, self._stype))
+        return self._get_row_sparse(self._data, list, row_id)
+
     def data(self, ctx=None):
         """Returns a copy of this parameter on one context. Must have been
-        initialized on this context before.
+        initialized on this context before. For sparse parameters, use
+        :py:meth:`Parameter.row_sparse_data` instead.
 
         Parameters
         ----------
@@ -396,11 +481,25 @@ class Parameter(object):
         -------
         NDArray on ctx
         """
+        if self._stype != 'default':
+            raise RuntimeError("Cannot return a copy of Parameter '%s' on ctx %s via data() " \
+                               "because its storage type is %s. Please use row_sparse_data() " \
+                               "instead." % (self.name, str(ctx), self._stype))
         return self._check_and_get(self._data, ctx)
 
     def list_data(self):
         """Returns copies of this parameter on all contexts, in the same order
-        as creation."""
+        as creation. For sparse parameters, use :py:meth:`Parameter.list_row_sparse_data`
+        instead.
+
+        Returns
+        -------
+        list of NDArrays
+        """
+        if self._stype != 'default':
+            raise RuntimeError("Cannot return copies of Parameter '%s' on all contexts via " \
+                               "list_data() because its storage type is %s. Please use " \
+                               "row_sparse_data() instead." % (self.name, self._stype))
         return self._check_and_get(self._data, list)
 
     def grad(self, ctx=None):
@@ -447,7 +546,7 @@ class Parameter(object):
         if self._var is None:
             self._var = symbol.var(self.name, shape=self.shape, dtype=self.dtype,
                                    lr_mult=self.lr_mult, wd_mult=self.wd_mult,
-                                   init=self.init)
+                                   init=self.init, stype=self._stype)
         return self._var
 
     def cast(self, dtype):
diff --git a/python/mxnet/gluon/trainer.py b/python/mxnet/gluon/trainer.py
index f285b91..ef20109 100644
--- a/python/mxnet/gluon/trainer.py
+++ b/python/mxnet/gluon/trainer.py
@@ -21,7 +21,7 @@
 __all__ = ['Trainer']
 
 from .. import optimizer as opt
-from ..model import _create_kvstore
+from ..model import _create_kvstore, _create_sparse_kvstore
 from .parameter import ParameterDict, Parameter
 
 class Trainer(object):
@@ -68,20 +68,30 @@ class Trainer(object):
                 "First argument must be a list or dict of Parameters, " \
                 "got %s."%(type(params)))
         self._params = []
-        for param in params:
+        # parameters to initialize on the kvstore
+        self._contains_sparse = False
+        self._param2idx = {}
+        for i, param in enumerate(params):
             if not isinstance(param, Parameter):
                 raise ValueError(
                     "First argument must be a list or dict of Parameters, " \
                     "got list of %s."%(type(param)))
+            self._param2idx[param.name] = i
             self._params.append(param)
+            param._set_trainer(self)
+            if param._stype != 'default':
+                self._contains_sparse = True
         self._compression_params = compression_params
         optimizer_params = optimizer_params if optimizer_params else {}
         self._scale = float(optimizer_params.get('rescale_grad', 1.0))
         self._contexts = self._check_contexts()
         self._init_optimizer(optimizer, optimizer_params)
+        self._kvstore_params = {'kvstore': kvstore, 'update_on_kvstore': update_on_kvstore}
         self._kv_initialized = False
-        self._kvstore = kvstore
-        self._update_on_kvstore = update_on_kvstore
+        self._kvstore = None
+        self._update_on_kvstore = None
+        self._params_to_init = []
+        self._reset_kvstore()
 
     def _check_contexts(self):
         contexts = None
@@ -109,38 +119,62 @@ class Trainer(object):
         self._updaters = [opt.get_updater(self._optimizer) \
                             for _ in self._contexts]
 
+    def _init_params(self):
+        """Initialize parameters in the KVStore.
+
+        Parameters with incomplete initialization are ignored.
+
+        """
+        assert self._kv_initialized, "Cannot initialize parameters in KVStore " \
+                                     "when KVStore is not initialized."
+        params_to_init = []
+        if self._kvstore:
+            for param in self._params_to_init:
+                if param._deferred_init:
+                    params_to_init.append(param)
+                else:
+                    param_arrays = param._check_and_get(param._data, list)
+                    idx = self._param2idx[param.name]
+                    self._kvstore.init(idx, param_arrays[0])
+                    if param._stype == 'default':
+                        self._kvstore.pull(idx, param_arrays, priority=-idx)
+
+        self._params_to_init = params_to_init
+
+    def _reset_kvstore(self):
+        """Reset kvstore."""
+        if self._kvstore and 'dist' in self._kvstore.type:
+            raise RuntimeError("Cannot reset distributed KVStore.")
+        self._kv_initialized = False
+        self._kvstore = None
+        self._update_on_kvstore = None
+        self._params_to_init = [param for param in self._params]
+
     def _init_kvstore(self):
+        """Create kvstore."""
         arg_arrays = {}
-        contains_sparse = False
-        for param in self._params:
-            arg_arrays[param.name] = param.data(self._contexts[0])
-            if param._grad_stype != 'default':
-                contains_sparse = True
-                # update_on_kvstore is set to False by the user
-                if self._update_on_kvstore is False:
-                    raise RuntimeError("Cannot set update_on_kvstore to False when sparse "
-                                       "gradients and/or sparse weights are present for "
-                                       "Parameter %s." % param.name)
-        kvstore, update_on_kvstore = _create_kvstore(self._kvstore, len(self._contexts),
-                                                     arg_arrays)
-        update_on_kvstore = self._update_on_kvstore if self._update_on_kvstore is not None \
-                            else update_on_kvstore
+        config = self._kvstore_params
+        if self._contains_sparse:
+            kvstore, update_on_kvstore = _create_sparse_kvstore(config['kvstore'])
+            # update_on_kvstore is set to False by the user
+            if config['update_on_kvstore'] is False:
+                raise RuntimeError("Cannot set update_on_kvstore to False when sparse "
+                                   "gradients and/or sparse weights are present for "
+                                   "Parameter '%s'."%param.name)
+        else:
+            kvstore, update_on_kvstore = _create_kvstore(config['kvstore'], len(self._contexts),
+                                                         arg_arrays)
+            if config['update_on_kvstore'] is not None:
+                update_on_kvstore = config['update_on_kvstore']
         if kvstore:
             if self._compression_params:
                 kvstore.set_gradient_compression(self._compression_params)
             # kv.pull(row_sparse_grad) is not supported
-            if contains_sparse:
-                update_on_kvstore = True
-            else:
-                if 'dist' in kvstore.type:
-                    update_on_kvstore = False
+            if 'dist' in kvstore.type and not self._contains_sparse:
+                update_on_kvstore = False
             if update_on_kvstore:
+                # optimizer preferably needs to be set before init for multiprecision
                 kvstore.set_optimizer(self._optimizer)
-            # optimizer preferably needs to be set before init for multiprecision
-            for i, param in enumerate(self._params):
-                param_arrays = param.list_data()
-                kvstore.init(i, param_arrays[0])
-                kvstore.pull(i, param_arrays, priority=-i)
             self._kvstore = kvstore
             self._update_on_kvstore = update_on_kvstore
         else:
@@ -171,6 +205,15 @@ class Trainer(object):
         else:
             self._optimizer.set_learning_rate(lr)
 
+    def _row_sparse_pull(self, parameter, out, row_id):
+        # initialize kv and params if not already
+        if not self._kv_initialized:
+            self._init_kvstore()
+        if self._params_to_init:
+            self._init_params()
+        self._kvstore.row_sparse_pull(self._param2idx[parameter.name], \
+                                      out=out, row_ids=row_id)
+
     def step(self, batch_size, ignore_stale_grad=False):
         """Makes one step of parameter update. Should be called after
         `autograd.backward()` and outside of `record()` scope.
@@ -191,6 +234,8 @@ class Trainer(object):
         """
         if not self._kv_initialized:
             self._init_kvstore()
+        if self._params_to_init:
+            self._init_params()
 
         self._optimizer.rescale_grad = self._scale / batch_size
 
@@ -210,6 +255,8 @@ class Trainer(object):
         """
         if not self._kv_initialized:
             self._init_kvstore()
+        if self._params_to_init:
+            self._init_params()
         assert not (self._kvstore and self._update_on_kvstore), \
                 'allreduce_grads() when parameters are updated on kvstore ' \
                 'is not supported. Try setting `update_on_kvstore` ' \
@@ -250,6 +297,8 @@ class Trainer(object):
         """
         if not self._kv_initialized:
             self._init_kvstore()
+        if self._params_to_init:
+            self._init_params()
         assert not (self._kvstore and self._update_on_kvstore), \
                 'update() when parameters are updated on kvstore ' \
                 'is not supported. Try setting `update_on_kvstore` ' \
@@ -264,7 +313,7 @@ class Trainer(object):
                 continue
 
             if not ignore_stale_grad:
-                for data in param.list_data():
+                for data in param._check_and_get(param._data, list):
                     if not data._fresh_grad:
                         raise UserWarning(
                             "Gradient of Parameter `%s` on context %s has not been updated "
@@ -276,7 +325,10 @@ class Trainer(object):
                             %(param.name, str(data.context)))
 
             if self._kvstore and self._update_on_kvstore:
-                self._kvstore.pull(i, param.list_data(), priority=-i)
+                if param._stype == 'default':
+                    # 'row_sparse' parameters are not pulled immediately - they're pulled
+                    # in `SparseBlock.sparse_forward`
+                    self._kvstore.pull(i, param.list_data(), priority=-i)
                 continue
 
             for upd, arr, grad in zip(self._updaters, param.list_data(), param.list_grad()):
@@ -296,8 +348,12 @@ class Trainer(object):
 
         if not self._kv_initialized:
             self._init_kvstore()
+        if self._params_to_init:
+            self._init_params()
 
         if self._update_on_kvstore:
+            assert not self._params_to_init, "Cannot save trainer states when some " \
+                                             "parameters are not yet initialized in kvstore."
             self._kvstore.save_optimizer_states(fname, dump_optimizer=True)
         else:
             with open(fname, 'wb') as fout:
@@ -313,6 +369,8 @@ class Trainer(object):
         """
         if not self._kv_initialized:
             self._init_kvstore()
+        if self._params_to_init:
+            self._init_params()
 
         if self._update_on_kvstore:
             self._kvstore.load_optimizer_states(fname)
diff --git a/python/mxnet/model.py b/python/mxnet/model.py
index ae7726d..3a50553 100644
--- a/python/mxnet/model.py
+++ b/python/mxnet/model.py
@@ -55,6 +55,25 @@ BatchEndParam = namedtuple('BatchEndParams',
                             'eval_metric',
                             'locals'])
 
+def _create_sparse_kvstore(kvstore):
+    """Create kvstore assuming some parameters' storage types are row_sparse.
+
+    Parameters
+    ----------
+    kvstore : KVStore or str
+        The kvstore.
+    """
+    # always update on kvstore
+    update_on_kvstore = True
+    if isinstance(kvstore, kvs.KVStore):
+        kv = kvstore
+    elif isinstance(kvstore, str):
+        kv = kvs.create(kvstore)
+    else:
+        raise TypeError("Cannot create '%s' KVStore with row_sparse parameters. "
+                        "The type must be KVStore or str." % kvstore)
+    return (kv, update_on_kvstore)
+
 def _create_kvstore(kvstore, num_device, arg_params):
     """Create kvstore
     This function select and create a proper kvstore if given the kvstore type.
diff --git a/src/operator/tensor/indexing_op.h b/src/operator/tensor/indexing_op.h
index 28827db..23a866d 100644
--- a/src/operator/tensor/indexing_op.h
+++ b/src/operator/tensor/indexing_op.h
@@ -270,6 +270,12 @@ inline bool EmbeddingOpBackwardStorageType(const nnvm::NodeAttrs& attrs,
       dispatched = dispatch_mode_assign(dispatch_mode, target_mode);
     }
   }
+  // Print user friendly error message to notify misuses of sparse_grad
+  if (weight_grad_stype != target_stype) {
+    LOG(FATAL) << "Cannot use sparse_grad = " << sparse_grad
+               << ", while stype of gradients w.r.t embedding weight is "
+               << common::stype_string(weight_grad_stype);
+  }
   return dispatched;
 }
 
diff --git a/tests/nightly/dist_sync_kvstore.py b/tests/nightly/dist_sync_kvstore.py
index 3bf5cbf..32ed2dd 100644
--- a/tests/nightly/dist_sync_kvstore.py
+++ b/tests/nightly/dist_sync_kvstore.py
@@ -24,7 +24,7 @@ import argparse
 import mxnet as mx
 import numpy as np
 import numpy.random as rnd
-from mxnet.test_utils import assert_almost_equal
+from mxnet.test_utils import assert_almost_equal, assert_exception
 from test_kvstore import compute_expected_2bit_quantization
 
 def check_diff(A, x, rank=None):
@@ -350,6 +350,20 @@ def test_sync_init(gpu_tests=False):
         check_init(kv, init_test_keys_device_big, big_shape, device=True)
     print('worker ' + str(kv.rank) + ' is initialized')
 
+def test_gluon_trainer_reset():
+    params = mx.gluon.ParameterDict()
+    x = params.get('x', shape=(4, 2), lr_mult=1.0, stype='row_sparse')
+    params.initialize(ctx=mx.cpu(0), init='zeros')
+    trainer = mx.gluon.Trainer(params, 'sgd', {'learning_rate': 0.1}, kvstore=kv)
+    params.save('test_gluon_trainer_reset_' + str(my_rank) + '.params')
+    row_id = mx.nd.arange(0, 4)
+    w = x.row_sparse_data(row_id)
+    assert trainer._kv_initialized and trainer._update_on_kvstore
+    # load would fail to reset kvstore since update_on_kvstore is True
+    assert_exception(params.load, RuntimeError, 'test_gluon_trainer_reset_' + str(my_rank) + '.params')
+    print('worker ' + str(my_rank) + ' passed test_gluon_trainer_reset')
+
+
 if __name__ == "__main__":
     parser = argparse.ArgumentParser(description='test distributed kvstore in dist_sync mode')
     parser.add_argument('--nrepeat', type=int, default=7)
@@ -357,13 +371,16 @@ if __name__ == "__main__":
     parser.add_argument('--no-gpu', dest='gpu', action='store_false')
     parser.add_argument('--no-multiprecision', dest='multiprecision', action='store_false')
     opt = parser.parse_args()
-    if opt.type == 'all' or  opt.type == 'init':
+    if opt.type == 'gluon':
+        test_gluon_trainer_reset()
+    if opt.type == 'all' or opt.type == 'init':
         test_sync_init(opt.gpu)
-    kv = init_kv()
-    if opt.type == 'all' or  opt.type == 'default':
+    if opt.type == 'all' or opt.type == 'default':
+        kv = init_kv()
         kv = set_optimizer(use_multiprecision=opt.multiprecision)
         test_sync_push_pull(opt.nrepeat)
     # dont run non compressed tests after this as kvstore compression will be set here
-    if opt.type == 'all' or  opt.type == 'compressed':
+    if opt.type == 'all' or opt.type == 'compressed':
+        kv = init_kv()
         kv, threshold = init_kv_compressed(kv)
         test_sync_2bit_compression(threshold, opt.nrepeat)
diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py
index b1b5fe2..2384812 100644
--- a/tests/python/unittest/test_gluon.py
+++ b/tests/python/unittest/test_gluon.py
@@ -19,7 +19,8 @@ import mxnet as mx
 from mxnet import gluon
 from mxnet.gluon import nn
 from mxnet.test_utils import assert_almost_equal
-from common import setup_module, with_seed
+from mxnet.ndarray.ndarray import _STORAGE_TYPE_STR_TO_ID
+from common import setup_module, with_seed, assertRaises
 import numpy as np
 from nose.tools import raises, assert_raises
 from copy import deepcopy
@@ -27,8 +28,6 @@ import warnings
 import json
 import unittest
 
-
-
 @with_seed()
 def test_parameter():
     p = gluon.Parameter('weight', shape=(10, 10))
@@ -39,33 +38,122 @@ def test_parameter():
     assert p.data(mx.cpu(0)).shape == (10, 10)
     assert p.var().name == 'weight'
     assert p.grad(mx.cpu(0)).stype == 'default'
+    assert p.data(mx.cpu(0)).stype == 'default'
 
     p.reset_ctx(ctx=[mx.cpu(1), mx.cpu(2)])
     assert p.list_ctx() == [mx.cpu(1), mx.cpu(2)]
 
 @with_seed()
+@raises(AssertionError)
+def test_invalid_parameter_stype():
+    p = gluon.Parameter('weight', shape=(10, 10), stype='invalid')
+
+@with_seed()
+@raises(AssertionError)
+def test_invalid_parameter_grad_stype():
+    p = gluon.Parameter('weight', shape=(10, 10), grad_stype='invalid')
+
+@with_seed()
 def test_sparse_parameter():
-    p = gluon.Parameter('weight', shape=(10, 10), grad_stype='row_sparse')
+    p = gluon.Parameter('weight', shape=(10, 10), stype='row_sparse', grad_stype='row_sparse')
     p.initialize(init='xavier', ctx=[mx.cpu(0), mx.cpu(1)])
-    assert len(p.list_data()) == 2
+    row_id = mx.nd.arange(0, 10, ctx=mx.cpu(1))
     assert len(p.list_grad()) == 2
-    assert p.data(mx.cpu(1)).context == mx.cpu(1)
-    assert p.data(mx.cpu(0)).shape == (10, 10)
+    # getting row_sparse data without trainer throws an exception
+    assertRaises(RuntimeError, p.list_row_sparse_data, row_id)
+    trainer = mx.gluon.Trainer([p], 'sgd')
+    assert len(p.list_row_sparse_data(row_id)) == 2
+    weight = p.row_sparse_data(row_id)
+    assert weight.context == mx.cpu(1)
+    assert weight.shape == (10, 10)
+    assert weight.stype == 'row_sparse'
     assert p.var().name == 'weight'
+    assert p.var().attr('__storage_type__') == str(_STORAGE_TYPE_STR_TO_ID['row_sparse'])
     assert p.grad(mx.cpu(0)).stype == 'row_sparse'
 
     p.reset_ctx(ctx=[mx.cpu(1), mx.cpu(2)])
     assert p.list_ctx() == [mx.cpu(1), mx.cpu(2)]
 
+@with_seed()
+def test_parameter_invalid_access():
+    # cannot call data on row_sparse parameters
+    p0 = gluon.Parameter('weight', shape=(10, 10), stype='row_sparse', grad_stype='row_sparse')
+    p0.initialize(init='xavier', ctx=[mx.cpu(0), mx.cpu(1)])
+    assertRaises(RuntimeError, p0.data)
+    assertRaises(RuntimeError, p0.list_data)
+    row_id = mx.nd.arange(0, 10)
+    # cannot call row_sparse_data on dense parameters
+    p1 = gluon.Parameter('weight', shape=(10, 10))
+    p1.initialize(init='xavier', ctx=[mx.cpu(0), mx.cpu(1)])
+    assertRaises(RuntimeError, p1.row_sparse_data, row_id.copyto(mx.cpu(0)))
+    assertRaises(RuntimeError, p1.list_row_sparse_data, row_id)
 
 @with_seed()
 def test_paramdict():
-    params = gluon.ParameterDict('net_')
-    params.get('weight', shape=(10, 10))
-    assert list(params.keys()) == ['net_weight']
-    params.initialize(ctx=mx.cpu())
-    params.save('test.params')
-    params.load('test.params', mx.cpu())
+    params0 = gluon.ParameterDict('net_')
+    params0.get('w0', shape=(10, 10))
+    params0.get('w1', shape=(10, 10), stype='row_sparse')
+    all_row_ids = mx.nd.arange(0, 10, ctx=mx.cpu())
+    # check param names
+    assert list(params0.keys()) == ['net_w0', 'net_w1']
+    params0.initialize(ctx=mx.cpu())
+    trainer0 = mx.gluon.Trainer(params0, 'sgd')
+    prev_w0 = params0.get('w0').data(mx.cpu())
+    prev_w1 = params0.get('w1').row_sparse_data(all_row_ids)
+    # save params
+    params0.save('test_paramdict.params')
+
+    # load params
+    params1 = gluon.ParameterDict('net_')
+    params1.get('w0', shape=(10, 10))
+    params1.get('w1', shape=(10, 10), stype='row_sparse')
+    params1.load('test_paramdict.params', mx.cpu())
+    trainer1 = mx.gluon.Trainer(params1, 'sgd')
+
+    # compare the values before and after save/load
+    cur_w0 = params1.get('w0').data(mx.cpu())
+    cur_w1 = params1.get('w1').row_sparse_data(all_row_ids)
+    mx.test_utils.assert_almost_equal(prev_w0.asnumpy(), cur_w0.asnumpy())
+    mx.test_utils.assert_almost_equal(prev_w1.asnumpy(), cur_w1.asnumpy())
+
+    # create a new param dict with dense params, and load from the checkpoint
+    # of sparse & dense params
+    params2 = gluon.ParameterDict('net_')
+    params2.get('w0', shape=(10, 10))
+    params2.get('w1', shape=(10, 10))
+    params2.load('test_paramdict.params', mx.cpu())
+
+    # compare the values before and after save/load
+    cur_w0 = params2.get('w0').data(mx.cpu())
+    cur_w1 = params2.get('w1').data(mx.cpu())
+    mx.test_utils.assert_almost_equal(prev_w0.asnumpy(), cur_w0.asnumpy())
+    mx.test_utils.assert_almost_equal(prev_w1.asnumpy(), cur_w1.asnumpy())
+
+
+@with_seed()
+def test_parameter_row_sparse_data():
+    ctx0 = mx.cpu(1)
+    ctx1 = mx.cpu(2)
+    dim0 = 4
+    x = gluon.Parameter('x', shape=(dim0, 2), stype='row_sparse')
+    x.initialize(init='xavier', ctx=[ctx0, ctx1])
+    trainer = gluon.Trainer([x], 'sgd')
+    x_param = x._data[0].copy()
+    assert x_param.stype == 'row_sparse'
+    row_id_0 = mx.nd.array([0,1], ctx=ctx0)
+    retained_0 = x.row_sparse_data(row_id_0)
+    retained_target_0 = mx.nd.sparse.retain(x_param, row_id_0.as_in_context(ctx0))
+    mx.test_utils.assert_almost_equal(retained_0.asnumpy(), retained_target_0.asnumpy())
+    assert retained_0.context == ctx0
+    row_id_1 = mx.nd.arange(0, dim0, ctx=ctx1)
+    retained_1 = x.row_sparse_data(row_id_1)
+    retained_target_1 = x_param
+    mx.test_utils.assert_almost_equal(retained_1.asnumpy(), retained_target_1.asnumpy())
+    assert retained_1.context == ctx1
+    row_id_2 = mx.nd.array([0,1,2])
+    retained_2 = x.list_row_sparse_data(row_id_2)
+    retained_target_2 = mx.nd.sparse.retain(x_param, row_id_2.as_in_context(ctx0))
+    mx.test_utils.assert_almost_equal(retained_2[0].asnumpy(), retained_target_2.asnumpy())
 
 
 @with_seed()
@@ -246,7 +334,29 @@ def test_symbol_block():
     net.hybridize()
     assert isinstance(net(mx.nd.zeros((16, 10))), mx.nd.NDArray)
 
+@with_seed()
+@raises(AssertionError)
+def test_sparse_symbol_block():
+    data = mx.sym.var('data')
+    weight = mx.sym.var('weight', stype='row_sparse')
+    bias = mx.sym.var('bias')
+    out = mx.sym.broadcast_add(mx.sym.dot(data, weight), bias)
+    # an exception is expected when creating a SparseBlock w/ sparse param
+    net = gluon.SymbolBlock(out, data)
 
+@with_seed()
+@raises(RuntimeError)
+def test_sparse_hybrid_block():
+    params = gluon.ParameterDict('net_')
+    params.get('weight', shape=(5,5), stype='row_sparse', dtype='float32')
+    params.get('bias', shape=(5,), dtype='float32')
+    net = gluon.nn.Dense(5, params=params)
+    net.initialize()
+    x = mx.nd.ones((2,5))
+    # an exception is expected when forwarding a HybridBlock w/ sparse param
+    y = net(x)
+
+@with_seed()
 def check_layer_forward(layer, dshape):
     layer.collect_params().initialize()
     x = mx.nd.ones(shape=dshape)
@@ -496,80 +606,6 @@ def test_flatten():
     x = mx.nd.zeros((3,))
     assert flatten(x).shape == (3, 1)
 
-
-@with_seed()
-def test_trainer():
-    def dict_equ(a, b):
-        assert set(a) == set(b)
-        for k in a:
-            assert (a[k].asnumpy() == b[k].asnumpy()).all()
-    x = gluon.Parameter('x', shape=(10,))
-    x.initialize(ctx=[mx.cpu(0), mx.cpu(1)], init='zeros')
-    trainer = gluon.Trainer([x], 'sgd', {'learning_rate': 1.0, 'momentum': 0.5})
-    with mx.autograd.record():
-        for w in x.list_data():
-            y = w + 1
-            y.backward()
-    trainer.step(1)
-
-    assert (x.data(mx.cpu(1)).asnumpy() == -2).all()
-
-    x.lr_mult = 0.5
-
-    with mx.autograd.record():
-        for w in x.list_data():
-            y = w + 1
-            y.backward()
-    trainer.step(1)
-
-    assert (x.data(mx.cpu(1)).asnumpy() == -4).all()
-
-    trainer.save_states('test_trainer.states')
-    states = deepcopy(trainer._kvstore._updater.states) if trainer._update_on_kvstore \
-             else deepcopy(trainer._updaters[0].states)
-    trainer.load_states('test_trainer.states')
-    if trainer._update_on_kvstore:
-        dict_equ(trainer._kvstore._updater.states, states)
-        assert trainer._optimizer == trainer._kvstore._updater.optimizer
-    else:
-        for updater in trainer._updaters:
-            dict_equ(updater.states, states)
-        assert trainer._optimizer == trainer._updaters[0].optimizer
-    assert_raises(AssertionError, trainer.update, 1)
-    assert_raises(AssertionError, trainer.allreduce_grads)
-
-    x = gluon.Parameter('x', shape=(10,))
-    x.initialize(ctx=[mx.cpu(0), mx.cpu(1)], init='zeros')
-    trainer2 = gluon.Trainer([x], 'sgd', {'learning_rate': 1.0, 'momentum': 0.5},
-                             update_on_kvstore=False)
-    with mx.autograd.record():
-        for i, w in enumerate(x.list_data()):
-            y = i*w
-            y.backward()
-    assert (x.grad(mx.cpu(0)).asnumpy() != x.grad(mx.cpu(1)).asnumpy()).all()
-    trainer2.allreduce_grads()
-    assert (x.grad(mx.cpu(0)).asnumpy() == x.grad(mx.cpu(1)).asnumpy()).all()
-    trainer2.update(1)
-
-    assert (x.data(mx.cpu(1)).asnumpy() == -1).all(), x.data(mx.cpu(1)).asnumpy()
-
-@with_seed()
-def test_trainer_save_load():
-    x = gluon.Parameter('x', shape=(10,), lr_mult=1.0)
-    x.initialize(ctx=[mx.cpu(0), mx.cpu(1)], init='zeros')
-    trainer = gluon.Trainer([x], 'sgd', {'learning_rate': 0.1})
-    with mx.autograd.record():
-        for w in x.list_data():
-            y = w + 1
-            y.backward()
-    trainer.step(1)
-    assert trainer._kvstore._updater.optimizer._get_lr(0) == 0.1
-    trainer.save_states('test_trainer_save_load.states')
-    trainer.load_states('test_trainer_save_load.states')
-    x.lr_mult = 2.0
-    # check if parameter dict is correctly associated with optimizer after load_state
-    assert trainer._kvstore._updater.optimizer._get_lr(0) == 0.2
-
 @with_seed()
 def test_block_attr_hidden():
     b = gluon.Block()
@@ -900,6 +936,7 @@ def test_inline():
     assert len_1 == len_2 + 2
 
 
+@with_seed()
 def test_activations():
     point_to_validate = mx.nd.array([-0.1, 0.1] * 3)
 
@@ -1013,13 +1050,14 @@ def test_req():
 @with_seed()
 def test_save_load():
     net = mx.gluon.model_zoo.vision.get_resnet(1, 18, pretrained=True)
-    net.save_params('test.params')
+    net.save_params('test_save_load.params')
 
     net = mx.gluon.model_zoo.vision.get_resnet(1, 18)
     net.output = mx.gluon.nn.Dense(1000)
 
-    net.load_params('test.params')
+    net.load_params('test_save_load.params')
 
+@with_seed()
 def test_symbol_block_save_load():
     class Net(gluon.HybridBlock):
         def __init__(self):
@@ -1042,10 +1080,10 @@ def test_symbol_block_save_load():
     net1.initialize(mx.init.Normal())
     net1.hybridize()
     net1(mx.nd.random.normal(shape=(1, 3, 32, 32)))
-    net1.save_params('./test.params')
+    net1.save_params('./test_symbol_block_save_load.params')
 
     net2 = Net()
-    net2.load_params('./test.params', ctx=mx.cpu())
+    net2.load_params('./test_symbol_block_save_load.params', ctx=mx.cpu())
 
 
 @with_seed()
diff --git a/tests/python/unittest/test_gluon_trainer.py b/tests/python/unittest/test_gluon_trainer.py
new file mode 100644
index 0000000..c2e11eb
--- /dev/null
+++ b/tests/python/unittest/test_gluon_trainer.py
@@ -0,0 +1,200 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import mxnet as mx
+import unittest
+import numpy as np
+from mxnet import gluon
+from mxnet.gluon import nn
+from mxnet.test_utils import assert_almost_equal
+from common import setup_module, with_seed, assertRaises
+from copy import deepcopy
+from nose.tools import raises, assert_raises
+
+@with_seed()
+@raises(RuntimeError)
+def test_multi_trainer():
+    x = gluon.Parameter('x', shape=(10,), stype='row_sparse')
+    x.initialize()
+    # test set trainer
+    trainer0 = gluon.Trainer([x], 'sgd')
+    assert(x._trainer is trainer0)
+    # test unset trainer
+    x._set_trainer(None)
+    assert(x._trainer is None)
+    x._set_trainer(trainer0)
+    # multiple trainers for a sparse Parameter is not allowed
+    trainer1 = gluon.Trainer([x], 'sgd')
+
+@with_seed()
+def test_trainer():
+    def dict_equ(a, b):
+        assert set(a) == set(b)
+        for k in a:
+            assert (a[k].asnumpy() == b[k].asnumpy()).all()
+    x = gluon.Parameter('x', shape=(10,))
+    x.initialize(ctx=[mx.cpu(0), mx.cpu(1)], init='zeros')
+    trainer = gluon.Trainer([x], 'sgd', {'learning_rate': 1.0, 'momentum': 0.5})
+    with mx.autograd.record():
+        for w in x.list_data():
+            y = w + 1
+            y.backward()
+    trainer.step(1)
+
+    assert (x.data(mx.cpu(1)).asnumpy() == -2).all()
+
+    x.lr_mult = 0.5
+
+    with mx.autograd.record():
+        for w in x.list_data():
+            y = w + 1
+            y.backward()
+    trainer.step(1)
+
+    assert (x.data(mx.cpu(1)).asnumpy() == -4).all()
+
+    trainer.save_states('test_trainer.states')
+    states = deepcopy(trainer._kvstore._updater.states) if trainer._update_on_kvstore \
+             else deepcopy(trainer._updaters[0].states)
+    trainer.load_states('test_trainer.states')
+    if trainer._update_on_kvstore:
+        dict_equ(trainer._kvstore._updater.states, states)
+        assert trainer._optimizer == trainer._kvstore._updater.optimizer
+    else:
+        for updater in trainer._updaters:
+            dict_equ(updater.states, states)
+        assert trainer._optimizer == trainer._updaters[0].optimizer
+    assert_raises(AssertionError, trainer.update, 1)
+    assert_raises(AssertionError, trainer.allreduce_grads)
+
+    x = gluon.Parameter('x', shape=(10,))
+    x.initialize(ctx=[mx.cpu(0), mx.cpu(1)], init='zeros')
+    trainer2 = gluon.Trainer([x], 'sgd', {'learning_rate': 1.0, 'momentum': 0.5},
+                             update_on_kvstore=False)
+    with mx.autograd.record():
+        for i, w in enumerate(x.list_data()):
+            y = i*w
+            y.backward()
+    assert (x.grad(mx.cpu(0)).asnumpy() != x.grad(mx.cpu(1)).asnumpy()).all()
+    trainer2.allreduce_grads()
+    assert (x.grad(mx.cpu(0)).asnumpy() == x.grad(mx.cpu(1)).asnumpy()).all()
+    trainer2.update(1)
+
+    assert (x.data(mx.cpu(1)).asnumpy() == -1).all(), x.data(mx.cpu(1)).asnumpy()
+
+@with_seed()
+def test_trainer_save_load():
+    x = gluon.Parameter('x', shape=(10,), lr_mult=1.0)
+    x.initialize(ctx=[mx.cpu(0), mx.cpu(1)], init='zeros')
+    trainer = gluon.Trainer([x], 'sgd', {'learning_rate': 0.1})
+    with mx.autograd.record():
+        for w in x.list_data():
+            y = w + 1
+            y.backward()
+    trainer.step(1)
+    assert trainer._kvstore._updater.optimizer._get_lr(0) == 0.1
+    trainer.save_states('test_trainer_save_load.states')
+    trainer.load_states('test_trainer_save_load.states')
+    x.lr_mult = 2.0
+    # check if parameter dict is correctly associated with optimizer after load_state
+    assert trainer._kvstore._updater.optimizer._get_lr(0) == 0.2
+
+@with_seed()
+def test_trainer_multi_layer_init():
+    class Net(gluon.Block):
+        def __init__(self, **kwargs):
+            super(Net, self).__init__(**kwargs)
+            with self.name_scope():
+                # sparse param
+                self.embed_weight = self.params.get('embed_weight', stype='row_sparse',
+                                                    shape=(4,3), grad_stype='row_sparse')
+                # dense param from a hybrid block
+                self.dense0 = nn.Dense(2)
+
+        def forward(self, x):
+            embed_weight = self.embed_weight.row_sparse_data(x)
+            embed = mx.nd.Embedding(data=x, weight=embed_weight,
+                                    input_dim=4, output_dim=3, sparse_grad=True)
+            return self.dense0(embed)
+
+    def check_init(ctxes):
+        net = Net(prefix='net_')
+        net.initialize(mx.init.One(), ctx=ctxes)
+        trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 1})
+        data = mx.nd.array([[0,2], [1,2]])
+        xs = gluon.utils.split_and_load(data, ctxes)
+        ys = []
+        with mx.autograd.record():
+            for x in xs:
+                y = net(x)
+                ys.append(y)
+        for y in ys:
+            y.backward()
+        trainer.step(1)
+        # all parameters should be initialized
+        assert not trainer._params_to_init
+        all_rows = mx.nd.arange(0, 4, ctx=mx.cpu(1))
+        # check the updated weights
+        weight = net.embed_weight.row_sparse_data(all_rows).asnumpy()
+        assert (weight[0] == -1).all()
+        assert (weight[1] == -1).all()
+        assert (weight[2] == -3).all()
+        assert (weight[3] == 1).all()
+
+    check_init([mx.cpu(1), mx.cpu(2)])
+    check_init([mx.cpu(1)])
+
+@with_seed()
+def test_trainer_save_load():
+    x = gluon.Parameter('x', shape=(10,), lr_mult=1.0)
+    x.initialize(ctx=[mx.cpu(0), mx.cpu(1)], init='zeros')
+    trainer = gluon.Trainer([x], 'sgd', {'learning_rate': 0.1})
+    with mx.autograd.record():
+        for w in x.list_data():
+            y = w + 1
+            y.backward()
+    trainer.step(1)
+    assert trainer._kvstore._updater.optimizer._get_lr(0) == 0.1
+    trainer.save_states('test_trainer_save_load.states')
+    trainer.load_states('test_trainer_save_load.states')
+    x.lr_mult = 2.0
+    # check if parameter dict is correctly associated with optimizer after load_state
+    assert trainer._kvstore._updater.optimizer._get_lr(0) == 0.2
+
+@with_seed()
+def test_trainer_reset_kv():
+    params = gluon.ParameterDict()
+    x = params.get('x', shape=(10,), lr_mult=1.0)
+    params.initialize(ctx=[mx.cpu(0), mx.cpu(1)], init='zeros')
+    trainer = gluon.Trainer(params, 'sgd', {'learning_rate': 0.1})
+    params.save('test_trainer_reset_kv.params')
+    with mx.autograd.record():
+        for w in x.list_data():
+            y = w + 1
+            y.backward()
+    trainer.step(1)
+    # load would reset kvstore
+    params.load('test_trainer_reset_kv.params')
+    assert trainer._kvstore is None
+    assert trainer._kv_initialized is False
+    with mx.autograd.record():
+        for w in x.list_data():
+            y = w + 1
+            y.backward()
+    trainer.step(1)
+    # the updated parameter should be based on the loaded checkpoint
+    assert (x.data(mx.cpu()) == -0.2).asnumpy().all()

-- 
To stop receiving notification emails like this one, please contact
jxie@apache.org.