You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2015/05/19 06:32:41 UTC

spark git commit: [SPARK-7681] [MLLIB] Add SparseVector support for gemv

Repository: spark
Updated Branches:
  refs/heads/master 3a6003866 -> d03638cc2


[SPARK-7681] [MLLIB] Add SparseVector support for gemv

JIRA: https://issues.apache.org/jira/browse/SPARK-7681

Author: Liang-Chi Hsieh <vi...@gmail.com>

Closes #6209 from viirya/sparsevector_gemv and squashes the following commits:

ce0bb8b [Liang-Chi Hsieh] Still need to scal y when beta is 0.0 because it clears out y.
b890e63 [Liang-Chi Hsieh] Do not delete multiply for DenseVector.
57a8c1e [Liang-Chi Hsieh] Add MimaExcludes for v1.4.
458d1ae [Liang-Chi Hsieh] List DenseMatrix.multiply and SparseMatrix.multiply to MimaExcludes too.
054f05d [Liang-Chi Hsieh] Fix scala style.
410381a [Liang-Chi Hsieh] Address comments. Make Matrix.multiply more generalized.
4616696 [Liang-Chi Hsieh] Add support for SparseVector with SparseMatrix.
5d6d07a [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into sparsevector_gemv
c069507 [Liang-Chi Hsieh] Add SparseVector support for gemv with DenseMatrix.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d03638cc
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d03638cc
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d03638cc

Branch: refs/heads/master
Commit: d03638cc2d414cee9ac7481084672e454495dfc1
Parents: 3a60038
Author: Liang-Chi Hsieh <vi...@gmail.com>
Authored: Mon May 18 21:32:36 2015 -0700
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Mon May 18 21:32:36 2015 -0700

----------------------------------------------------------------------
 .../org/apache/spark/mllib/linalg/BLAS.scala    | 152 +++++++++++++++++--
 .../apache/spark/mllib/linalg/Matrices.scala    |   7 +-
 .../apache/spark/mllib/linalg/BLASSuite.scala   |  96 ++++++++++--
 project/MimaExcludes.scala                      |  18 ++-
 4 files changed, 240 insertions(+), 33 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/d03638cc/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
index 87052e1..ec38529 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
@@ -463,7 +463,7 @@ private[spark] object BLAS extends Serializable with Logging {
   def gemv(
       alpha: Double,
       A: Matrix,
-      x: DenseVector,
+      x: Vector,
       beta: Double,
       y: DenseVector): Unit = {
     require(A.numCols == x.size,
@@ -473,44 +473,169 @@ private[spark] object BLAS extends Serializable with Logging {
     if (alpha == 0.0) {
       logDebug("gemv: alpha is equal to 0. Returning y.")
     } else {
-      A match {
-        case sparse: SparseMatrix =>
-          gemv(alpha, sparse, x, beta, y)
-        case dense: DenseMatrix =>
-          gemv(alpha, dense, x, beta, y)
+      (A, x) match {
+        case (smA: SparseMatrix, dvx: DenseVector) =>
+          gemv(alpha, smA, dvx, beta, y)
+        case (smA: SparseMatrix, svx: SparseVector) =>
+          gemv(alpha, smA, svx, beta, y)
+        case (dmA: DenseMatrix, dvx: DenseVector) =>
+          gemv(alpha, dmA, dvx, beta, y)
+        case (dmA: DenseMatrix, svx: SparseVector) =>
+          gemv(alpha, dmA, svx, beta, y)
         case _ =>
-          throw new IllegalArgumentException(s"gemv doesn't support matrix type ${A.getClass}.")
+          throw new IllegalArgumentException(s"gemv doesn't support running on matrix type " +
+            s"${A.getClass} and vector type ${x.getClass}.")
       }
     }
   }
 
   /**
    * y := alpha * A * x + beta * y
-   * For `DenseMatrix` A.
+   * For `DenseMatrix` A and `DenseVector` x.
    */
   private def gemv(
       alpha: Double,
       A: DenseMatrix,
       x: DenseVector,
       beta: Double,
-      y: DenseVector): Unit =  {
+      y: DenseVector): Unit = {
     val tStrA = if (A.isTransposed) "T" else "N"
     val mA = if (!A.isTransposed) A.numRows else A.numCols
     val nA = if (!A.isTransposed) A.numCols else A.numRows
     nativeBLAS.dgemv(tStrA, mA, nA, alpha, A.values, mA, x.values, 1, beta,
       y.values, 1)
   }
+ 
+  /**
+   * y := alpha * A * x + beta * y
+   * For `DenseMatrix` A and `SparseVector` x.
+   */
+  private def gemv(
+      alpha: Double,
+      A: DenseMatrix,
+      x: SparseVector,
+      beta: Double,
+      y: DenseVector): Unit = {
+    val mA: Int = A.numRows
+    val nA: Int = A.numCols
+
+    val Avals = A.values
+
+    val xIndices = x.indices
+    val xNnz = xIndices.length
+    val xValues = x.values
+    val yValues = y.values
 
+    if (alpha == 0.0) {
+      scal(beta, y)
+      return
+    }
+
+    if (A.isTransposed) {
+      var rowCounterForA = 0
+      while (rowCounterForA < mA) {
+        var sum = 0.0
+        var k = 0
+        while (k < xNnz) {
+          sum += xValues(k) * Avals(xIndices(k) + rowCounterForA * nA)
+          k += 1
+        }
+        yValues(rowCounterForA) = sum * alpha + beta * yValues(rowCounterForA)
+        rowCounterForA += 1
+      }
+    } else {
+      var rowCounterForA = 0
+      while (rowCounterForA < mA) {
+        var sum = 0.0
+        var k = 0
+        while (k < xNnz) {
+          sum += xValues(k) * Avals(xIndices(k) * mA + rowCounterForA)
+          k += 1
+        }
+        yValues(rowCounterForA) = sum * alpha + beta * yValues(rowCounterForA)
+        rowCounterForA += 1
+      }
+    }
+  }
+ 
   /**
    * y := alpha * A * x + beta * y
-   * For `SparseMatrix` A.
+   * For `SparseMatrix` A and `SparseVector` x.
+   */
+  private def gemv(
+      alpha: Double,
+      A: SparseMatrix,
+      x: SparseVector,
+      beta: Double,
+      y: DenseVector): Unit = {
+    val xValues = x.values
+    val xIndices = x.indices
+    val xNnz = xIndices.length
+
+    val yValues = y.values
+
+    val mA: Int = A.numRows
+    val nA: Int = A.numCols
+
+    val Avals = A.values
+    val Arows = if (!A.isTransposed) A.rowIndices else A.colPtrs
+    val Acols = if (!A.isTransposed) A.colPtrs else A.rowIndices
+
+    if (alpha == 0.0) {
+      scal(beta, y)
+      return
+    }
+
+    if (A.isTransposed) {
+      var rowCounter = 0
+      while (rowCounter < mA) {
+        var i = Arows(rowCounter)
+        val indEnd = Arows(rowCounter + 1)
+        var sum = 0.0
+        var k = 0
+        while (k < xNnz && i < indEnd) {
+          if (xIndices(k) == Acols(i)) {
+            sum += Avals(i) * xValues(k)
+            i += 1
+          }
+          k += 1
+        }
+        yValues(rowCounter) = sum * alpha + beta * yValues(rowCounter)
+        rowCounter += 1
+      }
+    } else {
+      scal(beta, y)
+
+      var colCounterForA = 0
+      var k = 0
+      while (colCounterForA < nA && k < xNnz) {
+        if (xIndices(k) == colCounterForA) {
+          var i = Acols(colCounterForA)
+          val indEnd = Acols(colCounterForA + 1)
+
+          val xTemp = xValues(k) * alpha
+          while (i < indEnd) {
+            val rowIndex = Arows(i)
+            yValues(Arows(i)) += Avals(i) * xTemp
+            i += 1
+          }
+          k += 1
+        }
+        colCounterForA += 1
+      }
+    }
+  }
+
+  /**
+   * y := alpha * A * x + beta * y
+   * For `SparseMatrix` A and `DenseVector` x.
    */
   private def gemv(
       alpha: Double,
       A: SparseMatrix,
       x: DenseVector,
       beta: Double,
-      y: DenseVector): Unit =  {
+      y: DenseVector): Unit = {
     val xValues = x.values
     val yValues = y.values
     val mA: Int = A.numRows
@@ -534,10 +659,7 @@ private[spark] object BLAS extends Serializable with Logging {
         rowCounter += 1
       }
     } else {
-      // Scale vector first if `beta` is not equal to 0.0
-      if (beta != 0.0) {
-        scal(beta, y)
-      }
+      scal(beta, y)
       // Perform matrix-vector multiplication and add to y
       var colCounterForA = 0
       while (colCounterForA < nA) {

http://git-wip-us.apache.org/repos/asf/spark/blob/d03638cc/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
index a609674..9584da8 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
@@ -77,8 +77,13 @@ sealed trait Matrix extends Serializable {
     C
   }
 
-  /** Convenience method for `Matrix`-`DenseVector` multiplication. */
+  /** Convenience method for `Matrix`-`DenseVector` multiplication. For binary compatibility. */
   def multiply(y: DenseVector): DenseVector = {
+    multiply(y.asInstanceOf[Vector])
+  }
+
+  /** Convenience method for `Matrix`-`Vector` multiplication. */
+  def multiply(y: Vector): DenseVector = {
     val output = new DenseVector(new Array[Double](numRows))
     BLAS.gemv(1.0, this, y, 0.0, output)
     output

http://git-wip-us.apache.org/repos/asf/spark/blob/d03638cc/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala
index 002cb25..64ecd12 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala
@@ -256,42 +256,108 @@ class BLASSuite extends FunSuite {
     val dA =
       new DenseMatrix(4, 3, Array(0.0, 1.0, 0.0, 0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 3.0))
     val sA = new SparseMatrix(4, 3, Array(0, 1, 3, 4), Array(1, 0, 2, 3), Array(1.0, 2.0, 1.0, 3.0))
-
-    val x = new DenseVector(Array(1.0, 2.0, 3.0))
+ 
+    val dA2 =
+      new DenseMatrix(4, 3, Array(0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 3.0), true)
+    val sA2 =
+      new SparseMatrix(4, 3, Array(0, 1, 2, 3, 4), Array(1, 0, 1, 2), Array(2.0, 1.0, 1.0, 3.0),
+        true)
+ 
+    val dx = new DenseVector(Array(1.0, 2.0, 3.0))
+    val sx = dx.toSparse
     val expected = new DenseVector(Array(4.0, 1.0, 2.0, 9.0))
 
-    assert(dA.multiply(x) ~== expected absTol 1e-15)
-    assert(sA.multiply(x) ~== expected absTol 1e-15)
-
+    assert(dA.multiply(dx) ~== expected absTol 1e-15)
+    assert(sA.multiply(dx) ~== expected absTol 1e-15)
+    assert(dA.multiply(sx) ~== expected absTol 1e-15)
+    assert(sA.multiply(sx) ~== expected absTol 1e-15)
+ 
     val y1 = new DenseVector(Array(1.0, 3.0, 1.0, 0.0))
     val y2 = y1.copy
     val y3 = y1.copy
     val y4 = y1.copy
+    val y5 = y1.copy
+    val y6 = y1.copy
+    val y7 = y1.copy
+    val y8 = y1.copy
+    val y9 = y1.copy
+    val y10 = y1.copy
+    val y11 = y1.copy
+    val y12 = y1.copy
+    val y13 = y1.copy
+    val y14 = y1.copy
+    val y15 = y1.copy
+    val y16 = y1.copy
+ 
     val expected2 = new DenseVector(Array(6.0, 7.0, 4.0, 9.0))
     val expected3 = new DenseVector(Array(10.0, 8.0, 6.0, 18.0))
 
-    gemv(1.0, dA, x, 2.0, y1)
-    gemv(1.0, sA, x, 2.0, y2)
-    gemv(2.0, dA, x, 2.0, y3)
-    gemv(2.0, sA, x, 2.0, y4)
+    gemv(1.0, dA, dx, 2.0, y1)
+    gemv(1.0, sA, dx, 2.0, y2)
+    gemv(1.0, dA, sx, 2.0, y3)
+    gemv(1.0, sA, sx, 2.0, y4)
+ 
+    gemv(1.0, dA2, dx, 2.0, y5)
+    gemv(1.0, sA2, dx, 2.0, y6)
+    gemv(1.0, dA2, sx, 2.0, y7)
+    gemv(1.0, sA2, sx, 2.0, y8)
+ 
+    gemv(2.0, dA, dx, 2.0, y9)
+    gemv(2.0, sA, dx, 2.0, y10)
+    gemv(2.0, dA, sx, 2.0, y11)
+    gemv(2.0, sA, sx, 2.0, y12)
+ 
+    gemv(2.0, dA2, dx, 2.0, y13)
+    gemv(2.0, sA2, dx, 2.0, y14)
+    gemv(2.0, dA2, sx, 2.0, y15)
+    gemv(2.0, sA2, sx, 2.0, y16)
+ 
     assert(y1 ~== expected2 absTol 1e-15)
     assert(y2 ~== expected2 absTol 1e-15)
-    assert(y3 ~== expected3 absTol 1e-15)
-    assert(y4 ~== expected3 absTol 1e-15)
+    assert(y3 ~== expected2 absTol 1e-15)
+    assert(y4 ~== expected2 absTol 1e-15)
+ 
+    assert(y5 ~== expected2 absTol 1e-15)
+    assert(y6 ~== expected2 absTol 1e-15)
+    assert(y7 ~== expected2 absTol 1e-15)
+    assert(y8 ~== expected2 absTol 1e-15)
+ 
+    assert(y9 ~== expected3 absTol 1e-15)
+    assert(y10 ~== expected3 absTol 1e-15)
+    assert(y11 ~== expected3 absTol 1e-15)
+    assert(y12 ~== expected3 absTol 1e-15)
+ 
+    assert(y13 ~== expected3 absTol 1e-15)
+    assert(y14 ~== expected3 absTol 1e-15)
+    assert(y15 ~== expected3 absTol 1e-15)
+    assert(y16 ~== expected3 absTol 1e-15)
+ 
     withClue("columns of A don't match the rows of B") {
       intercept[Exception] {
-        gemv(1.0, dA.transpose, x, 2.0, y1)
+        gemv(1.0, dA.transpose, dx, 2.0, y1)
+      }
+      intercept[Exception] {
+        gemv(1.0, sA.transpose, dx, 2.0, y1)
+      }
+      intercept[Exception] {
+        gemv(1.0, dA.transpose, sx, 2.0, y1)
+      }
+      intercept[Exception] {
+        gemv(1.0, sA.transpose, sx, 2.0, y1)
       }
     }
+ 
     val dAT =
       new DenseMatrix(3, 4, Array(0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 3.0))
     val sAT =
       new SparseMatrix(3, 4, Array(0, 1, 2, 3, 4), Array(1, 0, 1, 2), Array(2.0, 1.0, 1.0, 3.0))
-
+ 
     val dATT = dAT.transpose
     val sATT = sAT.transpose
 
-    assert(dATT.multiply(x) ~== expected absTol 1e-15)
-    assert(sATT.multiply(x) ~== expected absTol 1e-15)
+    assert(dATT.multiply(dx) ~== expected absTol 1e-15)
+    assert(sATT.multiply(dx) ~== expected absTol 1e-15)
+    assert(dATT.multiply(sx) ~== expected absTol 1e-15)
+    assert(sATT.multiply(sx) ~== expected absTol 1e-15)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/d03638cc/project/MimaExcludes.scala
----------------------------------------------------------------------
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 513bbaf..f8d0160 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -87,7 +87,14 @@ object MimaExcludes {
             ProblemFilters.exclude[MissingMethodProblem](
               "org.apache.spark.mllib.linalg.Vector.toSparse"),
             ProblemFilters.exclude[MissingMethodProblem](
-              "org.apache.spark.mllib.linalg.Vector.numActives")
+              "org.apache.spark.mllib.linalg.Vector.numActives"),
+            // SPARK-7681 add SparseVector support for gemv
+            ProblemFilters.exclude[MissingMethodProblem](
+              "org.apache.spark.mllib.linalg.Matrix.multiply"),
+            ProblemFilters.exclude[MissingMethodProblem](
+              "org.apache.spark.mllib.linalg.DenseMatrix.multiply"),
+            ProblemFilters.exclude[MissingMethodProblem](
+              "org.apache.spark.mllib.linalg.SparseMatrix.multiply")
           ) ++ Seq(
             // Execution should never be included as its always internal.
             MimaBuild.excludeSparkPackage("sql.execution"),
@@ -180,7 +187,14 @@ object MimaExcludes {
             ProblemFilters.exclude[MissingMethodProblem](
               "org.apache.spark.mllib.linalg.Matrix.isTransposed"),
             ProblemFilters.exclude[MissingMethodProblem](
-              "org.apache.spark.mllib.linalg.Matrix.foreachActive")
+              "org.apache.spark.mllib.linalg.Matrix.foreachActive"),
+            // SPARK-7681 add SparseVector support for gemv
+            ProblemFilters.exclude[MissingMethodProblem](
+              "org.apache.spark.mllib.linalg.Matrix.multiply"),
+            ProblemFilters.exclude[MissingMethodProblem](
+              "org.apache.spark.mllib.linalg.DenseMatrix.multiply"),
+            ProblemFilters.exclude[MissingMethodProblem](
+              "org.apache.spark.mllib.linalg.SparseMatrix.multiply")
           ) ++ Seq(
             // SPARK-5540
             ProblemFilters.exclude[MissingMethodProblem](


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org