You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2014/12/03 12:02:26 UTC

spark git commit: [SPARK-4708][MLLib] Make k-mean runs two/three times faster with dense/sparse sample

Repository: spark
Updated Branches:
  refs/heads/master 4ac215115 -> 7fc49ed91


[SPARK-4708][MLLib] Make k-mean runs two/three times faster with dense/sparse sample

Note that the usage of `breezeSquaredDistance` in
`org.apache.spark.mllib.util.MLUtils.fastSquaredDistance`
is in the critical path, and `breezeSquaredDistance` is slow.
We should replace it with our own implementation.

Here is the benchmark against mnist8m dataset.

Before
DenseVector: 70.04secs
SparseVector: 59.05secs

With this PR
DenseVector: 30.58secs
SparseVector: 21.14secs

Author: DB Tsai <db...@alpinenow.com>

Closes #3565 from dbtsai/kmean and squashes the following commits:

08bc068 [DB Tsai] restyle
de24662 [DB Tsai] address feedback
b185a77 [DB Tsai] cleanup
4554ddd [DB Tsai] first commit


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/7fc49ed9
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/7fc49ed9
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/7fc49ed9

Branch: refs/heads/master
Commit: 7fc49ed91168999d24ae7b4cc46fbb4ec87febc1
Parents: 4ac2151
Author: DB Tsai <db...@alpinenow.com>
Authored: Wed Dec 3 19:01:56 2014 +0800
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Wed Dec 3 19:01:56 2014 +0800

----------------------------------------------------------------------
 .../apache/spark/mllib/clustering/KMeans.scala  | 67 ++++++++++----------
 .../spark/mllib/clustering/KMeansModel.scala    | 10 +--
 .../spark/mllib/clustering/LocalKMeans.scala    | 22 +++----
 .../org/apache/spark/mllib/util/MLUtils.scala   | 26 ++++----
 .../apache/spark/mllib/util/MLUtilsSuite.scala  | 13 ++--
 5 files changed, 70 insertions(+), 68 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/7fc49ed9/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
index 0f8dee5..54c301d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
@@ -19,12 +19,11 @@ package org.apache.spark.mllib.clustering
 
 import scala.collection.mutable.ArrayBuffer
 
-import breeze.linalg.{DenseVector => BDV, Vector => BV}
-
 import org.apache.spark.annotation.Experimental
 import org.apache.spark.Logging
 import org.apache.spark.SparkContext._
 import org.apache.spark.mllib.linalg.{Vector, Vectors}
+import org.apache.spark.mllib.linalg.BLAS.{axpy, scal}
 import org.apache.spark.mllib.util.MLUtils
 import org.apache.spark.rdd.RDD
 import org.apache.spark.storage.StorageLevel
@@ -127,10 +126,10 @@ class KMeans private (
     // Compute squared norms and cache them.
     val norms = data.map(Vectors.norm(_, 2.0))
     norms.persist()
-    val breezeData = data.map(_.toBreeze).zip(norms).map { case (v, norm) =>
-      new BreezeVectorWithNorm(v, norm)
+    val zippedData = data.zip(norms).map { case (v, norm) =>
+      new VectorWithNorm(v, norm)
     }
-    val model = runBreeze(breezeData)
+    val model = runAlgorithm(zippedData)
     norms.unpersist()
 
     // Warn at the end of the run as well, for increased visibility.
@@ -142,9 +141,9 @@ class KMeans private (
   }
 
   /**
-   * Implementation of K-Means using breeze.
+   * Implementation of K-Means algorithm.
    */
-  private def runBreeze(data: RDD[BreezeVectorWithNorm]): KMeansModel = {
+  private def runAlgorithm(data: RDD[VectorWithNorm]): KMeansModel = {
 
     val sc = data.sparkContext
 
@@ -170,9 +169,10 @@ class KMeans private (
 
     // Execute iterations of Lloyd's algorithm until all runs have converged
     while (iteration < maxIterations && !activeRuns.isEmpty) {
-      type WeightedPoint = (BV[Double], Long)
-      def mergeContribs(p1: WeightedPoint, p2: WeightedPoint): WeightedPoint = {
-        (p1._1 += p2._1, p1._2 + p2._2)
+      type WeightedPoint = (Vector, Long)
+      def mergeContribs(x: WeightedPoint, y: WeightedPoint): WeightedPoint = {
+        axpy(1.0, x._1, y._1)
+        (y._1, x._2 + y._2)
       }
 
       val activeCenters = activeRuns.map(r => centers(r)).toArray
@@ -185,16 +185,17 @@ class KMeans private (
         val thisActiveCenters = bcActiveCenters.value
         val runs = thisActiveCenters.length
         val k = thisActiveCenters(0).length
-        val dims = thisActiveCenters(0)(0).vector.length
+        val dims = thisActiveCenters(0)(0).vector.size
 
-        val sums = Array.fill(runs, k)(BDV.zeros[Double](dims).asInstanceOf[BV[Double]])
+        val sums = Array.fill(runs, k)(Vectors.zeros(dims))
         val counts = Array.fill(runs, k)(0L)
 
         points.foreach { point =>
           (0 until runs).foreach { i =>
             val (bestCenter, cost) = KMeans.findClosest(thisActiveCenters(i), point)
             costAccums(i) += cost
-            sums(i)(bestCenter) += point.vector
+            val sum = sums(i)(bestCenter)
+            axpy(1.0, point.vector, sum)
             counts(i)(bestCenter) += 1
           }
         }
@@ -212,8 +213,8 @@ class KMeans private (
         while (j < k) {
           val (sum, count) = totalContribs((i, j))
           if (count != 0) {
-            sum /= count.toDouble
-            val newCenter = new BreezeVectorWithNorm(sum)
+            scal(1.0 / count, sum)
+            val newCenter = new VectorWithNorm(sum)
             if (KMeans.fastSquaredDistance(newCenter, centers(run)(j)) > epsilon * epsilon) {
               changed = true
             }
@@ -245,18 +246,18 @@ class KMeans private (
 
     logInfo(s"The cost for the best run is $minCost.")
 
-    new KMeansModel(centers(bestRun).map(c => Vectors.fromBreeze(c.vector)))
+    new KMeansModel(centers(bestRun).map(_.vector))
   }
 
   /**
    * Initialize `runs` sets of cluster centers at random.
    */
-  private def initRandom(data: RDD[BreezeVectorWithNorm])
-  : Array[Array[BreezeVectorWithNorm]] = {
+  private def initRandom(data: RDD[VectorWithNorm])
+  : Array[Array[VectorWithNorm]] = {
     // Sample all the cluster centers in one pass to avoid repeated scans
     val sample = data.takeSample(true, runs * k, new XORShiftRandom().nextInt()).toSeq
     Array.tabulate(runs)(r => sample.slice(r * k, (r + 1) * k).map { v =>
-      new BreezeVectorWithNorm(v.vector.toDenseVector, v.norm)
+      new VectorWithNorm(Vectors.dense(v.vector.toArray), v.norm)
     }.toArray)
   }
 
@@ -269,8 +270,8 @@ class KMeans private (
    *
    * The original paper can be found at http://theory.stanford.edu/~sergei/papers/vldb12-kmpar.pdf.
    */
-  private def initKMeansParallel(data: RDD[BreezeVectorWithNorm])
-  : Array[Array[BreezeVectorWithNorm]] = {
+  private def initKMeansParallel(data: RDD[VectorWithNorm])
+  : Array[Array[VectorWithNorm]] = {
     // Initialize each run's center to a random point
     val seed = new XORShiftRandom().nextInt()
     val sample = data.takeSample(true, runs, seed).toSeq
@@ -376,8 +377,8 @@ object KMeans {
    * Returns the index of the closest center to the given point, as well as the squared distance.
    */
   private[mllib] def findClosest(
-      centers: TraversableOnce[BreezeVectorWithNorm],
-      point: BreezeVectorWithNorm): (Int, Double) = {
+      centers: TraversableOnce[VectorWithNorm],
+      point: VectorWithNorm): (Int, Double) = {
     var bestDistance = Double.PositiveInfinity
     var bestIndex = 0
     var i = 0
@@ -402,8 +403,8 @@ object KMeans {
    * Returns the K-means cost of a given point against the given cluster centers.
    */
   private[mllib] def pointCost(
-      centers: TraversableOnce[BreezeVectorWithNorm],
-      point: BreezeVectorWithNorm): Double =
+      centers: TraversableOnce[VectorWithNorm],
+      point: VectorWithNorm): Double =
     findClosest(centers, point)._2
 
   /**
@@ -411,26 +412,24 @@ object KMeans {
    * [[org.apache.spark.mllib.util.MLUtils#fastSquaredDistance]].
    */
   private[clustering] def fastSquaredDistance(
-      v1: BreezeVectorWithNorm,
-      v2: BreezeVectorWithNorm): Double = {
+      v1: VectorWithNorm,
+      v2: VectorWithNorm): Double = {
     MLUtils.fastSquaredDistance(v1.vector, v1.norm, v2.vector, v2.norm)
   }
 }
 
 /**
- * A breeze vector with its norm for fast distance computation.
+ * A vector with its norm for fast distance computation.
  *
  * @see [[org.apache.spark.mllib.clustering.KMeans#fastSquaredDistance]]
  */
 private[clustering]
-class BreezeVectorWithNorm(val vector: BV[Double], val norm: Double) extends Serializable {
-
-  def this(vector: BV[Double]) = this(vector, Vectors.norm(Vectors.fromBreeze(vector), 2.0))
+class VectorWithNorm(val vector: Vector, val norm: Double) extends Serializable {
 
-  def this(array: Array[Double]) = this(new BDV[Double](array))
+  def this(vector: Vector) = this(vector, Vectors.norm(vector, 2.0))
 
-  def this(v: Vector) = this(v.toBreeze)
+  def this(array: Array[Double]) = this(Vectors.dense(array))
 
   /** Converts the vector to a dense vector. */
-  def toDense = new BreezeVectorWithNorm(vector.toDenseVector, norm)
+  def toDense = new VectorWithNorm(Vectors.dense(vector.toArray), norm)
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/7fc49ed9/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala
index 12a3d91..3b95a9e 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala
@@ -32,14 +32,14 @@ class KMeansModel (val clusterCenters: Array[Vector]) extends Serializable {
 
   /** Returns the cluster index that a given point belongs to. */
   def predict(point: Vector): Int = {
-    KMeans.findClosest(clusterCentersWithNorm, new BreezeVectorWithNorm(point))._1
+    KMeans.findClosest(clusterCentersWithNorm, new VectorWithNorm(point))._1
   }
 
   /** Maps given points to their cluster indices. */
   def predict(points: RDD[Vector]): RDD[Int] = {
     val centersWithNorm = clusterCentersWithNorm
     val bcCentersWithNorm = points.context.broadcast(centersWithNorm)
-    points.map(p => KMeans.findClosest(bcCentersWithNorm.value, new BreezeVectorWithNorm(p))._1)
+    points.map(p => KMeans.findClosest(bcCentersWithNorm.value, new VectorWithNorm(p))._1)
   }
 
   /** Maps given points to their cluster indices. */
@@ -53,9 +53,9 @@ class KMeansModel (val clusterCenters: Array[Vector]) extends Serializable {
   def computeCost(data: RDD[Vector]): Double = {
     val centersWithNorm = clusterCentersWithNorm
     val bcCentersWithNorm = data.context.broadcast(centersWithNorm)
-    data.map(p => KMeans.pointCost(bcCentersWithNorm.value, new BreezeVectorWithNorm(p))).sum()
+    data.map(p => KMeans.pointCost(bcCentersWithNorm.value, new VectorWithNorm(p))).sum()
   }
 
-  private def clusterCentersWithNorm: Iterable[BreezeVectorWithNorm] =
-    clusterCenters.map(new BreezeVectorWithNorm(_))
+  private def clusterCentersWithNorm: Iterable[VectorWithNorm] =
+    clusterCenters.map(new VectorWithNorm(_))
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/7fc49ed9/mllib/src/main/scala/org/apache/spark/mllib/clustering/LocalKMeans.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LocalKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LocalKMeans.scala
index f0722d7..b2f140e 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LocalKMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LocalKMeans.scala
@@ -19,9 +19,9 @@ package org.apache.spark.mllib.clustering
 
 import scala.util.Random
 
-import breeze.linalg.{Vector => BV, DenseVector => BDV, norm => breezeNorm}
-
 import org.apache.spark.Logging
+import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.linalg.BLAS.{axpy, scal}
 
 /**
  * An utility object to run K-means locally. This is private to the ML package because it's used
@@ -35,14 +35,14 @@ private[mllib] object LocalKMeans extends Logging {
    */
   def kMeansPlusPlus(
       seed: Int,
-      points: Array[BreezeVectorWithNorm],
+      points: Array[VectorWithNorm],
       weights: Array[Double],
       k: Int,
       maxIterations: Int
-  ): Array[BreezeVectorWithNorm] = {
+  ): Array[VectorWithNorm] = {
     val rand = new Random(seed)
-    val dimensions = points(0).vector.length
-    val centers = new Array[BreezeVectorWithNorm](k)
+    val dimensions = points(0).vector.size
+    val centers = new Array[VectorWithNorm](k)
 
     // Initialize centers by sampling using the k-means++ procedure.
     centers(0) = pickWeighted(rand, points, weights).toDense
@@ -75,14 +75,12 @@ private[mllib] object LocalKMeans extends Logging {
     while (moved && iteration < maxIterations) {
       moved = false
       val counts = Array.fill(k)(0.0)
-      val sums = Array.fill(k)(
-        BDV.zeros[Double](dimensions).asInstanceOf[BV[Double]]
-      )
+      val sums = Array.fill(k)(Vectors.zeros(dimensions))
       var i = 0
       while (i < points.length) {
         val p = points(i)
         val index = KMeans.findClosest(centers, p)._1
-        breeze.linalg.axpy(weights(i), p.vector, sums(index))
+        axpy(weights(i), p.vector, sums(index))
         counts(index) += weights(i)
         if (index != oldClosest(i)) {
           moved = true
@@ -97,8 +95,8 @@ private[mllib] object LocalKMeans extends Logging {
           // Assign center to a random point
           centers(j) = points(rand.nextInt(points.length)).toDense
         } else {
-          sums(j) /= counts(j)
-          centers(j) = new BreezeVectorWithNorm(sums(j))
+          scal(1.0 / counts(j), sums(j))
+          centers(j) = new VectorWithNorm(sums(j))
         }
         j += 1
       }

http://git-wip-us.apache.org/repos/asf/spark/blob/7fc49ed9/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
index 9353351..b0d05ae 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
@@ -19,7 +19,7 @@ package org.apache.spark.mllib.util
 
 import scala.reflect.ClassTag
 
-import breeze.linalg.{Vector => BV, DenseVector => BDV, SparseVector => BSV,
+import breeze.linalg.{DenseVector => BDV, SparseVector => BSV,
   squaredDistance => breezeSquaredDistance}
 
 import org.apache.spark.annotation.Experimental
@@ -28,7 +28,8 @@ import org.apache.spark.rdd.RDD
 import org.apache.spark.rdd.PartitionwiseSampledRDD
 import org.apache.spark.util.random.BernoulliCellSampler
 import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.mllib.linalg.{Vector, Vectors}
+import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors}
+import org.apache.spark.mllib.linalg.BLAS.dot
 import org.apache.spark.storage.StorageLevel
 import org.apache.spark.streaming.StreamingContext
 import org.apache.spark.streaming.dstream.DStream
@@ -281,9 +282,9 @@ object MLUtils {
    * @return squared distance between v1 and v2 within the specified precision
    */
   private[mllib] def fastSquaredDistance(
-      v1: BV[Double],
+      v1: Vector,
       norm1: Double,
-      v2: BV[Double],
+      v2: Vector,
       norm2: Double,
       precision: Double = 1e-6): Double = {
     val n = v1.size
@@ -306,16 +307,19 @@ object MLUtils {
      */
     val precisionBound1 = 2.0 * EPSILON * sumSquaredNorm / (normDiff * normDiff + EPSILON)
     if (precisionBound1 < precision) {
-      sqDist = sumSquaredNorm - 2.0 * v1.dot(v2)
-    } else if (v1.isInstanceOf[BSV[Double]] || v2.isInstanceOf[BSV[Double]]) {
-      val dot = v1.dot(v2)
-      sqDist = math.max(sumSquaredNorm - 2.0 * dot, 0.0)
-      val precisionBound2 = EPSILON * (sumSquaredNorm + 2.0 * math.abs(dot)) / (sqDist + EPSILON)
+      sqDist = sumSquaredNorm - 2.0 * dot(v1, v2)
+    } else if (v1.isInstanceOf[SparseVector] || v2.isInstanceOf[SparseVector]) {
+      val dotValue = dot(v1, v2)
+      sqDist = math.max(sumSquaredNorm - 2.0 * dotValue, 0.0)
+      val precisionBound2 = EPSILON * (sumSquaredNorm + 2.0 * math.abs(dotValue)) /
+        (sqDist + EPSILON)
       if (precisionBound2 > precision) {
-        sqDist = breezeSquaredDistance(v1, v2)
+        // TODO: breezeSquaredDistance is slow,
+        // so we should replace it with our own implementation.
+        sqDist = breezeSquaredDistance(v1.toBreeze, v2.toBreeze)
       }
     } else {
-      sqDist = breezeSquaredDistance(v1, v2)
+      sqDist = breezeSquaredDistance(v1.toBreeze, v2.toBreeze)
     }
     sqDist
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/7fc49ed9/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala
index 88bc49c..df07987 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala
@@ -44,18 +44,19 @@ class MLUtilsSuite extends FunSuite with MLlibTestSparkContext {
   test("fast squared distance") {
     val a = (30 to 0 by -1).map(math.pow(2.0, _)).toArray
     val n = a.length
-    val v1 = new BDV[Double](a)
-    val norm1 = breezeNorm(v1, 2.0)
+    val v1 = Vectors.dense(a)
+    val norm1 = Vectors.norm(v1, 2.0)
     val precision = 1e-6
     for (m <- 0 until n) {
       val indices = (0 to m).toArray
       val values = indices.map(i => a(i))
-      val v2 = new BSV[Double](indices, values, n)
-      val norm2 = breezeNorm(v2, 2.0)
-      val squaredDist = breezeSquaredDistance(v1, v2)
+      val v2 = Vectors.sparse(n, indices, values)
+      val norm2 = Vectors.norm(v2, 2.0)
+      val squaredDist = breezeSquaredDistance(v1.toBreeze, v2.toBreeze)
       val fastSquaredDist1 = fastSquaredDistance(v1, norm1, v2, norm2, precision)
       assert((fastSquaredDist1 - squaredDist) <= precision * squaredDist, s"failed with m = $m")
-      val fastSquaredDist2 = fastSquaredDistance(v1, norm1, v2.toDenseVector, norm2, precision)
+      val fastSquaredDist2 =
+        fastSquaredDistance(v1, norm1, Vectors.dense(v2.toArray), norm2, precision)
       assert((fastSquaredDist2 - squaredDist) <= precision * squaredDist, s"failed with m = $m")
     }
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org