You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by pw...@apache.org on 2014/01/08 01:57:25 UTC

[7/9] git commit: Added predictAll python function to MatrixFactorizationModel

Added predictAll python function to MatrixFactorizationModel


Project: http://git-wip-us.apache.org/repos/asf/incubator-spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-spark/commit/754f5300
Tree: http://git-wip-us.apache.org/repos/asf/incubator-spark/tree/754f5300
Diff: http://git-wip-us.apache.org/repos/asf/incubator-spark/diff/754f5300

Branch: refs/heads/master
Commit: 754f5300a1e0a214b62cbd6db2398dea4dfbceb4
Parents: 04132ea
Author: Hossein Falaki <fa...@gmail.com>
Authored: Mon Jan 6 12:19:43 2014 -0800
Committer: Hossein Falaki <fa...@gmail.com>
Committed: Mon Jan 6 12:19:43 2014 -0800

----------------------------------------------------------------------
 python/pyspark/mllib/recommendation.py | 10 ++++++----
 1 file changed, 6 insertions(+), 4 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/754f5300/python/pyspark/mllib/recommendation.py
----------------------------------------------------------------------
diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py
index c81b482..0eeb5bb 100644
--- a/python/pyspark/mllib/recommendation.py
+++ b/python/pyspark/mllib/recommendation.py
@@ -21,8 +21,7 @@ from pyspark.mllib._common import \
     _serialize_double_matrix, _deserialize_double_matrix, \
     _serialize_double_vector, _deserialize_double_vector, \
     _get_initial_weights, _serialize_rating, _regression_train_wrapper, \
-    _serialize_tuple, _deserialize_rating
-from pyspark.serializers import BatchedSerializer
+    _serialize_tuple, RatingDeserializer
 from pyspark.rdd import RDD
 
 class MatrixFactorizationModel(object):
@@ -36,6 +35,9 @@ class MatrixFactorizationModel(object):
     >>> model = ALS.trainImplicit(sc, ratings, 1)
     >>> model.predict(2,2) is not None
     True
+    >>> testset = sc.parallelize([(1, 2), (1, 1)])
+    >>> model.predictAll(testset).count == 2
+    True
     """
 
     def __init__(self, sc, java_model):
@@ -50,8 +52,8 @@ class MatrixFactorizationModel(object):
 
     def predictAll(self, usersProducts):
         usersProductsJRDD = _get_unmangled_rdd(usersProducts, _serialize_tuple)
-        return RDD(self._java_model.predictJavaRDD(usersProductsJRDD._jrdd),
-                   self._context, BatchedSerializer(_deserialize_rating, self._context._batchSize))
+        return RDD(self._java_model.predict(usersProductsJRDD._jrdd),
+                   self._context, RatingDeserializer())
 
 class ALS(object):
     @classmethod