You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2014/10/21 18:29:53 UTC

git commit: [SPARK-4023] [MLlib] [PySpark] convert rdd into RDD of Vector

Repository: spark
Updated Branches:
  refs/heads/master 5a8f64f33 -> 857081683


[SPARK-4023] [MLlib] [PySpark] convert rdd into RDD of Vector

Convert the input rdd to RDD of Vector.

cc mengxr

Author: Davies Liu <da...@databricks.com>

Closes #2870 from davies/fix4023 and squashes the following commits:

1eac767 [Davies Liu] address comments
0871576 [Davies Liu] convert rdd into RDD of Vector


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/85708168
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/85708168
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/85708168

Branch: refs/heads/master
Commit: 85708168341a9406c451df20af3374c0850ce166
Parents: 5a8f64f
Author: Davies Liu <da...@databricks.com>
Authored: Tue Oct 21 09:29:45 2014 -0700
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Tue Oct 21 09:29:45 2014 -0700

----------------------------------------------------------------------
 python/pyspark/mllib/stat.py  |  9 +++++----
 python/pyspark/mllib/tests.py | 19 +++++++++++++++++++
 2 files changed, 24 insertions(+), 4 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/85708168/python/pyspark/mllib/stat.py
----------------------------------------------------------------------
diff --git a/python/pyspark/mllib/stat.py b/python/pyspark/mllib/stat.py
index a6019da..84baf12 100644
--- a/python/pyspark/mllib/stat.py
+++ b/python/pyspark/mllib/stat.py
@@ -22,7 +22,7 @@ Python package for statistical functions in MLlib.
 from functools import wraps
 
 from pyspark import PickleSerializer
-from pyspark.mllib.linalg import _to_java_object_rdd
+from pyspark.mllib.linalg import _convert_to_vector, _to_java_object_rdd
 
 
 __all__ = ['MultivariateStatisticalSummary', 'Statistics']
@@ -107,7 +107,7 @@ class Statistics(object):
         array([ 2.,  0.,  0., -2.])
         """
         sc = rdd.ctx
-        jrdd = _to_java_object_rdd(rdd)
+        jrdd = _to_java_object_rdd(rdd.map(_convert_to_vector))
         cStats = sc._jvm.PythonMLLibAPI().colStats(jrdd)
         return MultivariateStatisticalSummary(sc, cStats)
 
@@ -163,14 +163,15 @@ class Statistics(object):
         if type(y) == str:
             raise TypeError("Use 'method=' to specify method name.")
 
-        jx = _to_java_object_rdd(x)
         if not y:
+            jx = _to_java_object_rdd(x.map(_convert_to_vector))
             resultMat = sc._jvm.PythonMLLibAPI().corr(jx, method)
             bytes = sc._jvm.SerDe.dumps(resultMat)
             ser = PickleSerializer()
             return ser.loads(str(bytes)).toArray()
         else:
-            jy = _to_java_object_rdd(y)
+            jx = _to_java_object_rdd(x.map(float))
+            jy = _to_java_object_rdd(y.map(float))
             return sc._jvm.PythonMLLibAPI().corr(jx, jy, method)
 
 

http://git-wip-us.apache.org/repos/asf/spark/blob/85708168/python/pyspark/mllib/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py
index 463faf7..d6fb87b 100644
--- a/python/pyspark/mllib/tests.py
+++ b/python/pyspark/mllib/tests.py
@@ -36,6 +36,8 @@ else:
 from pyspark.serializers import PickleSerializer
 from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, _convert_to_vector
 from pyspark.mllib.regression import LabeledPoint
+from pyspark.mllib.random import RandomRDDs
+from pyspark.mllib.stat import Statistics
 from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase
 
 
@@ -202,6 +204,23 @@ class ListTests(PySparkTestCase):
         self.assertTrue(dt_model.predict(features[3]) > 0)
 
 
+class StatTests(PySparkTestCase):
+    # SPARK-4023
+    def test_col_with_different_rdds(self):
+        # numpy
+        data = RandomRDDs.normalVectorRDD(self.sc, 1000, 10, 10)
+        summary = Statistics.colStats(data)
+        self.assertEqual(1000, summary.count())
+        # array
+        data = self.sc.parallelize([range(10)] * 10)
+        summary = Statistics.colStats(data)
+        self.assertEqual(10, summary.count())
+        # array
+        data = self.sc.parallelize([pyarray.array("d", range(10))] * 10)
+        summary = Statistics.colStats(data)
+        self.assertEqual(10, summary.count())
+
+
 @unittest.skipIf(not _have_scipy, "SciPy not installed")
 class SciPyTests(PySparkTestCase):
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org