You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2015/01/14 20:03:13 UTC

spark git commit: [SPARK-2909] [MLlib] [PySpark] SparseVector in pyspark now supports indexing

Repository: spark
Updated Branches:
  refs/heads/master 38bdc992a -> 5840f5464


[SPARK-2909] [MLlib] [PySpark] SparseVector in pyspark now supports indexing

Slightly different than the scala code which converts the sparsevector into a densevector and then checks the index.

I also hope I've added tests in the right place.

Author: MechCoder <ma...@gmail.com>

Closes #4025 from MechCoder/spark-2909 and squashes the following commits:

07d0f26 [MechCoder] STY: Rename item to index
f02148b [MechCoder] [SPARK-2909] [Mlib] SparseVector in pyspark now supports indexing


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/5840f546
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/5840f546
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/5840f546

Branch: refs/heads/master
Commit: 5840f5464bad8431810d459c97d6e4635eea175c
Parents: 38bdc99
Author: MechCoder <ma...@gmail.com>
Authored: Wed Jan 14 11:03:11 2015 -0800
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Wed Jan 14 11:03:11 2015 -0800

----------------------------------------------------------------------
 python/pyspark/mllib/linalg.py | 17 +++++++++++++++++
 python/pyspark/mllib/tests.py  | 12 ++++++++++++
 2 files changed, 29 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/5840f546/python/pyspark/mllib/linalg.py
----------------------------------------------------------------------
diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py
index 4f8491f..7f21190 100644
--- a/python/pyspark/mllib/linalg.py
+++ b/python/pyspark/mllib/linalg.py
@@ -510,6 +510,23 @@ class SparseVector(Vector):
                 and np.array_equal(other.indices, self.indices)
                 and np.array_equal(other.values, self.values))
 
+    def __getitem__(self, index):
+        inds = self.indices
+        vals = self.values
+        if not isinstance(index, int):
+            raise ValueError(
+                "Indices must be of type integer, got type %s" % type(index))
+        if index < 0:
+            index += self.size
+        if index >= self.size or index < 0:
+            raise ValueError("Index %d out of bounds." % index)
+
+        insert_index = np.searchsorted(inds, index)
+        row_ind = inds[insert_index]
+        if row_ind == index:
+            return vals[insert_index]
+        return 0.
+
     def __ne__(self, other):
         return not self.__eq__(other)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/5840f546/python/pyspark/mllib/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py
index 1f48bc1..140c22b 100644
--- a/python/pyspark/mllib/tests.py
+++ b/python/pyspark/mllib/tests.py
@@ -120,6 +120,18 @@ class VectorTests(PySparkTestCase):
         dv = DenseVector(v)
         self.assertTrue(dv.array.dtype == 'float64')
 
+    def test_sparse_vector_indexing(self):
+        sv = SparseVector(4, {1: 1, 3: 2})
+        self.assertEquals(sv[0], 0.)
+        self.assertEquals(sv[3], 2.)
+        self.assertEquals(sv[1], 1.)
+        self.assertEquals(sv[2], 0.)
+        self.assertEquals(sv[-1], 2)
+        self.assertEquals(sv[-2], 0)
+        self.assertEquals(sv[-4], 0)
+        for ind in [4, -5, 7.8]:
+            self.assertRaises(ValueError, sv.__getitem__, ind)
+
 
 class ListTests(PySparkTestCase):
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org