You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@madlib.apache.org by nk...@apache.org on 2018/06/30 00:34:06 UTC

madlib git commit: MLP: Add momentum and nesterov accelerated gradient

Repository: madlib
Updated Branches:
  refs/heads/master 519acce82 -> 8bf15fde0


MLP: Add momentum and nesterov accelerated gradient

JIRA: MADLIB-1210

Momentum methods remember the past gradients/model updates and allow
smoothening out the erratic behaviour of the gradient updates, without
slowing down the learning. With Momentum update, the parameter vector
will build up velocity in any direction that has consistent gradient.

Nesterov Accelerated Gradient method is a slightly different version of
the momentum update that enjoys stronger theoretical converge guarantees
for convex functions and in practice also works slightly better than
standard momentum.

This commit also includes some refactoring that combines the update
methods for IGD and mini-batch.

Closes #272

Co-authored-by: Rahul Iyer <ri...@apache.org>
Co-authored-by: Jingyi Mei <jm...@pivotal.io>


Project: http://git-wip-us.apache.org/repos/asf/madlib/repo
Commit: http://git-wip-us.apache.org/repos/asf/madlib/commit/8bf15fde
Tree: http://git-wip-us.apache.org/repos/asf/madlib/tree/8bf15fde
Diff: http://git-wip-us.apache.org/repos/asf/madlib/diff/8bf15fde

Branch: refs/heads/master
Commit: 8bf15fde06cf534710e8045a4dadff7d38bfa369
Parents: 519acce
Author: Nikhil Kak <nk...@pivotal.io>
Authored: Wed May 2 05:25:48 2018 -0700
Committer: Nikhil Kak <nk...@pivotal.io>
Committed: Fri Jun 29 17:34:00 2018 -0700

----------------------------------------------------------------------
 doc/design/modules/neural-network.tex           |  46 +++++-
 doc/literature.bib                              |  11 ++
 src/modules/convex/mlp_igd.cpp                  |  46 +++---
 src/modules/convex/task/mlp.hpp                 | 114 ++++++++++-----
 src/modules/convex/type/model.hpp               |  87 +++++++++--
 src/modules/convex/type/state.hpp               |  37 ++---
 src/modules/convex/type/tuple.hpp               |   2 +-
 src/ports/postgres/modules/convex/mlp.sql_in    | 143 ++++++++++++-------
 src/ports/postgres/modules/convex/mlp_igd.py_in |  79 +++++++---
 .../postgres/modules/convex/test/mlp.sql_in     | 137 ++++++++++++++++++
 10 files changed, 531 insertions(+), 171 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/madlib/blob/8bf15fde/doc/design/modules/neural-network.tex
----------------------------------------------------------------------
diff --git a/doc/design/modules/neural-network.tex b/doc/design/modules/neural-network.tex
index b0b601b..6d60485 100644
--- a/doc/design/modules/neural-network.tex
+++ b/doc/design/modules/neural-network.tex
@@ -22,7 +22,12 @@
 \chapter{Neural Network}
 
 \begin{moduleinfo}
-\item[Authors] {Xixuan Feng, Cooper Sloan}
+\item[Authors] {Xixuan Feng, Cooper Sloan, Rahul Iyer, Nikhil Kak}
+\item[History]
+    \begin{modulehistory}
+        \item[v0.1] Initial version
+        \item[v0.2] Added a section for momentum updates. Also updated the mlp-train-iteration algorithm to include momentum calculations.
+    \end{modulehistory}
 \end{moduleinfo}
 
 % Abstract. What is the problem we want to solve?
@@ -87,7 +92,7 @@ For inner (hidden) layers, it is more difficult to compute the partial derivativ
 That said, $\frac{\partial f}{\partial \mathit{net}_N^t} = (o_N^t - y^t) \phi'(\mathit{net}_N^t)$ is easy, where $t = 1,...,n_N$, but $\frac{\partial f}{\partial \mathit{net}_k^j}$ is hard, where $k = 1,...,N-1, j = 1,..,n_k$.
 This hard-to-compute statistic is referred to as \textit{delta error}, and let $\delta_k^j = \frac{\partial f}{\partial \mathit{net}_k^j}$, where $k = 1,...,N-1, j = 1,..,n_k$.
 If this is solved, the gradient can be easily computed as follow
-\[\frac{\partial f}{\partial u_{k-1}^{sj}} = \boxed{\frac{\partial f}{\partial \mathit{net}_k^j}} \cdot \frac{\partial \mathit{net}_k^j}{\partial u_{k-1}^{sj}} = \boxed{\delta_k^j} o_{k-1}^s,\]
+    \[\frac{\partial f}{\partial u_{k-1}^{sj}} = \boxed{\frac{\partial f}{\partial \mathit{net}_k^j}} \cdot \frac{\partial \mathit{net}_k^j}{\partial u_{k-1}^{sj}} = \boxed{\delta_k^j} o_{k-1}^s,\]
 where $k = 1,...,N-1, s = 0,...,n_{k-1}, j = 1,..,n_k$.
 To solve this, we introduce the popular backpropagation below.
 
@@ -117,6 +122,26 @@ To sum up, we need the following equation for error back propagation
 \[\boxed{\delta_{k}^j = \sum_{t=1}^{n_{k+1}} \left( \delta_{k+1}^t \cdot u_{k}^{jt} \right) \cdot \phi'(\mathit{net}_{k}^j)}\]
 where $k = 1,...,N-1$, and $j = 1,...,n_{k}$.
 
+\paragraph{Momentum updates.}
+Momentum\cite{momentum_ilya}\cite{momentum_cs231n} can help accelerate learning and avoid local minima when using gradient descent. We also support nesterov's accelarated gradient due to its look ahead characteristics. \\
+Here we need to introduce two new variables namely velocity and momentum. momentum must be in the range 0 to 1, where 0 means no momentum.
+The velocity is the same size as the coefficient and is accumulated in the direction of persistent reduction, which speeds up the optimization. The momentum value is responsible for damping the velocity and is analogous to the coefficient of friction. \\
+In classical momentum we first correct the velocity, and then update the model with that velocity, whereas in Nesterov momentum, we first move the model in the direction of momentum*velocity, then correct the velocity and finally use the updated model to calculate the gradient. The main difference being that in classical momentum, we compute the gradient before updating the model whereas in nesterov we first update the model and then compute the gradient from the updated position.\\
+
+Classic momentum update
+\[\begin{aligned}
+    \mathit{v} = \mu * \mathit{v} - \eta * \frac{\partial f}{\partial u_{k-1}^{sj}} &\quad \text{ (velocity update),} \\
+    \mathit{u} = \mathit{u} + \mathit{v} \\
+\end{aligned}\]
+
+Nesterov momentum update
+\[\begin{aligned}
+    \mathit{ua} = \mathit{u} + \mu * \mathit{v}  &\quad \text{ (nesterov's initial coefficient update to the model),} \\
+    \mathit{v} = \mu * \mathit{v} -  \eta * \frac{\partial f}{\partial ua_{k-1}^{sj}}  &\quad \text{ (velocity update, use the lookahead model $ua$ for gradient calculations),} \\
+    \mathit{u} = \mathit{u} - \eta * \frac{\partial f}{\partial ua_{k-1}^{sj}} \\
+\end{aligned}\]
+where $u$ is the coefficient vector, $v$ is the velocity vector, $\mu$ is the momentum value, $\eta$ is the learning rate and $\frac{\partial f}{\partial ua_{k-1}^{sj}}$ is the gradient calculated at the updated position $ua$
+
 \subsubsection{The $\mathit{Gradient}$ Function}
 \begin{algorithm}[mlp-gradient$(u, x, y)$] \label{alg:mlp-gradient}
 \alginput{Coefficients $u = \{ u_{k-1}^{sj} \; | \; k = 1,...,N, \: s = 0,...,n_{k-1}, \: j = 1,...,n_k\}$,\\
@@ -196,17 +221,28 @@ derivative of activation unit $\phi' : \mathbb{R} \to \mathbb{R}$}
 \end{algorithmic}
 \end{algorithm}
 
-\begin{algorithm}[mlp-train-iteration$(X, Y, \eta)$] \label{alg:mlp-train-iteration}
+\begin{algorithm}[mlp-train-iteration$(X, Y, \eta, \mu, n)$] \label{alg:mlp-train-iteration}
 \alginput{
 start vectors $X_{i...m} \in \mathbb{R}^{n_0}$,\\
 end vectors $Y_{i...m} \in \mathbb{R}^{n_N}$,\\
-learning rate $\eta$,\\}
+learning rate $\eta$,\\
+momentum $\mu$,\\
+nesterov flag n\\}
 \algoutput{Coefficients $u = \{ u_{k-1}^{sj} \; | \; k = 1,...,N, \: s = 0,...,n_{k-1}, \: j = 1,...,n_k\}$}
 \begin{algorithmic}[1]
     \State \texttt{Randomnly initialize u}
+    \State \texttt{Initialize velocity v to 0}
     \For{$i = 1,...,m$}
+        \If{n == True}
+            \State $u \set u + \mu * v$ \text{ (nesterov's initial coefficient update)}
+        \EndIf
         \State $\nabla f(u) \set \texttt{mlp-gradient}(u,X_i,Y_i)$
-        \State $u \set u - (\eta \nabla f(u) u + \lambda u)$
+        \State $v \set \mu * v - (\eta \nabla f(u) u + \lambda u)$ \Comment{(nesterov's initial coefficient update)}
+        \If{$\mu$ == 0 and n == False }
+            \State $u \set u + v$ \Comment{ (classic(non-nesterov) momentum update)}
+        \Else
+            \State $u \set u - (\eta \nabla f(u) u + \lambda u)$
+        \EndIf
     \EndFor
     \State \Return $u$
 \end{algorithmic}

http://git-wip-us.apache.org/repos/asf/madlib/blob/8bf15fde/doc/literature.bib
----------------------------------------------------------------------
diff --git a/doc/literature.bib b/doc/literature.bib
index 8c53813..3c07260 100644
--- a/doc/literature.bib
+++ b/doc/literature.bib
@@ -975,3 +975,14 @@ Applied Survival Analysis},
  acmid = {324140},
  publisher = {ACM}
 }
+
+@misc{momentum_cs231n,
+    Url = {http://cs231n.github.io/neural-networks-3/#sgd},
+    Title = {{CS231n Convolutional Neural Networks for Visual Recognition}},
+}
+
+@misc{momentum_ilya,
+    Url = {http://www.cs.utoronto.ca/~ilya/pubs/ilya_sutskever_phd_thesis.pdf},
+    Title = {{TRAINING RECURRENT NEURAL NETWORKS}},
+    Author = {{Ilya Sutskever}}
+}

http://git-wip-us.apache.org/repos/asf/madlib/blob/8bf15fde/src/modules/convex/mlp_igd.cpp
----------------------------------------------------------------------
diff --git a/src/modules/convex/mlp_igd.cpp b/src/modules/convex/mlp_igd.cpp
index 63f278f..0fb17b3 100644
--- a/src/modules/convex/mlp_igd.cpp
+++ b/src/modules/convex/mlp_igd.cpp
@@ -91,7 +91,8 @@ mlp_igd_transition::run(AnyType &args) {
             // args[8] is for weighting the input row, which is populated later.
             state.task.lambda = args[10].getAs<double>();
             MLPTask::lambda = state.task.lambda;
-
+            state.task.model.momentum = args[11].getAs<double>();
+            state.task.model.is_nesterov = static_cast<double>(args[12].getAs<bool>());
             if (!args[9].isNull()){
                 // initial coefficients are provided
                 MappedColumnVector warm_start_coeff = args[9].getAs<MappedColumnVector>();
@@ -119,21 +120,15 @@ mlp_igd_transition::run(AnyType &args) {
         state.reset();
     }
 
-    // tuple
-    ColumnVector indVar;
-    MappedColumnVector depVar;
+    MLPTuple tuple;
     try {
-        indVar = args[1].getAs<MappedColumnVector>();
-        MappedColumnVector y = args[2].getAs<MappedColumnVector>();
-        depVar.rebind(y.memoryHandle(), y.size());
+        tuple.indVar = args[1].getAs<MappedColumnVector>();;
+        tuple.depVar = args[2].getAs<MappedColumnVector>();
     } catch (const ArrayWithNullException &e) {
         return args[0];
     }
-    MLPTuple tuple;
-    tuple.indVar = indVar;
-    tuple.depVar.rebind(depVar.memoryHandle(), depVar.size());
-    tuple.weight = args[8].getAs<double>();
 
+    tuple.weight = args[8].getAs<double>();
     MLPIGDAlgorithm::transition(state, tuple);
     // Use the model from the previous iteration to compute the loss (note that
     // it is stored in Task's state, and the Algo's state holds the model from
@@ -211,6 +206,8 @@ mlp_minibatch_transition::run(AnyType &args) {
             state.model.activation = static_cast<double>(args[6].getAs<int>());
             state.model.is_classification = static_cast<double>(args[7].getAs<int>());
             // args[8] is for weighting the input row, which is populated later.
+            state.model.momentum = args[13].getAs<double>();
+            state.model.is_nesterov = static_cast<double>(args[14].getAs<bool>());
             if (!args[9].isNull()){
                 // initial coefficients are provided copy warm start into the model
                 MappedColumnVector warm_start_coeff = args[9].getAs<MappedColumnVector>();
@@ -240,23 +237,19 @@ mlp_minibatch_transition::run(AnyType &args) {
         state.reset();
     }
 
-    // tuple
-    Matrix indVar;
-    Matrix depVar;
+    MiniBatchTuple tuple;
     try {
         // Ideally there should be no NULLs in the pre-processed input data,
         // but keep it in a try block in case the user has modified the
         // pre-processed data in any way.
-        indVar = args[1].getAs<MappedMatrix>();
-        depVar = args[2].getAs<MappedMatrix>();
+        // The matrices are by default read as column-major. We will have to
+        // transpose it to get back the matrix like how it is in the database.
+        tuple.indVar = trans(args[1].getAs<MappedMatrix>());
+        tuple.depVar = trans(args[2].getAs<MappedMatrix>());
     } catch (const ArrayWithNullException &e) {
         return args[0];
     }
-    MiniBatchTuple tuple;
-    // The matrices are by default read as column-major. We will have to
-    // transpose it to get back the matrix like how it is in the database.
-    tuple.indVar = trans(indVar);
-    tuple.depVar = trans(depVar);
+
     tuple.weight = args[8].getAs<double>();
 
     /*
@@ -339,7 +332,7 @@ internal_mlp_igd_result::run(AnyType &args) {
     HandleTraits<ArrayHandle<double> >::ColumnVectorTransparentHandleMap
         flattenU;
     flattenU.rebind(&state.task.model.u[0](0, 0),
-                    state.task.model.arraySize(state.task.numberOfStages,
+                    state.task.model.coeffArraySize(state.task.numberOfStages,
                                                state.task.numbersOfUnits));
     AnyType tuple;
     tuple << flattenU
@@ -355,7 +348,7 @@ internal_mlp_minibatch_result::run(AnyType &args) {
     MLPMiniBatchState<ArrayHandle<double> > state = args[0];
     HandleTraits<ArrayHandle<double> >::ColumnVectorTransparentHandleMap flattenU;
     flattenU.rebind(&state.model.u[0](0, 0),
-                    state.model.arraySize(state.numberOfStages,
+                    state.model.coeffArraySize(state.numberOfStages,
                                           state.numbersOfUnits));
     AnyType tuple;
     tuple << flattenU
@@ -379,7 +372,12 @@ internal_predict_mlp::run(AnyType &args) {
     int is_dep_var_array_for_classification = args[8].getAs<int>();
     bool is_classification_response = is_classification && is_response;
 
-    model.rebind(&is_classification, &activation, &coeff.data()[0],
+    // The model rebind function is called by both predict and train functions.
+    // Since we have to use the same function, we are passing a dummy value for
+    // activation, momentum and nesterov because predict does not care
+    // about the actual values for these params.
+    const double dummy_value = static_cast<double>(-1);
+    model.rebind(&is_classification, &activation, &dummy_value, &dummy_value, &coeff.data()[0],
                  numberOfStages, &layerSizes.data()[0]);
     try {
         indVar = (args[1].getAs<MappedColumnVector>()-x_means).cwiseQuotient(x_stds);

http://git-wip-us.apache.org/repos/asf/madlib/blob/8bf15fde/src/modules/convex/task/mlp.hpp
----------------------------------------------------------------------
diff --git a/src/modules/convex/task/mlp.hpp b/src/modules/convex/task/mlp.hpp
index a13aa84..3915ab1 100644
--- a/src/modules/convex/task/mlp.hpp
+++ b/src/modules/convex/task/mlp.hpp
@@ -58,6 +58,18 @@ public:
             const Matrix                        &y,
             const double                        &stepsize);
 
+    static double getLossAndGradient(
+            model_type                    &model,
+            const Matrix                        &x,
+            const Matrix                        &y,
+            std::vector<Matrix>                 &total_gradient_per_layer,
+            const double                        &stepsize);
+
+    static double getLoss(
+            const ColumnVector                  &y_true,
+            const ColumnVector                  &y_estimated,
+            const bool                          is_classification);
+
     static double loss(
             const model_type                    &model,
             const independent_variables_type    &x,
@@ -120,74 +132,110 @@ double MLP<Model, Tuple>::lambda = 0;
 
 template <class Model, class Tuple>
 double
+MLP<Model, Tuple>::getLoss(const ColumnVector &y_true,
+                           const ColumnVector &y_estimated,
+                           const bool is_classification){
+    if(is_classification){
+        // cross entropy loss function
+        double clip = 1.e-10;
+        ColumnVector y_clipped = y_estimated.cwiseMax(clip).cwiseMin(1. - clip);
+        return -(y_true.array() * y_clipped.array().log() +
+                    (1 - y_true.array()) * (1 - y_clipped.array()).log()
+                ).sum();
+    }
+    else{
+        // squared loss
+        return 0.5 * (y_estimated - y_true).squaredNorm();
+    }
+}
+
+template <class Model, class Tuple>
+double
 MLP<Model, Tuple>::getLossAndUpdateModel(
         model_type           &model,
         const Matrix         &x_batch,
         const Matrix         &y_true_batch,
         const double         &stepsize) {
 
-    uint16_t num_layers = static_cast<uint16_t>(model.u.size()); // assuming nu. of layers >= 1
-    Index num_rows_in_batch = x_batch.rows();
     double total_loss = 0.;
+    // model is updated with the momentum step (i.e. velocity vector)
+    // if Nesterov Accelerated Gradient is enabled
+    model.nesterovUpdatePosition();
 
-    // gradient added over the batch
-    std::vector<Matrix> total_gradient_per_layer(num_layers);
-    for (Index k=0; k < num_layers; ++k)
+    // initialize gradient vector
+    std::vector<Matrix> total_gradient_per_layer(model.num_layers);
+    for (Index k=0; k < model.num_layers; ++k) {
         total_gradient_per_layer[k] = Matrix::Zero(model.u[k].rows(),
                                                    model.u[k].cols());
+    }
 
+    std::vector<ColumnVector> net, o, delta;
+    Index num_rows_in_batch = x_batch.rows();
     for (Index i=0; i < num_rows_in_batch; i++){
+        // gradient and loss added over the batch
         ColumnVector x = x_batch.row(i);
         ColumnVector y_true = y_true_batch.row(i);
 
-        std::vector<ColumnVector> net, o, delta;
         feedForward(model, x, net, o);
         backPropogate(y_true, o.back(), net, model, delta);
 
-        for (Index k=0; k < num_layers; k++){
+        // compute the gradient
+        for (Index k=0; k < model.num_layers; k++){
                 total_gradient_per_layer[k] += o[k] * delta[k].transpose();
         }
 
-        // loss computation
-        ColumnVector y_estimated = o.back();
-        if(model.is_classification){
-            double clip = 1.e-10;
-            y_estimated = y_estimated.cwiseMax(clip).cwiseMin(1.-clip);
-            total_loss += - (y_true.array()*y_estimated.array().log()
-                   + (-y_true.array()+1)*(-y_estimated.array()+1).log()).sum();
-        }
-        else{
-            total_loss += 0.5 * (y_estimated - y_true).squaredNorm();
-        }
+        // compute the loss
+        total_loss += getLoss(y_true, o.back(), model.is_classification);
     }
-    for (Index k=0; k < num_layers; k++){
+
+    // convert gradient to a gradient update vector
+    //  1. normalize to per row update
+    //  2. discount by stepsize
+    //  3. add regularization
+    //  4. make negative
+    for (Index k=0; k < model.num_layers; k++){
         Matrix regularization = MLP<Model, Tuple>::lambda * model.u[k];
         regularization.row(0).setZero(); // Do not update bias
-        model.u[k] -=
-            stepsize *
-            (total_gradient_per_layer[k] / static_cast<double>(num_rows_in_batch) +
-             regularization);
+        total_gradient_per_layer[k] = -stepsize * (total_gradient_per_layer[k] / static_cast<double>(num_rows_in_batch) +
+                                                  regularization);
+        model.updateVelocity(total_gradient_per_layer[k], k);
+        model.updatePosition(total_gradient_per_layer[k], k);
     }
+
     return total_loss;
+
 }
 
+
 template <class Model, class Tuple>
 void
 MLP<Model, Tuple>::gradientInPlace(
         model_type                          &model,
         const independent_variables_type    &x,
         const dependent_variable_type       &y_true,
-        const double                        &stepsize) {
-    size_t N = model.u.size(); // assuming nu. of layers >= 1
+        const double                        &stepsize)
+{
+    model.nesterovUpdatePosition();
+
     std::vector<ColumnVector> net, o, delta;
 
     feedForward(model, x, net, o);
     backPropogate(y_true, o.back(), net, model, delta);
 
-    for (size_t k=0; k < N; k++){
+    for (Index k=0; k < model.num_layers; k++){
         Matrix regularization = MLP<Model, Tuple>::lambda*model.u[k];
         regularization.row(0).setZero(); // Do not update bias
-        model.u[k] -= stepsize * (o[k] * delta[k].transpose() + regularization);
+        if (model.momentum > 0){
+            Matrix gradient = -stepsize * (o[k] * delta[k].transpose() + regularization);
+            model.updateVelocity(gradient, k);
+            model.updatePosition(gradient, k);
+        }
+        else {
+            // Updating model inline instead of using updatePosition because
+            // we suspect that updatePosition ends up creating a copy of the
+            // gradient even if it is passed by reference and hence making it slower.
+            model.u[k] -= stepsize * (o[k] * delta[k].transpose() + regularization);
+        }
     }
 }
 
@@ -197,21 +245,14 @@ MLP<Model, Tuple>::loss(
         const model_type                    &model,
         const independent_variables_type    &x,
         const dependent_variable_type       &y_true) {
+
     // Here we compute the loss. In the case of regression we use sum of square errors
     // In the case of classification the loss term is cross entropy.
     std::vector<ColumnVector> net, o;
     feedForward(model, x, net, o);
     ColumnVector y_estimated = o.back();
 
-    if(model.is_classification){
-        double clip = 1.e-10;
-        y_estimated = y_estimated.cwiseMax(clip).cwiseMin(1.-clip);
-        return - (y_true.array()*y_estimated.array().log()
-               + (-y_true.array()+1)*(-y_estimated.array()+1).log()).sum();
-    }
-    else{
-        return 0.5 * (y_estimated - y_true).squaredNorm();
-    }
+    return getLoss(y_true, y_estimated, model.is_classification);
 }
 
 template <class Model, class Tuple>
@@ -324,3 +365,4 @@ MLP<Model, Tuple>::backPropogate(
 
 #endif
 
+

http://git-wip-us.apache.org/repos/asf/madlib/blob/8bf15fde/src/modules/convex/type/model.hpp
----------------------------------------------------------------------
diff --git a/src/modules/convex/type/model.hpp b/src/modules/convex/type/model.hpp
index 11013e5..440e384 100644
--- a/src/modules/convex/type/model.hpp
+++ b/src/modules/convex/type/model.hpp
@@ -105,15 +105,18 @@ template <class Handle>
 struct MLPModel {
     typename HandleTraits<Handle>::ReferenceToDouble is_classification;
     typename HandleTraits<Handle>::ReferenceToDouble activation;
+    typename HandleTraits<Handle>::ReferenceToDouble momentum;
+    typename HandleTraits<Handle>::ReferenceToDouble is_nesterov;
+
+    uint16_t num_layers;
+
     // std::vector<Eigen::Map<Matrix > > u;
     std::vector<MutableMappedMatrix> u;
+    std::vector<MutableMappedMatrix> velocity;
 
     /**
-     * @brief Space needed.
+     * @brief Space needed for the whole model
      *
-     * Extra information besides the values in the matrix, like dimension is
-     * necessary for a matrix, so that it can perform operations. These are
-     * stored in the HandleMap.
      */
     static inline uint32_t arraySize(const uint16_t &inNumberOfStages,
                                      const double *inNumbersOfUnits) {
@@ -126,45 +129,96 @@ struct MLPModel {
         for (k = 0; k < N; k ++) {
             size += static_cast<uint32_t>((n[k] + 1) * (n[k+1]));
         }
+        //TODO conditionally assign size based on momentum
+        return size * 2;     // position (u) + velocity
+    }
+
+    /**
+     * @brief Space needed for the coefficients
+     *
+     */
+    static inline uint32_t coeffArraySize(const uint16_t &inNumberOfStages,
+                                          const double *inNumbersOfUnits) {
+        // inNumberOfStages == 0 is not an expected value, but
+        // it won't cause exception -- returning 0
+        uint32_t size = 0;
+        size_t N = inNumberOfStages;
+        const double *n = inNumbersOfUnits;
+        size_t k;
+        for (k = 0; k < N; k ++) {
+            size += (n[k] + 1) * (n[k+1]);
+        }
         return size;     // weights (u)
     }
 
     size_t rebind(const double *is_classification_in,
                     const double *activation_in,
+                    const double *momentum_in,
+                    const double *is_nesterov_in,
                     const double *data,
                     const uint16_t &inNumberOfStages,
                     const double *inNumbersOfUnits) {
-        size_t N = inNumberOfStages;
         const double *n = inNumbersOfUnits;
         size_t k;
 
         is_classification.rebind(is_classification_in);
         activation.rebind(activation_in);
+        momentum.rebind(momentum_in);
+        is_nesterov.rebind(is_nesterov_in);
+        num_layers = inNumberOfStages;
 
         size_t sizeOfU = 0;
         u.clear();
-        for (k = 0; k < N; k ++) {
+        for (k = 0; k < num_layers; k ++) {
             u.push_back(MutableMappedMatrix());
             u[k].rebind(const_cast<double *>(data + sizeOfU),
                         static_cast<Index>(n[k] + 1),
                         static_cast<Index>(n[k+1]));
             sizeOfU += static_cast<size_t>((n[k] + 1) * (n[k+1]));
         }
-
+        for (k = 0; k < num_layers; k ++) {
+            velocity.push_back(MutableMappedMatrix());
+            velocity[k].rebind(const_cast<double *>(data + sizeOfU), n[k] + 1, n[k+1]);
+            sizeOfU += (n[k] + 1) * (n[k+1]);
+        }
         return sizeOfU;
     }
 
     void initialize(const uint16_t &inNumberOfStages,
-                    const double *inNumbersOfUnits){
-        size_t N = inNumberOfStages;
-        const double *n = inNumbersOfUnits;
-        size_t k;
-        double span;
-        for (k =0; k < N; ++k){
+                    const double *inNumbersOfUnits) {
+        num_layers = inNumberOfStages;
+
+        for (size_t k =0; k < num_layers; ++k){
             // Initalize according to Glorot and Bengio (2010)
             // See design doc for more info
-            span = 0.5 * sqrt(6.0 / (n[k] + n[k+1]));
+            double span = 0.5 * sqrt(6.0 / (inNumbersOfUnits[k] + inNumbersOfUnits[k+1]));
             u[k] << span * Matrix::Random(u[k].rows(), u[k].cols());
+            velocity[k].setZero();
+        }
+    }
+
+    void updateVelocity(const Matrix &gradient, const Index layer_index){
+        if (momentum > 0.){
+            // if momentum is enabled
+            velocity[layer_index] = momentum * velocity[layer_index] + gradient;
+        }
+    }
+
+    void updatePosition(const Matrix &gradient, const Index layer_index){
+        if (momentum > 0 and not is_nesterov){
+            u[layer_index] += velocity[layer_index];
+        }
+        else {
+            // update is same for non momentum and nesterov
+            u[layer_index] += gradient;
+        }
+    }
+
+    void nesterovUpdatePosition(){
+        if (momentum > 0 and is_nesterov){
+            for (size_t k = 0; k < u.size(); k++){
+                u[k] += momentum * velocity[k];
+            }
         }
     }
 
@@ -181,6 +235,7 @@ struct MLPModel {
         size_t k;
         for (k = 0; k < u.size(); k ++) {
             u[k].setZero();
+            velocity[k].setZero();
         }
     }
 
@@ -223,9 +278,13 @@ struct MLPModel {
         size_t k;
         for (k = 0; k < u.size() && k < inOtherModel.u.size(); k ++) {
             u[k] = inOtherModel.u[k];
+            velocity[k] = inOtherModel.velocity[k];
         }
+        num_layers = inOtherModel.num_layers;
         is_classification = inOtherModel.is_classification;
         activation = inOtherModel.activation;
+        momentum = inOtherModel.momentum;
+        is_nesterov = inOtherModel.is_nesterov;
 
         return *this;
     }

http://git-wip-us.apache.org/repos/asf/madlib/blob/8bf15fde/src/modules/convex/type/state.hpp
----------------------------------------------------------------------
diff --git a/src/modules/convex/type/state.hpp b/src/modules/convex/type/state.hpp
index 6bd5854..c1394b4 100644
--- a/src/modules/convex/type/state.hpp
+++ b/src/modules/convex/type/state.hpp
@@ -632,6 +632,8 @@ public:
             + 1                         // lambda
             + 1                         // is_classification
             + 1                         // activation
+            + 1                         // momentum
+            + 1                         // is_nesterov
             + sizeOfModel               // model
             + sizeOfModel               // incrModel
             + 1                         // numRows
@@ -669,16 +671,13 @@ private:
             reinterpret_cast<dimension_pointer_type>(&mStorage[1]);
         task.stepsize.rebind(&mStorage[N + 2]);
         task.lambda.rebind(&mStorage[N + 3]);
-        size_t sizeOfModel = task.model.rebind(&mStorage[N + 4],
-                                                 &mStorage[N + 5],
-                                                 &mStorage[N + 6],
-                                                 task.numberOfStages,
-                                                 task.numbersOfUnits);
+        size_t sizeOfModel = task.model.rebind(&mStorage[N + 4],&mStorage[N + 5],&mStorage[N + 6],&mStorage[N + 7], &mStorage[N + 8],
+                task.numberOfStages, task.numbersOfUnits);
 
-        algo.incrModel.rebind(&mStorage[N + 4],&mStorage[N + 5],&mStorage[N + 6 + sizeOfModel],
+        algo.incrModel.rebind(&mStorage[N + 4], &mStorage[N + 5], &mStorage[N + 6], &mStorage[N + 7], &mStorage[N + 8 + sizeOfModel],
                 task.numberOfStages, task.numbersOfUnits);
-        algo.numRows.rebind(&mStorage[N + 6 + 2*sizeOfModel]);
-        algo.loss.rebind(&mStorage[N + 7 + 2*sizeOfModel]);
+        algo.numRows.rebind(&mStorage[N + 8 + 2*sizeOfModel]);
+        algo.loss.rebind(&mStorage[N + 9 + 2*sizeOfModel]);
 
     }
 
@@ -796,12 +795,13 @@ public:
             + 1                         // lambda
             + 1                         // is_classification
             + 1                         // activation
+            + 1                          // momentum
+            + 1                          // is_nesterov
             + sizeOfModel               // model
             + 1                         // numRows
             + 1                         // batchSize
             + 1                         // nEpochs
             + 1;                        // loss
-
     }
 
     Handle mStorage;
@@ -834,14 +834,17 @@ private:
         stepsize.rebind(&mStorage[N + 2]);
         lambda.rebind(&mStorage[N + 3]);
         size_t sizeOfModel = model.rebind(&mStorage[N + 4],
-                                               &mStorage[N + 5],
-                                               &mStorage[N + 6],
-                                               numberOfStages,
-                                               numbersOfUnits);
-        numRows.rebind(&mStorage[N + 6 + sizeOfModel]);
-        batchSize.rebind(&mStorage[N + 7 + sizeOfModel]);
-        nEpochs.rebind(&mStorage[N + 8 + sizeOfModel]);
-        loss.rebind(&mStorage[N + 9 + sizeOfModel]);
+                                          &mStorage[N + 5],
+                                          &mStorage[N + 6],
+                                          &mStorage[N + 7],
+                                          &mStorage[N + 8],
+                                          numberOfStages,
+                                          numbersOfUnits);
+
+        numRows.rebind(&mStorage[N + 8 + sizeOfModel]);
+        batchSize.rebind(&mStorage[N + 9 + sizeOfModel]);
+        nEpochs.rebind(&mStorage[N + 10 + sizeOfModel]);
+        loss.rebind(&mStorage[N + 11 + sizeOfModel]);
     }
 
 

http://git-wip-us.apache.org/repos/asf/madlib/blob/8bf15fde/src/modules/convex/type/tuple.hpp
----------------------------------------------------------------------
diff --git a/src/modules/convex/type/tuple.hpp b/src/modules/convex/type/tuple.hpp
index 7354bb9..e7c1187 100644
--- a/src/modules/convex/type/tuple.hpp
+++ b/src/modules/convex/type/tuple.hpp
@@ -64,7 +64,7 @@ typedef ExampleTuple<MappedColumnVector, double> GLMTuple;
 // madlib::modules::convex::MatrixIndex
 typedef ExampleTuple<MatrixIndex, double> LMFTuple;
 
-typedef ExampleTuple<ColumnVector, MappedColumnVector> MLPTuple;
+typedef ExampleTuple<ColumnVector, ColumnVector> MLPTuple;
 typedef ExampleTuple<Matrix, Matrix> MiniBatchTuple;
 
 } // namespace convex

http://git-wip-us.apache.org/repos/asf/madlib/blob/8bf15fde/src/ports/postgres/modules/convex/mlp.sql_in
----------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/convex/mlp.sql_in b/src/ports/postgres/modules/convex/mlp.sql_in
index 64ed62d..c4f8271 100644
--- a/src/ports/postgres/modules/convex/mlp.sql_in
+++ b/src/ports/postgres/modules/convex/mlp.sql_in
@@ -53,18 +53,18 @@ neural net can be interpreted as the probability that a given input
 feature belongs to a specific class.
 
 MLP can be used with or without mini-batching.
-The advantage of using mini-batching is that it 
+The advantage of using mini-batching is that it
 can perform better than stochastic gradient descent
 (default MADlib optimizer)
-because it uses more than one training example at a time, 
+because it uses more than one training example at a time,
 typically resulting faster and smoother convergence [3].
 
 @note
-In order to use mini-batching, you must first run 
-the <a href="group__grp__minibatch__preprocessing.html">Mini-Batch Preprocessor</a>, 
-which is a utility that prepares input data for 
+In order to use mini-batching, you must first run
+the <a href="group__grp__minibatch__preprocessing.html">Mini-Batch Preprocessor</a>,
+which is a utility that prepares input data for
 use by models that support mini-batch as an optimization option,
-such as MLP.  This is a one-time operation and you would only 
+such as MLP.  This is a one-time operation and you would only
 need to re-run the preprocessor if your input data has changed,
 or if you change the grouping parameter.
 
@@ -110,7 +110,7 @@ mlp_classification(
   containing the packed independent variables.
 
   @note
-  If you are not using mini-batching, 
+  If you are not using mini-batching,
   please note that an intercept variable should not be included as part
   of this expression - this is different from other MADlib modules.  Also
   please note that independent variables should be encoded properly.
@@ -123,7 +123,7 @@ mlp_classification(
   <dt>dependent_varname</dt>
   <dd> TEXT. Name of the dependent variable column. For classification, supported types are:
   text, varchar, character varying, char, character
-  integer, smallint, bigint, and boolean.  
+  integer, smallint, bigint, and boolean.
   If you are using mini-batching, set this parameter to 'dependent_varname'
   which is the hardcoded name of the column from the mini-batch preprocessor
   containing the packed dependent variables.</dd>
@@ -152,45 +152,45 @@ mlp_classification(
   <DT>weights (optional)</DT>
   <DD>TEXT, default: 1.
     Weights for input rows. Column name which specifies the weight for each input row.
-    This weight will be incorporated into the update during stochastic gradient descent (SGD),
-    but will not be used
-    for loss calculations. If not specified, weight for each row will default to 1 (equal
-    weights).  Column should be a numeric type.
+    This weight will be incorporated into the update during stochastic gradient
+    descent (SGD), but will not be used for loss calculations. If not specified,
+     weight for each row will default to 1 (equal weights).  Column should be a
+      numeric type.
     @note
     The 'weights' parameter is not currently for mini-batching.
   </DD>
 
   <DT>warm_start (optional)</DT>
   <DD>BOOLEAN, default: FALSE.
-    Initalize weights with the coefficients from the last call of the training function.
-    If set to true, weights will be initialized from the output_table generated by the
-    previous run. Note that all parameters
-    other than optimizer_params and verbose must remain constant
-    between calls when warm_start is used.
+    Initalize weights with the coefficients from the last call of the training
+    function. If set to true, weights will be initialized from the output_table
+    generated by the previous run. Note that all parameters other than
+    optimizer_params and verbose must remain constant between calls when
+    warm_start is used.
 
     @note
     The warm start feature works based on the name of the output_table.
     When using warm start, do not drop the output table or the output table summary
-    before calling the training function, since these are needed to obtain the weights
-    from the previous run.
+    before calling the training function, since these are needed to obtain the
+    weights from the previous run.
     If you are not using warm start, the output table and the output table
     summary must be dropped in the usual way before calling the training function.
+
   </DD>
 
   <DT>verbose (optional)</DT>
-  <DD>BOOLEAN, default: FALSE. Provides verbose output of the results of training, including
-  the value of loss at each iteration.</DD>
+  <DD>BOOLEAN, default: FALSE. Provides verbose output of the results of training,
+  including the value of loss at each iteration.</DD>
 
   <DT>grouping_col (optional)</DT>
   <DD>TEXT, default: NULL.
-    A single column or a list of comma-separated
-    columns that divides the input data into discrete groups, resulting in one
-    model per group. When this value is NULL, no grouping is used and
-    a single model is generated for all data.  If you are using mini-batching, 
-    you must have run the mini-batch preprocessor with exactly the same
-    groups that you specify here for MLP training.  If you change the 
-    groups, or remove the groups, then you must re-run the mini-batch 
-    preprocessor.</dd>
+    A single column or a list of comma-separated columns that divides the input
+    data into discrete groups, resulting in one model per group. When this value
+    is NULL, no grouping is used and a single model is generated for all data.
+    If you are using mini-batching, you must have run the mini-batch
+    preprocessor with exactly the same groups that you specify here for MLP
+    training.  If you change the groups, or remove the groups, then you must re-
+    run the mini-batch preprocessor.</dd>
   </DD>
 </dl>
 
@@ -247,6 +247,14 @@ A summary table named \<output_table\>_summary is also created, which has the fo
         <td>The learning rate policy as given in optimizer_params.</td>
     </tr>
     <tr>
+        <th>momentum</th>
+        <td>Momentum value as given in optimizer_params.</td>
+    </tr>
+    <tr>
+        <th>nesterov</th>
+        <td>Nesterov value as given in optimizer_params.</td>
+    </tr>
+    <tr>
         <th>n_iterations</th>
         <td>The number of iterations run.</td>
     </tr>
@@ -277,14 +285,14 @@ A summary table named \<output_table\>_summary is also created, which has the fo
     </tr>
     <tr>
         <th>grouping_col</th>
-        <td>NULL if no grouping_col was specified during training, and a comma separated
-        list of grouping column names if not.</td>
+        <td>NULL if no grouping_col was specified during training, and a
+        comma-separated list of grouping column names if not.</td>
     </tr>
 
    </table>
 
-A standardization table named \<output_table\>_standardization is also create, that has the
-following columns:
+A standardization table named \<output_table\>_standardization is also create,
+that has the following columns:
   <table class="output">
     <tr>
         <th>mean</th>
@@ -296,8 +304,8 @@ following columns:
     </tr>
     <tr>
         <th>grouping columns</th>
-        <td>If grouping_col is specified during training, a column for each grouping column
-        is created.</td>
+        <td>If grouping_col is specified during training, a column for each
+        grouping column is created.</td>
     </tr>
   </table>
 
@@ -323,8 +331,8 @@ mlp_regression(
 \b Arguments
 
 Parameters for regression are largely the same as for classification. In the
-model table, the loss refers to mean square error instead of cross entropy loss. In the
-summary table, there is no classes column. The following
+model table, the loss refers to mean square error instead of cross entropy loss.
+In the summary table, there is no classes column. The following
 arguments have specifications which differ from mlp_classification:
 <DL class="arglist">
 <DT>dependent_varname</DT>
@@ -353,7 +361,9 @@ the parameter is ignored.
    lambda = &lt;value>,
    tolerance = &lt;value>,
    batch_size = &lt;value>,
-   n_epochs = &lt;value>'
+   n_epochs = &lt;value>,
+   momentum = &lt;value>,
+   nesterov = &lt;value>'
 </pre>
 \b Optimizer \b Parameters
 <DL class="arglist">
@@ -425,6 +435,25 @@ This parameter is only used in the case of mini-batching.
 each batch is used by the optimizer.  This parameter
 is only used in the case of mini-batching.
 </DD>
+
+<DT>momentum</dt>
+<DD>Default: 0.9. Momentum can help accelerate learning and
+avoid local minima when using gradient descent. Value must be in the
+range 0 to 1, where 0 means no momentum.
+</DD>
+
+<DT>nesterov</dt>
+<DD>Default: TRUE. Nesterov momentum can provide better results than using
+classical momentum alone, due to its look-ahead characteristics. In classical
+momentum we correct the velocity and then update the model with that velocity,
+whereas in Nesterov Accelerated Gradient method, we first  move the model in the
+direction of velocity, compute the gradient using this updated model, and then
+add this gradient back into the model. The main difference being that in
+classical momentum, we compute the gradient before updating the model whereas in
+nesterov we first update the model and then compute the gradient from the
+updated position.
+</DD>
+
 </DL>
 
 @anchor predict
@@ -608,6 +637,8 @@ dependent_vartype    | character varying
 tolerance            | 0
 learning_rate_init   | 0.003
 learning_rate_policy | constant
+momentum             | 0.9
+nesterov             | t
 n_iterations         | 500
 n_tries              | 1
 layer_sizes          | {4,5,2}
@@ -771,7 +802,7 @@ SELECT madlib.mlp_predict(
 SELECT * FROM mlp_prediction JOIN iris_data USING (id) ORDER BY id;
 </pre>
 <pre class="result">
- id | estimated_class_text |    attributes     |   class_text    | class |   state   
+ id | estimated_class_text |    attributes     |   class_text    | class |   state
 ----+----------------------+-------------------+-----------------+-------+-----------
   1 | Iris_setosa          | {5.0,3.2,1.2,0.2} | Iris_setosa     |     1 | Alaska
   2 | Iris_setosa          | {5.5,3.5,1.3,0.2} | Iris_setosa     |     1 | Alaska
@@ -842,7 +873,7 @@ WHERE mlp_prediction.estimated_class_text != iris_data.class_text;
 
 -# Now, use the n_tries optimizer parameter to learn and choose the best model
 among n_tries number of models learnt by the algorithm. Run only for 50 iterations
-and choose the best model from this short run. Note we are not using mini-batching 
+and choose the best model from this short run. Note we are not using mini-batching
 here.
 <pre class="example">
 DROP TABLE IF EXISTS mlp_model, mlp_model_summary, mlp_model_standardization;
@@ -1143,7 +1174,7 @@ JOIN mlp_regress_prediction USING (id);
 
 <h4>Regression with Mini-Batching</h4>
 
--# Call min-batch preprocessor using 
+-# Call min-batch preprocessor using
 the same data set as above:
 <pre class="example">
 DROP TABLE IF EXISTS lin_housing_packed, lin_housing_packed_summary, lin_housing_packed_standardization;
@@ -1198,7 +1229,7 @@ SELECT madlib.mlp_predict(
 SELECT *, ABS(y-estimated_y) as abs_diff FROM lin_housing JOIN mlp_regress_prediction USING (id) ORDER BY id;
 </pre>
 <pre class="result">
- id |                                        x                                         | zipcode |  y   | zipcode |   estimated_y    |      abs_diff      
+ id |                                        x                                         | zipcode |  y   | zipcode |   estimated_y    |      abs_diff
 ----+----------------------------------------------------------------------------------+---------+------+---------+------------------+--------------------
   1 | {1,0.00632,18.00,2.310,0,0.5380,6.5750,65.20,4.0900,1,296.0,15.30,396.90,4.98}   |   94016 |   24 |   94016 | 23.9714991250013 | 0.0285008749987092
   2 | {1,0.02731,0.00,7.070,0,0.4690,6.4210,78.90,4.9671,2,242.0,17.80,396.90,9.14}    |   94016 | 21.6 |   94016 | 22.3655180133895 |  0.765518013389535
@@ -1224,11 +1255,11 @@ SELECT *, ABS(y-estimated_y) as abs_diff FROM lin_housing JOIN mlp_regress_predi
 </pre>
 RMS error:
 <pre class="example">
-SELECT SQRT(SUM(ABS(y-estimated_y))/COUNT(y)) as rms_error FROM lin_housing 
+SELECT SQRT(SUM(ABS(y-estimated_y))/COUNT(y)) as rms_error FROM lin_housing
 JOIN mlp_regress_prediction USING (id);
 </pre>
 <pre class="result">
-     rms_error     
+     rms_error
 -------------------+
  0.912158035902468
 (1 row)
@@ -1303,7 +1334,7 @@ SELECT madlib.mlp_predict(
 SELECT * FROM lin_housing JOIN mlp_regress_prediction USING (zipcode, id) ORDER BY zipcode, id;
 </pre>
 <pre class="result">
- zipcode | id |                                        x                                         |  y   |   estimated_y    
+ zipcode | id |                                        x                                         |  y   |   estimated_y
 ---------+----+----------------------------------------------------------------------------------+------+------------------
    20001 | 12 | {1,0.11747,12.50,7.870,0,0.5240,6.0090,82.90,6.2267,5,311.0,15.20,396.90,13.27}  | 18.9 | 19.2272848285357
    20001 | 13 | {1,0.09378,12.50,7.870,0,0.5240,5.8890,39.00,5.4509,5,311.0,15.20,390.50,15.71}  | 21.7 | 21.3979318641202
@@ -1353,7 +1384,7 @@ For details on backpropogation, see [2].
 University of Wisconsin Madison: Computer-Aided Engineering. Web. 12 July 2017,
 http://homepages.cae.wisc.edu/~ece539/videocourse/notes/pdf/lec%2011%20MLP%20(3)%20BP.pdf
 
-[3] "Neural Networks for Machine Learning", Lectures 6a and 6b on mini-batch gradient descent,  
+[3] "Neural Networks for Machine Learning", Lectures 6a and 6b on mini-batch gradient descent,
 Geoffrey Hinton with Nitish Srivastava and Kevin Swersky,
 http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf
 
@@ -1383,7 +1414,9 @@ CREATE FUNCTION MADLIB_SCHEMA.mlp_igd_transition(
         is_classification  INTEGER,
         weight             DOUBLE PRECISION,
         warm_start_coeff   DOUBLE PRECISION[],
-        lambda             DOUBLE PRECISION
+        lambda             DOUBLE PRECISION,
+        momentum           DOUBLE PRECISION,
+        is_nesterov        BOOLEAN
     )
 RETURNS DOUBLE PRECISION[]
 AS 'MODULE_PATHNAME'
@@ -1402,7 +1435,9 @@ CREATE FUNCTION MADLIB_SCHEMA.mlp_minibatch_transition(
         warm_start_coeff   DOUBLE PRECISION[],
         lambda             DOUBLE PRECISION,
         batch_size         INTEGER,
-        n_epochs           INTEGER
+        n_epochs           INTEGER,
+        momentum           DOUBLE PRECISION,
+        is_nesterov        BOOLEAN
     )
 RETURNS DOUBLE PRECISION[]
 AS 'MODULE_PATHNAME'
@@ -1448,13 +1483,15 @@ CREATE AGGREGATE MADLIB_SCHEMA.mlp_igd_step(
         /* is_classification */   INTEGER,
         /* weight */              DOUBLE PRECISION,
         /* warm_start_coeff */    DOUBLE PRECISION[],
-        /* lambda */              DOUBLE PRECISION
+        /* lambda */              DOUBLE PRECISION,
+        /* momentum */            DOUBLE PRECISION,
+        /* is_nesterov */         BOOLEAN
         )(
     STYPE=DOUBLE PRECISION[],
     SFUNC=MADLIB_SCHEMA.mlp_igd_transition,
     m4_ifdef(`__POSTGRESQL__', `', `prefunc=MADLIB_SCHEMA.mlp_igd_merge,')
     FINALFUNC=MADLIB_SCHEMA.mlp_igd_final,
-    INITCOND='{0,0,0,0,0,0,0,0}'
+    INITCOND='{0,0,0,0,0,0,0,0,0,0,0,0}'
 );
 -------------------------------------------------------------------------
 
@@ -1474,13 +1511,15 @@ CREATE AGGREGATE MADLIB_SCHEMA.mlp_minibatch_step(
         /* warm_start_coeff */    DOUBLE PRECISION[],
         /* lambda */              DOUBLE PRECISION,
         /* batch_size */          INTEGER,
-        /* n_epochs */            INTEGER
+        /* n_epochs */            INTEGER,
+        /* momentum */            DOUBLE PRECISION,
+        /* is_nesterov */         BOOLEAN
         )(
     STYPE=DOUBLE PRECISION[],
     SFUNC=MADLIB_SCHEMA.mlp_minibatch_transition,
     m4_ifdef(`__POSTGRESQL__', `', `prefunc=MADLIB_SCHEMA.mlp_minibatch_merge,')
     FINALFUNC=MADLIB_SCHEMA.mlp_minibatch_final,
-    INITCOND='{0,0,0,0,0,0,0,0,0,0,0,0}'
+    INITCOND='{0,0,0,0,0,0,0,0,0,0,0,0,0,0}'
 );
 -------------------------------------------------------------------------
 

http://git-wip-us.apache.org/repos/asf/madlib/blob/8bf15fde/src/ports/postgres/modules/convex/mlp_igd.py_in
----------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/convex/mlp_igd.py_in b/src/ports/postgres/modules/convex/mlp_igd.py_in
index 2cfa12f..1ea80a5 100644
--- a/src/ports/postgres/modules/convex/mlp_igd.py_in
+++ b/src/ports/postgres/modules/convex/mlp_igd.py_in
@@ -75,31 +75,34 @@ def mlp(schema_madlib, source_table, output_table, independent_varname,
     """
     warm_start = bool(warm_start)
     optimizer_params = _get_optimizer_params(optimizer_param_str or "")
+    summary_table = add_postfix(output_table, "_summary")
+    standardization_table = add_postfix(output_table, "_standardization")
+    hidden_layer_sizes = hidden_layer_sizes or []
+
+    _validate_args(source_table, output_table, summary_table,
+                   standardization_table, independent_varname,
+                   dependent_varname, hidden_layer_sizes, optimizer_params,
+                   warm_start, activation, grouping_col)
 
-    tolerance = optimizer_params["tolerance"]
-    n_iterations = optimizer_params["n_iterations"]
-    step_size_init = optimizer_params["learning_rate_init"]
-    iterations_per_step = optimizer_params["iterations_per_step"]
-    power = optimizer_params["power"]
-    gamma = optimizer_params["gamma"]
+    tolerance = optimizer_params['tolerance']
+    n_iterations = optimizer_params['n_iterations']
+    step_size_init = optimizer_params['learning_rate_init']
+    iterations_per_step = optimizer_params['iterations_per_step']
+    power = optimizer_params['power']
+    gamma = optimizer_params['gamma']
     step_size = step_size_init
-    n_tries = optimizer_params["n_tries"]
+    n_tries = optimizer_params['n_tries']
     # lambda is a reserved word in python
-    lambda_ = optimizer_params["lambda"]
+    lambda_ = optimizer_params['lambda']
     batch_size = optimizer_params['batch_size']
     n_epochs = optimizer_params['n_epochs']
-
-    summary_table = add_postfix(output_table, "_summary")
-    standardization_table = add_postfix(output_table, "_standardization")
-    hidden_layer_sizes = hidden_layer_sizes or []
+    momentum = optimizer_params['momentum']
+    is_nesterov = optimizer_params['nesterov']
 
     # Note that we don't support weights with mini batching yet, so validate
     # this based on is_minibatch_enabled.
     weights = '1' if not weights or not weights.strip() else weights.strip()
-    _validate_args(source_table, output_table, summary_table,
-                   standardization_table, independent_varname,
-                   dependent_varname, hidden_layer_sizes, optimizer_params,
-                   warm_start, activation, grouping_col)
+
     is_minibatch_enabled = check_if_minibatch_enabled(source_table, independent_varname)
     _validate_params_based_on_minibatch(source_table, independent_varname,
                                         dependent_varname, weights,
@@ -107,7 +110,7 @@ def mlp(schema_madlib, source_table, output_table, independent_varname,
                                         is_minibatch_enabled)
     activation = _get_activation_function_name(activation)
     learning_rate_policy = _get_learning_rate_policy_name(
-                                optimizer_params["learning_rate_policy"])
+        optimizer_params["learning_rate_policy"])
     activation_index = _get_activation_index(activation)
 
     # The original dependent_varname is required later if warm start is
@@ -251,7 +254,9 @@ def mlp(schema_madlib, source_table, output_table, independent_varname,
         "x_mean_table": x_mean_table,
         "batch_size": batch_size,
         "n_epochs": n_epochs,
-        "start_coeff": start_coeff
+        "start_coeff": start_coeff,
+        "momentum": momentum,
+        "is_nesterov": is_nesterov
     }
     # variables to be used by GroupIterationController
     it_args.update({
@@ -305,7 +310,9 @@ def mlp(schema_madlib, source_table, output_table, independent_varname,
                             ({start_coeff})::DOUBLE PRECISION[],
                             {lambda_},
                             {batch_size}::integer,
-                            {n_epochs}::integer
+                            {n_epochs}::integer,
+                            {momentum}::FLOAT8,
+                            {is_nesterov}::boolean
                         )
                         """
                 else:
@@ -320,7 +327,9 @@ def mlp(schema_madlib, source_table, output_table, independent_varname,
                             {is_classification},
                             ({weights})::DOUBLE PRECISION,
                             ({start_coeff})::DOUBLE PRECISION[],
-                            {lambda_}
+                            {lambda_},
+                            {momentum}::FLOAT8,
+                            {is_nesterov}::boolean
                         )
                         """
                 it.update(train_sql)
@@ -498,6 +507,8 @@ def _create_summary_table(args):
             tolerance FLOAT,
             learning_rate_init FLOAT,
             learning_rate_policy TEXT,
+            momentum FLOAT,
+            nesterov BOOLEAN,
             n_iterations INTEGER,
             n_tries INTEGER,
             layer_sizes INTEGER[],
@@ -520,6 +531,8 @@ def _create_summary_table(args):
             {tolerance},
             {step_size_init},
             '{learning_rate_policy}',
+            {momentum},
+            {is_nesterov},
             {n_iterations},
             {n_tries},
             {layer_sizes_str},
@@ -616,7 +629,9 @@ def _get_optimizer_params(param_str):
         "power": (0.5, float),
         "lambda": (0, float),
         "n_epochs": (1, int),
-        "batch_size": (1, int)
+        "batch_size": (1, int),
+        "momentum": (0.9, float),
+        "nesterov": (True, bool)
     }
     param_defaults = dict([(k, v[0]) for k, v in params_defaults.items()])
     param_types = dict([(k, v[1]) for k, v in params_defaults.items()])
@@ -794,6 +809,8 @@ def _validate_args(source_table, output_table, summary_table,
             "MLP error: batch_size should be greater than 0.")
     _assert(optimizer_params["n_epochs"] > 0,
             "MLP error: n_epochs should be greater than 0.")
+    _assert(0 <= optimizer_params["momentum"] <= 1,
+            "MLP error: momentum should be in the range [0, 1].")
 
     if grouping_col:
         cols_in_tbl_valid(source_table,
@@ -801,6 +818,7 @@ def _validate_args(source_table, output_table, summary_table,
                           'MLP',
                           invalid_names=[independent_varname, dependent_varname])
 
+
 def _get_learning_rate_policy_name(learning_rate_policy):
     if not learning_rate_policy:
         learning_rate_policy = 'constant'
@@ -1199,6 +1217,8 @@ def mlp_help(schema_madlib, message, is_classification):
     tolerance            -- The tolerance as given in optimizer_params.
     learning_rate_init   -- The initial learning rate as given in optimizer_params.
     learning_rate_policy -- The learning rate policy as given in optimizer_params.
+    momentum             -- Momentum value as given in optimizer_params.
+    nesterov             -- Nesterov value as given in optimizer_params.
     n_iterations         -- The number of iterations run.
     n_tries              -- The number of tries as given in optimizer_params.
     layer_sizes          -- The number of units in each layer including the input
@@ -1571,6 +1591,20 @@ def mlp_help(schema_madlib, message, is_classification):
                                             uses mini-batch gradient descent. During gradient
                                             descent, n_epochs represents the number of times
                                             all batches in a buffer are iterated over.
+    momentum                            --  Default: 0.9. Momentum can help accelerate
+                                            learning and avoid local minima when
+                                            using gradient descent. Value must be in the
+                                            range 0 to 1, where 0 means no momentum.
+    nesterov                            --  Default: TRUE. Nesterov momentum can provide
+                                            better results than using classical momentum alone,
+                                            due to its look ahead characteristics. In Nesterov
+                                            momentum, we first move the model in the direction of
+                                            velocity and use the updated model to calculate
+                                            the gradient. The main difference being that in
+                                            classical momentum, we compute the gradient before
+                                            updating the model whereas in nesterov we first update
+                                            the model and then compute the gradient from the
+                                            updated position.
     """.format(**args)
 
     if not message:
@@ -1753,6 +1787,7 @@ class MLPMinibatchPreProcessor:
     """
     # summary table columns names
     DEPENDENT_VARNAME = "dependent_varname"
+    DEPENDENT_VARTYPE = "dependent_vartype"
     INDEPENDENT_VARNAME = "independent_varname"
     GROUPING_COL = "grouping_cols"
     CLASS_VALUES = "class_values"
@@ -1781,7 +1816,7 @@ class MLPMinibatchPreProcessor:
             summary_table_columns = summary_table_columns[0]
 
         required_columns = (self.DEPENDENT_VARNAME, self.INDEPENDENT_VARNAME,
-                            self.CLASS_VALUES, self.GROUPING_COL)
+                            self.CLASS_VALUES, self.GROUPING_COL, self.DEPENDENT_VARTYPE)
         if set(required_columns) <= set(summary_table_columns):
             self.preprocessed_summary_dict = summary_table_columns
         else:

http://git-wip-us.apache.org/repos/asf/madlib/blob/8bf15fde/src/ports/postgres/modules/convex/test/mlp.sql_in
----------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/convex/test/mlp.sql_in b/src/ports/postgres/modules/convex/test/mlp.sql_in
index 16d1637..24e3feb 100644
--- a/src/ports/postgres/modules/convex/test/mlp.sql_in
+++ b/src/ports/postgres/modules/convex/test/mlp.sql_in
@@ -1033,6 +1033,143 @@ SELECT mlp_predict(
     'mlp_prediction_regress_batch',
     'output');
 
+
+------------------------------------------------ Momentum ------------------------------------------------------------
+
+-- regression momentum without nesterov momentum
+DROP TABLE IF EXISTS mlp_regress, mlp_regress_summary, mlp_regress_standardization;
+SELECT mlp_regression(
+    'lin_housing_wi',           -- Source table
+    'mlp_regress',              -- Desination table
+    'x',                        -- Input features
+    'y',                        -- Dependent variable
+    ARRAY[40],                 -- Number of units per layer
+    'learning_rate_init=0.015,
+    momentum = 0.5, nesterov = False,
+    learning_rate_policy=inv,
+    n_iterations=5, tolerance=0', 'sigmoid',
+    '',
+    False,
+    False,
+    'grp');
+DROP TABLE IF EXISTS mlp_prediction_regress;
+SELECT mlp_predict(
+    'mlp_regress',
+    'lin_housing_wi',
+    'id',
+    'mlp_prediction_regress',
+    'output');
+
+-- assert that the summary table has momentum and nesterov
+SELECT assert
+        (
+        source_table         = 'lin_housing_wi' AND
+        independent_varname  = 'x' AND
+        dependent_varname    = 'y' AND
+        dependent_vartype    = 'double precision[]'  AND
+        tolerance            = 0  AND
+        learning_rate_init   = 0.015  AND
+        learning_rate_policy = 'inv'  AND
+        momentum             = 0.5 AND
+        nesterov             = False AND
+        n_iterations         = 5  AND
+        n_tries              = 1  AND
+        layer_sizes          = '{14,40,1}'  AND
+        activation           = 'sigmoid'  AND
+        is_classification    = False  AND
+        classes              = '{}'  AND
+        weights              = '1'  AND
+        grouping_col         = 'grp',
+        'Summary Validation failed. Actual:' || __to_char(summary)
+        ) from (select * from mlp_regress_summary) summary;
+
+-- regression momentum with nesterov momentum
+DROP TABLE IF EXISTS mlp_regress_batch, mlp_regress_batch_summary, mlp_regress_batch_standardization;
+SELECT mlp_regression(
+    'lin_housing_wi_batch',           -- Source table
+    'mlp_regress_batch',              -- Desination table
+    'independent_varname',                        -- Input features
+    'dependent_varname',                        -- Dependent variable
+    ARRAY[10],                 -- Number of units per layer
+    'learning_rate_init=0.025,
+    momentum = 0.5, nesterov = True,
+    learning_rate_policy=step,
+    lambda=0.001,
+    n_iterations=5,
+    tolerance=0,
+    batch_size=25, n_epochs=10',
+    'sigmoid',
+    '',
+    False,
+    TRUE);
+DROP TABLE IF EXISTS mlp_prediction_regress_batch;
+SELECT mlp_predict(
+    'mlp_regress_batch',
+    'lin_housing_wi',
+    'id',
+    'mlp_prediction_regress_batch',
+    'output');
+
+-- classification momentum without nesterov momentum
+DROP TABLE IF EXISTS mlp_class_batch, mlp_class_batch_summary, mlp_class_batch_standardization;
+SELECT mlp_classification(
+    'iris_data_batch',    -- Source table
+    'mlp_class_batch',    -- Desination table
+    'independent_varname',   -- Input features
+    'dependent_varname',        -- Label
+    ARRAY[5],   -- Number of units per layer
+    'learning_rate_init=0.1,
+    momentum = 0.1, nesterov = False,
+    learning_rate_policy=constant,
+    n_iterations=5,
+    n_tries=3,
+    tolerance=0,
+    n_epochs=20',
+    'sigmoid',
+    '',
+    False,
+    False
+);
+DROP TABLE IF EXISTS mlp_prediction_batch_output, mlp_prediction_output;
+-- See prediction accuracy for training data
+SELECT mlp_predict(
+    'mlp_class_batch',
+    'iris_data',
+    'id',
+    'mlp_prediction_batch_output',
+    'output');
+
+-- classification momentum with nesterov momentum
+DROP TABLE IF EXISTS mlp_class_batch, mlp_class_batch_summary, mlp_class_batch_standardization;
+SELECT mlp_classification(
+    'iris_data_batch_grp',    -- Source table
+    'mlp_class_batch',    -- Desination table
+    'independent_varname',   -- Input features
+    'dependent_varname',        -- Label
+    ARRAY[5],   -- Number of units per layer
+    'learning_rate_init=0.1,
+    momentum = 0.1, nesterov = True,
+    learning_rate_policy=constant,
+    n_iterations=5,
+    n_tries=3,
+    tolerance=0,
+    n_epochs=20',
+    'sigmoid',
+    '',
+    False,
+    False,
+    'grp'
+);
+
+DROP TABLE IF EXISTS mlp_prediction_batch_output, mlp_prediction_output;
+-- See prediction accuracy for training data
+SELECT mlp_predict(
+    'mlp_class_batch',
+    'iris_data',
+    'id',
+    'mlp_prediction_batch_output',
+    'output');
+
 -- Assert of all input tables still exist, to make sure we have not dropped
 -- anything in the code.
 -- Classification minibatch tables