You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2015/01/09 19:27:37 UTC

spark git commit: [SPARK-5145][Mllib] Add BLAS.dsyr and use it in GaussianMixtureEM

Repository: spark
Updated Branches:
  refs/heads/master b6aa55730 -> e9ca16ec9


[SPARK-5145][Mllib] Add BLAS.dsyr and use it in GaussianMixtureEM

This pr uses BLAS.dsyr to replace few implementations in GaussianMixtureEM.

Author: Liang-Chi Hsieh <vi...@gmail.com>

Closes #3949 from viirya/blas_dsyr and squashes the following commits:

4e4d6cf [Liang-Chi Hsieh] Add unit test. Rename function name, modify doc and style.
3f57fd2 [Liang-Chi Hsieh] Add BLAS.dsyr and use it in GaussianMixtureEM.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/e9ca16ec
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/e9ca16ec
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/e9ca16ec

Branch: refs/heads/master
Commit: e9ca16ec943b9553056482d0c085eacb6046821e
Parents: b6aa557
Author: Liang-Chi Hsieh <vi...@gmail.com>
Authored: Fri Jan 9 10:27:33 2015 -0800
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Fri Jan 9 10:27:33 2015 -0800

----------------------------------------------------------------------
 .../mllib/clustering/GaussianMixtureEM.scala    | 10 +++--
 .../org/apache/spark/mllib/linalg/BLAS.scala    | 26 +++++++++++++
 .../apache/spark/mllib/linalg/BLASSuite.scala   | 41 ++++++++++++++++++++
 3 files changed, 73 insertions(+), 4 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/e9ca16ec/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala
index bdf984a..3a6c0e6 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala
@@ -21,7 +21,7 @@ import scala.collection.mutable.IndexedSeq
 
 import breeze.linalg.{DenseVector => BreezeVector, DenseMatrix => BreezeMatrix, diag, Transpose}
 import org.apache.spark.rdd.RDD
-import org.apache.spark.mllib.linalg.{Matrices, Vector, Vectors}
+import org.apache.spark.mllib.linalg.{Matrices, Vector, Vectors, DenseVector, DenseMatrix, BLAS}
 import org.apache.spark.mllib.stat.impl.MultivariateGaussian
 import org.apache.spark.mllib.util.MLUtils
 
@@ -151,9 +151,10 @@ class GaussianMixtureEM private (
       var i = 0
       while (i < k) {
         val mu = sums.means(i) / sums.weights(i)
-        val sigma = sums.sigmas(i) / sums.weights(i) - mu * new Transpose(mu) // TODO: Use BLAS.dsyr
+        BLAS.syr(-sums.weights(i), Vectors.fromBreeze(mu).asInstanceOf[DenseVector],
+          Matrices.fromBreeze(sums.sigmas(i)).asInstanceOf[DenseMatrix])
         weights(i) = sums.weights(i) / sumWeights
-        gaussians(i) = new MultivariateGaussian(mu, sigma)
+        gaussians(i) = new MultivariateGaussian(mu, sums.sigmas(i) / sums.weights(i))
         i = i + 1
       }
    
@@ -211,7 +212,8 @@ private object ExpectationSum {
       p(i) /= pSum
       sums.weights(i) += p(i)
       sums.means(i) += x * p(i)
-      sums.sigmas(i) += xxt * p(i) // TODO: use BLAS.dsyr
+      BLAS.syr(p(i), Vectors.fromBreeze(x).asInstanceOf[DenseVector],
+        Matrices.fromBreeze(sums.sigmas(i)).asInstanceOf[DenseMatrix])
       i = i + 1
     }
     sums

http://git-wip-us.apache.org/repos/asf/spark/blob/e9ca16ec/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
index 9fed513..3414dac 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
@@ -228,6 +228,32 @@ private[spark] object BLAS extends Serializable with Logging {
     }
     _nativeBLAS
   }
+ 
+  /**
+   * A := alpha * x * x^T^ + A
+   * @param alpha a real scalar that will be multiplied to x * x^T^.
+   * @param x the vector x that contains the n elements.
+   * @param A the symmetric matrix A. Size of n x n.
+   */
+  def syr(alpha: Double, x: DenseVector, A: DenseMatrix) {
+    val mA = A.numRows
+    val nA = A.numCols
+    require(mA == nA, s"A is not a symmetric matrix. A: $mA x $nA")
+    require(mA == x.size, s"The size of x doesn't match the rank of A. A: $mA x $nA, x: ${x.size}")
+
+    nativeBLAS.dsyr("U", x.size, alpha, x.values, 1, A.values, nA)
+
+    // Fill lower triangular part of A
+    var i = 0
+    while (i < mA) {
+      var j = i + 1
+      while (j < nA) {
+        A(j, i) = A(i, j)
+        j += 1
+      }
+      i += 1
+    }    
+  }
 
   /**
    * C := alpha * A * B + beta * C

http://git-wip-us.apache.org/repos/asf/spark/blob/e9ca16ec/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala
index 5d70c91..771878e 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala
@@ -127,6 +127,47 @@ class BLASSuite extends FunSuite {
     }
   }
 
+  test("syr") {
+    val dA = new DenseMatrix(4, 4,
+      Array(0.0, 1.2, 2.2, 3.1, 1.2, 3.2, 5.3, 4.6, 2.2, 5.3, 1.8, 3.0, 3.1, 4.6, 3.0, 0.8))
+    val x = new DenseVector(Array(0.0, 2.7, 3.5, 2.1))
+    val alpha = 0.15
+
+    val expected = new DenseMatrix(4, 4,
+      Array(0.0, 1.2, 2.2, 3.1, 1.2, 4.2935, 6.7175, 5.4505, 2.2, 6.7175, 3.6375, 4.1025, 3.1,
+        5.4505, 4.1025, 1.4615))
+
+    syr(alpha, x, dA)
+
+    assert(dA ~== expected absTol 1e-15)
+ 
+    val dB =
+      new DenseMatrix(3, 4, Array(0.0, 1.2, 2.2, 3.1, 1.2, 3.2, 5.3, 4.6, 2.2, 5.3, 1.8, 3.0))
+
+    withClue("Matrix A must be a symmetric Matrix") {
+      intercept[Exception] {
+        syr(alpha, x, dB)
+      }
+    }
+ 
+    val dC =
+      new DenseMatrix(3, 3, Array(0.0, 1.2, 2.2, 1.2, 3.2, 5.3, 2.2, 5.3, 1.8))
+
+    withClue("Size of vector must match the rank of matrix") {
+      intercept[Exception] {
+        syr(alpha, x, dC)
+      }
+    }
+ 
+    val y = new DenseVector(Array(0.0, 2.7, 3.5, 2.1, 1.5))
+
+    withClue("Size of vector must match the rank of matrix") {
+      intercept[Exception] {
+        syr(alpha, y, dA)
+      }
+    }
+  }
+
   test("gemm") {
 
     val dA =


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org