You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/09/10 18:44:18 UTC

[GitHub] sandeep-krishnamurthy closed pull request #12376: [MXNET-854] SVRG Optimization in Python Module API

sandeep-krishnamurthy closed pull request #12376: [MXNET-854] SVRG Optimization in Python Module API
URL: https://github.com/apache/incubator-mxnet/pull/12376
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/docs/api/python/contrib/svrg_optimization.md b/docs/api/python/contrib/svrg_optimization.md
new file mode 100644
index 00000000000..e6e1c3e23ee
--- /dev/null
+++ b/docs/api/python/contrib/svrg_optimization.md
@@ -0,0 +1,86 @@
+# SVRG Optimization in Python Module API
+
+## Overview
+SVRG which stands for Stochastic Variance Reduced Gradients, is an optimization technique that was first introduced in 
+paper _Accelerating Stochastic Gradient Descent using Predictive Variance Reduction_ in 2013. It is complement to SGD 
+(Stochastic Gradient Descent), which is known for large scale optimization but suffers from slow convergence 
+asymptotically due to its inherent variance. SGD approximates the full gradients using a small batch of data or 
+a single data sample, which will introduce variance and thus requires to start with a small learning rate in order to 
+ensure convergence. SVRG remedies the problem by keeping track of a version of estimated weights that close to the 
+optimal parameter values and maintaining an average of full gradients over a full pass of data. The average of full 
+gradients is calculated with respect to the weights from the last m-th epochs in the training.  SVRG uses a different 
+update rule: gradients w.r.t current parameter values minus gradients w.r.t to parameters from the last m-th epochs 
+plus the average of full gradients over all data. 
+  
+Key Characteristics of SVRG:
+* Employs explicit variance reduction by using a different update rule compared to SGD.
+* Ability to use relatively large learning rate, which leads to faster convergence compared to SGD.
+* Guarantees for fast convergence for smooth and strongly convex functions.
+
+SVRG optimization is implemented as a SVRGModule in `mxnet.contrib.svrg_optimization`, which is an extension of the 
+existing `mxnet.module.Module` APIs and encapsulates SVRG optimization logic within several new functions. SVRGModule 
+API changes compared to Module API to end users are minimal. 
+
+In distributed training, each worker gets the same special weights from the last m-th epoch and calculates the full 
+gradients with respect to its own shard of data. The standard SVRG optimization requires building a global full 
+gradients, which is calculated by aggregating the full gradients from each worker and averaging over the number of 
+workers. The workaround is to keep an additional set of keys in the KVStore that maps to full gradients. 
+The `_SVRGOptimizer` is designed to wrap two optimizers, an `_AssignmentOptimizer` which is used for full gradients 
+accumulation in the KVStore and a regular optimizer that performs actual update rule to the parameters. 
+The `_SVRGOptimizer` and `_AssignmentOptimizer` are designed to be used in `SVRGModule` only.
+
+```eval_rst
+.. warning:: This package contains experimental APIs and may change in the near future.
+``` 
+
+This document lists the SVRGModule APIs in MXNet/Contrib package:
+
+```eval_rst
+.. autosummary::
+    :nosignatures:
+
+    mxnet.contrib.svrg_optimization.svrg_module
+```
+
+### Intermediate Level API for SVRGModule
+
+The only extra step to use a SVRGModule compared to use a Module is to check if the current epoch should update the
+full gradients over all data. Code snippets below demonstrate the suggested usage of SVRGModule using intermediate 
+level APIs.
+
+```python
+>>> mod = SVRGModule(symbol=model, update_freq=2, data_names=['data'], label_names=['lin_reg_label'])
+>>> mod.bind(data_shapes=di.provide_data, label_shapes=di.provide_label)
+>>> mod.init_params()
+>>> mod.init_optimizer(optimizer='sgd', optimizer_params=(('learning_rate', 0.01), ), kvstore='local')
+>>> for epoch in range(num_epochs):
+...     if epoch % mod.update_freq == 0:
+...         mod.update_full_grads(di)
+...     di.reset()
+...     for batch in di:
+...         mod.forward_backward(data_batch=batch)
+...         mod.update()
+```
+
+### High Level API for SVRGModule
+
+The high level API usage of SVRGModule remains exactly the same as Module API. Code snippets below gives an example of
+suggested usage of high level API.
+
+```python
+>>> mod = SVRGModule(symbol=model, update_freq=2, data_names=['data'], label_names=['lin_reg_label'])
+>>> mod.fit(di, num_epochs=100, optimizer='sgd', optimizer_params=(('learning_rate', 0.01), ))
+```
+
+## API reference
+
+<script type="text/javascript" src='../../../_static/js/auto_module_index.js'></script>
+
+```eval_rst
+
+.. automodule:: mxnet.contrib.svrg_optimization.svrg_module
+.. autoclass:: mxnet.contrib.svrg_optimization.svrg_module.SVRGModule
+    :members: init_optimizer, bind, forward, backward, reshape, update, update_full_grads, fit, prepare
+ 
+```
+<script>auto_index("api-reference");</script>
\ No newline at end of file
diff --git a/docs/api/python/index.md b/docs/api/python/index.md
index 42c4af9e46b..15d1045a93e 100644
--- a/docs/api/python/index.md
+++ b/docs/api/python/index.md
@@ -52,6 +52,7 @@ Code examples are placed throughout the API documentation and these can be run a
    contrib/contrib.md
    contrib/text.md
    contrib/onnx.md
+   contrib/svrg_optimization.md
 ```
 
 ## Gluon API
@@ -176,4 +177,4 @@ Code examples are placed throughout the API documentation and these can be run a
    :maxdepth: 1
 
    symbol_in_pictures/symbol_in_pictures.md
-```
+```
\ No newline at end of file
diff --git a/docs/api/python/module/module.md b/docs/api/python/module/module.md
index 86ed74db6c1..5a874ac6df0 100644
--- a/docs/api/python/module/module.md
+++ b/docs/api/python/module/module.md
@@ -207,4 +207,4 @@ additional functionality. We summarize them in this section.
     :members:
 ```
 
-<script>auto_index("api-reference");</script>
+<script>auto_index("api-reference");</script>
\ No newline at end of file
diff --git a/example/svrg_module/README.md b/example/svrg_module/README.md
new file mode 100644
index 00000000000..63e7ba2f2bf
--- /dev/null
+++ b/example/svrg_module/README.md
@@ -0,0 +1,33 @@
+## SVRGModule Example
+SVRGModule is an extension to the Module API that implements SVRG optimization, which stands for Stochastic
+Variance Reduced Gradient. SVRG is an optimization technique that complements SGD and has several key
+properties: 
+* Employs explicit variance reduction by using a different update rule compared to SGD.
+* Ability to use relatively large learning rate, which leads to faster convergence compared to SGD.
+* Guarantees for fast convergence for smooth and strongly convex functions.
+
+#### API Usage Example
+SVRGModule provides both high-level and intermediate-level APIs while minimizing the changes with Module API.  
+example_api_train.py: provides suggested usage of SVRGModule high-level and intermediate-level API.
+example_inference.py: provides example usage of SVRGModule inference.
+
+#### Linear Regression 
+This example trains a linear regression model using SVRGModule on a real dataset, YearPredictionMSD. 
+Logs of the training results can be  found in experiments.log which will automatically generated when running the 
+training script.
+
+##### Dataset
+YearPredictionMSD: contains predictions of the release year of a song from audio features. It has over 
+400,000 samples with 90 features. Please uncomment data downloading script from data_reader.py to download the data. 
+
+#### Benchmarks:
+An initial set of benchmarks has been performed on YearPredictionDatasetMSD with linear regression model.  A jupyter 
+notebook under `/benchmarks` demonstrates the training process and plots two graphs for benchmarking.
+
+* benchmark1: A lr_scheduler returns a new learning rate based on the number of updates that have been performed. 
+
+* benchmark2: One drawback for SGD is that in order to converge faster, the learning rate has to decay to zero, 
+thus SGD needs to start with a small learning rate. The learning rate does not need to decay to zero for SVRG, 
+therefore we can use a relatively larger learning rate. SGD with learning rate of (0.001, 0.0025) and SVRG with 
+learning rate of (0.025) are benchmarked. Even though SVRG starts with a relatively large learning rate, it converges 
+much faster than SGD in both cases.  
diff --git a/example/svrg_module/api_usage_example/example_api_train.py b/example/svrg_module/api_usage_example/example_api_train.py
new file mode 100644
index 00000000000..f6cd1b2e592
--- /dev/null
+++ b/example/svrg_module/api_usage_example/example_api_train.py
@@ -0,0 +1,124 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+
+import mxnet as mx
+import numpy as np
+from mxnet.contrib.svrg_optimization.svrg_module import SVRGModule
+
+
+def test_svrg_intermediate_level_api(args):
+    """Demonstrates intermediate level SVRGModule API where the training process
+    need to be explicitly defined. KVstore is not explicitly created.
+
+    Parameters
+    ----------
+    args: args
+        Command line arguments
+    """
+    num_epoch = args.epochs
+    batch_size = args.batch_size
+    update_freq = args.update_freq
+
+    di, mod = create_network(batch_size, update_freq)
+
+    mod.bind(data_shapes=di.provide_data, label_shapes=di.provide_label)
+    mod.init_params(initializer=mx.init.Uniform(0.01), allow_missing=False, force_init=False, allow_extra=False)
+    kv = mx.kv.create("local")
+    mod.init_optimizer(kvstore=kv, optimizer='sgd', optimizer_params=(('learning_rate', 0.025),))
+    metrics = mx.metric.create("mse")
+    for e in range(num_epoch):
+        metrics.reset()
+        if e % mod.update_freq == 0:
+            mod.update_full_grads(di)
+        di.reset()
+        for batch in di:
+            mod.forward_backward(data_batch=batch)
+            mod.update()
+            mod.update_metric(metrics, batch.label)
+        mod.logger.info('Epoch[%d] Train cost=%f', e, metrics.get()[1])
+
+
+def test_svrg_high_level_api(args):
+    """Demonstrates suggested usage of  high level SVRGModule API. KVStore is explicitly created.
+
+    Parameters
+    ----------
+    args: args
+        Command line arguments
+    """
+    num_epoch = args.epochs
+    batch_size = args.batch_size
+    update_freq = args.update_freq
+
+    di, mod = create_network(batch_size, update_freq)
+    mod.fit(di, eval_metric='mse', optimizer='sgd', optimizer_params=(('learning_rate', 0.025),), num_epoch=num_epoch,
+            kvstore='local')
+
+
+def create_network(batch_size, update_freq):
+    """Create a linear regression network for performing SVRG optimization.
+    Parameters
+    ----------
+    batch_size: int
+        Size of data split
+    update_freq: int
+        Update Frequency for calculating full gradients
+
+    Returns
+    ----------
+    di: mx.io.NDArrayIter
+        Data iterator
+    update_freq: SVRGModule
+        An instance of SVRGModule for performing SVRG optimization
+    """
+    import logging
+    head = '%(asctime)-15s %(message)s'
+    logging.basicConfig(level=logging.INFO, format=head)
+
+    train_data = np.random.randint(1, 5, [1000, 2])
+    weights = np.array([1.0, 2.0])
+    train_label = train_data.dot(weights)
+
+    di = mx.io.NDArrayIter(train_data, train_label, batch_size=batch_size, shuffle=True, label_name='lin_reg_label')
+    X = mx.sym.Variable('data')
+    Y = mx.symbol.Variable('lin_reg_label')
+    fully_connected_layer = mx.sym.FullyConnected(data=X, name='fc1', num_hidden=1)
+    lro = mx.sym.LinearRegressionOutput(data=fully_connected_layer, label=Y, name="lro")
+
+    mod = SVRGModule(
+        symbol=lro,
+        data_names=['data'],
+        label_names=['lin_reg_label'], update_freq=update_freq, logger=logging
+    )
+
+    return di, mod
+
+# run as a script
+if __name__ == "__main__":
+    import argparse
+
+    parser = argparse.ArgumentParser()
+    parser.add_argument('-e', dest='epochs', default=100, type=int)
+    parser.add_argument('-bs', dest='batch_size', default=32, type=int)
+    parser.add_argument('-f', dest="update_freq", default=2, type=int)
+    args = parser.parse_args()
+
+    print("========================== Intermediate Level API ==========================")
+    test_svrg_intermediate_level_api(args)
+    print("========================== High Level API ==========================")
+    test_svrg_high_level_api(args)
diff --git a/example/svrg_module/api_usage_example/example_inference.py b/example/svrg_module/api_usage_example/example_inference.py
new file mode 100644
index 00000000000..312f9796074
--- /dev/null
+++ b/example/svrg_module/api_usage_example/example_inference.py
@@ -0,0 +1,106 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+
+import mxnet as mx
+import numpy as np
+import logging
+from mxnet.contrib.svrg_optimization.svrg_module import SVRGModule
+
+
+def test_svrg_inference(args):
+    epoch = args.epochs
+    batch_size = args.batch_size
+    update_freq = args.update_freq
+
+    train_iter, val_iter, mod = create_network(batch_size, update_freq)
+    mod.fit(train_iter, eval_data=val_iter, eval_metric='mse', optimizer='sgd',
+            optimizer_params=(('learning_rate', 0.025),),
+            num_epoch=epoch)
+
+
+def get_validation_score(args):
+    epoch = args.epochs
+    batch_size = args.batch_size
+    update_freq = args.update_freq
+
+    train_iter, val_iter,  mod = create_network(batch_size, update_freq)
+    mod.bind(data_shapes=train_iter.provide_data, label_shapes=train_iter.provide_label)
+    mod.init_params(initializer=mx.init.Uniform(0.01), allow_missing=False, force_init=False, allow_extra=False)
+    mod.init_optimizer(kvstore='local', optimizer='sgd', optimizer_params=(('learning_rate', 0.025),))
+    metrics = mx.metric.create("mse")
+    for e in range(epoch):
+        metrics.reset()
+        if e % mod.update_freq == 0:
+            mod.update_full_grads(train_iter)
+        train_iter.reset()
+        for batch in train_iter:
+            mod.forward_backward(data_batch=batch)
+            mod.update()
+            mod.update_metric(metrics, batch.label)
+
+    y = mod.predict(val_iter)
+
+    # test-train data split, 20% test data out of 1000 data samples
+    assert y.shape == (200, 1)
+    score = mod.score(val_iter, ['mse'])
+    print("Training Loss on Validation Set is {}".format(score[0][1]))
+
+
+def create_network(batch_size, update_freq):
+    """Create a linear regression network for performing SVRG optimization.
+    :return: an instance of mx.io.NDArrayIter
+    :return: an instance of mx.mod.svrgmodule for performing SVRG optimization
+    """
+    head = '%(asctime)-15s %(message)s'
+    logging.basicConfig(level=logging.INFO, format=head)
+    data = np.random.randint(1, 5, [1000, 2])
+
+    #Test_Train data split
+    n_train = int(data.shape[0] * 0.8)
+    weights = np.array([1.0, 2.0])
+    label = data.dot(weights)
+
+    di = mx.io.NDArrayIter(data[:n_train, :], label[:n_train], batch_size=batch_size, shuffle=True, label_name='lin_reg_label')
+    val_iter = mx.io.NDArrayIter(data[n_train:, :], label[n_train:], batch_size=batch_size)
+
+    X = mx.sym.Variable('data')
+    Y = mx.symbol.Variable('lin_reg_label')
+    fully_connected_layer = mx.sym.FullyConnected(data=X, name='fc1', num_hidden=1)
+    lro = mx.sym.LinearRegressionOutput(data=fully_connected_layer, label=Y, name="lro")
+
+    mod = SVRGModule(
+        symbol=lro,
+        data_names=['data'],
+        label_names=['lin_reg_label'], update_freq=update_freq, logger=logging)
+
+    return di, val_iter, mod
+
+
+# run as a script
+if __name__ == "__main__":
+    import argparse
+    parser = argparse.ArgumentParser()
+    parser.add_argument('-e', dest='epochs', default=100, type=int)
+    parser.add_argument('-bs', dest='batch_size', default=32, type=int)
+    parser.add_argument('-f', dest="update_freq", default=2, type=int)
+    args = parser.parse_args()
+
+    print("========================== SVRG Module Inference ==========================")
+    test_svrg_inference(args)
+    print("========================SVRG Module Score ============================")
+    get_validation_score(args)
diff --git a/example/svrg_module/benchmarks/svrg_benchmark.ipynb b/example/svrg_module/benchmarks/svrg_benchmark.ipynb
new file mode 100644
index 00000000000..db02938af46
--- /dev/null
+++ b/example/svrg_module/benchmarks/svrg_benchmark.ipynb
@@ -0,0 +1,379 @@
+{
+ "cells": [
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Linear Regression Using SVRGModule on YearPredictionMSD Dataset"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "In this notebook, a linear regression model will be fit on YearPredictionMSD dataset, which contains predictions of the release year of a song based on its audio features. The dataset has 90 features and over 400,000 samples. The dataset is downsampled to 5,000 in this experiment."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 16,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import mxnet as mx\n",
+    "from sklearn.datasets import load_svmlight_file\n",
+    "import numpy as np\n",
+    "import json\n",
+    "import tempfile\n",
+    "import os\n",
+    "from mxnet.contrib.svrg_optimization.svrg_module import SVRGModule\n"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Read Data\n",
+    "The first step is to get the training features and labels and normalize the data. In this example, we will use 5000 data samples.  "
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Download data file\n",
+    "# from subprocess import call\n",
+    "# YearPredictionMSD dataset: https://archive.ics.uci.edu/ml/datasets/yearpredictionmsd\n",
+    "# call(['wget', 'https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/regression/YearPredictionMSD.bz2'])\n",
+    "# call(['bzip2', '-d', 'YearPredictionMSD.bz2'])"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 30,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Reading data from disk...\n"
+     ]
+    }
+   ],
+   "source": [
+    "feature_dim = 90\n",
+    "print(\"Reading data from disk...\")\n",
+    "train_features, train_labels = load_svmlight_file('YearPredictionMSD', n_features=feature_dim, dtype=np.float32)\n",
+    "train_features = train_features.todense()\n",
+    "\n",
+    "# normalize the data: subtract means and divide by standard deviations\n",
+    "label_mean = train_labels.mean()\n",
+    "label_std = np.sqrt(np.square(train_labels - label_mean).mean())\n",
+    "feature_means = train_features.mean(axis=0)\n",
+    "feature_stds = np.sqrt(np.square(train_features - feature_means).mean(axis=0))\n",
+    "\n",
+    "train_features = (train_features - feature_means) / feature_stds\n",
+    "train_labels = (train_labels - label_mean) / label_std\n",
+    "\n",
+    "train_features = train_features[-5000:]\n",
+    "train_labels = train_labels[-5000:]"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Create Linear Regression Network"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 19,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def create_lin_reg_network(batch_size=100):\n",
+    "    train_iter = mx.io.NDArrayIter(train_features, train_labels, batch_size=batch_size, shuffle=True,\n",
+    "                               data_name='data', label_name='label')\n",
+    "    data = mx.sym.Variable(\"data\")\n",
+    "    label = mx.sym.Variable(\"label\")\n",
+    "    weight = mx.sym.Variable(\"fc_weight\", shape=(1, 90))\n",
+    "    net = mx.sym.dot(data, weight.transpose())\n",
+    "    bias = mx.sym.Variable(\"fc_bias\", shape=(1,), wd_mult=0.0, lr_mult=10.0)\n",
+    "    net = mx.sym.broadcast_plus(net, bias)\n",
+    "    net = mx.sym.LinearRegressionOutput(data=net, label=label)\n",
+    "    \n",
+    "    return train_iter, net"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### SVRGModule with SVRG Optimization\n",
+    "In this example, we will use intermediate level API for SVRGModule and the dump mse per epoch to JSON file for plotting graphs."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 24,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def train_svrg_lin_reg(num_epoch=100, batch_size=100, update_freq=2, output='svrg_lr.json', \n",
+    "                       optimizer_params=None):\n",
+    "\n",
+    "    di, net = create_lin_reg_network(batch_size=batch_size)\n",
+    "    \n",
+    "    #Create a SVRGModule\n",
+    "    mod = SVRGModule(symbol=net, context=mx.cpu(0), data_names=['data'], label_names=['label'], update_freq=update_freq)\n",
+    "    mod.bind(data_shapes=di.provide_data, label_shapes=di.provide_label)\n",
+    "    mod.init_params(initializer=mx.init.Zero(), allow_missing=False, force_init=False, allow_extra=False)\n",
+    "    mod.init_optimizer(kvstore='local', optimizer='sgd', optimizer_params=optimizer_params)\n",
+    "    metrics = mx.metric.create(\"mse\")\n",
+    "    \n",
+    "    results = {}\n",
+    "    for e in range(num_epoch):\n",
+    "        results[e] = {}\n",
+    "        metrics.reset()\n",
+    "        if e % mod.update_freq == 0:\n",
+    "            mod.update_full_grads(di)\n",
+    "        di.reset()\n",
+    "        for batch in di:\n",
+    "            mod.forward_backward(data_batch=batch)\n",
+    "            mod.update()\n",
+    "            mod.update_metric(metrics, batch.label)\n",
+    "        results[e][\"mse\"] = metrics.get()[1]\n",
+    "   \n",
+    "    f = open(output, 'w+')\n",
+    "    f.write(json.dumps(results, indent=4, sort_keys=True))\n",
+    "    f.close()\n"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Module with SGD Optimization "
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 25,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def train_sgd_lin_reg(num_epoch=100, batch_size=100, update_freq=2, output='sgd_lr.json', \n",
+    "                       optimizer_params=None):\n",
+    "    \n",
+    "    di, net = create_lin_reg_network(batch_size=batch_size)\n",
+    "    \n",
+    "    #Create a standard module\n",
+    "    mod = mx.mod.Module(symbol=net, context=mx.cpu(0), data_names=['data'], label_names=['label'])\n",
+    "    mod.bind(data_shapes=di.provide_data, label_shapes=di.provide_label)\n",
+    "    mod.init_params(initializer=mx.init.Zero(), allow_missing=False, force_init=False, allow_extra=False)\n",
+    "    mod.init_optimizer(kvstore='local', optimizer='sgd', optimizer_params=optimizer_params)\n",
+    "    metrics = mx.metric.create(\"mse\")\n",
+    "    \n",
+    "    results = {}\n",
+    "    for e in range(num_epoch):\n",
+    "        results[e] = {}\n",
+    "        metrics.reset()\n",
+    "        di.reset()\n",
+    "        for batch in di:\n",
+    "            mod.forward_backward(data_batch=batch)\n",
+    "            mod.update()\n",
+    "            mod.update_metric(metrics, batch.label)\n",
+    "        results[e][\"mse\"] = metrics.get()[1]\n",
+    "    f = open(output, 'w+')\n",
+    "    f.write(json.dumps(results, indent=4, sort_keys=True))\n",
+    "    f.close()\n",
+    "  "
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 11,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import json\n",
+    "import seaborn as sns\n",
+    "import matplotlib.pyplot as plt\n",
+    "import matplotlib.patches as mpatches\n",
+    "import pandas as pd"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Training Loss over 100 Epochs Using lr_scheduler\n",
+    "When a large learning rate is used with SGD, training loss will drop fast but will oscillates above the minimum and never converges. With a small learning rate, it will eventually reach the minimum after many iterations. A common practice is to use learning rate scheduling by starting with a large learning rate and gradually decreasing it. "
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 31,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "train_svrg_lin_reg(optimizer_params={'lr_scheduler': mx.lr_scheduler.FactorScheduler(step=10, factor=0.99)})\n",
+    "train_sgd_lin_reg(optimizer_params={'lr_scheduler': mx.lr_scheduler.FactorScheduler(step=10, factor=0.99)})"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 32,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "Text(0.5,0,'Epochs')"
+      ]
+     },
+     "execution_count": 32,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "image/png": "\n",
+      "text/plain": [
+       "<Figure size 1440x864 with 1 Axes>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "# plot graph\n",
+    "#Plot training loss over Epochs:\n",
+    "color = sns.color_palette()\n",
+    "#Draw Weight Variance Ratio\n",
+    "dataplot3 = {\"svrg_mse\": [], \"sgd_mse\": []}\n",
+    "with open('sgd_lr.json') as sgd_data, open('svrg_lr.json') as svrg_data:\n",
+    "    sgd = json.load(sgd_data)\n",
+    "    svrg = json.load(svrg_data)\n",
+    "    for epoch in range(100):\n",
+    "        dataplot3[\"svrg_mse\"].append(svrg[str(epoch)][\"mse\"])\n",
+    "        dataplot3[\"sgd_mse\"].append(sgd[str(epoch)][\"mse\"])\n",
+    "\n",
+    "x3 = list(range(100))\n",
+    "plt.figure(figsize=(20, 12))\n",
+    "plt.title(\"Training Loss Over Epochs\")\n",
+    "sns.pointplot(x3, dataplot3['svrg_mse'], color=color[9])\n",
+    "sns.pointplot(x3, dataplot3['sgd_mse'], color=color[8])\n",
+    "color_patch1 = mpatches.Patch(color=color[9], label=\"svrg_mse\")\n",
+    "color_patch2 = mpatches.Patch(color=color[8], label=\"sgd_mse\")\n",
+    "plt.legend(handles=[color_patch1, color_patch2])\n",
+    "plt.ylabel('Training Loss', fontsize=12)\n",
+    "plt.xlabel('Epochs', fontsize=12)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Training Loss Comparison with SGD with fixed learning rates\n",
+    "Choosing learning rate (0.0025, 0.001, 0.005) for SGD and a relatively large learning rate 0.025 for SVRG, we can see SVRG smoothly goes down faster than SGD. Learning rate for SVRG does not need to decay to zero, which means we can start with a larger learning rate."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 33,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "train_svrg_lin_reg(output=\"svrg_0.025.json\", optimizer_params=(('learning_rate', 0.025),))\n",
+    "train_sgd_lin_reg(output=\"sgd_0.001.json\", optimizer_params=((\"learning_rate\", 0.001),))\n",
+    "train_sgd_lin_reg(output=\"sgd_0.0025.json\", optimizer_params=((\"learning_rate\", 0.0025),))\n",
+    "train_sgd_lin_reg(output=\"sgd_0.005.json\", optimizer_params=((\"learning_rate\", 0.005),))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 34,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "Text(0.5,0,'Epochs')"
+      ]
+     },
+     "execution_count": 34,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "image/png": "\n",
+      "text/plain": [
+       "<Figure size 1440x864 with 1 Axes>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "#Plot training loss over Epochs:\n",
+    "color = sns.color_palette()\n",
+    "#Draw Weight Variance Ratio\n",
+    "dataplot3 = {\"svrg_mse\": [], \"sgd_mse_lr_0.001\": [], \"sgd_mse_lr_0.0025\": [], \"sgd_mse_lr_0.005\":[]}\n",
+    "with open('sgd_0.001.json') as sgd_data, open('svrg_0.025.json') as svrg_data, open('sgd_0.0025.json') as sgd_data_2, open('sgd_0.005.json') as sgd_data_3:\n",
+    "    sgd = json.load(sgd_data)\n",
+    "    svrg = json.load(svrg_data)\n",
+    "    sgd_lr = json.load(sgd_data_2)\n",
+    "    sgd_lr_2 = json.load(sgd_data_3)\n",
+    "    for epoch in range(100):\n",
+    "        dataplot3[\"svrg_mse\"].append(svrg[str(epoch)][\"mse\"])\n",
+    "        dataplot3[\"sgd_mse_lr_0.001\"].append(sgd[str(epoch)][\"mse\"])\n",
+    "        dataplot3[\"sgd_mse_lr_0.0025\"].append(sgd_lr[str(epoch)][\"mse\"])\n",
+    "        dataplot3[\"sgd_mse_lr_0.005\"].append(sgd_lr_2[str(epoch)][\"mse\"])\n",
+    "\n",
+    "x3 = list(range(100))\n",
+    "plt.figure(figsize=(20, 12))\n",
+    "plt.title(\"Training Loss Over Epochs\")\n",
+    "sns.pointplot(x3, dataplot3['svrg_mse'], color=color[9])\n",
+    "sns.pointplot(x3, dataplot3['sgd_mse_lr_0.001'], color=color[8])\n",
+    "sns.pointplot(x3, dataplot3['sgd_mse_lr_0.0025'], color=color[3])\n",
+    "sns.pointplot(x3, dataplot3['sgd_mse_lr_0.005'], color=color[7])\n",
+    "color_patch1 = mpatches.Patch(color=color[9], label=\"svrg_mse_0.025\")\n",
+    "color_patch2 = mpatches.Patch(color=color[8], label=\"sgd_mse_lr_0.001\")\n",
+    "color_patch3 = mpatches.Patch(color=color[3], label=\"sgd_mse_lr_0.0025\")\n",
+    "color_patch4 = mpatches.Patch(color=color[7], label=\"sgd_mse_lr_0.005\")\n",
+    "plt.legend(handles=[color_patch1, color_patch2, color_patch3, color_patch4])\n",
+    "plt.ylabel('Training Loss', fontsize=12)\n",
+    "plt.xlabel('Epochs', fontsize=12)"
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 2",
+   "language": "python",
+   "name": "python2"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 2
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython2",
+   "version": "2.7.15"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/example/svrg_module/linear_regression/common.py b/example/svrg_module/linear_regression/common.py
new file mode 100644
index 00000000000..14a144f40ce
--- /dev/null
+++ b/example/svrg_module/linear_regression/common.py
@@ -0,0 +1,117 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+
+import mxnet as mx
+import logging
+from mxnet.contrib.svrg_optimization.svrg_module import SVRGModule
+
+
+def create_lin_reg_network(train_features, train_labels, feature_dim, batch_size, update_freq, ctx, logger):
+    # fit a linear regression model with mxnet SVRGModule
+    print("Fitting linear regression with mxnet")
+    train_iter = mx.io.NDArrayIter(train_features, train_labels, batch_size=batch_size, shuffle=True,
+                                   data_name='data', label_name='label')
+    data = mx.sym.Variable("data")
+    label = mx.sym.Variable("label")
+    weight = mx.sym.Variable("fc_weight", shape=(1, feature_dim))
+    net = mx.sym.dot(data, weight.transpose())
+    bias = mx.sym.Variable("fc_bias", shape=(1,), wd_mult=0.0, lr_mult=10.0)
+    net = mx.sym.broadcast_plus(net, bias)
+    net = mx.sym.LinearRegressionOutput(data=net, label=label)
+    mod = SVRGModule(symbol=net, context=ctx, data_names=['data'], label_names=['label'], logger=logger,
+                     update_freq=update_freq)
+    return train_iter, mod
+
+
+def create_metrics(metrics):
+    metric = mx.metric.create(metrics)
+    return metric
+
+
+def create_logger():
+    logger = logging.getLogger('sgd_svrg')
+    logger.setLevel(logging.INFO)
+    formatter = logging.Formatter('%(asctime)s - %(message)s')
+    fh = logging.FileHandler('experiments.log')
+    fh.setFormatter(formatter)
+    logger.addHandler(fh)
+    return logger
+
+
+################################################################################
+# Functions below are for benchmark purpose to calcuate expectation, variance of
+# gradients per epoch for each parameter. These calculations will be helpful when
+# benchmarking SVRG optimization with other optimization techniques, such as SGD.
+# Currently it only calculates the expectation, variance for single context but
+# can be extended to multi-context in later iterations.
+################################################################################
+
+def accumulate_grad(grad_dict, mod):
+    param_names = mod._exec_group.param_names
+
+    for index, name in enumerate(param_names):
+        if name not in grad_dict:
+            grad_dict[name] = mod._exec_group.grad_arrays[index][0].copy()
+        else:
+            grad_dict[name] = mx.ndarray.concat(grad_dict[name], mod._exec_group.grad_arrays[index][0], dim=0)
+
+
+def calc_expectation(grad_dict, num_batches):
+    """Calculates the expectation of the gradients per epoch for each parameter w.r.t number of batches
+
+    Parameters
+    ----------
+    grad_dict: dict
+        dictionary that maps parameter name to gradients in the mod executor group
+    num_batches: int
+        number of batches
+
+    Returns
+    ----------
+    grad_dict: dict
+        dictionary with new keys mapping to gradients expectations
+
+    """
+    for key in grad_dict.keys():
+        grad_dict[str.format(key+"_expectation")] = mx.ndarray.sum(grad_dict[key], axis=0) / num_batches
+
+    return grad_dict
+
+
+def calc_variance(grad_dict, num_batches, param_names):
+    """Calculates the variance of the gradients per epoch for each parameter w.r.t number of batches
+
+    Parameters
+    ----------
+    grad_dict: dict
+        dictionary that maps parameter name to gradients in the mod executor group
+    num_batches: int
+        number of batches
+    param_names: str
+        parameter name in the module
+
+    Returns
+    ----------
+    grad_dict: dict
+        dictionary with new keys mapping to gradients variance
+
+    """
+    for i in range(len(param_names)):
+        diff_sqr = mx.ndarray.square(mx.nd.subtract(grad_dict[param_names[i]],
+                                                    grad_dict[str.format(param_names[i]+"_expectation")]))
+        grad_dict[str.format(param_names[i] + "_variance")] = mx.ndarray.sum(diff_sqr, axis=0) / num_batches
diff --git a/example/svrg_module/linear_regression/data_reader.py b/example/svrg_module/linear_regression/data_reader.py
new file mode 100644
index 00000000000..d56ae03a5f4
--- /dev/null
+++ b/example/svrg_module/linear_regression/data_reader.py
@@ -0,0 +1,45 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+
+import numpy as np
+from sklearn.datasets import load_svmlight_file
+
+# Download data file
+# from subprocess import call
+# YearPredictionMSD dataset: https://archive.ics.uci.edu/ml/datasets/yearpredictionmsd
+# call(['wget', 'https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/regression/YearPredictionMSD.bz2'])
+# call(['bzip2', '-d', 'YearPredictionMSD.bz2'])
+
+
+def read_year_prediction_data(fileName):
+    feature_dim = 90
+    print("Reading data from disk...")
+    train_features, train_labels = load_svmlight_file(fileName, n_features=feature_dim, dtype=np.float32)
+    train_features = train_features.todense()
+
+    # normalize the data: subtract means and divide by standard deviations
+    label_mean = train_labels.mean()
+    label_std = np.sqrt(np.square(train_labels - label_mean).mean())
+    feature_means = train_features.mean(axis=0)
+    feature_stds = np.sqrt(np.square(train_features - feature_means).mean(axis=0))
+
+    train_features = (train_features - feature_means) / feature_stds
+    train_labels = (train_labels - label_mean) / label_std
+
+    return feature_dim, train_features, train_labels
+
diff --git a/example/svrg_module/linear_regression/train.py b/example/svrg_module/linear_regression/train.py
new file mode 100644
index 00000000000..b3d942973f1
--- /dev/null
+++ b/example/svrg_module/linear_regression/train.py
@@ -0,0 +1,45 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+
+import argparse
+import mxnet as mx
+from common import create_lin_reg_network, create_logger
+from data_reader import read_year_prediction_data
+
+parser = argparse.ArgumentParser()
+parser.add_argument('-e', dest='epochs', help='number of epochs for training phase', type=int, default=100)
+parser.add_argument('-f', dest="updateFreq", help="update frequency for SVRGModule", type=int, default=2)
+parser.add_argument('-b', dest="batch_size", help="define the batch size for training", type=int,
+                    default=100, required=False)
+parser.add_argument('-m', dest='metrics', help="create eval metric", type=str, default='mse')
+parser.add_argument('--gpus', type=str, help='list of gpus to run, e.g. 0 or 0,2,5. empty means using cpu')
+parser.add_argument('--kv-store', type=str, default='local', help='key-value store type')
+
+args = parser.parse_args()
+# devices for training
+ctx = mx.cpu() if args.gpus is None or args.gpus == "" else [mx.gpu(int(i)) for i in args.gpus.split(',')]
+
+logger = create_logger()
+kv = mx.kvstore.create(args.kv_store)
+
+feature_dim, train_features, train_labels = read_year_prediction_data('YearPredictionMSD')
+train_iter, mod = create_lin_reg_network(train_features, train_labels, feature_dim, args.batch_size, args.updateFreq,
+                                         ctx, logger)
+
+mod.fit(train_iter, eval_metric='mse', optimizer='sgd',
+        optimizer_params=(('learning_rate', 0.025), ), num_epoch=args.epochs, kvstore=kv)
diff --git a/python/mxnet/contrib/svrg_optimization/__init__.py b/python/mxnet/contrib/svrg_optimization/__init__.py
new file mode 100644
index 00000000000..6e70009983c
--- /dev/null
+++ b/python/mxnet/contrib/svrg_optimization/__init__.py
@@ -0,0 +1,22 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""SVRGModule, SVRGOptimization import.
+"""
+
+
+from . import svrg_module
+from . import svrg_optimizer
diff --git a/python/mxnet/contrib/svrg_optimization/svrg_module.py b/python/mxnet/contrib/svrg_optimization/svrg_module.py
new file mode 100644
index 00000000000..5d6b5dd5720
--- /dev/null
+++ b/python/mxnet/contrib/svrg_optimization/svrg_module.py
@@ -0,0 +1,578 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# coding: utf-8
+"""A `SVRGModule` implements the `Module` API by wrapping an auxiliary module to perform
+SVRG optimization logic.
+"""
+
+import time
+import logging
+import mxnet as mx
+from mxnet.module import Module
+from .svrg_optimizer import _SVRGOptimizer
+
+
+class SVRGModule(Module):
+    """SVRGModule is a module that encapsulates two Modules to accommodate the SVRG optimization technique.
+    It is functionally the same as Module API, except it is implemented using SVRG optimization logic.
+
+    Parameters
+    ----------
+    symbol : Symbol
+    data_names : list of str
+        Defaults to `('data')` for a typical model used in image classification.
+    label_names : list of str
+        Defaults to `('softmax_label')` for a typical model used in image
+        classification.
+    logger : Logger
+        Defaults to `logging`.
+    context : Context or list of Context
+        Defaults to ``mx.cpu()``.
+    work_load_list : list of number
+        Default ``None``, indicating uniform workload.
+    fixed_param_names: list of str
+        Default ``None``, indicating no network parameters are fixed.
+    state_names : list of str
+        states are similar to data and label, but not provided by data iterator.
+        Instead they are initialized to 0 and can be set by `set_states()`.
+    group2ctxs : dict of str to context or list of context, or list of dict of str to context
+        Default is `None`. Mapping the `ctx_group` attribute to the context assignment.
+    compression_params : dict
+        Specifies type of gradient compression and additional arguments depending
+        on the type of compression being used. For example, 2bit compression requires a threshold.
+        Arguments would then be {'type':'2bit', 'threshold':0.5}
+        See mxnet.KVStore.set_gradient_compression method for more details on gradient compression.
+    update_freq: int
+        Specifies the number of times to update the full gradients to be used in the SVRG optimization. For instance,
+        update_freq = 2 will calculates the gradients over all data every two epochs
+    Examples
+    --------
+    >>> # An example of declaring and using SVRGModule.
+    >>> mod = SVRGModule(symbol=lro, data_names=['data'], label_names=['lin_reg_label'], update_freq=2)
+    >>> mod.fit(di, eval_metric='mse', optimizer='sgd', optimizer_params=(('learning_rate', 0.025),),
+    >>>         num_epoch=num_epoch, kvstore='local')
+    """
+
+    def __init__(self, symbol, data_names=('data',), label_names=('softmax_label',),
+                 logger=logging, context=mx.cpu(), work_load_list=None,
+                 fixed_param_names=None, state_names=None, group2ctxs=None,
+                 compression_params=None, update_freq=None):
+        super(SVRGModule, self).__init__(symbol, data_names=data_names, label_names=label_names, logger=logger,
+                                         context=context, work_load_list=work_load_list,
+                                         fixed_param_names=fixed_param_names, state_names=state_names,
+                                         group2ctxs=group2ctxs, compression_params=compression_params)
+
+        # Type check update_frequency
+        if isinstance(update_freq, int):
+            if update_freq <= 0:
+                raise ValueError("update_freq in SVRGModule must be a positive integer to represent the frequency for "
+                                 "calculating full gradients")
+            self.update_freq = update_freq
+        else:
+            raise TypeError("update_freq in SVRGModule must be an integer to represent the frequency for "
+                            "calculating full gradients")
+
+        self._mod_aux = mx.mod.Module(symbol, data_names, label_names, logger, context, work_load_list,
+                                      fixed_param_names, state_names, group2ctxs, compression_params)
+
+        self._param_dict = None
+        self._ctx_len = len(self._context)
+
+    def _reset_bind(self):
+        """Internal function to reset binded state for both modules."""
+        super(SVRGModule, self)._reset_bind()
+        self._mod_aux._reset_bind()
+
+    def reshape(self, data_shapes, label_shapes=None):
+        """Reshapes both modules for new input shapes.
+
+        Parameters
+        ----------
+        data_shapes : list of (str, tuple)
+            Typically is ``data_iter.provide_data``.
+        label_shapes : list of (str, tuple)
+            Typically is ``data_iter.provide_label``.
+        """
+        super(SVRGModule, self).reshape(data_shapes, label_shapes=label_shapes)
+        self._mod_aux.reshape(data_shapes, label_shapes=label_shapes)
+
+    def init_optimizer(self, kvstore='local', optimizer='sgd',
+                       optimizer_params=(('learning_rate', 0.01),), force_init=False):
+        """Installs and initializes SVRGOptimizer. The SVRGOptimizer is a wrapper class for a regular optimizer that is
+        passed in and a special AssignmentOptimizer to accumulate the full gradients.  If KVStore is 'local' or None,
+        the full gradients will be accumulated locally without pushing to the KVStore. Otherwise, additional keys will
+        be pushed to accumulate the full gradients in the KVStore.
+
+        Parameters
+        ----------
+        kvstore : str or KVStore
+            Default `'local'`.
+        optimizer : str or Optimizer
+            Default `'sgd'`
+        optimizer_params : dict
+            Default `(('learning_rate', 0.01),)`. The default value is not a dictionary,
+            just to avoid pylint warning of dangerous default values.
+        force_init : bool
+            Default ``False``, indicating whether we should force re-initializing the
+            optimizer in the case an optimizer is already installed.
+        """
+
+        # Init dict for storing average of full gradients for each device
+        self._param_dict = [{key: mx.nd.zeros(shape=value.shape, ctx=self._context[i])
+                             for key, value in self.get_params()[0].items()} for i in range(self._ctx_len)]
+
+        svrg_optimizer = self._create_optimizer(_SVRGOptimizer.__name__, default_opt=optimizer,
+                                                kvstore=kvstore, optimizer_params=optimizer_params)
+
+        super(SVRGModule, self).init_optimizer(kvstore=kvstore, optimizer=svrg_optimizer,
+                                               optimizer_params=optimizer_params, force_init=force_init)
+
+        # Init additional keys for accumulating full grads in KVStore
+        if self._kvstore:
+            for idx, param_on_devs in enumerate(self._exec_group.param_arrays):
+                name = self._exec_group.param_names[idx]
+                self._kvstore.init(name + "_full", mx.nd.zeros(shape=self._arg_params[name].shape))
+                if self._update_on_kvstore:
+                    self._kvstore.pull(name + "_full", param_on_devs, priority=-idx)
+
+    def _create_optimizer(self, optimizer, default_opt, kvstore, optimizer_params):
+        """Helper function to create a svrg optimizer. SVRG optimizer encapsulates two optimizers and
+        will redirect update() to the correct optimizer based on the key.
+
+        Parameters
+        ----------
+        kvstore : str or KVStore
+            Default `'local'`.
+        optimizer: str
+            Name for SVRGOptimizer
+        default_opt : str or Optimizer that was passed in.
+        optimizer_params : dict
+           optimizer params that was passed in.
+        """
+
+        # code partially copied from mxnet module.init_optimizer() to accomodate svrg_optimizer
+        batch_size = self._exec_group.batch_size
+
+        (kv_store, update_on_kvstore) = mx.model._create_kvstore(kvstore, self._ctx_len, self._arg_params)
+        if kv_store and 'dist' in kv_store.type and '_sync' in kv_store.type:
+            batch_size *= kv_store.num_workers
+        rescale_grad = 1.0 / batch_size
+
+        idx2name = {}
+        if update_on_kvstore:
+            idx2name.update(enumerate(self._exec_group.param_names))
+        else:
+            for k in range(self._ctx_len):
+                idx2name.update({i * self._ctx_len + k: n
+                                 for i, n in enumerate(self._exec_group.param_names)})
+
+        # update idx2name to include new keys
+        for key in self._param_dict[0].keys():
+            max_key = max(list(idx2name.keys())) + 1
+            idx2name[max_key] = key + "_full"
+
+        optimizer_params = dict(optimizer_params)
+        if 'rescale_grad' not in optimizer_params:
+            optimizer_params['rescale_grad'] = rescale_grad
+        optimizer_params["default_optimizer"] = default_opt
+        optimizer_params["param_idx2name"] = idx2name
+        optimizer = mx.optimizer.create(optimizer, **optimizer_params)
+
+        return optimizer
+
+    def bind(self, data_shapes, label_shapes=None, for_training=True,
+             inputs_need_grad=False, force_rebind=False, shared_module=None, grad_req='write'):
+        """Binds the symbols to construct executors for both two modules. This is necessary before one
+        can perform computation with the SVRGModule.
+
+        Parameters
+        ----------
+        data_shapes : list of (str, tuple)
+            Typically is ``data_iter.provide_data``.
+        label_shapes : list of (str, tuple)
+            Typically is ``data_iter.provide_label``.
+        for_training : bool
+            Default is ``True``. Whether the executors should be bound for training.
+        inputs_need_grad : bool
+            Default is ``False``. Whether the gradients to the input data need to be computed.
+            Typically this is not needed. But this might be needed when implementing composition
+            of modules.
+        force_rebind : bool
+            Default is ``False``. This function does nothing if the executors are already
+            bound. But with this ``True``, the executors will be forced to rebind.
+        shared_module : Module
+            Default is ``None``. This is used in bucketing. When not ``None``, the shared module
+            essentially corresponds to a different bucket -- a module with different symbol
+            but with the same sets of parameters (e.g. unrolled RNNs with different lengths).
+        """
+        # force rebinding is typically used when one want to switch from
+        # training to prediction phase.
+        super(SVRGModule, self).bind(data_shapes, label_shapes, for_training, inputs_need_grad, force_rebind,
+                                     shared_module, grad_req)
+
+        if for_training:
+            self._mod_aux.bind(data_shapes, label_shapes, for_training, inputs_need_grad, force_rebind, shared_module,
+                               grad_req)
+
+    def forward(self, data_batch, is_train=None):
+        """Forward computation for both two modules. It supports data batches with different shapes, such as
+        different batch sizes or different image sizes.
+        If reshaping of data batch relates to modification of symbol or module, such as
+        changing image layout ordering or switching from training to predicting, module
+        rebinding is required.
+
+        See Also
+        ----------
+        :meth:`BaseModule.forward`.
+
+        Parameters
+        ----------
+        data_batch : DataBatch
+            Could be anything with similar API implemented.
+        is_train : bool
+            Default is ``None``, which means ``is_train`` takes the value of ``self.for_training``.
+        """
+        super(SVRGModule, self).forward(data_batch, is_train)
+
+        if is_train:
+            self._mod_aux.forward(data_batch, is_train)
+
+    def backward(self, out_grads=None):
+        """Backward computation.
+
+        See Also
+        ----------
+        :meth:`BaseModule.backward`.
+
+        Parameters
+        ----------
+        out_grads : NDArray or list of NDArray, optional
+            Gradient on the outputs to be propagated back.
+            This parameter is only needed when bind is called
+            on outputs that are not a loss function.
+        """
+        super(SVRGModule, self).backward(out_grads)
+
+        if self._mod_aux.binded:
+            self._mod_aux.backward(out_grads)
+
+    def update(self):
+        """Updates parameters according to the installed optimizer and the gradients computed
+        in the previous forward-backward batch. The gradients in the _exec_group will be overwritten
+        using the gradients calculated by the SVRG update rule.
+
+        When KVStore is used to update parameters for multi-device or multi-machine training,
+        a copy of the parameters is stored in KVStore. Note that for `row_sparse` parameters,
+        this function does update the copy of parameters in KVStore, but doesn't broadcast the
+        updated parameters to all devices / machines. Please call `prepare` to broadcast
+        `row_sparse` parameters with the next batch of data.
+
+        See Also
+        ----------
+        :meth:`BaseModule.update`.
+        """
+        self._update_svrg_gradients()
+        super(SVRGModule, self).update()
+
+    def update_full_grads(self, train_data):
+        """Computes the gradients over all data w.r.t weights of past
+        m epochs. For distributed env, it will accumulate full grads in the kvstore.
+
+        Parameters
+        ----------
+        train_data: DataIter
+            Train data iterator
+        """
+        param_names = self._exec_group.param_names
+        arg, aux = self.get_params()
+        self._mod_aux.set_params(arg_params=arg, aux_params=aux)
+        train_data.reset()
+        nbatch = 0
+        padding = 0
+        for batch in train_data:
+            self._mod_aux.forward(batch, is_train=True)
+            self._mod_aux.backward()
+            nbatch += 1
+            for ctx in range(self._ctx_len):
+                for index, name in enumerate(param_names):
+                    grads = self._mod_aux._exec_group.grad_arrays[index][ctx]
+                    self._param_dict[ctx][name] = mx.nd.broadcast_add(self._param_dict[ctx][name], grads, axis=0)
+            padding = batch.pad
+
+        true_num_batch = nbatch - padding / train_data.batch_size
+        for name in param_names:
+            grad_list = []
+            for i in range(self._ctx_len):
+                self._param_dict[i][name] /= true_num_batch
+                grad_list.append(self._param_dict[i][name])
+            if self._kvstore:
+                # If in distributed mode, push a list of gradients from each worker/device to the KVStore
+                self._accumulate_kvstore(name, grad_list)
+
+    def _accumulate_kvstore(self, key, value):
+        """Accumulate gradients over all data in the KVStore. In distributed setting, each worker sees a portion of
+        data. The full gradients will be aggregated from each worker in the KVStore.
+
+        Parameters
+        ----------
+
+        key: int or str
+            Key in the KVStore.
+        value: NDArray, RowSparseNDArray
+            Average of the full gradients.
+        """
+        # Accumulate full gradients for current epochs
+        self._kvstore.push(key + "_full", value)
+        self._kvstore._barrier()
+        self._kvstore.pull(key + "_full", value)
+
+        self._allocate_gradients(key, value)
+
+    def _allocate_gradients(self, key, value):
+        """Allocate average of full gradients accumulated in the KVStore to each device.
+
+        Parameters
+        ----------
+
+        key: int or str
+            Key in the kvstore.
+        value: List of NDArray, List of RowSparseNDArray
+            A list of average of the full gradients in the KVStore.
+        """
+        for i in range(self._ctx_len):
+            self._param_dict[i][key] = value[i] / self._ctx_len
+
+    def _svrg_grads_update_rule(self, g_curr_batch_curr_weight, g_curr_batch_special_weight,
+                                g_special_weight_all_batch):
+        """Calculates the gradient based on the SVRG update rule.
+        Parameters
+        ----------
+        g_curr_batch_curr_weight : NDArray
+            gradients of current weight of self.mod w.r.t current batch of data
+        g_curr_batch_special_weight: NDArray
+            gradients of the weight of past m epochs of self._mod_special w.r.t current batch of data
+        g_special_weight_all_batch: NDArray
+            average of full gradients over full pass of data
+
+        Returns
+        ----------
+        Gradients calculated using SVRG update rule:
+        grads = g_curr_batch_curr_weight - g_curr_batch_special_weight + g_special_weight_all_batch
+        """
+        for index, grad in enumerate(g_curr_batch_curr_weight):
+            grad -= g_curr_batch_special_weight[index]
+            grad += g_special_weight_all_batch[index]
+        return g_curr_batch_curr_weight
+
+    def _update_svrg_gradients(self):
+        """Calculates gradients based on the SVRG update rule.
+        """
+        param_names = self._exec_group.param_names
+        for ctx in range(self._ctx_len):
+            for index, name in enumerate(param_names):
+                g_curr_batch_reg = self._exec_group.grad_arrays[index][ctx]
+                g_curr_batch_special = self._mod_aux._exec_group.grad_arrays[index][ctx]
+                g_special_weight_all_batch = self._param_dict[ctx][name]
+                g_svrg = self._svrg_grads_update_rule(g_curr_batch_reg, g_curr_batch_special,
+                                                      g_special_weight_all_batch)
+                self._exec_group.grad_arrays[index][ctx] = g_svrg
+
+    def fit(self, train_data, eval_data=None, eval_metric='acc',
+            epoch_end_callback=None, batch_end_callback=None, kvstore='local',
+            optimizer='sgd', optimizer_params=(('learning_rate', 0.01),),
+            eval_end_callback=None,
+            eval_batch_end_callback=None, initializer=mx.init.Uniform(0.01),
+            arg_params=None, aux_params=None, allow_missing=False,
+            force_rebind=False, force_init=False, begin_epoch=0, num_epoch=None,
+            validation_metric=None, monitor=None, sparse_row_id_fn=None):
+        """Trains the module parameters.
+        Parameters
+        ----------
+        train_data : DataIter
+            Train DataIter.
+        eval_data : DataIter
+            If not ``None``, will be used as validation set and the performance
+            after each epoch will be evaluated.
+        eval_metric : str or EvalMetric
+            Defaults to 'accuracy'. The performance measure used to display during training.
+            Other possible predefined metrics are:
+            'ce' (CrossEntropy), 'f1', 'mae', 'mse', 'rmse', 'top_k_accuracy'.
+        epoch_end_callback : function or list of functions
+            Each callback will be called with the current `epoch`, `symbol`, `arg_params`
+            and `aux_params`.
+        batch_end_callback : function or list of function
+            Each callback will be called with a `BatchEndParam`.
+        kvstore : str or KVStore
+            Defaults to 'local'.
+        optimizer : str or Optimizer
+            Defaults to 'sgd'.
+        optimizer_params : dict
+            Defaults to ``(('learning_rate', 0.01),)``. The parameters for
+            the optimizer constructor.
+            The default value is not a dict, just to avoid pylint warning on dangerous
+            default values.
+        eval_end_callback : function or list of function
+            These will be called at the end of each full evaluation, with the metrics over
+            the entire evaluation set.
+        eval_batch_end_callback : function or list of function
+            These will be called at the end of each mini-batch during evaluation.
+        initializer : Initializer
+            The initializer is called to initialize the module parameters when they are
+            not already initialized.
+        arg_params : dict
+            Defaults to ``None``, if not ``None``, should be existing parameters from a trained
+            model or loaded from a checkpoint (previously saved model). In this case,
+            the value here will be used to initialize the module parameters, unless they
+            are already initialized by the user via a call to `init_params` or `fit`.
+            `arg_params` has a higher priority than `initializer`.
+        aux_params : dict
+            Defaults to ``None``. Similar to `arg_params`, except for auxiliary states.
+        allow_missing : bool
+            Defaults to ``False``. Indicates whether to allow missing parameters when `arg_params`
+            and `aux_params` are not ``None``. If this is ``True``, then the missing parameters
+            will be initialized via the `initializer`.
+        force_rebind : bool
+            Defaults to ``False``. Whether to force rebinding the executors if already bound.
+        force_init : bool
+            Defaults to ``False``. Indicates whether to force initialization even if the
+            parameters are already initialized.
+        begin_epoch : int
+            Defaults to 0. Indicates the starting epoch. Usually, if resumed from a
+            checkpoint saved at a previous training phase at epoch N, then this value should be
+            N+1.
+        num_epoch : int
+            Number of epochs for training.
+        sparse_row_id_fn : A callback function
+            The function  takes `data_batch` as an input and returns a dict of
+            str -> NDArray. The resulting dict is used for pulling row_sparse
+            parameters from the kvstore, where the str key is the name of the param,
+            and the value is the row id of the param to pull.
+        validation_metric: str or EvalMetric
+            The performance measure used to display during validation.
+        """
+        assert num_epoch is not None, 'please specify number of epochs'
+
+        self.bind(data_shapes=train_data.provide_data, label_shapes=train_data.provide_label,
+                  for_training=True, force_rebind=force_rebind)
+        if monitor is not None:
+            self.install_monitor(monitor)
+        self.init_params(initializer=initializer, arg_params=arg_params, aux_params=aux_params,
+                         allow_missing=allow_missing, force_init=force_init)
+        self.init_optimizer(kvstore=kvstore, optimizer=optimizer, optimizer_params=optimizer_params)
+
+        if validation_metric is None:
+            validation_metric = eval_metric
+        if not isinstance(eval_metric, mx.metric.EvalMetric):
+            eval_metric = mx.metric.create(eval_metric)
+
+        ################################################################################
+        # training loop
+        ################################################################################
+        for epoch in range(begin_epoch, num_epoch):
+            eval_metric.reset()
+            tic = time.time()
+            if epoch % self.update_freq == 0:
+                self.update_full_grads(train_data)
+
+            train_data.reset()
+            data_iter = iter(train_data)
+            end_of_batch = False
+            nbatch = 0
+            next_data_batch = next(data_iter)
+
+            while not end_of_batch:
+                data_batch = next_data_batch
+                if monitor is not None:
+                    monitor.tic()
+
+                self.forward_backward(data_batch)
+                self.update()
+
+                if isinstance(data_batch, list):
+                    self.update_metric(eval_metric, [db.label for db in data_batch], pre_sliced=True)
+                else:
+                    self.update_metric(eval_metric, data_batch.label)
+
+                try:
+                    # pre fetch next batch
+                    next_data_batch = next(data_iter)
+                    self.prepare(next_data_batch, sparse_row_id_fn=sparse_row_id_fn)
+                except StopIteration:
+                    end_of_batch = True
+
+                if monitor is not None:
+                    monitor.toc_print()
+
+                if end_of_batch:
+                    eval_name_vals = eval_metric.get_name_value()
+
+                if batch_end_callback is not None:
+                    batch_end_params = mx.model.BatchEndParam(epoch=epoch, nbatch=nbatch,
+                                                              eval_metric=eval_metric, locals=locals())
+                    for callback in mx.base._as_list(batch_end_callback):
+                        callback(batch_end_params)
+
+                nbatch += 1
+            for name, val in eval_name_vals:
+                self.logger.info('Epoch[%d] Train-%s=%f', epoch, name, val)
+            toc = time.time()
+            self.logger.info('Epoch[%d] Time cost=%.3f', epoch, (toc - tic))
+
+            # sync aux params across devices
+            arg_params, aux_params = self.get_params()
+            self.set_params(arg_params, aux_params)
+
+            if epoch_end_callback is not None:
+                for callback in mx.base._as_list(epoch_end_callback):
+                    callback(epoch, self.symbol, arg_params, aux_params)
+
+            # ----------------------------------------
+            # evaluation on validation set
+            if eval_data:
+                res = self.score(eval_data, validation_metric,
+                                 score_end_callback=eval_end_callback,
+                                 batch_end_callback=eval_batch_end_callback, epoch=epoch)
+                for name, val in res:
+                    self.logger.info('Epoch[%d] Validation-%s=%f', epoch, name, val)
+
+    def prepare(self, data_batch, sparse_row_id_fn=None):
+        """Prepares two modules for processing a data batch.
+
+        Usually involves switching bucket and reshaping.
+        For modules that contain `row_sparse` parameters in KVStore,
+        it prepares the `row_sparse` parameters based on the sparse_row_id_fn.
+
+        When KVStore is used to update parameters for multi-device or multi-machine training,
+        a copy of the parameters are stored in KVStore. Note that for `row_sparse` parameters,
+        the `update()` updates the copy of parameters in KVStore, but doesn't broadcast
+        the updated parameters to all devices / machines. The `prepare` function is used to
+        broadcast `row_sparse` parameters with the next batch of data.
+
+        Parameters
+        ----------
+        data_batch : DataBatch
+            The current batch of data for forward computation.
+
+        sparse_row_id_fn : A callback function
+            The function  takes `data_batch` as an input and returns a dict of
+            str -> NDArray. The resulting dict is used for pulling row_sparse
+            parameters from the kvstore, where the str key is the name of the param,
+            and the value is the row id of the param to pull.
+        """
+        super(SVRGModule, self).prepare(data_batch, sparse_row_id_fn=sparse_row_id_fn)
+        self._mod_aux.prepare(data_batch, sparse_row_id_fn=sparse_row_id_fn)
diff --git a/python/mxnet/contrib/svrg_optimization/svrg_optimizer.py b/python/mxnet/contrib/svrg_optimization/svrg_optimizer.py
new file mode 100644
index 00000000000..0f695a1b2ff
--- /dev/null
+++ b/python/mxnet/contrib/svrg_optimization/svrg_optimizer.py
@@ -0,0 +1,171 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""A `_SVRGOptimizer` encapsulates two optimizers to support SVRGModule in single machine and distributed settings.
+Both `_AssignmentOptimizer` and `_SVRGOptimizer` are designed to be used with SVRGModule only.
+"""
+
+
+import mxnet as mx
+
+
+@mx.optimizer.register
+class _AssignmentOptimizer(mx.optimizer.Optimizer):
+    """_AssignmentOptimizer assigns gradients to weights for SVRGModule's full gradients
+    accumulation in the KVStore. It is a helper optimizer that is designed to be used with SVRGModule only.
+    """
+    def update(self, index, weight, grad, state):
+        """Assign the gradients to weight for accumulating full gradients in the KVStore across all devices and workers.
+
+        Parameters
+        ----------
+        index : int
+            The unique index of the parameter into the individual learning
+            rates and weight decays. Learning rates and weight decay
+            may be set via `set_lr_mult()` and `set_wd_mult()`, respectively.
+        weight : NDArray
+            The parameter to be updated.
+        grad : NDArray
+            The gradient of the objective with respect to this parameter.
+        state: any obj
+            AssignmentOptimizer will not need to be associated with state.
+        """
+
+        weight[:] = grad
+
+
+@mx.optimizer.register
+class _SVRGOptimizer(mx.optimizer.Optimizer):
+    """_SVRGOptimizer is a wrapper class for two optimizers: _AssignmentOptimizer for accumulating full gradients in the
+    KVStore and a default optimizer that is passed in as a parameter in `mod.init_optimizer()`
+    The _SVRGOptimizer is designed to be used with SVRGModule only.
+
+    This optimizer accepts the following parameters in addition to those accepted by :class:`.Optimizer`.
+
+    Parameters
+    ----------
+    default_optimizer: str or Optimizer
+        Optimizer passed-in when invoke on mx.mod.init_optimizer in SVRGModule
+    """
+
+    def __init__(self, default_optimizer, **kwargs):
+        # Reconstruct kwargs to identify additional params for default optimizer
+        base_param = self._check_params(**kwargs)
+        super(_SVRGOptimizer, self).__init__(**base_param)
+        if isinstance(default_optimizer, str):
+            self.default_opt = mx.optimizer.create(default_optimizer, **kwargs)
+        else:
+            self.default_opt = default_optimizer
+        self.aux_opt = mx.optimizer.create(_AssignmentOptimizer.__name__)
+
+    @staticmethod
+    def _check_params(**kwargs):
+        """ Reassemble kwargs to identify additional optimizer params for default optimizers. base_params contains
+        all the param names in base class Optimizer.
+
+        Parameters
+        ----------
+        kwargs: dict
+            Parameters for the default optimizer
+
+        Returns
+        ----------
+        default_params: dict
+            Optimizer parameters that are defined in base class Optimizer
+        """
+
+        optimizer_param = dict(kwargs)
+        base_params = ['rescale_grad', 'param_idx2name', 'wd', 'clip_gradient', 'learning_rate', 'lr_scheduler', 'sym',
+                       'begin_num_update', 'multi_precision', 'param_dict']
+
+        default_params = {}
+        for key, _ in optimizer_param.items():
+            if key in base_params:
+                default_params[key] = optimizer_param[key]
+
+        return default_params
+
+    def update(self, index, weight, grad, state):
+        """Updates the given parameter using the corresponding gradient and state. If key contains 'full', update with
+        `_AssignmentOptimizer` otherwise will use default optimizer.
+
+        Parameters
+        ----------
+        index : int
+            The unique index of the parameter into the individual learning
+            rates and weight decays. Learning rates and weight decay
+            may be set via `set_lr_mult()` and `set_wd_mult()`, respectively.
+        weight : NDArray
+            The parameter to be updated.
+        grad : NDArray
+            The gradient of the objective with respect to this parameter.
+        state : any obj
+            The state returned by `create_state()`.
+        """
+
+        name = self._check_index(index)
+
+        if "full" in name:
+            self.aux_opt.update(index, weight, grad, state)
+        else:
+            # use the default optimizer
+            self.default_opt.update(index, weight, grad, state)
+
+    def create_state(self, index, weight):
+        """Creates auxiliary state for a given weight.
+        Some optimizers require additional states, e.g. as momentum, in addition
+        to gradients in order to update weights. This function creates state
+        for a given weight which will be used in `update`. This function is
+        called only once for each weight.
+
+        Parameters
+        ----------
+        index : int
+            An unique index to identify the weight.
+        weight : NDArray
+            The weight.
+        Returns
+        -------
+        state : any obj
+            The state associated with the weight.
+        """
+
+        name = self._check_index(index)
+        if "full" in name:
+            return self.aux_opt.create_state(index, weight)
+        else:
+            #
+            return self.default_opt.create_state(index, weight)
+
+    def _check_index(self, index):
+        """Check index in idx2name to get corresponding param_name
+        Parameters
+        ----------
+        index : int or str
+            An unique index to identify the weight.
+        Returns
+        -------
+        name : str
+            Name of the Module parameter
+        """
+
+        if index in self.idx2name.values():
+            # index is a str
+            name = index
+        else:
+            # index is an int
+            name = self.idx2name[index]
+        return name
diff --git a/tests/python/unittest/test_contrib_svrg_module.py b/tests/python/unittest/test_contrib_svrg_module.py
new file mode 100644
index 00000000000..d9e0abaebb2
--- /dev/null
+++ b/tests/python/unittest/test_contrib_svrg_module.py
@@ -0,0 +1,307 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import mxnet as mx
+import numpy as np
+from common import with_seed, assertRaises
+from mxnet.contrib.svrg_optimization.svrg_module import SVRGModule
+from mxnet.test_utils import *
+
+
+def setup():
+    train_data = np.random.randint(1, 5, [1000, 2])
+    weights = np.array([1.0, 2.0])
+    train_label = train_data.dot(weights)
+
+    di = mx.io.NDArrayIter(train_data, train_label, batch_size=32, shuffle=True, label_name='lin_reg_label')
+    X = mx.sym.Variable('data')
+    Y = mx.symbol.Variable('lin_reg_label')
+    fully_connected_layer = mx.sym.FullyConnected(data=X, name='fc1', num_hidden=1)
+    lro = mx.sym.LinearRegressionOutput(data=fully_connected_layer, label=Y, name="lro")
+
+    mod = SVRGModule(
+        symbol=lro,
+        data_names=['data'],
+        label_names=['lin_reg_label'], update_freq=2)
+    mod.bind(data_shapes=di.provide_data, label_shapes=di.provide_label)
+    mod.init_params(initializer=mx.init.Uniform(0.01), allow_missing=False, force_init=False, allow_extra=False)
+
+    return di, mod
+
+
+def test_bind_module():
+    _, mod = setup()
+    assert mod.binded == True
+    assert mod._mod_aux.binded == True
+
+
+def test_module_init():
+    _, mod = setup()
+    assert mod._mod_aux is not None
+
+
+def test_module_initializer():
+    def regression_model(m):
+        x = mx.symbol.var("data", stype='csr')
+        v = mx.symbol.var("v", shape=(m, 1), init=mx.init.Uniform(scale=.1),
+                          stype='row_sparse')
+        model = mx.symbol.dot(lhs=x, rhs=v)
+        y = mx.symbol.Variable("label")
+        model = mx.symbol.LinearRegressionOutput(data=model, label=y, name="out")
+        return model
+
+    #shape of the data
+    n, m = 128, 100
+    model = regression_model(m)
+
+    data = mx.nd.zeros(shape=(n, m), stype='csr')
+    label = mx.nd.zeros((n, 1))
+    iterator = mx.io.NDArrayIter(data=data, label={'label': label},
+                                 batch_size=n, last_batch_handle='discard')
+
+    # create module
+    mod = SVRGModule(symbol=model, data_names=['data'], label_names=['label'], update_freq=2)
+    mod.bind(data_shapes=iterator.provide_data, label_shapes=iterator.provide_label)
+    mod.init_params()
+    v = mod._arg_params['v']
+    assert v.stype == 'row_sparse'
+    assert np.sum(v.asnumpy()) != 0
+
+
+def test_module_bind():
+    x = mx.sym.Variable("data")
+    net = mx.sym.FullyConnected(x, num_hidden=1)
+
+    mod = SVRGModule(symbol=net, data_names=['data'], label_names=None, update_freq=2)
+    assertRaises(TypeError, mod.bind, data_shapes=['data', mx.nd.zeros(shape=(2, 1))])
+
+    mod.bind(data_shapes=[('data', (2, 1))])
+    assert mod.binded == True
+    assert mod._mod_aux.binded == True
+
+
+@with_seed()
+def test_module_save_load():
+    import tempfile
+    import os
+
+    x = mx.sym.Variable("data")
+    y = mx.sym.Variable("softmax_label")
+    net = mx.sym.FullyConnected(x, y, num_hidden=1)
+
+    mod = SVRGModule(symbol=net, data_names=['data'], label_names=['softmax_label'], update_freq=2)
+    mod.bind(data_shapes=[('data', (1, 1))])
+    mod.init_params()
+    mod.init_optimizer(optimizer='sgd', optimizer_params={'learning_rate': 0.1})
+    mod.update()
+
+    # Create tempfile
+    tmp = tempfile.mkdtemp()
+    tmp_file = os.path.join(tmp, 'svrg_test_output')
+    mod.save_checkpoint(tmp_file, 0, save_optimizer_states=True)
+
+    mod2 = SVRGModule.load(tmp_file, 0, load_optimizer_states=True, data_names=('data', ))
+    mod2.bind(data_shapes=[('data', (1, 1))])
+    mod2.init_optimizer(optimizer_params={'learning_rate': 0.1})
+    assert mod._symbol.tojson() == mod2._symbol.tojson()
+
+    # Multi-device
+    mod3 = SVRGModule(symbol=net, data_names=['data'], label_names=['softmax_label'], update_freq=3,
+                     context=[mx.cpu(0), mx.cpu(1)])
+    mod3.bind(data_shapes=[('data', (10, 10))])
+    mod3.init_params()
+    mod3.init_optimizer(optimizer_params={'learning_rate': 1.0})
+    mod3.update()
+    mod3.save_checkpoint(tmp_file, 0, save_optimizer_states=True)
+
+    mod4 = SVRGModule.load(tmp_file, 0, load_optimizer_states=True, data_names=('data', ))
+    mod4.bind(data_shapes=[('data', (10, 10))])
+    mod4.init_optimizer(optimizer_params={'learning_rate': 1.0})
+    assert mod3._symbol.tojson() == mod4._symbol.tojson()
+
+
+@with_seed()
+def test_svrgmodule_reshape():
+    data = mx.sym.Variable("data")
+    sym = mx.sym.FullyConnected(data=data, num_hidden=4, name='fc')
+
+    dshape=(3, 4)
+    mod = SVRGModule(sym, data_names=["data"], label_names=None, context=[mx.cpu(0), mx.cpu(1)], update_freq=2)
+    mod.bind(data_shapes=[('data', dshape)])
+    mod.init_params()
+    mod._mod_aux.init_params()
+    mod.init_optimizer(optimizer_params={"learning_rate": 1.0})
+
+    data_batch = mx.io.DataBatch(data=[mx.nd.ones(dshape)], label=None)
+    mod.forward(data_batch)
+    mod.backward([mx.nd.ones(dshape)])
+    mod.update()
+    assert mod.get_outputs()[0].shape == dshape
+
+    dshape = (2, 4)
+    mod.reshape(data_shapes=[('data', dshape)])
+    mod.forward(mx.io.DataBatch(data=[mx.nd.ones(dshape)],
+                                label=None))
+    mod.backward([mx.nd.ones(dshape)])
+    mod.update()
+    assert mod.get_outputs()[0].shape == dshape
+
+
+@with_seed()
+def test_update_full_grad():
+    def create_network():
+        train_data = np.random.randint(1, 5, [10, 2])
+        weights = np.array([1.0, 2.0])
+        train_label = train_data.dot(weights)
+
+        di = mx.io.NDArrayIter(train_data, train_label, batch_size=5, shuffle=True, label_name='lin_reg_label')
+        X = mx.sym.Variable('data')
+        Y = mx.symbol.Variable('lin_reg_label')
+        fully_connected_layer = mx.sym.FullyConnected(data=X, name='fc1', num_hidden=1)
+        lro = mx.sym.LinearRegressionOutput(data=fully_connected_layer, label=Y, name="lro")
+
+        mod = SVRGModule(
+            symbol=lro,
+            data_names=['data'],
+            label_names=['lin_reg_label'], update_freq=2)
+        mod.bind(data_shapes=di.provide_data, label_shapes=di.provide_label)
+        mod.init_params(initializer=mx.init.One(), allow_missing=False, force_init=False, allow_extra=False)
+        mod.init_optimizer(kvstore='local', optimizer='sgd', optimizer_params=(('learning_rate', 0.01),),
+                           force_init=False)
+        return di, mod
+
+    di, svrg_mod = create_network()
+
+    # Calculates the average of full gradients over number batches
+    full_grads_weights = mx.nd.zeros(shape=svrg_mod.get_params()[0]['fc1_weight'].shape)
+    arg, aux = svrg_mod.get_params()
+    svrg_mod._mod_aux.set_params(arg_params=arg, aux_params=aux)
+    num_batch = 2
+
+    for batch in di:
+        svrg_mod.forward(batch)
+        svrg_mod.backward()
+        full_grads_weights = mx.nd.broadcast_add(svrg_mod._exec_group.grad_arrays[0][0], full_grads_weights, axis=0)
+    full_grads_weights /= num_batch
+
+    di.reset()
+    svrg_mod.update_full_grads(di)
+    assert same(full_grads_weights, svrg_mod._param_dict[0]['fc1_weight'])
+
+
+@with_seed()
+def test_svrg_with_sgd():
+    def create_module_with_sgd():
+        train_data = np.random.randint(1, 5, [100, 2])
+        weights = np.array([1.0, 2.0])
+        train_label = train_data.dot(weights)
+
+        di = mx.io.NDArrayIter(train_data, train_label, batch_size=10, shuffle=True, label_name='lin_reg_label')
+        X = mx.sym.Variable('data')
+        Y = mx.symbol.Variable('lin_reg_label')
+        fully_connected_layer = mx.sym.FullyConnected(data=X, name='fc1', num_hidden=1)
+        lro = mx.sym.LinearRegressionOutput(data=fully_connected_layer, label=Y, name="lro")
+
+        reg_mod = mx.mod.Module(
+            symbol=lro,
+            data_names=['data'],
+            label_names=['lin_reg_label'])
+        reg_mod.bind(data_shapes=di.provide_data, label_shapes=di.provide_label)
+        reg_mod.init_params(initializer=mx.init.One(), allow_missing=False, force_init=False, allow_extra=False)
+        reg_mod.init_optimizer(kvstore='local', optimizer='sgd', optimizer_params=(('learning_rate', 0.01),))
+
+        svrg_mod = SVRGModule(symbol=lro,
+            data_names=['data'],
+            label_names=['lin_reg_label'],
+            update_freq=2)
+        svrg_mod.bind(data_shapes=di.provide_data, label_shapes=di.provide_label)
+        svrg_mod.init_params(initializer=mx.init.One(), allow_missing=False, force_init=False, allow_extra=False)
+        svrg_mod.init_optimizer(kvstore='local', optimizer='sgd', optimizer_params=(('learning_rate', 0.01),))
+
+        return di,reg_mod, svrg_mod
+
+    di, reg_mod, svrg_mod = create_module_with_sgd()
+    num_epoch = 10
+
+    # Use metric MSE
+    metrics = mx.metric.create("mse")
+
+    # Train with SVRGModule
+    for e in range(num_epoch):
+        metrics.reset()
+        if e % svrg_mod.update_freq == 0:
+            svrg_mod.update_full_grads(di)
+        di.reset()
+        for batch in di:
+            svrg_mod.forward_backward(data_batch=batch)
+            svrg_mod.update()
+            svrg_mod.update_metric(metrics, batch.label)
+    svrg_mse = metrics.get()[1]
+
+    # Train with SGD standard Module
+    di.reset()
+    for e in range(num_epoch):
+        metrics.reset()
+        di.reset()
+        for batch in di:
+            reg_mod.forward_backward(data_batch=batch)
+            reg_mod.update()
+            reg_mod.update_metric(metrics, batch.label)
+    sgd_mse = metrics.get()[1]
+
+    assert svrg_mse < sgd_mse
+
+
+@with_seed()
+def test_accumulate_kvstore():
+    # Test KVStore behavior when push a list of values
+    kv = mx.kv.create('local')
+    kv.init("fc1_weight", mx.nd.zeros(shape=(1, 2)))
+    kv.init("fc1_weight_full", mx.nd.zeros(shape=(1, 2)))
+    b = [mx.nd.ones(shape=(1, 2)) for i in range(4)]
+    a = mx.nd.zeros(shape=(1, 2))
+    kv.push("fc1_weight_full", b)
+    kv.pull("fc1_weight_full", out=a)
+    assert same(a, [mx.nd.array([4, 4])])
+    assert kv.num_workers == 1
+
+    # Test accumulate in KVStore and allocate gradients
+    kv_test = mx.kv.create('local')
+    _, svrg_mod = setup()
+    svrg_mod.init_optimizer(kvstore=kv_test, optimizer='sgd', optimizer_params=(('learning_rate', 0.01),),
+                            force_init=False)
+    svrg_mod._accumulate_kvstore("fc1_weight", b)
+    assert len(svrg_mod._param_dict) == svrg_mod._ctx_len
+    assert same(svrg_mod._param_dict[0]["fc1_weight"], b[0])
+
+
+@with_seed()
+def test_fit():
+    di, mod = setup()
+    num_epoch = 100
+    metric = mx.metric.create("mse")
+    mod.fit(di, eval_metric=metric, optimizer='sgd', optimizer_params=(('learning_rate', 0.025),), num_epoch=num_epoch,
+            kvstore='local')
+
+    # Estimated MSE for using SGD optimizer of lr = 0.025, SVRG MSE should be smaller
+    estimated_mse = 1e-5
+    assert metric.get()[1] < estimated_mse
+
+
+if __name__ == "__main__":
+    import nose
+    nose.runmodule()
diff --git a/tests/python/unittest/test_contrib_svrg_optimizer.py b/tests/python/unittest/test_contrib_svrg_optimizer.py
new file mode 100644
index 00000000000..f7d90d12872
--- /dev/null
+++ b/tests/python/unittest/test_contrib_svrg_optimizer.py
@@ -0,0 +1,101 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import numpy as np
+import mxnet as mx
+from mxnet.test_utils import same
+from mxnet.contrib.svrg_optimization.svrg_module import SVRGModule
+from mxnet.contrib.svrg_optimization.svrg_optimizer import _SVRGOptimizer
+
+
+def create_network():
+
+    train_data = np.random.randint(1, 5, [1000, 2])
+    weights = np.array([1.0, 2.0])
+    train_label = train_data.dot(weights)
+
+    batch_size = 32
+
+    di = mx.io.NDArrayIter(train_data, train_label, batch_size=batch_size, shuffle=True, label_name='lin_reg_label')
+    X = mx.sym.Variable('data')
+    Y = mx.symbol.Variable('lin_reg_label')
+    fully_connected_layer = mx.sym.FullyConnected(data=X, name='fc1', num_hidden=1)
+    lro = mx.sym.LinearRegressionOutput(data=fully_connected_layer, label=Y, name="lro")
+
+    mod = SVRGModule(
+        symbol=lro,
+        data_names=['data'],
+        label_names=['lin_reg_label'], update_freq=2
+    )
+
+    mod.bind(data_shapes=di.provide_data, label_shapes=di.provide_label)
+    mod.init_params(initializer=mx.init.Uniform(0.01), allow_missing=False,
+                    force_init=False, allow_extra=False)
+
+    return di, mod
+
+
+def test_init_svrg_optimizer():
+    _, mod = create_network()
+
+    kv = mx.kv.create('local')
+    mod.init_optimizer(kvstore=kv, optimizer='sgd', optimizer_params=(('learning_rate', 0.01),),
+                       force_init=False)
+
+    assert type(mod._optimizer).__name__ == _SVRGOptimizer.__name__
+
+
+def test_svrg_optimizer_constructor():
+    kv = mx.kv.create('local')
+    svrg_optimizer = _SVRGOptimizer(default_optimizer='sgd', learning_rate=-1.0)
+    kv.set_optimizer(svrg_optimizer)
+
+    assert svrg_optimizer.default_opt.lr == -1.0
+
+
+def test_kvstore_init_aux_keys():
+    param_idx2name = {0: "weight", 1: "weight_full"}
+
+    svrg_optimizer = _SVRGOptimizer(default_optimizer='sgd', param_idx2name= param_idx2name, learning_rate=1.0)
+    kv = mx.kv.create('local')
+    kv.set_optimizer(svrg_optimizer)
+
+    # Use default sgd optimizer
+    param_weight_init = mx.nd.array([0, 0, 0])
+    param_weight_update = mx.nd.array([1, 1, 1])
+
+    kv.init(0, param_weight_init)
+    kv.push(0, param_weight_update)
+    kv.pull(0, param_weight_init)
+
+    param_weight_full_init = mx.nd.array([1, 1, 1])
+    param_weight_full_update = mx.nd.array([2, 2, 2])
+
+    # Use AssignmentOptimizer
+    kv.init(1, param_weight_full_init)
+    kv.push(1, param_weight_full_update)
+    kv.pull(1, param_weight_full_init)
+
+    # updated weights using default sgd optimizer
+    assert same(param_weight_init.asnumpy(), np.array([-1, -1, -1]))
+    # updated with AssignmentOptimizer
+    assert same(param_weight_full_init.asnumpy(), np.array([2, 2, 2]))
+
+
+if __name__ == "__main__":
+    import nose
+    nose.runmodule()


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services