You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by tr...@apache.org on 2015/06/04 14:04:13 UTC

flink git commit: [FLINK-1993] [ml] Replaces custom SGD logic with optimization framework's SGD in MultipleLinearRegression

Repository: flink
Updated Branches:
  refs/heads/master 1559701f4 -> 463300ec5


[FLINK-1993] [ml] Replaces custom SGD logic with optimization framework's SGD in MultipleLinearRegression

Fixes PipelineITSuite because of change MLR loss function

This closes #760.


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/463300ec
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/463300ec
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/463300ec

Branch: refs/heads/master
Commit: 463300ec560efd2acf64ebf5520129868e7e25ae
Parents: 1559701
Author: Till Rohrmann <tr...@apache.org>
Authored: Fri May 29 18:02:47 2015 +0200
Committer: Till Rohrmann <tr...@apache.org>
Committed: Thu Jun 4 14:03:19 2015 +0200

----------------------------------------------------------------------
 .../apache/flink/ml/classification/SVM.scala    |   8 +-
 .../apache/flink/ml/pipeline/Estimator.scala    |   2 +-
 .../apache/flink/ml/pipeline/Predictor.scala    |   2 +-
 .../regression/MultipleLinearRegression.scala   | 362 ++-----------------
 .../flink/ml/pipeline/PipelineITSuite.scala     |   6 +-
 .../MultipleLinearRegressionITSuite.scala       |  18 +-
 .../flink/ml/regression/RegressionData.scala    |   4 +-
 7 files changed, 60 insertions(+), 342 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/463300ec/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/classification/SVM.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/classification/SVM.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/classification/SVM.scala
index c69b56a..b259090 100644
--- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/classification/SVM.scala
+++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/classification/SVM.scala
@@ -270,10 +270,10 @@ object SVM{
   implicit def predictLabeledValues = {
     new PredictOperation[SVM, LabeledVector, (Double, Double)]{
       override def predict(
-                            instance: SVM,
-                            predictParameters: ParameterMap,
-                            input: DataSet[LabeledVector])
-      : DataSet[(Double, Double)] = {
+          instance: SVM,
+          predictParameters: ParameterMap,
+          input: DataSet[LabeledVector])
+        : DataSet[(Double, Double)] = {
 
         instance.weightsOption match {
           case Some(weights) => {

http://git-wip-us.apache.org/repos/asf/flink/blob/463300ec/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/pipeline/Estimator.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/pipeline/Estimator.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/pipeline/Estimator.scala
index 088b184..e3031f7 100644
--- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/pipeline/Estimator.scala
+++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/pipeline/Estimator.scala
@@ -34,7 +34,7 @@ import org.apache.flink.ml.common.{FlinkMLTools, ParameterMap, WithParameters}
   *
   * @tparam Self
   */
-trait Estimator[Self] extends WithParameters with Serializable {
+trait Estimator[Self] extends WithParameters {
   that: Self =>
 
   /** Fits the estimator to the given input data. The fitting logic is contained in the

http://git-wip-us.apache.org/repos/asf/flink/blob/463300ec/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/pipeline/Predictor.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/pipeline/Predictor.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/pipeline/Predictor.scala
index 8a6b204..9bb5c5c 100644
--- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/pipeline/Predictor.scala
+++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/pipeline/Predictor.scala
@@ -35,7 +35,7 @@ import org.apache.flink.ml.common.{FlinkMLTools, ParameterMap, WithParameters}
   *
   * @tparam Self Type of the implementing class
   */
-trait Predictor[Self] extends Estimator[Self] with WithParameters with Serializable {
+trait Predictor[Self] extends Estimator[Self] with WithParameters {
   that: Self =>
 
   /** Predict testing data according the learned model. The implementing class has to provide

http://git-wip-us.apache.org/repos/asf/flink/blob/463300ec/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/regression/MultipleLinearRegression.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/regression/MultipleLinearRegression.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/regression/MultipleLinearRegression.scala
index 32746a1..439d038 100644
--- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/regression/MultipleLinearRegression.scala
+++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/regression/MultipleLinearRegression.scala
@@ -18,15 +18,14 @@
 
 package org.apache.flink.ml.regression
 
-import org.apache.flink.api.common.functions.RichMapFunction
 import org.apache.flink.api.scala.DataSet
-import org.apache.flink.configuration.Configuration
-import org.apache.flink.ml.math.{DenseVector, BLAS, Vector, vector2Array}
+import org.apache.flink.ml.math.Vector
 import org.apache.flink.ml.common._
 
 import org.apache.flink.api.scala._
 
-import com.github.fommil.netlib.BLAS.{ getInstance => blas }
+import org.apache.flink.ml.optimization.{LinearPrediction, SquaredLoss, GenericLossFunction,
+SimpleGradientDescent}
 import org.apache.flink.ml.pipeline.{FitOperation, PredictOperation, Predictor}
 
 /** Multiple linear regression using the ordinary least squares (OLS) estimator.
@@ -44,7 +43,7 @@ import org.apache.flink.ml.pipeline.{FitOperation, PredictOperation, Predictor}
   * the current value `w` which gives the new value of `w_new`. The weight is defined as
   * `stepsize/math.sqrt(iteration)`.
   *
-  * The optimization runs at most a maximum number of iteratinos or, if a convergence threshold has
+  * The optimization runs at most a maximum number of iterations or, if a convergence threshold has
   * been set, until the convergence criterion has been met. As convergence criterion the relative
   * change of the sum of squared residuals is used:
   *
@@ -87,11 +86,11 @@ import org.apache.flink.ml.pipeline.{FitOperation, PredictOperation, Predictor}
   *
   */
 class MultipleLinearRegression extends Predictor[MultipleLinearRegression] {
-
+  import org.apache.flink.ml._
   import MultipleLinearRegression._
 
   // Stores the weights of the linear model after the fitting phase
-  var weightsOption: Option[DataSet[(Array[Double], Double)]] = None
+  var weightsOption: Option[DataSet[WeightVector]] = None
 
   def setIterations(iterations: Int): MultipleLinearRegression = {
     parameters.add(Iterations, iterations)
@@ -111,9 +110,9 @@ class MultipleLinearRegression extends Predictor[MultipleLinearRegression] {
   def squaredResidualSum(input: DataSet[LabeledVector]): DataSet[Double] = {
     weightsOption match {
       case Some(weights) => {
-        input.map {
-          new SquaredResiduals
-        }.withBroadcastSet(weights, WEIGHTVECTOR_BROADCAST).reduce {
+        input.mapWithBcVariable(weights){
+          (dataPoint, weights) => lossFunction.loss(dataPoint, weights)
+        }.reduce {
           _ + _
         }
       }
@@ -128,8 +127,13 @@ class MultipleLinearRegression extends Predictor[MultipleLinearRegression] {
 }
 
 object MultipleLinearRegression {
+
+  import org.apache.flink.ml._
+
   val WEIGHTVECTOR_BROADCAST = "weights_broadcast"
 
+  val lossFunction = GenericLossFunction(SquaredLoss, LinearPrediction)
+
   // ====================================== Parameters =============================================
 
   case object Stepsize extends Parameter[Double] {
@@ -158,10 +162,10 @@ object MultipleLinearRegression {
     */
   implicit val fitMLR = new FitOperation[MultipleLinearRegression, LabeledVector] {
     override def fit(
-        instance: MultipleLinearRegression,
-        fitParameters: ParameterMap,
-        input: DataSet[LabeledVector])
-      : Unit = {
+      instance: MultipleLinearRegression,
+      fitParameters: ParameterMap,
+      input: DataSet[LabeledVector])
+    : Unit = {
       val map = instance.parameters ++ fitParameters
 
       // retrieve parameters of the algorithm
@@ -169,128 +173,19 @@ object MultipleLinearRegression {
       val stepsize = map(Stepsize)
       val convergenceThreshold = map.get(ConvergenceThreshold)
 
-      // calculate dimension of the feature vectors
-      val dimension = input.map{_.vector.size}.reduce {
-        (a, b) =>
-          require(a == b, "All input vector must have the same dimension.")
-          a
-      }
-
-      input.flatMap{
-        t =>
-          Seq(t)
-      }
-
-      // initial weight vector is set to 0
-      val initialWeightVector = createInitialWeightVector(dimension)
-
-      // check if a convergence threshold has been set
-      val resultingWeightVector = convergenceThreshold match {
-        case Some(convergence) =>
+      val lossFunction = GenericLossFunction(SquaredLoss, LinearPrediction)
 
-          // we have to calculate for each weight vector the sum of squared residuals
-          val initialSquaredResidualSum = input.map {
-            new SquaredResiduals
-          }.withBroadcastSet(initialWeightVector, WEIGHTVECTOR_BROADCAST).reduce {
-            _ + _
-          }
-
-          // combine weight vector with current sum of squared residuals
-          val initialWeightVectorWithSquaredResidualSum = initialWeightVector.
-            crossWithTiny(initialSquaredResidualSum).setParallelism(1)
-
-          // start SGD iteration
-          val resultWithResidual = initialWeightVectorWithSquaredResidualSum.
-            iterateWithTermination(numberOfIterations) {
-            weightVectorSquaredResidualDS =>
-
-              // extract weight vector and squared residual sum
-              val weightVector = weightVectorSquaredResidualDS.map{_._1}
-              val squaredResidualSum = weightVectorSquaredResidualDS.map{_._2}
-
-              // TODO: Sample from input to realize proper SGD
-              val newWeightVector = input.map {
-                new LinearRegressionGradientDescent
-              }.withBroadcastSet(weightVector, WEIGHTVECTOR_BROADCAST).reduce {
-                (left, right) =>
-                  val (leftBetas, leftBeta0, leftCount) = left
-                  val (rightBetas, rightBeta0, rightCount) = right
-
-                  blas.daxpy(leftBetas.length, 1.0, rightBetas, 1, leftBetas, 1)
-
-                  (leftBetas, leftBeta0 + rightBeta0, leftCount + rightCount)
-              }.map {
-                new LinearRegressionWeightsUpdate(stepsize)
-              }.withBroadcastSet(weightVector, WEIGHTVECTOR_BROADCAST)
-
-              // calculate the sum of squared residuals for the new weight vector
-              val newResidual = input.map {
-                new SquaredResiduals
-              }.withBroadcastSet(newWeightVector, WEIGHTVECTOR_BROADCAST).reduce {
-                _ + _
-              }
-
-              // check if the relative change in the squared residual sum is smaller than the
-              // convergence threshold. If yes, then terminate => return empty termination data set
-              val termination = squaredResidualSum.crossWithTiny(newResidual).setParallelism(1).
-                filter{
-                pair => {
-                  val (residual, newResidual) = pair
-
-                  if (residual <= 0) {
-                    false
-                  } else {
-                    math.abs((residual - newResidual)/residual) >= convergence
-                  }
-                }
-              }
-
-              // result for new iteration
-              (newWeightVector cross newResidual, termination)
-          }
-
-          // remove squared residual sum to only return the weight vector
-          resultWithResidual.map{_._1}
+      val optimizer = SimpleGradientDescent()
+        .setIterations(numberOfIterations)
+        .setStepsize(stepsize)
+        .setLossFunction(lossFunction)
 
+      convergenceThreshold match {
+        case Some(threshold) => optimizer.setConvergenceThreshold(threshold)
         case None =>
-          // No convergence criterion
-          initialWeightVector.iterate(numberOfIterations) {
-            weightVector => {
-
-              // TODO: Sample from input to realize proper SGD
-              input.map {
-                new LinearRegressionGradientDescent
-              }.withBroadcastSet(weightVector, WEIGHTVECTOR_BROADCAST).reduce {
-                (left, right) =>
-                  val (leftBetas, leftBeta0, leftCount) = left
-                  val (rightBetas, rightBeta0, rightCount) = right
-
-                  blas.daxpy(leftBetas.length, 1, rightBetas, 1, leftBetas, 1)
-                  (leftBetas, leftBeta0 + rightBeta0, leftCount + rightCount)
-              }.map {
-                new LinearRegressionWeightsUpdate(stepsize)
-              }.withBroadcastSet(weightVector, WEIGHTVECTOR_BROADCAST)
-            }
-          }
       }
 
-      instance.weightsOption = Some(resultingWeightVector)
-    }
-  }
-
-  /** Creates a DataSet with one zero vector. The zero vector has dimension d, which is given
-    * by the dimensionDS.
-    *
-    * @param dimensionDS DataSet with one element d, denoting the dimension of the returned zero
-    *                    vector
-    * @return DataSet of a zero vector of dimension d
-    */
-  private def createInitialWeightVector(dimensionDS: DataSet[Int]):
-  DataSet[(Array[Double], Double)] = {
-    dimensionDS.map {
-      dimension =>
-        val values = Array.fill(dimension)(0.0)
-        (values, 0.0)
+      instance.weightsOption = Some(optimizer.optimize(input, None))
     }
   }
 
@@ -298,7 +193,8 @@ object MultipleLinearRegression {
     *
     * @tparam T Testing data type for which the prediction is calculated. Has to be a subtype of
     *           [[Vector]]
-    * @return
+    * @return [[PredictOperation]] which calculates for a given vector it's label according to the
+    *        linear model. The result of this [[PredictOperation]] is a [[LabeledVector]]
     */
   implicit def predictVectors[T <: Vector] = {
     new PredictOperation[MultipleLinearRegression, T, LabeledVector] {
@@ -309,8 +205,10 @@ object MultipleLinearRegression {
       : DataSet[LabeledVector] = {
         instance.weightsOption match {
           case Some(weights) => {
-            input.map(new LinearRegressionPrediction[T])
-              .withBroadcastSet(weights, WEIGHTVECTOR_BROADCAST)
+            input.mapWithBcVariable(weights) {
+              (dataPoint, weights) =>
+                LabeledVector(LinearPrediction.predict(dataPoint, weights), dataPoint)
+            }
           }
 
           case None => {
@@ -322,31 +220,6 @@ object MultipleLinearRegression {
     }
   }
 
-  private class LinearRegressionPrediction[T <: Vector] extends RichMapFunction[T, LabeledVector] {
-    private var weights: Array[Double] = null
-    private var weight0: Double = 0
-
-
-    @throws(classOf[Exception])
-    override def open(configuration: Configuration): Unit = {
-      val t = getRuntimeContext
-        .getBroadcastVariable[(Array[Double], Double)](WEIGHTVECTOR_BROADCAST)
-
-      val weightsPair = t.get(0)
-
-      weights = weightsPair._1
-      weight0 = weightsPair._2
-    }
-
-    override def map(value: T): LabeledVector = {
-      val dotProduct = blas.ddot(weights.length, weights, 1, vector2Array(value), 1)
-
-      val prediction = dotProduct + weight0
-
-      LabeledVector(prediction, value)
-    }
-  }
-
   /** Calculates the predictions for labeled data with respect to the learned linear model.
     *
     * @return A DataSet[(Double, Double)] where each tuple is a (truth, prediction) pair.
@@ -354,14 +227,17 @@ object MultipleLinearRegression {
   implicit def predictLabeledVectors = {
     new PredictOperation[MultipleLinearRegression, LabeledVector, (Double, Double)] {
       override def predict(
-                            instance: MultipleLinearRegression,
-                            predictParameters: ParameterMap,
-                            input: DataSet[LabeledVector])
+        instance: MultipleLinearRegression,
+        predictParameters: ParameterMap,
+        input: DataSet[LabeledVector])
       : DataSet[(Double, Double)] = {
         instance.weightsOption match {
           case Some(weights) => {
-            input.map(new LinearRegressionLabeledPrediction)
-              .withBroadcastSet(weights, WEIGHTVECTOR_BROADCAST)
+            input.mapWithBcVariable(weights) {
+              (labeledVector, weights) => {
+                (labeledVector.label, LinearPrediction.predict(labeledVector.vector, weights))
+              }
+            }
           }
 
           case None => {
@@ -372,162 +248,4 @@ object MultipleLinearRegression {
       }
     }
   }
-
-  private class LinearRegressionLabeledPrediction
-    extends RichMapFunction[LabeledVector, (Double, Double)] {
-    private var weights: Array[Double] = null
-    private var weight0: Double = 0
-
-
-    @throws(classOf[Exception])
-    override def open(configuration: Configuration): Unit = {
-      val t = getRuntimeContext
-        .getBroadcastVariable[(Array[Double], Double)](WEIGHTVECTOR_BROADCAST)
-
-      val weightsPair = t.get(0)
-
-      weights = weightsPair._1
-      weight0 = weightsPair._2
-    }
-
-    override def map(labeledVector: LabeledVector ): (Double, Double) = {
-
-      val truth = labeledVector.label
-      val dotProduct = BLAS.dot(DenseVector(weights), labeledVector.vector)
-
-      val prediction = dotProduct + weight0
-
-      (truth, prediction)
-    }
-  }
-}
-
-//--------------------------------------------------------------------------------------------------
-//  Flink function definitions
-//--------------------------------------------------------------------------------------------------
-
-/** Calculates for a labeled vector and the current weight vector its squared residual:
-  *
-  * `(y - (w^Tx + w_0))^2`
-  *
-  * The weight vector is received as a broadcast variable.
-  */
-private class SquaredResiduals extends RichMapFunction[LabeledVector, Double] {
-  import MultipleLinearRegression.WEIGHTVECTOR_BROADCAST
-
-  var weightVector: Array[Double] = null
-  var weight0: Double = 0.0
-
-  @throws(classOf[Exception])
-  override def open(configuration: Configuration): Unit = {
-    val list = this.getRuntimeContext.
-      getBroadcastVariable[(Array[Double], Double)](WEIGHTVECTOR_BROADCAST)
-
-    val weightsPair = list.get(0)
-
-    weightVector = weightsPair._1
-    weight0 = weightsPair._2
-  }
-
-  override def map(value: LabeledVector): Double = {
-    val array = vector2Array(value.vector)
-    val label = value.label
-
-    val dotProduct = blas.ddot(weightVector.length, weightVector, 1, array, 1)
-
-    val residual = dotProduct + weight0 - label
-
-    residual * residual
-  }
-}
-
-/** Calculates for a labeled vector and the current weight vector the gradient minimizing the
-  * OLS equation. The gradient is given by:
-  *
-  * `dw = 2*(w^T*x + w_0 - y)*x`
-  * `dw_0 = 2*(w^T*x + w_0 - y)`
-  *
-  * The weight vector is received as a broadcast variable.
-  */
-private class LinearRegressionGradientDescent extends
-RichMapFunction[LabeledVector, (Array[Double], Double, Int)] {
-
-  import MultipleLinearRegression.WEIGHTVECTOR_BROADCAST
-
-  var weightVector: Array[Double] = null
-  var weight0: Double = 0.0
-
-  @throws(classOf[Exception])
-  override def open(configuration: Configuration): Unit = {
-    val list = this.getRuntimeContext.
-      getBroadcastVariable[(Array[Double], Double)](WEIGHTVECTOR_BROADCAST)
-
-    val weightsPair = list.get(0)
-
-    weightVector = weightsPair._1
-    weight0 = weightsPair._2
-  }
-
-  override def map(value: LabeledVector): (Array[Double], Double, Int) = {
-    val x = vector2Array(value.vector)
-    val label = value.label
-
-    val dotProduct = blas.ddot(weightVector.length, weightVector, 1, x, 1)
-
-    val error = dotProduct + weight0 - label
-
-    // reuse vector x
-    val weightsGradient = x
-
-    blas.dscal(weightsGradient.length, 2*error, weightsGradient, 1)
-
-    val weight0Gradient = 2 * error
-
-    (weightsGradient, weight0Gradient, 1)
-  }
-}
-
-/** Calculates the new weight vector based on the partial gradients. In order to do that,
-  * all partial gradients are averaged and weighted by the current stepsize. This update value is
-  * added to the current weight vector.
-  *
-  * @param stepsize Initial value of the step size used to update the weight vector
-  */
-private class LinearRegressionWeightsUpdate(val stepsize: Double) extends
-RichMapFunction[(Array[Double], Double, Int), (Array[Double], Double)] {
-
-  import MultipleLinearRegression.WEIGHTVECTOR_BROADCAST
-
-  var weights: Array[Double] = null
-  var weight0: Double = 0.0
-
-  @throws(classOf[Exception])
-  override def open(configuration: Configuration): Unit = {
-    val list = this.getRuntimeContext.
-      getBroadcastVariable[(Array[Double], Double)](WEIGHTVECTOR_BROADCAST)
-
-    val weightsPair = list.get(0)
-
-    weights = weightsPair._1
-    weight0 = weightsPair._2
-  }
-
-  override def map(value: (Array[Double], Double, Int)): (Array[Double], Double) = {
-    val weightsGradient = value._1
-    blas.dscal(weightsGradient.length, 1.0/value._3, weightsGradient, 1)
-
-    val weight0Gradient = value._2 / value._3
-
-    val iteration = getIterationRuntimeContext.getSuperstepNumber
-
-    // scale initial stepsize by the inverse square root of the iteration number to make it
-    // decreasing
-    val effectiveStepsize = stepsize/math.sqrt(iteration)
-
-    val newWeights = weights.clone
-    blas.daxpy(newWeights.length, -effectiveStepsize, weightsGradient, 1, newWeights, 1)
-    val newWeight0 = weight0 - effectiveStepsize * weight0Gradient
-
-    (newWeights, newWeight0)
-  }
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/463300ec/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/pipeline/PipelineITSuite.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/pipeline/PipelineITSuite.scala b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/pipeline/PipelineITSuite.scala
index a36a0d1..c25ad79 100644
--- a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/pipeline/PipelineITSuite.scala
+++ b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/pipeline/PipelineITSuite.scala
@@ -170,11 +170,11 @@ class PipelineITSuite extends FlatSpec with Matchers with FlinkTestBase {
 
     val weightVector = predictor.weightsOption.get.collect().head
 
-    weightVector._1.foreach{
-      _ should be (0.367282 +- 0.01)
+    weightVector.weights.valueIterator.foreach{
+      _ should be (0.268050 +- 0.01)
     }
 
-    weightVector._2 should be (1.3131727 +- 0.01)
+    weightVector.intercept should be (0.807924 +- 0.01)
   }
 
   it should "throw an exception when the input data is not supported by a predictor" in {

http://git-wip-us.apache.org/repos/asf/flink/blob/463300ec/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/regression/MultipleLinearRegressionITSuite.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/regression/MultipleLinearRegressionITSuite.scala b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/regression/MultipleLinearRegressionITSuite.scala
index 30338e5..e42b87d 100644
--- a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/regression/MultipleLinearRegressionITSuite.scala
+++ b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/regression/MultipleLinearRegressionITSuite.scala
@@ -19,7 +19,7 @@
 package org.apache.flink.ml.regression
 
 import org.apache.flink.api.scala.ExecutionEnvironment
-import org.apache.flink.ml.common.ParameterMap
+import org.apache.flink.ml.common.{WeightVector, ParameterMap}
 import org.apache.flink.ml.preprocessing.PolynomialFeatures
 import org.scalatest.{Matchers, FlatSpec}
 
@@ -44,7 +44,7 @@ class MultipleLinearRegressionITSuite
 
     val parameters = ParameterMap()
 
-    parameters.add(MultipleLinearRegression.Stepsize, 1.0)
+    parameters.add(MultipleLinearRegression.Stepsize, 2.0)
     parameters.add(MultipleLinearRegression.Iterations, 10)
     parameters.add(MultipleLinearRegression.ConvergenceThreshold, 0.001)
 
@@ -55,13 +55,13 @@ class MultipleLinearRegressionITSuite
 
     weightList.size should equal(1)
 
-    val (weights, weight0) = weightList(0)
+    val WeightVector(weights, intercept) = weightList(0)
 
-    expectedWeights zip weights foreach {
+    expectedWeights.toIterator zip weights.valueIterator foreach {
       case (expectedWeight, weight) =>
         weight should be (expectedWeight +- 1)
     }
-    weight0 should be (expectedWeight0 +- 0.4)
+    intercept should be (expectedWeight0 +- 0.4)
 
     val srs = mlr.squaredResidualSum(inputDS).collect().apply(0)
 
@@ -82,7 +82,7 @@ class MultipleLinearRegressionITSuite
 
     val parameters = ParameterMap()
       .add(PolynomialFeatures.Degree, 3)
-      .add(MultipleLinearRegression.Stepsize, 0.002)
+      .add(MultipleLinearRegression.Stepsize, 0.004)
       .add(MultipleLinearRegression.Iterations, 100)
 
     pipeline.fit(inputDS, parameters)
@@ -91,14 +91,14 @@ class MultipleLinearRegressionITSuite
 
     weightList.size should equal(1)
 
-    val (weights, weight0) = weightList(0)
+    val WeightVector(weights, intercept) = weightList(0)
 
-    RegressionData.expectedPolynomialWeights.zip(weights) foreach {
+    RegressionData.expectedPolynomialWeights.toIterator.zip(weights.valueIterator) foreach {
       case (expectedWeight, weight) =>
         weight should be(expectedWeight +- 0.1)
     }
 
-    weight0 should be(RegressionData.expectedPolynomialWeight0 +- 0.1)
+    intercept should be(RegressionData.expectedPolynomialWeight0 +- 0.1)
 
     val transformedInput = polynomialBase.transform(inputDS, parameters)
 

http://git-wip-us.apache.org/repos/asf/flink/blob/463300ec/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/regression/RegressionData.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/regression/RegressionData.scala b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/regression/RegressionData.scala
index 8525c0f..062f510 100644
--- a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/regression/RegressionData.scala
+++ b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/regression/RegressionData.scala
@@ -25,7 +25,7 @@ object RegressionData {
 
   val expectedWeights = Array[Double](3.0094)
   val expectedWeight0: Double = 9.8158
-  val expectedSquaredResidualSum: Double = 49.7596
+  val expectedSquaredResidualSum: Double = 49.7596/2
 
   val data: Seq[LabeledVector] = Seq(
     LabeledVector(10.7949, DenseVector(0.2714)),
@@ -119,7 +119,7 @@ object RegressionData {
 
   val expectedPolynomialWeights = Seq(0.2375, -0.3493, -0.1674)
   val expectedPolynomialWeight0 = 0.0233
-  val expectedPolynomialSquaredResidualSum = 1.5389e+03
+  val expectedPolynomialSquaredResidualSum = 1.5389e+03/2
 
   val polynomialData: Seq[LabeledVector] = Seq(
     LabeledVector(2.1415, DenseVector(3.6663)),