You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2015/01/08 08:23:21 UTC

spark git commit: [SPARK-5116][MLlib] Add extractor for SparseVector and DenseVector

Repository: spark
Updated Branches:
  refs/heads/master 2b729d225 -> c66a97630


[SPARK-5116][MLlib] Add extractor for SparseVector and DenseVector

Add extractor for SparseVector and DenseVector in MLlib to save some code while performing pattern matching on Vectors. For example, previously we may use:

     vec match {
          case dv: DenseVector =>
            val values = dv.values
            ...
          case sv: SparseVector =>
            val indices = sv.indices
            val values = sv.values
            val size = sv.size
            ...
      }

with extractor it is:

    vec match {
        case DenseVector(values) =>
          ...
        case SparseVector(size, indices, values) =>
          ...
    }

Author: Shuo Xiang <sh...@gmail.com>

Closes #3919 from coderxiang/extractor and squashes the following commits:

359e8d5 [Shuo Xiang] merge master
ca5fc3e [Shuo Xiang] merge master
0b1e190 [Shuo Xiang] use extractor for vectors in RowMatrix.scala
e961805 [Shuo Xiang] use extractor for vectors in StandardScaler.scala
c2bbdaf [Shuo Xiang] use extractor for vectors in IDFscala
8433922 [Shuo Xiang] use extractor for vectors in NaiveBayes.scala and Normalizer.scala
d83c7ca [Shuo Xiang] use extractor for vectors in Vectors.scala
5523dad [Shuo Xiang] Add extractor for SparseVector and DenseVector


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/c66a9763
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/c66a9763
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/c66a9763

Branch: refs/heads/master
Commit: c66a976300734b52d943d4ff811fc269c1bff2de
Parents: 2b729d2
Author: Shuo Xiang <sh...@gmail.com>
Authored: Wed Jan 7 23:22:37 2015 -0800
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Wed Jan 7 23:22:37 2015 -0800

----------------------------------------------------------------------
 .../spark/mllib/classification/NaiveBayes.scala |  8 +++---
 .../org/apache/spark/mllib/feature/IDF.scala    | 26 ++++++++++----------
 .../apache/spark/mllib/feature/Normalizer.scala | 10 ++++----
 .../spark/mllib/feature/StandardScaler.scala    | 15 ++++++-----
 .../org/apache/spark/mllib/linalg/Vectors.scala | 25 +++++++++++++------
 .../mllib/linalg/distributed/RowMatrix.scala    | 24 +++++++++---------
 6 files changed, 57 insertions(+), 51 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/c66a9763/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
index 8c8e4a1..a967df8 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
@@ -93,10 +93,10 @@ class NaiveBayes private (private var lambda: Double) extends Serializable with
   def run(data: RDD[LabeledPoint]) = {
     val requireNonnegativeValues: Vector => Unit = (v: Vector) => {
       val values = v match {
-        case sv: SparseVector =>
-          sv.values
-        case dv: DenseVector =>
-          dv.values
+        case SparseVector(size, indices, values) =>
+          values
+        case DenseVector(values) =>
+          values
       }
       if (!values.forall(_ >= 0.0)) {
         throw new SparkException(s"Naive Bayes requires nonnegative feature values but found $v.")

http://git-wip-us.apache.org/repos/asf/spark/blob/c66a9763/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala
index 19120e1..3260f27 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala
@@ -86,20 +86,20 @@ private object IDF {
         df = BDV.zeros(doc.size)
       }
       doc match {
-        case sv: SparseVector =>
-          val nnz = sv.indices.size
+        case SparseVector(size, indices, values) =>
+          val nnz = indices.size
           var k = 0
           while (k < nnz) {
-            if (sv.values(k) > 0) {
-              df(sv.indices(k)) += 1L
+            if (values(k) > 0) {
+              df(indices(k)) += 1L
             }
             k += 1
           }
-        case dv: DenseVector =>
-          val n = dv.size
+        case DenseVector(values) =>
+          val n = values.size
           var j = 0
           while (j < n) {
-            if (dv.values(j) > 0.0) {
+            if (values(j) > 0.0) {
               df(j) += 1L
             }
             j += 1
@@ -207,20 +207,20 @@ private object IDFModel {
   def transform(idf: Vector, v: Vector): Vector = {
     val n = v.size
     v match {
-      case sv: SparseVector =>
-        val nnz = sv.indices.size
+      case SparseVector(size, indices, values) =>
+        val nnz = indices.size
         val newValues = new Array[Double](nnz)
         var k = 0
         while (k < nnz) {
-          newValues(k) = sv.values(k) * idf(sv.indices(k))
+          newValues(k) = values(k) * idf(indices(k))
           k += 1
         }
-        Vectors.sparse(n, sv.indices, newValues)
-      case dv: DenseVector =>
+        Vectors.sparse(n, indices, newValues)
+      case DenseVector(values) =>
         val newValues = new Array[Double](n)
         var j = 0
         while (j < n) {
-          newValues(j) = dv.values(j) * idf(j)
+          newValues(j) = values(j) * idf(j)
           j += 1
         }
         Vectors.dense(newValues)

http://git-wip-us.apache.org/repos/asf/spark/blob/c66a9763/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala
index 1ced26a..32848e0 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala
@@ -52,8 +52,8 @@ class Normalizer(p: Double) extends VectorTransformer {
       // However, for sparse vector, the `index` array will not be changed,
       // so we can re-use it to save memory.
       vector match {
-        case dv: DenseVector =>
-          val values = dv.values.clone()
+        case DenseVector(vs) =>
+          val values = vs.clone()
           val size = values.size
           var i = 0
           while (i < size) {
@@ -61,15 +61,15 @@ class Normalizer(p: Double) extends VectorTransformer {
             i += 1
           }
           Vectors.dense(values)
-        case sv: SparseVector =>
-          val values = sv.values.clone()
+        case SparseVector(size, ids, vs) =>
+          val values = vs.clone()
           val nnz = values.size
           var i = 0
           while (i < nnz) {
             values(i) /= norm
             i += 1
           }
-          Vectors.sparse(sv.size, sv.indices, values)
+          Vectors.sparse(size, ids, values)
         case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass)
       }
     } else {

http://git-wip-us.apache.org/repos/asf/spark/blob/c66a9763/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala
index 8c4c5db..3c20917 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala
@@ -105,8 +105,8 @@ class StandardScalerModel private[mllib] (
       // This can be avoid by having a local reference of `shift`.
       val localShift = shift
       vector match {
-        case dv: DenseVector =>
-          val values = dv.values.clone()
+        case DenseVector(vs) =>
+          val values = vs.clone()
           val size = values.size
           if (withStd) {
             // Having a local reference of `factor` to avoid overhead as the comment before.
@@ -130,8 +130,8 @@ class StandardScalerModel private[mllib] (
       // Having a local reference of `factor` to avoid overhead as the comment before.
       val localFactor = factor
       vector match {
-        case dv: DenseVector =>
-          val values = dv.values.clone()
+        case DenseVector(vs) =>
+          val values = vs.clone()
           val size = values.size
           var i = 0
           while(i < size) {
@@ -139,18 +139,17 @@ class StandardScalerModel private[mllib] (
             i += 1
           }
           Vectors.dense(values)
-        case sv: SparseVector =>
+        case SparseVector(size, indices, vs) =>
           // For sparse vector, the `index` array inside sparse vector object will not be changed,
           // so we can re-use it to save memory.
-          val indices = sv.indices
-          val values = sv.values.clone()
+          val values = vs.clone()
           val nnz = values.size
           var i = 0
           while (i < nnz) {
             values(i) *= localFactor(indices(i))
             i += 1
           }
-          Vectors.sparse(sv.size, indices, values)
+          Vectors.sparse(size, indices, values)
         case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass)
       }
     } else {

http://git-wip-us.apache.org/repos/asf/spark/blob/c66a9763/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
index d40f133..bf1faa2 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
@@ -108,16 +108,16 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] {
   override def serialize(obj: Any): Row = {
     val row = new GenericMutableRow(4)
     obj match {
-      case sv: SparseVector =>
+      case SparseVector(size, indices, values) =>
         row.setByte(0, 0)
-        row.setInt(1, sv.size)
-        row.update(2, sv.indices.toSeq)
-        row.update(3, sv.values.toSeq)
-      case dv: DenseVector =>
+        row.setInt(1, size)
+        row.update(2, indices.toSeq)
+        row.update(3, values.toSeq)
+      case DenseVector(values) =>
         row.setByte(0, 1)
         row.setNullAt(1)
         row.setNullAt(2)
-        row.update(3, dv.values.toSeq)
+        row.update(3, values.toSeq)
     }
     row
   }
@@ -271,8 +271,8 @@ object Vectors {
   def norm(vector: Vector, p: Double): Double = {
     require(p >= 1.0)
     val values = vector match {
-      case dv: DenseVector => dv.values
-      case sv: SparseVector => sv.values
+      case DenseVector(vs) => vs
+      case SparseVector(n, ids, vs) => vs
       case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass)
     }
     val size = values.size
@@ -427,6 +427,10 @@ class DenseVector(val values: Array[Double]) extends Vector {
   }
 }
 
+object DenseVector {
+  def unapply(dv: DenseVector): Option[Array[Double]] = Some(dv.values)
+}
+
 /**
  * A sparse vector represented by an index array and an value array.
  *
@@ -474,3 +478,8 @@ class SparseVector(
     }
   }
 }
+
+object SparseVector {
+  def unapply(sv: SparseVector): Option[(Int, Array[Int], Array[Double])] =
+    Some((sv.size, sv.indices, sv.values))
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/c66a9763/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
index a3fca53..fbd35e3 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
@@ -528,21 +528,21 @@ class RowMatrix(
       iter.flatMap { row =>
         val buf = new ListBuffer[((Int, Int), Double)]()
         row match {
-          case sv: SparseVector =>
-            val nnz = sv.indices.size
+          case SparseVector(size, indices, values) =>
+            val nnz = indices.size
             var k = 0
             while (k < nnz) {
-              scaled(k) = sv.values(k) / q(sv.indices(k))
+              scaled(k) = values(k) / q(indices(k))
               k += 1
             }
             k = 0
             while (k < nnz) {
-              val i = sv.indices(k)
+              val i = indices(k)
               val iVal = scaled(k)
               if (iVal != 0 && rand.nextDouble() < p(i)) {
                 var l = k + 1
                 while (l < nnz) {
-                  val j = sv.indices(l)
+                  val j = indices(l)
                   val jVal = scaled(l)
                   if (jVal != 0 && rand.nextDouble() < p(j)) {
                     buf += (((i, j), iVal * jVal))
@@ -552,11 +552,11 @@ class RowMatrix(
               }
               k += 1
             }
-          case dv: DenseVector =>
-            val n = dv.values.size
+          case DenseVector(values) =>
+            val n = values.size
             var i = 0
             while (i < n) {
-              scaled(i) = dv.values(i) / q(i)
+              scaled(i) = values(i) / q(i)
               i += 1
             }
             i = 0
@@ -620,11 +620,9 @@ object RowMatrix {
     // TODO: Find a better home (breeze?) for this method.
     val n = v.size
     v match {
-      case dv: DenseVector =>
-        blas.dspr("U", n, alpha, dv.values, 1, U)
-      case sv: SparseVector =>
-        val indices = sv.indices
-        val values = sv.values
+      case DenseVector(values) =>
+        blas.dspr("U", n, alpha, values, 1, U)
+      case SparseVector(size, indices, values) =>
         val nnz = indices.length
         var colStartIdx = 0
         var prevCol = 0


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org