You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by tr...@apache.org on 2015/03/17 23:45:27 UTC

[8/8] flink git commit: [ml] Introduces FlinkTools containing persist methods.

[ml] Introduces FlinkTools containing persist methods.

[ml] Changes comments into proper ScalaDoc in MultipleLinearRegression


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/86485434
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/86485434
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/86485434

Branch: refs/heads/master
Commit: 8648543425a36da40eabc9ffb7220664d78c02e7
Parents: 458c524
Author: Till Rohrmann <tr...@apache.org>
Authored: Fri Mar 6 15:27:25 2015 +0100
Committer: Till Rohrmann <tr...@apache.org>
Committed: Tue Mar 17 23:28:34 2015 +0100

----------------------------------------------------------------------
 .../org/apache/flink/ml/common/FlinkTools.scala | 259 ++++++++++++
 .../apache/flink/ml/recommendation/ALS.scala    | 401 ++++++++++---------
 .../regression/MultipleLinearRegression.scala   | 177 ++++----
 .../flink/ml/recommendation/ALSSuite.scala      |   2 +-
 .../MultipleLinearRegressionSuite.scala         |   2 +-
 5 files changed, 571 insertions(+), 270 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/86485434/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/FlinkTools.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/FlinkTools.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/FlinkTools.scala
new file mode 100644
index 0000000..d972960
--- /dev/null
+++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/FlinkTools.scala
@@ -0,0 +1,259 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common
+
+import org.apache.flink.api.common.io.FileOutputFormat.OutputDirectoryMode
+import org.apache.flink.api.common.typeinfo.TypeInformation
+import org.apache.flink.api.java.io.{TypeSerializerInputFormat, TypeSerializerOutputFormat}
+import org.apache.flink.api.scala.DataSet
+import org.apache.flink.core.fs.FileSystem.WriteMode
+import org.apache.flink.core.fs.Path
+
+import scala.reflect.ClassTag
+
+/**
+ * Collection of convenience functions
+ */
+object FlinkTools {
+
+  /**
+   *
+   * @param dataset
+   * @param path
+   * @tparam T
+   * @return
+   */
+  def persist[T: ClassTag: TypeInformation](dataset: DataSet[T], path: String): DataSet[T] = {
+    val env = dataset.getExecutionEnvironment
+    val outputFormat = new TypeSerializerOutputFormat[T]
+
+    val filePath = new Path(path)
+
+    outputFormat.setOutputFilePath(filePath)
+    outputFormat.setWriteMode(WriteMode.OVERWRITE)
+
+    dataset.output(outputFormat)
+    env.execute("FlinkTools persist")
+
+    val inputFormat = new TypeSerializerInputFormat[T](dataset.getType.createSerializer())
+    inputFormat.setFilePath(filePath)
+
+    env.createInput(inputFormat)
+  }
+
+  def persist[A: ClassTag: TypeInformation ,B: ClassTag: TypeInformation](ds1: DataSet[A], ds2:
+  DataSet[B], path1: String, path2: String):(DataSet[A], DataSet[B])  = {
+    val env = ds1.getExecutionEnvironment
+
+    val f1 = new Path(path1)
+
+    val of1 = new TypeSerializerOutputFormat[A]
+    of1.setOutputFilePath(f1)
+    of1.setWriteMode(WriteMode.OVERWRITE)
+
+    ds1.output(of1)
+
+    val f2 = new Path(path2)
+
+    val of2 = new TypeSerializerOutputFormat[B]
+    of2.setOutputFilePath(f2)
+    of2.setWriteMode(WriteMode.OVERWRITE)
+
+    ds2.output(of2)
+
+    env.execute("FlinkTools persist")
+
+    val if1 = new TypeSerializerInputFormat[A](ds1.getType.createSerializer())
+    if1.setFilePath(f1)
+
+    val if2 = new TypeSerializerInputFormat[B](ds2.getType.createSerializer())
+    if2.setFilePath(f2)
+
+    (env.createInput(if1), env.createInput(if2))
+  }
+
+  def persist[A: ClassTag: TypeInformation ,B: ClassTag: TypeInformation,
+  C: ClassTag: TypeInformation](ds1: DataSet[A], ds2:  DataSet[B], ds3: DataSet[C], path1:
+  String, path2: String, path3: String): (DataSet[A], DataSet[B], DataSet[C])  = {
+    val env = ds1.getExecutionEnvironment
+
+    val f1 = new Path(path1)
+
+    val of1 = new TypeSerializerOutputFormat[A]
+    of1.setOutputFilePath(f1)
+    of1.setWriteMode(WriteMode.OVERWRITE)
+
+    ds1.output(of1)
+
+    val f2 = new Path(path2)
+
+    val of2 = new TypeSerializerOutputFormat[B]
+    of2.setOutputFilePath(f2)
+    of2.setWriteMode(WriteMode.OVERWRITE)
+
+    ds2.output(of2)
+
+    val f3 = new Path(path3)
+
+    val of3 = new TypeSerializerOutputFormat[C]
+    of3.setOutputFilePath(f3)
+    of3.setWriteMode(WriteMode.OVERWRITE)
+
+    ds3.output(of3)
+
+    env.execute("FlinkTools persist")
+
+    val if1 = new TypeSerializerInputFormat[A](ds1.getType.createSerializer())
+    if1.setFilePath(f1)
+
+    val if2 = new TypeSerializerInputFormat[B](ds2.getType.createSerializer())
+    if2.setFilePath(f2)
+
+    val if3 = new TypeSerializerInputFormat[C](ds3.getType.createSerializer())
+    if3.setFilePath(f3)
+
+    (env.createInput(if1), env.createInput(if2), env.createInput(if3))
+  }
+
+  def persist[A: ClassTag: TypeInformation ,B: ClassTag: TypeInformation,
+  C: ClassTag: TypeInformation, D: ClassTag: TypeInformation](ds1: DataSet[A], ds2:  DataSet[B],
+                                                              ds3: DataSet[C], ds4: DataSet[D],
+                                                              path1: String, path2: String, path3:
+                                                              String, path4: String):
+  (DataSet[A], DataSet[B], DataSet[C], DataSet[D])  = {
+    val env = ds1.getExecutionEnvironment
+
+    val f1 = new Path(path1)
+
+    val of1 = new TypeSerializerOutputFormat[A]
+    of1.setOutputFilePath(f1)
+    of1.setWriteMode(WriteMode.OVERWRITE)
+
+    ds1.output(of1)
+
+    val f2 = new Path(path2)
+
+    val of2 = new TypeSerializerOutputFormat[B]
+    of2.setOutputFilePath(f2)
+    of2.setWriteMode(WriteMode.OVERWRITE)
+
+    ds2.output(of2)
+
+    val f3 = new Path(path3)
+
+    val of3 = new TypeSerializerOutputFormat[C]
+    of3.setOutputFilePath(f3)
+    of3.setWriteMode(WriteMode.OVERWRITE)
+
+    ds3.output(of3)
+
+    val f4 = new Path(path4)
+
+    val of4 = new TypeSerializerOutputFormat[D]
+    of4.setOutputFilePath(f4)
+    of4.setWriteMode(WriteMode.OVERWRITE)
+
+    ds4.output(of4)
+
+    env.execute("FlinkTools persist")
+
+    val if1 = new TypeSerializerInputFormat[A](ds1.getType.createSerializer())
+    if1.setFilePath(f1)
+
+    val if2 = new TypeSerializerInputFormat[B](ds2.getType.createSerializer())
+    if2.setFilePath(f2)
+
+    val if3 = new TypeSerializerInputFormat[C](ds3.getType.createSerializer())
+    if3.setFilePath(f3)
+
+    val if4 = new TypeSerializerInputFormat[D](ds4.getType.createSerializer())
+    if4.setFilePath(f4)
+
+    (env.createInput(if1), env.createInput(if2), env.createInput(if3), env.createInput(if4))
+  }
+
+  def persist[A: ClassTag: TypeInformation ,B: ClassTag: TypeInformation,
+  C: ClassTag: TypeInformation, D: ClassTag: TypeInformation, E: ClassTag: TypeInformation]
+  (ds1: DataSet[A], ds2:  DataSet[B], ds3: DataSet[C], ds4: DataSet[D], ds5: DataSet[E], path1:
+  String, path2: String, path3: String, path4: String, path5: String): (DataSet[A], DataSet[B],
+    DataSet[C], DataSet[D], DataSet[E])  = {
+    val env = ds1.getExecutionEnvironment
+
+    val f1 = new Path(path1)
+
+    val of1 = new TypeSerializerOutputFormat[A]
+    of1.setOutputFilePath(f1)
+    of1.setWriteMode(WriteMode.OVERWRITE)
+
+    ds1.output(of1)
+
+    val f2 = new Path(path2)
+
+    val of2 = new TypeSerializerOutputFormat[B]
+    of2.setOutputFilePath(f2)
+    of2.setOutputDirectoryMode(OutputDirectoryMode.ALWAYS)
+    of2.setWriteMode(WriteMode.OVERWRITE)
+
+    ds2.output(of2)
+
+    val f3 = new Path(path3)
+
+    val of3 = new TypeSerializerOutputFormat[C]
+    of3.setOutputFilePath(f3)
+    of3.setWriteMode(WriteMode.OVERWRITE)
+
+    ds3.output(of3)
+
+    val f4 = new Path(path4)
+
+    val of4 = new TypeSerializerOutputFormat[D]
+    of4.setOutputFilePath(f4)
+    of4.setWriteMode(WriteMode.OVERWRITE)
+
+    ds4.output(of4)
+
+    val f5 = new Path(path5)
+
+    val of5 = new TypeSerializerOutputFormat[E]
+    of5.setOutputFilePath(f5)
+    of5.setWriteMode(WriteMode.OVERWRITE)
+
+    ds5.output(of5)
+
+    env.execute("FlinkTools persist")
+
+    val if1 = new TypeSerializerInputFormat[A](ds1.getType.createSerializer())
+    if1.setFilePath(f1)
+
+    val if2 = new TypeSerializerInputFormat[B](ds2.getType.createSerializer())
+    if2.setFilePath(f2)
+
+    val if3 = new TypeSerializerInputFormat[C](ds3.getType.createSerializer())
+    if3.setFilePath(f3)
+
+    val if4 = new TypeSerializerInputFormat[D](ds4.getType.createSerializer())
+    if4.setFilePath(f4)
+
+    val if5 = new TypeSerializerInputFormat[E](ds5.getType.createSerializer())
+    if5.setFilePath(f5)
+
+    (env.createInput(if1), env.createInput(if2), env.createInput(if3), env.createInput(if4), env
+      .createInput(if5))
+  }
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/86485434/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/recommendation/ALS.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/recommendation/ALS.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/recommendation/ALS.scala
index 261252a..1051ae5 100644
--- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/recommendation/ALS.scala
+++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/recommendation/ALS.scala
@@ -19,15 +19,10 @@ package org.apache.flink.ml.recommendation
 
 import java.lang
 
-import org.apache.flink.api.common.ExecutionConfig
 import org.apache.flink.api.scala._
 import org.apache.flink.api.common.operators.Order
-import org.apache.flink.api.common.typeinfo.TypeInformation
-import org.apache.flink.api.java.io.{TypeSerializerInputFormat, TypeSerializerOutputFormat}
-import org.apache.flink.core.fs.FileSystem.WriteMode
-import org.apache.flink.core.fs.Path
 import org.apache.flink.core.memory.{DataOutputView, DataInputView}
-import org.apache.flink.ml.common.{Parameter, ParameterMap, Transformer, Learner}
+import org.apache.flink.ml.common._
 import org.apache.flink.ml.recommendation.ALS.Factors
 import org.apache.flink.types.Value
 import org.apache.flink.util.Collector
@@ -37,89 +32,160 @@ import org.jblas.{Solve, SimpleBlas, DoubleMatrix}
 
 import scala.collection.mutable
 import scala.collection.mutable.ArrayBuffer
-import scala.reflect.ClassTag
 import scala.util.Random
 
-/**
- * Alternating least squares algorithm to calculate a matrix factorization.
- * The implementation uses weighted-lambda-regularization to avoid overfitting.
- *
- * Parameters:
- *
- *  - NumFactors: The number of latent factors. It is the dimension of the calculated
- *    user and item vectors.
- *
- *  - Lambda: Regularization factor. Tune this value in order to avoid overfitting/generalization.
- *
- *  - Iterations: The number of iterations to perform.
- *
- *  - Blocks: The number of blocks into which the user and item matrix a grouped. The fewer
- *  blocks one uses, the less data is sent redundantly. However, bigger blocks entail bigger
- *  update messages which have to be stored on the Heap. If the algorithm fails because of
- *  an OutOfMemoryException, then try to increase the number of blocks.
- *
- *  - Seed: Random seed used to generate the initial item matrix for the algorithm
- *
- *  - PersistencePath: Path to a directory into which intermediate results are stored. If
- *  this value is set, then the algorithm is split into two preprocessing steps, the ALS iteration
- *  and a post-processing step which calculates a last ALS half-step. The preprocessing steps
- *  calculate the [[org.apache.flink.ml.recommendation.ALS.OutBlockInformation]] and [[org.apache
- *  .flink.ml.recommendation.ALS.InBlockInformation]] for the given rating matrix. The result of
- *  the individual steps are stored in the specified directory. By splitting the algorithm
- *  into multiple smaller steps, Flink does not have to split the available memory amongst too many
- *  operators. This allows the system to process bigger individual messasges and improves the
- *  overall performance.
- *
- * The ALS implementation is based on Spark's MLLib implementation of ALS [https://github.com/
- * apache/spark/blob/master/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala].
- */
+/** Alternating least squares algorithm to calculate a matrix factorization.
+  *
+  * Given a matrix `R`, ALS calculates two matricess `U` and `V` such that `R ~~ U^TV`. The
+  * unknown row dimension is given by the number of latent factors. Since matrix factorization
+  * is often used in the context of recommendation, we'll call the first matrix the user and the
+  * second matrix the item matrix. The `i`th column of the user matrix is `u_i` and the `i`th
+  * column of the item matrix is `v_i`. The matrix `R` is called the ratings matrix and
+  * `(R)_{i,j} = r_{i,j}`.
+  *
+  * In order to find the user and item matrix the following problem is solved:
+  *
+  * `argmin_{U,V} sum_(i,j\ with\ r_{i,j} != 0) (r_{i,j} - u_{i}^Tv_{j})^2 +
+  * \lambda (sum_(i) n_{u_i} ||u_i||^2 + sum_(j) n_{v_j} ||v_j||^2)`
+  *
+  * Overfitting is avoided by using a weighted-lambda-regularization scheme.
+  *
+  * By fixing one of the matrices `U` or `V` one obtains a quadratic form which can be solved. The
+  * solution of the modified problem is guaranteed to decrease the overall cost function. By
+  * applying this step alternately to the matrices `U` and `V`, we can iteratively improve the
+  * overall solution. Details can be found in the work of
+  * [[http://dx.doi.org/10.1007/978-3-540-68880-8_32 Zhou et al.]].
+  *
+  * The matrix `R` is given in its sparse representation as a tuple of `(i, j, r)` where `i` is the
+  * row index, `j` is the column index and `r` is the matrix a position `(i,j)`.
+  *
+  * @example
+  *          {{{
+  *             val inputDS: DataSet[(Int, Int, Double)] = env.readCsvFile[(Int, Int, Double)](
+  *               pathToTrainingFile)
+  *
+  *             val als = ALS()
+  *               .setIterations(10)
+  *               .setNumFactors(10)
+  *
+  *             val model = als.fit(inputDS))
+  *
+  *             val data2Predict: DataSet[(Int, Int)] = env.readCsvFile[(Int, Int)](pathToData)
+  *
+  *             model.transform(data2Predict)
+  *          }}}
+  *
+  * =Parameters=
+  *
+  *  - [[ALS.NumFactors]]:
+  *  The number of latent factors. It is the dimension of the calculated user and item vectors.
+  *
+  *  - [[ALS.Lambda]]:
+  *  Regularization factor. Tune this value in order to avoid overfitting/generalization.
+  *
+  *  - [[ALS.Iterations]]: The number of iterations to perform.
+  *
+  *  - [[ALS.Blocks]]:
+  *  The number of blocks into which the user and item matrix a grouped. The fewer
+  *  blocks one uses, the less data is sent redundantly. However, bigger blocks entail bigger
+  *  update messages which have to be stored on the Heap. If the algorithm fails because of
+  *  an OutOfMemoryException, then try to increase the number of blocks.
+  *
+  *  - [[ALS.Seed]]:
+  *  Random seed used to generate the initial item matrix for the algorithm
+  *
+  *  - [[ALS.TemporaryPath]]:
+  *  Path to a temporary directory into which intermediate results are stored. If
+  *  this value is set, then the algorithm is split into two preprocessing steps, the ALS iteration
+  *  and a post-processing step which calculates a last ALS half-step. The preprocessing steps
+  *  calculate the [[org.apache.flink.ml.recommendation.ALS.OutBlockInformation]] and [[org.apache
+  *  .flink.ml.recommendation.ALS.InBlockInformation]] for the given rating matrix. The result of
+  *  the individual steps are stored in the specified directory. By splitting the algorithm
+  *  into multiple smaller steps, Flink does not have to split the available memory amongst too many
+  *  operators. This allows the system to process bigger individual messasges and improves the
+  *  overall performance.
+  *
+  * The ALS implementation is based on Spark's MLLib implementation of ALS which you can find
+  * [[https://github.com/apache/spark/blob/master/mllib/src/main/scala/org/apache/spark/mllib/
+  * recommendation/ALS.scala here]].
+  */
 class ALS extends Learner[(Int, Int, Double), ALSModel] with Serializable {
 
   import ALS._
-  
+
+  /** Sets the number of latent factors/row dimension of the latent model
+    *
+    * @param numFactors
+    * @return
+    */
   def setNumFactors(numFactors: Int): ALS = {
     parameters.add(NumFactors, numFactors)
     this
   }
 
+  /** Sets the regularization coefficient lambda
+    *
+    * @param lambda
+    * @return
+    */
   def setLambda(lambda: Double): ALS = {
     parameters.add(Lambda, lambda)
     this
   }
 
+  /** Sets the number of iterations of the ALS algorithm
+    * 
+    * @param iterations
+    * @return
+    */
   def setIterations(iterations: Int): ALS = {
     parameters.add(Iterations, iterations)
     this
   }
 
+  /** Sets the number of blocks into which the user and item matrix shall be partitioned
+    * 
+    * @param blocks
+    * @return
+    */
   def setBlocks(blocks: Int): ALS = {
     parameters.add(Blocks, blocks)
     this
   }
 
+  /** Sets the random seed for the initial item matrix initialization
+    * 
+    * @param seed
+    * @return
+    */
   def setSeed(seed: Long): ALS = {
     parameters.add(Seed, seed)
     this
   }
-  
-  def setPersistencePath(persistencePath: String): ALS = {
-    parameters.add(PersistencePath, persistencePath)
+
+  /** Sets the temporary path into which intermediate results are written in order to increase
+    * performance.
+    * 
+    * @param temporaryPath
+    * @return
+    */
+  def setTemporaryPath(temporaryPath: String): ALS = {
+    parameters.add(TemporaryPath, temporaryPath)
     this
   }
 
-  /**
-   * Calculates the matrix factorization for the given ratings. A rating is defined as
-   * a tuple of user ID, item ID and the corresponding rating.
-   *
-   * @param input Set of user/item ratings for which the factorization has to be calculated
-   * @return Factorization containing the user and item matrix
-   */
+  /** Calculates the matrix factorization for the given ratings. A rating is defined as
+    * a tuple of user ID, item ID and the corresponding rating.
+    *
+    * @param input Set of user/item ratings for which the factorization has to be calculated
+    * @return Factorization containing the user and item matrix
+    */
   def fit(input: DataSet[(Int, Int, Double)], fitParameters: ParameterMap): ALSModel = {
     val resultParameters = this.parameters ++ fitParameters
 
     val userBlocks = resultParameters.get(Blocks).getOrElse(input.count.toInt)
     val itemBlocks = userBlocks
-    val persistencePath = resultParameters.get(PersistencePath)
+    val persistencePath = resultParameters.get(TemporaryPath)
     val seed = resultParameters(Seed)
     val factors = resultParameters(NumFactors)
     val iterations = resultParameters(Iterations)
@@ -152,12 +218,12 @@ class ALS extends Learner[(Int, Int, Double), ALSModel] with Serializable {
       blockIDPartitioner)
 
     val (userIn, userOut) = persistencePath match {
-      case Some(path) => persist(uIn, uOut, path + "userIn", path + "userOut")
+      case Some(path) => FlinkTools.persist(uIn, uOut, path + "userIn", path + "userOut")
       case None => (uIn, uOut)
     }
 
     val (itemIn, itemOut) = persistencePath match {
-      case Some(path) => persist(iIn, iOut, path + "itemIn", path + "itemOut")
+      case Some(path) => FlinkTools.persist(iIn, iOut, path + "itemIn", path + "itemOut")
       case None => (iIn, iOut)
     }
 
@@ -183,7 +249,7 @@ class ALS extends Learner[(Int, Int, Double), ALSModel] with Serializable {
     }
 
     val pItems = persistencePath match {
-      case Some(path) => persist(items, path + "items")
+      case Some(path) => FlinkTools.persist(items, path + "items")
       case None => items
     }
 
@@ -195,19 +261,18 @@ class ALS extends Learner[(Int, Int, Double), ALSModel] with Serializable {
       blockIDPartitioner), lambda)
   }
 
-  /**
-   * Calculates a single half step of the ALS optimization. The result is the new value for
-   * either the user or item matrix, depending with which matrix the method was called.
-   *
-   * @param numUserBlocks Number of blocks in the respective dimension
-   * @param items Fixed matrix value for the half step
-   * @param itemOut Out information to know where to send the vectors
-   * @param userIn In information for the cogroup step
-   * @param factors Number of latent factors
-   * @param lambda Regularization constant
-   * @param blockIDPartitioner Custom Flink partitioner
-   * @return New value for the optimized matrix (either user or item)
-   */
+  /** Calculates a single half step of the ALS optimization. The result is the new value for
+    * either the user or item matrix, depending with which matrix the method was called.
+    *
+    * @param numUserBlocks Number of blocks in the respective dimension
+    * @param items Fixed matrix value for the half step
+    * @param itemOut Out information to know where to send the vectors
+    * @param userIn In information for the cogroup step
+    * @param factors Number of latent factors
+    * @param lambda Regularization constant
+    * @param blockIDPartitioner Custom Flink partitioner
+    * @return New value for the optimized matrix (either user or item)
+    */
   def updateFactors(numUserBlocks: Int,
                     items: DataSet[(Int, Array[Array[Double]])],
                     itemOut: DataSet[(Int, OutBlockInformation)],
@@ -352,16 +417,14 @@ class ALS extends Learner[(Int, Int, Double), ALSModel] with Serializable {
     }.withForwardedFieldsFirst("0").withForwardedFieldsSecond("0")
   }
 
-  /**
-   * Creates the meta information needed to route the item and user vectors to the respective user
-   * and item blocks.
-   *
-   * @param userBlocks
-   * @param itemBlocks
-   * @param ratings
-   * @param blockIDPartitioner
-   * @return
-   */
+  /** Creates the meta information needed to route the item and user vectors to the respective user
+    * and item blocks.
+    * * @param userBlocks
+    * @param itemBlocks
+    * @param ratings
+    * @param blockIDPartitioner
+    * @return
+    */
   def createBlockInformation(userBlocks: Int, itemBlocks: Int, ratings: DataSet[(Int, Rating)],
                              blockIDPartitioner: BlockIDPartitioner):
   (DataSet[(Int, InBlockInformation)], DataSet[(Int, OutBlockInformation)]) = {
@@ -377,12 +440,11 @@ class ALS extends Learner[(Int, Int, Double), ALSModel] with Serializable {
     (inBlockInfos, outBlockInfos)
   }
 
-  /**
-   * Calculates the userIDs in ascending order of each user block
-   *
-   * @param ratings
-   * @return
-   */
+  /** Calculates the userIDs in ascending order of each user block
+    *
+    * @param ratings
+    * @return
+    */
   def createUsersPerBlock(ratings: DataSet[(Int, Rating)]): DataSet[(Int, Array[Int])] = {
     ratings.map{ x => (x._1, x._2.user)}.withForwardedFields("0").groupBy(0).
       sortGroup(1, Order.ASCENDING).reduceGroup {
@@ -408,18 +470,20 @@ class ALS extends Learner[(Int, Int, Double), ALSModel] with Serializable {
     }.withForwardedFields("0")
   }
 
-  /**
-   * Creates for every user block the out-going block information. The out block information
-   * contains for every item block a bitset which indicates which user vector has to be sent to
-   * this block. If a vector v has to be sent to a block b, then bitsets(b)'s bit v is
-   * set to 1, otherwise 0. Additionally the user IDataSet are replaced by the user vector's index value.
-   *
-   * @param ratings
-   * @param usersPerBlock
-   * @param itemBlocks
-   * @param blockIDGenerator
-   * @return
-   */
+  /** Creates the outgoing block information
+    *
+    * Creates for every user block the outgoing block information. The out block information
+    * contains for every item block a [[scala.collection.mutable.BitSet]] which indicates which
+    * user vector has to be sent to this block. If a vector v has to be sent to a block b, then
+    * bitsets(b)'s bit v is set to 1, otherwise 0. Additionally the user IDataSet are replaced by
+    * the user vector's index value.
+    *
+    * @param ratings
+    * @param usersPerBlock
+    * @param itemBlocks
+    * @param blockIDGenerator
+    * @return
+    */
   def createOutBlockInformation(ratings: DataSet[(Int, Rating)],
                                 usersPerBlock: DataSet[(Int, Array[Int])],
                                 itemBlocks: Int, blockIDGenerator: BlockIDGenerator):
@@ -454,20 +518,21 @@ class ALS extends Learner[(Int, Int, Double), ALSModel] with Serializable {
     }.withForwardedFieldsFirst("0").withForwardedFieldsSecond("0")
   }
 
-  /**
-   * Creates for every user block the incoming block information. The incoming block information
-   * contains the userIDs of the users in the respective block and for every item block a
-   * BlockRating instance. The BlockRating instance describes for every incoming set of item
-   * vectors of an item block, which user rated these items and what the rating was. For that
-   * purpose it contains for every incoming item vector a tuple of an id array us and a rating
-   * array rs. The array us contains the indices of the users having rated the respective
-   * item vector with the ratings in rs.
-   *
-   * @param ratings
-   * @param usersPerBlock
-   * @param blockIDGenerator
-   * @return
-   */
+  /** Creates the incoming block information
+    *
+    * Creates for every user block the incoming block information. The incoming block information
+    * contains the userIDs of the users in the respective block and for every item block a
+    * BlockRating instance. The BlockRating instance describes for every incoming set of item
+    * vectors of an item block, which user rated these items and what the rating was. For that
+    * purpose it contains for every incoming item vector a tuple of an id array us and a rating
+    * array rs. The array us contains the indices of the users having rated the respective
+    * item vector with the ratings in rs.
+    *
+    * @param ratings
+    * @param usersPerBlock
+    * @param blockIDGenerator
+    * @return
+    */
   def createInBlockInformation(ratings: DataSet[(Int, Rating)],
                                usersPerBlock: DataSet[(Int, Array[Int])],
                                blockIDGenerator: BlockIDGenerator):
@@ -549,7 +614,8 @@ class ALS extends Learner[(Int, Int, Double), ALSModel] with Serializable {
         (Int, Array[Int]), (Int, InBlockInformation)] {
         val buffer = ArrayBuffer[BlockRating]()
 
-        override def coGroup(partialInfosIterable: lang.Iterable[(Int, Int, Array[(Array[Int], Array[Double])])],
+        override def coGroup(partialInfosIterable:
+                             lang.Iterable[(Int, Int,  Array[(Array[Int], Array[Double])])],
                              userIterable: lang.Iterable[(Int, Array[Int])],
                              collector: Collector[(Int, InBlockInformation)]): Unit = {
 
@@ -603,15 +669,14 @@ class ALS extends Learner[(Int, Int, Double), ALSModel] with Serializable {
     }.withForwardedFieldsFirst("0").withForwardedFieldsSecond("0")
   }
 
-  /**
-   * Unblocks the blocked user and item matrix representation so that it is at DataSet of
-   * column vectors.
-   *
-   * @param users
-   * @param outInfo
-   * @param blockIDPartitioner
-   * @return
-   */
+  /** Unblocks the blocked user and item matrix representation so that it is at DataSet of
+    * column vectors.
+    *
+    * @param users
+    * @param outInfo
+    * @param blockIDPartitioner
+    * @return
+    */
   def unblock(users: DataSet[(Int, Array[Array[Double]])],
               outInfo: DataSet[(Int, OutBlockInformation)],
               blockIDPartitioner: BlockIDPartitioner): DataSet[Factors] = {
@@ -685,60 +750,6 @@ class ALS extends Learner[(Int, Int, Double), ALSModel] with Serializable {
   def randomFactors(factors: Int, random: Random): Array[Double] = {
     Array.fill(factors)(random.nextDouble())
   }
-
-  // ========================= Convenience functions to persist DataSets ===========================
-
-  def persist[T: ClassTag: TypeInformation](dataset: DataSet[T], path: String): DataSet[T] = {
-    val env = dataset.getExecutionEnvironment
-    val outputFormat = new TypeSerializerOutputFormat[T]
-
-    val filePath = new Path(path)
-
-    outputFormat.setOutputFilePath(filePath)
-    outputFormat.setWriteMode(WriteMode.OVERWRITE)
-
-    dataset.output(outputFormat)
-    env.execute("FlinkTools persist")
-
-    val inputFormat = new TypeSerializerInputFormat[T](dataset.getType)
-    inputFormat.setFilePath(filePath)
-
-    env.createInput(inputFormat)
-  }
-
-  def persist[A: ClassTag: TypeInformation ,B: ClassTag: TypeInformation](ds1: DataSet[A],
-                                                                          ds2: DataSet[B],
-                                                                          path1: String,
-                                                                          path2: String):
-  (DataSet[A], DataSet[B])  = {
-    val env = ds1.getExecutionEnvironment
-
-    val f1 = new Path(path1)
-
-    val of1 = new TypeSerializerOutputFormat[A]
-    of1.setOutputFilePath(f1)
-    of1.setWriteMode(WriteMode.OVERWRITE)
-
-    ds1.output(of1)
-
-    val f2 = new Path(path2)
-
-    val of2 = new TypeSerializerOutputFormat[B]
-    of2.setOutputFilePath(f2)
-    of2.setWriteMode(WriteMode.OVERWRITE)
-
-    ds2.output(of2)
-
-    env.execute("FlinkTools persist")
-
-    val if1 = new TypeSerializerInputFormat[A](ds1.getType)
-    if1.setFilePath(f1)
-
-    val if2 = new TypeSerializerInputFormat[B](ds2.getType)
-    if2.setFilePath(f2)
-
-    (env.createInput(if1), env.createInput(if2))
-  }
 }
 
 object ALS {
@@ -765,27 +776,25 @@ object ALS {
     val defaultValue: Option[Long] = Some(0L)
   }
 
-  case object PersistencePath extends Parameter[String] {
+  case object TemporaryPath extends Parameter[String] {
     val defaultValue: Option[String] = None
   }
 
   // ==================================== ALS type definitions =====================================
 
-  /**
-   * Representation of a user-item rating
-   *
-   * @param user User ID of the rating user
-   * @param item Item iD of the rated item
-   * @param rating Rating value
-   */
+  /** Representation of a user-item rating
+    *
+    * @param user User ID of the rating user
+    * @param item Item iD of the rated item
+    * @param rating Rating value
+    */
   case class Rating(user: Int, item: Int, rating: Double)
 
-  /**
-   * Representation of a factors vector of latent factor model
-   *
-   * @param id
-   * @param factors
-   */
+  /** Latent factor model vector
+    *
+    * @param id
+    * @param factors
+    */
   case class Factors(id: Int, factors: Array[Double]) {
     override def toString = s"($id, ${factors.mkString(",")})"
   }
@@ -866,8 +875,24 @@ object ALS {
       id % blocks
     }
   }
+
+  // ========================= Factory methods =====================================
+
+  def apply(): ALS = {
+    new ALS()
+  }
 }
 
+/** Resulting model of the ALS algorithm.
+  *
+  * It contains the calculated factors, user and item matrix, of the given
+  * ratings matrix. Additionally it stores the used regularization value lambda in order to
+  * calculate the empirical risk of the model.
+  *
+  * @param userFactors Calculated user matrix
+  * @param itemFactors Calcualted item matrix
+  * @param lambda Regularization value used to calculate the model
+  */
 class ALSModel(@transient val userFactors: DataSet[Factors],@transient val itemFactors: DataSet[Factors],
                val lambda: Double) extends Transformer[(Int, Int), (Int, Int, Double)] with
 Serializable{

http://git-wip-us.apache.org/repos/asf/flink/blob/86485434/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/regression/MultipleLinearRegression.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/regression/MultipleLinearRegression.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/regression/MultipleLinearRegression.scala
index b04bf3e..b1d9242 100644
--- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/regression/MultipleLinearRegression.scala
+++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/regression/MultipleLinearRegression.scala
@@ -29,45 +29,61 @@ import org.apache.flink.api.scala._
 
 import org.jblas.{SimpleBlas, DoubleMatrix}
 
-/**
- * Multiple linear regression using the ordinary least squares (OLS) estimator.
- *
- * The linear regression finds a solution to the problem
- *
- *    y = w0 + w1*x1 + w2*x2 ... + wn*xn = w0 + w^T*x
- *
- * such that the sum of squared residuals is minimized
- *
- *    min_{w, w0} = \sum (y - w^T*x - w0)^2
- *
- * The minimization problem is solved by (stochastic) gradient descent. For each labeled vector
- * (x,y), the gradient is calculated. The weighted average of all gradients is subtracted from
- * the current value of b which gives the new value of b. The weight is defined as
- * Stepsize/math.sqrt(iteration).
- *
- * The optimization runs at most Iterations or, if a ConvergenceThreshold has been set, until the
- * convergence criterion has been met. As convergence criterion the relative change of the sum of 
- * squared residuals is used:
- * 
- * Convergence if: (S_{k-1} - S_k)/S_{k-1} < ConvergenceThreshold
- * 
- * with S_k being the sum of squared residuals in iteration k.
- *
- * At the moment, the whole partition is used for SGD, making it effectively a batch gradient
- * descent. Once a sampling operator has been introduced, the algorithm can be optimized.
- * 
- * Parameters:
- * 
- *  - Iterations: Maximum number of iterations
- *
- *  - Stepsize: Initial stepsize for the gradient descent method. This value decides how far the
- *      gradient descent method goes in the direction of the gradient. Tuning this parameter can
- *      lead to better practical results.
- *
- *  - ConvergenceThreshold: Threshold for relative change of sum of squared residuals until
- *      convergence
- *  
- */
+/** Multiple linear regression using the ordinary least squares (OLS) estimator.
+  *
+  * The linear regression finds a solution to the problem
+  *
+  * `y = w_0 + w_1*x_1 + w_2*x_2 ... + w_n*x_n = w_0 + w^T*x`
+  *
+  * such that the sum of squared residuals is minimized
+  *
+  * `min_{w, w_0} = \sum (y - w^T*x - w_0)^2`
+  *
+  * The minimization problem is solved by (stochastic) gradient descent. For each labeled vector
+  * `(x,y)`, the gradient is calculated. The weighted average of all gradients is subtracted from
+  * the current value `w` which gives the new value of `w_new`. The weight is defined as
+  * `stepsize/math.sqrt(iteration)`.
+  *
+  * The optimization runs at most a maximum number of iteratinos or, if a convergence threshold has
+  * been set, until the convergence criterion has been met. As convergence criterion the relative
+  * change of the sum of squared residuals is used:
+  *
+  * `(S_{k-1} - S_k)/S_{k-1} < \rho`
+  *
+  * with S_k being the sum of squared residuals in iteration k and `\rho` being the convergence
+  * threshold.
+  *
+  * At the moment, the whole partition is used for SGD, making it effectively a batch gradient
+  * descent. Once a sampling operator has been introduced, the algorithm can be optimized.
+  *
+  * @example
+  *          {{{
+  *             val mlr = MultipleLinearRegression()
+  *               .setIterations(10)
+  *               .setStepsize(0.5)
+  *               .setConvergenceThreshold(0.001)
+  *
+  *             val trainingDS: DataSet[LabeledVector] = ...
+  *             val data: DataSet[Vector] = ...
+  *
+  *             val model = mlr.fit(trainingDS)
+  *
+  *             val predictions = model.transform(data)
+  *          }}}
+  *
+  * =Parameters=
+  *
+  *  - [[MultipleLinearRegression.Iterations]]: Maximum number of iterations.
+  *
+  *  - [[MultipleLinearRegression.Stepsize]]:
+  *  Initial stepsize for the gradient descent method. This value decides how far the gradient
+  *  descent method goes in the direction of the gradient. Tuning this parameter can lead to better
+  *  practical results.
+  *
+  *  - [[MultipleLinearRegression.ConvergenceThreshold]]:
+  *  Threshold for relative change of sum of squared residuals until convergence.
+  *
+  */
 class MultipleLinearRegression extends Learner[LabeledVector, MultipleLinearRegressionModel]
 with Serializable {
   import MultipleLinearRegression._
@@ -192,14 +208,13 @@ with Serializable {
     new MultipleLinearRegressionModel(resultingWeightVector)
   }
 
-  /**
-   * Creates a DataSet with one zero vector. The zero vector has dimension d, which is given
-   * by the dimensionDS.
-   *
-   * @param dimensionDS DataSet with one element d, denoting the dimension of the returned zero
-   *                    vector
-   * @return DataSet of a zero vector of dimension d
-   */
+  /** Creates a DataSet with one zero vector. The zero vector has dimension d, which is given
+    * by the dimensionDS.
+    *
+    * @param dimensionDS DataSet with one element d, denoting the dimension of the returned zero
+    *                    vector
+    * @return DataSet of a zero vector of dimension d
+    */
   private def createInitialWeightVector(dimensionDS: DataSet[Int]):
   DataSet[(DoubleMatrix, Double)] = {
     dimensionDS.map {
@@ -225,19 +240,24 @@ object MultipleLinearRegression {
   case object ConvergenceThreshold extends Parameter[Double] {
     val defaultValue = None
   }
+
+  // ====================== Facotry methods ==========================
+
+  def apply(): MultipleLinearRegression = {
+    new MultipleLinearRegression()
+  }
 }
 
 //--------------------------------------------------------------------------------------------------
 //  Flink function definitions
 //--------------------------------------------------------------------------------------------------
 
-/**
- * Calculates for a labeled vector and the current weight vector its squared residual
- *
- *    (y - (w^Tx + w0))^2
- *
- * The weight vector is received as a broadcast variable.
- */
+/** Calculates for a labeled vector and the current weight vector its squared residual:
+  *
+  * `(y - (w^Tx + w_0))^2`
+  *
+  * The weight vector is received as a broadcast variable.
+  */
 private class SquaredResiduals extends RichMapFunction[LabeledVector, Double] {
   import MultipleLinearRegression.WEIGHTVECTOR_BROADCAST
 
@@ -265,15 +285,14 @@ private class SquaredResiduals extends RichMapFunction[LabeledVector, Double] {
   }
 }
 
-/**
- * Calculates for a labeled vector and the current weight vector the gradient minimizing the
- * OLS equation. The gradient is given by: 
- * 
- *    dw = 2*(w^T*x + w0 - y)*x
- *    dw0 = 2*(w^T*x + w0 - y)
- * 
- * The weight vector is received as a broadcast variable.
- */
+/** Calculates for a labeled vector and the current weight vector the gradient minimizing the
+  * OLS equation. The gradient is given by:
+  *
+  * `dw = 2*(w^T*x + w_0 - y)*x`
+  * `dw_0 = 2*(w^T*x + w_0 - y)`
+  *
+  * The weight vector is received as a broadcast variable.
+  */
 private class LinearRegressionGradientDescent extends
 RichMapFunction[LabeledVector, (DoubleMatrix, Double, Int)] {
 
@@ -306,13 +325,12 @@ RichMapFunction[LabeledVector, (DoubleMatrix, Double, Int)] {
   }
 }
 
-/**
- * Calculates the new weight vector based on the partial gradients. In order to do that,
- * all partial gradients are averaged and weighted by the current stepsize. This update value is
- * added to the current weight vector.
- * 
- * @param stepsize Initial value of the step size used to update the weight vector
- */
+/** Calculates the new weight vector based on the partial gradients. In order to do that,
+  * all partial gradients are averaged and weighted by the current stepsize. This update value is
+  * added to the current weight vector.
+  *
+  * @param stepsize Initial value of the step size used to update the weight vector
+  */
 private class LinearRegressionWeightsUpdate(val stepsize: Double) extends
 RichMapFunction[(DoubleMatrix, Double, Int), (DoubleMatrix, Double)] {
 
@@ -355,16 +373,15 @@ RichMapFunction[(DoubleMatrix, Double, Int), (DoubleMatrix, Double)] {
 //  Model definition
 //--------------------------------------------------------------------------------------------------
 
-/**
- * Multiple linear regression model returned by [[MultipleLinearRegression]]. The model stores the
- * calculated weight vector and applies the linear model to given vectors v:
- *
- *    \hat{y} = w^T*v + w0
- *
- * with \hat{y} being the predicted regression value.
- * 
- * @param weights DataSet containing the calculated weight vector
- */
+/** Multiple linear regression model returned by [[MultipleLinearRegression]]. The model stores the
+  * calculated weight vector and applies the linear model to given vectors v:
+  *
+  * `hat y = w^T*v + w_0`
+  *
+  * with `hat y` being the predicted regression value.
+  *
+  * @param weights DataSet containing the calculated weight vector
+  */
 class MultipleLinearRegressionModel private[regression]
 (val weights: DataSet[(DoubleMatrix, Double)]) extends
 Transformer[ Vector, LabeledVector ] {

http://git-wip-us.apache.org/repos/asf/flink/blob/86485434/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/recommendation/ALSSuite.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/recommendation/ALSSuite.scala b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/recommendation/ALSSuite.scala
index 0b17a86..770d4d2 100644
--- a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/recommendation/ALSSuite.scala
+++ b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/recommendation/ALSSuite.scala
@@ -32,7 +32,7 @@ class ALSSuite extends FlatSpec with ShouldMatchers {
 
     val env = ExecutionEnvironment.getExecutionEnvironment
 
-    val als = new ALS()
+    val als = ALS()
     .setIterations(iterations)
     .setLambda(lambda)
     .setBlocks(4)

http://git-wip-us.apache.org/repos/asf/flink/blob/86485434/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/regression/MultipleLinearRegressionSuite.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/regression/MultipleLinearRegressionSuite.scala b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/regression/MultipleLinearRegressionSuite.scala
index d365b15..8d59b49 100644
--- a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/regression/MultipleLinearRegressionSuite.scala
+++ b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/regression/MultipleLinearRegressionSuite.scala
@@ -30,7 +30,7 @@ class MultipleLinearRegressionSuite extends FlatSpec with ShouldMatchers {
   it should "estimate the correct linear function" in {
     val env = ExecutionEnvironment.getExecutionEnvironment
 
-    val learner = new MultipleLinearRegression()
+    val learner = MultipleLinearRegression()
 
     import RegressionData._