You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2015/05/05 16:53:15 UTC

spark git commit: [SPARK-7202] [MLLIB] [PYSPARK] Add SparseMatrixPickler to SerDe

Repository: spark
Updated Branches:
  refs/heads/master c6d1efba2 -> 5ab652cdb


[SPARK-7202] [MLLIB] [PYSPARK] Add SparseMatrixPickler to SerDe

Utilities for pickling and unpickling SparseMatrices using SerDe

Author: MechCoder <ma...@gmail.com>

Closes #5775 from MechCoder/spark-7202 and squashes the following commits:

7e689dc [MechCoder] [SPARK-7202] Add SparseMatrixPickler to SerDe


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/5ab652cd
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/5ab652cd
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/5ab652cd

Branch: refs/heads/master
Commit: 5ab652cdb8bef10214edd079502a7f49017579aa
Parents: c6d1efb
Author: MechCoder <ma...@gmail.com>
Authored: Tue May 5 07:53:11 2015 -0700
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Tue May 5 07:53:11 2015 -0700

----------------------------------------------------------------------
 .../spark/mllib/api/python/PythonMLLibAPI.scala | 56 ++++++++++++++++++++
 .../mllib/api/python/PythonMLLibAPISuite.scala  | 12 ++++-
 python/pyspark/mllib/linalg.py                  |  4 +-
 python/pyspark/mllib/tests.py                   |  3 ++
 4 files changed, 72 insertions(+), 3 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/5ab652cd/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index 6237b64..8e9a208 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -1015,6 +1015,61 @@ private[spark] object SerDe extends Serializable {
     }
   }
 
+  // Pickler for SparseMatrix
+  private[python] class SparseMatrixPickler extends BasePickler[SparseMatrix] {
+
+    def saveState(obj: Object, out: OutputStream, pickler: Pickler): Unit = {
+      val s = obj.asInstanceOf[SparseMatrix]
+      val order = ByteOrder.nativeOrder()
+
+      val colPtrsBytes = new Array[Byte](4 * s.colPtrs.length)
+      val indicesBytes = new Array[Byte](4 * s.rowIndices.length)
+      val valuesBytes = new Array[Byte](8 * s.values.length)
+      val isTransposed = if (s.isTransposed) 1 else 0
+      ByteBuffer.wrap(colPtrsBytes).order(order).asIntBuffer().put(s.colPtrs)
+      ByteBuffer.wrap(indicesBytes).order(order).asIntBuffer().put(s.rowIndices)
+      ByteBuffer.wrap(valuesBytes).order(order).asDoubleBuffer().put(s.values)
+
+      out.write(Opcodes.MARK)
+      out.write(Opcodes.BININT)
+      out.write(PickleUtils.integer_to_bytes(s.numRows))
+      out.write(Opcodes.BININT)
+      out.write(PickleUtils.integer_to_bytes(s.numCols))
+      out.write(Opcodes.BINSTRING)
+      out.write(PickleUtils.integer_to_bytes(colPtrsBytes.length))
+      out.write(colPtrsBytes)
+      out.write(Opcodes.BINSTRING)
+      out.write(PickleUtils.integer_to_bytes(indicesBytes.length))
+      out.write(indicesBytes)
+      out.write(Opcodes.BINSTRING)
+      out.write(PickleUtils.integer_to_bytes(valuesBytes.length))
+      out.write(valuesBytes)
+      out.write(Opcodes.BININT)
+      out.write(PickleUtils.integer_to_bytes(isTransposed))
+      out.write(Opcodes.TUPLE)
+    }
+
+    def construct(args: Array[Object]): Object = {
+      if (args.length != 6) {
+        throw new PickleException("should be 6")
+      }
+      val order = ByteOrder.nativeOrder()
+      val colPtrsBytes = getBytes(args(2))
+      val indicesBytes = getBytes(args(3))
+      val valuesBytes = getBytes(args(4))
+      val colPtrs = new Array[Int](colPtrsBytes.length / 4)
+      val rowIndices = new Array[Int](indicesBytes.length / 4)
+      val values = new Array[Double](valuesBytes.length / 8)
+      ByteBuffer.wrap(colPtrsBytes).order(order).asIntBuffer().get(colPtrs)
+      ByteBuffer.wrap(indicesBytes).order(order).asIntBuffer().get(rowIndices)
+      ByteBuffer.wrap(valuesBytes).order(order).asDoubleBuffer().get(values)
+      val isTransposed = args(5).asInstanceOf[Int] == 1
+      new SparseMatrix(
+        args(0).asInstanceOf[Int], args(1).asInstanceOf[Int], colPtrs, rowIndices, values,
+        isTransposed)
+    }
+  }
+
   // Pickler for SparseVector
   private[python] class SparseVectorPickler extends BasePickler[SparseVector] {
 
@@ -1099,6 +1154,7 @@ private[spark] object SerDe extends Serializable {
       if (!initialized) {
         new DenseVectorPickler().register()
         new DenseMatrixPickler().register()
+        new SparseMatrixPickler().register()
         new SparseVectorPickler().register()
         new LabeledPointPickler().register()
         new RatingPickler().register()

http://git-wip-us.apache.org/repos/asf/spark/blob/5ab652cd/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala
index db8ed62..a629dba 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.mllib.api.python
 
 import org.scalatest.FunSuite
 
-import org.apache.spark.mllib.linalg.{DenseMatrix, Matrices, Vectors}
+import org.apache.spark.mllib.linalg.{DenseMatrix, Matrices, Vectors, SparseMatrix}
 import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.mllib.recommendation.Rating
 
@@ -77,6 +77,16 @@ class PythonMLLibAPISuite extends FunSuite {
     val emptyMatrix = Matrices.dense(0, 0, empty)
     val ne = SerDe.loads(SerDe.dumps(emptyMatrix)).asInstanceOf[DenseMatrix]
     assert(emptyMatrix == ne)
+
+    val sm = new SparseMatrix(3, 2, Array(0, 1, 3), Array(1, 0, 2), Array(0.9, 1.2, 3.4))
+    val nsm = SerDe.loads(SerDe.dumps(sm)).asInstanceOf[SparseMatrix]
+    assert(sm.toArray === nsm.toArray)
+
+    val smt = new SparseMatrix(
+      3, 3, Array(0, 2, 3, 5), Array(0, 2, 1, 0, 2), Array(0.9, 1.2, 3.4, 5.7, 8.9),
+      isTransposed=true)
+    val nsmt = SerDe.loads(SerDe.dumps(smt)).asInstanceOf[SparseMatrix]
+    assert(smt.toArray === nsmt.toArray)
   }
 
   test("pickle rating") {

http://git-wip-us.apache.org/repos/asf/spark/blob/5ab652cd/python/pyspark/mllib/linalg.py
----------------------------------------------------------------------
diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py
index a57c0b3..9f3b0ba 100644
--- a/python/pyspark/mllib/linalg.py
+++ b/python/pyspark/mllib/linalg.py
@@ -755,7 +755,7 @@ class SparseMatrix(Matrix):
         return SparseMatrix, (
             self.numRows, self.numCols, self.colPtrs.tostring(),
             self.rowIndices.tostring(), self.values.tostring(),
-            self.isTransposed)
+            int(self.isTransposed))
 
     def __getitem__(self, indices):
         i, j = indices
@@ -801,7 +801,7 @@ class SparseMatrix(Matrix):
 
     # TODO: More efficient implementation:
     def __eq__(self, other):
-        return np.all(self.toArray == other.toArray)
+        return np.all(self.toArray() == other.toArray())
 
 
 class Matrices(object):

http://git-wip-us.apache.org/repos/asf/spark/blob/5ab652cd/python/pyspark/mllib/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py
index 1b008b9..1d9c6eb 100644
--- a/python/pyspark/mllib/tests.py
+++ b/python/pyspark/mllib/tests.py
@@ -92,6 +92,9 @@ class VectorTests(MLlibTestCase):
         self._test_serialize(SparseVector(4, {1: 1, 3: 2}))
         self._test_serialize(SparseVector(3, {}))
         self._test_serialize(DenseMatrix(2, 3, range(6)))
+        sm1 = SparseMatrix(
+            3, 4, [0, 2, 2, 4, 4], [1, 2, 1, 2], [1.0, 2.0, 4.0, 5.0])
+        self._test_serialize(sm1)
 
     def test_dot(self):
         sv = SparseVector(4, {1: 1, 3: 2})


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org