You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by tr...@apache.org on 2017/02/12 19:41:37 UTC
flink git commit: [FLINK-1979] [ml] Add logistic loss,
hinge loss and regularization penalties for optimization
Repository: flink
Updated Branches:
refs/heads/master 4f47ccdcb -> d3a07ef61
[FLINK-1979] [ml] Add logistic loss, hinge loss and regularization penalties for optimization
Use parameter to set regularization penalty in gradient descent solver
Update regularization penalty docs
This closes #1985.
Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/d3a07ef6
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/d3a07ef6
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/d3a07ef6
Branch: refs/heads/master
Commit: d3a07ef617562ee10fe46acc6a394e7f01dbcafe
Parents: 4f47ccd
Author: spkavuly <so...@intel.com>
Authored: Wed May 11 15:55:13 2016 -0700
Committer: Till Rohrmann <tr...@apache.org>
Committed: Sun Feb 12 20:24:49 2017 +0100
----------------------------------------------------------------------
docs/dev/libs/ml/optimization.md | 101 ++++++---
.../flink/ml/optimization/GradientDescent.scala | 147 +++----------
.../flink/ml/optimization/LossFunction.scala | 4 +-
.../ml/optimization/PartialLossFunction.scala | 111 +++++++++-
.../ml/optimization/RegularizationPenalty.scala | 219 +++++++++++++++++++
.../apache/flink/ml/optimization/Solver.scala | 9 +
.../regression/MultipleLinearRegression.scala | 2 +-
.../optimization/GradientDescentITSuite.scala | 18 +-
.../ml/optimization/LossFunctionITSuite.scala | 51 -----
.../ml/optimization/LossFunctionTest.scala | 102 +++++++++
.../RegularizationPenaltyTest.scala | 64 ++++++
11 files changed, 610 insertions(+), 218 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/flink/blob/d3a07ef6/docs/dev/libs/ml/optimization.md
----------------------------------------------------------------------
diff --git a/docs/dev/libs/ml/optimization.md b/docs/dev/libs/ml/optimization.md
index e3e2f63..739d912 100644
--- a/docs/dev/libs/ml/optimization.md
+++ b/docs/dev/libs/ml/optimization.md
@@ -1,7 +1,10 @@
---
mathjax: include
title: Optimization
-nav-parent_id: ml
+# Sub navigation
+sub-nav-group: batch
+sub-nav-parent: flinkml
+sub-nav-title: Optimization
---
<!--
Licensed to the Apache Software Foundation (ASF) under one
@@ -102,33 +105,6 @@ The current implementation of SGD uses the whole partition, making it
effectively a batch gradient descent. Once a sampling operator has been introduced in Flink, true
mini-batch SGD will be performed.
-### Regularization
-
-FlinkML supports Stochastic Gradient Descent with L1, L2 and no regularization.
-The following list contains a mapping between the implementing classes and the regularization function.
-
-<table class="table table-bordered">
- <thead>
- <tr>
- <th class="text-left" style="width: 20%">Class Name</th>
- <th class="text-center">Regularization function $R(\wv)$</th>
- </tr>
- </thead>
- <tbody>
- <tr>
- <td><code>SimpleGradient</code></td>
- <td>$R(\wv) = 0$</td>
- </tr>
- <tr>
- <td><code>GradientDescentL1</code></td>
- <td>$R(\wv) = \norm{\wv}_1$</td>
- </tr>
- <tr>
- <td><code>GradientDescentL2</code></td>
- <td>$R(\wv) = \frac{1}{2}\norm{\wv}_2^2$</td>
- </tr>
- </tbody>
-</table>
### Parameters
@@ -143,10 +119,10 @@ The following list contains a mapping between the implementing classes and the r
</thead>
<tbody>
<tr>
- <td><strong>LossFunction</strong></td>
+ <td><strong>RegularizationPenalty</strong></td>
<td>
<p>
- The loss function to be optimized. (Default value: <strong>None</strong>)
+ The regularization function to apply. (Default value: <strong>NoRegularization</strong>)
</p>
</td>
</tr>
@@ -159,6 +135,14 @@ The following list contains a mapping between the implementing classes and the r
</td>
</tr>
<tr>
+ <td><strong>LossFunction</strong></td>
+ <td>
+ <p>
+ The loss function to be optimized. (Default value: <strong>None</strong>)
+ </p>
+ </td>
+ </tr>
+ <tr>
<td><strong>Iterations</strong></td>
<td>
<p>
@@ -206,6 +190,35 @@ The following list contains a mapping between the implementing classes and the r
</tbody>
</table>
+### Regularization
+
+FlinkML supports Stochastic Gradient Descent with L1, L2 and no regularization. The regularization type has to implement the `RegularizationPenalty` interface,
+which calculates the new weights based on the gradient and regularization type.
+The following list contains the supported regularization functions.
+
+<table class="table table-bordered">
+ <thead>
+ <tr>
+ <th class="text-left" style="width: 20%">Class Name</th>
+ <th class="text-center">Regularization function $R(\wv)$</th>
+ </tr>
+ </thead>
+ <tbody>
+ <tr>
+ <td><strong>NoRegularization</strong></td>
+ <td>$R(\wv) = 0$</td>
+ </tr>
+ <tr>
+ <td><strong>L1Regularization</strong></td>
+ <td>$R(\wv) = \norm{\wv}_1$</td>
+ </tr>
+ <tr>
+ <td><strong>L2Regularization</strong></td>
+ <td>$R(\wv) = \frac{1}{2}\norm{\wv}_2^2$</td>
+ </tr>
+ </tbody>
+</table>
+
### Loss Function
The loss function which is minimized has to implement the `LossFunction` interface, which defines methods to compute the loss and the gradient of it.
@@ -241,6 +254,29 @@ The full list of supported prediction functions can be found [here](#prediction-
<td class="text-center">$\frac{1}{2} (\wv^T \cdot \x - y)^2$</td>
<td class="text-center">$\wv^T \cdot \x - y$</td>
</tr>
+ <tr>
+ <td><strong>LogisticLoss</strong></td>
+ <td>
+ <p>
+ Loss function used for classification tasks.
+ </p>
+ </td>
+ <td class="text-center">$\log\left(1+\exp\left( -y ~ \wv^T \cdot \x\right)\right), \quad y \in \{-1, +1\}$</td>
+ <td class="text-center">$\frac{-y}{1+\exp\left(y ~ \wv^T \cdot \x\right)}$</td>
+ </tr>
+ <tr>
+ <td><strong>HingeLoss</strong></td>
+ <td>
+ <p>
+ Loss function used for classification tasks.
+ </p>
+ </td>
+ <td class="text-center">$\max \left(0, 1 - y ~ \wv^T \cdot \x\right), \quad y \in \{-1, +1\}$</td>
+ <td class="text-center">$\begin{cases}
+ -y&\text{if } y ~ \wv^T <= 1 \\
+ 0&\text{if } y ~ \wv^T > 1
+ \end{cases}$</td>
+ </tr>
</tbody>
</table>
@@ -354,7 +390,7 @@ Where:
### Examples
In the Flink implementation of SGD, given a set of examples in a `DataSet[LabeledVector]` and
-optionally some initial weights, we can use `GradientDescentL1.optimize()` in order to optimize
+optionally some initial weights, we can use `GradientDescent.optimize()` in order to optimize
the weights for the given data.
The user can provide an initial `DataSet[WeightVector]`,
@@ -366,8 +402,9 @@ weight vector. This allows us to avoid applying regularization to the intercept.
{% highlight scala %}
// Create stochastic gradient descent solver
-val sgd = GradientDescentL1()
+val sgd = GradientDescent()
.setLossFunction(SquaredLoss())
+ .setRegularizationPenalty(L1Regularization)
.setRegularizationConstant(0.2)
.setIterations(100)
.setLearningRate(0.01)
http://git-wip-us.apache.org/repos/asf/flink/blob/d3a07ef6/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/optimization/GradientDescent.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/optimization/GradientDescent.scala b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/optimization/GradientDescent.scala
index 407c074..fbb3a31 100644
--- a/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/optimization/GradientDescent.scala
+++ b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/optimization/GradientDescent.scala
@@ -38,6 +38,7 @@ import org.apache.flink.ml._
*
* The parameters to tune the algorithm are:
* [[Solver.LossFunction]] for the loss function to be used,
+ * [[Solver.RegularizationPenaltyValue]] for the regularization penalty.
* [[Solver.RegularizationConstant]] for the regularization parameter,
* [[IterativeSolver.Iterations]] for the maximum number of iteration,
* [[IterativeSolver.LearningRate]] for the learning rate used.
@@ -47,7 +48,7 @@ import org.apache.flink.ml._
* [[IterativeSolver.LearningRateMethodValue]] determines functional form of
* effective learning rate.
*/
-abstract class GradientDescent extends IterativeSolver {
+class GradientDescent extends IterativeSolver {
/** Provides a solution for the given optimization problem
*
@@ -63,8 +64,10 @@ abstract class GradientDescent extends IterativeSolver {
val convergenceThresholdOption: Option[Double] = parameters.get(ConvergenceThreshold)
val lossFunction = parameters(LossFunction)
val learningRate = parameters(LearningRate)
+ val regularizationPenalty = parameters(RegularizationPenaltyValue)
val regularizationConstant = parameters(RegularizationConstant)
val learningRateMethod = parameters(LearningRateMethodValue)
+
// Initialize weights
val initialWeightsDS: DataSet[WeightVector] = createInitialWeightsDS(initialWeights, data)
@@ -76,6 +79,7 @@ abstract class GradientDescent extends IterativeSolver {
data,
initialWeightsDS,
numberOfIterations,
+ regularizationPenalty,
regularizationConstant,
learningRate,
lossFunction,
@@ -85,6 +89,7 @@ abstract class GradientDescent extends IterativeSolver {
data,
initialWeightsDS,
numberOfIterations,
+ regularizationPenalty,
regularizationConstant,
learningRate,
convergence,
@@ -97,6 +102,7 @@ abstract class GradientDescent extends IterativeSolver {
dataPoints: DataSet[LabeledVector],
initialWeightsDS: DataSet[WeightVector],
numberOfIterations: Int,
+ regularizationPenalty: RegularizationPenalty,
regularizationConstant: Double,
learningRate: Double,
convergenceThreshold: Double,
@@ -123,6 +129,7 @@ abstract class GradientDescent extends IterativeSolver {
dataPoints,
previousWeightsDS,
lossFunction,
+ regularizationPenalty,
regularizationConstant,
learningRate,
learningRateMethod)
@@ -152,6 +159,7 @@ abstract class GradientDescent extends IterativeSolver {
data: DataSet[LabeledVector],
initialWeightsDS: DataSet[WeightVector],
numberOfIterations: Int,
+ regularizationPenalty: RegularizationPenalty,
regularizationConstant: Double,
learningRate: Double,
lossFunction: LossFunction,
@@ -162,6 +170,7 @@ abstract class GradientDescent extends IterativeSolver {
SGDStep(data,
weightVectorDS,
lossFunction,
+ regularizationPenalty,
regularizationConstant,
learningRate,
optimizationMethod)
@@ -173,12 +182,19 @@ abstract class GradientDescent extends IterativeSolver {
*
* @param data A Dataset of LabeledVector (label, features) pairs
* @param currentWeights A Dataset with the current weights to be optimized as its only element
+ * @param lossFunction The loss function to be used
+ * @param regularizationPenalty The regularization penalty to be used
+ * @param regularizationConstant The regularization parameter
+ * @param learningRate The effective step size for this iteration
+ * @param learningRateMethod The learning rate used
+ *
* @return A Dataset containing the weights after one stochastic gradient descent step
*/
private def SGDStep(
data: DataSet[(LabeledVector)],
currentWeights: DataSet[WeightVector],
lossFunction: LossFunction,
+ regularizationPenalty: RegularizationPenalty,
regularizationConstant: Double,
learningRate: Double,
learningRateMethod: LearningRateMethodTrait)
@@ -220,6 +236,7 @@ abstract class GradientDescent extends IterativeSolver {
val newWeights = takeStep(
weightVector.weights,
gradient.weights,
+ regularizationPenalty,
regularizationConstant,
effectiveLearningRate)
@@ -232,25 +249,29 @@ abstract class GradientDescent extends IterativeSolver {
/** Calculates the new weights based on the gradient
*
- * @param weightVector
- * @param gradient
- * @param regularizationConstant
- * @param learningRate
- * @return
+ * @param weightVector The weights to be updated
+ * @param gradient The gradient according to which we will update the weights
+ * @param regularizationPenalty The regularization penalty to apply
+ * @param regularizationConstant The regularization parameter
+ * @param learningRate The effective step size for this iteration
+ * @return Updated weights
*/
def takeStep(
weightVector: Vector,
gradient: Vector,
+ regularizationPenalty: RegularizationPenalty,
regularizationConstant: Double,
learningRate: Double
- ): Vector
+ ): Vector = {
+ regularizationPenalty.takeStep(weightVector, gradient, regularizationConstant, learningRate)
+ }
/** Calculates the regularized loss, from the data and given weights.
*
- * @param data
- * @param weightDS
- * @param lossFunction
- * @return
+ * @param data A Dataset of LabeledVector (label, features) pairs
+ * @param weightDS A Dataset with the current weights to be optimized as its only element
+ * @param lossFunction The loss function to be used
+ * @return A Dataset with the regularized loss as its only element
*/
private def calculateLoss(
data: DataSet[LabeledVector],
@@ -267,108 +288,10 @@ abstract class GradientDescent extends IterativeSolver {
}
}
-/** Implementation of a SGD solver with L2 regularization.
- *
- * The regularization function is `1/2 ||w||_2^2` with `w` being the weight vector.
- */
-class GradientDescentL2 extends GradientDescent {
-
- /** Calculates the new weights based on the gradient
- *
- * @param weightVector
- * @param gradient
- * @param regularizationConstant
- * @param learningRate
- * @return
- */
- override def takeStep(
- weightVector: Vector,
- gradient: Vector,
- regularizationConstant: Double,
- learningRate: Double)
- : Vector = {
- // add the gradient of the L2 regularization
- BLAS.axpy(regularizationConstant, weightVector, gradient)
-
- // update the weights according to the learning rate
- BLAS.axpy(-learningRate, gradient, weightVector)
-
- weightVector
- }
-}
-
-object GradientDescentL2 {
- def apply() = new GradientDescentL2
-}
-
-/** Implementation of a SGD solver with L1 regularization.
- *
- * The regularization function is `||w||_1` with `w` being the weight vector.
- */
-class GradientDescentL1 extends GradientDescent {
-
- /** Calculates the new weights based on the gradient.
- *
- * @param weightVector
- * @param gradient
- * @param regularizationConstant
- * @param learningRate
- * @return
- */
- override def takeStep(
- weightVector: Vector,
- gradient: Vector,
- regularizationConstant: Double,
- learningRate: Double)
- : Vector = {
- // Update weight vector with gradient. L1 regularization has no gradient, the proximal operator
- // does the job.
- BLAS.axpy(-learningRate, gradient, weightVector)
-
- // Apply proximal operator (soft thresholding)
- val shrinkageVal = regularizationConstant * learningRate
- var i = 0
- while (i < weightVector.size) {
- val wi = weightVector(i)
- weightVector(i) = scala.math.signum(wi) *
- scala.math.max(0.0, scala.math.abs(wi) - shrinkageVal)
- i += 1
- }
-
- weightVector
- }
-}
-
-object GradientDescentL1 {
- def apply() = new GradientDescentL1
-}
-/** Implementation of a SGD solver without regularization.
+/** Implementation of a Gradient Descent solver.
*
- * No regularization is applied.
*/
-class SimpleGradientDescent extends GradientDescent {
-
- /** Calculates the new weights based on the gradient.
- *
- * @param weightVector
- * @param gradient
- * @param regularizationConstant
- * @param learningRate
- * @return
- */
- override def takeStep(
- weightVector: Vector,
- gradient: Vector,
- regularizationConstant: Double,
- learningRate: Double)
- : Vector = {
- // Update the weight vector
- BLAS.axpy(-learningRate, gradient, weightVector)
- weightVector
- }
-}
-
-object SimpleGradientDescent{
- def apply() = new SimpleGradientDescent
+object GradientDescent {
+ def apply() = new GradientDescent
}
http://git-wip-us.apache.org/repos/asf/flink/blob/d3a07ef6/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/optimization/LossFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/optimization/LossFunction.scala b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/optimization/LossFunction.scala
index 1ff5d97..bf96cac 100644
--- a/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/optimization/LossFunction.scala
+++ b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/optimization/LossFunction.scala
@@ -23,8 +23,8 @@ import org.apache.flink.ml.math.BLAS
/** Abstract class that implements some of the functionality for common loss functions
*
- * A loss function determines the loss term $L(w) of the objective function $f(w) = L(w) +
- * \lambda R(w)$ for prediction tasks, the other being regularization, $R(w)$.
+ * A loss function determines the loss term `L(w)` of the objective function `f(w) = L(w) +
+ * lambda*R(w)` for prediction tasks, the other being regularization, `R(w)`.
*
* The regularization is specific to the used optimization algorithm and, thus, implemented there.
*
http://git-wip-us.apache.org/repos/asf/flink/blob/d3a07ef6/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/optimization/PartialLossFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/optimization/PartialLossFunction.scala b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/optimization/PartialLossFunction.scala
index ac0053e..10f7e00 100644
--- a/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/optimization/PartialLossFunction.scala
+++ b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/optimization/PartialLossFunction.scala
@@ -24,17 +24,17 @@ package org.apache.flink.ml.optimization
trait PartialLossFunction extends Serializable {
/** Calculates the loss depending on the label and the prediction
*
- * @param prediction
- * @param label
- * @return
+ * @param prediction The predicted value
+ * @param label The true value
+ * @return The loss
*/
def loss(prediction: Double, label: Double): Double
/** Calculates the derivative of the [[PartialLossFunction]]
*
- * @param prediction
- * @param label
- * @return
+ * @param prediction The predicted value
+ * @param label The true value
+ * @return The derivative of the loss function
*/
def derivative(prediction: Double, label: Double): Double
}
@@ -47,9 +47,9 @@ object SquaredLoss extends PartialLossFunction {
/** Calculates the loss depending on the label and the prediction
*
- * @param prediction
- * @param label
- * @return
+ * @param prediction The predicted value
+ * @param label The true value
+ * @return The loss
*/
override def loss(prediction: Double, label: Double): Double = {
0.5 * (prediction - label) * (prediction - label)
@@ -57,11 +57,98 @@ object SquaredLoss extends PartialLossFunction {
/** Calculates the derivative of the [[PartialLossFunction]]
*
- * @param prediction
- * @param label
- * @return
+ * @param prediction The predicted value
+ * @param label The true value
+ * @return The derivative of the loss function
*/
override def derivative(prediction: Double, label: Double): Double = {
prediction - label
}
}
+
+/** Logistic loss function which can be used with the [[GenericLossFunction]]
+ *
+ *
+ * The [[LogisticLoss]] function implements `log(1 + -exp(prediction*label))`
+ * for binary classification with label in {-1, 1}
+ */
+object LogisticLoss extends PartialLossFunction {
+
+ /** Calculates the loss depending on the label and the prediction
+ *
+ * @param prediction The predicted value
+ * @param label The true value
+ * @return The loss
+ */
+ override def loss(prediction: Double, label: Double): Double = {
+ val z = prediction * label
+
+ // based on implementation in scikit-learn
+ // approximately equal and saves the computation of the log
+ if (z > 18) {
+ math.exp(-z)
+ }
+ else if (z < -18) {
+ -z
+ }
+ else {
+ math.log(1 + math.exp(-z))
+ }
+ }
+
+ /** Calculates the derivative of the loss function with respect to the prediction
+ *
+ * @param prediction The predicted value
+ * @param label The true value
+ * @return The derivative of the loss function
+ */
+ override def derivative(prediction: Double, label: Double): Double = {
+ val z = prediction * label
+
+ // based on implementation in scikit-learn
+ // approximately equal and saves the computation of the log
+ if (z > 18) {
+ label * math.exp(-z)
+ }
+ else if (z < -18) {
+ -label
+ }
+ else {
+ -label/(math.exp(z) + 1)
+ }
+ }
+}
+
+/** Hinge loss function which can be used with the [[GenericLossFunction]]
+ *
+ * The [[HingeLoss]] function implements `max(0, 1 - prediction*label)`
+ * for binary classification with label in {-1, 1}
+ */
+object HingeLoss extends PartialLossFunction {
+ /** Calculates the loss for a given prediction/truth pair
+ *
+ * @param prediction The predicted value
+ * @param label The true value
+ * @return The loss
+ */
+ override def loss(prediction: Double, label: Double): Double = {
+ val z = prediction * label
+ math.max(0, 1 - z)
+ }
+
+ /** Calculates the derivative of the loss function with respect to the prediction
+ *
+ * @param prediction The predicted value
+ * @param label The true value
+ * @return The derivative of the loss function
+ */
+ override def derivative(prediction: Double, label: Double): Double = {
+ val z = prediction * label
+ if (z <= 1) {
+ -label
+ }
+ else {
+ 0
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/d3a07ef6/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/optimization/RegularizationPenalty.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/optimization/RegularizationPenalty.scala b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/optimization/RegularizationPenalty.scala
new file mode 100644
index 0000000..1ed59ab
--- /dev/null
+++ b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/optimization/RegularizationPenalty.scala
@@ -0,0 +1,219 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.optimization
+
+import org.apache.flink.ml.math.{Vector, BLAS}
+import org.apache.flink.ml.math.Breeze._
+import breeze.linalg.{norm => BreezeNorm}
+
+/** Represents a type of regularization penalty
+ *
+ * Regularization penalties are used to restrict the optimization problem to solutions with
+ * certain desirable characteristics, such as sparsity for the L1 penalty, or penalizing large
+ * weights for the L2 penalty.
+ *
+ * The regularization term, `R(w)` is added to the objective function, `f(w) = L(w) + lambda*R(w)`
+ * where lambda is the regularization parameter used to tune the amount of regularization applied.
+ */
+trait RegularizationPenalty extends Serializable {
+
+ /** Calculates the new weights based on the gradient and regularization penalty
+ *
+ * Weights are updated using the gradient descent step `w - learningRate * gradient`
+ * with `w` being the weight vector.
+ *
+ * @param weightVector The weights to be updated
+ * @param gradient The gradient used to update the weights
+ * @param regularizationConstant The regularization parameter to be applied
+ * @param learningRate The effective step size for this iteration
+ * @return Updated weights
+ */
+ def takeStep(
+ weightVector: Vector,
+ gradient: Vector,
+ regularizationConstant: Double,
+ learningRate: Double)
+ : Vector
+
+ /** Adds regularization to the loss value
+ *
+ * @param oldLoss The loss to be updated
+ * @param weightVector The weights used to update the loss
+ * @param regularizationConstant The regularization parameter to be applied
+ * @return Updated loss
+ */
+ def regLoss(oldLoss: Double, weightVector: Vector, regularizationConstant: Double): Double
+
+}
+
+
+/** `L_2` regularization penalty.
+ *
+ * The regularization function is the square of the L2 norm `1/2*||w||_2^2`
+ * with `w` being the weight vector. The function penalizes large weights,
+ * favoring solutions with more small weights rather than few large ones.
+ */
+object L2Regularization extends RegularizationPenalty {
+
+ /** Calculates the new weights based on the gradient and L2 regularization penalty
+ *
+ * The updated weight is `w - learningRate * (gradient + lambda * w)` where
+ * `w` is the weight vector, and `lambda` is the regularization parameter.
+ *
+ * @param weightVector The weights to be updated
+ * @param gradient The gradient according to which we will update the weights
+ * @param regularizationConstant The regularization parameter to be applied
+ * @param learningRate The effective step size for this iteration
+ * @return Updated weights
+ */
+ override def takeStep(
+ weightVector: Vector,
+ gradient: Vector,
+ regularizationConstant: Double,
+ learningRate: Double)
+ : Vector = {
+ // add the gradient of the L2 regularization
+ BLAS.axpy(regularizationConstant, weightVector, gradient)
+
+ // update the weights according to the learning rate
+ BLAS.axpy(-learningRate, gradient, weightVector)
+
+ weightVector
+ }
+
+ /** Adds regularization to the loss value
+ *
+ * The updated loss is `oldLoss + lambda * 1/2*||w||_2^2` where
+ * `w` is the weight vector, and `lambda` is the regularization parameter
+ *
+ * @param oldLoss The loss to be updated
+ * @param weightVector The weights used to update the loss
+ * @param regularizationConstant The regularization parameter to be applied
+ * @return Updated loss
+ */
+ override def regLoss(oldLoss: Double, weightVector: Vector, regularizationConstant: Double)
+ : Double = {
+ val squareNorm = BLAS.dot(weightVector, weightVector)
+ oldLoss + regularizationConstant * 0.5 * squareNorm
+ }
+}
+
+/** `L_1` regularization penalty.
+ *
+ * The regularization function is the `L1` norm `||w||_1` with `w` being the weight vector.
+ * The `L_1` penalty can be used to drive a number of the solution coefficients to 0, thereby
+ * producing sparse solutions.
+ *
+ */
+object L1Regularization extends RegularizationPenalty {
+
+ /** Calculates the new weights based on the gradient and L1 regularization penalty
+ *
+ * Uses the proximal gradient method with L1 regularization to update weights.
+ * The updated weight `w - learningRate * gradient` is shrunk towards zero
+ * by applying the proximal operator `signum(w) * max(0.0, abs(w) - shrinkageVal)`
+ * where `w` is the weight vector, `lambda` is the regularization parameter,
+ * and `shrinkageVal` is `lambda*learningRate`.
+ *
+ * @param weightVector The weights to be updated
+ * @param gradient The gradient according to which we will update the weights
+ * @param regularizationConstant The regularization parameter to be applied
+ * @param learningRate The effective step size for this iteration
+ * @return Updated weights
+ */
+ override def takeStep(
+ weightVector: Vector,
+ gradient: Vector,
+ regularizationConstant: Double,
+ learningRate: Double)
+ : Vector = {
+ // Update weight vector with gradient.
+ BLAS.axpy(-learningRate, gradient, weightVector)
+
+ // Apply proximal operator (soft thresholding)
+ val shrinkageVal = regularizationConstant * learningRate
+ var i = 0
+ while (i < weightVector.size) {
+ val wi = weightVector(i)
+ weightVector(i) = math.signum(wi) *
+ math.max(0.0, math.abs(wi) - shrinkageVal)
+ i += 1
+ }
+
+ weightVector
+ }
+
+ /** Adds regularization to the loss value
+ *
+ * The updated loss is `oldLoss + lambda * ||w||_1` where
+ * `w` is the weight vector and `lambda` is the regularization parameter
+ *
+ * @param oldLoss The loss to be updated
+ * @param weightVector The weights used to update the loss
+ * @param regularizationConstant The regularization parameter to be applied
+ * @return Updated loss
+ */
+ override def regLoss(oldLoss: Double, weightVector: Vector, regularizationConstant: Double)
+ : Double = {
+ val norm = BreezeNorm(weightVector.asBreeze, 1.0)
+ oldLoss + norm * regularizationConstant
+ }
+}
+
+/** No regularization penalty.
+ *
+ */
+object NoRegularization extends RegularizationPenalty {
+
+ /** Calculates the new weights based on the gradient
+ *
+ * The updated weight is `w - learningRate *gradient` where `w` is the weight vector
+ *
+ * @param weightVector The weights to be updated
+ * @param gradient The gradient according to which we will update the weights
+ * @param regularizationConstant The regularization parameter which is ignored
+ * @param learningRate The effective step size for this iteration
+ * @return Updated weights
+ */
+ override def takeStep(
+ weightVector: Vector,
+ gradient: Vector,
+ regularizationConstant: Double,
+ learningRate: Double)
+ : Vector = {
+ // Update the weight vector
+ BLAS.axpy(-learningRate, gradient, weightVector)
+ weightVector
+ }
+
+ /**
+ * Returns the unmodified loss value
+ *
+ * The updated loss is `oldLoss`
+ *
+ * @param oldLoss The loss to be updated
+ * @param weightVector The weights used to update the loss
+ * @param regularizationParameter The regularization parameter which is ignored
+ * @return Updated loss
+ */
+ override def regLoss(oldLoss: Double, weightVector: Vector, regularizationParameter: Double)
+ : Double = {
+ oldLoss
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/d3a07ef6/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/optimization/Solver.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/optimization/Solver.scala b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/optimization/Solver.scala
index ee91bd1..761620d 100644
--- a/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/optimization/Solver.scala
+++ b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/optimization/Solver.scala
@@ -95,6 +95,11 @@ abstract class Solver extends Serializable with WithParameters {
parameters.add(RegularizationConstant, regularizationConstant)
this
}
+
+ def setRegularizationPenalty(regularizationPenalty: RegularizationPenalty) : this.type = {
+ parameters.add(RegularizationPenaltyValue, regularizationPenalty)
+ this
+ }
}
object Solver {
@@ -108,6 +113,10 @@ object Solver {
case object RegularizationConstant extends Parameter[Double] {
val defaultValue = Some(0.0001) // TODO(tvas): Properly initialize this, ensure Parameter > 0!
}
+
+ case object RegularizationPenaltyValue extends Parameter[RegularizationPenalty] {
+ val defaultValue = Some(NoRegularization)
+ }
}
/** An abstract class for iterative optimization algorithms
http://git-wip-us.apache.org/repos/asf/flink/blob/d3a07ef6/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/regression/MultipleLinearRegression.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/regression/MultipleLinearRegression.scala b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/regression/MultipleLinearRegression.scala
index ef06033..773dd04 100644
--- a/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/regression/MultipleLinearRegression.scala
+++ b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/regression/MultipleLinearRegression.scala
@@ -190,7 +190,7 @@ object MultipleLinearRegression {
val lossFunction = GenericLossFunction(SquaredLoss, LinearPrediction)
- val optimizer = SimpleGradientDescent()
+ val optimizer = GradientDescent()
.setIterations(numberOfIterations)
.setStepsize(stepsize)
.setLossFunction(lossFunction)
http://git-wip-us.apache.org/repos/asf/flink/blob/d3a07ef6/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/optimization/GradientDescentITSuite.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/optimization/GradientDescentITSuite.scala b/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/optimization/GradientDescentITSuite.scala
index 728ef57..da3fac9 100644
--- a/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/optimization/GradientDescentITSuite.scala
+++ b/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/optimization/GradientDescentITSuite.scala
@@ -40,10 +40,11 @@ class GradientDescentITSuite extends FlatSpec with Matchers with FlinkTestBase {
val lossFunction = GenericLossFunction(SquaredLoss, LinearPrediction)
- val sgd = GradientDescentL1()
+ val sgd = GradientDescent()
.setStepsize(0.01)
.setIterations(2000)
.setLossFunction(lossFunction)
+ .setRegularizationPenalty(L1Regularization)
.setRegularizationConstant(0.3)
val inputDS: DataSet[LabeledVector] = env.fromCollection(regularizationData)
@@ -72,10 +73,11 @@ class GradientDescentITSuite extends FlatSpec with Matchers with FlinkTestBase {
val lossFunction = GenericLossFunction(SquaredLoss, LinearPrediction)
- val sgd = GradientDescentL2()
+ val sgd = GradientDescent()
.setStepsize(0.1)
.setIterations(1)
.setLossFunction(lossFunction)
+ .setRegularizationPenalty(L2Regularization)
.setRegularizationConstant(1.0)
val inputDS: DataSet[LabeledVector] = env.fromElements(LabeledVector(1.0, DenseVector(2.0)))
@@ -101,7 +103,7 @@ class GradientDescentITSuite extends FlatSpec with Matchers with FlinkTestBase {
val lossFunction = GenericLossFunction(SquaredLoss, LinearPrediction)
- val sgd = SimpleGradientDescent()
+ val sgd = GradientDescent()
.setStepsize(1.0)
.setIterations(800)
.setLossFunction(lossFunction)
@@ -132,7 +134,7 @@ class GradientDescentITSuite extends FlatSpec with Matchers with FlinkTestBase {
val lossFunction = GenericLossFunction(SquaredLoss, LinearPrediction)
- val sgd = SimpleGradientDescent()
+ val sgd = GradientDescent()
.setStepsize(0.0001)
.setIterations(100)
.setLossFunction(lossFunction)
@@ -163,7 +165,7 @@ class GradientDescentITSuite extends FlatSpec with Matchers with FlinkTestBase {
val lossFunction = GenericLossFunction(SquaredLoss, LinearPrediction)
- val sgd = SimpleGradientDescent()
+ val sgd = GradientDescent()
.setStepsize(0.1)
.setIterations(1)
.setLossFunction(lossFunction)
@@ -199,7 +201,7 @@ class GradientDescentITSuite extends FlatSpec with Matchers with FlinkTestBase {
val lossFunction = GenericLossFunction(SquaredLoss, LinearPrediction)
- val sgdEarlyTerminate = SimpleGradientDescent()
+ val sgdEarlyTerminate = GradientDescent()
.setConvergenceThreshold(1e2)
.setStepsize(1.0)
.setIterations(800)
@@ -217,7 +219,7 @@ class GradientDescentITSuite extends FlatSpec with Matchers with FlinkTestBase {
val weightsEarly = weightVectorEarly.weights.asInstanceOf[DenseVector].data
val weight0Early = weightVectorEarly.intercept
- val sgdNoConvergence = SimpleGradientDescent()
+ val sgdNoConvergence = GradientDescent()
.setStepsize(1.0)
.setIterations(800)
.setLossFunction(lossFunction)
@@ -247,7 +249,7 @@ class GradientDescentITSuite extends FlatSpec with Matchers with FlinkTestBase {
val lossFunction = GenericLossFunction(SquaredLoss, LinearPrediction)
- val sgd = SimpleGradientDescent()
+ val sgd = GradientDescent()
.setStepsize(1.0)
.setIterations(800)
.setLossFunction(lossFunction)
http://git-wip-us.apache.org/repos/asf/flink/blob/d3a07ef6/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/optimization/LossFunctionITSuite.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/optimization/LossFunctionITSuite.scala b/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/optimization/LossFunctionITSuite.scala
deleted file mode 100644
index 3d538cf..0000000
--- a/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/optimization/LossFunctionITSuite.scala
+++ /dev/null
@@ -1,51 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.ml.optimization
-
-import org.apache.flink.ml.common.{LabeledVector, WeightVector}
-import org.apache.flink.ml.math.DenseVector
-import org.apache.flink.ml.util.FlinkTestBase
-import org.scalatest.{Matchers, FlatSpec}
-
-import org.apache.flink.api.scala._
-
-
-class LossFunctionITSuite extends FlatSpec with Matchers with FlinkTestBase {
-
- behavior of "The optimization Loss Function implementations"
-
- it should "calculate squared loss and gradient correctly" in {
- val env = ExecutionEnvironment.getExecutionEnvironment
-
- env.setParallelism(2)
-
- val lossFunction = GenericLossFunction(SquaredLoss, LinearPrediction)
-
-
- val example = LabeledVector(1.0, DenseVector(2))
- val weightVector = new WeightVector(DenseVector(1.0), 1.0)
-
- val gradient = lossFunction.gradient(example, weightVector)
- val loss = lossFunction.loss(example, weightVector)
-
- loss should be (2.0 +- 0.001)
-
- gradient.weights(0) should be (4.0 +- 0.001)
- }
-}
http://git-wip-us.apache.org/repos/asf/flink/blob/d3a07ef6/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/optimization/LossFunctionTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/optimization/LossFunctionTest.scala b/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/optimization/LossFunctionTest.scala
new file mode 100644
index 0000000..05a159c
--- /dev/null
+++ b/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/optimization/LossFunctionTest.scala
@@ -0,0 +1,102 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.optimization
+
+import org.apache.flink.ml.common.{LabeledVector, WeightVector}
+import org.apache.flink.ml.math.DenseVector
+import org.apache.flink.ml.util.FlinkTestBase
+import org.scalatest.{Matchers, FlatSpec}
+
+
+class LossFunctionTest extends FlatSpec with Matchers {
+
+ behavior of "The optimization Loss Function implementations"
+
+ it should "calculate squared loss and gradient correctly" in {
+
+ val lossFunction = GenericLossFunction(SquaredLoss, LinearPrediction)
+
+ val example = LabeledVector(1.0, DenseVector(2))
+ val weightVector = new WeightVector(DenseVector(1.0), 1.0)
+
+ val gradient = lossFunction.gradient(example, weightVector)
+ val loss = lossFunction.loss(example, weightVector)
+
+ loss should be (2.0 +- 0.001)
+
+ gradient.weights(0) should be (4.0 +- 0.001)
+ }
+
+ it should "calculate logistic loss and gradient correctly" in {
+
+ val lossFunction = GenericLossFunction(LogisticLoss, LinearPrediction)
+
+ val examples = List(
+ LabeledVector(1.0, DenseVector(2)),
+ LabeledVector(1.0, DenseVector(20)),
+ LabeledVector(1.0, DenseVector(-25))
+ )
+
+ val weightVector = new WeightVector(DenseVector(1.0), 1.0)
+ val expectedLosses = List(0.049, 7.58e-10, 24.0)
+ val expectedGradients = List(-0.095, -1.52e-8, 25.0)
+
+ expectedLosses zip examples foreach {
+ case (expectedLoss, example) => {
+ val loss = lossFunction.loss(example, weightVector)
+ loss should be (expectedLoss +- 0.001)
+ }
+ }
+
+ expectedGradients zip examples foreach {
+ case (expectedGradient, example) => {
+ val gradient = lossFunction.gradient(example, weightVector)
+ gradient.weights(0) should be (expectedGradient +- 0.001)
+ }
+ }
+ }
+
+ it should "calculate hinge loss and gradient correctly" in {
+
+ val lossFunction = GenericLossFunction(HingeLoss, LinearPrediction)
+
+ val examples = List(
+ LabeledVector(1.0, DenseVector(2)),
+ LabeledVector(1.0, DenseVector(-2))
+ )
+
+ val weightVector = new WeightVector(DenseVector(1.0), 1.0)
+ val expectedLosses = List(0.0, 2.0)
+ val expectedGradients = List(0.0, 2.0)
+
+ expectedLosses zip examples foreach {
+ case (expectedLoss, example) => {
+ val loss = lossFunction.loss(example, weightVector)
+ loss should be (expectedLoss +- 0.001)
+ }
+ }
+
+ expectedGradients zip examples foreach {
+ case (expectedGradient, example) => {
+ val gradient = lossFunction.gradient(example, weightVector)
+ gradient.weights(0) should be (expectedGradient +- 0.001)
+ }
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/d3a07ef6/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/optimization/RegularizationPenaltyTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/optimization/RegularizationPenaltyTest.scala b/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/optimization/RegularizationPenaltyTest.scala
new file mode 100644
index 0000000..fe475f0
--- /dev/null
+++ b/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/optimization/RegularizationPenaltyTest.scala
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.optimization
+
+import org.apache.flink.ml.math.DenseVector
+import org.scalatest.{FlatSpec, Matchers}
+
+
+class RegularizationPenaltyTest extends FlatSpec with Matchers {
+
+ behavior of "The Regularization Penalty Function implementations"
+
+ it should "correctly update weights and loss with L2 regularization penalty" in {
+ val loss = 3.4
+ val weights = DenseVector(0.8)
+ val gradient = DenseVector(2.0)
+
+ val updatedWeights = L2Regularization.takeStep(weights, gradient, 0.3, 0.01)
+ val updatedLoss = L2Regularization.regLoss(loss, updatedWeights, 0.3)
+
+ updatedWeights(0) should be (0.7776 +- 0.001)
+ updatedLoss should be (3.4907 +- 0.001)
+ }
+
+ it should "correctly update weights and loss with L1 regularization penalty" in {
+ val loss = 3.4
+ val weights = DenseVector(0.8)
+ val gradient = DenseVector(2.0)
+
+ val updatedWeights = L1Regularization.takeStep(weights, gradient, 0.3, 0.01)
+ val updatedLoss = L1Regularization.regLoss(loss, updatedWeights, 0.3)
+
+ updatedWeights(0) should be (0.777 +- 0.001)
+ updatedLoss should be (3.6331 +- 0.001)
+ }
+
+ it should "correctly update weights and loss with no regularization penalty" in {
+ val loss = 3.4
+ val weights = DenseVector(0.8)
+ val gradient = DenseVector(2.0)
+
+ val updatedWeights = NoRegularization.takeStep(weights, gradient, 0.3, 0.01)
+ val updatedLoss = NoRegularization.regLoss(loss, updatedWeights, 0.3)
+
+ updatedWeights(0) should be (0.78 +- 0.001)
+ updatedLoss should be (3.4 +- 0.001)
+ }
+}