You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2015/04/22 01:44:55 UTC

spark git commit: [SPARK-7036][MLLIB] ALS.train should support DataFrames in PySpark

Repository: spark
Updated Branches:
  refs/heads/master 7fe6142cd -> 686dd742e


[SPARK-7036][MLLIB] ALS.train should support DataFrames in PySpark

SchemaRDD works with ALS.train in 1.2, so we should continue support DataFrames for compatibility. coderxiang

Author: Xiangrui Meng <me...@databricks.com>

Closes #5619 from mengxr/SPARK-7036 and squashes the following commits:

dfcaf5a [Xiangrui Meng] ALS.train should support DataFrames in PySpark


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/686dd742
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/686dd742
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/686dd742

Branch: refs/heads/master
Commit: 686dd742e11f6ad0078b7ff9b30b83a118fd8109
Parents: 7fe6142
Author: Xiangrui Meng <me...@databricks.com>
Authored: Tue Apr 21 16:44:52 2015 -0700
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Tue Apr 21 16:44:52 2015 -0700

----------------------------------------------------------------------
 python/pyspark/mllib/recommendation.py | 36 +++++++++++++++++++++--------
 1 file changed, 26 insertions(+), 10 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/686dd742/python/pyspark/mllib/recommendation.py
----------------------------------------------------------------------
diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py
index 80e0a35..4b7d17d 100644
--- a/python/pyspark/mllib/recommendation.py
+++ b/python/pyspark/mllib/recommendation.py
@@ -22,6 +22,7 @@ from pyspark import SparkContext
 from pyspark.rdd import RDD
 from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, inherit_doc
 from pyspark.mllib.util import JavaLoader, JavaSaveable
+from pyspark.sql import DataFrame
 
 __all__ = ['MatrixFactorizationModel', 'ALS', 'Rating']
 
@@ -78,18 +79,23 @@ class MatrixFactorizationModel(JavaModelWrapper, JavaSaveable, JavaLoader):
     True
 
     >>> model = ALS.train(ratings, 1, nonnegative=True, seed=10)
-    >>> model.predict(2,2)
+    >>> model.predict(2, 2)
+    3.8...
+
+    >>> df = sqlContext.createDataFrame([Rating(1, 1, 1.0), Rating(1, 2, 2.0), Rating(2, 1, 2.0)])
+    >>> model = ALS.train(df, 1, nonnegative=True, seed=10)
+    >>> model.predict(2, 2)
     3.8...
 
     >>> model = ALS.trainImplicit(ratings, 1, nonnegative=True, seed=10)
-    >>> model.predict(2,2)
+    >>> model.predict(2, 2)
     0.4...
 
     >>> import os, tempfile
     >>> path = tempfile.mkdtemp()
     >>> model.save(sc, path)
     >>> sameModel = MatrixFactorizationModel.load(sc, path)
-    >>> sameModel.predict(2,2)
+    >>> sameModel.predict(2, 2)
     0.4...
     >>> sameModel.predictAll(testset).collect()
     [Rating(...
@@ -125,13 +131,20 @@ class ALS(object):
 
     @classmethod
     def _prepare(cls, ratings):
-        assert isinstance(ratings, RDD), "ratings should be RDD"
+        if isinstance(ratings, RDD):
+            pass
+        elif isinstance(ratings, DataFrame):
+            ratings = ratings.rdd
+        else:
+            raise TypeError("Ratings should be represented by either an RDD or a DataFrame, "
+                            "but got %s." % type(ratings))
         first = ratings.first()
-        if not isinstance(first, Rating):
-            if isinstance(first, (tuple, list)):
-                ratings = ratings.map(lambda x: Rating(*x))
-            else:
-                raise ValueError("rating should be RDD of Rating or tuple/list")
+        if isinstance(first, Rating):
+            pass
+        elif isinstance(first, (tuple, list)):
+            ratings = ratings.map(lambda x: Rating(*x))
+        else:
+            raise TypeError("Expect a Rating or a tuple/list, but got %s." % type(first))
         return ratings
 
     @classmethod
@@ -152,8 +165,11 @@ class ALS(object):
 def _test():
     import doctest
     import pyspark.mllib.recommendation
+    from pyspark.sql import SQLContext
     globs = pyspark.mllib.recommendation.__dict__.copy()
-    globs['sc'] = SparkContext('local[4]', 'PythonTest')
+    sc = SparkContext('local[4]', 'PythonTest')
+    globs['sc'] = sc
+    globs['sqlContext'] = SQLContext(sc)
     (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
     globs['sc'].stop()
     if failure_count:


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org