You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by tr...@apache.org on 2015/05/08 14:15:04 UTC

flink git commit: [FLINK-1807] [ml] Adds optimization framework and SGD solver.

Repository: flink
Updated Branches:
  refs/heads/master f0b45c450 -> 2939fba3f


[FLINK-1807] [ml] Adds optimization framework and SGD solver.

Added Stochastic Gradient Descent initial version and some tests.

Added L1, L2 regularization.

Added tests for regularization, fixed parameter setting.

Added documentation.

Added option to provide UDF for the prediction function, moved SGD regularization to update step.

Added prediction function class to allow non-linear optimization in the future.

Small refactoring to allow calculation of regularized loss separatly from
regularized gradient.

This closes #613.


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/2939fba3
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/2939fba3
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/2939fba3

Branch: refs/heads/master
Commit: 2939fba3fbe594bbee6babdb7c223a93c45eef64
Parents: f0b45c4
Author: Theodore Vasiloudis <tv...@sics.se>
Authored: Tue Apr 21 10:59:34 2015 +0200
Committer: Till Rohrmann <tr...@apache.org>
Committed: Fri May 8 12:04:50 2015 +0200

----------------------------------------------------------------------
 docs/libs/ml/index.md                           |   2 +
 docs/libs/ml/optimization.md                    | 233 +++++++++++++++
 flink-staging/flink-ml/pom.xml                  |   6 +-
 .../apache/flink/ml/common/WeightVector.scala   |  32 ++
 .../scala/org/apache/flink/ml/math/BLAS.scala   | 291 +++++++++++++++++++
 .../flink/ml/optimization/GradientDescent.scala | 243 ++++++++++++++++
 .../flink/ml/optimization/LossFunction.scala    | 119 ++++++++
 .../ml/optimization/PredictionFunction.scala    |  38 +++
 .../flink/ml/optimization/Regularization.scala  | 198 +++++++++++++
 .../apache/flink/ml/optimization/Solver.scala   | 135 +++++++++
 .../optimization/GradientDescentITSuite.scala   | 210 +++++++++++++
 .../ml/optimization/LossFunctionITSuite.scala   |  55 ++++
 .../PredictionFunctionITSuite.scala             |  62 ++++
 .../ml/optimization/RegularizationITSuite.scala | 115 ++++++++
 .../flink/ml/regression/RegressionData.scala    |  62 ++++
 15 files changed, 1798 insertions(+), 3 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/2939fba3/docs/libs/ml/index.md
----------------------------------------------------------------------
diff --git a/docs/libs/ml/index.md b/docs/libs/ml/index.md
index 0754045..d36ce20 100644
--- a/docs/libs/ml/index.md
+++ b/docs/libs/ml/index.md
@@ -37,6 +37,8 @@ under the License.
 * [Multiple linear regression](multiple_linear_regression.html)
 * [Polynomial Base Feature Mapper](polynomial_base_feature_mapper.html)
 * [Standard Scaler](standard_scaler.html)
+* [Optimization Framework](optimization.html)
+
 
 ## Metrics
 

http://git-wip-us.apache.org/repos/asf/flink/blob/2939fba3/docs/libs/ml/optimization.md
----------------------------------------------------------------------
diff --git a/docs/libs/ml/optimization.md b/docs/libs/ml/optimization.md
new file mode 100644
index 0000000..5d1f3a7
--- /dev/null
+++ b/docs/libs/ml/optimization.md
@@ -0,0 +1,233 @@
+---
+mathjax: include
+title: "ML - Optimization"
+displayTitle: <a href="index.md">ML</a> - Optimization
+---
+<!--
+Licensed to the Apache Software Foundation (ASF) under one
+or more contributor license agreements.  See the NOTICE file
+distributed with this work for additional information
+regarding copyright ownership.  The ASF licenses this file
+to you under the Apache License, Version 2.0 (the
+"License"); you may not use this file except in compliance
+with the License.  You may obtain a copy of the License at
+
+  http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing,
+software distributed under the License is distributed on an
+"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+KIND, either express or implied.  See the License for the
+specific language governing permissions and limitations
+under the License.
+-->
+
+* Table of contents
+{:toc}
+
+$$
+\newcommand{\R}{\mathbb{R}}
+\newcommand{\E}{\mathbb{E}} 
+\newcommand{\x}{\mathbf{x}}
+\newcommand{\y}{\mathbf{y}}
+\newcommand{\wv}{\mathbf{w}}
+\newcommand{\av}{\mathbf{\alpha}}
+\newcommand{\bv}{\mathbf{b}}
+\newcommand{\N}{\mathbb{N}}
+\newcommand{\id}{\mathbf{I}}
+\newcommand{\ind}{\mathbf{1}} 
+\newcommand{\0}{\mathbf{0}} 
+\newcommand{\unit}{\mathbf{e}} 
+\newcommand{\one}{\mathbf{1}} 
+\newcommand{\zero}{\mathbf{0}}
+$$
+
+## Mathematical Formulation
+
+The optimization framework in Flink is a developer-oriented package that can be used to solve
+[optimization](https://en.wikipedia.org/wiki/Mathematical_optimization) 
+problems common in Machine Learning (ML) tasks. In the supervised learning context, this usually 
+involves finding a model, as defined by a set of parameters $w$, that minimize a function $f(\wv)$ 
+given a set of $(\x, y)$ examples,
+where $\x$ is a feature vector and $y$ is a real number, which can represent either a real value in 
+the regression case, or a class label in the classification case. In supervised learning, the 
+function to be minimized is usually of the form:
+
+$$
+\begin{equation}
+    f(\wv) := 
+    \frac1n \sum_{i=1}^n L(\wv;\x_i,y_i) +
+    \lambda\, R(\wv)
+    \label{eq:objectiveFunc}
+    \ .
+\end{equation}
+$$
+
+where $L$ is the loss function and $R(\wv)$ the regularization penalty. We use $L$ to measure how
+well the model fits the observed data, and we use $R$ in order to impose a complexity cost to the
+model, with $\lambda > 0$ being the regularization parameter.
+
+### Loss Functions
+
+In supervised learning, we use loss functions in order to measure the model fit, by 
+penalizing errors in the predictions $p$ made by the model compared to the true $y$ for each 
+example. Different loss function can be used for regression (e.g. Squared Loss) and classification
+(e.g. Hinge Loss).
+
+Some common loss functions are:
+ 
+* Squared Loss: $ \frac{1}{2} (\wv^T \x - y)^2, \quad y \in \R $ 
+* Hinge Loss: $ \max (0, 1 - y ~ \wv^T \x), \quad y \in \{-1, +1\} $
+* Logistic Loss: $ \log(1+\exp( -y ~ \wv^T \x)), \quad y \in \{-1, +1\} $
+
+Currently, only the Squared Loss function is implemented in Flink.
+
+### Regularization Types
+
+[Regularization](https://en.wikipedia.org/wiki/Regularization_(mathematics)) in machine learning is 
+imposes penalties to the estimated models, in order to reduce overfitting. The most common penalties
+are the $L_1$ and $L_2$ penalties, defined as:
+
+* $L_1$: $R(\wv) = \|\wv\|_1$
+* $L_2$: $R(\wv) = \frac{1}{2}\|\wv\|_2^2$
+
+The $L_2$ penalty penalizes large weights, favoring solutions with more small weights rather than
+few large ones.
+The $L_1$ penalty can be used to drive a number of the solution coefficients to 0, thereby
+producing sparse solutions.
+The optimization framework in Flink supports the $L_1$ and $L_2$ penalties, as well as no 
+regularization. The 
+regularization parameter $\lambda$ in $\eqref{objectiveFunc}$ determines the amount of 
+regularization applied to the model,
+and is usually determined through model cross-validation.
+
+## Stochastic Gradient Descent
+
+In order to find a (local) minimum of a function, Gradient Descent methods take steps in the
+direction opposite to the gradient of the function $\eqref{objectiveFunc}$ taken with
+respect to the current parameters (weights).
+In order to compute the exact gradient we need to perform one pass through all the points in
+a dataset, making the process computationally expensive.
+An alternative is Stochastic Gradient Descent (SGD) where at each iteration we sample one point
+from the complete dataset and update the parameters for each point, in an online manner.
+
+In mini-batch SGD we instead sample random subsets of the dataset, and compute the gradient
+over each batch. At each iteration of the algorithm we update the weights once, based on
+the average of the gradients computed from each mini-batch.
+
+An important parameter is the learning rate $\eta$, or step size, which is currently determined as
+$\eta = \frac{\eta_0}{\sqrt{j}}$, where $\eta_0$ is the initial step size and $j$ is the iteration 
+number. The setting of the initial step size can significantly affect the performance of the 
+algorithm. For some practical tips on tuning SGD see Leon Botou's 
+"[Stochastic Gradient Descent Tricks](http://research.microsoft.com/pubs/192769/tricks-2012.pdf)".
+
+The current implementation of SGD  uses the whole partition, making it 
+effectively a batch gradient descent. Once a sampling operator has been introduced in Flink, true
+mini-batch SGD will be performed.
+
+
+### Parameters
+
+  The stochastic gradient descent implementation can be controlled by the following parameters:
+  
+   <table class="table table-bordered">
+    <thead>
+      <tr>
+        <th class="text-left" style="width: 20%">Parameter</th>
+        <th class="text-center">Description</th>
+      </tr>
+    </thead>
+    <tbody>
+      <tr>
+        <td><strong>Loss Function</strong></td>
+        <td>
+          <p>
+            The class of the loss function to be used. (Default value: 
+            <strong>SquaredLoss</strong>, used for regression tasks)
+          </p>
+        </td>
+      </tr>
+      <tr>
+        <td><strong>RegularizationType</strong></td>
+        <td>
+          <p>
+            The type of regularization penalty to apply. (Default value: 
+            <strong>NoRegularization</strong>)
+          </p>
+        </td>
+      </tr>
+      <tr>
+        <td><strong>RegularizationParameter</strong></td>
+        <td>
+          <p>
+            The amount of regularization to apply. (Default value:<strong>0</strong>)
+          </p>
+        </td>
+      </tr>     
+      <tr>
+        <td><strong>Iterations</strong></td>
+        <td>
+          <p>
+            The maximum number of iterations. (Default value: <strong>10</strong>)
+          </p>
+        </td>
+      </tr>
+      <tr>
+        <td><strong>Stepsize</strong></td>
+        <td>
+          <p>
+            Initial step size for the gradient descent method.
+            This value controls how far the gradient descent method moves in the opposite direction 
+            of the gradient.
+            (Default value: <strong>0.1</strong>)
+          </p>
+        </td>
+      </tr>
+      <tr>
+        <td><strong>Prediction Function</strong></td>
+        <td>
+          <p>
+            Class that provides the prediction function, used to calculate $\hat{y}$ based on the 
+            weights and the example features, and the prediction gradient.
+            (Default value: <strong>LinearPrediction</strong>)
+          </p>
+        </td>
+      </tr>
+    </tbody>
+  </table>
+
+### Examples
+
+In the Flink implementation of SGD, given a set of examples in a `DataSet[LabeledVector]` and
+optionally some initial weights, we can use `GradientDescent.optimize()` in order to optimize
+the weights for the given data.
+
+The user can provide an initial `DataSet[WeightVector]`,
+which contains one `WeightVector` element, or use the default weights which are all set to 0.
+A `WeightVector` is a container class for the weights, which separates the intercept from the
+weight vector. This allows us to avoid applying regularization to the intercept.
+
+
+
+{% highlight scala %}
+// Create stochastic gradient descent solver
+val sgd = GradientDescent()
+.setLossFunction(new SquaredLoss)
+.setRegularizationType(new L1Regularization)
+.setRegularizationParameter(0.2)
+.setIterations(100)
+.setStepsize(0.01)
+
+
+// Obtain data
+val trainingDS: DataSet[LabeledVector] = ...
+
+// Fit the solver to the provided data, using initial weights set to 0.0
+val weightDS = sgd.optimize(inputDS, None)
+
+// Retrieve the optimized weights
+val weightVector = weightDS
+
+// We can now use the weightVector to make predictions
+
+{% endhighlight %}

http://git-wip-us.apache.org/repos/asf/flink/blob/2939fba3/flink-staging/flink-ml/pom.xml
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/pom.xml b/flink-staging/flink-ml/pom.xml
index db3cf48..bc61e2e 100644
--- a/flink-staging/flink-ml/pom.xml
+++ b/flink-staging/flink-ml/pom.xml
@@ -27,7 +27,7 @@
 		<version>0.9-SNAPSHOT</version>
 		<relativePath>..</relativePath>
 	</parent>
-	
+
 	<artifactId>flink-ml</artifactId>
 	<name>flink-ml</name>
 
@@ -42,8 +42,8 @@
 
 		<dependency>
 			<groupId>org.scalanlp</groupId>
-			<artifactId>breeze_2.10</artifactId>
-			<version>0.11.1</version>
+			<artifactId>breeze_${scala.binary.version}</artifactId>
+			<version>0.11.2</version>
 		</dependency>
 
 		<dependency>

http://git-wip-us.apache.org/repos/asf/flink/blob/2939fba3/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/WeightVector.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/WeightVector.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/WeightVector.scala
new file mode 100644
index 0000000..247d92e
--- /dev/null
+++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/WeightVector.scala
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common
+
+import org.apache.flink.ml.math.Vector
+
+// TODO(tvas): This provides an abstraction for the weights
+// but at the same time it leads to the creation of many objects as we have to pack and unpack
+// the weights and the intercept often during SGD.
+
+/** This class represents a weight vector with an intercept, as it is required for many supervised
+  * learning tasks
+  * @param weights The vector of weights
+  * @param intercept The intercept (bias) weight
+  */
+case class WeightVector(weights: Vector, var intercept: Double) extends Serializable {}

http://git-wip-us.apache.org/repos/asf/flink/blob/2939fba3/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/BLAS.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/BLAS.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/BLAS.scala
new file mode 100644
index 0000000..8ea3b65
--- /dev/null
+++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/BLAS.scala
@@ -0,0 +1,291 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.math
+
+import com.github.fommil.netlib.{BLAS => NetlibBLAS, F2jBLAS}
+import com.github.fommil.netlib.BLAS.{getInstance => NativeBLAS}
+
+/**
+ * BLAS routines for vectors and matrices.
+ *
+ * Original code from the Apache Spark project:
+ * http://git.io/vfZUe
+ */
+object BLAS extends Serializable {
+
+  @transient private var _f2jBLAS: NetlibBLAS = _
+  @transient private var _nativeBLAS: NetlibBLAS = _
+
+  // For level-1 routines, we use Java implementation.
+  private def f2jBLAS: NetlibBLAS = {
+    if (_f2jBLAS == null) {
+      _f2jBLAS = new F2jBLAS
+    }
+    _f2jBLAS
+  }
+
+  /**
+   * y += a * x
+   */
+  def axpy(a: Double, x: Vector, y: Vector): Unit = {
+    require(x.size == y.size)
+    y match {
+      case dy: DenseVector =>
+        x match {
+          case sx: SparseVector =>
+            axpy(a, sx, dy)
+          case dx: DenseVector =>
+            axpy(a, dx, dy)
+          case _ =>
+            throw new UnsupportedOperationException(
+              s"axpy doesn't support x type ${x.getClass}.")
+        }
+      case _ =>
+        throw new IllegalArgumentException(
+          s"axpy only supports adding to a dense vector but got type ${y.getClass}.")
+    }
+  }
+
+  /**
+   * y += a * x
+   */
+  private def axpy(a: Double, x: DenseVector, y: DenseVector): Unit = {
+    val n = x.size
+    f2jBLAS.daxpy(n, a, x.data, 1, y.data, 1)
+  }
+
+  /**
+   * y += a * x
+   */
+  private def axpy(a: Double, x: SparseVector, y: DenseVector): Unit = {
+    val xValues = x.data
+    val xIndices = x.indices
+    val yValues = y.data
+    val nnz = xIndices.size
+
+    if (a == 1.0) {
+      var k = 0
+      while (k < nnz) {
+        yValues(xIndices(k)) += xValues(k)
+        k += 1
+      }
+    } else {
+      var k = 0
+      while (k < nnz) {
+        yValues(xIndices(k)) += a * xValues(k)
+        k += 1
+      }
+    }
+  }
+
+  /**
+   * dot(x, y)
+   */
+  def dot(x: Vector, y: Vector): Double = {
+    require(x.size == y.size,
+      "BLAS.dot(x: Vector, y:Vector) was given Vectors with non-matching sizes:" +
+        " x.size = " + x.size + ", y.size = " + y.size)
+    (x, y) match {
+      case (dx: DenseVector, dy: DenseVector) =>
+        dot(dx, dy)
+      case (sx: SparseVector, dy: DenseVector) =>
+        dot(sx, dy)
+      case (dx: DenseVector, sy: SparseVector) =>
+        dot(sy, dx)
+      case (sx: SparseVector, sy: SparseVector) =>
+        dot(sx, sy)
+      case _ =>
+        throw new IllegalArgumentException(s"dot doesn't support (${x.getClass}, ${y.getClass}).")
+    }
+  }
+
+  /**
+   * dot(x, y)
+   */
+  private def dot(x: DenseVector, y: DenseVector): Double = {
+    val n = x.size
+    f2jBLAS.ddot(n, x.data, 1, y.data, 1)
+  }
+
+  /**
+   * dot(x, y)
+   */
+  private def dot(x: SparseVector, y: DenseVector): Double = {
+    val xValues = x.data
+    val xIndices = x.indices
+    val yValues = y.data
+    val nnz = xIndices.size
+
+    var sum = 0.0
+    var k = 0
+    while (k < nnz) {
+      sum += xValues(k) * yValues(xIndices(k))
+      k += 1
+    }
+    sum
+  }
+
+  /**
+   * dot(x, y)
+   */
+  private def dot(x: SparseVector, y: SparseVector): Double = {
+    val xValues = x.data
+    val xIndices = x.indices
+    val yValues = y.data
+    val yIndices = y.indices
+    val nnzx = xIndices.size
+    val nnzy = yIndices.size
+
+    var kx = 0
+    var ky = 0
+    var sum = 0.0
+    // y catching x
+    while (kx < nnzx && ky < nnzy) {
+      val ix = xIndices(kx)
+      while (ky < nnzy && yIndices(ky) < ix) {
+        ky += 1
+      }
+      if (ky < nnzy && yIndices(ky) == ix) {
+        sum += xValues(kx) * yValues(ky)
+        ky += 1
+      }
+      kx += 1
+    }
+    sum
+  }
+
+  /**
+   * y = x
+   */
+  def copy(x: Vector, y: Vector): Unit = {
+    val n = y.size
+    require(x.size == n)
+    y match {
+      case dy: DenseVector =>
+        x match {
+          case sx: SparseVector =>
+            val sxIndices = sx.indices
+            val sxValues = sx.data
+            val dyValues = dy.data
+            val nnz = sxIndices.size
+
+            var i = 0
+            var k = 0
+            while (k < nnz) {
+              val j = sxIndices(k)
+              while (i < j) {
+                dyValues(i) = 0.0
+                i += 1
+              }
+              dyValues(i) = sxValues(k)
+              i += 1
+              k += 1
+            }
+            while (i < n) {
+              dyValues(i) = 0.0
+              i += 1
+            }
+          case dx: DenseVector =>
+            Array.copy(dx.data, 0, dy.data, 0, n)
+        }
+      case _ =>
+        throw new IllegalArgumentException(s"y must be dense in copy but got ${y.getClass}")
+    }
+  }
+
+  /**
+   * x = a * x
+   */
+  def scal(a: Double, x: Vector): Unit = {
+    x match {
+      case sx: SparseVector =>
+        f2jBLAS.dscal(sx.data.size, a, sx.data, 1)
+      case dx: DenseVector =>
+        f2jBLAS.dscal(dx.data.size, a, dx.data, 1)
+      case _ =>
+        throw new IllegalArgumentException(s"scal doesn't support vector type ${x.getClass}.")
+    }
+  }
+
+  // For level-3 routines, we use the native BLAS.
+  private def nativeBLAS: NetlibBLAS = {
+    if (_nativeBLAS == null) {
+      _nativeBLAS = NativeBLAS
+    }
+    _nativeBLAS
+  }
+
+  /**
+   * A := alpha * x * x^T^ + A
+   * @param alpha a real scalar that will be multiplied to x * x^T^.
+   * @param x the vector x that contains the n elements.
+   * @param A the symmetric matrix A. Size of n x n.
+   */
+  def syr(alpha: Double, x: Vector, A: DenseMatrix) {
+    val mA = A.numRows
+    val nA = A.numCols
+    require(mA == nA, s"A is not a square matrix (and hence is not symmetric). A: $mA x $nA")
+    require(mA == x.size, s"The size of x doesn't match the rank of A. A: $mA x $nA, x: ${x.size}")
+
+    x match {
+      case dv: DenseVector => syr(alpha, dv, A)
+      case sv: SparseVector => syr(alpha, sv, A)
+      case _ =>
+        throw new IllegalArgumentException(s"syr doesn't support vector type ${x.getClass}.")
+    }
+  }
+
+  private def syr(alpha: Double, x: DenseVector, A: DenseMatrix) {
+    val nA = A.numRows
+    val mA = A.numCols
+
+    nativeBLAS.dsyr("U", x.size, alpha, x.data, 1, A.data, nA)
+
+    // Fill lower triangular part of A
+    var i = 0
+    while (i < mA) {
+      var j = i + 1
+      while (j < nA) {
+        A(j, i) = A(i, j)
+        j += 1
+      }
+      i += 1
+    }
+  }
+
+  private def syr(alpha: Double, x: SparseVector, A: DenseMatrix) {
+    val mA = A.numCols
+    val xIndices = x.indices
+    val xValues = x.data
+    val nnz = xValues.length
+    val Avalues = A.data
+
+    var i = 0
+    while (i < nnz) {
+      val multiplier = alpha * xValues(i)
+      val offset = xIndices(i) * mA
+      var j = 0
+      while (j < nnz) {
+        Avalues(xIndices(j) + offset) += multiplier * xValues(j)
+        j += 1
+      }
+      i += 1
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/2939fba3/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/optimization/GradientDescent.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/optimization/GradientDescent.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/optimization/GradientDescent.scala
new file mode 100644
index 0000000..eff519e
--- /dev/null
+++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/optimization/GradientDescent.scala
@@ -0,0 +1,243 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+package org.apache.flink.ml.optimization
+
+import org.apache.flink.api.common.functions.RichMapFunction
+import org.apache.flink.api.scala._
+import org.apache.flink.configuration.Configuration
+import org.apache.flink.ml.common._
+import org.apache.flink.ml.math._
+import org.apache.flink.ml.optimization.IterativeSolver.{Iterations, Stepsize}
+import org.apache.flink.ml.optimization.Solver._
+
+/** This [[Solver]] performs Stochastic Gradient Descent optimization using mini batches
+  *
+  * For each labeled vector in a mini batch the gradient is computed and added to a partial
+  * gradient. The partial gradients are then summed and divided by the size of the batches. The
+  * average gradient is then used to updated the weight values, including regularization.
+  *
+  * At the moment, the whole partition is used for SGD, making it effectively a batch gradient
+  * descent. Once a sampling operator has been introduced, the algorithm can be optimized
+  *
+  * @param runParameters The parameters to tune the algorithm. Currently these include:
+  *                      [[Solver.LossFunction]] for the loss function to be used,
+  *                      [[Solver.RegularizationType]] for the type of regularization,
+  *                      [[Solver.RegularizationParameter]] for the regularization parameter,
+  *                      [[IterativeSolver.Iterations]] for the maximum number of iteration,
+  *                      [[IterativeSolver.Stepsize]] for the learning rate used.
+  */
+class GradientDescent(runParameters: ParameterMap) extends IterativeSolver {
+
+  import Solver.WEIGHTVECTOR_BROADCAST
+
+  var parameterMap: ParameterMap = parameters ++ runParameters
+
+  /** Performs one iteration of Stochastic Gradient Descent using mini batches
+    *
+    * @param data A Dataset of LabeledVector (label, features) pairs
+    * @param currentWeights A Dataset with the current weights to be optimized as its only element
+    * @return A Dataset containing the weights after one stochastic gradient descent step
+    */
+  private def SGDStep(data: DataSet[(LabeledVector)], currentWeights: DataSet[WeightVector]):
+  DataSet[WeightVector] = {
+
+    // TODO: Sample from input to realize proper SGD
+    data.map {
+      new GradientCalculation
+    }.withBroadcastSet(currentWeights, WEIGHTVECTOR_BROADCAST).reduce {
+      (left, right) =>
+        val (leftGradVector, leftLoss, leftCount) = left
+        val (rightGradVector, rightLoss, rightCount) = right
+        // Add the left gradient to the right one
+        BLAS.axpy(1.0, leftGradVector.weights, rightGradVector.weights)
+        val gradients = WeightVector(
+          rightGradVector.weights, leftGradVector.intercept + rightGradVector.intercept)
+
+        (gradients , leftLoss + rightLoss, leftCount + rightCount)
+    }.map {
+      new WeightsUpdate
+    }.withBroadcastSet(currentWeights, WEIGHTVECTOR_BROADCAST)
+  }
+
+  /** Provides a solution for the given optimization problem
+    *
+    * @param data A Dataset of LabeledVector (label, features) pairs
+    * @param initWeights The initial weights that will be optimized
+    * @return The weights, optimized for the provided data.
+    */
+  override def optimize(
+    data: DataSet[LabeledVector],
+    initWeights: Option[DataSet[WeightVector]]): DataSet[WeightVector] = {
+    // TODO: Faster way to do this?
+    val dimensionsDS = data.map(_.vector.size).reduce((a, b) => b)
+
+    val numberOfIterations: Int = parameterMap(Iterations)
+
+    // Initialize weights
+    val initialWeightsDS: DataSet[WeightVector] = initWeights match {
+      // Ensure provided weight vector is a DenseVector
+      case Some(wvDS) => {
+        wvDS.map{wv => {
+          val denseWeights = wv.weights match {
+            case dv: DenseVector => dv
+            case sv: SparseVector => sv.toDenseVector
+          }
+          WeightVector(denseWeights, wv.intercept)
+        }
+
+        }
+      }
+      case None => createInitialWeightVector(dimensionsDS)
+    }
+
+    // Perform the iterations
+    // TODO: Enable convergence stopping criterion, as in Multiple Linear regression
+    initialWeightsDS.iterate(numberOfIterations) {
+      weightVector => {
+        SGDStep(data, weightVector)
+      }
+    }
+  }
+
+  /** Mapping function that calculates the weight gradients from the data.
+    *
+    */
+  private class GradientCalculation extends
+    RichMapFunction[LabeledVector, (WeightVector, Double, Int)] {
+
+    var weightVector: WeightVector = null
+
+    @throws(classOf[Exception])
+    override def open(configuration: Configuration): Unit = {
+      val list = this.getRuntimeContext.
+        getBroadcastVariable[WeightVector](WEIGHTVECTOR_BROADCAST)
+
+      weightVector = list.get(0)
+    }
+
+    override def map(example: LabeledVector): (WeightVector, Double, Int) = {
+
+      val lossFunction = parameterMap(LossFunction)
+      val regType = parameterMap(RegularizationType)
+      val regParameter = parameterMap(RegularizationParameter)
+      val predictionFunction = parameterMap(PredictionFunctionParameter)
+      val dimensions = example.vector.size
+      // TODO(tvas): Any point in carrying the weightGradient vector for in-place replacement?
+      // The idea in spark is to avoid object creation, but here we have to do it anyway
+      val weightGradient = new DenseVector(new Array[Double](dimensions))
+
+      // TODO(tvas): Indentation here?
+      val (loss, lossDeriv) = lossFunction.lossAndGradient(
+                                example,
+                                weightVector,
+                                weightGradient,
+                                regType,
+                                regParameter,
+                                predictionFunction)
+
+      (new WeightVector(weightGradient, lossDeriv), loss, 1)
+    }
+  }
+
+  /** Performs the update of the weights, according to the given gradients and regularization type.
+    *
+    */
+  private class WeightsUpdate() extends
+  RichMapFunction[(WeightVector, Double, Int), WeightVector] {
+
+    var weightVector: WeightVector = null
+
+    @throws(classOf[Exception])
+    override def open(configuration: Configuration): Unit = {
+      val list = this.getRuntimeContext.
+        getBroadcastVariable[WeightVector](WEIGHTVECTOR_BROADCAST)
+
+      weightVector = list.get(0)
+    }
+
+    override def map(gradientLossAndCount: (WeightVector, Double, Int)): WeightVector = {
+      val regType = parameterMap(RegularizationType)
+      val regParameter = parameterMap(RegularizationParameter)
+      val stepsize = parameterMap(Stepsize)
+      val weightGradients = gradientLossAndCount._1
+      val lossSum = gradientLossAndCount._2
+      val count = gradientLossAndCount._3
+
+      // Scale the gradients according to batch size
+      BLAS.scal(1.0/count, weightGradients.weights)
+
+      // Calculate the regularized loss and, if the regularization is differentiable, add the
+      // regularization term to the gradient as well, in-place
+      // Note(tvas): adjustedLoss is never used currently, but I'd like to leave it here for now.
+      // We can probably maintain a loss history as the optimization package grows towards a
+      // Breeze-like interface (see breeze.optimize.FirstOrderMinimizer)
+      val adjustedLoss = {
+        regType match {
+          case x: DiffRegularization => {
+            x.regularizedLossAndGradient(
+              lossSum / count,
+              weightVector.weights,
+              weightGradients.weights,
+              regParameter)
+          }
+          case x: Regularization => {
+            x.regLoss(
+              lossSum / count,
+              weightVector.weights,
+              regParameter)
+          }
+        }
+      }
+
+      val weight0Gradient = weightGradients.intercept / count
+
+      val iteration = getIterationRuntimeContext.getSuperstepNumber
+
+      // Scale initial stepsize by the inverse square root of the iteration number
+      // TODO(tvas): There are more ways to determine the stepsize, possible low-effort extensions
+      // here
+      val effectiveStepsize = stepsize/math.sqrt(iteration)
+
+      // Take the gradient step for the intercept
+      weightVector.intercept -= effectiveStepsize * weight0Gradient
+
+      // Take the gradient step for the weight vector, possibly applying regularization
+      // TODO(tvas): This should be moved to a takeStep() function that takes regType plus all these
+      // arguments, this would decouple the update step from the regularization classes
+      regType.takeStep(weightVector.weights, weightGradients.weights,
+        effectiveStepsize, regParameter)
+
+      weightVector
+    }
+  }
+}
+
+object GradientDescent {
+  def apply(): GradientDescent = {
+    new GradientDescent(new ParameterMap())
+  }
+
+  def apply(parameterMap: ParameterMap): GradientDescent = {
+    new GradientDescent(parameterMap)
+  }
+}
+
+
+

http://git-wip-us.apache.org/repos/asf/flink/blob/2939fba3/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/optimization/LossFunction.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/optimization/LossFunction.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/optimization/LossFunction.scala
new file mode 100644
index 0000000..1bb6152
--- /dev/null
+++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/optimization/LossFunction.scala
@@ -0,0 +1,119 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.optimization
+
+import org.apache.flink.ml.common.{WeightVector, LabeledVector}
+import org.apache.flink.ml.math.{Vector => FlinkVector, BLAS}
+
+/** Abstract class that implements some of the functionality for common loss functions
+  *
+  * A loss function determines the loss term $L(w) of the objective function  $f(w) = L(w) +
+  * \lambda R(w)$ for prediction tasks, the other being regularization, $R(w)$.
+  *
+  * We currently only support differentiable loss functions, in the future this class
+  * could be changed to DiffLossFunction in order to support other types, such absolute loss.
+  */
+abstract class LossFunction extends Serializable{
+
+  /** Calculates the loss for a given prediction/truth pair
+    *
+    * @param prediction The predicted value
+    * @param truth The true value
+    */
+  protected def loss(prediction: Double, truth: Double): Double
+
+  /** Calculates the derivative of the loss function with respect to the prediction
+    *
+    * @param prediction The predicted value
+    * @param truth The true value
+    */
+  protected def lossDerivative(prediction: Double, truth: Double): Double
+
+  /** Compute the gradient and the loss for the given data.
+    * The provided cumGradient is updated in place.
+    *
+    * @param example The features and the label associated with the example
+    * @param weights The current weight vector
+    * @param cumGradient The vector to which the gradient will be added to, in place.
+    * @return A tuple containing the computed loss as its first element and a the loss derivative as
+    *         its second element. The gradient is updated in-place.
+    */
+  def lossAndGradient(
+      example: LabeledVector,
+      weights: WeightVector,
+      cumGradient: FlinkVector,
+      regType: Regularization,
+      regParameter: Double,
+      predictionFunction: PredictionFunction):
+  (Double, Double) = {
+    val features = example.vector
+    val label = example.label
+    // TODO(tvas): We could also provide for the case where we don't want an intercept value
+    // i.e. data already centered
+    val prediction = predictionFunction.predict(features, weights)
+    val predictionGradient = predictionFunction.gradient(features, weights)
+    val lossValue: Double = loss(prediction, label)
+    // The loss derivative is used to update the intercept
+    val lossDeriv = lossDerivative(prediction, label)
+    // Restrict the value of the loss derivative to avoid numerical instabilities
+    val restrictedLossDeriv: Double = {
+      if (lossDeriv < -IterativeSolver.MAX_DLOSS) {
+        -IterativeSolver.MAX_DLOSS
+      }
+      else if (lossDeriv > IterativeSolver.MAX_DLOSS) {
+        IterativeSolver.MAX_DLOSS
+      }
+      else {
+        lossDeriv
+      }
+    }
+    // Update the gradient
+    BLAS.axpy(restrictedLossDeriv, predictionGradient, cumGradient)
+    (lossValue, lossDeriv)
+  }
+}
+
+trait ClassificationLoss extends LossFunction
+trait RegressionLoss extends LossFunction
+
+// TODO(tvas): Implement LogisticLoss, HingeLoss.
+
+/** Squared loss function where $L(w) = \frac{1}{2} (w^{T} x - y)^2$
+  *
+  */
+class SquaredLoss extends RegressionLoss {
+  /** Calculates the loss for a given prediction/truth pair
+    *
+    * @param prediction The predicted value
+    * @param truth The true value
+    */
+  protected override def loss(prediction: Double, truth: Double): Double = {
+    0.5 * (prediction - truth) * (prediction - truth)
+  }
+
+  /** Calculates the derivative of the loss function with respect to the prediction
+    *
+    * @param prediction The predicted value
+    * @param truth The true value
+    */
+  protected override def lossDerivative(prediction: Double, truth: Double): Double = {
+    prediction - truth
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/2939fba3/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/optimization/PredictionFunction.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/optimization/PredictionFunction.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/optimization/PredictionFunction.scala
new file mode 100644
index 0000000..91b0f39
--- /dev/null
+++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/optimization/PredictionFunction.scala
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.optimization
+
+import org.apache.flink.ml.common.WeightVector
+import org.apache.flink.ml.math.{Vector => FlinkVector, BLAS}
+
+/** An abstract class for prediction functions to be used in optimization **/
+abstract class PredictionFunction extends Serializable {
+  def predict(features: FlinkVector, weights: WeightVector): Double
+
+  def gradient(features: FlinkVector, weights: WeightVector): FlinkVector
+}
+
+/** A linear prediction function **/
+class LinearPrediction extends PredictionFunction {
+  override def predict(features: FlinkVector, weightVector: WeightVector): Double = {
+    BLAS.dot(features, weightVector.weights) + weightVector.intercept
+  }
+
+  override def gradient(features: FlinkVector, weights: WeightVector): FlinkVector = {features}
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/2939fba3/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/optimization/Regularization.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/optimization/Regularization.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/optimization/Regularization.scala
new file mode 100644
index 0000000..4ec2452
--- /dev/null
+++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/optimization/Regularization.scala
@@ -0,0 +1,198 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.optimization
+
+import org.apache.flink.ml.math.{Vector => FlinkVector, BLAS}
+
+/** Represents a type of regularization penalty
+  *
+  * Regularization penalties are used to restrict the optimization problem to solutions with
+  * certain desirable characteristics, such as sparsity for the L1 penalty, or penalizing large
+  * weights for the L2 penalty.
+  *
+  * The regularization term, $R(w)$ is added to the objective function, $f(w) = L(w) + \lambda R(w)$
+  * where $\lambda$ is the regularization parameter used to tune the amount of regularization
+  * applied.
+  */
+abstract class Regularization extends Serializable {
+
+  /** Updates the weights by taking a step according to the gradient and regularization applied
+    *
+    * @param oldWeights The weights to be updated
+    * @param gradient The gradient according to which we will update the weights
+    * @param effectiveStepSize The effective step size for this iteration
+    * @param regParameter The regularization parameter, $\lambda$.
+    */
+  def takeStep(
+      oldWeights: FlinkVector,
+      gradient: FlinkVector,
+      effectiveStepSize: Double,
+      regParameter: Double) {
+    BLAS.axpy(-effectiveStepSize, gradient, oldWeights)
+  }
+
+  /** Adds the regularization term to the loss value
+    *
+    * @param loss The loss value, before applying regularization.
+    * @param weightVector The current vector of weights.
+    * @param regularizationParameter The regularization parameter, $\lambda$.
+    * @return The loss value with regularization applied.
+    */
+  def regLoss(loss: Double, weightVector: FlinkVector, regularizationParameter: Double): Double
+
+}
+
+/** Abstract class for regularization penalties that are differentiable
+  *
+  */
+abstract class DiffRegularization extends Regularization {
+
+  /** Compute the regularized gradient loss for the given data.
+    * The provided cumGradient is updated in place.
+    *
+    * @param loss The loss value without regularization.
+    * @param weightVector The current vector of weights.
+    * @param lossGradient The loss gradient, without regularization. Updated in-place.
+    * @param regParameter The regularization parameter, $\lambda$.
+    * @return The loss value with regularization applied.
+    */
+  def regularizedLossAndGradient(
+      loss: Double,
+      weightVector: FlinkVector,
+      lossGradient: FlinkVector,
+      regParameter: Double) : Double ={
+    val adjustedLoss = regLoss(loss, weightVector, regParameter)
+    regGradient(weightVector, lossGradient, regParameter)
+
+    adjustedLoss
+  }
+
+  /** Adds the regularization gradient term to the loss gradient. The gradient is updated in place.
+    *
+    * @param weightVector The current vector of weights
+    * @param lossGradient The loss gradient, without regularization. Updated in-place.
+    * @param regParameter The regularization parameter, $\lambda$.
+    */
+  def regGradient(
+      weightVector: FlinkVector,
+      lossGradient: FlinkVector,
+      regParameter: Double)
+}
+
+/** Performs no regularization, equivalent to $R(w) = 0$ **/
+class NoRegularization extends Regularization {
+  /** Adds the regularization term to the loss value
+    *
+    * @param loss The loss value, before applying regularization
+    * @param weightVector The current vector of weights
+    * @param regParameter The regularization parameter, $\lambda$
+    * @return The loss value with regularization applied.
+    */
+  override def regLoss(
+    loss: Double,
+    weightVector: FlinkVector,
+    regParameter: Double):  Double = {loss}
+}
+
+/** $L_2$ regularization penalty.
+  *
+  * Penalizes large weights, favoring solutions with more small weights rather than few large ones.
+  *
+  */
+class L2Regularization extends DiffRegularization {
+
+  /** Adds the regularization term to the loss value
+    *
+    * @param loss The loss value, before applying regularization
+    * @param weightVector The current vector of weights
+    * @param regParameter The regularization parameter, $\lambda$
+    * @return The loss value with regularization applied.
+    */
+  override def regLoss(loss: Double, weightVector: FlinkVector, regParameter: Double)
+    : Double = {
+    loss + regParameter * BLAS.dot(weightVector, weightVector) / 2
+  }
+
+  /** Adds the regularization gradient term to the loss gradient. The gradient is updated in place.
+    *
+    * @param weightVector The current vector of weights.
+    * @param lossGradient The loss gradient, without regularization. Updated in-place.
+    * @param regParameter The regularization parameter, $\lambda$.
+    */
+  override def regGradient(
+      weightVector: FlinkVector,
+      lossGradient: FlinkVector,
+      regParameter: Double): Unit = {
+    BLAS.axpy(regParameter, weightVector, lossGradient)
+  }
+}
+
+/** $L_1$ regularization penalty.
+  *
+  * The $L_1$ penalty can be used to drive a number of the solution coefficients to 0, thereby
+  * producing sparse solutions.
+  *
+  */
+class L1Regularization extends Regularization {
+  /** Calculates and applies the regularization amount and the regularization parameter
+    *
+    * Implementation was taken from the Apache Spark Mllib library:
+    * http://git.io/vfZIT
+    *
+    * @param oldWeights The weights to be updated
+    * @param gradient The gradient according to which we will update the weights
+    * @param effectiveStepSize The effective step size for this iteration
+    * @param regParameter The regularization parameter to be applied in the case of L1
+    *                     regularization
+    */
+  override def takeStep(
+      oldWeights: FlinkVector,
+      gradient: FlinkVector,
+      effectiveStepSize: Double,
+      regParameter: Double) {
+    BLAS.axpy(-effectiveStepSize, gradient, oldWeights)
+
+    // Apply proximal operator (soft thresholding)
+    val shrinkageVal = regParameter * effectiveStepSize
+    var i = 0
+    while (i < oldWeights.size) {
+      val wi = oldWeights(i)
+      oldWeights(i) = math.signum(wi) * math.max(0.0, math.abs(wi) - shrinkageVal)
+      i += 1
+    }
+  }
+
+  /** Adds the regularization term to the loss value
+    *
+    * @param loss The loss value, before applying regularization.
+    * @param weightVector The current vector of weights.
+    * @param regularizationParameter The regularization parameter, $\lambda$.
+    * @return The loss value with regularization applied.
+    */
+  override def regLoss(loss: Double, weightVector: FlinkVector, regularizationParameter: Double):
+  Double = {
+    loss + l1Norm(weightVector) * regularizationParameter
+  }
+
+  // TODO(tvas): Replace once we decide on how we deal with vector ops (roll our own or use Breeze)
+  /** $L_1$ norm of a Vector **/
+  private def l1Norm(vector: FlinkVector) : Double = {
+    vector.valueIterator.fold(0.0){(a,b) => math.abs(a) + math.abs(b)}
+  }
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/2939fba3/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/optimization/Solver.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/optimization/Solver.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/optimization/Solver.scala
new file mode 100644
index 0000000..580e096
--- /dev/null
+++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/optimization/Solver.scala
@@ -0,0 +1,135 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.optimization
+
+import org.apache.flink.api.scala.DataSet
+import org.apache.flink.ml.common._
+import org.apache.flink.ml.math.{Vector => FlinkVector, BLAS, DenseVector}
+import org.apache.flink.api.scala._
+import org.apache.flink.ml.optimization.IterativeSolver._
+import org.apache.flink.ml.optimization.Solver._
+
+/** Base class for optimization algorithms
+ *
+ */
+abstract class Solver extends Serializable with WithParameters {
+
+  /** Provides a solution for the given optimization problem
+    *
+    * @param data A Dataset of LabeledVector (input, output) pairs
+    * @param initialWeights The initial weight that will be optimized
+    * @return A Vector of weights optimized to the given problem
+    */
+  def optimize(
+    data: DataSet[LabeledVector],
+    initialWeights: Option[DataSet[WeightVector]]): DataSet[WeightVector]
+
+  /** Creates a DataSet with one zero vector. The zero vector has dimension d, which is given
+    * by the dimensionDS.
+    *
+    * @param dimensionDS DataSet with one element d, denoting the dimension of the returned zero
+    *                    vector
+    * @return DataSet of a zero vector of dimension d
+    */
+  def createInitialWeightVector(dimensionDS: DataSet[Int]):  DataSet[WeightVector] = {
+    dimensionDS.map {
+      dimension =>
+        val values = Array.fill(dimension)(0.0)
+        new WeightVector(DenseVector(values), 0.0)
+    }
+  }
+
+  //Setters for parameters
+  def setLossFunction(lossFunction: LossFunction): Solver = {
+    parameters.add(LossFunction, lossFunction)
+    this
+  }
+
+  def setRegularizationType(regularization: Regularization): Solver = {
+    parameters.add(RegularizationType, regularization)
+    this
+  }
+
+  def setRegularizationParameter(regularizationParameter: Double): Solver = {
+    parameters.add(RegularizationParameter, regularizationParameter)
+    this
+  }
+
+  def setPredictionFunction(predictionFunction: PredictionFunction): Solver = {
+    parameters.add(PredictionFunctionParameter, predictionFunction)
+    this
+  }
+}
+
+object Solver {
+  // TODO(tvas): Does this belong in IterativeSolver instead?
+  val WEIGHTVECTOR_BROADCAST = "weights_broadcast"
+
+  // Define parameters for Solver
+  case object LossFunction extends Parameter[LossFunction] {
+    // TODO(tvas): Should depend on problem, here is where differentiating between classification
+    // and regression could become useful
+    val defaultValue = Some(new SquaredLoss)
+  }
+
+  case object RegularizationType extends Parameter[Regularization] {
+    val defaultValue = Some(new NoRegularization)
+  }
+
+  case object RegularizationParameter extends Parameter[Double] {
+    val defaultValue = Some(0.0) // TODO(tvas): Properly initialize this, ensure Parameter > 0!
+  }
+
+  case object PredictionFunctionParameter extends Parameter[PredictionFunction] {
+    val defaultValue = Some(new LinearPrediction)
+  }
+}
+
+/** An abstract class for iterative optimization algorithms
+  *
+  * See [[https://en.wikipedia.org/wiki/Iterative_method Iterative Methods on Wikipedia]] for more
+  * info
+  */
+abstract class IterativeSolver extends Solver {
+
+  //Setters for parameters
+  def setIterations(iterations: Int): IterativeSolver = {
+    parameters.add(Iterations, iterations)
+    this
+  }
+
+  def setStepsize(stepsize: Double): IterativeSolver = {
+    parameters.add(Stepsize, stepsize)
+    this
+  }
+}
+
+object IterativeSolver {
+
+  val MAX_DLOSS: Double = 1e12
+
+  // Define parameters for IterativeSolver
+  case object Stepsize extends Parameter[Double] {
+    val defaultValue = Some(0.1)
+  }
+
+  case object Iterations extends Parameter[Int] {
+    val defaultValue = Some(10)
+  }
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/2939fba3/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/optimization/GradientDescentITSuite.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/optimization/GradientDescentITSuite.scala b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/optimization/GradientDescentITSuite.scala
new file mode 100644
index 0000000..2734419
--- /dev/null
+++ b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/optimization/GradientDescentITSuite.scala
@@ -0,0 +1,210 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.optimization
+
+import org.apache.flink.ml.common.{LabeledVector, WeightVector, ParameterMap}
+import org.apache.flink.ml.math.DenseVector
+import org.apache.flink.ml.regression.RegressionData._
+import org.scalatest.{Matchers, FlatSpec}
+
+import org.apache.flink.api.scala._
+import org.apache.flink.test.util.FlinkTestBase
+
+
+class GradientDescentITSuite extends FlatSpec with Matchers with FlinkTestBase {
+
+  // TODO(tvas): Check results again once sampling operators are in place
+
+  behavior of "The Stochastic Gradient Descent implementation"
+
+  it should "correctly solve an L1 regularized regression problem" in {
+    val env = ExecutionEnvironment.getExecutionEnvironment
+
+    env.setParallelism(2)
+
+    val parameters = ParameterMap()
+
+    parameters.add(IterativeSolver.Stepsize, 0.01)
+    parameters.add(IterativeSolver.Iterations, 2000)
+    parameters.add(Solver.LossFunction, new SquaredLoss)
+    parameters.add(Solver.RegularizationType, new L1Regularization)
+    parameters.add(Solver.RegularizationParameter, 0.3)
+
+    val sgd = GradientDescent(parameters)
+
+    val inputDS: DataSet[LabeledVector] = env.fromCollection(regularizationData)
+
+    val weightDS = sgd.optimize(inputDS, None)
+
+    val weightList: Seq[WeightVector] = weightDS.collect()
+
+    val weightVector: WeightVector = weightList.head
+
+    val intercept = weightVector.intercept
+    val weights = weightVector.weights.asInstanceOf[DenseVector].data
+
+    expectedRegWeights zip weights foreach {
+      case (expectedWeight, weight) =>
+        weight should be (expectedWeight +- 0.01)
+    }
+
+    intercept should be (expectedRegWeight0 +- 0.1)
+  }
+
+  it should "correctly perform one step with L2 regularization" in {
+    val env = ExecutionEnvironment.getExecutionEnvironment
+
+    env.setParallelism(2)
+
+    val parameters = ParameterMap()
+
+    parameters.add(IterativeSolver.Stepsize, 0.1)
+    parameters.add(IterativeSolver.Iterations, 1)
+    parameters.add(Solver.LossFunction, new SquaredLoss)
+    parameters.add(Solver.RegularizationType, new L2Regularization)
+    parameters.add(Solver.RegularizationParameter, 1.0)
+
+    val sgd = GradientDescent(parameters)
+
+    val inputDS: DataSet[LabeledVector] = env.fromElements(LabeledVector(1.0, DenseVector(2.0)))
+    val currentWeights = new WeightVector(DenseVector(1.0), 1.0)
+    val currentWeightsDS = env.fromElements(currentWeights)
+
+    val weightDS = sgd.optimize(inputDS, Some(currentWeightsDS))
+
+    val weightList: Seq[WeightVector] = weightDS.collect()
+
+    weightList.size should equal(1)
+
+    val weightVector: WeightVector = weightList.head
+
+    val updatedIntercept = weightVector.intercept
+    val updatedWeight = weightVector.weights(0)
+
+    updatedWeight should be (0.5 +- 0.001)
+    updatedIntercept should be (0.8 +- 0.01)
+  }
+
+  it should "estimate a linear function" in {
+    val env = ExecutionEnvironment.getExecutionEnvironment
+
+    env.setParallelism(2)
+
+    val parameters = ParameterMap()
+
+    parameters.add(IterativeSolver.Stepsize, 1.0)
+    parameters.add(IterativeSolver.Iterations, 800)
+    parameters.add(Solver.LossFunction, new SquaredLoss)
+    parameters.add(Solver.RegularizationType, new NoRegularization)
+    parameters.add(Solver.RegularizationParameter, 0.0)
+
+    val sgd = GradientDescent(parameters)
+
+    val inputDS = env.fromCollection(data)
+    val weightDS = sgd.optimize(inputDS, None)
+
+    val weightList: Seq[WeightVector] = weightDS.collect()
+
+    weightList.size should equal(1)
+
+    val weightVector: WeightVector = weightList.head
+
+    val weights = weightVector.weights.asInstanceOf[DenseVector].data
+    val weight0 = weightVector.intercept
+
+
+    expectedWeights zip weights foreach {
+      case (expectedWeight, weight) =>
+        weight should be (expectedWeight +- 0.1)
+    }
+    weight0 should be (expectedWeight0 +- 0.1)
+  }
+
+  it should "estimate a linear function without an intercept" in {
+    val env = ExecutionEnvironment.getExecutionEnvironment
+
+    env.setParallelism(2)
+
+    val parameters = ParameterMap()
+
+    parameters.add(IterativeSolver.Stepsize, 0.0001)
+    parameters.add(IterativeSolver.Iterations, 100)
+    parameters.add(Solver.LossFunction, new SquaredLoss)
+    parameters.add(Solver.RegularizationType, new NoRegularization)
+    parameters.add(Solver.RegularizationParameter, 0.0)
+
+    val sgd = GradientDescent(parameters)
+
+    val inputDS = env.fromCollection(noInterceptData)
+    val weightDS = sgd.optimize(inputDS, None)
+
+    val weightList: Seq[WeightVector] = weightDS.collect()
+
+    weightList.size should equal(1)
+
+    val weightVector: WeightVector = weightList.head
+
+    val weights = weightVector.weights.asInstanceOf[DenseVector].data
+    val weight0 = weightVector.intercept
+
+    expectedNoInterceptWeights zip weights foreach {
+      case (expectedWeight, weight) =>
+        weight should be (expectedWeight +- 0.1)
+    }
+    weight0 should be (expectedNoInterceptWeight0 +- 0.1)
+  }
+
+  it should "correctly perform one step of the algorithm with initial weights provided" in {
+    val env = ExecutionEnvironment.getExecutionEnvironment
+
+    env.setParallelism(2)
+
+    val parameters = ParameterMap()
+
+    parameters.add(IterativeSolver.Stepsize, 0.1)
+    parameters.add(IterativeSolver.Iterations, 1)
+    parameters.add(Solver.LossFunction, new SquaredLoss)
+    parameters.add(Solver.RegularizationType, new NoRegularization)
+    parameters.add(Solver.RegularizationParameter, 0.0)
+
+    val sgd = GradientDescent(parameters)
+
+    val inputDS: DataSet[LabeledVector] = env.fromElements(LabeledVector(1.0, DenseVector(2.0)))
+    val currentWeights = new WeightVector(DenseVector(1.0), 1.0)
+    val currentWeightsDS = env.fromElements(currentWeights)
+
+    val weightDS = sgd.optimize(inputDS, Some(currentWeightsDS))
+
+    val weightList: Seq[WeightVector] = weightDS.collect()
+
+    weightList.size should equal(1)
+
+    val weightVector: WeightVector = weightList.head
+
+    val updatedIntercept = weightVector.intercept
+    val updatedWeight = weightVector.weights(0)
+
+    updatedWeight should be (0.6 +- 0.01)
+    updatedIntercept should be (0.8 +- 0.01)
+
+  }
+
+  // TODO: Need more corner cases
+
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/2939fba3/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/optimization/LossFunctionITSuite.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/optimization/LossFunctionITSuite.scala b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/optimization/LossFunctionITSuite.scala
new file mode 100644
index 0000000..e5509a3
--- /dev/null
+++ b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/optimization/LossFunctionITSuite.scala
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.optimization
+
+import org.apache.flink.ml.common.{LabeledVector, WeightVector, ParameterMap}
+import org.apache.flink.ml.math.{BLAS, Vector => FlinkVector, DenseVector}
+import org.scalatest.{Matchers, FlatSpec}
+
+import org.apache.flink.api.scala._
+import org.apache.flink.test.util.FlinkTestBase
+
+
+class LossFunctionITSuite extends FlatSpec with Matchers with FlinkTestBase {
+
+  behavior of "The optimization Loss Function implementations"
+
+  it should "calculate squared loss correctly" in {
+    val env = ExecutionEnvironment.getExecutionEnvironment
+
+    env.setParallelism(2)
+
+    val squaredLoss = new SquaredLoss
+
+    val example = LabeledVector(1.0, DenseVector(2))
+    val weightVector = new WeightVector(DenseVector(1.0), 1.0)
+    val gradient = DenseVector(0.0)
+
+    val (loss, lossDerivative) = squaredLoss.lossAndGradient(example, weightVector, gradient, new
+        NoRegularization, 0.0, new LinearPrediction)
+
+    loss should be (2.0 +- 0.001)
+
+    lossDerivative should be (2.0 +- 0.001)
+
+    gradient.data(0) should be (4.0 +- 0.001)
+
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/2939fba3/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/optimization/PredictionFunctionITSuite.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/optimization/PredictionFunctionITSuite.scala b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/optimization/PredictionFunctionITSuite.scala
new file mode 100644
index 0000000..69e67e9
--- /dev/null
+++ b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/optimization/PredictionFunctionITSuite.scala
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.optimization
+
+import org.apache.flink.ml.common.WeightVector
+import org.apache.flink.ml.math.DenseVector
+import org.apache.flink.api.scala._
+import org.apache.flink.test.util.FlinkTestBase
+
+import org.scalatest.{Matchers, FlatSpec}
+
+class PredictionFunctionITSuite extends FlatSpec with Matchers with FlinkTestBase {
+
+  behavior of "The optimization framework prediction functions"
+
+  it should "correctly calculate linear predictions" in {
+    val env = ExecutionEnvironment.getExecutionEnvironment
+
+    env.setParallelism(2)
+
+    val predFunction = new LinearPrediction
+
+    val weightVector = new WeightVector(DenseVector(-1.0, 1.0, 0.4, -0.4, 0.0), 1.0)
+    val features = DenseVector(1.0, 1.0, 1.0, 1.0, 1.0)
+
+    val prediction = predFunction.predict(features, weightVector)
+
+    prediction should be (1.0 +- 0.001)
+  }
+
+  it should "correctly calculate the gradient for linear predictions" in {
+    val env = ExecutionEnvironment.getExecutionEnvironment
+
+    env.setParallelism(2)
+
+    val predFunction = new LinearPrediction
+
+    val weightVector = new WeightVector(DenseVector(-1.0, 1.0, 0.4, -0.4, 0.0), 1.0)
+    val features = DenseVector(1.0, 1.0, 1.0, 1.0, 1.0)
+
+    val gradient = predFunction.gradient(features, weightVector)
+
+    gradient shouldEqual DenseVector(1.0, 1.0, 1.0, 1.0, 1.0)
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/2939fba3/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/optimization/RegularizationITSuite.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/optimization/RegularizationITSuite.scala b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/optimization/RegularizationITSuite.scala
new file mode 100644
index 0000000..ad3ea89
--- /dev/null
+++ b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/optimization/RegularizationITSuite.scala
@@ -0,0 +1,115 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.optimization
+
+import org.apache.flink.ml.common.WeightVector
+import org.apache.flink.ml.math.DenseVector
+import org.apache.flink.api.scala._
+import org.apache.flink.test.util.FlinkTestBase
+
+import org.scalatest.{Matchers, FlatSpec}
+
+
+
+
+class RegularizationITSuite extends FlatSpec with Matchers with FlinkTestBase {
+
+  behavior of "The regularization type implementations"
+
+  it should "not change the loss when no regularization is used" in {
+
+    val env = ExecutionEnvironment.getExecutionEnvironment
+
+    env.setParallelism(2)
+
+    val regularization = new NoRegularization
+
+    val weightVector = new WeightVector(DenseVector(1.0), 1.0)
+    val effectiveStepsize = 1.0
+    val regParameter = 0.0
+    val gradient = DenseVector(0.0)
+    val originalLoss = 1.0
+
+    val adjustedLoss = regularization.regLoss(originalLoss, weightVector.weights, regParameter)
+
+    adjustedLoss should be (originalLoss +- 0.0001)
+  }
+
+  it should "correctly apply L1 regularization" in {
+    val env = ExecutionEnvironment.getExecutionEnvironment
+
+    env.setParallelism(2)
+
+    val regularization = new L1Regularization
+
+    val weightVector = new WeightVector(DenseVector(-1.0, 1.0, 0.4, -0.4, 0.0), 1.0)
+    val effectiveStepsize = 1.0
+    val regParameter = 0.5
+    val gradient = DenseVector(0.0, 0.0, 0.0, 0.0, 0.0)
+
+    regularization.takeStep(weightVector.weights,  gradient, effectiveStepsize, regParameter)
+
+    val expectedWeights = DenseVector(-0.5, 0.5, 0.0, 0.0, 0.0)
+
+    weightVector.weights shouldEqual expectedWeights
+    weightVector.intercept should be (1.0 +- 0.0001)
+  }
+
+  it should "correctly calculate L1 loss"  in {
+    val env = ExecutionEnvironment.getExecutionEnvironment
+
+    env.setParallelism(2)
+
+    val regularization = new L1Regularization
+
+    val weightVector = new WeightVector(DenseVector(-1.0, 1.0, 0.4, -0.4, 0.0), 1.0)
+    val regParameter = 0.5
+    val originalLoss = 1.0
+
+    val adjustedLoss = regularization.regLoss(originalLoss, weightVector.weights, regParameter)
+
+    weightVector shouldEqual WeightVector(DenseVector(-1.0, 1.0, 0.4, -0.4, 0.0), 1.0)
+    adjustedLoss should be (2.4 +- 0.1)
+  }
+
+  it should "correctly adjust the gradient and loss for L2 regularization" in {
+    val env = ExecutionEnvironment.getExecutionEnvironment
+
+    env.setParallelism(2)
+
+    val regularization = new L2Regularization
+
+    val weightVector = new WeightVector(DenseVector(-1.0, 1.0, 0.4, -0.4, 0.0), 1.0)
+    val regParameter = 0.5
+    val lossGradient = DenseVector(0.0, 0.0, 0.0, 0.0, 0.0)
+    val originalLoss = 1.0
+
+    val adjustedLoss = regularization.regularizedLossAndGradient(
+      originalLoss,
+      weightVector.weights,
+      lossGradient,
+      regParameter)
+
+    val expectedGradient = DenseVector(-0.5, 0.5, 0.2, -0.2, 0.0)
+
+    weightVector shouldEqual WeightVector(DenseVector(-1.0, 1.0, 0.4, -0.4, 0.0), 1.0)
+    adjustedLoss should be (1.58 +- 0.1)
+    lossGradient shouldEqual expectedGradient
+  }
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/2939fba3/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/regression/RegressionData.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/regression/RegressionData.scala b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/regression/RegressionData.scala
index 82b4cc3..8525c0f 100644
--- a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/regression/RegressionData.scala
+++ b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/regression/RegressionData.scala
@@ -70,6 +70,53 @@ object RegressionData {
     LabeledVector(12.0442, DenseVector(0.1192))
   )
 
+  val expectedNoInterceptWeights = Array[Double](5.0)
+  val expectedNoInterceptWeight0: Double = 0.0
+
+  val noInterceptData: Seq[LabeledVector] = Seq(
+    LabeledVector(217.228709, DenseVector(43.4457419)),
+    LabeledVector(450.037048, DenseVector(90.0074095)),
+    LabeledVector( 67.553478, DenseVector(13.5106955)),
+    LabeledVector( 26.976958, DenseVector( 5.3953916)),
+    LabeledVector(403.808709, DenseVector(80.7617418)),
+    LabeledVector(203.932158, DenseVector(40.7864316)),
+    LabeledVector(146.974958, DenseVector(29.3949916)),
+    LabeledVector( 46.869291, DenseVector( 9.3738582)),
+    LabeledVector(450.780834, DenseVector(90.1561667)),
+    LabeledVector(386.535619, DenseVector(77.3071239)),
+    LabeledVector(202.644342, DenseVector(40.5288684)),
+    LabeledVector(227.586507, DenseVector(45.5173013)),
+    LabeledVector(408.801080, DenseVector(81.7602161)),
+    LabeledVector(146.118550, DenseVector(29.2237100)),
+    LabeledVector(156.475382, DenseVector(31.2950763)),
+    LabeledVector(291.822515, DenseVector(58.3645030)),
+    LabeledVector( 61.506887, DenseVector(12.3013775)),
+    LabeledVector(363.949913, DenseVector(72.7899827)),
+    LabeledVector(398.050744, DenseVector(79.6101487)),
+    LabeledVector(246.053111, DenseVector(49.2106221)),
+    LabeledVector(225.494661, DenseVector(45.0989323)),
+    LabeledVector(265.986844, DenseVector(53.1973689)),
+    LabeledVector(110.459912, DenseVector(22.0919823)),
+    LabeledVector(122.716974, DenseVector(24.5433947)),
+    LabeledVector(128.014314, DenseVector(25.6028628)),
+    LabeledVector(252.538913, DenseVector(50.5077825)),
+    LabeledVector(393.632082, DenseVector(78.7264163)),
+    LabeledVector( 77.698941, DenseVector(15.5397881)),
+    LabeledVector(206.187568, DenseVector(41.2375135)),
+    LabeledVector(244.073426, DenseVector(48.8146851)),
+    LabeledVector(364.946890, DenseVector(72.9893780)),
+    LabeledVector(  4.627494, DenseVector( 0.9254987)),
+    LabeledVector(485.359565, DenseVector(97.0719130)),
+    LabeledVector(347.359190, DenseVector(69.4718380)),
+    LabeledVector(419.663211, DenseVector(83.9326422)),
+    LabeledVector(488.518318, DenseVector(97.7036635)),
+    LabeledVector( 28.082962, DenseVector( 5.6165925)),
+    LabeledVector(211.002441, DenseVector(42.2004881)),
+    LabeledVector(250.624124, DenseVector(50.1248248)),
+    LabeledVector(489.776669, DenseVector(97.9553337))
+  )
+
+
   val expectedPolynomialWeights = Seq(0.2375, -0.3493, -0.1674)
   val expectedPolynomialWeight0 = 0.0233
   val expectedPolynomialSquaredResidualSum = 1.5389e+03
@@ -126,4 +173,19 @@ object RegressionData {
     LabeledVector(-3.1140, DenseVector(3.1921)),
     LabeledVector(-1.4323, DenseVector(3.3961))
   )
+
+  val expectedRegWeights = Array[Double](0.0, 0.0, 0.0, 0.18, 0.2, 0.24)
+  val expectedRegWeight0 = 0.74
+
+  // Example values from scikit-learn L1 test: http://git.io/vf4V2
+  val regularizationData: Seq[LabeledVector] = Seq(
+    LabeledVector(1.0, DenseVector(1.0,0.9 ,0.8 ,0.0 ,0.0 ,0.0)),
+    LabeledVector(1.0, DenseVector(1.0,0.84,0.98,0.0 ,0.0 ,0.0)),
+    LabeledVector(1.0, DenseVector(1.0,0.96,0.88,0.0 ,0.0 ,0.0)),
+    LabeledVector(1.0, DenseVector(1.0,0.91,0.99,0.0 ,0.0 ,0.0)),
+    LabeledVector(2.0, DenseVector(0.0,0.0 ,0.0 ,0.89,0.91,1.0)),
+    LabeledVector(2.0, DenseVector(0.0,0.0 ,0.0 ,0.79,0.84,1.0)),
+    LabeledVector(2.0, DenseVector(0.0,0.0 ,0.0 ,0.91,0.95,1.0)),
+    LabeledVector(2.0, DenseVector(0.0,0.0 ,0.0 ,0.93,1.0 ,1.0))
+  )
 }