You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2015/07/21 02:03:50 UTC

spark git commit: [SPARK-9175] [MLLIB] BLAS.gemm fails to update matrix C when alpha==0 and beta!=1

Repository: spark
Updated Branches:
  refs/heads/master a5d05819a -> ff3c72dba


[SPARK-9175] [MLLIB] BLAS.gemm fails to update matrix C when alpha==0 and beta!=1

Fix BLAS.gemm to update matrix C when alpha==0 and beta!=1
Also include unit tests to verify the fix.

mengxr brkyvz

Author: Meihua Wu <me...@umich.edu>

Closes #7503 from rotationsymmetry/fix_BLAS_gemm and squashes the following commits:

fce199c [Meihua Wu] Fix BLAS.gemm to update C when alpha==0 and beta!=1


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/ff3c72db
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/ff3c72db
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/ff3c72db

Branch: refs/heads/master
Commit: ff3c72dbafa16c6158fc36619f3c38344c452ba0
Parents: a5d0581
Author: Meihua Wu <me...@umich.edu>
Authored: Mon Jul 20 17:03:46 2015 -0700
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Mon Jul 20 17:03:46 2015 -0700

----------------------------------------------------------------------
 .../scala/org/apache/spark/mllib/linalg/BLAS.scala  |  4 ++--
 .../org/apache/spark/mllib/linalg/BLASSuite.scala   | 16 ++++++++++++++++
 2 files changed, 18 insertions(+), 2 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/ff3c72db/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
index 3523f18..9029093 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
@@ -303,8 +303,8 @@ private[spark] object BLAS extends Serializable with Logging {
       C: DenseMatrix): Unit = {
     require(!C.isTransposed,
       "The matrix C cannot be the product of a transpose() call. C.isTransposed must be false.")
-    if (alpha == 0.0) {
-      logDebug("gemm: alpha is equal to 0. Returning C.")
+    if (alpha == 0.0 && beta == 1.0) {
+      logDebug("gemm: alpha is equal to 0 and beta is equal to 1. Returning C.")
     } else {
       A match {
         case sparse: SparseMatrix => gemm(alpha, sparse, B, beta, C)

http://git-wip-us.apache.org/repos/asf/spark/blob/ff3c72db/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala
index b0f3f71..d119e0b 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala
@@ -200,8 +200,14 @@ class BLASSuite extends SparkFunSuite {
     val C10 = C1.copy
     val C11 = C1.copy
     val C12 = C1.copy
+    val C13 = C1.copy
+    val C14 = C1.copy
+    val C15 = C1.copy
+    val C16 = C1.copy
     val expected2 = new DenseMatrix(4, 2, Array(2.0, 1.0, 4.0, 2.0, 4.0, 0.0, 4.0, 3.0))
     val expected3 = new DenseMatrix(4, 2, Array(2.0, 2.0, 4.0, 2.0, 8.0, 0.0, 6.0, 6.0))
+    val expected4 = new DenseMatrix(4, 2, Array(5.0, 0.0, 10.0, 5.0, 0.0, 0.0, 5.0, 0.0))
+    val expected5 = C1.copy
 
     gemm(1.0, dA, B, 2.0, C1)
     gemm(1.0, sA, B, 2.0, C2)
@@ -248,6 +254,16 @@ class BLASSuite extends SparkFunSuite {
     assert(C10 ~== expected2 absTol 1e-15)
     assert(C11 ~== expected3 absTol 1e-15)
     assert(C12 ~== expected3 absTol 1e-15)
+
+    gemm(0, dA, B, 5, C13)
+    gemm(0, sA, B, 5, C14)
+    gemm(0, dA, B, 1, C15)
+    gemm(0, sA, B, 1, C16)
+    assert(C13 ~== expected4 absTol 1e-15)
+    assert(C14 ~== expected4 absTol 1e-15)
+    assert(C15 ~== expected5 absTol 1e-15)
+    assert(C16 ~== expected5 absTol 1e-15)
+
   }
 
   test("gemv") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org