You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2015/05/05 16:53:15 UTC
spark git commit: [SPARK-7202] [MLLIB] [PYSPARK] Add
SparseMatrixPickler to SerDe
Repository: spark
Updated Branches:
refs/heads/master c6d1efba2 -> 5ab652cdb
[SPARK-7202] [MLLIB] [PYSPARK] Add SparseMatrixPickler to SerDe
Utilities for pickling and unpickling SparseMatrices using SerDe
Author: MechCoder <ma...@gmail.com>
Closes #5775 from MechCoder/spark-7202 and squashes the following commits:
7e689dc [MechCoder] [SPARK-7202] Add SparseMatrixPickler to SerDe
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/5ab652cd
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/5ab652cd
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/5ab652cd
Branch: refs/heads/master
Commit: 5ab652cdb8bef10214edd079502a7f49017579aa
Parents: c6d1efb
Author: MechCoder <ma...@gmail.com>
Authored: Tue May 5 07:53:11 2015 -0700
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Tue May 5 07:53:11 2015 -0700
----------------------------------------------------------------------
.../spark/mllib/api/python/PythonMLLibAPI.scala | 56 ++++++++++++++++++++
.../mllib/api/python/PythonMLLibAPISuite.scala | 12 ++++-
python/pyspark/mllib/linalg.py | 4 +-
python/pyspark/mllib/tests.py | 3 ++
4 files changed, 72 insertions(+), 3 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/5ab652cd/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index 6237b64..8e9a208 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -1015,6 +1015,61 @@ private[spark] object SerDe extends Serializable {
}
}
+ // Pickler for SparseMatrix
+ private[python] class SparseMatrixPickler extends BasePickler[SparseMatrix] {
+
+ def saveState(obj: Object, out: OutputStream, pickler: Pickler): Unit = {
+ val s = obj.asInstanceOf[SparseMatrix]
+ val order = ByteOrder.nativeOrder()
+
+ val colPtrsBytes = new Array[Byte](4 * s.colPtrs.length)
+ val indicesBytes = new Array[Byte](4 * s.rowIndices.length)
+ val valuesBytes = new Array[Byte](8 * s.values.length)
+ val isTransposed = if (s.isTransposed) 1 else 0
+ ByteBuffer.wrap(colPtrsBytes).order(order).asIntBuffer().put(s.colPtrs)
+ ByteBuffer.wrap(indicesBytes).order(order).asIntBuffer().put(s.rowIndices)
+ ByteBuffer.wrap(valuesBytes).order(order).asDoubleBuffer().put(s.values)
+
+ out.write(Opcodes.MARK)
+ out.write(Opcodes.BININT)
+ out.write(PickleUtils.integer_to_bytes(s.numRows))
+ out.write(Opcodes.BININT)
+ out.write(PickleUtils.integer_to_bytes(s.numCols))
+ out.write(Opcodes.BINSTRING)
+ out.write(PickleUtils.integer_to_bytes(colPtrsBytes.length))
+ out.write(colPtrsBytes)
+ out.write(Opcodes.BINSTRING)
+ out.write(PickleUtils.integer_to_bytes(indicesBytes.length))
+ out.write(indicesBytes)
+ out.write(Opcodes.BINSTRING)
+ out.write(PickleUtils.integer_to_bytes(valuesBytes.length))
+ out.write(valuesBytes)
+ out.write(Opcodes.BININT)
+ out.write(PickleUtils.integer_to_bytes(isTransposed))
+ out.write(Opcodes.TUPLE)
+ }
+
+ def construct(args: Array[Object]): Object = {
+ if (args.length != 6) {
+ throw new PickleException("should be 6")
+ }
+ val order = ByteOrder.nativeOrder()
+ val colPtrsBytes = getBytes(args(2))
+ val indicesBytes = getBytes(args(3))
+ val valuesBytes = getBytes(args(4))
+ val colPtrs = new Array[Int](colPtrsBytes.length / 4)
+ val rowIndices = new Array[Int](indicesBytes.length / 4)
+ val values = new Array[Double](valuesBytes.length / 8)
+ ByteBuffer.wrap(colPtrsBytes).order(order).asIntBuffer().get(colPtrs)
+ ByteBuffer.wrap(indicesBytes).order(order).asIntBuffer().get(rowIndices)
+ ByteBuffer.wrap(valuesBytes).order(order).asDoubleBuffer().get(values)
+ val isTransposed = args(5).asInstanceOf[Int] == 1
+ new SparseMatrix(
+ args(0).asInstanceOf[Int], args(1).asInstanceOf[Int], colPtrs, rowIndices, values,
+ isTransposed)
+ }
+ }
+
// Pickler for SparseVector
private[python] class SparseVectorPickler extends BasePickler[SparseVector] {
@@ -1099,6 +1154,7 @@ private[spark] object SerDe extends Serializable {
if (!initialized) {
new DenseVectorPickler().register()
new DenseMatrixPickler().register()
+ new SparseMatrixPickler().register()
new SparseVectorPickler().register()
new LabeledPointPickler().register()
new RatingPickler().register()
http://git-wip-us.apache.org/repos/asf/spark/blob/5ab652cd/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala
index db8ed62..a629dba 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.mllib.api.python
import org.scalatest.FunSuite
-import org.apache.spark.mllib.linalg.{DenseMatrix, Matrices, Vectors}
+import org.apache.spark.mllib.linalg.{DenseMatrix, Matrices, Vectors, SparseMatrix}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.recommendation.Rating
@@ -77,6 +77,16 @@ class PythonMLLibAPISuite extends FunSuite {
val emptyMatrix = Matrices.dense(0, 0, empty)
val ne = SerDe.loads(SerDe.dumps(emptyMatrix)).asInstanceOf[DenseMatrix]
assert(emptyMatrix == ne)
+
+ val sm = new SparseMatrix(3, 2, Array(0, 1, 3), Array(1, 0, 2), Array(0.9, 1.2, 3.4))
+ val nsm = SerDe.loads(SerDe.dumps(sm)).asInstanceOf[SparseMatrix]
+ assert(sm.toArray === nsm.toArray)
+
+ val smt = new SparseMatrix(
+ 3, 3, Array(0, 2, 3, 5), Array(0, 2, 1, 0, 2), Array(0.9, 1.2, 3.4, 5.7, 8.9),
+ isTransposed=true)
+ val nsmt = SerDe.loads(SerDe.dumps(smt)).asInstanceOf[SparseMatrix]
+ assert(smt.toArray === nsmt.toArray)
}
test("pickle rating") {
http://git-wip-us.apache.org/repos/asf/spark/blob/5ab652cd/python/pyspark/mllib/linalg.py
----------------------------------------------------------------------
diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py
index a57c0b3..9f3b0ba 100644
--- a/python/pyspark/mllib/linalg.py
+++ b/python/pyspark/mllib/linalg.py
@@ -755,7 +755,7 @@ class SparseMatrix(Matrix):
return SparseMatrix, (
self.numRows, self.numCols, self.colPtrs.tostring(),
self.rowIndices.tostring(), self.values.tostring(),
- self.isTransposed)
+ int(self.isTransposed))
def __getitem__(self, indices):
i, j = indices
@@ -801,7 +801,7 @@ class SparseMatrix(Matrix):
# TODO: More efficient implementation:
def __eq__(self, other):
- return np.all(self.toArray == other.toArray)
+ return np.all(self.toArray() == other.toArray())
class Matrices(object):
http://git-wip-us.apache.org/repos/asf/spark/blob/5ab652cd/python/pyspark/mllib/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py
index 1b008b9..1d9c6eb 100644
--- a/python/pyspark/mllib/tests.py
+++ b/python/pyspark/mllib/tests.py
@@ -92,6 +92,9 @@ class VectorTests(MLlibTestCase):
self._test_serialize(SparseVector(4, {1: 1, 3: 2}))
self._test_serialize(SparseVector(3, {}))
self._test_serialize(DenseMatrix(2, 3, range(6)))
+ sm1 = SparseMatrix(
+ 3, 4, [0, 2, 2, 4, 4], [1, 2, 1, 2], [1.0, 2.0, 4.0, 5.0])
+ self._test_serialize(sm1)
def test_dot(self):
sv = SparseVector(4, {1: 1, 3: 2})
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org