You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2018/06/14 05:01:46 UTC

[incubator-mxnet] branch master updated: add import_ for SymbolBlock (#11127)

This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 66ab27e  add import_ for SymbolBlock (#11127)
66ab27e is described below

commit 66ab27e67a70b1164364a8a52ebbe0def45dc327
Author: Eric Junyuan Xie <pi...@users.noreply.github.com>
AuthorDate: Wed Jun 13 22:01:40 2018 -0700

    add import_ for SymbolBlock (#11127)
    
    * add import_ for SymbolBlock
    
    * fix
    
    * Update block.py
    
    * add save_parameters
    
    * fix
    
    * fix lint
    
    * fix
    
    * fix
    
    * fix
    
    * fix
    
    * fix
    
    * Update save_load_params.md
---
 docs/tutorials/gluon/hybrid.md                    |  2 +-
 docs/tutorials/gluon/naming.md                    |  6 +-
 docs/tutorials/gluon/save_load_params.md          | 16 +---
 example/gluon/dcgan.py                            |  8 +-
 example/gluon/embedding_learning/train.py         |  2 +-
 example/gluon/image_classification.py             |  8 +-
 example/gluon/mnist.py                            |  2 +-
 example/gluon/style_transfer/main.py              |  8 +-
 example/gluon/super_resolution.py                 |  4 +-
 example/gluon/tree_lstm/main.py                   |  2 +-
 example/gluon/word_language_model/train.py        |  4 +-
 python/mxnet/gluon/block.py                       | 90 +++++++++++++++++++++--
 python/mxnet/gluon/model_zoo/vision/alexnet.py    |  2 +-
 python/mxnet/gluon/model_zoo/vision/densenet.py   |  2 +-
 python/mxnet/gluon/model_zoo/vision/inception.py  |  2 +-
 python/mxnet/gluon/model_zoo/vision/mobilenet.py  |  4 +-
 python/mxnet/gluon/model_zoo/vision/resnet.py     |  4 +-
 python/mxnet/gluon/model_zoo/vision/squeezenet.py |  2 +-
 python/mxnet/gluon/model_zoo/vision/vgg.py        |  4 +-
 tests/python/unittest/test_gluon.py               | 54 +++++++++++---
 20 files changed, 164 insertions(+), 62 deletions(-)

diff --git a/docs/tutorials/gluon/hybrid.md b/docs/tutorials/gluon/hybrid.md
index 3554a15..5c8372a 100644
--- a/docs/tutorials/gluon/hybrid.md
+++ b/docs/tutorials/gluon/hybrid.md
@@ -117,7 +117,7 @@ x = mx.sym.var('data')
 y = net(x)
 print(y)
 y.save('model.json')
-net.save_params('model.params')
+net.save_parameters('model.params')
 ```
 
 If your network outputs more than one value, you can use `mx.sym.Group` to
diff --git a/docs/tutorials/gluon/naming.md b/docs/tutorials/gluon/naming.md
index 37b63fa..3606a03 100644
--- a/docs/tutorials/gluon/naming.md
+++ b/docs/tutorials/gluon/naming.md
@@ -203,12 +203,12 @@ except Exception as e:
     Parameter 'model1_dense0_weight' is missing in file 'model.params', which contains parameters: 'model0_mydense_weight', 'model0_dense1_bias', 'model0_dense1_weight', 'model0_dense0_weight', 'model0_dense0_bias', 'model0_mydense_bias'. Please make sure source and target networks have the same prefix.
 
 
-To solve this problem, we use `save_params`/`load_params` instead of `collect_params` and `save`/`load`. `save_params` uses model structure, instead of parameter name, to match parameters.
+To solve this problem, we use `save_parameters`/`load_parameters` instead of `collect_params` and `save`/`load`. `save_parameters` uses model structure, instead of parameter name, to match parameters.
 
 
 ```python
-model0.save_params('model.params')
-model1.load_params('model.params')
+model0.save_parameters('model.params')
+model1.load_parameters('model.params')
 print(mx.nd.load('model.params').keys())
 ```
 
diff --git a/docs/tutorials/gluon/save_load_params.md b/docs/tutorials/gluon/save_load_params.md
index cd87680..f5f4812 100644
--- a/docs/tutorials/gluon/save_load_params.md
+++ b/docs/tutorials/gluon/save_load_params.md
@@ -10,7 +10,7 @@ Parameters of any Gluon model can be saved using the `save_params` and `load_par
 
 **2. Save/load model parameters AND architecture**
 
-The Model architecture of `Hybrid` models stays static and don't change during execution. Therefore both model parameters AND architecture can be saved and loaded using `export`, `load_checkpoint` and `load` methods.
+The Model architecture of `Hybrid` models stays static and don't change during execution. Therefore both model parameters AND architecture can be saved and loaded using `export`, `imports` methods.
 
 Let's look at the above methods in more detail. Let's start by importing the modules we'll need.
 
@@ -61,7 +61,7 @@ def build_lenet(net):
         net.add(gluon.nn.Dense(512, activation="relu"))
         # Second fully connected layer with as many neurons as the number of classes
         net.add(gluon.nn.Dense(num_outputs))
-        
+
         return net
 
 # Train a given model using MNIST data
@@ -240,18 +240,10 @@ One of the main reasons to serialize model architecture into a JSON file is to l
 
 ### From Python
 
-Serialized Hybrid networks (saved as .JSON and .params file) can be loaded and used inside Python frontend using `mx.model.load_checkpoint` and `gluon.nn.SymbolBlock`. To demonstrate that, let's load the network we serialized above.
+Serialized Hybrid networks (saved as .JSON and .params file) can be loaded and used inside Python frontend using `gluon.nn.SymbolBlock`. To demonstrate that, let's load the network we serialized above.
 
 ```python
-# Load the network architecture and parameters
-sym = mx.sym.load('lenet-symbol.json')
-# Create a Gluon Block using the loaded network architecture.
-# 'inputs' parameter specifies the name of the symbol in the computation graph
-# that should be treated as input. 'data' is the default name used for input when
-# a model architecture is saved to a file.
-deserialized_net = gluon.nn.SymbolBlock(outputs=sym, inputs=mx.sym.var('data'))
-# Load the parameters
-deserialized_net.collect_params().load('lenet-0001.params', ctx=ctx)
+deserialized_net = gluon.nn.SymbolBlock.imports("lenet-symbol.json", ['data'], "lenet-0001.params")
 ```
 
 `deserialized_net` now contains the network we deserialized from files. Let's test the deserialized network to make sure it works.
diff --git a/example/gluon/dcgan.py b/example/gluon/dcgan.py
index 3233f43..8ac9c52 100644
--- a/example/gluon/dcgan.py
+++ b/example/gluon/dcgan.py
@@ -229,8 +229,8 @@ for epoch in range(opt.nepoch):
     logging.info('time: %f' % (time.time() - tic))
 
     if check_point:
-        netG.save_params(os.path.join(outf,'generator_epoch_%d.params' %epoch))
-        netD.save_params(os.path.join(outf,'discriminator_epoch_%d.params' % epoch))
+        netG.save_parameters(os.path.join(outf,'generator_epoch_%d.params' %epoch))
+        netD.save_parameters(os.path.join(outf,'discriminator_epoch_%d.params' % epoch))
 
-netG.save_params(os.path.join(outf, 'generator.params'))
-netD.save_params(os.path.join(outf, 'discriminator.params'))
+netG.save_parameters(os.path.join(outf, 'generator.params'))
+netD.save_parameters(os.path.join(outf, 'discriminator.params'))
diff --git a/example/gluon/embedding_learning/train.py b/example/gluon/embedding_learning/train.py
index 46f76b5..b8a5bf2 100644
--- a/example/gluon/embedding_learning/train.py
+++ b/example/gluon/embedding_learning/train.py
@@ -246,7 +246,7 @@ def train(epochs, ctx):
         if val_accs[0] > best_val:
             best_val = val_accs[0]
             logging.info('Saving %s.' % opt.save_model_prefix)
-            net.save_params('%s.params' % opt.save_model_prefix)
+            net.save_parameters('%s.params' % opt.save_model_prefix)
     return best_val
 
 
diff --git a/example/gluon/image_classification.py b/example/gluon/image_classification.py
index 6e2f1d6..b21e943 100644
--- a/example/gluon/image_classification.py
+++ b/example/gluon/image_classification.py
@@ -122,7 +122,7 @@ def get_model(model, ctx, opt):
 
     net = models.get_model(model, **kwargs)
     if opt.resume:
-        net.load_params(opt.resume)
+        net.load_parameters(opt.resume)
     elif not opt.use_pretrained:
         if model in ['alexnet']:
             net.initialize(mx.init.Normal())
@@ -176,12 +176,12 @@ def update_learning_rate(lr, trainer, epoch, ratio, steps):
 def save_checkpoint(epoch, top1, best_acc):
     if opt.save_frequency and (epoch + 1) % opt.save_frequency == 0:
         fname = os.path.join(opt.prefix, '%s_%d_acc_%.4f.params' % (opt.model, epoch, top1))
-        net.save_params(fname)
+        net.save_parameters(fname)
         logger.info('[Epoch %d] Saving checkpoint to %s with Accuracy: %.4f', epoch, fname, top1)
     if top1 > best_acc[0]:
         best_acc[0] = top1
         fname = os.path.join(opt.prefix, '%s_best.params' % (opt.model))
-        net.save_params(fname)
+        net.save_parameters(fname)
         logger.info('[Epoch %d] Saving checkpoint to %s with Accuracy: %.4f', epoch, fname, top1)
 
 def train(opt, ctx):
@@ -267,7 +267,7 @@ def main():
                 optimizer = 'sgd',
                 optimizer_params = {'learning_rate': opt.lr, 'wd': opt.wd, 'momentum': opt.momentum, 'multi_precision': True},
                 initializer = mx.init.Xavier(magnitude=2))
-        mod.save_params('image-classifier-%s-%d-final.params'%(opt.model, opt.epochs))
+        mod.save_parameters('image-classifier-%s-%d-final.params'%(opt.model, opt.epochs))
     else:
         if opt.mode == 'hybrid':
             net.hybridize()
diff --git a/example/gluon/mnist.py b/example/gluon/mnist.py
index 198d7ca..6aea3ab 100644
--- a/example/gluon/mnist.py
+++ b/example/gluon/mnist.py
@@ -117,7 +117,7 @@ def train(epochs, ctx):
         name, val_acc = test(ctx)
         print('[Epoch %d] Validation: %s=%f'%(epoch, name, val_acc))
 
-    net.save_params('mnist.params')
+    net.save_parameters('mnist.params')
 
 
 if __name__ == '__main__':
diff --git a/example/gluon/style_transfer/main.py b/example/gluon/style_transfer/main.py
index cab8211..dde992a 100644
--- a/example/gluon/style_transfer/main.py
+++ b/example/gluon/style_transfer/main.py
@@ -55,7 +55,7 @@ def train(args):
     style_model.initialize(init=mx.initializer.MSRAPrelu(), ctx=ctx)
     if args.resume is not None:
         print('Resuming, initializing using weight from {}.'.format(args.resume))
-        style_model.load_params(args.resume, ctx=ctx)
+        style_model.load_parameters(args.resume, ctx=ctx)
     print('style_model:',style_model)
     # optimizer and loss
     trainer = gluon.Trainer(style_model.collect_params(), 'adam',
@@ -121,14 +121,14 @@ def train(args):
                     str(count) + "_" + str(time.ctime()).replace(' ', '_') + "_" + str(
                     args.content_weight) + "_" + str(args.style_weight) + ".params"
                 save_model_path = os.path.join(args.save_model_dir, save_model_filename)
-                style_model.save_params(save_model_path)
+                style_model.save_parameters(save_model_path)
                 print("\nCheckpoint, trained model saved at", save_model_path)
 
     # save model
     save_model_filename = "Final_epoch_" + str(args.epochs) + "_" + str(time.ctime()).replace(' ', '_') + "_" + str(
         args.content_weight) + "_" + str(args.style_weight) + ".params"
     save_model_path = os.path.join(args.save_model_dir, save_model_filename)
-    style_model.save_params(save_model_path)
+    style_model.save_parameters(save_model_path)
     print("\nDone, trained model saved at", save_model_path)
 
 
@@ -143,7 +143,7 @@ def evaluate(args):
     style_image = utils.preprocess_batch(style_image)
     # model
     style_model = net.Net(ngf=args.ngf)
-    style_model.load_params(args.model, ctx=ctx)
+    style_model.load_parameters(args.model, ctx=ctx)
     # forward
     style_model.set_target(style_image)
     output = style_model(content_image)
diff --git a/example/gluon/super_resolution.py b/example/gluon/super_resolution.py
index 38c3bec..0f2f21f 100644
--- a/example/gluon/super_resolution.py
+++ b/example/gluon/super_resolution.py
@@ -168,13 +168,13 @@ def train(epoch, ctx):
         print('training mse at epoch %d: %s=%f'%(i, name, acc))
         test(ctx)
 
-    net.save_params('superres.params')
+    net.save_parameters('superres.params')
 
 def resolve(ctx):
     from PIL import Image
     if isinstance(ctx, list):
         ctx = [ctx[0]]
-    net.load_params('superres.params', ctx=ctx)
+    net.load_parameters('superres.params', ctx=ctx)
     img = Image.open(opt.resolve_img).convert('YCbCr')
     y, cb, cr = img.split()
     data = mx.nd.expand_dims(mx.nd.expand_dims(mx.nd.array(y), axis=0), axis=0)
diff --git a/example/gluon/tree_lstm/main.py b/example/gluon/tree_lstm/main.py
index d2fe464..ad5d59f 100644
--- a/example/gluon/tree_lstm/main.py
+++ b/example/gluon/tree_lstm/main.py
@@ -138,7 +138,7 @@ def test(ctx, data_iter, best, mode='validation', num_iter=-1):
         if test_r >= best:
             best = test_r
             logging.info('New optimum found: {}. Checkpointing.'.format(best))
-            net.save_params('childsum_tree_lstm_{}.params'.format(num_iter))
+            net.save_parameters('childsum_tree_lstm_{}.params'.format(num_iter))
             test(ctx, test_iter, -1, 'test')
         return best
 
diff --git a/example/gluon/word_language_model/train.py b/example/gluon/word_language_model/train.py
index 9e15263..7f0a916 100644
--- a/example/gluon/word_language_model/train.py
+++ b/example/gluon/word_language_model/train.py
@@ -185,7 +185,7 @@ def train():
         if val_L < best_val:
             best_val = val_L
             test_L = eval(test_data)
-            model.save_params(args.save)
+            model.save_parameters(args.save)
             print('test loss %.2f, test ppl %.2f'%(test_L, math.exp(test_L)))
         else:
             args.lr = args.lr*0.25
@@ -193,6 +193,6 @@ def train():
 
 if __name__ == '__main__':
     train()
-    model.load_params(args.save, context)
+    model.load_parameters(args.save, context)
     test_L = eval(test_data)
     print('Best test loss %.2f, test ppl %.2f'%(test_L, math.exp(test_L)))
diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py
index 7406a5d..689b7ab 100644
--- a/python/mxnet/gluon/block.py
+++ b/python/mxnet/gluon/block.py
@@ -16,7 +16,7 @@
 # under the License.
 
 # coding: utf-8
-# pylint: disable= arguments-differ
+# pylint: disable= arguments-differ, too-many-lines
 """Base container class for all neural network models."""
 __all__ = ['Block', 'HybridBlock', 'SymbolBlock']
 
@@ -307,7 +307,7 @@ class Block(object):
             ret.update(child._collect_params_with_prefix(prefix + name))
         return ret
 
-    def save_params(self, filename):
+    def save_parameters(self, filename):
         """Save parameters to file.
 
         filename : str
@@ -317,8 +317,23 @@ class Block(object):
         arg_dict = {key : val._reduce() for key, val in params.items()}
         ndarray.save(filename, arg_dict)
 
-    def load_params(self, filename, ctx=None, allow_missing=False,
-                    ignore_extra=False):
+    def save_params(self, filename):
+        """[Deprecated] Please use save_parameters.
+
+        Save parameters to file.
+
+        filename : str
+            Path to file.
+        """
+        warnings.warn("save_params is deprecated. Please use save_parameters.")
+        try:
+            self.collect_params().save(filename, strip_prefix=self.prefix)
+        except ValueError as e:
+            raise ValueError('%s\nsave_params is deprecated. Using ' \
+                              'save_parameters may resolve this error.'%e.message)
+
+    def load_parameters(self, filename, ctx=None, allow_missing=False,
+                        ignore_extra=False):
         """Load parameters from file.
 
         filename : str
@@ -358,6 +373,25 @@ class Block(object):
             if name in params:
                 params[name]._load_init(loaded[name], ctx)
 
+    def load_params(self, filename, ctx=None, allow_missing=False,
+                    ignore_extra=False):
+        """[Deprecated] Please use load_parameters.
+
+        Load parameters from file.
+
+        filename : str
+            Path to parameter file.
+        ctx : Context or list of Context, default cpu()
+            Context(s) initialize loaded parameters on.
+        allow_missing : bool, default False
+            Whether to silently skip loading parameters not represents in the file.
+        ignore_extra : bool, default False
+            Whether to silently ignore parameters from the file that are not
+            present in this Block.
+        """
+        warnings.warn("load_params is deprecated. Please use load_parameters.")
+        self.load_parameters(filename, ctx, allow_missing, ignore_extra)
+
     def register_child(self, block, name=None):
         """Registers block as a child of self. :py:class:`Block` s assigned to self as
         attributes will be registered automatically."""
@@ -771,8 +805,8 @@ class HybridBlock(Block):
         self._infer_attrs('infer_type', 'dtype', *args)
 
     def export(self, path, epoch=0):
-        """Export HybridBlock to json format that can be loaded by `mxnet.mod.Module`
-        or the C++ interface.
+        """Export HybridBlock to json format that can be loaded by
+        `SymbolBlock.imports`, `mxnet.mod.Module` or the C++ interface.
 
         .. note:: When there are only one input, it will have name `data`. When there
                   Are more than one inputs, they will be named as `data0`, `data1`, etc.
@@ -886,6 +920,50 @@ class SymbolBlock(HybridBlock):
     >>> x = mx.nd.random.normal(shape=(16, 3, 224, 224))
     >>> print(feat_model(x))
     """
+    @staticmethod
+    def imports(symbol_file, input_names, param_file=None, ctx=None):
+        """Import model previously saved by `HybridBlock.export` or
+        `Module.save_checkpoint` as a SymbolBlock for use in Gluon.
+
+        Parameters
+        ----------
+        symbol_file : str
+            Path to symbol file.
+        input_names : list of str
+            List of input variable names
+        param_file : str, optional
+            Path to parameter file.
+        ctx : Context, default None
+            The context to initialize SymbolBlock on.
+
+        Returns
+        -------
+        SymbolBlock
+            SymbolBlock loaded from symbol and parameter files.
+
+        Examples
+        --------
+        >>> net1 = gluon.model_zoo.vision.resnet18_v1(
+        ...     prefix='resnet', pretrained=True)
+        >>> net1.hybridize()
+        >>> x = mx.nd.random.normal(shape=(1, 3, 32, 32))
+        >>> out1 = net1(x)
+        >>> net1.export('net1', epoch=1)
+        >>>
+        >>> net2 = gluon.SymbolBlock.imports(
+        ...     'net1-symbol.json', ['data'], 'net1-0001.params')
+        >>> out2 = net2(x)
+        """
+        sym = symbol.load(symbol_file)
+        if isinstance(input_names, str):
+            input_names = [input_names]
+        inputs = [symbol.var(i) for i in input_names]
+        ret = SymbolBlock(sym, inputs)
+        if param_file is not None:
+            ret.collect_params().load(param_file, ctx=ctx)
+        return ret
+
+
     def __init__(self, outputs, inputs, params=None):
         super(SymbolBlock, self).__init__(prefix=None, params=None)
         self._prefix = ''
diff --git a/python/mxnet/gluon/model_zoo/vision/alexnet.py b/python/mxnet/gluon/model_zoo/vision/alexnet.py
index 5549947..fdb0062 100644
--- a/python/mxnet/gluon/model_zoo/vision/alexnet.py
+++ b/python/mxnet/gluon/model_zoo/vision/alexnet.py
@@ -83,5 +83,5 @@ def alexnet(pretrained=False, ctx=cpu(),
     net = AlexNet(**kwargs)
     if pretrained:
         from ..model_store import get_model_file
-        net.load_params(get_model_file('alexnet', root=root), ctx=ctx)
+        net.load_parameters(get_model_file('alexnet', root=root), ctx=ctx)
     return net
diff --git a/python/mxnet/gluon/model_zoo/vision/densenet.py b/python/mxnet/gluon/model_zoo/vision/densenet.py
index 8353367..b03f5ce 100644
--- a/python/mxnet/gluon/model_zoo/vision/densenet.py
+++ b/python/mxnet/gluon/model_zoo/vision/densenet.py
@@ -141,7 +141,7 @@ def get_densenet(num_layers, pretrained=False, ctx=cpu(),
     net = DenseNet(num_init_features, growth_rate, block_config, **kwargs)
     if pretrained:
         from ..model_store import get_model_file
-        net.load_params(get_model_file('densenet%d'%(num_layers), root=root), ctx=ctx)
+        net.load_parameters(get_model_file('densenet%d'%(num_layers), root=root), ctx=ctx)
     return net
 
 def densenet121(**kwargs):
diff --git a/python/mxnet/gluon/model_zoo/vision/inception.py b/python/mxnet/gluon/model_zoo/vision/inception.py
index 6d75050..7c54691 100644
--- a/python/mxnet/gluon/model_zoo/vision/inception.py
+++ b/python/mxnet/gluon/model_zoo/vision/inception.py
@@ -216,5 +216,5 @@ def inception_v3(pretrained=False, ctx=cpu(),
     net = Inception3(**kwargs)
     if pretrained:
         from ..model_store import get_model_file
-        net.load_params(get_model_file('inceptionv3', root=root), ctx=ctx)
+        net.load_parameters(get_model_file('inceptionv3', root=root), ctx=ctx)
     return net
diff --git a/python/mxnet/gluon/model_zoo/vision/mobilenet.py b/python/mxnet/gluon/model_zoo/vision/mobilenet.py
index 5b4c9a8..1a2c9b9 100644
--- a/python/mxnet/gluon/model_zoo/vision/mobilenet.py
+++ b/python/mxnet/gluon/model_zoo/vision/mobilenet.py
@@ -213,7 +213,7 @@ def get_mobilenet(multiplier, pretrained=False, ctx=cpu(),
         version_suffix = '{0:.2f}'.format(multiplier)
         if version_suffix in ('1.00', '0.50'):
             version_suffix = version_suffix[:-1]
-        net.load_params(
+        net.load_parameters(
             get_model_file('mobilenet%s' % version_suffix, root=root), ctx=ctx)
     return net
 
@@ -245,7 +245,7 @@ def get_mobilenet_v2(multiplier, pretrained=False, ctx=cpu(),
         version_suffix = '{0:.2f}'.format(multiplier)
         if version_suffix in ('1.00', '0.50'):
             version_suffix = version_suffix[:-1]
-        net.load_params(
+        net.load_parameters(
             get_model_file('mobilenetv2_%s' % version_suffix, root=root), ctx=ctx)
     return net
 
diff --git a/python/mxnet/gluon/model_zoo/vision/resnet.py b/python/mxnet/gluon/model_zoo/vision/resnet.py
index 5ee67b5..da279b8 100644
--- a/python/mxnet/gluon/model_zoo/vision/resnet.py
+++ b/python/mxnet/gluon/model_zoo/vision/resnet.py
@@ -386,8 +386,8 @@ def get_resnet(version, num_layers, pretrained=False, ctx=cpu(),
     net = resnet_class(block_class, layers, channels, **kwargs)
     if pretrained:
         from ..model_store import get_model_file
-        net.load_params(get_model_file('resnet%d_v%d'%(num_layers, version),
-                                       root=root), ctx=ctx)
+        net.load_parameters(get_model_file('resnet%d_v%d'%(num_layers, version),
+                                           root=root), ctx=ctx)
     return net
 
 def resnet18_v1(**kwargs):
diff --git a/python/mxnet/gluon/model_zoo/vision/squeezenet.py b/python/mxnet/gluon/model_zoo/vision/squeezenet.py
index 09f62a5..aaff4c3 100644
--- a/python/mxnet/gluon/model_zoo/vision/squeezenet.py
+++ b/python/mxnet/gluon/model_zoo/vision/squeezenet.py
@@ -132,7 +132,7 @@ def get_squeezenet(version, pretrained=False, ctx=cpu(),
     net = SqueezeNet(version, **kwargs)
     if pretrained:
         from ..model_store import get_model_file
-        net.load_params(get_model_file('squeezenet%s'%version, root=root), ctx=ctx)
+        net.load_parameters(get_model_file('squeezenet%s'%version, root=root), ctx=ctx)
     return net
 
 def squeezenet1_0(**kwargs):
diff --git a/python/mxnet/gluon/model_zoo/vision/vgg.py b/python/mxnet/gluon/model_zoo/vision/vgg.py
index dbae538..a3b1685 100644
--- a/python/mxnet/gluon/model_zoo/vision/vgg.py
+++ b/python/mxnet/gluon/model_zoo/vision/vgg.py
@@ -114,8 +114,8 @@ def get_vgg(num_layers, pretrained=False, ctx=cpu(),
     if pretrained:
         from ..model_store import get_model_file
         batch_norm_suffix = '_bn' if kwargs.get('batch_norm') else ''
-        net.load_params(get_model_file('vgg%d%s'%(num_layers, batch_norm_suffix),
-                                       root=root), ctx=ctx)
+        net.load_parameters(get_model_file('vgg%d%s'%(num_layers, batch_norm_suffix),
+                                           root=root), ctx=ctx)
     return net
 
 def vgg11(**kwargs):
diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py
index 8ad86d4..bb61af1 100644
--- a/tests/python/unittest/test_gluon.py
+++ b/tests/python/unittest/test_gluon.py
@@ -202,20 +202,20 @@ def test_parameter_sharing():
     net1.collect_params().initialize()
     net2(mx.nd.zeros((3, 5)))
 
-    net1.save_params('net1.params')
+    net1.save_parameters('net1.params')
 
     net3 = Net(prefix='net3_')
-    net3.load_params('net1.params', mx.cpu())
+    net3.load_parameters('net1.params', mx.cpu())
 
     net4 = Net(prefix='net4_')
     net5 = Net(prefix='net5_', in_units=5, params=net4.collect_params())
     net4.collect_params().initialize()
     net5(mx.nd.zeros((3, 5)))
 
-    net4.save_params('net4.params')
+    net4.save_parameters('net4.params')
 
     net6 = Net(prefix='net6_')
-    net6.load_params('net4.params', mx.cpu())
+    net6.load_parameters('net4.params', mx.cpu())
 
 
 @with_seed()
@@ -777,7 +777,7 @@ def test_export():
     model = gluon.model_zoo.vision.resnet18_v1(
         prefix='resnet', ctx=ctx, pretrained=True)
     model.hybridize()
-    data = mx.nd.random.normal(shape=(1, 3, 224, 224))
+    data = mx.nd.random.normal(shape=(1, 3, 32, 32))
     out = model(data)
 
     model.export('gluon')
@@ -795,6 +795,22 @@ def test_export():
 
     assert_almost_equal(out.asnumpy(), out2.asnumpy())
 
+@with_seed()
+def test_import():
+    ctx = mx.context.current_context()
+    net1 = gluon.model_zoo.vision.resnet18_v1(
+        prefix='resnet', ctx=ctx, pretrained=True)
+    net1.hybridize()
+    data = mx.nd.random.normal(shape=(1, 3, 32, 32))
+    out1 = net1(data)
+
+    net1.export('net1', epoch=1)
+
+    net2 = gluon.SymbolBlock.imports(
+        'net1-symbol.json', ['data'], 'net1-0001.params', ctx)
+    out2 = net2(data)
+
+    assert_almost_equal(out1.asnumpy(), out2.asnumpy())
 
 @with_seed()
 def test_hybrid_stale_cache():
@@ -911,7 +927,7 @@ def test_fill_shape_load():
     net1.hybridize()
     net1.initialize(ctx=ctx)
     net1(mx.nd.ones((2,3,5,7), ctx))
-    net1.save_params('net_fill.params')
+    net1.save_parameters('net_fill.params')
 
     net2 = nn.HybridSequential()
     with net2.name_scope():
@@ -920,7 +936,7 @@ def test_fill_shape_load():
                  nn.Dense(10))
     net2.hybridize()
     net2.initialize()
-    net2.load_params('net_fill.params', ctx)
+    net2.load_parameters('net_fill.params', ctx)
     assert net2[0].weight.shape[1] == 3, net2[0].weight.shape[1]
     assert net2[1].gamma.shape[0] == 64, net2[1].gamma.shape[0]
     assert net2[2].weight.shape[1] == 3072, net2[2].weight.shape[1]
@@ -1066,12 +1082,12 @@ def test_req():
 @with_seed()
 def test_save_load():
     net = mx.gluon.model_zoo.vision.get_resnet(1, 18, pretrained=True)
-    net.save_params('test_save_load.params')
+    net.save_parameters('test_save_load.params')
 
     net = mx.gluon.model_zoo.vision.get_resnet(1, 18)
     net.output = mx.gluon.nn.Dense(1000)
 
-    net.load_params('test_save_load.params')
+    net.load_parameters('test_save_load.params')
 
 @with_seed()
 def test_symbol_block_save_load():
@@ -1096,10 +1112,10 @@ def test_symbol_block_save_load():
     net1.initialize(mx.init.Normal())
     net1.hybridize()
     net1(mx.nd.random.normal(shape=(1, 3, 32, 32)))
-    net1.save_params('./test_symbol_block_save_load.params')
+    net1.save_parameters('./test_symbol_block_save_load.params')
 
     net2 = Net()
-    net2.load_params('./test_symbol_block_save_load.params', ctx=mx.cpu())
+    net2.load_parameters('./test_symbol_block_save_load.params', ctx=mx.cpu())
 
 
 @with_seed()
@@ -1253,6 +1269,22 @@ def test_summary():
     assert_raises(AssertionError, net.summary, mx.nd.ones((32, 3, 224, 224)))
 
 
+@with_seed()
+def test_legacy_save_params():
+    net = gluon.nn.HybridSequential(prefix='')
+    with net.name_scope():
+        net.add(gluon.nn.Conv2D(10, (3, 3)))
+        net.add(gluon.nn.Dense(50))
+    net.initialize()
+    net(mx.nd.ones((1,1,50,50)))
+    a = net(mx.sym.var('data'))
+    a.save('test.json')
+    net.save_params('test.params')
+    model = gluon.nn.SymbolBlock(outputs=mx.sym.load_json(open('test.json', 'r').read()),
+                                     inputs=mx.sym.var('data'))
+    model.load_params('test.params', ctx=mx.cpu())
+
+
 if __name__ == '__main__':
     import nose
     nose.runmodule()

-- 
To stop receiving notification emails like this one, please contact
jxie@apache.org.