You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by tr...@apache.org on 2015/04/02 11:10:38 UTC

flink git commit: [FLINK-1716] [ml] Adds CoCoA algorithm

Repository: flink
Updated Branches:
  refs/heads/master 57e9ae0bd -> 00759257a


[FLINK-1716] [ml] Adds CoCoA algorithm

[ml] Adds web documentation and code comments to CoCoA

[ml] Adds comments

This closes #545.


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/00759257
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/00759257
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/00759257

Branch: refs/heads/master
Commit: 00759257a2d69608040ba4f074d7dc9c45dbb964
Parents: 57e9ae0
Author: Till Rohrmann <tr...@apache.org>
Authored: Thu Mar 12 16:52:45 2015 +0100
Committer: Till Rohrmann <tr...@apache.org>
Committed: Thu Apr 2 11:10:16 2015 +0200

----------------------------------------------------------------------
 docs/ml/cocoa.md                                | 164 +++++++
 .../apache/flink/ml/classification/CoCoA.scala  | 492 +++++++++++++++++++
 .../org/apache/flink/ml/common/Block.scala      |  29 ++
 .../apache/flink/ml/common/ChainedLearner.scala |   6 +-
 .../flink/ml/common/ChainedTransformer.scala    |   4 +-
 .../org/apache/flink/ml/common/FlinkTools.scala |  67 ++-
 .../apache/flink/ml/common/LabeledVector.scala  |   2 +-
 .../org/apache/flink/ml/common/Learner.scala    |   2 +-
 .../apache/flink/ml/common/Transformer.scala    |   3 +-
 .../org/apache/flink/ml/math/DenseMatrix.scala  |   8 +-
 .../org/apache/flink/ml/math/DenseVector.scala  |   6 +-
 .../scala/org/apache/flink/ml/math/Matrix.scala |  12 +-
 .../org/apache/flink/ml/math/SparseMatrix.scala |  17 +
 .../org/apache/flink/ml/math/SparseVector.scala |  15 +
 .../scala/org/apache/flink/ml/math/Vector.scala |  15 +-
 .../apache/flink/ml/recommendation/ALS.scala    |  15 +-
 .../regression/MultipleLinearRegression.scala   |  14 +-
 .../ml/classification/Classification.scala      | 133 +++++
 .../flink/ml/classification/CoCoASuite.scala    |  52 ++
 .../flink/ml/math/SparseMatrixSuite.scala       |   4 +-
 .../flink/ml/math/SparseVectorSuite.scala       |   4 +-
 .../flink/ml/recommendation/ALSITSuite.scala    |  74 +--
 .../ml/recommendation/Recommendation.scala      |  90 ++++
 23 files changed, 1114 insertions(+), 114 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/00759257/docs/ml/cocoa.md
----------------------------------------------------------------------
diff --git a/docs/ml/cocoa.md b/docs/ml/cocoa.md
new file mode 100644
index 0000000..0bf8d67
--- /dev/null
+++ b/docs/ml/cocoa.md
@@ -0,0 +1,164 @@
+---
+mathjax: include
+title: Communication efficient distributed dual coordinate ascent (CoCoA)
+---
+<!--
+Licensed to the Apache Software Foundation (ASF) under one
+or more contributor license agreements.  See the NOTICE file
+distributed with this work for additional information
+regarding copyright ownership.  The ASF licenses this file
+to you under the Apache License, Version 2.0 (the
+"License"); you may not use this file except in compliance
+with the License.  You may obtain a copy of the License at
+
+  http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing,
+software distributed under the License is distributed on an
+"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+KIND, either express or implied.  See the License for the
+specific language governing permissions and limitations
+under the License.
+-->
+
+* This will be replaced by the TOC
+{:toc}
+
+## Description
+
+Implements the communication-efficient distributed dual coordinate ascent algorithm with hinge-loss function. 
+The algorithm can be used to train a SVM with soft-margin.
+The algorithm solves the following minimization problem:
+  
+$$\min_{\mathbf{w} \in \mathbb{R}^d} \frac{\lambda}{2} \left\lVert \mathbf{w} \right\rVert^2 + \frac{1}{n} \sum_{i=1}^n l_{i}\left(\mathbf{w}^T\mathbf{x}_i\right)$$
+ 
+with $\mathbf{w}$ being the weight vector, $\lambda$ being the regularization constant, 
+$$\mathbf{x}_i \in \mathbb{R}^d$$ being the data points and $$l_{i}$$ being the convex loss 
+functions, which can also depend on the labels $$y_{i} \in \mathbb{R}$$.
+In the current implementation the regularizer is the $\ell_2$-norm and the loss functions are the hinge-loss functions:
+  
+  $$l_{i} = \max\left(0, 1 - y_{i} \mathbf{w}^T\mathbf{x}_i \right)$$
+
+With these choices, the problem definition is equivalent to a SVM with soft-margin.
+Thus, the algorithm allows us to train a SVM with soft-margin.
+
+The minimization problem is solved by applying stochastic dual coordinate ascent (SDCA).
+In order to make the algorithm efficient in a distributed setting, the CoCoA algorithm calculates 
+several iterations of SDCA locally on a data block before merging the local updates into a
+valid global state.
+This state is redistributed to the different data partitions where the next round of local SDCA 
+iterations is then executed.
+The number of outer iterations and local SDCA iterations control the overall network costs, because 
+there is only network communication required for each outer iteration.
+The local SDCA iterations are embarrassingly parallel once the individual data partitions have been 
+distributed across the cluster.
+
+The implementation of this algorithm is based on the work of 
+[Jaggi et al.](http://arxiv.org/abs/1409.1458 here)
+
+## Parameters
+
+The CoCoA implementation can be controlled by the following parameters:
+
+   <table class="table table-bordered">
+    <thead>
+      <tr>
+        <th class="text-left" style="width: 20%">Parameters</th>
+        <th class="text-center">Description</th>
+      </tr>
+    </thead>
+
+    <tbody>
+      <tr>
+        <td><strong>Blocks</strong></td>
+        <td>
+          <p>
+            Sets the number of blocks into which the input data will be split. 
+            On each block the local stochastic dual coordinate ascent method is executed. 
+            This number should be set at least to the degree of parallelism. 
+            If no value is specified, then the parallelism of the input DataSet is used as the number of blocks. 
+            (Default value: <strong>None</strong>)
+          </p>
+        </td>
+      </tr>
+      <tr>
+        <td><strong>Iterations</strong></td>
+        <td>
+          <p>
+            Defines the maximum number of iterations of the outer loop method. 
+            In other words, it defines how often the SDCA method is applied to the blocked data. 
+            After each iteration, the locally computed weight vector updates have to be reduced to update the global weight vector value.
+            The new weight vector is broadcast to all SDCA tasks at the beginning of each iteration.
+            (Default value: <strong>10</strong>)
+          </p>
+        </td>
+      </tr>
+      <tr>
+        <td><strong>LocalIterations</strong></td>
+        <td>
+          <p>
+            Defines the maximum number of SDCA iterations. 
+            In other words, it defines how many data points are drawn from each local data block to calculate the stochastic dual coordinate ascent.
+            (Default value: <strong>10</strong>)
+          </p>
+        </td>
+      </tr>
+      <tr>
+        <td><strong>Regularization</strong></td>
+        <td>
+          <p>
+            Defines the regularization constant of the CoCoA algorithm. 
+            The higher the value, the smaller will the 2-norm of the weight vector be. 
+            In case of a SVM with hinge loss this means that the SVM margin will be wider even though it might contain some false classifications.
+            (Default value: <strong>1.0</strong>)
+          </p>
+        </td>
+      </tr>
+      <tr>
+        <td><strong>Stepsize</strong></td>
+        <td>
+          <p>
+            Defines the initial step size for the updates of the weight vector. 
+            The larger the step size is, the larger will be the contribution of the weight vector updates to the next weight vector value. 
+            The effective scaling of the updates is $\frac{stepsize}{blocks}$.
+            This value has to be tuned in case that the algorithm becomes instable. 
+            (Default value: <strong>1.0</strong>)
+          </p>
+        </td>
+      </tr>
+      <tr>
+        <td><strong>Seed</strong></td>
+        <td>
+          <p>
+            Defines the seed to initialize the random number generator. 
+            The seed directly controls which data points are chosen for the SDCA method. 
+            (Default value: <strong>0</strong>)
+          </p>
+        </td>
+      </tr>
+    </tbody>
+  </table>
+
+## Examples
+
+{% highlight scala %}
+// Read the training data set
+val trainingDS: DataSet[LabeledVector] = env.readSVMFile(pathToTrainingFile)
+
+// Create the CoCoA learner
+val cocoa = CoCoA()
+.setBlocks(10)
+.setIterations(10)
+.setLocalIterations(10)
+.setRegularization(0.5)
+.setStepsize(0.5)
+
+// Learn the SVM model
+val svm = cocoa.fit(trainingDS)
+
+// Read the testing data set
+val testingDS: DataSet[Vector] = env.readVectorFile(pathToTestingFile)
+
+// Calculate the predictions for the testing data set
+val predictionDS: DataSet[LabeledVector] = model.transform(testingDS)
+{% endhighlight %}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/flink/blob/00759257/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/classification/CoCoA.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/classification/CoCoA.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/classification/CoCoA.scala
new file mode 100644
index 0000000..e1c2053
--- /dev/null
+++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/classification/CoCoA.scala
@@ -0,0 +1,492 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification
+
+import scala.collection.mutable.ArrayBuffer
+import scala.util.Random
+
+import org.apache.flink.api.common.functions.RichMapFunction
+import org.apache.flink.api.scala._
+import org.apache.flink.configuration.Configuration
+import org.apache.flink.ml.common.FlinkTools.ModuloKeyPartitioner
+import org.apache.flink.ml.common._
+import org.apache.flink.ml.math.Vector
+import org.apache.flink.ml.math.Breeze._
+
+import breeze.linalg.{Vector => BreezeVector, DenseVector => BreezeDenseVector}
+
+/** Implements the communication-efficient distributed dual coordinate ascent algorithm with
+  * hinge-loss function. The algorithm can be used to train a SVM with soft-margin.
+  *
+  * The algorithm solves the following minimization problem:
+  *
+  * `min_{w in bbb"R"^d} lambda/2 ||w||^2 + 1/n sum_(i=1)^n l_{i}(w^Tx_i)`
+  *
+  * with `w` being the weight vector, `lambda` being the regularization constant,
+  * `x_{i} in bbb"R"^d` being the data points and `l_{i}` being the convex loss functions, which
+  * can also depend on the labels `y_{i} in bbb"R"`.
+  * In the current implementation the regularizer is the 2-norm and the loss functions are the
+  * hinge-loss functions:
+  *
+  * `l_{i} = max(0, 1 - y_{i} * w^Tx_i`
+  *
+  * With these choices, the problem definition is equivalent to a SVM with soft-margin.
+  * Thus, the algorithm allows us to train a SVM with soft-margin.
+  *
+  * The minimization problem is solved by applying stochastic dual coordinate ascent (SDCA).
+  * In order to make the algorithm efficient in a distributed setting, the CoCoA algorithm
+  * calculates several iterations of SDCA locally on a data block before merging the local
+  * updates into a valid global state.
+  * This state is redistributed to the different data partitions where the next round of local
+  * SDCA iterations is then executed.
+  * The number of outer iterations and local SDCA iterations control the overall network costs,
+  * because there is only network communication required for each outer iteration.
+  * The local SDCA iterations are embarrassingly parallel once the individual data partitions have
+  * been distributed across the cluster.
+  *
+  * Further details of the algorithm can be found [[http://arxiv.org/abs/1409.1458 here]].
+  *
+  * @example
+  *          {{{
+  *             val trainingDS: DataSet[LabeledVector] = env.readSVMFile(pathToTrainingFile)
+  *
+  *             val cocoa = CoCoA()
+  *               .setBlocks(10)
+  *               .setIterations(10)
+  *               .setLocalIterations(10)
+  *               .setRegularization(0.5)
+  *               .setStepsize(0.5)
+  *
+  *             val svm = cocoa.fit(trainingDS)
+  *
+  *             val testingDS: DataSet[Vector] = env.readVectorFile(pathToTestingFile)
+  *
+  *             val predictionDS: DataSet[LabeledVector] = model.transform(testingDS)
+  *          }}}
+  *
+  * =Parameters=
+  *
+  *  - [[CoCoA.Blocks]]:
+  *  Sets the number of blocks into which the input data will be split. On each block the local
+  *  stochastic dual coordinate ascent method is executed. This number should be set at least to
+  *  the degree of parallelism. If no value is specified, then the parallelism of the input
+  *  [[DataSet]] is used as the number of blocks. (Default value: '''None''')
+  *
+  *  - [[CoCoA.Iterations]]:
+  *  Defines the maximum number of iterations of the outer loop method. In other words, it defines
+  *  how often the SDCA method is applied to the blocked data. After each iteration, the locally
+  *  computed weight vector updates have to be reduced to update the global weight vector value.
+  *  The new weight vector is broadcast to all SDCA tasks at the beginning of each iteration.
+  *  (Default value: '''10''')
+  *
+  *  - [[CoCoA.LocalIterations]]:
+  *  Defines the maximum number of SDCA iterations. In other words, it defines how many data points
+  *  are drawn from each local data block to calculate the stochastic dual coordinate ascent.
+  *  (Default value: '''10''')
+  *
+  *  - [[CoCoA.Regularization]]:
+  *  Defines the regularization constant of the CoCoA algorithm. The higher the value, the smaller
+  *  will the 2-norm of the weight vector be. In case of a SVM with hinge loss this means that the
+  *  SVM margin will be wider even though it might contain some false classifications.
+  *  (Default value: '''1.0''')
+  *
+  *  - [[CoCoA.Stepsize]]:
+  *  Defines the initial step size for the updates of the weight vector. The larger the step size
+  *  is, the larger will be the contribution of the weight vector updates to the next weight vector
+  *  value. The effective scaling of the updates is `stepsize/blocks`. This value has to be tuned
+  *  in case that the algorithm becomes instable. (Default value: '''1.0''')
+  *
+  *  - [[CoCoA.Seed]]:
+  *  Defines the seed to initialize the random number generator. The seed directly controls which
+  *  data points are chosen for the SDCA method. (Default value: '''0''')
+  */
+class CoCoA extends Learner[LabeledVector, CoCoAModel] with Serializable {
+
+  import CoCoA._
+
+  /** Sets the number of data blocks/partitions
+    *
+    * @param blocks
+    * @return itself
+    */
+  def setBlocks(blocks: Int): CoCoA = {
+    parameters.add(Blocks, blocks)
+    this
+  }
+
+  /** Sets the number of outer iterations
+    *
+    * @param iterations
+    * @return itself
+    */
+  def setIterations(iterations: Int): CoCoA = {
+    parameters.add(Iterations, iterations)
+    this
+  }
+
+  /** Sets the number of local SDCA iterations
+    *
+    * @param localIterations
+    * @return itselft
+    */
+  def setLocalIterations(localIterations: Int): CoCoA =  {
+    parameters.add(LocalIterations, localIterations)
+    this
+  }
+
+  /** Sets the regularization constant
+    *
+    * @param regularization
+    * @return itself
+    */
+  def setRegularization(regularization: Double): CoCoA = {
+    parameters.add(Regularization, regularization)
+    this
+  }
+
+  /** Sets the stepsize for the weight vector updates
+    *
+    * @param stepsize
+    * @return itself
+    */
+  def setStepsize(stepsize: Double): CoCoA = {
+    parameters.add(Stepsize, stepsize)
+    this
+  }
+
+  /** Sets the seed value for the random number generator
+    *
+    * @param seed
+    * @return itself
+    */
+  def setSeed(seed: Long): CoCoA = {
+    parameters.add(Seed, seed)
+    this
+  }
+
+  /** Trains a SVM with soft-margin based on the given training data set.
+    *
+    * @param input Training data set
+    * @param fitParameters Parameter values
+    * @return Trained SVM model
+    */
+  override def fit(input: DataSet[LabeledVector], fitParameters: ParameterMap): CoCoAModel = {
+    val resultingParameters = this.parameters ++ fitParameters
+
+    // Check if the number of blocks/partitions has been specified
+    val blocks = resultingParameters.get(Blocks) match {
+      case Some(value) => value
+      case None => input.getParallelism
+    }
+
+    val scaling = resultingParameters(Stepsize)/blocks
+    val iterations = resultingParameters(Iterations)
+    val localIterations = resultingParameters(LocalIterations)
+    val regularization = resultingParameters(Regularization)
+    val seed = resultingParameters(Seed)
+
+    // Obtain DataSet with the dimension of the data points
+    val dimension = input.map{_.vector.size}.reduce{
+      (a, b) => {
+        require(a == b, "Dimensions of feature vectors have to be equal.")
+        a
+      }
+    }
+
+    val initialWeights = createInitialWeights(dimension)
+
+    // Count the number of vectors, but keep the value in a DataSet to broadcast it later
+    // TODO: Once efficient count and intermediate result partitions are implemented, use count
+    val numberVectors = input map { x => 1 } reduce { _ + _ }
+
+    // Group the input data into blocks in round robin fashion
+    val blockedInputNumberElements = FlinkTools.block(input, blocks, Some(ModuloKeyPartitioner)).
+      cross(numberVectors).
+      map { x => x }
+
+    val resultingWeights = initialWeights.iterate(iterations) {
+      weights => {
+        // compute the local SDCA to obtain the weight vector updates
+        val deltaWs = localDualMethod(
+          weights,
+          blockedInputNumberElements,
+          localIterations,
+          regularization,
+          scaling,
+          seed
+        )
+
+        // scale the weight vectors
+        val weightedDeltaWs = deltaWs map {
+          deltaW => {
+            deltaW :*= scaling
+          }
+        }
+
+        // calculate the new weight vector by adding the weight vector updates to the weight vector
+        // value
+        weights.union(weightedDeltaWs).reduce { _ + _ }
+      }
+    }
+
+    CoCoAModel(resultingWeights)
+  }
+
+  /** Creates a zero vector of length dimension
+    *
+    * @param dimension [[DataSet]] containing the dimension of the initial weight vector
+    * @return Zero vector of length dimension
+    */
+  private def createInitialWeights(dimension: DataSet[Int]): DataSet[BreezeDenseVector[Double]] = {
+    dimension.map {
+      d => BreezeDenseVector.zeros[Double](d)
+    }
+  }
+
+  /** Computes the local SDCA on the individual data blocks/partitions
+    *
+    * @param w Current weight vector
+    * @param blockedInputNumberElements Blocked/Partitioned input data
+    * @param localIterations Number of local SDCA iterations
+    * @param regularization Regularization constant
+    * @param scaling Scaling value for new weight vector updates
+    * @param seed Random number generator seed
+    * @return [[DataSet]] of weight vector updates. The weight vector updates are double arrays
+    */
+  private def localDualMethod(
+      w: DataSet[BreezeDenseVector[Double]],
+      blockedInputNumberElements: DataSet[(Block[LabeledVector], Int)],
+      localIterations: Int,
+      regularization: Double,
+      scaling: Double,
+      seed: Long)
+    : DataSet[BreezeDenseVector[Double]] = {
+    /*
+    Rich mapper calculating for each data block the local SDCA. We use a RichMapFunction here,
+    because we broadcast the current value of the weight vector to all mappers.
+     */
+    val localSDCA = new RichMapFunction[(Block[LabeledVector], Int), BreezeDenseVector[Double]] {
+      var originalW: BreezeDenseVector[Double] = _
+      // we keep the alphas across the outer loop iterations
+      val alphasArray = ArrayBuffer[BreezeDenseVector[Double]]()
+      // there might be several data blocks in one Flink partition, therefore store mapping
+      val idMapping = scala.collection.mutable.HashMap[Int, Int]()
+      var counter = 0
+
+      var r: Random = _
+
+      override def open(parameters: Configuration): Unit = {
+        originalW = getRuntimeContext.getBroadcastVariable(WEIGHT_VECTOR).get(0)
+
+        if(r == null){
+          r = new Random(seed ^ getRuntimeContext.getIndexOfThisSubtask)
+        }
+      }
+
+      override def map(blockNumberElements: (Block[LabeledVector], Int))
+        : BreezeDenseVector[Double] = {
+        val (block, numberElements) = blockNumberElements
+
+        // check if we already processed a data block with the corresponding block index
+        val localIndex = idMapping.get(block.index) match {
+          case Some(idx) => idx
+          case None =>
+            idMapping += (block.index -> counter)
+            counter += 1
+
+            alphasArray += BreezeDenseVector.zeros[Double](block.values.length)
+
+            counter - 1
+        }
+
+        // create temporary alpha array for the local SDCA iterations
+        val tempAlphas = alphasArray(localIndex).copy
+
+        val numLocalDatapoints = tempAlphas.length
+        val deltaAlphas = BreezeDenseVector.zeros[Double](numLocalDatapoints)
+
+        val w = originalW.copy
+
+        val deltaW = BreezeDenseVector.zeros[Double](originalW.length)
+
+        for(i <- 1 to localIterations) {
+          // pick random data point for SDCA
+          val idx = r.nextInt(numLocalDatapoints)
+
+          val LabeledVector(label, vector) = block.values(idx)
+          val alpha = tempAlphas(idx)
+
+          // maximize the dual problem and retrieve alpha and weight vector updates
+          val (deltaAlpha, deltaWUpdate) = maximize(
+            vector.asBreeze,
+            label,
+            regularization,
+            alpha,
+            w,
+            numberElements)
+
+          // update alpha values
+          tempAlphas(idx) += deltaAlpha
+          deltaAlphas(idx) += deltaAlpha
+
+          // deltaWUpdate is already scaled with 1/lambda/n
+          w += deltaWUpdate
+          deltaW += deltaWUpdate
+        }
+
+        // update local alpha values
+        alphasArray(localIndex) += deltaAlphas * scaling
+
+        deltaW
+      }
+    }
+
+    blockedInputNumberElements.map(localSDCA).withBroadcastSet(w, WEIGHT_VECTOR)
+  }
+
+  /** Maximizes the dual problem using hinge loss functions. It returns the alpha and weight
+    * vector updates.
+    *
+    * @param x Selected data point
+    * @param y Label of selected data point
+    * @param regularization Regularization constant
+    * @param alpha Alpha value of selected data point
+    * @param w Current weight vector value
+    * @param numberElements Number of elements in the training data set
+    * @return Alpha and weight vector updates
+    */
+  private def maximize(
+    x: BreezeVector[Double],
+    y: Double, regularization: Double,
+    alpha: Double,
+    w: BreezeVector[Double],
+    numberElements: Int)
+  : (Double, BreezeVector[Double]) = {
+    // compute hinge loss gradient
+    val dotProduct = x dot w
+    val grad = (y * dotProduct - 1.0) * (regularization * numberElements)
+
+    // compute projected gradient
+    var proj_grad = if(alpha  <= 0.0){
+      Math.min(grad, 0)
+    } else if(alpha >= 1.0) {
+      Math.max(grad, 0)
+    } else {
+      grad
+    }
+
+    if(Math.abs(grad) != 0.0){
+      val qii = x dot x
+      val newAlpha = if(qii != 0.0){
+        Math.min(Math.max((alpha - (grad / qii)), 0.0), 1.0)
+      } else {
+        1.0
+      }
+
+      val deltaW = x * y * (newAlpha - alpha) / (regularization * numberElements)
+
+      (newAlpha - alpha, deltaW)
+    } else {
+      (0.0 , BreezeVector.zeros(w.length))
+    }
+  }
+}
+
+/** Companion object of CoCoA. Contains convenience functions and the parameter type definitions
+  * of the algorithm.
+  */
+object CoCoA{
+  val WEIGHT_VECTOR ="weightVector"
+
+  case object Blocks extends Parameter[Int] {
+    val defaultValue: Option[Int] = None
+  }
+
+  case object Iterations extends Parameter[Int] {
+    val defaultValue = Some(10)
+  }
+
+  case object LocalIterations extends Parameter[Int] {
+    val defaultValue = Some(10)
+  }
+
+  case object Regularization extends Parameter[Double] {
+    val defaultValue = Some(1.0)
+  }
+
+  case object Stepsize extends Parameter[Double] {
+    val defaultValue = Some(1.0)
+  }
+
+  case object Seed extends Parameter[Long] {
+    val defaultValue = Some(0L)
+  }
+
+  def apply(): CoCoA = {
+    new CoCoA()
+  }
+}
+
+/** Resulting SVM model calculated by the CoCoA algorithm.
+  *
+   * @param weights Calculated weight vector representing the separating hyperplane of the
+  *                classification task.
+  */
+case class CoCoAModel(weights: DataSet[BreezeDenseVector[Double]])
+  extends Transformer[Vector, LabeledVector]
+  with Serializable {
+  import CoCoA.WEIGHT_VECTOR
+
+  /** Calculates the prediction value of the SVM value (not the label)
+    *
+    * @param input [[DataSet]] containing the vector for which to calculate the predictions
+    * @param parameters Parameter values for the algorithm
+    * @return [[DataSet]] containing the labeled vectors
+    */
+  override def transform(input: DataSet[Vector], parameters: ParameterMap):
+  DataSet[LabeledVector] = {
+    input.map(new PredictionMapper).withBroadcastSet(weights, WEIGHT_VECTOR)
+  }
+}
+
+/** Mapper to calculate the value of the prediction function. This is a RichMapFunction, because
+  * we broadcast the weight vector to all mappers.
+  */
+class PredictionMapper extends RichMapFunction[Vector, LabeledVector] {
+
+  import CoCoA.WEIGHT_VECTOR
+
+  var weights: BreezeDenseVector[Double] = _
+
+  @throws(classOf[Exception])
+  override def open(configuration: Configuration): Unit = {
+    // get current weights
+    weights = getRuntimeContext.
+      getBroadcastVariable[BreezeDenseVector[Double]](WEIGHT_VECTOR).get(0)
+  }
+
+  override def map(vector: Vector): LabeledVector = {
+    // calculate the prediction value (scaled distance from the separating hyperplane)
+    val dotProduct = weights dot vector.asBreeze
+
+    LabeledVector(dotProduct, vector)
+  }
+}
+
+

http://git-wip-us.apache.org/repos/asf/flink/blob/00759257/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/Block.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/Block.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/Block.scala
new file mode 100644
index 0000000..1af77ea
--- /dev/null
+++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/Block.scala
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common
+
+/** Base class for blocks of elements.
+  *
+  * TODO: Replace Vector type by Array type once Flink supports generic arrays
+  *
+  * @param index
+  * @param values
+  * @tparam T
+  */
+case class Block[T](index: Int, values: Vector[T]) {}

http://git-wip-us.apache.org/repos/asf/flink/blob/00759257/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/ChainedLearner.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/ChainedLearner.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/ChainedLearner.scala
index b1a0a2f..cf1b51a 100644
--- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/ChainedLearner.scala
+++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/ChainedLearner.scala
@@ -37,9 +37,9 @@ class ChainedLearner[IN, TEMP, OUT](val head: Transformer[IN, TEMP],
                                     val tail: Learner[TEMP, OUT])
   extends Learner[IN, OUT] {
 
-  override def fit(input: DataSet[IN], parameters: ParameterMap): OUT = {
-    val tempResult = head.transform(input, parameters)
+  override def fit(input: DataSet[IN], fitParameters: ParameterMap): OUT = {
+    val tempResult = head.transform(input, fitParameters)
 
-    tail.fit(tempResult, parameters)
+    tail.fit(tempResult, fitParameters)
   }
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/00759257/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/ChainedTransformer.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/ChainedTransformer.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/ChainedTransformer.scala
index 3f108bf..0658876 100644
--- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/ChainedTransformer.scala
+++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/ChainedTransformer.scala
@@ -36,8 +36,8 @@ class ChainedTransformer[IN, TEMP, OUT](val head: Transformer[IN, TEMP],
                                         val tail: Transformer[TEMP, OUT])
   extends Transformer[IN, OUT] {
 
-  override def transform(input: DataSet[IN], parameters: ParameterMap): DataSet[OUT] = {
-    val tempResult = head.transform(input, parameters)
+  override def transform(input: DataSet[IN], transformParameters: ParameterMap): DataSet[OUT] = {
+    val tempResult = head.transform(input, transformParameters)
     tail.transform(tempResult)
   }
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/00759257/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/FlinkTools.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/FlinkTools.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/FlinkTools.scala
index 22bbe82..57bf98e 100644
--- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/FlinkTools.scala
+++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/FlinkTools.scala
@@ -18,10 +18,11 @@
 
 package org.apache.flink.ml.common
 
+import org.apache.flink.api.common.functions.Partitioner
 import org.apache.flink.api.common.io.FileOutputFormat.OutputDirectoryMode
 import org.apache.flink.api.common.typeinfo.TypeInformation
 import org.apache.flink.api.java.io.{TypeSerializerInputFormat, TypeSerializerOutputFormat}
-import org.apache.flink.api.scala.DataSet
+import org.apache.flink.api.scala._
 import org.apache.flink.core.fs.FileSystem.WriteMode
 import org.apache.flink.core.fs.Path
 
@@ -34,6 +35,10 @@ import scala.reflect.ClassTag
   *  path and subsequently re-read from disk. This method can be used to effectively split the
   *  execution graph at the given [[DataSet]]. Writing it to disk triggers its materialization
   *  and specifying it as a source will prevent the re-execution of it.
+  *
+  *  - block:
+  *  Takes a DataSet of elements T and groups them in n blocks.
+  *
   */
 object FlinkTools {
 
@@ -324,4 +329,64 @@ object FlinkTools {
     (env.createInput(if1), env.createInput(if2), env.createInput(if3), env.createInput(if4), env
       .createInput(if5))
   }
+
+  /** Groups the DataSet input into numBlocks blocks.
+    * 
+    * @param input
+    * @param numBlocks Number of Blocks
+    * @param partitionerOption Optional partitioner to control the partitioning
+    * @tparam T
+    * @return
+    */
+  def block[T: TypeInformation: ClassTag](
+    input: DataSet[T],
+    numBlocks: Int,
+    partitionerOption: Option[Partitioner[Int]] = None)
+  : DataSet[Block[T]] = {
+    val blockIDInput = input map {
+      element =>
+        val blockID = element.hashCode() % numBlocks
+
+        val blockIDResult = if(blockID < 0){
+          blockID + numBlocks
+        } else {
+          blockID
+        }
+
+        (blockIDResult, element)
+    }
+
+    val preGroupBlockIDInput = partitionerOption match {
+      case Some(partitioner) =>
+        blockIDInput partitionCustom(partitioner, 0)
+
+      case None => blockIDInput
+    }
+
+    preGroupBlockIDInput.groupBy(0).reduceGroup {
+      iter => {
+        val array = iter.toVector
+
+        val blockID = array(0)._1
+        val elements = array.map(_._2)
+
+        Block[T](blockID, elements)
+      }
+    }.withForwardedFields("0 -> index")
+  }
+
+  /** Distributes the elements by taking the modulo of their keys and assigning it to this channel
+    *
+    */
+  object ModuloKeyPartitioner extends Partitioner[Int] {
+    override def partition(key: Int, numPartitions: Int): Int = {
+      val result = key % numPartitions
+
+      if(result < 0) {
+        result + numPartitions
+      } else {
+        result
+      }
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/00759257/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/LabeledVector.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/LabeledVector.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/LabeledVector.scala
index 4563724..3b948c0 100644
--- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/LabeledVector.scala
+++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/LabeledVector.scala
@@ -26,7 +26,7 @@ import org.apache.flink.ml.math.Vector
   * @param label Label of the data point
   * @param vector Data point
   */
-case class LabeledVector(label: Double, vector: Vector) {
+case class LabeledVector(label: Double, vector: Vector) extends Serializable {
 
   override def equals(obj: Any): Boolean = {
     obj match {

http://git-wip-us.apache.org/repos/asf/flink/blob/00759257/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/Learner.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/Learner.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/Learner.scala
index c8082c7..a081f76 100644
--- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/Learner.scala
+++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/Learner.scala
@@ -34,5 +34,5 @@ import org.apache.flink.api.scala.DataSet
   * @tparam OUT Type of the trained model
   */
 trait Learner[IN, OUT] extends WithParameters {
-  def fit(input: DataSet[IN], parameters: ParameterMap = ParameterMap.Empty): OUT
+  def fit(input: DataSet[IN], fitParameters: ParameterMap = ParameterMap.Empty): OUT
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/00759257/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/Transformer.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/Transformer.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/Transformer.scala
index 02d63cf..6b8780d 100644
--- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/Transformer.scala
+++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/Transformer.scala
@@ -45,5 +45,6 @@ trait Transformer[IN, OUT] extends WithParameters {
     new ChainedLearner[IN, OUT, CHAINED](this, learner)
   }
 
-  def transform(input: DataSet[IN], parameters: ParameterMap = ParameterMap.Empty): DataSet[OUT]
+  def transform(input: DataSet[IN], transformParameters: ParameterMap = ParameterMap.Empty):
+  DataSet[OUT]
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/00759257/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/DenseMatrix.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/DenseMatrix.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/DenseMatrix.scala
index fd490e1..4ae565e 100644
--- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/DenseMatrix.scala
+++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/DenseMatrix.scala
@@ -106,10 +106,16 @@ case class DenseMatrix(
     obj match {
       case dense: DenseMatrix =>
         numRows == dense.numRows && numCols == dense.numCols && data.sameElements(dense.data)
-      case _ => super.equals(obj)
+      case _ => false
     }
   }
 
+  override def hashCode: Int = {
+    val hashCodes = List(numRows.hashCode(), numCols.hashCode(), java.util.Arrays.hashCode(data))
+
+    hashCodes.foldLeft(3){(left, right) => left * 41 + right}
+  }
+
   /** Element wise update function
     *
     * @param row row index

http://git-wip-us.apache.org/repos/asf/flink/blob/00759257/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/DenseVector.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/DenseVector.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/DenseVector.scala
index ab657c5..e5c6187 100644
--- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/DenseVector.scala
+++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/DenseVector.scala
@@ -52,10 +52,14 @@ case class DenseVector(val data: Array[Double]) extends Vector with Serializable
   override def equals(obj: Any): Boolean = {
     obj match {
       case dense: DenseVector => data.length == dense.data.length && data.sameElements(dense.data)
-      case _ => super.equals(obj)
+      case _ => false
     }
   }
 
+  override def hashCode: Int = {
+    java.util.Arrays.hashCode(data)
+  }
+
   /**
    * Copies the vector instance
    *

http://git-wip-us.apache.org/repos/asf/flink/blob/00759257/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/Matrix.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/Matrix.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/Matrix.scala
index 11b4e55..ba6a781 100644
--- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/Matrix.scala
+++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/Matrix.scala
@@ -57,12 +57,12 @@ trait Matrix {
     */
   def copy: Matrix
 
-  override def equals(obj: Any): Boolean = {
-    obj match {
-      case matrix: Matrix if numRows == matrix.numRows && numCols == matrix.numCols =>
-        val coordinates = for(row <- 0 until numRows; col <- 0 until numCols) yield (row, col)
-        coordinates forall { case(row, col) => this.apply(row, col) == matrix(row, col)}
-      case _ => false
+  def equalsMatrix(matrix: Matrix): Boolean = {
+    if(numRows == matrix.numRows && numCols == matrix.numCols) {
+      val coordinates = for(row <- 0 until numRows; col <- 0 until numCols) yield (row, col)
+      coordinates forall { case(row, col) => this.apply(row, col) == matrix(row, col)}
+    } else {
+      false
     }
   }
 

http://git-wip-us.apache.org/repos/asf/flink/blob/00759257/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/SparseMatrix.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/SparseMatrix.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/SparseMatrix.scala
index c9842f8..061c464 100644
--- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/SparseMatrix.scala
+++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/SparseMatrix.scala
@@ -112,6 +112,23 @@ class SparseMatrix(
     result.toString
   }
 
+  override def equals(obj: Any): Boolean = {
+    obj match {
+      case sm: SparseMatrix if numRows == sm.numRows && numCols == sm.numCols =>
+        rowIndices.sameElements(sm.rowIndices) && colPtrs.sameElements(sm.colPtrs) &&
+        data.sameElements(sm.data)
+      case _ => false
+    }
+  }
+
+  override def hashCode: Int = {
+    val hashCodes = List(numRows.hashCode(), numCols.hashCode(),
+      java.util.Arrays.hashCode(rowIndices), java.util.Arrays.hashCode(colPtrs),
+      java.util.Arrays.hashCode(data))
+
+    hashCodes.foldLeft(5){(left, right) => left * 41 + right}
+  }
+
   private def locate(row: Int, col: Int): Int = {
     require(0 <= row && row < numRows && 0 <= col && col < numCols,
       (row, col) + " not in [0, " + numRows + ") x [0, " + numCols + ")")

http://git-wip-us.apache.org/repos/asf/flink/blob/00759257/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/SparseVector.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/SparseVector.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/SparseVector.scala
index 2c63203..6689aed 100644
--- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/SparseVector.scala
+++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/SparseVector.scala
@@ -78,6 +78,21 @@ class SparseVector(
     denseVector
   }
 
+  override def equals(obj: Any): Boolean = {
+    obj match {
+      case sv: SparseVector if size == sv.size =>
+        indices.sameElements(indices) && data.sameElements(sv.data)
+      case _ => false
+    }
+  }
+
+  override def hashCode: Int = {
+    val hashCodes = List(size.hashCode, java.util.Arrays.hashCode(indices),
+      java.util.Arrays.hashCode(data))
+
+    hashCodes.foldLeft(3){ (left, right) => left * 41 + right}
+  }
+
   override def toString: String = {
     val entries = indices.zip(data).mkString(", ")
     "SparseVector(" + entries + ")"

http://git-wip-us.apache.org/repos/asf/flink/blob/00759257/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/Vector.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/Vector.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/Vector.scala
index 7e7c32c..ef6b7aa 100644
--- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/Vector.scala
+++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/Vector.scala
@@ -49,14 +49,13 @@ trait Vector {
     */
   def copy: Vector
 
-  override def equals(obj: Any): Boolean = {
-    obj match {
-      case vector: Vector if size == vector.size =>
-        0 until size forall { idx =>
-          this(idx) == vector(idx)
-        }
-
-      case _ => false
+  def equalsVector(vector: Vector): Boolean = {
+    if(size == vector.size) {
+      (0 until size) forall { idx =>
+        this(idx) == vector(idx)
+      }
+    } else {
+      false
     }
   }
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/00759257/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/recommendation/ALS.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/recommendation/ALS.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/recommendation/ALS.scala
index 5ff59d1..5c6de55 100644
--- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/recommendation/ALS.scala
+++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/recommendation/ALS.scala
@@ -49,7 +49,7 @@ import scala.util.Random
   * In order to find the user and item matrix, the following problem is solved:
   *
   * `argmin_{U,V} sum_(i,j\ with\ r_{i,j} != 0) (r_{i,j} - u_{i}^Tv_{j})^2 +
-  * \lambda (sum_(i) n_{u_i} ||u_i||^2 + sum_(j) n_{v_j} ||v_j||^2)`
+  * lambda (sum_(i) n_{u_i} ||u_i||^2 + sum_(j) n_{v_j} ||v_j||^2)`
   *
   * with `\lambda` being the regularization factor, `n_{u_i}` being the number of items the user `i`
   * has rated and `n_{v_j}` being the number of times the item `j` has been rated. This
@@ -117,7 +117,8 @@ import scala.util.Random
   * [[https://github.com/apache/spark/blob/master/mllib/src/main/scala/org/apache/spark/mllib/
   * recommendation/ALS.scala here]].
   */
-class ALS extends Learner[(Int, Int, Double), ALSModel] with Serializable {
+class
+ALS extends Learner[(Int, Int, Double), ALSModel] with Serializable {
 
   import ALS._
 
@@ -900,10 +901,12 @@ object ALS {
   * @param itemFactors Calcualted item matrix
   * @param lambda Regularization value used to calculate the model
   */
-class ALSModel(@transient val userFactors: DataSet[Factors],
-               @transient val itemFactors: DataSet[Factors],
-               val lambda: Double) extends Transformer[(Int, Int), (Int, Int, Double)] with
-Serializable{
+class ALSModel(
+    @transient val userFactors: DataSet[Factors],
+    @transient val itemFactors: DataSet[Factors],
+    val lambda: Double)
+  extends Transformer[(Int, Int), (Int, Int, Double)]
+  with Serializable {
 
   override def transform(input: DataSet[(Int, Int)], parameters: ParameterMap): DataSet[(Int,
     Int, Double)] = {

http://git-wip-us.apache.org/repos/asf/flink/blob/00759257/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/regression/MultipleLinearRegression.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/regression/MultipleLinearRegression.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/regression/MultipleLinearRegression.scala
index 076156d..87352fa 100644
--- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/regression/MultipleLinearRegression.scala
+++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/regression/MultipleLinearRegression.scala
@@ -86,7 +86,8 @@ import com.github.fommil.netlib.BLAS.{ getInstance => blas }
   *  Threshold for relative change of sum of squared residuals until convergence.
   *
   */
-class MultipleLinearRegression extends Learner[LabeledVector, MultipleLinearRegressionModel]
+class 
+MultipleLinearRegression extends Learner[LabeledVector, MultipleLinearRegressionModel]
 with Serializable {
   import MultipleLinearRegression._
 
@@ -105,9 +106,9 @@ with Serializable {
     this
   }
 
-  override def fit(input: DataSet[LabeledVector], parameters: ParameterMap):
+  override def fit(input: DataSet[LabeledVector], fitParameters: ParameterMap):
   MultipleLinearRegressionModel = {
-    val map = this.parameters ++ parameters
+    val map = this.parameters ++ fitParameters
 
     // retrieve parameters of the algorithm
     val numberOfIterations = map(Iterations)
@@ -400,9 +401,10 @@ RichMapFunction[(Array[Double], Double, Int), (Array[Double], Double)] {
   *
   * @param weights DataSet containing the calculated weight vector
   */
-class MultipleLinearRegressionModel private[regression]
-(val weights: DataSet[(Array[Double], Double)]) extends
-Transformer[ Vector, LabeledVector ] {
+class MultipleLinearRegressionModel private[regression](
+    val weights: DataSet[(Array[Double], Double)])
+  extends Transformer[ Vector, LabeledVector ]
+  with Serializable {
 
   import MultipleLinearRegression.WEIGHTVECTOR_BROADCAST
 

http://git-wip-us.apache.org/repos/asf/flink/blob/00759257/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/classification/Classification.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/classification/Classification.scala b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/classification/Classification.scala
new file mode 100644
index 0000000..c9dd00f
--- /dev/null
+++ b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/classification/Classification.scala
@@ -0,0 +1,133 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification
+
+import org.apache.flink.ml.common.LabeledVector
+import org.apache.flink.ml.math.DenseVector
+
+object Classification {
+
+  /** Centered data of fisheriris data set
+    *
+    */
+  val trainingData = Seq[LabeledVector](
+    LabeledVector(1.0000, DenseVector(-0.2060, -0.2760)),
+    LabeledVector(1.0000, DenseVector(-0.4060, -0.1760)),
+    LabeledVector(1.0000, DenseVector(-0.0060, -0.1760)),
+    LabeledVector(1.0000, DenseVector(-0.9060, -0.3760)),
+    LabeledVector(1.0000, DenseVector(-0.3060, -0.1760)),
+    LabeledVector(1.0000, DenseVector(-0.4060, -0.3760)),
+    LabeledVector(1.0000, DenseVector(-0.2060, -0.0760)),
+    LabeledVector(1.0000, DenseVector(-1.6060, -0.6760)),
+    LabeledVector(1.0000, DenseVector(-0.3060, -0.3760)),
+    LabeledVector(1.0000, DenseVector(-1.0060, -0.2760)),
+    LabeledVector(1.0000, DenseVector(-1.4060, -0.6760)),
+    LabeledVector(1.0000, DenseVector(-0.7060, -0.1760)),
+    LabeledVector(1.0000, DenseVector(-0.9060, -0.6760)),
+    LabeledVector(1.0000, DenseVector(-0.2060, -0.2760)),
+    LabeledVector(1.0000, DenseVector(-1.3060, -0.3760)),
+    LabeledVector(1.0000, DenseVector(-0.5060, -0.2760)),
+    LabeledVector(1.0000, DenseVector(-0.4060, -0.1760)),
+    LabeledVector(1.0000, DenseVector(-0.8060, -0.6760)),
+    LabeledVector(1.0000, DenseVector(-0.4060, -0.1760)),
+    LabeledVector(1.0000, DenseVector(-1.0060, -0.5760)),
+    LabeledVector(1.0000, DenseVector(-0.1060, 0.1240)),
+    LabeledVector(1.0000, DenseVector(-0.9060, -0.3760)),
+    LabeledVector(1.0000, DenseVector(-0.0060, -0.1760)),
+    LabeledVector(1.0000, DenseVector(-0.2060, -0.4760)),
+    LabeledVector(1.0000, DenseVector(-0.6060, -0.3760)),
+    LabeledVector(1.0000, DenseVector(-0.5060, -0.2760)),
+    LabeledVector(1.0000, DenseVector(-0.1060, -0.2760)),
+    LabeledVector(1.0000, DenseVector(0.0940, 0.0240)),
+    LabeledVector(1.0000, DenseVector(-0.4060, -0.1760)),
+    LabeledVector(1.0000, DenseVector(-1.4060, -0.6760)),
+    LabeledVector(1.0000, DenseVector(-1.1060, -0.5760)),
+    LabeledVector(1.0000, DenseVector(-1.2060, -0.6760)),
+    LabeledVector(1.0000, DenseVector(-1.0060, -0.4760)),
+    LabeledVector(1.0000, DenseVector(0.1940, -0.0760)),
+    LabeledVector(1.0000, DenseVector(-0.4060, -0.1760)),
+    LabeledVector(1.0000, DenseVector(-0.4060, -0.0760)),
+    LabeledVector(1.0000, DenseVector(-0.2060, -0.1760)),
+    LabeledVector(1.0000, DenseVector(-0.5060, -0.3760)),
+    LabeledVector(1.0000, DenseVector(-0.8060, -0.3760)),
+    LabeledVector(1.0000, DenseVector(-0.9060, -0.3760)),
+    LabeledVector(1.0000, DenseVector(-0.5060, -0.4760)),
+    LabeledVector(1.0000, DenseVector(-0.3060, -0.2760)),
+    LabeledVector(1.0000, DenseVector(-0.9060, -0.4760)),
+    LabeledVector(1.0000, DenseVector(-1.6060, -0.6760)),
+    LabeledVector(1.0000, DenseVector(-0.7060, -0.3760)),
+    LabeledVector(1.0000, DenseVector(-0.7060, -0.4760)),
+    LabeledVector(1.0000, DenseVector(-0.7060, -0.3760)),
+    LabeledVector(1.0000, DenseVector(-0.6060, -0.3760)),
+    LabeledVector(1.0000, DenseVector(-1.9060, -0.5760)),
+    LabeledVector(1.0000, DenseVector(-0.8060, -0.3760)),
+    LabeledVector(-1.0000, DenseVector(1.0940, 0.8240)),
+    LabeledVector(-1.0000, DenseVector(0.1940, 0.2240)),
+    LabeledVector(-1.0000, DenseVector(0.9940, 0.4240)),
+    LabeledVector(-1.0000, DenseVector(0.6940, 0.1240)),
+    LabeledVector(-1.0000, DenseVector(0.8940, 0.5240)),
+    LabeledVector(-1.0000, DenseVector(1.6940, 0.4240)),
+    LabeledVector(-1.0000, DenseVector(-0.4060, 0.0240)),
+    LabeledVector(-1.0000, DenseVector(1.3940, 0.1240)),
+    LabeledVector(-1.0000, DenseVector(0.8940, 0.1240)),
+    LabeledVector(-1.0000, DenseVector(1.1940, 0.8240)),
+    LabeledVector(-1.0000, DenseVector(0.1940, 0.3240)),
+    LabeledVector(-1.0000, DenseVector(0.3940, 0.2240)),
+    LabeledVector(-1.0000, DenseVector(0.5940, 0.4240)),
+    LabeledVector(-1.0000, DenseVector(0.0940, 0.3240)),
+    LabeledVector(-1.0000, DenseVector(0.1940, 0.7240)),
+    LabeledVector(-1.0000, DenseVector(0.3940, 0.6240)),
+    LabeledVector(-1.0000, DenseVector(0.5940, 0.1240)),
+    LabeledVector(-1.0000, DenseVector(1.7940, 0.5240)),
+    LabeledVector(-1.0000, DenseVector(1.9940, 0.6240)),
+    LabeledVector(-1.0000, DenseVector(0.0940, -0.1760)),
+    LabeledVector(-1.0000, DenseVector(0.7940, 0.6240)),
+    LabeledVector(-1.0000, DenseVector(-0.0060, 0.3240)),
+    LabeledVector(-1.0000, DenseVector(1.7940, 0.3240)),
+    LabeledVector(-1.0000, DenseVector(-0.0060, 0.1240)),
+    LabeledVector(-1.0000, DenseVector(0.7940, 0.4240)),
+    LabeledVector(-1.0000, DenseVector(1.0940, 0.1240)),
+    LabeledVector(-1.0000, DenseVector(-0.1060, 0.1240)),
+    LabeledVector(-1.0000, DenseVector(-0.0060, 0.1240)),
+    LabeledVector(-1.0000, DenseVector(0.6940, 0.4240)),
+    LabeledVector(-1.0000, DenseVector(0.8940, -0.0760)),
+    LabeledVector(-1.0000, DenseVector(1.1940, 0.2240)),
+    LabeledVector(-1.0000, DenseVector(1.4940, 0.3240)),
+    LabeledVector(-1.0000, DenseVector(0.6940, 0.5240)),
+    LabeledVector(-1.0000, DenseVector(0.1940, -0.1760)),
+    LabeledVector(-1.0000, DenseVector(0.6940, -0.2760)),
+    LabeledVector(-1.0000, DenseVector(1.1940, 0.6240)),
+    LabeledVector(-1.0000, DenseVector(0.6940, 0.7240)),
+    LabeledVector(-1.0000, DenseVector(0.5940, 0.1240)),
+    LabeledVector(-1.0000, DenseVector(-0.1060, 0.1240)),
+    LabeledVector(-1.0000, DenseVector(0.4940, 0.4240)),
+    LabeledVector(-1.0000, DenseVector(0.6940, 0.7240)),
+    LabeledVector(-1.0000, DenseVector(0.1940, 0.6240)),
+    LabeledVector(-1.0000, DenseVector(0.1940, 0.2240)),
+    LabeledVector(-1.0000, DenseVector(0.9940, 0.6240)),
+    LabeledVector(-1.0000, DenseVector(0.7940, 0.8240)),
+    LabeledVector(-1.0000, DenseVector(0.2940, 0.6240)),
+    LabeledVector(-1.0000, DenseVector(0.0940, 0.2240)),
+    LabeledVector(-1.0000, DenseVector(0.2940, 0.3240)),
+    LabeledVector(-1.0000, DenseVector(0.4940, 0.6240)),
+    LabeledVector(-1.0000, DenseVector(0.1940, 0.1240))
+  )
+
+  val expectedWeightVector = DenseVector(-1.95, -3.45)
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/00759257/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/classification/CoCoASuite.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/classification/CoCoASuite.scala b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/classification/CoCoASuite.scala
new file mode 100644
index 0000000..35a69f1
--- /dev/null
+++ b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/classification/CoCoASuite.scala
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification
+
+import org.scalatest.{FlatSpec, Matchers}
+
+import org.apache.flink.api.scala._
+import org.apache.flink.test.util.FlinkTestBase
+
+class CoCoAITSuite extends FlatSpec with Matchers with FlinkTestBase {
+
+  behavior of "The CoCoA implementation"
+
+  it should "train a SVM" in {
+    val env = ExecutionEnvironment.getExecutionEnvironment
+
+    val learner = CoCoA().
+    setBlocks(env.getParallelism).
+    setIterations(100).
+    setLocalIterations(100).
+    setRegularization(0.002).
+    setStepsize(0.1).
+    setSeed(0)
+
+    val trainingDS = env.fromCollection(Classification.trainingData)
+
+    val model = learner.fit(trainingDS)
+
+    val weightVector = model.weights.collect(0)
+
+    weightVector.valuesIterator.zip(Classification.expectedWeightVector.valueIterator).foreach {
+      case (weight, expectedWeight) =>
+        weight should be(expectedWeight +- 0.1)
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/00759257/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/SparseMatrixSuite.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/SparseMatrixSuite.scala b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/SparseMatrixSuite.scala
index 5710931..132e7fe 100644
--- a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/SparseMatrixSuite.scala
+++ b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/SparseMatrixSuite.scala
@@ -34,7 +34,7 @@ class SparseMatrixSuite extends FlatSpec with Matchers {
     val sparseMatrix = SparseMatrix.fromCOO(numRows, numCols, data)
 
     val expectedSparseMatrix = SparseMatrix.fromCOO(5, 5, (3, 4, 42), (2, 1, 17), (3, 3, 88),
-      (4, 2, 99), (1, 4, 91))
+      (4, 2, 99), (1, 4, 91), (0, 0, 0), (0, 1, 0))
 
     val expectedDenseMatrix = DenseMatrix.zeros(5, 5)
     expectedDenseMatrix(3, 4) = 42
@@ -44,7 +44,7 @@ class SparseMatrixSuite extends FlatSpec with Matchers {
     expectedDenseMatrix(1, 4) = 91
 
     sparseMatrix should equal(expectedSparseMatrix)
-    sparseMatrix should equal(expectedDenseMatrix)
+    sparseMatrix.equalsMatrix(expectedDenseMatrix) should be(true)
 
     sparseMatrix.toDenseMatrix.data.sameElements(expectedDenseMatrix.data) should be(true)
 

http://git-wip-us.apache.org/repos/asf/flink/blob/00759257/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/SparseVectorSuite.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/SparseVectorSuite.scala b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/SparseVectorSuite.scala
index 28415e8..97ef1cb 100644
--- a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/SparseVectorSuite.scala
+++ b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/SparseVectorSuite.scala
@@ -29,14 +29,14 @@ class SparseVectorSuite extends FlatSpec with Matchers {
     val size = 5
     val sparseVector = SparseVector.fromCOO(size, data)
 
-    val expectedSparseVector = SparseVector.fromCOO(5, (0, 4), (4, 42))
+    val expectedSparseVector = SparseVector.fromCOO(5, (0, 4), (4, 42), (2, 0))
     val expectedDenseVector = DenseVector.zeros(5)
 
     expectedDenseVector(0) = 4
     expectedDenseVector(4) = 42
 
     sparseVector should equal(expectedSparseVector)
-    sparseVector should equal(expectedDenseVector)
+    sparseVector.equalsVector(expectedDenseVector) should be(true)
 
     val denseVector = sparseVector.toDenseVector
 

http://git-wip-us.apache.org/repos/asf/flink/blob/00759257/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/recommendation/ALSITSuite.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/recommendation/ALSITSuite.scala b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/recommendation/ALSITSuite.scala
index aadcd2d..db5ad6e 100644
--- a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/recommendation/ALSITSuite.scala
+++ b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/recommendation/ALSITSuite.scala
@@ -36,7 +36,7 @@ class ALSITSuite
   behavior of "The alternating least squares (ALS) implementation"
 
   it should "properly factorize a matrix" in {
-    import ALSData._
+    import Recommendation._
 
     val env = ExecutionEnvironment.getExecutionEnvironment
 
@@ -75,75 +75,3 @@ class ALSITSuite
     risk should be(expectedEmpiricalRisk +- 1)
   }
 }
-
-object ALSData {
-
-  val iterations = 9
-  val lambda = 1.0
-  val numFactors = 5
-
-  val data: Seq[(Int, Int, Double)] = {
-    Seq(
-      (2,13,534.3937734561154),
-      (6,14,509.63176469621936),
-      (4,14,515.8246770897443),
-      (7,3,495.05234565105),
-      (2,3,532.3281786219485),
-      (5,3,497.1906356844367),
-      (3,3,512.0640508585093),
-      (10,3,500.2906742233019),
-      (1,4,521.9189079662882),
-      (2,4,515.0734651491396),
-      (1,7,522.7532725967008),
-      (8,4,492.65683825096403),
-      (4,8,492.65683825096403),
-      (10,8,507.03319667905413),
-      (7,1,522.7532725967008),
-      (1,1,572.2230209271174),
-      (2,1,563.5849190220224),
-      (6,1,518.4844061038742),
-      (9,1,529.2443732217674),
-      (8,1,543.3202505434103),
-      (7,2,516.0188923307859),
-      (1,2,563.5849190220224),
-      (1,11,515.1023793011227),
-      (8,2,536.8571133978352),
-      (2,11,507.90776961762225),
-      (3,2,532.3281786219485),
-      (5,11,476.24185144363304),
-      (4,2,515.0734651491396),
-      (4,11,469.92049343738233),
-      (3,12,509.4713776280098),
-      (4,12,494.6533165132021),
-      (7,5,482.2907867916308),
-      (6,5,477.5940040923741),
-      (4,5,480.9040684364228),
-      (1,6,518.4844061038742),
-      (6,6,470.6605085832807),
-      (8,6,489.6360564705307),
-      (4,6,472.74052954447046),
-      (7,9,482.5837650471611),
-      (5,9,487.00175463269863),
-      (9,9,500.69514584780944),
-      (4,9,477.71644808419325),
-      (7,10,485.3852917539852),
-      (8,10,507.03319667905413),
-      (3,10,500.2906742233019),
-      (5,15,488.08215944254437),
-      (6,15,480.16929757607346)
-    )
-  }
-
-  val expectedResult: Seq[(Int, Int, Double)] = {
-    Seq(
-      (2, 2, 526.1037),
-      (5, 9, 468.5680),
-      (10, 3, 484.8975),
-      (5, 13, 451.6228),
-      (1, 15, 493.4956),
-      (4, 11, 456.3862)
-    )
-  }
-
-  val expectedEmpiricalRisk = 505374.1877
-}

http://git-wip-us.apache.org/repos/asf/flink/blob/00759257/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/recommendation/Recommendation.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/recommendation/Recommendation.scala b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/recommendation/Recommendation.scala
new file mode 100644
index 0000000..8d8e4b9
--- /dev/null
+++ b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/recommendation/Recommendation.scala
@@ -0,0 +1,90 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.recommendation
+
+object Recommendation {
+  val iterations = 9
+  val lambda = 1.0
+  val numFactors = 5
+
+  val data: Seq[(Int, Int, Double)] = {
+    Seq(
+      (2,13,534.3937734561154),
+      (6,14,509.63176469621936),
+      (4,14,515.8246770897443),
+      (7,3,495.05234565105),
+      (2,3,532.3281786219485),
+      (5,3,497.1906356844367),
+      (3,3,512.0640508585093),
+      (10,3,500.2906742233019),
+      (1,4,521.9189079662882),
+      (2,4,515.0734651491396),
+      (1,7,522.7532725967008),
+      (8,4,492.65683825096403),
+      (4,8,492.65683825096403),
+      (10,8,507.03319667905413),
+      (7,1,522.7532725967008),
+      (1,1,572.2230209271174),
+      (2,1,563.5849190220224),
+      (6,1,518.4844061038742),
+      (9,1,529.2443732217674),
+      (8,1,543.3202505434103),
+      (7,2,516.0188923307859),
+      (1,2,563.5849190220224),
+      (1,11,515.1023793011227),
+      (8,2,536.8571133978352),
+      (2,11,507.90776961762225),
+      (3,2,532.3281786219485),
+      (5,11,476.24185144363304),
+      (4,2,515.0734651491396),
+      (4,11,469.92049343738233),
+      (3,12,509.4713776280098),
+      (4,12,494.6533165132021),
+      (7,5,482.2907867916308),
+      (6,5,477.5940040923741),
+      (4,5,480.9040684364228),
+      (1,6,518.4844061038742),
+      (6,6,470.6605085832807),
+      (8,6,489.6360564705307),
+      (4,6,472.74052954447046),
+      (7,9,482.5837650471611),
+      (5,9,487.00175463269863),
+      (9,9,500.69514584780944),
+      (4,9,477.71644808419325),
+      (7,10,485.3852917539852),
+      (8,10,507.03319667905413),
+      (3,10,500.2906742233019),
+      (5,15,488.08215944254437),
+      (6,15,480.16929757607346)
+    )
+  }
+
+  val expectedResult: Seq[(Int, Int, Double)] = {
+    Seq(
+      (2, 2, 526.1037),
+      (5, 9, 468.5680),
+      (10, 3, 484.8975),
+      (5, 13, 451.6228),
+      (1, 15, 493.4956),
+      (4, 11, 456.3862)
+    )
+  }
+
+  val expectedEmpiricalRisk = 505374.1877
+}