You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2018/01/05 19:44:12 UTC

[incubator-mxnet] branch master updated: example/autoencoder fixes for MXNet 1.0.0 and pylint and addition of README (#9097)

This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 2d28593  example/autoencoder fixes for MXNet 1.0.0 and pylint and addition of README (#9097)
2d28593 is described below

commit 2d28593e6d5dd76ee9b95a0205c0b040393fe183
Author: Sina Afrooze <si...@gmail.com>
AuthorDate: Fri Jan 5 11:44:09 2018 -0800

    example/autoencoder fixes for MXNet 1.0.0 and pylint and addition of README (#9097)
    
    * Fixes for MXNet 1.0.0 and pylint and addition of README
    
    * Fixed pylint issues
    * Added README
    * Added --gpu command line option
    * Updated initializer call to new API with InitDesc()
    
    * Simplified code for extracting debug symbols from graph
---
 example/autoencoder/README.md      |  16 ++++++
 example/autoencoder/autoencoder.py | 101 ++++++++++++++++++++++---------------
 example/autoencoder/data.py        |   4 ++
 example/autoencoder/mnist_sae.py   |  40 ++++++++++-----
 example/autoencoder/model.py       |  15 +++---
 example/autoencoder/solver.py      |  47 ++++++++++-------
 6 files changed, 141 insertions(+), 82 deletions(-)

diff --git a/example/autoencoder/README.md b/example/autoencoder/README.md
new file mode 100644
index 0000000..7efa30a
--- /dev/null
+++ b/example/autoencoder/README.md
@@ -0,0 +1,16 @@
+# Example of Autencoder
+
+Autoencoder architecture is often used for unsupervised feature learning. This [link](http://ufldl.stanford.edu/tutorial/unsupervised/Autoencoders/) contains an introduction tutorial to autoencoders. This example illustrates a simple autoencoder using stack of fully-connected layers for both encoder and decoder. The number of hidden layers and size of each hidden layer can be customized using command line arguments.
+
+## Training Stages
+This example uses a two-stage training. In the first stage, each layer of encoder and its corresponding decoder are trained separately in a layer-wise training loop. In the second stage the entire autoencoder network is fine-tuned end to end.
+
+## Dataset
+The dataset used in this example is [MNIST](http://yann.lecun.com/exdb/mnist/) dataset. This example uses scikit-learn module to download this dataset.
+
+## Simple autoencoder example
+mnist_sae.py: this example uses a simple auto-encoder architecture to encode and decode MNIST images with size of 28x28 pixels. It contains several command line arguments. Pass -h (or --help) to view all available options. To start the training on CPU (use --gpu option for training on GPU) using default options:
+
+```
+python mnist_sae.py
+```
diff --git a/example/autoencoder/autoencoder.py b/example/autoencoder/autoencoder.py
index 5089a4d..47931e5 100644
--- a/example/autoencoder/autoencoder.py
+++ b/example/autoencoder/autoencoder.py
@@ -15,20 +15,20 @@
 # specific language governing permissions and limitations
 # under the License.
 
-# pylint: skip-file
+# pylint: disable=missing-docstring, arguments-differ
+from __future__ import print_function
+
+import logging
+
 import mxnet as mx
-from mxnet import misc
 import numpy as np
 import model
-import logging
 from solver import Solver, Monitor
-try:
-   import cPickle as pickle
-except:
-   import pickle
+
 
 class AutoEncoderModel(model.MXModel):
-    def setup(self, dims, sparseness_penalty=None, pt_dropout=None, ft_dropout=None, input_act=None, internal_act='relu', output_act=None):
+    def setup(self, dims, sparseness_penalty=None, pt_dropout=None,
+              ft_dropout=None, input_act=None, internal_act='relu', output_act=None):
         self.N = len(dims) - 1
         self.dims = dims
         self.stacks = []
@@ -52,22 +52,26 @@ class AutoEncoderModel(model.MXModel):
             else:
                 encoder_act = internal_act
                 odropout = pt_dropout
-            istack, iargs, iargs_grad, iargs_mult, iauxs = self.make_stack(i, self.data, dims[i], dims[i+1],
-                                                sparseness_penalty, idropout, odropout, encoder_act, decoder_act)
+            istack, iargs, iargs_grad, iargs_mult, iauxs = self.make_stack(
+                i, self.data, dims[i], dims[i+1], sparseness_penalty,
+                idropout, odropout, encoder_act, decoder_act
+            )
             self.stacks.append(istack)
             self.args.update(iargs)
             self.args_grad.update(iargs_grad)
             self.args_mult.update(iargs_mult)
             self.auxs.update(iauxs)
-        self.encoder, self.internals = self.make_encoder(self.data, dims, sparseness_penalty, ft_dropout, internal_act, output_act)
-        self.decoder = self.make_decoder(self.encoder, dims, sparseness_penalty, ft_dropout, internal_act, input_act)
+        self.encoder, self.internals = self.make_encoder(
+            self.data, dims, sparseness_penalty, ft_dropout, internal_act, output_act)
+        self.decoder = self.make_decoder(
+            self.encoder, dims, sparseness_penalty, ft_dropout, internal_act, input_act)
         if input_act == 'softmax':
             self.loss = self.decoder
         else:
             self.loss = mx.symbol.LinearRegressionOutput(data=self.decoder, label=self.data)
 
-    def make_stack(self, istack, data, num_input, num_hidden, sparseness_penalty=None, idropout=None,
-                   odropout=None, encoder_act='relu', decoder_act='relu'):
+    def make_stack(self, istack, data, num_input, num_hidden, sparseness_penalty=None,
+                   idropout=None, odropout=None, encoder_act='relu', decoder_act='relu'):
         x = data
         if idropout:
             x = mx.symbol.Dropout(data=x, p=idropout)
@@ -75,7 +79,8 @@ class AutoEncoderModel(model.MXModel):
         if encoder_act:
             x = mx.symbol.Activation(data=x, act_type=encoder_act)
             if encoder_act == 'sigmoid' and sparseness_penalty:
-                x = mx.symbol.IdentityAttachKLSparseReg(data=x, name='sparse_encoder_%d' % istack, penalty=sparseness_penalty)
+                x = mx.symbol.IdentityAttachKLSparseReg(
+                    data=x, name='sparse_encoder_%d' % istack, penalty=sparseness_penalty)
         if odropout:
             x = mx.symbol.Dropout(data=x, p=odropout)
         x = mx.symbol.FullyConnected(name='decoder_%d'%istack, data=x, num_hidden=num_input)
@@ -84,7 +89,8 @@ class AutoEncoderModel(model.MXModel):
         elif decoder_act:
             x = mx.symbol.Activation(data=x, act_type=decoder_act)
             if decoder_act == 'sigmoid' and sparseness_penalty:
-                x = mx.symbol.IdentityAttachKLSparseReg(data=x, name='sparse_decoder_%d' % istack, penalty=sparseness_penalty)
+                x = mx.symbol.IdentityAttachKLSparseReg(
+                    data=x, name='sparse_decoder_%d' % istack, penalty=sparseness_penalty)
             x = mx.symbol.LinearRegressionOutput(data=x, label=data)
         else:
             x = mx.symbol.LinearRegressionOutput(data=x, label=data)
@@ -103,16 +109,17 @@ class AutoEncoderModel(model.MXModel):
                      'decoder_%d_bias'%istack: 2.0,}
         auxs = {}
         if encoder_act == 'sigmoid' and sparseness_penalty:
-            auxs['sparse_encoder_%d_moving_avg' % istack] = mx.nd.ones((num_hidden), self.xpu) * 0.5
+            auxs['sparse_encoder_%d_moving_avg' % istack] = mx.nd.ones(num_hidden, self.xpu) * 0.5
         if decoder_act == 'sigmoid' and sparseness_penalty:
-            auxs['sparse_decoder_%d_moving_avg' % istack] = mx.nd.ones((num_input), self.xpu) * 0.5
+            auxs['sparse_decoder_%d_moving_avg' % istack] = mx.nd.ones(num_input, self.xpu) * 0.5
         init = mx.initializer.Uniform(0.07)
-        for k,v in args.items():
-            init(k,v)
+        for k, v in args.items():
+            init(mx.initializer.InitDesc(k), v)
 
         return x, args, args_grad, args_mult, auxs
 
-    def make_encoder(self, data, dims, sparseness_penalty=None, dropout=None, internal_act='relu', output_act=None):
+    def make_encoder(self, data, dims, sparseness_penalty=None, dropout=None, internal_act='relu',
+                     output_act=None):
         x = data
         internals = []
         N = len(dims) - 1
@@ -120,38 +127,45 @@ class AutoEncoderModel(model.MXModel):
             x = mx.symbol.FullyConnected(name='encoder_%d'%i, data=x, num_hidden=dims[i+1])
             if internal_act and i < N-1:
                 x = mx.symbol.Activation(data=x, act_type=internal_act)
-                if internal_act=='sigmoid' and sparseness_penalty:
-                    x = mx.symbol.IdentityAttachKLSparseReg(data=x, name='sparse_encoder_%d' % i, penalty=sparseness_penalty)
+                if internal_act == 'sigmoid' and sparseness_penalty:
+                    x = mx.symbol.IdentityAttachKLSparseReg(
+                        data=x, name='sparse_encoder_%d' % i, penalty=sparseness_penalty)
             elif output_act and i == N-1:
                 x = mx.symbol.Activation(data=x, act_type=output_act)
-                if output_act=='sigmoid' and sparseness_penalty:
-                    x = mx.symbol.IdentityAttachKLSparseReg(data=x, name='sparse_encoder_%d' % i, penalty=sparseness_penalty)
+                if output_act == 'sigmoid' and sparseness_penalty:
+                    x = mx.symbol.IdentityAttachKLSparseReg(
+                        data=x, name='sparse_encoder_%d' % i, penalty=sparseness_penalty)
             if dropout:
                 x = mx.symbol.Dropout(data=x, p=dropout)
             internals.append(x)
         return x, internals
 
-    def make_decoder(self, feature, dims, sparseness_penalty=None, dropout=None, internal_act='relu', input_act=None):
+    def make_decoder(self, feature, dims, sparseness_penalty=None, dropout=None,
+                     internal_act='relu', input_act=None):
         x = feature
         N = len(dims) - 1
         for i in reversed(range(N)):
             x = mx.symbol.FullyConnected(name='decoder_%d'%i, data=x, num_hidden=dims[i])
             if internal_act and i > 0:
                 x = mx.symbol.Activation(data=x, act_type=internal_act)
-                if internal_act=='sigmoid' and sparseness_penalty:
-                    x = mx.symbol.IdentityAttachKLSparseReg(data=x, name='sparse_decoder_%d' % i, penalty=sparseness_penalty)
+                if internal_act == 'sigmoid' and sparseness_penalty:
+                    x = mx.symbol.IdentityAttachKLSparseReg(
+                        data=x, name='sparse_decoder_%d' % i, penalty=sparseness_penalty)
             elif input_act and i == 0:
                 x = mx.symbol.Activation(data=x, act_type=input_act)
-                if input_act=='sigmoid' and sparseness_penalty:
-                    x = mx.symbol.IdentityAttachKLSparseReg(data=x, name='sparse_decoder_%d' % i, penalty=sparseness_penalty)
+                if input_act == 'sigmoid' and sparseness_penalty:
+                    x = mx.symbol.IdentityAttachKLSparseReg(
+                        data=x, name='sparse_decoder_%d' % i, penalty=sparseness_penalty)
             if dropout and i > 0:
                 x = mx.symbol.Dropout(data=x, p=dropout)
         return x
 
-    def layerwise_pretrain(self, X, batch_size, n_iter, optimizer, l_rate, decay, lr_scheduler=None, print_every=1000):
+    def layerwise_pretrain(self, X, batch_size, n_iter, optimizer, l_rate, decay,
+                           lr_scheduler=None, print_every=1000):
         def l2_norm(label, pred):
             return np.mean(np.square(label-pred))/2.0
-        solver = Solver(optimizer, momentum=0.9, wd=decay, learning_rate=l_rate, lr_scheduler=lr_scheduler)
+        solver = Solver(optimizer, momentum=0.9, wd=decay, learning_rate=l_rate,
+                        lr_scheduler=lr_scheduler)
         solver.set_metric(mx.metric.CustomMetric(l2_norm))
         solver.set_monitor(Monitor(print_every))
         data_iter = mx.io.NDArrayIter({'data': X}, batch_size=batch_size, shuffle=True,
@@ -160,18 +174,21 @@ class AutoEncoderModel(model.MXModel):
             if i == 0:
                 data_iter_i = data_iter
             else:
-                X_i = list(model.extract_feature(self.internals[i-1], self.args, self.auxs,
-                                            data_iter, X.shape[0], self.xpu).values())[0]
+                X_i = list(model.extract_feature(
+                    self.internals[i-1], self.args, self.auxs, data_iter, X.shape[0],
+                    self.xpu).values())[0]
                 data_iter_i = mx.io.NDArrayIter({'data': X_i}, batch_size=batch_size,
                                                 last_batch_handle='roll_over')
-            logging.info('Pre-training layer %d...'%i)
-            solver.solve(self.xpu, self.stacks[i], self.args, self.args_grad, self.auxs, data_iter_i,
-                         0, n_iter, {}, False)
+            logging.info('Pre-training layer %d...', i)
+            solver.solve(self.xpu, self.stacks[i], self.args, self.args_grad, self.auxs,
+                         data_iter_i, 0, n_iter, {}, False)
 
-    def finetune(self, X, batch_size, n_iter, optimizer, l_rate, decay, lr_scheduler=None, print_every=1000):
+    def finetune(self, X, batch_size, n_iter, optimizer, l_rate, decay, lr_scheduler=None,
+                 print_every=1000):
         def l2_norm(label, pred):
-           return np.mean(np.square(label-pred))/2.0
-        solver = Solver(optimizer, momentum=0.9, wd=decay, learning_rate=l_rate, lr_scheduler=lr_scheduler)
+            return np.mean(np.square(label-pred))/2.0
+        solver = Solver(optimizer, momentum=0.9, wd=decay, learning_rate=l_rate,
+                        lr_scheduler=lr_scheduler)
         solver.set_metric(mx.metric.CustomMetric(l2_norm))
         solver.set_monitor(Monitor(print_every))
         data_iter = mx.io.NDArrayIter({'data': X}, batch_size=batch_size, shuffle=True,
@@ -184,6 +201,6 @@ class AutoEncoderModel(model.MXModel):
         batch_size = 100
         data_iter = mx.io.NDArrayIter({'data': X}, batch_size=batch_size, shuffle=False,
                                       last_batch_handle='pad')
-        Y = list(model.extract_feature(self.loss, self.args, self.auxs, data_iter,
-                                 X.shape[0], self.xpu).values())[0]
+        Y = list(model.extract_feature(
+            self.loss, self.args, self.auxs, data_iter, X.shape[0], self.xpu).values())[0]
         return np.mean(np.square(Y-X))/2.0
diff --git a/example/autoencoder/data.py b/example/autoencoder/data.py
index d6a25ed..99dd4eb 100644
--- a/example/autoencoder/data.py
+++ b/example/autoencoder/data.py
@@ -15,10 +15,14 @@
 # specific language governing permissions and limitations
 # under the License.
 
+# pylint: disable=missing-docstring
+from __future__ import print_function
+
 import os
 import numpy as np
 from sklearn.datasets import fetch_mldata
 
+
 def get_mnist():
     np.random.seed(1234) # set seed for deterministic ordering
     data_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
diff --git a/example/autoencoder/mnist_sae.py b/example/autoencoder/mnist_sae.py
index 8ab4b0b..886f2a1 100644
--- a/example/autoencoder/mnist_sae.py
+++ b/example/autoencoder/mnist_sae.py
@@ -15,31 +15,38 @@
 # specific language governing permissions and limitations
 # under the License.
 
+# pylint: disable=missing-docstring
 from __future__ import print_function
+
 import argparse
+import logging
+
 import mxnet as mx
 import numpy as np
-import logging
 import data
 from autoencoder import AutoEncoderModel
 
-parser = argparse.ArgumentParser(description='Train an auto-encoder model for mnist dataset.')
+parser = argparse.ArgumentParser(description='Train an auto-encoder model for mnist dataset.',
+                                 formatter_class=argparse.ArgumentDefaultsHelpFormatter)
 parser.add_argument('--print-every', type=int, default=1000,
-                    help='the interval of printing during training.')
+                    help='interval of printing during training.')
 parser.add_argument('--batch-size', type=int, default=256,
-                    help='the batch size used for training.')
+                    help='batch size used for training.')
 parser.add_argument('--pretrain-num-iter', type=int, default=50000,
-                    help='the number of iterations for pretraining.')
+                    help='number of iterations for pretraining.')
 parser.add_argument('--finetune-num-iter', type=int, default=100000,
-                    help='the number of iterations for fine-tuning.')
+                    help='number of iterations for fine-tuning.')
 parser.add_argument('--visualize', action='store_true',
                     help='whether to visualize the original image and the reconstructed one.')
 parser.add_argument('--num-units', type=str, default="784,500,500,2000,10",
-                    help='the number of hidden units for the layers of the encoder.' \
-                         'The decoder layers are created in the reverse order.')
+                    help='number of hidden units for the layers of the encoder.'
+                         'The decoder layers are created in the reverse order. First dimension '
+                         'must be 784 (28x28) to match mnist image dimension.')
+parser.add_argument('--gpu', action='store_true',
+                    help='whether to start training on GPU.')
 
 # set to INFO to see less information during training
-logging.basicConfig(level=logging.DEBUG)
+logging.basicConfig(level=logging.INFO)
 opt = parser.parse_args()
 logging.info(opt)
 print_every = opt.print_every
@@ -47,21 +54,26 @@ batch_size = opt.batch_size
 pretrain_num_iter = opt.pretrain_num_iter
 finetune_num_iter = opt.finetune_num_iter
 visualize = opt.visualize
+gpu = opt.gpu
 layers = [int(i) for i in opt.num_units.split(',')]
 
+
 if __name__ == '__main__':
-    ae_model = AutoEncoderModel(mx.cpu(0), layers, pt_dropout=0.2,
-        internal_act='relu', output_act='relu')
+    xpu = mx.gpu() if gpu else mx.cpu()
+    print("Training on {}".format("GPU" if gpu else "CPU"))
+
+    ae_model = AutoEncoderModel(xpu, layers, pt_dropout=0.2, internal_act='relu',
+                                output_act='relu')
 
     X, _ = data.get_mnist()
     train_X = X[:60000]
     val_X = X[60000:]
 
     ae_model.layerwise_pretrain(train_X, batch_size, pretrain_num_iter, 'sgd', l_rate=0.1,
-                                decay=0.0, lr_scheduler=mx.misc.FactorScheduler(20000,0.1),
+                                decay=0.0, lr_scheduler=mx.lr_scheduler.FactorScheduler(20000, 0.1),
                                 print_every=print_every)
     ae_model.finetune(train_X, batch_size, finetune_num_iter, 'sgd', l_rate=0.1, decay=0.0,
-                      lr_scheduler=mx.misc.FactorScheduler(20000,0.1), print_every=print_every)
+                      lr_scheduler=mx.lr_scheduler.FactorScheduler(20000, 0.1), print_every=print_every)
     ae_model.save('mnist_pt.arg')
     ae_model.load('mnist_pt.arg')
     print("Training error:", ae_model.eval(train_X))
@@ -79,7 +91,7 @@ if __name__ == '__main__':
                                                   ae_model.auxs, data_iter, 1,
                                                   ae_model.xpu).values()[0]
             print("original image")
-            plt.imshow(original_image.reshape((28,28)))
+            plt.imshow(original_image.reshape((28, 28)))
             plt.show()
             print("reconstructed image")
             plt.imshow(reconstructed_image.reshape((28, 28)))
diff --git a/example/autoencoder/model.py b/example/autoencoder/model.py
index 1aaae1b..c1b7221 100644
--- a/example/autoencoder/model.py
+++ b/example/autoencoder/model.py
@@ -15,15 +15,15 @@
 # specific language governing permissions and limitations
 # under the License.
 
-# pylint: skip-file
+# pylint: disable=missing-docstring
+from __future__ import print_function
+
 import mxnet as mx
 import numpy as np
-import logging
-from solver import Solver, Monitor
 try:
-   import cPickle as pickle
-except:
-   import pickle
+    import cPickle as pickle
+except ModuleNotFoundError:
+    import pickle
 
 
 def extract_feature(sym, args, auxs, data_iter, N, xpu=mx.cpu()):
@@ -31,7 +31,7 @@ def extract_feature(sym, args, auxs, data_iter, N, xpu=mx.cpu()):
     input_names = [k for k, shape in data_iter.provide_data]
     args = dict(args, **dict(zip(input_names, input_buffs)))
     exe = sym.bind(xpu, args=args, aux_states=auxs)
-    outputs = [[] for i in exe.outputs]
+    outputs = [[] for _ in exe.outputs]
     output_buffs = None
 
     data_iter.hard_reset()
@@ -51,6 +51,7 @@ def extract_feature(sym, args, auxs, data_iter, N, xpu=mx.cpu()):
     outputs = [np.concatenate(i, axis=0)[:N] for i in outputs]
     return dict(zip(sym.list_outputs(), outputs))
 
+
 class MXModel(object):
     def __init__(self, xpu=mx.cpu(), *args, **kwargs):
         self.xpu = xpu
diff --git a/example/autoencoder/solver.py b/example/autoencoder/solver.py
index 69d8836..0c990ce 100644
--- a/example/autoencoder/solver.py
+++ b/example/autoencoder/solver.py
@@ -15,10 +15,14 @@
 # specific language governing permissions and limitations
 # under the License.
 
-# pylint: skip-file
+# pylint: disable=missing-docstring
+from __future__ import print_function
+
+import logging
+
 import mxnet as mx
 import numpy as np
-import logging
+
 
 class Monitor(object):
     def __init__(self, interval, level=logging.DEBUG, stat=None):
@@ -32,19 +36,23 @@ class Monitor(object):
             self.stat = stat
 
     def forward_end(self, i, internals):
-        if i%self.interval == 0 and logging.getLogger().isEnabledFor(self.level):
+        if i % self.interval == 0 and logging.getLogger().isEnabledFor(self.level):
             for key in sorted(internals.keys()):
                 arr = internals[key]
-                logging.log(self.level, 'Iter:%d  param:%s\t\tstat(%s):%s'%(i, key, self.stat.__name__, str(self.stat(arr.asnumpy()))))
+                logging.log(self.level, 'Iter:%d  param:%s\t\tstat(%s):%s',
+                            i, key, self.stat.__name__, str(self.stat(arr.asnumpy())))
 
     def backward_end(self, i, weights, grads, metric=None):
-        if i%self.interval == 0 and logging.getLogger().isEnabledFor(self.level):
+        if i % self.interval == 0 and logging.getLogger().isEnabledFor(self.level):
             for key in sorted(grads.keys()):
                 arr = grads[key]
-                logging.log(self.level, 'Iter:%d  param:%s\t\tstat(%s):%s\t\tgrad_stat:%s'%(i, key, self.stat.__name__, str(self.stat(weights[key].asnumpy())), str(self.stat(arr.asnumpy()))))
-        if i%self.interval == 0 and metric is not None:
-                logging.log(logging.INFO, 'Iter:%d metric:%f'%(i, metric.get()[1]))
-                metric.reset()
+                logging.log(self.level, 'Iter:%d  param:%s\t\tstat(%s):%s\t\tgrad_stat:%s',
+                            i, key, self.stat.__name__,
+                            str(self.stat(weights[key].asnumpy())), str(self.stat(arr.asnumpy())))
+        if i % self.interval == 0 and metric is not None:
+            logging.log(logging.INFO, 'Iter:%d metric:%f', i, metric.get()[1])
+            metric.reset()
+
 
 class Solver(object):
     def __init__(self, optimizer, **kwargs):
@@ -71,7 +79,9 @@ class Solver(object):
         self.iter_start_callback = callback
 
     def solve(self, xpu, sym, args, args_grad, auxs,
-              data_iter, begin_iter, end_iter, args_lrmult={}, debug = False):
+              data_iter, begin_iter, end_iter, args_lrmult=None, debug=False):
+        if args_lrmult is None:
+            args_lrmult = dict()
         input_desc = data_iter.provide_data + data_iter.provide_label
         input_names = [k for k, shape in input_desc]
         input_buffs = [mx.nd.empty(shape, ctx=xpu) for k, shape in input_desc]
@@ -79,20 +89,19 @@ class Solver(object):
 
         output_names = sym.list_outputs()
         if debug:
-            sym = sym.get_internals()
-            blob_names = sym.list_outputs()
             sym_group = []
-            for i in range(len(blob_names)):
-                if blob_names[i] not in args:
-                    x = sym[i]
-                    if blob_names[i] not in output_names:
-                        x = mx.symbol.BlockGrad(x, name=blob_names[i])
+            for x in sym.get_internals():
+                if x.name not in args:
+                    if x.name not in output_names:
+                        x = mx.symbol.BlockGrad(x, name=x.name)
                     sym_group.append(x)
             sym = mx.symbol.Group(sym_group)
         exe = sym.bind(xpu, args=args, args_grad=args_grad, aux_states=auxs)
 
         assert len(sym.list_arguments()) == len(exe.grad_arrays)
-        update_dict = {name: nd for name, nd in zip(sym.list_arguments(), exe.grad_arrays) if nd is not None}
+        update_dict = {
+            name: nd for name, nd in zip(sym.list_arguments(), exe.grad_arrays) if nd is not None
+        }
         batch_size = input_buffs[0].shape[0]
         self.optimizer.rescale_grad = 1.0/batch_size
         self.optimizer.set_lr_mult(args_lrmult)
@@ -114,7 +123,7 @@ class Solver(object):
                     return
             try:
                 batch = data_iter.next()
-            except:
+            except StopIteration:
                 data_iter.reset()
                 batch = data_iter.next()
             for data, buff in zip(batch.data+batch.label, input_buffs):

-- 
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].