You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by tr...@apache.org on 2017/02/12 19:41:37 UTC

flink git commit: [FLINK-1979] [ml] Add logistic loss, hinge loss and regularization penalties for optimization

Repository: flink
Updated Branches:
  refs/heads/master 4f47ccdcb -> d3a07ef61


[FLINK-1979] [ml] Add logistic loss, hinge loss and regularization penalties for optimization

Use parameter to set regularization penalty in gradient descent solver

Update regularization penalty docs

This closes #1985.


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/d3a07ef6
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/d3a07ef6
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/d3a07ef6

Branch: refs/heads/master
Commit: d3a07ef617562ee10fe46acc6a394e7f01dbcafe
Parents: 4f47ccd
Author: spkavuly <so...@intel.com>
Authored: Wed May 11 15:55:13 2016 -0700
Committer: Till Rohrmann <tr...@apache.org>
Committed: Sun Feb 12 20:24:49 2017 +0100

----------------------------------------------------------------------
 docs/dev/libs/ml/optimization.md                | 101 ++++++---
 .../flink/ml/optimization/GradientDescent.scala | 147 +++----------
 .../flink/ml/optimization/LossFunction.scala    |   4 +-
 .../ml/optimization/PartialLossFunction.scala   | 111 +++++++++-
 .../ml/optimization/RegularizationPenalty.scala | 219 +++++++++++++++++++
 .../apache/flink/ml/optimization/Solver.scala   |   9 +
 .../regression/MultipleLinearRegression.scala   |   2 +-
 .../optimization/GradientDescentITSuite.scala   |  18 +-
 .../ml/optimization/LossFunctionITSuite.scala   |  51 -----
 .../ml/optimization/LossFunctionTest.scala      | 102 +++++++++
 .../RegularizationPenaltyTest.scala             |  64 ++++++
 11 files changed, 610 insertions(+), 218 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/d3a07ef6/docs/dev/libs/ml/optimization.md
----------------------------------------------------------------------
diff --git a/docs/dev/libs/ml/optimization.md b/docs/dev/libs/ml/optimization.md
index e3e2f63..739d912 100644
--- a/docs/dev/libs/ml/optimization.md
+++ b/docs/dev/libs/ml/optimization.md
@@ -1,7 +1,10 @@
 ---
 mathjax: include
 title: Optimization
-nav-parent_id: ml
+# Sub navigation
+sub-nav-group: batch
+sub-nav-parent: flinkml
+sub-nav-title: Optimization
 ---
 <!--
 Licensed to the Apache Software Foundation (ASF) under one
@@ -102,33 +105,6 @@ The current implementation of SGD  uses the whole partition, making it
 effectively a batch gradient descent. Once a sampling operator has been introduced in Flink, true
 mini-batch SGD will be performed.
 
-### Regularization
-
-FlinkML supports Stochastic Gradient Descent with L1, L2 and no regularization.
-The following list contains a mapping between the implementing classes and the regularization function.
-
-<table class="table table-bordered">
-  <thead>
-    <tr>
-      <th class="text-left" style="width: 20%">Class Name</th>
-      <th class="text-center">Regularization function $R(\wv)$</th>
-    </tr>
-  </thead>
-  <tbody>
-    <tr>
-      <td><code>SimpleGradient</code></td>
-      <td>$R(\wv) = 0$</td>
-    </tr>
-    <tr>
-      <td><code>GradientDescentL1</code></td>
-      <td>$R(\wv) = \norm{\wv}_1$</td>
-    </tr>
-    <tr>
-      <td><code>GradientDescentL2</code></td>
-      <td>$R(\wv) = \frac{1}{2}\norm{\wv}_2^2$</td>
-    </tr>
-  </tbody>
-</table>
 
 ### Parameters
 
@@ -143,10 +119,10 @@ The following list contains a mapping between the implementing classes and the r
     </thead>
     <tbody>
       <tr>
-        <td><strong>LossFunction</strong></td>
+        <td><strong>RegularizationPenalty</strong></td>
         <td>
           <p>
-            The loss function to be optimized. (Default value: <strong>None</strong>)
+            The regularization function to apply. (Default value: <strong>NoRegularization</strong>)
           </p>
         </td>
       </tr>
@@ -159,6 +135,14 @@ The following list contains a mapping between the implementing classes and the r
         </td>
       </tr>
       <tr>
+        <td><strong>LossFunction</strong></td>
+        <td>
+          <p>
+            The loss function to be optimized. (Default value: <strong>None</strong>)
+          </p>
+        </td>
+      </tr>
+      <tr>
         <td><strong>Iterations</strong></td>
         <td>
           <p>
@@ -206,6 +190,35 @@ The following list contains a mapping between the implementing classes and the r
     </tbody>
   </table>
 
+### Regularization
+
+FlinkML supports Stochastic Gradient Descent with L1, L2 and no regularization. The regularization type has to implement the `RegularizationPenalty` interface,
+which calculates the new weights based on the gradient and regularization type.
+The following list contains the supported regularization functions.
+
+<table class="table table-bordered">
+  <thead>
+    <tr>
+      <th class="text-left" style="width: 20%">Class Name</th>
+      <th class="text-center">Regularization function $R(\wv)$</th>
+    </tr>
+  </thead>
+  <tbody>
+    <tr>
+      <td><strong>NoRegularization</strong></td>
+      <td>$R(\wv) = 0$</td>
+    </tr>
+    <tr>
+      <td><strong>L1Regularization</strong></td>
+      <td>$R(\wv) = \norm{\wv}_1$</td>
+    </tr>
+    <tr>
+      <td><strong>L2Regularization</strong></td>
+      <td>$R(\wv) = \frac{1}{2}\norm{\wv}_2^2$</td>
+    </tr>
+  </tbody>
+</table>
+
 ### Loss Function
 
 The loss function which is minimized has to implement the `LossFunction` interface, which defines methods to compute the loss and the gradient of it.
@@ -241,6 +254,29 @@ The full list of supported prediction functions can be found [here](#prediction-
         <td class="text-center">$\frac{1}{2} (\wv^T \cdot \x - y)^2$</td>
         <td class="text-center">$\wv^T \cdot \x - y$</td>
       </tr>
+      <tr>
+        <td><strong>LogisticLoss</strong></td>
+        <td>
+          <p>
+            Loss function used for classification tasks.
+          </p>
+        </td>
+        <td class="text-center">$\log\left(1+\exp\left( -y ~ \wv^T \cdot \x\right)\right), \quad y \in \{-1, +1\}$</td>
+        <td class="text-center">$\frac{-y}{1+\exp\left(y ~ \wv^T \cdot \x\right)}$</td>
+      </tr>
+      <tr>
+        <td><strong>HingeLoss</strong></td>
+        <td>
+          <p>
+            Loss function used for classification tasks.
+          </p>
+        </td>
+        <td class="text-center">$\max \left(0, 1 - y ~ \wv^T \cdot \x\right), \quad y \in \{-1, +1\}$</td>
+        <td class="text-center">$\begin{cases}
+                                 -y&\text{if } y ~ \wv^T <= 1 \\
+                                 0&\text{if } y ~ \wv^T > 1
+                                 \end{cases}$</td>
+      </tr>
     </tbody>
   </table>
 
@@ -354,7 +390,7 @@ Where:
 ### Examples
 
 In the Flink implementation of SGD, given a set of examples in a `DataSet[LabeledVector]` and
-optionally some initial weights, we can use `GradientDescentL1.optimize()` in order to optimize
+optionally some initial weights, we can use `GradientDescent.optimize()` in order to optimize
 the weights for the given data.
 
 The user can provide an initial `DataSet[WeightVector]`,
@@ -366,8 +402,9 @@ weight vector. This allows us to avoid applying regularization to the intercept.
 
 {% highlight scala %}
 // Create stochastic gradient descent solver
-val sgd = GradientDescentL1()
+val sgd = GradientDescent()
   .setLossFunction(SquaredLoss())
+  .setRegularizationPenalty(L1Regularization)
   .setRegularizationConstant(0.2)
   .setIterations(100)
   .setLearningRate(0.01)

http://git-wip-us.apache.org/repos/asf/flink/blob/d3a07ef6/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/optimization/GradientDescent.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/optimization/GradientDescent.scala b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/optimization/GradientDescent.scala
index 407c074..fbb3a31 100644
--- a/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/optimization/GradientDescent.scala
+++ b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/optimization/GradientDescent.scala
@@ -38,6 +38,7 @@ import org.apache.flink.ml._
   *
   *  The parameters to tune the algorithm are:
   *                      [[Solver.LossFunction]] for the loss function to be used,
+  *                      [[Solver.RegularizationPenaltyValue]] for the regularization penalty.
   *                      [[Solver.RegularizationConstant]] for the regularization parameter,
   *                      [[IterativeSolver.Iterations]] for the maximum number of iteration,
   *                      [[IterativeSolver.LearningRate]] for the learning rate used.
@@ -47,7 +48,7 @@ import org.apache.flink.ml._
   *                      [[IterativeSolver.LearningRateMethodValue]] determines functional form of
   *                      effective learning rate.
   */
-abstract class GradientDescent extends IterativeSolver {
+class GradientDescent extends IterativeSolver {
 
   /** Provides a solution for the given optimization problem
     *
@@ -63,8 +64,10 @@ abstract class GradientDescent extends IterativeSolver {
     val convergenceThresholdOption: Option[Double] = parameters.get(ConvergenceThreshold)
     val lossFunction = parameters(LossFunction)
     val learningRate = parameters(LearningRate)
+    val regularizationPenalty = parameters(RegularizationPenaltyValue)
     val regularizationConstant = parameters(RegularizationConstant)
     val learningRateMethod = parameters(LearningRateMethodValue)
+
     // Initialize weights
     val initialWeightsDS: DataSet[WeightVector] = createInitialWeightsDS(initialWeights, data)
 
@@ -76,6 +79,7 @@ abstract class GradientDescent extends IterativeSolver {
           data,
           initialWeightsDS,
           numberOfIterations,
+          regularizationPenalty,
           regularizationConstant,
           learningRate,
           lossFunction,
@@ -85,6 +89,7 @@ abstract class GradientDescent extends IterativeSolver {
           data,
           initialWeightsDS,
           numberOfIterations,
+          regularizationPenalty,
           regularizationConstant,
           learningRate,
           convergence,
@@ -97,6 +102,7 @@ abstract class GradientDescent extends IterativeSolver {
       dataPoints: DataSet[LabeledVector],
       initialWeightsDS: DataSet[WeightVector],
       numberOfIterations: Int,
+      regularizationPenalty: RegularizationPenalty,
       regularizationConstant: Double,
       learningRate: Double,
       convergenceThreshold: Double,
@@ -123,6 +129,7 @@ abstract class GradientDescent extends IterativeSolver {
           dataPoints,
           previousWeightsDS,
           lossFunction,
+          regularizationPenalty,
           regularizationConstant,
           learningRate,
           learningRateMethod)
@@ -152,6 +159,7 @@ abstract class GradientDescent extends IterativeSolver {
       data: DataSet[LabeledVector],
       initialWeightsDS: DataSet[WeightVector],
       numberOfIterations: Int,
+      regularizationPenalty: RegularizationPenalty,
       regularizationConstant: Double,
       learningRate: Double,
       lossFunction: LossFunction,
@@ -162,6 +170,7 @@ abstract class GradientDescent extends IterativeSolver {
         SGDStep(data,
                 weightVectorDS,
                 lossFunction,
+                regularizationPenalty,
                 regularizationConstant,
                 learningRate,
                 optimizationMethod)
@@ -173,12 +182,19 @@ abstract class GradientDescent extends IterativeSolver {
     *
     * @param data A Dataset of LabeledVector (label, features) pairs
     * @param currentWeights A Dataset with the current weights to be optimized as its only element
+    * @param lossFunction The loss function to be used
+    * @param regularizationPenalty The regularization penalty to be used
+    * @param regularizationConstant The regularization parameter
+    * @param learningRate The effective step size for this iteration
+    * @param learningRateMethod The learning rate used
+    *
     * @return A Dataset containing the weights after one stochastic gradient descent step
     */
   private def SGDStep(
     data: DataSet[(LabeledVector)],
     currentWeights: DataSet[WeightVector],
     lossFunction: LossFunction,
+    regularizationPenalty: RegularizationPenalty,
     regularizationConstant: Double,
     learningRate: Double,
     learningRateMethod: LearningRateMethodTrait)
@@ -220,6 +236,7 @@ abstract class GradientDescent extends IterativeSolver {
         val newWeights = takeStep(
           weightVector.weights,
           gradient.weights,
+          regularizationPenalty,
           regularizationConstant,
           effectiveLearningRate)
 
@@ -232,25 +249,29 @@ abstract class GradientDescent extends IterativeSolver {
 
   /** Calculates the new weights based on the gradient
     *
-    * @param weightVector
-    * @param gradient
-    * @param regularizationConstant
-    * @param learningRate
-    * @return
+    * @param weightVector The weights to be updated
+    * @param gradient The gradient according to which we will update the weights
+    * @param regularizationPenalty The regularization penalty to apply
+    * @param regularizationConstant The regularization parameter
+    * @param learningRate The effective step size for this iteration
+    * @return Updated weights
     */
   def takeStep(
     weightVector: Vector,
     gradient: Vector,
+    regularizationPenalty: RegularizationPenalty,
     regularizationConstant: Double,
     learningRate: Double
-    ): Vector
+    ): Vector = {
+    regularizationPenalty.takeStep(weightVector, gradient, regularizationConstant, learningRate)
+  }
 
   /** Calculates the regularized loss, from the data and given weights.
     *
-    * @param data
-    * @param weightDS
-    * @param lossFunction
-    * @return
+    * @param data A Dataset of LabeledVector (label, features) pairs
+    * @param weightDS A Dataset with the current weights to be optimized as its only element
+    * @param lossFunction The loss function to be used
+    * @return A Dataset with the regularized loss as its only element
     */
   private def calculateLoss(
       data: DataSet[LabeledVector],
@@ -267,108 +288,10 @@ abstract class GradientDescent extends IterativeSolver {
   }
 }
 
-/** Implementation of a SGD solver with L2 regularization.
-  *
-  * The regularization function is `1/2 ||w||_2^2` with `w` being the weight vector.
-  */
-class GradientDescentL2 extends GradientDescent {
-
-  /** Calculates the new weights based on the gradient
-    *
-    * @param weightVector
-    * @param gradient
-    * @param regularizationConstant
-    * @param learningRate
-    * @return
-    */
-  override def takeStep(
-      weightVector: Vector,
-      gradient: Vector,
-      regularizationConstant: Double,
-      learningRate: Double)
-    : Vector = {
-    // add the gradient of the L2 regularization
-    BLAS.axpy(regularizationConstant, weightVector, gradient)
-
-    // update the weights according to the learning rate
-    BLAS.axpy(-learningRate, gradient, weightVector)
-
-    weightVector
-  }
-}
-
-object GradientDescentL2 {
-  def apply() = new GradientDescentL2
-}
-
-/** Implementation of a SGD solver with L1 regularization.
-  *
-  * The regularization function is `||w||_1` with `w` being the weight vector.
-  */
-class GradientDescentL1 extends GradientDescent {
-
-  /** Calculates the new weights based on the gradient.
-    *
-    * @param weightVector
-    * @param gradient
-    * @param regularizationConstant
-    * @param learningRate
-    * @return
-    */
-  override def takeStep(
-      weightVector: Vector,
-      gradient: Vector,
-      regularizationConstant: Double,
-      learningRate: Double)
-    : Vector = {
-    // Update weight vector with gradient. L1 regularization has no gradient, the proximal operator
-    // does the job.
-    BLAS.axpy(-learningRate, gradient, weightVector)
-
-    // Apply proximal operator (soft thresholding)
-    val shrinkageVal = regularizationConstant * learningRate
-    var i = 0
-    while (i < weightVector.size) {
-      val wi = weightVector(i)
-      weightVector(i) = scala.math.signum(wi) *
-        scala.math.max(0.0, scala.math.abs(wi) - shrinkageVal)
-      i += 1
-    }
-
-    weightVector
-  }
-}
-
-object GradientDescentL1 {
-  def apply() = new GradientDescentL1
-}
 
-/** Implementation of a SGD solver without regularization.
+/** Implementation of a Gradient Descent solver.
   *
-  * No regularization is applied.
   */
-class SimpleGradientDescent extends GradientDescent {
-
-  /** Calculates the new weights based on the gradient.
-    *
-    * @param weightVector
-    * @param gradient
-    * @param regularizationConstant
-    * @param learningRate
-    * @return
-    */
-  override def takeStep(
-      weightVector: Vector,
-      gradient: Vector,
-      regularizationConstant: Double,
-      learningRate: Double)
-    : Vector = {
-    // Update the weight vector
-    BLAS.axpy(-learningRate, gradient, weightVector)
-    weightVector
-  }
-}
-
-object SimpleGradientDescent{
-  def apply() = new SimpleGradientDescent
+object GradientDescent {
+  def apply() = new GradientDescent
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/d3a07ef6/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/optimization/LossFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/optimization/LossFunction.scala b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/optimization/LossFunction.scala
index 1ff5d97..bf96cac 100644
--- a/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/optimization/LossFunction.scala
+++ b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/optimization/LossFunction.scala
@@ -23,8 +23,8 @@ import org.apache.flink.ml.math.BLAS
 
 /** Abstract class that implements some of the functionality for common loss functions
   *
-  * A loss function determines the loss term $L(w) of the objective function  $f(w) = L(w) +
-  * \lambda R(w)$ for prediction tasks, the other being regularization, $R(w)$.
+  * A loss function determines the loss term `L(w)` of the objective function  `f(w) = L(w) +
+  * lambda*R(w)` for prediction tasks, the other being regularization, `R(w)`.
   *
   * The regularization is specific to the used optimization algorithm and, thus, implemented there.
   *

http://git-wip-us.apache.org/repos/asf/flink/blob/d3a07ef6/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/optimization/PartialLossFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/optimization/PartialLossFunction.scala b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/optimization/PartialLossFunction.scala
index ac0053e..10f7e00 100644
--- a/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/optimization/PartialLossFunction.scala
+++ b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/optimization/PartialLossFunction.scala
@@ -24,17 +24,17 @@ package org.apache.flink.ml.optimization
 trait PartialLossFunction extends Serializable {
   /** Calculates the loss depending on the label and the prediction
     *
-    * @param prediction
-    * @param label
-    * @return
+    * @param prediction The predicted value
+    * @param label The true value
+    * @return The loss
     */
   def loss(prediction: Double, label: Double): Double
 
   /** Calculates the derivative of the [[PartialLossFunction]]
     * 
-    * @param prediction
-    * @param label
-    * @return
+    * @param prediction The predicted value
+    * @param label The true value
+    * @return The derivative of the loss function
     */
   def derivative(prediction: Double, label: Double): Double
 }
@@ -47,9 +47,9 @@ object SquaredLoss extends PartialLossFunction {
 
   /** Calculates the loss depending on the label and the prediction
     *
-    * @param prediction
-    * @param label
-    * @return
+    * @param prediction The predicted value
+    * @param label The true value
+    * @return The loss
     */
   override def loss(prediction: Double, label: Double): Double = {
     0.5 * (prediction - label) * (prediction - label)
@@ -57,11 +57,98 @@ object SquaredLoss extends PartialLossFunction {
 
   /** Calculates the derivative of the [[PartialLossFunction]]
     *
-    * @param prediction
-    * @param label
-    * @return
+    * @param prediction The predicted value
+    * @param label The true value
+    * @return The derivative of the loss function
     */
   override def derivative(prediction: Double, label: Double): Double = {
     prediction - label
   }
 }
+
+/** Logistic loss function which can be used with the [[GenericLossFunction]]
+  *
+  *
+  * The [[LogisticLoss]] function implements `log(1 + -exp(prediction*label))`
+  * for binary classification with label in {-1, 1}
+  */
+object LogisticLoss extends PartialLossFunction {
+
+  /** Calculates the loss depending on the label and the prediction
+    *
+    * @param prediction The predicted value
+    * @param label The true value
+    * @return The loss
+    */
+  override def loss(prediction: Double, label: Double): Double = {
+    val z = prediction * label
+
+    // based on implementation in scikit-learn
+    // approximately equal and saves the computation of the log
+    if (z > 18) {
+      math.exp(-z)
+    }
+    else if (z < -18) {
+      -z
+    }
+    else {
+      math.log(1 + math.exp(-z))
+    }
+  }
+
+  /** Calculates the derivative of the loss function with respect to the prediction
+    *
+    * @param prediction The predicted value
+    * @param label The true value
+    * @return The derivative of the loss function
+    */
+  override def derivative(prediction: Double, label: Double): Double = {
+    val z = prediction * label
+
+    // based on implementation in scikit-learn
+    // approximately equal and saves the computation of the log
+    if (z > 18) {
+      label * math.exp(-z)
+    }
+    else if (z < -18) {
+      -label
+    }
+    else {
+      -label/(math.exp(z) + 1)
+    }
+  }
+}
+
+/** Hinge loss function which can be used with the [[GenericLossFunction]]
+  *
+  * The [[HingeLoss]] function implements `max(0, 1 - prediction*label)`
+  * for binary classification with label in {-1, 1}
+  */
+object HingeLoss extends PartialLossFunction {
+  /** Calculates the loss for a given prediction/truth pair
+    *
+    * @param prediction The predicted value
+    * @param label The true value
+    * @return The loss
+    */
+  override def loss(prediction: Double, label: Double): Double = {
+    val z = prediction * label
+    math.max(0, 1 - z)
+  }
+
+  /** Calculates the derivative of the loss function with respect to the prediction
+    *
+    * @param prediction The predicted value
+    * @param label The true value
+    * @return The derivative of the loss function
+    */
+  override def derivative(prediction: Double, label: Double): Double = {
+    val z = prediction * label
+    if (z <= 1) {
+      -label
+    }
+    else {
+      0
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/d3a07ef6/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/optimization/RegularizationPenalty.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/optimization/RegularizationPenalty.scala b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/optimization/RegularizationPenalty.scala
new file mode 100644
index 0000000..1ed59ab
--- /dev/null
+++ b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/optimization/RegularizationPenalty.scala
@@ -0,0 +1,219 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.optimization
+
+import org.apache.flink.ml.math.{Vector, BLAS}
+import org.apache.flink.ml.math.Breeze._
+import breeze.linalg.{norm => BreezeNorm}
+
+/** Represents a type of regularization penalty
+  *
+  * Regularization penalties are used to restrict the optimization problem to solutions with
+  * certain desirable characteristics, such as sparsity for the L1 penalty, or penalizing large
+  * weights for the L2 penalty.
+  *
+  * The regularization term, `R(w)` is added to the objective function, `f(w) = L(w) + lambda*R(w)`
+  * where lambda is the regularization parameter used to tune the amount of regularization applied.
+  */
+trait RegularizationPenalty extends Serializable {
+
+  /** Calculates the new weights based on the gradient and regularization penalty
+    *
+    * Weights are updated using the gradient descent step `w - learningRate * gradient`
+    * with `w` being the weight vector.
+    *
+    * @param weightVector The weights to be updated
+    * @param gradient The gradient used to update the weights
+    * @param regularizationConstant The regularization parameter to be applied 
+    * @param learningRate The effective step size for this iteration
+    * @return Updated weights
+    */
+  def takeStep(
+      weightVector: Vector,
+      gradient: Vector,
+      regularizationConstant: Double,
+      learningRate: Double)
+    : Vector
+
+  /** Adds regularization to the loss value
+    *
+    * @param oldLoss The loss to be updated
+    * @param weightVector The weights used to update the loss
+    * @param regularizationConstant The regularization parameter to be applied
+    * @return Updated loss
+    */
+  def regLoss(oldLoss: Double, weightVector: Vector, regularizationConstant: Double): Double
+
+}
+
+
+/** `L_2` regularization penalty.
+  *
+  * The regularization function is the square of the L2 norm `1/2*||w||_2^2`
+  * with `w` being the weight vector. The function penalizes large weights,
+  * favoring solutions with more small weights rather than few large ones.
+  */
+object L2Regularization extends RegularizationPenalty {
+
+  /** Calculates the new weights based on the gradient and L2 regularization penalty
+    *
+    * The updated weight is `w - learningRate * (gradient + lambda * w)` where
+    * `w` is the weight vector, and `lambda` is the regularization parameter.
+    *
+    * @param weightVector The weights to be updated
+    * @param gradient The gradient according to which we will update the weights
+    * @param regularizationConstant The regularization parameter to be applied
+    * @param learningRate The effective step size for this iteration
+    * @return Updated weights
+    */
+  override def takeStep(
+      weightVector: Vector,
+      gradient: Vector,
+      regularizationConstant: Double,
+      learningRate: Double)
+    : Vector = {
+    // add the gradient of the L2 regularization
+    BLAS.axpy(regularizationConstant, weightVector, gradient)
+
+    // update the weights according to the learning rate
+    BLAS.axpy(-learningRate, gradient, weightVector)
+
+    weightVector
+  }
+
+  /** Adds regularization to the loss value
+    *
+    * The updated loss is `oldLoss + lambda * 1/2*||w||_2^2` where
+    * `w` is the weight vector, and `lambda` is the regularization parameter
+    *
+    * @param oldLoss The loss to be updated
+    * @param weightVector The weights used to update the loss
+    * @param regularizationConstant The regularization parameter to be applied
+    * @return Updated loss
+    */
+  override def regLoss(oldLoss: Double, weightVector: Vector, regularizationConstant: Double)
+    : Double = {
+    val squareNorm = BLAS.dot(weightVector, weightVector)
+    oldLoss + regularizationConstant * 0.5 * squareNorm
+  }
+}
+
+/** `L_1` regularization penalty.
+  *
+  * The regularization function is the `L1` norm `||w||_1` with `w` being the weight vector.
+  * The `L_1` penalty can be used to drive a number of the solution coefficients to 0, thereby
+  * producing sparse solutions.
+  *
+  */
+object L1Regularization extends RegularizationPenalty {
+
+  /** Calculates the new weights based on the gradient and L1 regularization penalty
+    *
+    * Uses the proximal gradient method with L1 regularization to update weights.
+    * The updated weight `w - learningRate * gradient` is shrunk towards zero
+    * by applying the proximal operator `signum(w) * max(0.0, abs(w) - shrinkageVal)`
+    * where `w` is the weight vector, `lambda` is the regularization parameter,
+    * and `shrinkageVal` is `lambda*learningRate`.
+    *
+    * @param weightVector The weights to be updated
+    * @param gradient The gradient according to which we will update the weights
+    * @param regularizationConstant The regularization parameter to be applied
+    * @param learningRate The effective step size for this iteration
+    * @return Updated weights
+    */
+  override def takeStep(
+      weightVector: Vector,
+      gradient: Vector,
+      regularizationConstant: Double,
+      learningRate: Double)
+    : Vector = {
+    // Update weight vector with gradient.
+    BLAS.axpy(-learningRate, gradient, weightVector)
+
+    // Apply proximal operator (soft thresholding)
+    val shrinkageVal = regularizationConstant * learningRate
+    var i = 0
+    while (i < weightVector.size) {
+      val wi = weightVector(i)
+      weightVector(i) = math.signum(wi) *
+        math.max(0.0, math.abs(wi) - shrinkageVal)
+      i += 1
+    }
+
+    weightVector
+  }
+
+  /** Adds regularization to the loss value
+    *
+    * The updated loss is `oldLoss + lambda * ||w||_1` where
+    * `w` is the weight vector and `lambda` is the regularization parameter
+    *
+    * @param oldLoss The loss to be updated
+    * @param weightVector The weights used to update the loss
+    * @param regularizationConstant The regularization parameter to be applied
+    * @return Updated loss
+    */
+  override def regLoss(oldLoss: Double, weightVector: Vector, regularizationConstant: Double)
+    : Double = {
+    val norm =  BreezeNorm(weightVector.asBreeze, 1.0)
+    oldLoss + norm * regularizationConstant
+  }
+}
+
+/** No regularization penalty.
+  *
+  */
+object NoRegularization extends RegularizationPenalty {
+
+  /** Calculates the new weights based on the gradient
+    *
+    * The updated weight is `w - learningRate *gradient` where `w` is the weight vector
+    *
+    * @param weightVector The weights to be updated
+    * @param gradient The gradient according to which we will update the weights
+    * @param regularizationConstant The regularization parameter which is ignored
+    * @param learningRate The effective step size for this iteration
+    * @return Updated weights
+    */
+  override def takeStep(
+                         weightVector: Vector,
+                         gradient: Vector,
+                         regularizationConstant: Double,
+                         learningRate: Double)
+  : Vector = {
+    // Update the weight vector
+    BLAS.axpy(-learningRate, gradient, weightVector)
+    weightVector
+  }
+
+  /**
+   * Returns the unmodified loss value
+   *
+   * The updated loss is `oldLoss`
+   *
+   * @param oldLoss The loss to be updated
+   * @param weightVector The weights used to update the loss
+   * @param regularizationParameter The regularization parameter which is ignored
+   * @return Updated loss
+   */
+  override def regLoss(oldLoss: Double, weightVector: Vector, regularizationParameter: Double)
+    : Double = {
+    oldLoss
+  }
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/d3a07ef6/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/optimization/Solver.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/optimization/Solver.scala b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/optimization/Solver.scala
index ee91bd1..761620d 100644
--- a/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/optimization/Solver.scala
+++ b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/optimization/Solver.scala
@@ -95,6 +95,11 @@ abstract class Solver extends Serializable with WithParameters {
     parameters.add(RegularizationConstant, regularizationConstant)
     this
   }
+  
+  def setRegularizationPenalty(regularizationPenalty: RegularizationPenalty) : this.type = {
+    parameters.add(RegularizationPenaltyValue, regularizationPenalty)
+    this
+  }
 }
 
 object Solver {
@@ -108,6 +113,10 @@ object Solver {
   case object RegularizationConstant extends Parameter[Double] {
     val defaultValue = Some(0.0001) // TODO(tvas): Properly initialize this, ensure Parameter > 0!
   }
+  
+  case object RegularizationPenaltyValue extends Parameter[RegularizationPenalty] {
+    val defaultValue = Some(NoRegularization)
+  }
 }
 
 /** An abstract class for iterative optimization algorithms

http://git-wip-us.apache.org/repos/asf/flink/blob/d3a07ef6/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/regression/MultipleLinearRegression.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/regression/MultipleLinearRegression.scala b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/regression/MultipleLinearRegression.scala
index ef06033..773dd04 100644
--- a/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/regression/MultipleLinearRegression.scala
+++ b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/regression/MultipleLinearRegression.scala
@@ -190,7 +190,7 @@ object MultipleLinearRegression {
 
       val lossFunction = GenericLossFunction(SquaredLoss, LinearPrediction)
 
-      val optimizer = SimpleGradientDescent()
+      val optimizer = GradientDescent()
         .setIterations(numberOfIterations)
         .setStepsize(stepsize)
         .setLossFunction(lossFunction)

http://git-wip-us.apache.org/repos/asf/flink/blob/d3a07ef6/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/optimization/GradientDescentITSuite.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/optimization/GradientDescentITSuite.scala b/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/optimization/GradientDescentITSuite.scala
index 728ef57..da3fac9 100644
--- a/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/optimization/GradientDescentITSuite.scala
+++ b/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/optimization/GradientDescentITSuite.scala
@@ -40,10 +40,11 @@ class GradientDescentITSuite extends FlatSpec with Matchers with FlinkTestBase {
 
     val lossFunction = GenericLossFunction(SquaredLoss, LinearPrediction)
 
-    val sgd = GradientDescentL1()
+    val sgd = GradientDescent()
       .setStepsize(0.01)
       .setIterations(2000)
       .setLossFunction(lossFunction)
+      .setRegularizationPenalty(L1Regularization)
       .setRegularizationConstant(0.3)
       
     val inputDS: DataSet[LabeledVector] = env.fromCollection(regularizationData)
@@ -72,10 +73,11 @@ class GradientDescentITSuite extends FlatSpec with Matchers with FlinkTestBase {
 
     val lossFunction = GenericLossFunction(SquaredLoss, LinearPrediction)
 
-    val sgd = GradientDescentL2()
+    val sgd = GradientDescent()
       .setStepsize(0.1)
       .setIterations(1)
       .setLossFunction(lossFunction)
+      .setRegularizationPenalty(L2Regularization)
       .setRegularizationConstant(1.0)
 
     val inputDS: DataSet[LabeledVector] = env.fromElements(LabeledVector(1.0, DenseVector(2.0)))
@@ -101,7 +103,7 @@ class GradientDescentITSuite extends FlatSpec with Matchers with FlinkTestBase {
 
     val lossFunction = GenericLossFunction(SquaredLoss, LinearPrediction)
 
-    val sgd = SimpleGradientDescent()
+    val sgd = GradientDescent()
       .setStepsize(1.0)
       .setIterations(800)
       .setLossFunction(lossFunction)
@@ -132,7 +134,7 @@ class GradientDescentITSuite extends FlatSpec with Matchers with FlinkTestBase {
 
     val lossFunction = GenericLossFunction(SquaredLoss, LinearPrediction)
 
-    val sgd = SimpleGradientDescent()
+    val sgd = GradientDescent()
       .setStepsize(0.0001)
       .setIterations(100)
       .setLossFunction(lossFunction)
@@ -163,7 +165,7 @@ class GradientDescentITSuite extends FlatSpec with Matchers with FlinkTestBase {
 
     val lossFunction = GenericLossFunction(SquaredLoss, LinearPrediction)
 
-    val sgd = SimpleGradientDescent()
+    val sgd = GradientDescent()
       .setStepsize(0.1)
       .setIterations(1)
       .setLossFunction(lossFunction)
@@ -199,7 +201,7 @@ class GradientDescentITSuite extends FlatSpec with Matchers with FlinkTestBase {
 
     val lossFunction = GenericLossFunction(SquaredLoss, LinearPrediction)
 
-    val sgdEarlyTerminate = SimpleGradientDescent()
+    val sgdEarlyTerminate = GradientDescent()
       .setConvergenceThreshold(1e2)
       .setStepsize(1.0)
       .setIterations(800)
@@ -217,7 +219,7 @@ class GradientDescentITSuite extends FlatSpec with Matchers with FlinkTestBase {
     val weightsEarly = weightVectorEarly.weights.asInstanceOf[DenseVector].data
     val weight0Early = weightVectorEarly.intercept
 
-    val sgdNoConvergence = SimpleGradientDescent()
+    val sgdNoConvergence = GradientDescent()
       .setStepsize(1.0)
       .setIterations(800)
       .setLossFunction(lossFunction)
@@ -247,7 +249,7 @@ class GradientDescentITSuite extends FlatSpec with Matchers with FlinkTestBase {
 
     val lossFunction = GenericLossFunction(SquaredLoss, LinearPrediction)
 
-    val sgd = SimpleGradientDescent()
+    val sgd = GradientDescent()
       .setStepsize(1.0)
       .setIterations(800)
       .setLossFunction(lossFunction)

http://git-wip-us.apache.org/repos/asf/flink/blob/d3a07ef6/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/optimization/LossFunctionITSuite.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/optimization/LossFunctionITSuite.scala b/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/optimization/LossFunctionITSuite.scala
deleted file mode 100644
index 3d538cf..0000000
--- a/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/optimization/LossFunctionITSuite.scala
+++ /dev/null
@@ -1,51 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.ml.optimization
-
-import org.apache.flink.ml.common.{LabeledVector, WeightVector}
-import org.apache.flink.ml.math.DenseVector
-import org.apache.flink.ml.util.FlinkTestBase
-import org.scalatest.{Matchers, FlatSpec}
-
-import org.apache.flink.api.scala._
-
-
-class LossFunctionITSuite extends FlatSpec with Matchers with FlinkTestBase {
-
-  behavior of "The optimization Loss Function implementations"
-
-  it should "calculate squared loss and gradient correctly" in {
-    val env = ExecutionEnvironment.getExecutionEnvironment
-
-    env.setParallelism(2)
-
-    val lossFunction = GenericLossFunction(SquaredLoss, LinearPrediction)
-
-
-    val example = LabeledVector(1.0, DenseVector(2))
-    val weightVector = new WeightVector(DenseVector(1.0), 1.0)
-
-    val gradient = lossFunction.gradient(example, weightVector)
-    val loss = lossFunction.loss(example, weightVector)
-
-    loss should be (2.0 +- 0.001)
-
-    gradient.weights(0) should be (4.0 +- 0.001)
-  }
-}

http://git-wip-us.apache.org/repos/asf/flink/blob/d3a07ef6/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/optimization/LossFunctionTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/optimization/LossFunctionTest.scala b/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/optimization/LossFunctionTest.scala
new file mode 100644
index 0000000..05a159c
--- /dev/null
+++ b/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/optimization/LossFunctionTest.scala
@@ -0,0 +1,102 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.optimization
+
+import org.apache.flink.ml.common.{LabeledVector, WeightVector}
+import org.apache.flink.ml.math.DenseVector
+import org.apache.flink.ml.util.FlinkTestBase
+import org.scalatest.{Matchers, FlatSpec}
+
+
+class LossFunctionTest extends FlatSpec with Matchers {
+
+  behavior of "The optimization Loss Function implementations"
+
+  it should "calculate squared loss and gradient correctly" in {
+
+    val lossFunction = GenericLossFunction(SquaredLoss, LinearPrediction)
+
+    val example = LabeledVector(1.0, DenseVector(2))
+    val weightVector = new WeightVector(DenseVector(1.0), 1.0)
+
+    val gradient = lossFunction.gradient(example, weightVector)
+    val loss = lossFunction.loss(example, weightVector)
+
+    loss should be (2.0 +- 0.001)
+
+    gradient.weights(0) should be (4.0 +- 0.001)
+  }
+
+  it should "calculate logistic loss and gradient correctly" in {
+
+    val lossFunction = GenericLossFunction(LogisticLoss, LinearPrediction)
+
+    val examples = List(
+      LabeledVector(1.0, DenseVector(2)),
+      LabeledVector(1.0, DenseVector(20)),
+      LabeledVector(1.0, DenseVector(-25))
+    )
+
+    val weightVector = new WeightVector(DenseVector(1.0), 1.0)
+    val expectedLosses = List(0.049, 7.58e-10, 24.0)
+    val expectedGradients = List(-0.095, -1.52e-8, 25.0)
+
+    expectedLosses zip examples foreach {
+      case (expectedLoss, example) => {
+        val loss = lossFunction.loss(example, weightVector)
+        loss should be (expectedLoss +- 0.001)
+      }
+    }
+
+    expectedGradients zip examples foreach {
+      case (expectedGradient, example) => {
+        val gradient = lossFunction.gradient(example, weightVector)
+        gradient.weights(0) should be (expectedGradient +- 0.001)
+      }
+    }
+  }
+
+  it should "calculate hinge loss and gradient correctly" in {
+
+    val lossFunction = GenericLossFunction(HingeLoss, LinearPrediction)
+
+    val examples = List(
+      LabeledVector(1.0, DenseVector(2)),
+      LabeledVector(1.0, DenseVector(-2))
+    )
+
+    val weightVector = new WeightVector(DenseVector(1.0), 1.0)
+    val expectedLosses = List(0.0, 2.0)
+    val expectedGradients = List(0.0, 2.0)
+
+    expectedLosses zip examples foreach {
+      case (expectedLoss, example) => {
+        val loss = lossFunction.loss(example, weightVector)
+        loss should be (expectedLoss +- 0.001)
+      }
+    }
+
+    expectedGradients zip examples foreach {
+      case (expectedGradient, example) => {
+        val gradient = lossFunction.gradient(example, weightVector)
+        gradient.weights(0) should be (expectedGradient +- 0.001)
+      }
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/d3a07ef6/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/optimization/RegularizationPenaltyTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/optimization/RegularizationPenaltyTest.scala b/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/optimization/RegularizationPenaltyTest.scala
new file mode 100644
index 0000000..fe475f0
--- /dev/null
+++ b/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/optimization/RegularizationPenaltyTest.scala
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.optimization
+
+import org.apache.flink.ml.math.DenseVector
+import org.scalatest.{FlatSpec, Matchers}
+
+
+class RegularizationPenaltyTest extends FlatSpec with Matchers {
+
+  behavior of "The Regularization Penalty Function implementations"
+
+  it should "correctly update weights and loss with L2 regularization penalty" in {
+    val loss =  3.4
+    val weights = DenseVector(0.8)
+    val gradient = DenseVector(2.0)
+
+    val updatedWeights = L2Regularization.takeStep(weights, gradient, 0.3, 0.01)
+    val updatedLoss = L2Regularization.regLoss(loss, updatedWeights, 0.3)
+
+    updatedWeights(0) should be (0.7776 +- 0.001)
+    updatedLoss should be (3.4907 +- 0.001)
+  }
+
+  it should "correctly update weights and loss with L1 regularization penalty" in {
+    val loss =  3.4
+    val weights = DenseVector(0.8)
+    val gradient = DenseVector(2.0)
+
+    val updatedWeights = L1Regularization.takeStep(weights, gradient, 0.3, 0.01)
+    val updatedLoss = L1Regularization.regLoss(loss, updatedWeights, 0.3)
+
+    updatedWeights(0) should be (0.777 +- 0.001)
+    updatedLoss should be (3.6331 +- 0.001)
+  }
+
+  it should "correctly update weights and loss with no regularization penalty" in {
+    val loss =  3.4
+    val weights = DenseVector(0.8)
+    val gradient = DenseVector(2.0)
+
+    val updatedWeights = NoRegularization.takeStep(weights, gradient, 0.3, 0.01)
+    val updatedLoss = NoRegularization.regLoss(loss, updatedWeights, 0.3)
+
+    updatedWeights(0) should be (0.78 +- 0.001)
+    updatedLoss should be (3.4 +- 0.001)
+  }
+}