You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jk...@apache.org on 2016/09/29 22:40:01 UTC

spark git commit: [SPARK-17721][MLLIB][ML] Fix for multiplying transposed SparseMatrix with SparseVector

Repository: spark
Updated Branches:
  refs/heads/master 4ecc648ad -> 29396e7d1


[SPARK-17721][MLLIB][ML] Fix for multiplying transposed SparseMatrix with SparseVector

## What changes were proposed in this pull request?

* changes the implementation of gemv with transposed SparseMatrix and SparseVector both in mllib-local and mllib (identical)
* adds a test that was failing before this change, but succeeds with these changes.

The problem in the previous implementation was that it only increments `i`, that is enumerating the columns of a row in the SparseMatrix, when the row-index of the vector matches the column-index of the SparseMatrix. In cases where a particular row of the SparseMatrix has non-zero values at column-indices lower than corresponding non-zero row-indices of the SparseVector, the non-zero values of the SparseVector are enumerated without ever matching the column-index at index `i` and the remaining column-indices i+1,...,indEnd-1 are never attempted. The test cases in this PR illustrate this issue.

## How was this patch tested?

I have run the specific `gemv` tests in both mllib-local and mllib. I am currently still running `./dev/run-tests`.

## ___
As per instructions, I hereby state that this is my original work and that I license the work to the project (Apache Spark) under the project's open source license.

Mentioning dbtsai, viirya and brkyvz whom I can see have worked/authored on these parts before.

Author: Bjarne Fruergaard <bw...@gmail.com>

Closes #15296 from bwahlgreen/bugfix-spark-17721.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/29396e7d
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/29396e7d
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/29396e7d

Branch: refs/heads/master
Commit: 29396e7d1483d027960b9a1bed47008775c4253e
Parents: 4ecc648
Author: Bjarne Fruergaard <bw...@gmail.com>
Authored: Thu Sep 29 15:39:57 2016 -0700
Committer: Joseph K. Bradley <jo...@databricks.com>
Committed: Thu Sep 29 15:39:57 2016 -0700

----------------------------------------------------------------------
 .../scala/org/apache/spark/ml/linalg/BLAS.scala    |  8 ++++++--
 .../org/apache/spark/ml/linalg/BLASSuite.scala     | 17 +++++++++++++++++
 .../scala/org/apache/spark/mllib/linalg/BLAS.scala |  8 ++++++--
 .../org/apache/spark/mllib/linalg/BLASSuite.scala  | 17 +++++++++++++++++
 4 files changed, 46 insertions(+), 4 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/29396e7d/mllib-local/src/main/scala/org/apache/spark/ml/linalg/BLAS.scala
----------------------------------------------------------------------
diff --git a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/BLAS.scala b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/BLAS.scala
index 41b0c6c..4ca19f3 100644
--- a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/BLAS.scala
+++ b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/BLAS.scala
@@ -638,12 +638,16 @@ private[spark] object BLAS extends Serializable {
         val indEnd = Arows(rowCounter + 1)
         var sum = 0.0
         var k = 0
-        while (k < xNnz && i < indEnd) {
+        while (i < indEnd && k < xNnz) {
           if (xIndices(k) == Acols(i)) {
             sum += Avals(i) * xValues(k)
+            k += 1
+            i += 1
+          } else if (xIndices(k) < Acols(i)) {
+            k += 1
+          } else {
             i += 1
           }
-          k += 1
         }
         yValues(rowCounter) = sum * alpha + beta * yValues(rowCounter)
         rowCounter += 1

http://git-wip-us.apache.org/repos/asf/spark/blob/29396e7d/mllib-local/src/test/scala/org/apache/spark/ml/linalg/BLASSuite.scala
----------------------------------------------------------------------
diff --git a/mllib-local/src/test/scala/org/apache/spark/ml/linalg/BLASSuite.scala b/mllib-local/src/test/scala/org/apache/spark/ml/linalg/BLASSuite.scala
index 8a9f497..6e72a5f 100644
--- a/mllib-local/src/test/scala/org/apache/spark/ml/linalg/BLASSuite.scala
+++ b/mllib-local/src/test/scala/org/apache/spark/ml/linalg/BLASSuite.scala
@@ -392,6 +392,23 @@ class BLASSuite extends SparkMLFunSuite {
       }
     }
 
+    val y17 = new DenseVector(Array(0.0, 0.0))
+    val y18 = y17.copy
+
+    val sA3 = new SparseMatrix(3, 2, Array(0, 2, 4), Array(1, 2, 0, 1), Array(2.0, 1.0, 1.0, 2.0))
+      .transpose
+    val sA4 =
+      new SparseMatrix(2, 3, Array(0, 1, 3, 4), Array(1, 0, 1, 0), Array(1.0, 2.0, 2.0, 1.0))
+    val sx3 = new SparseVector(3, Array(1, 2), Array(2.0, 1.0))
+
+    val expected4 = new DenseVector(Array(5.0, 4.0))
+
+    gemv(1.0, sA3, sx3, 0.0, y17)
+    gemv(1.0, sA4, sx3, 0.0, y18)
+
+    assert(y17 ~== expected4 absTol 1e-15)
+    assert(y18 ~== expected4 absTol 1e-15)
+
     val dAT =
       new DenseMatrix(3, 4, Array(0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 3.0))
     val sAT =

http://git-wip-us.apache.org/repos/asf/spark/blob/29396e7d/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
index 6a85608..0cd68a6 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
@@ -637,12 +637,16 @@ private[spark] object BLAS extends Serializable with Logging {
         val indEnd = Arows(rowCounter + 1)
         var sum = 0.0
         var k = 0
-        while (k < xNnz && i < indEnd) {
+        while (i < indEnd && k < xNnz) {
           if (xIndices(k) == Acols(i)) {
             sum += Avals(i) * xValues(k)
+            k += 1
+            i += 1
+          } else if (xIndices(k) < Acols(i)) {
+            k += 1
+          } else {
             i += 1
           }
-          k += 1
         }
         yValues(rowCounter) = sum * alpha + beta * yValues(rowCounter)
         rowCounter += 1

http://git-wip-us.apache.org/repos/asf/spark/blob/29396e7d/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala
index 80da03c..6e68c1c 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala
@@ -392,6 +392,23 @@ class BLASSuite extends SparkFunSuite {
       }
     }
 
+    val y17 = new DenseVector(Array(0.0, 0.0))
+    val y18 = y17.copy
+
+    val sA3 = new SparseMatrix(3, 2, Array(0, 2, 4), Array(1, 2, 0, 1), Array(2.0, 1.0, 1.0, 2.0))
+      .transpose
+    val sA4 =
+      new SparseMatrix(2, 3, Array(0, 1, 3, 4), Array(1, 0, 1, 0), Array(1.0, 2.0, 2.0, 1.0))
+    val sx3 = new SparseVector(3, Array(1, 2), Array(2.0, 1.0))
+
+    val expected4 = new DenseVector(Array(5.0, 4.0))
+
+    gemv(1.0, sA3, sx3, 0.0, y17)
+    gemv(1.0, sA4, sx3, 0.0, y18)
+
+    assert(y17 ~== expected4 absTol 1e-15)
+    assert(y18 ~== expected4 absTol 1e-15)
+
     val dAT =
       new DenseMatrix(3, 4, Array(0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 3.0))
     val sAT =


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org