You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by td...@apache.org on 2014/03/20 03:05:39 UTC

git commit: [SPARK-1273] MLlib bug fixes, improvements, and doc updates for v0.9.1

Repository: spark
Updated Branches:
  refs/heads/branch-0.9 a4eef655c -> 1cc979e0a


[SPARK-1273] MLlib bug fixes, improvements, and doc updates for v0.9.1

Cherry-picked a few MLlib commits that are bug fixes, optimization, or doc updates for the v0.9.1 release.

JIRA: https://spark-project.atlassian.net/browse/SPARK-1273

Author: Xiangrui Meng <me...@databricks.com>
Author: Sean Owen <so...@cloudera.com>
Author: Andrew Tulloch <an...@tullo.ch>
Author: Chen Chao <cr...@gmail.com>

Closes #175 from mengxr/branch-0.9 and squashes the following commits:

d8928ea [Xiangrui Meng] add Apache header to LocalSparkContext
a66d386 [Xiangrui Meng] Merge remote-tracking branch 'apache/branch-0.9' into branch-0.9
a899894 [Xiangrui Meng] [SPARK-1237, 1238] Improve the computation of YtY for implicit ALS
46fe493 [Xiangrui Meng] [SPARK-1260]: faster construction of features with intercept
6340a18 [Sean Owen] MLLIB-22. Support negative implicit input in ALS
f27441a [Chen Chao] MLLIB-24:  url of "Collaborative Filtering for Implicit Feedback Datasets" in ALS is invalid now
a26ac90 [Sean Owen] Merge pull request #460 from srowen/RandomInitialALSVectors
0564985 [Andrew Tulloch] Fixed import order
2512e67 [Andrew Tulloch] LocalSparkContext for MLlib


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/1cc979e0
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/1cc979e0
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/1cc979e0

Branch: refs/heads/branch-0.9
Commit: 1cc979e0aad14a025c203a1e279c1ed4068d71de
Parents: a4eef65
Author: Xiangrui Meng <me...@databricks.com>
Authored: Wed Mar 19 19:05:26 2014 -0700
Committer: Tathagata Das <ta...@gmail.com>
Committed: Wed Mar 19 19:05:26 2014 -0700

----------------------------------------------------------------------
 .../apache/spark/mllib/recommendation/ALS.scala | 201 +++++++++++++------
 .../regression/GeneralizedLinearAlgorithm.scala |   8 +-
 .../mllib/recommendation/JavaALSSuite.java      |  32 ++-
 .../LogisticRegressionSuite.scala               |  15 +-
 .../mllib/classification/NaiveBayesSuite.scala  |  14 +-
 .../spark/mllib/classification/SVMSuite.scala   |  15 +-
 .../spark/mllib/clustering/KMeansSuite.scala    |  15 +-
 .../optimization/GradientDescentSuite.scala     |  13 +-
 .../spark/mllib/recommendation/ALSSuite.scala   |  56 ++++--
 .../spark/mllib/regression/LassoSuite.scala     |  16 +-
 .../regression/LinearRegressionSuite.scala      |  15 +-
 .../mllib/regression/RidgeRegressionSuite.scala |  14 +-
 .../spark/mllib/util/LocalSparkContext.scala    |  40 ++++
 13 files changed, 259 insertions(+), 195 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/1cc979e0/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
index 3e93402..44db51c 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.mllib.recommendation
 
 import scala.collection.mutable.{ArrayBuffer, BitSet}
+import scala.math.{abs, sqrt}
 import scala.util.Random
 import scala.util.Sorting
 
@@ -63,7 +64,7 @@ case class Rating(val user: Int, val product: Int, val rating: Double)
  * Alternating Least Squares matrix factorization.
  *
  * ALS attempts to estimate the ratings matrix `R` as the product of two lower-rank matrices,
- * `X` and `Y`, i.e. `Xt * Y = R`. Typically these approximations are called 'factor' matrices.
+ * `X` and `Y`, i.e. `X * Yt = R`. Typically these approximations are called 'factor' matrices.
  * The general approach is iterative. During each iteration, one of the factor matrices is held
  * constant, while the other is solved for using least squares. The newly-solved factor matrix is
  * then held constant while solving for the other factor matrix.
@@ -80,17 +81,22 @@ case class Rating(val user: Int, val product: Int, val rating: Double)
  *
  * For implicit preference data, the algorithm used is based on
  * "Collaborative Filtering for Implicit Feedback Datasets", available at
- * [[http://research.yahoo.com/pub/2433]], adapted for the blocked approach used here.
+ * [[http://dx.doi.org/10.1109/ICDM.2008.22]], adapted for the blocked approach used here.
  *
  * Essentially instead of finding the low-rank approximations to the rating matrix `R`,
  * this finds the approximations for a preference matrix `P` where the elements of `P` are 1 if r > 0
  * and 0 if r = 0. The ratings then act as 'confidence' values related to strength of indicated user
  * preferences rather than explicit ratings given to items.
  */
-class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var lambda: Double,
-                   var implicitPrefs: Boolean, var alpha: Double)
-  extends Serializable with Logging
-{
+class ALS private (
+    var numBlocks: Int,
+    var rank: Int,
+    var iterations: Int,
+    var lambda: Double,
+    var implicitPrefs: Boolean,
+    var alpha: Double,
+    var seed: Long = System.nanoTime()
+  ) extends Serializable with Logging {
   def this() = this(-1, 10, 10, 0.01, false, 1.0)
 
   /**
@@ -130,6 +136,12 @@ class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var l
     this
   }
 
+  /** Sets a random seed to have deterministic results. */
+  def setSeed(seed: Long): ALS = {
+    this.seed = seed
+    this
+  }
+
   /**
    * Run ALS with the configured parameters on an input RDD of (user, product, rating) triples.
    * Returns a MatrixFactorizationModel with feature vectors for each user and product.
@@ -151,9 +163,9 @@ class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var l
     val (userInLinks, userOutLinks) = makeLinkRDDs(numBlocks, ratingsByUserBlock)
     val (productInLinks, productOutLinks) = makeLinkRDDs(numBlocks, ratingsByProductBlock)
 
-    // Initialize user and product factors randomly, but use a deterministic seed for each partition
-    // so that fault recovery works
-    val seedGen = new Random()
+    // Initialize user and product factors randomly, but use a deterministic seed for each
+    // partition so that fault recovery works
+    val seedGen = new Random(seed)
     val seed1 = seedGen.nextInt()
     val seed2 = seedGen.nextInt()
     // Hash an integer to propagate random bits at all positions, similar to java.util.HashTable
@@ -208,22 +220,47 @@ class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var l
    */
   def computeYtY(factors: RDD[(Int, Array[Array[Double]])]) = {
     if (implicitPrefs) {
-      Option(
-        factors.flatMapValues { case factorArray =>
-          factorArray.view.map { vector =>
-            val x = new DoubleMatrix(vector)
-            x.mmul(x.transpose())
-          }
-        }.reduceByKeyLocally((a, b) => a.addi(b))
-         .values
-         .reduce((a, b) => a.addi(b))
-      )
+      val n = rank * (rank + 1) / 2
+      val LYtY = factors.values.aggregate(new DoubleMatrix(n))( seqOp = (L, Y) => {
+        Y.foreach(y => dspr(1.0, new DoubleMatrix(y), L))
+        L
+      }, combOp = (L1, L2) => {
+        L1.addi(L2)
+      })
+      val YtY = new DoubleMatrix(rank, rank)
+      fillFullMatrix(LYtY, YtY)
+      Option(YtY)
     } else {
       None
     }
   }
 
   /**
+   * Adds alpha * x * x.t to a matrix in-place. This is the same as BLAS's DSPR.
+   *
+   * @param L the lower triangular part of the matrix packed in an array (row major)
+   */
+  private def dspr(alpha: Double, x: DoubleMatrix, L: DoubleMatrix) = {
+    val n = x.length
+    var i = 0
+    var j = 0
+    var idx = 0
+    var axi = 0.0
+    val xd = x.data
+    val Ld = L.data
+    while (i < n) {
+      axi = alpha * xd(i)
+      j = 0
+      while (j <= i) {
+        Ld(idx) += axi * xd(j)
+        j += 1
+        idx += 1
+      }
+      i += 1
+    }
+  }
+
+  /**
    * Flatten out blocked user or product factors into an RDD of (id, factor vector) pairs
    */
   def unblockFactors(blockedFactors: RDD[(Int, Array[Array[Double]])],
@@ -301,7 +338,14 @@ class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var l
    * Make a random factor vector with the given random.
    */
   private def randomFactor(rank: Int, rand: Random): Array[Double] = {
-    Array.fill(rank)(rand.nextDouble)
+    // Choose a unit vector uniformly at random from the unit sphere, but from the
+    // "first quadrant" where all elements are nonnegative. This can be done by choosing
+    // elements distributed as Normal(0,1) and taking the absolute value, and then normalizing.
+    // This appears to create factorizations that have a slightly better reconstruction
+    // (<1%) compared picking elements uniformly at random in [0,1].
+    val factor = Array.fill(rank)(abs(rand.nextGaussian()))
+    val norm = sqrt(factor.map(x => x * x).sum)
+    factor.map(x => x / norm)
   }
 
   /**
@@ -365,7 +409,8 @@ class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var l
     for (productBlock <- 0 until numBlocks) {
       for (p <- 0 until blockFactors(productBlock).length) {
         val x = new DoubleMatrix(blockFactors(productBlock)(p))
-        fillXtX(x, tempXtX)
+        tempXtX.fill(0.0)
+        dspr(1.0, x, tempXtX)
         val (us, rs) = inLinkBlock.ratingsForBlock(productBlock)(p)
         for (i <- 0 until us.length) {
           implicitPrefs match {
@@ -373,43 +418,32 @@ class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var l
               userXtX(us(i)).addi(tempXtX)
               SimpleBlas.axpy(rs(i), x, userXy(us(i)))
             case true =>
-              userXtX(us(i)).addi(tempXtX.mul(alpha * rs(i)))
-              SimpleBlas.axpy(1 + alpha * rs(i), x, userXy(us(i)))
+              // Extension to the original paper to handle rs(i) < 0. confidence is a function
+              // of |rs(i)| instead so that it is never negative:
+              val confidence = 1 + alpha * abs(rs(i))
+              SimpleBlas.axpy(confidence - 1.0, tempXtX, userXtX(us(i)))
+              // For rs(i) < 0, the corresponding entry in P is 0 now, not 1 -- negative rs(i)
+              // means we try to reconstruct 0. We add terms only where P = 1, so, term below
+              // is now only added for rs(i) > 0:
+              if (rs(i) > 0) {
+                SimpleBlas.axpy(confidence, x, userXy(us(i)))
+              }
           }
         }
       }
     }
 
     // Solve the least-squares problem for each user and return the new feature vectors
-    userXtX.zipWithIndex.map{ case (triangularXtX, index) =>
+    Array.range(0, numUsers).map { index =>
       // Compute the full XtX matrix from the lower-triangular part we got above
-      fillFullMatrix(triangularXtX, fullXtX)
+      fillFullMatrix(userXtX(index), fullXtX)
       // Add regularization
       (0 until rank).foreach(i => fullXtX.data(i*rank + i) += lambda)
       // Solve the resulting matrix, which is symmetric and positive-definite
       implicitPrefs match {
         case false => Solve.solvePositive(fullXtX, userXy(index)).data
-        case true => Solve.solvePositive(fullXtX.add(YtY.value.get), userXy(index)).data
-      }
-    }
-  }
-
-  /**
-   * Set xtxDest to the lower-triangular part of x transpose * x. For efficiency in summing
-   * these matrices, we store xtxDest as only rank * (rank+1) / 2 values, namely the values
-   * at (0,0), (1,0), (1,1), (2,0), (2,1), (2,2), etc in that order.
-   */
-  private def fillXtX(x: DoubleMatrix, xtxDest: DoubleMatrix) {
-    var i = 0
-    var pos = 0
-    while (i < x.length) {
-      var j = 0
-      while (j <= i) {
-        xtxDest.data(pos) = x.data(i) * x.data(j)
-        pos += 1
-        j += 1
+        case true => Solve.solvePositive(fullXtX.addi(YtY.value.get), userXy(index)).data
       }
-      i += 1
     }
   }
 
@@ -436,9 +470,10 @@ class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var l
 
 
 /**
- * Top-level methods for calling Alternating Least Squares (ALS) matrix factorizaton.
+ * Top-level methods for calling Alternating Least Squares (ALS) matrix factorization.
  */
 object ALS {
+
   /**
    * Train a matrix factorization model given an RDD of ratings given by users to some products,
    * in the form of (userID, productID, rating) pairs. We approximate the ratings matrix as the
@@ -451,15 +486,39 @@ object ALS {
    * @param iterations number of iterations of ALS (recommended: 10-20)
    * @param lambda     regularization factor (recommended: 0.01)
    * @param blocks     level of parallelism to split computation into
+   * @param seed       random seed
    */
   def train(
       ratings: RDD[Rating],
       rank: Int,
       iterations: Int,
       lambda: Double,
-      blocks: Int)
-    : MatrixFactorizationModel =
-  {
+      blocks: Int,
+      seed: Long
+    ): MatrixFactorizationModel = {
+    new ALS(blocks, rank, iterations, lambda, false, 1.0, seed).run(ratings)
+  }
+
+  /**
+   * Train a matrix factorization model given an RDD of ratings given by users to some products,
+   * in the form of (userID, productID, rating) pairs. We approximate the ratings matrix as the
+   * product of two lower-rank matrices of a given rank (number of features). To solve for these
+   * features, we run a given number of iterations of ALS. This is done using a level of
+   * parallelism given by `blocks`.
+   *
+   * @param ratings    RDD of (userID, productID, rating) pairs
+   * @param rank       number of features to use
+   * @param iterations number of iterations of ALS (recommended: 10-20)
+   * @param lambda     regularization factor (recommended: 0.01)
+   * @param blocks     level of parallelism to split computation into
+   */
+  def train(
+      ratings: RDD[Rating],
+      rank: Int,
+      iterations: Int,
+      lambda: Double,
+      blocks: Int
+    ): MatrixFactorizationModel = {
     new ALS(blocks, rank, iterations, lambda, false, 1.0).run(ratings)
   }
 
@@ -476,8 +535,7 @@ object ALS {
    * @param lambda     regularization factor (recommended: 0.01)
    */
   def train(ratings: RDD[Rating], rank: Int, iterations: Int, lambda: Double)
-    : MatrixFactorizationModel =
-  {
+    : MatrixFactorizationModel = {
     train(ratings, rank, iterations, lambda, -1)
   }
 
@@ -493,8 +551,7 @@ object ALS {
    * @param iterations number of iterations of ALS (recommended: 10-20)
    */
   def train(ratings: RDD[Rating], rank: Int, iterations: Int)
-    : MatrixFactorizationModel =
-  {
+    : MatrixFactorizationModel = {
     train(ratings, rank, iterations, 0.01, -1)
   }
 
@@ -511,6 +568,7 @@ object ALS {
    * @param lambda     regularization factor (recommended: 0.01)
    * @param blocks     level of parallelism to split computation into
    * @param alpha      confidence parameter (only applies when immplicitPrefs = true)
+   * @param seed       random seed
    */
   def trainImplicit(
       ratings: RDD[Rating],
@@ -518,9 +576,34 @@ object ALS {
       iterations: Int,
       lambda: Double,
       blocks: Int,
-      alpha: Double)
-  : MatrixFactorizationModel =
-  {
+      alpha: Double,
+      seed: Long
+    ): MatrixFactorizationModel = {
+    new ALS(blocks, rank, iterations, lambda, true, alpha, seed).run(ratings)
+  }
+
+  /**
+   * Train a matrix factorization model given an RDD of 'implicit preferences' given by users
+   * to some products, in the form of (userID, productID, preference) pairs. We approximate the
+   * ratings matrix as the product of two lower-rank matrices of a given rank (number of features).
+   * To solve for these features, we run a given number of iterations of ALS. This is done using
+   * a level of parallelism given by `blocks`.
+   *
+   * @param ratings    RDD of (userID, productID, rating) pairs
+   * @param rank       number of features to use
+   * @param iterations number of iterations of ALS (recommended: 10-20)
+   * @param lambda     regularization factor (recommended: 0.01)
+   * @param blocks     level of parallelism to split computation into
+   * @param alpha      confidence parameter (only applies when immplicitPrefs = true)
+   */
+  def trainImplicit(
+      ratings: RDD[Rating],
+      rank: Int,
+      iterations: Int,
+      lambda: Double,
+      blocks: Int,
+      alpha: Double
+    ): MatrixFactorizationModel = {
     new ALS(blocks, rank, iterations, lambda, true, alpha).run(ratings)
   }
 
@@ -537,8 +620,7 @@ object ALS {
    * @param lambda     regularization factor (recommended: 0.01)
    */
   def trainImplicit(ratings: RDD[Rating], rank: Int, iterations: Int, lambda: Double, alpha: Double)
-  : MatrixFactorizationModel =
-  {
+    : MatrixFactorizationModel = {
     trainImplicit(ratings, rank, iterations, lambda, -1, alpha)
   }
 
@@ -555,8 +637,7 @@ object ALS {
    * @param iterations number of iterations of ALS (recommended: 10-20)
    */
   def trainImplicit(ratings: RDD[Rating], rank: Int, iterations: Int)
-  : MatrixFactorizationModel =
-  {
+    : MatrixFactorizationModel = {
     trainImplicit(ratings, rank, iterations, 0.01, -1, 1.0)
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/1cc979e0/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
index f98b0b5..b962153 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
@@ -119,7 +119,7 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
    */
   def run(input: RDD[LabeledPoint]) : M = {
     val nfeatures: Int = input.first().features.length
-    val initialWeights = Array.fill(nfeatures)(1.0)
+    val initialWeights = new Array[Double](nfeatures)
     run(input, initialWeights)
   }
 
@@ -134,15 +134,15 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
       throw new SparkException("Input validation failed.")
     }
 
-    // Add a extra variable consisting of all 1.0's for the intercept.
+    // Prepend an extra variable consisting of all 1.0's for the intercept.
     val data = if (addIntercept) {
-      input.map(labeledPoint => (labeledPoint.label, Array(1.0, labeledPoint.features:_*)))
+      input.map(labeledPoint => (labeledPoint.label, labeledPoint.features.+:(1.0)))
     } else {
       input.map(labeledPoint => (labeledPoint.label, labeledPoint.features))
     }
 
     val initialWeightsWithIntercept = if (addIntercept) {
-      Array(1.0, initialWeights:_*)
+      initialWeights.+:(1.0)
     } else {
       initialWeights
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/1cc979e0/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java
----------------------------------------------------------------------
diff --git a/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java b/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java
index b40f552..b150334 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java
@@ -19,7 +19,6 @@ package org.apache.spark.mllib.recommendation;
 
 import java.io.Serializable;
 import java.util.List;
-import java.lang.Math;
 
 import org.junit.After;
 import org.junit.Assert;
@@ -46,7 +45,7 @@ public class JavaALSSuite implements Serializable {
     System.clearProperty("spark.driver.port");
   }
 
-  void validatePrediction(MatrixFactorizationModel model, int users, int products, int features, 
+  static void validatePrediction(MatrixFactorizationModel model, int users, int products, int features,
       DoubleMatrix trueRatings, double matchThreshold, boolean implicitPrefs, DoubleMatrix truePrefs) {
     DoubleMatrix predictedU = new DoubleMatrix(users, features);
     List<scala.Tuple2<Object, double[]>> userFeatures = model.userFeatures().toJavaRDD().collect();
@@ -84,15 +83,15 @@ public class JavaALSSuite implements Serializable {
         for (int p = 0; p < products; ++p) {
           double prediction = predictedRatings.get(u, p);
           double truePref = truePrefs.get(u, p);
-          double confidence = 1.0 + /* alpha = */ 1.0 * trueRatings.get(u, p);
+          double confidence = 1.0 + /* alpha = */ 1.0 * Math.abs(trueRatings.get(u, p));
           double err = confidence * (truePref - prediction) * (truePref - prediction);
           sqErr += err;
-          denom += 1.0;
+          denom += confidence;
         }
       }
       double rmse = Math.sqrt(sqErr / denom);
       Assert.assertTrue(String.format("Confidence-weighted RMSE=%2.4f above threshold of %2.2f",
-              rmse, matchThreshold), Math.abs(rmse) < matchThreshold);
+              rmse, matchThreshold), rmse < matchThreshold);
     }
   }
 
@@ -103,7 +102,7 @@ public class JavaALSSuite implements Serializable {
     int users = 50;
     int products = 100;
     scala.Tuple3<List<Rating>, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList(
-        users, products, features, 0.7, false);
+        users, products, features, 0.7, false, false);
 
     JavaRDD<Rating> data = sc.parallelize(testData._1());
     MatrixFactorizationModel model = ALS.train(data.rdd(), features, iterations);
@@ -117,7 +116,7 @@ public class JavaALSSuite implements Serializable {
     int users = 100;
     int products = 200;
     scala.Tuple3<List<Rating>, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList(
-        users, products, features, 0.7, false);
+        users, products, features, 0.7, false, false);
 
     JavaRDD<Rating> data = sc.parallelize(testData._1());
 
@@ -134,7 +133,7 @@ public class JavaALSSuite implements Serializable {
     int users = 80;
     int products = 160;
     scala.Tuple3<List<Rating>, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList(
-      users, products, features, 0.7, true);
+        users, products, features, 0.7, true, false);
 
     JavaRDD<Rating> data = sc.parallelize(testData._1());
     MatrixFactorizationModel model = ALS.trainImplicit(data.rdd(), features, iterations);
@@ -148,7 +147,7 @@ public class JavaALSSuite implements Serializable {
     int users = 100;
     int products = 200;
     scala.Tuple3<List<Rating>, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList(
-      users, products, features, 0.7, true);
+        users, products, features, 0.7, true, false);
 
     JavaRDD<Rating> data = sc.parallelize(testData._1());
 
@@ -158,4 +157,19 @@ public class JavaALSSuite implements Serializable {
       .run(data.rdd());
     validatePrediction(model, users, products, features, testData._2(), 0.4, true, testData._3());
   }
+
+  @Test
+  public void runImplicitALSWithNegativeWeight() {
+    int features = 2;
+    int iterations = 15;
+    int users = 80;
+    int products = 160;
+    scala.Tuple3<List<Rating>, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList(
+        users, products, features, 0.7, true, true);
+
+    JavaRDD<Rating> data = sc.parallelize(testData._1());
+    MatrixFactorizationModel model = ALS.trainImplicit(data.rdd(), features, iterations);
+    validatePrediction(model, users, products, features, testData._2(), 0.4, true, testData._3());
+  }
+
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/1cc979e0/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
index 02ede71..05322b0 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
@@ -26,6 +26,7 @@ import org.scalatest.matchers.ShouldMatchers
 
 import org.apache.spark.SparkContext
 import org.apache.spark.mllib.regression._
+import org.apache.spark.mllib.util.LocalSparkContext
 
 object LogisticRegressionSuite {
 
@@ -66,19 +67,7 @@ object LogisticRegressionSuite {
 
 }
 
-class LogisticRegressionSuite extends FunSuite with BeforeAndAfterAll with ShouldMatchers {
-  @transient private var sc: SparkContext = _
-
-  override def beforeAll() {
-    sc = new SparkContext("local", "test")
-  }
-
-
-  override def afterAll() {
-    sc.stop()
-    System.clearProperty("spark.driver.port")
-  }
-
+class LogisticRegressionSuite extends FunSuite with LocalSparkContext with ShouldMatchers {
   def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) {
     val numOffPredictions = predictions.zip(input).count { case (prediction, expected) =>
       prediction != expected.label

http://git-wip-us.apache.org/repos/asf/spark/blob/1cc979e0/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
index b615f76..9dd6c79 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
@@ -23,7 +23,7 @@ import org.scalatest.BeforeAndAfterAll
 import org.scalatest.FunSuite
 
 import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.SparkContext
+import org.apache.spark.mllib.util.LocalSparkContext
 
 object NaiveBayesSuite {
 
@@ -59,17 +59,7 @@ object NaiveBayesSuite {
   }
 }
 
-class NaiveBayesSuite extends FunSuite with BeforeAndAfterAll {
-  @transient private var sc: SparkContext = _
-
-  override def beforeAll() {
-    sc = new SparkContext("local", "test")
-  }
-
-  override def afterAll() {
-    sc.stop()
-    System.clearProperty("spark.driver.port")
-  }
+class NaiveBayesSuite extends FunSuite with LocalSparkContext {
 
   def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) {
     val numOfPredictions = predictions.zip(input).count {

http://git-wip-us.apache.org/repos/asf/spark/blob/1cc979e0/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala
index 3357b86..bc7abb5 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala
@@ -25,8 +25,9 @@ import org.scalatest.FunSuite
 
 import org.jblas.DoubleMatrix
 
-import org.apache.spark.{SparkException, SparkContext}
+import org.apache.spark.SparkException
 import org.apache.spark.mllib.regression._
+import org.apache.spark.mllib.util.LocalSparkContext
 
 object SVMSuite {
 
@@ -58,17 +59,7 @@ object SVMSuite {
 
 }
 
-class SVMSuite extends FunSuite with BeforeAndAfterAll {
-  @transient private var sc: SparkContext = _
-
-  override def beforeAll() {
-    sc = new SparkContext("local", "test")
-  }
-
-  override def afterAll() {
-    sc.stop()
-    System.clearProperty("spark.driver.port")
-  }
+class SVMSuite extends FunSuite with LocalSparkContext {
 
   def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) {
     val numOffPredictions = predictions.zip(input).count { case (prediction, expected) =>

http://git-wip-us.apache.org/repos/asf/spark/blob/1cc979e0/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
index 73657ca..4ef1d1f 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
@@ -21,20 +21,9 @@ package org.apache.spark.mllib.clustering
 import org.scalatest.BeforeAndAfterAll
 import org.scalatest.FunSuite
 
-import org.apache.spark.SparkContext
+import org.apache.spark.mllib.util.LocalSparkContext
 
-
-class KMeansSuite extends FunSuite with BeforeAndAfterAll {
-  @transient private var sc: SparkContext = _
-
-  override def beforeAll() {
-    sc = new SparkContext("local", "test")
-  }
-
-  override def afterAll() {
-    sc.stop()
-    System.clearProperty("spark.driver.port")
-  }
+class KMeansSuite extends FunSuite with LocalSparkContext {
 
   val EPSILON = 1e-4
 

http://git-wip-us.apache.org/repos/asf/spark/blob/1cc979e0/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala
index a6028a1..a453de6 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala
@@ -26,6 +26,7 @@ import org.scalatest.matchers.ShouldMatchers
 
 import org.apache.spark.SparkContext
 import org.apache.spark.mllib.regression._
+import org.apache.spark.mllib.util.LocalSparkContext
 
 object GradientDescentSuite {
 
@@ -62,17 +63,7 @@ object GradientDescentSuite {
   }
 }
 
-class GradientDescentSuite extends FunSuite with BeforeAndAfterAll with ShouldMatchers {
-  @transient private var sc: SparkContext = _
-
-  override def beforeAll() {
-    sc = new SparkContext("local", "test")
-  }
-
-  override def afterAll() {
-    sc.stop()
-    System.clearProperty("spark.driver.port")
-  }
+class GradientDescentSuite extends FunSuite with LocalSparkContext with ShouldMatchers {
 
   test("Assert the loss is decreasing.") {
     val nPoints = 10000

http://git-wip-us.apache.org/repos/asf/spark/blob/1cc979e0/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala
index 4e8dbde..5aab9ab 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala
@@ -18,14 +18,15 @@
 package org.apache.spark.mllib.recommendation
 
 import scala.collection.JavaConversions._
+import scala.math.abs
 import scala.util.Random
 
-import org.scalatest.BeforeAndAfterAll
 import org.scalatest.FunSuite
 
-import org.apache.spark.SparkContext
+import org.jblas.DoubleMatrix
 
-import org.jblas._
+import org.apache.spark.mllib.util.LocalSparkContext
+import org.apache.spark.SparkContext._
 
 object ALSSuite {
 
@@ -34,7 +35,8 @@ object ALSSuite {
       products: Int,
       features: Int,
       samplingRate: Double,
-      implicitPrefs: Boolean): (java.util.List[Rating], DoubleMatrix, DoubleMatrix) = {
+      implicitPrefs: Boolean,
+      negativeWeights: Boolean): (java.util.List[Rating], DoubleMatrix, DoubleMatrix) = {
     val (sampledRatings, trueRatings, truePrefs) =
       generateRatings(users, products, features, samplingRate, implicitPrefs)
     (seqAsJavaList(sampledRatings), trueRatings, truePrefs)
@@ -45,7 +47,8 @@ object ALSSuite {
       products: Int,
       features: Int,
       samplingRate: Double,
-      implicitPrefs: Boolean = false): (Seq[Rating], DoubleMatrix, DoubleMatrix) = {
+      implicitPrefs: Boolean = false,
+      negativeWeights: Boolean = false): (Seq[Rating], DoubleMatrix, DoubleMatrix) = {
     val rand = new Random(42)
 
     // Create a random matrix with uniform values from -1 to 1
@@ -56,7 +59,9 @@ object ALSSuite {
     val productMatrix = randomMatrix(features, products)
     val (trueRatings, truePrefs) = implicitPrefs match {
       case true =>
-        val raw = new DoubleMatrix(users, products, Array.fill(users * products)(rand.nextInt(10).toDouble): _*)
+        // Generate raw values from [0,9], or if negativeWeights, from [-2,7]
+        val raw = new DoubleMatrix(users, products,
+          Array.fill(users * products)((if (negativeWeights) -2 else 0) + rand.nextInt(10).toDouble): _*)
         val prefs = new DoubleMatrix(users, products, raw.data.map(v => if (v > 0) 1.0 else 0.0): _*)
         (raw, prefs)
       case false => (userMatrix.mmul(productMatrix), null)
@@ -73,17 +78,7 @@ object ALSSuite {
 }
 
 
-class ALSSuite extends FunSuite with BeforeAndAfterAll {
-  @transient private var sc: SparkContext = _
-
-  override def beforeAll() {
-    sc = new SparkContext("local", "test")
-  }
-
-  override def afterAll() {
-    sc.stop()
-    System.clearProperty("spark.driver.port")
-  }
+class ALSSuite extends FunSuite with LocalSparkContext {
 
   test("rank-1 matrices") {
     testALS(50, 100, 1, 15, 0.7, 0.3)
@@ -117,6 +112,22 @@ class ALSSuite extends FunSuite with BeforeAndAfterAll {
     testALS(100, 200, 2, 15, 0.7, 0.4, true, true)
   }
 
+  test("rank-2 matrices implicit negative") {
+    testALS(100, 200, 2, 15, 0.7, 0.4, true, false, true)
+  }
+
+  test("pseudorandomness") {
+    val ratings = sc.parallelize(ALSSuite.generateRatings(10, 20, 5, 0.5, false, false)._1, 2)
+    val model11 = ALS.train(ratings, 5, 1, 1.0, 2, 1)
+    val model12 = ALS.train(ratings, 5, 1, 1.0, 2, 1)
+    val u11 = model11.userFeatures.values.flatMap(_.toList).collect().toList
+    val u12 = model12.userFeatures.values.flatMap(_.toList).collect().toList
+    val model2 = ALS.train(ratings, 5, 1, 1.0, 2, 2)
+    val u2 = model2.userFeatures.values.flatMap(_.toList).collect().toList
+    assert(u11 == u12)
+    assert(u11 != u2)
+  }
+
   /**
    * Test if we can correctly factorize R = U * P where U and P are of known rank.
    *
@@ -128,13 +139,14 @@ class ALSSuite extends FunSuite with BeforeAndAfterAll {
    * @param matchThreshold max difference allowed to consider a predicted rating correct
    * @param implicitPrefs  flag to test implicit feedback
    * @param bulkPredict    flag to test bulk prediciton
+   * @param negativeWeights whether the generated data can contain negative values
    */
   def testALS(users: Int, products: Int, features: Int, iterations: Int,
     samplingRate: Double, matchThreshold: Double, implicitPrefs: Boolean = false,
-    bulkPredict: Boolean = false)
+    bulkPredict: Boolean = false, negativeWeights: Boolean = false)
   {
     val (sampledRatings, trueRatings, truePrefs) = ALSSuite.generateRatings(users, products,
-      features, samplingRate, implicitPrefs)
+      features, samplingRate, implicitPrefs, negativeWeights)
     val model = implicitPrefs match {
       case false => ALS.train(sc.parallelize(sampledRatings), features, iterations)
       case true => ALS.trainImplicit(sc.parallelize(sampledRatings), features, iterations)
@@ -176,13 +188,13 @@ class ALSSuite extends FunSuite with BeforeAndAfterAll {
       for (u <- 0 until users; p <- 0 until products) {
         val prediction = predictedRatings.get(u, p)
         val truePref = truePrefs.get(u, p)
-        val confidence = 1 + 1.0 * trueRatings.get(u, p)
+        val confidence = 1 + 1.0 * abs(trueRatings.get(u, p))
         val err = confidence * (truePref - prediction) * (truePref - prediction)
         sqErr += err
-        denom += 1
+        denom += confidence
       }
       val rmse = math.sqrt(sqErr / denom)
-      if (math.abs(rmse) > matchThreshold) {
+      if (rmse > matchThreshold) {
         fail("Model failed to predict RMSE: %f\ncorr: %s\npred: %s\nU: %s\n P: %s".format(
           rmse, truePrefs, predictedRatings, predictedU, predictedP))
       }

http://git-wip-us.apache.org/repos/asf/spark/blob/1cc979e0/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala
index b2c8df9..64e4cbb 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala
@@ -22,21 +22,9 @@ import org.scalatest.BeforeAndAfterAll
 import org.scalatest.FunSuite
 
 import org.apache.spark.SparkContext
-import org.apache.spark.mllib.util.LinearDataGenerator
+import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext}
 
-
-class LassoSuite extends FunSuite with BeforeAndAfterAll {
-  @transient private var sc: SparkContext = _
-
-  override def beforeAll() {
-    sc = new SparkContext("local", "test")
-  }
-
-
-  override def afterAll() {
-    sc.stop()
-    System.clearProperty("spark.driver.port")
-  }
+class LassoSuite extends FunSuite with LocalSparkContext {
 
   def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) {
     val numOffPredictions = predictions.zip(input).count { case (prediction, expected) =>

http://git-wip-us.apache.org/repos/asf/spark/blob/1cc979e0/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala
index 406afba..281f9df 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala
@@ -20,20 +20,9 @@ package org.apache.spark.mllib.regression
 import org.scalatest.BeforeAndAfterAll
 import org.scalatest.FunSuite
 
-import org.apache.spark.SparkContext
-import org.apache.spark.mllib.util.LinearDataGenerator
+import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext}
 
-class LinearRegressionSuite extends FunSuite with BeforeAndAfterAll {
-  @transient private var sc: SparkContext = _
-
-  override def beforeAll() {
-    sc = new SparkContext("local", "test")
-  }
-
-  override def afterAll() {
-    sc.stop()
-    System.clearProperty("spark.driver.port")
-  }
+class LinearRegressionSuite extends FunSuite with LocalSparkContext {
 
   def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) {
     val numOffPredictions = predictions.zip(input).count { case (prediction, expected) =>

http://git-wip-us.apache.org/repos/asf/spark/blob/1cc979e0/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala
index 1d6a10b..67dd06c 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala
@@ -22,20 +22,10 @@ import org.jblas.DoubleMatrix
 import org.scalatest.BeforeAndAfterAll
 import org.scalatest.FunSuite
 
-import org.apache.spark.SparkContext
-import org.apache.spark.mllib.util.LinearDataGenerator
+import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext}
 
-class RidgeRegressionSuite extends FunSuite with BeforeAndAfterAll {
-  @transient private var sc: SparkContext = _
 
-  override def beforeAll() {
-    sc = new SparkContext("local", "test")
-  }
-
-  override def afterAll() {
-    sc.stop()
-    System.clearProperty("spark.driver.port")
-  }
+class RidgeRegressionSuite extends FunSuite with LocalSparkContext {
 
   def predictionError(predictions: Seq[Double], input: Seq[LabeledPoint]) = {
     predictions.zip(input).map { case (prediction, expected) =>

http://git-wip-us.apache.org/repos/asf/spark/blob/1cc979e0/mllib/src/test/scala/org/apache/spark/mllib/util/LocalSparkContext.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/LocalSparkContext.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/LocalSparkContext.scala
new file mode 100644
index 0000000..212fbe9
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/util/LocalSparkContext.scala
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.util
+
+import org.scalatest.Suite
+import org.scalatest.BeforeAndAfterAll
+
+import org.apache.spark.SparkContext
+
+trait LocalSparkContext extends BeforeAndAfterAll { self: Suite =>
+  @transient var sc: SparkContext = _
+
+  override def beforeAll() {
+    sc = new SparkContext("local", "test")
+    super.beforeAll()
+  }
+
+  override def afterAll() {
+    if (sc != null) {
+      sc.stop()
+    }
+    System.clearProperty("spark.driver.port")
+    super.afterAll()
+  }
+}