You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2014/11/13 20:42:33 UTC

spark git commit: [SPARK-4378][MLLIB] make ALS more Java-friendly

Repository: spark
Updated Branches:
  refs/heads/master ce0333f9a -> ca26a212f


[SPARK-4378][MLLIB] make ALS more Java-friendly

Add Java-friendly version of `run` and `predict`, and use bulk prediction in Java unit tests. The user guide update will come later (though we may not save many lines of code there). srowen

Author: Xiangrui Meng <me...@databricks.com>

Closes #3240 from mengxr/SPARK-4378 and squashes the following commits:

6581503 [Xiangrui Meng] check number of predictions
6c8bbd1 [Xiangrui Meng] make ALS more Java-friendly


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/ca26a212
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/ca26a212
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/ca26a212

Branch: refs/heads/master
Commit: ca26a212fda39a15fde09dfdb2fbe69580a717f6
Parents: ce0333f
Author: Xiangrui Meng <me...@databricks.com>
Authored: Thu Nov 13 11:42:27 2014 -0800
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Thu Nov 13 11:42:27 2014 -0800

----------------------------------------------------------------------
 .../apache/spark/mllib/recommendation/ALS.scala | 17 +++--
 .../MatrixFactorizationModel.scala              | 15 ++--
 .../mllib/recommendation/JavaALSSuite.java      | 74 ++++++++------------
 3 files changed, 53 insertions(+), 53 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/ca26a212/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
index 84d192d..038edc3 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
@@ -20,20 +20,20 @@ package org.apache.spark.mllib.recommendation
 import scala.collection.mutable
 import scala.collection.mutable.ArrayBuffer
 import scala.math.{abs, sqrt}
-import scala.util.Random
-import scala.util.Sorting
+import scala.util.{Random, Sorting}
 import scala.util.hashing.byteswap32
 
 import org.jblas.{DoubleMatrix, SimpleBlas, Solve}
 
+import org.apache.spark.{HashPartitioner, Logging, Partitioner}
+import org.apache.spark.SparkContext._
 import org.apache.spark.annotation.{DeveloperApi, Experimental}
+import org.apache.spark.api.java.JavaRDD
 import org.apache.spark.broadcast.Broadcast
-import org.apache.spark.{Logging, HashPartitioner, Partitioner}
-import org.apache.spark.storage.StorageLevel
+import org.apache.spark.mllib.optimization.NNLS
 import org.apache.spark.rdd.RDD
-import org.apache.spark.SparkContext._
+import org.apache.spark.storage.StorageLevel
 import org.apache.spark.util.Utils
-import org.apache.spark.mllib.optimization.NNLS
 
 /**
  * Out-link information for a user or product block. This includes the original user/product IDs
@@ -326,6 +326,11 @@ class ALS private (
   }
 
   /**
+   * Java-friendly version of [[ALS.run]].
+   */
+  def run(ratings: JavaRDD[Rating]): MatrixFactorizationModel = run(ratings.rdd)
+
+  /**
    * Computes the (`rank x rank`) matrix `YtY`, where `Y` is the (`nui x rank`) matrix of factors
    * for each user (or product), in a distributed fashion.
    *

http://git-wip-us.apache.org/repos/asf/spark/blob/ca26a212/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
index 66b58ba..969e23b 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
@@ -17,13 +17,13 @@
 
 package org.apache.spark.mllib.recommendation
 
+import java.lang.{Integer => JavaInteger}
+
 import org.jblas.DoubleMatrix
 
-import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.api.java.JavaRDD
-import org.apache.spark.rdd.RDD
 import org.apache.spark.SparkContext._
-import org.apache.spark.mllib.api.python.SerDe
+import org.apache.spark.api.java.{JavaPairRDD, JavaRDD}
+import org.apache.spark.rdd.RDD
 
 /**
  * Model representing the result of matrix factorization.
@@ -66,6 +66,13 @@ class MatrixFactorizationModel private[mllib] (
   }
 
   /**
+   * Java-friendly version of [[MatrixFactorizationModel.predict]].
+   */
+  def predict(usersProducts: JavaPairRDD[JavaInteger, JavaInteger]): JavaRDD[Rating] = {
+    predict(usersProducts.rdd.asInstanceOf[RDD[(Int, Int)]]).toJavaRDD()
+  }
+
+  /**
    * Recommends products to a user.
    *
    * @param user the user to recommend products to

http://git-wip-us.apache.org/repos/asf/spark/blob/ca26a212/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java
----------------------------------------------------------------------
diff --git a/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java b/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java
index f6ca964..af688c5 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java
@@ -23,13 +23,14 @@ import java.util.List;
 import scala.Tuple2;
 import scala.Tuple3;
 
+import com.google.common.collect.Lists;
 import org.jblas.DoubleMatrix;
-
 import org.junit.After;
 import org.junit.Assert;
 import org.junit.Before;
 import org.junit.Test;
 
+import org.apache.spark.api.java.JavaPairRDD;
 import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.JavaSparkContext;
 
@@ -47,61 +48,48 @@ public class JavaALSSuite implements Serializable {
     sc = null;
   }
 
-  static void validatePrediction(
+  void validatePrediction(
       MatrixFactorizationModel model,
       int users,
       int products,
-      int features,
       DoubleMatrix trueRatings,
       double matchThreshold,
       boolean implicitPrefs,
       DoubleMatrix truePrefs) {
-    DoubleMatrix predictedU = new DoubleMatrix(users, features);
-    List<Tuple2<Object, double[]>> userFeatures = model.userFeatures().toJavaRDD().collect();
-    for (int i = 0; i < features; ++i) {
-      for (Tuple2<Object, double[]> userFeature : userFeatures) {
-        predictedU.put((Integer)userFeature._1(), i, userFeature._2()[i]);
-      }
-    }
-    DoubleMatrix predictedP = new DoubleMatrix(products, features);
-
-    List<Tuple2<Object, double[]>> productFeatures =
-      model.productFeatures().toJavaRDD().collect();
-    for (int i = 0; i < features; ++i) {
-      for (Tuple2<Object, double[]> productFeature : productFeatures) {
-        predictedP.put((Integer)productFeature._1(), i, productFeature._2()[i]);
+    List<Tuple2<Integer, Integer>> localUsersProducts =
+      Lists.newArrayListWithCapacity(users * products);
+    for (int u=0; u < users; ++u) {
+      for (int p=0; p < products; ++p) {
+        localUsersProducts.add(new Tuple2<Integer, Integer>(u, p));
       }
     }
-
-    DoubleMatrix predictedRatings = predictedU.mmul(predictedP.transpose());
-
+    JavaPairRDD<Integer, Integer> usersProducts = sc.parallelizePairs(localUsersProducts);
+    List<Rating> predictedRatings = model.predict(usersProducts).collect();
+    Assert.assertEquals(users * products, predictedRatings.size());
     if (!implicitPrefs) {
-      for (int u = 0; u < users; ++u) {
-        for (int p = 0; p < products; ++p) {
-          double prediction = predictedRatings.get(u, p);
-          double correct = trueRatings.get(u, p);
-          Assert.assertTrue(String.format("Prediction=%2.4f not below match threshold of %2.2f",
-                  prediction, matchThreshold), Math.abs(prediction - correct) < matchThreshold);
-        }
+      for (Rating r: predictedRatings) {
+        double prediction = r.rating();
+        double correct = trueRatings.get(r.user(), r.product());
+        Assert.assertTrue(String.format("Prediction=%2.4f not below match threshold of %2.2f",
+          prediction, matchThreshold), Math.abs(prediction - correct) < matchThreshold);
       }
     } else {
       // For implicit prefs we use the confidence-weighted RMSE to test
       // (ref Mahout's implicit ALS tests)
       double sqErr = 0.0;
       double denom = 0.0;
-      for (int u = 0; u < users; ++u) {
-        for (int p = 0; p < products; ++p) {
-          double prediction = predictedRatings.get(u, p);
-          double truePref = truePrefs.get(u, p);
-          double confidence = 1.0 + /* alpha = */ 1.0 * Math.abs(trueRatings.get(u, p));
-          double err = confidence * (truePref - prediction) * (truePref - prediction);
-          sqErr += err;
-          denom += confidence;
-        }
+      for (Rating r: predictedRatings) {
+        double prediction = r.rating();
+        double truePref = truePrefs.get(r.user(), r.product());
+        double confidence = 1.0 +
+          /* alpha = */ 1.0 * Math.abs(trueRatings.get(r.user(), r.product()));
+        double err = confidence * (truePref - prediction) * (truePref - prediction);
+        sqErr += err;
+        denom += confidence;
       }
       double rmse = Math.sqrt(sqErr / denom);
       Assert.assertTrue(String.format("Confidence-weighted RMSE=%2.4f above threshold of %2.2f",
-              rmse, matchThreshold), rmse < matchThreshold);
+        rmse, matchThreshold), rmse < matchThreshold);
     }
   }
 
@@ -116,7 +104,7 @@ public class JavaALSSuite implements Serializable {
 
     JavaRDD<Rating> data = sc.parallelize(testData._1());
     MatrixFactorizationModel model = ALS.train(data.rdd(), features, iterations);
-    validatePrediction(model, users, products, features, testData._2(), 0.3, false, testData._3());
+    validatePrediction(model, users, products, testData._2(), 0.3, false, testData._3());
   }
 
   @Test
@@ -132,8 +120,8 @@ public class JavaALSSuite implements Serializable {
 
     MatrixFactorizationModel model = new ALS().setRank(features)
       .setIterations(iterations)
-      .run(data.rdd());
-    validatePrediction(model, users, products, features, testData._2(), 0.3, false, testData._3());
+      .run(data);
+    validatePrediction(model, users, products, testData._2(), 0.3, false, testData._3());
   }
 
   @Test
@@ -147,7 +135,7 @@ public class JavaALSSuite implements Serializable {
 
     JavaRDD<Rating> data = sc.parallelize(testData._1());
     MatrixFactorizationModel model = ALS.trainImplicit(data.rdd(), features, iterations);
-    validatePrediction(model, users, products, features, testData._2(), 0.4, true, testData._3());
+    validatePrediction(model, users, products, testData._2(), 0.4, true, testData._3());
   }
 
   @Test
@@ -165,7 +153,7 @@ public class JavaALSSuite implements Serializable {
       .setIterations(iterations)
       .setImplicitPrefs(true)
       .run(data.rdd());
-    validatePrediction(model, users, products, features, testData._2(), 0.4, true, testData._3());
+    validatePrediction(model, users, products, testData._2(), 0.4, true, testData._3());
   }
 
   @Test
@@ -183,7 +171,7 @@ public class JavaALSSuite implements Serializable {
       .setImplicitPrefs(true)
       .setSeed(8675309L)
       .run(data.rdd());
-    validatePrediction(model, users, products, features, testData._2(), 0.4, true, testData._3());
+    validatePrediction(model, users, products, testData._2(), 0.4, true, testData._3());
   }
 
   @Test


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org