You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jk...@apache.org on 2015/07/24 03:53:13 UTC

spark git commit: [SPARK-9122] [MLLIB] [PySpark] spark.mllib regression support batch predict

Repository: spark
Updated Branches:
  refs/heads/master 8a94eb23d -> 52de3acca


[SPARK-9122] [MLLIB] [PySpark] spark.mllib regression support batch predict

spark.mllib support batch predict for LinearRegressionModel, RidgeRegressionModel and LassoModel.

Author: Yanbo Liang <yb...@gmail.com>

Closes #7614 from yanboliang/spark-9122 and squashes the following commits:

4e610c0 [Yanbo Liang] spark.mllib regression support batch predict


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/52de3acc
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/52de3acc
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/52de3acc

Branch: refs/heads/master
Commit: 52de3acca4ce8c36fd4c9ce162473a091701bbc7
Parents: 8a94eb2
Author: Yanbo Liang <yb...@gmail.com>
Authored: Thu Jul 23 18:53:07 2015 -0700
Committer: Joseph K. Bradley <jo...@databricks.com>
Committed: Thu Jul 23 18:53:07 2015 -0700

----------------------------------------------------------------------
 python/pyspark/mllib/regression.py | 12 ++++++++++--
 1 file changed, 10 insertions(+), 2 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/52de3acc/python/pyspark/mllib/regression.py
----------------------------------------------------------------------
diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py
index 8e90ade..5b7afc1 100644
--- a/python/pyspark/mllib/regression.py
+++ b/python/pyspark/mllib/regression.py
@@ -97,9 +97,11 @@ class LinearRegressionModelBase(LinearModel):
 
     def predict(self, x):
         """
-        Predict the value of the dependent variable given a vector x
-        containing values for the independent variables.
+        Predict the value of the dependent variable given a vector or
+        an RDD of vectors containing values for the independent variables.
         """
+        if isinstance(x, RDD):
+            return x.map(self.predict)
         x = _convert_to_vector(x)
         return self.weights.dot(x) + self.intercept
 
@@ -124,6 +126,8 @@ class LinearRegressionModel(LinearRegressionModelBase):
     True
     >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5
     True
+    >>> abs(lrm.predict(sc.parallelize([[1.0]])).collect()[0] - 1) < 0.5
+    True
     >>> import os, tempfile
     >>> path = tempfile.mkdtemp()
     >>> lrm.save(sc, path)
@@ -267,6 +271,8 @@ class LassoModel(LinearRegressionModelBase):
     True
     >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5
     True
+    >>> abs(lrm.predict(sc.parallelize([[1.0]])).collect()[0] - 1) < 0.5
+    True
     >>> import os, tempfile
     >>> path = tempfile.mkdtemp()
     >>> lrm.save(sc, path)
@@ -382,6 +388,8 @@ class RidgeRegressionModel(LinearRegressionModelBase):
     True
     >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5
     True
+    >>> abs(lrm.predict(sc.parallelize([[1.0]])).collect()[0] - 1) < 0.5
+    True
     >>> import os, tempfile
     >>> path = tempfile.mkdtemp()
     >>> lrm.save(sc, path)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org