You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2015/01/22 06:22:27 UTC

spark git commit: [SPARK-3424][MLLIB] cache point distances during k-means|| init

Repository: spark
Updated Branches:
  refs/heads/master 27bccc5ea -> ca7910d6d


[SPARK-3424][MLLIB] cache point distances during k-means|| init

This PR ports the following feature implemented in #2634 by derrickburns:

* During k-means|| initialization, we should cache costs (squared distances) previously computed.

It also contains the following optimization:

* aggregate sumCosts directly
* ran multiple (#runs) k-means++ in parallel

I compared the performance locally on mnist-digit. Before this patch:

![before](https://cloud.githubusercontent.com/assets/829644/5845647/93080862-a172-11e4-9a35-044ec711afc4.png)

with this patch:

![after](https://cloud.githubusercontent.com/assets/829644/5845653/a47c29e8-a172-11e4-8e9f-08db57fe3502.png)

It is clear that each k-means|| iteration takes about the same amount of time with this patch.

Authors:
  Derrick Burns <de...@gmail.com>
  Xiangrui Meng <me...@databricks.com>

Closes #4144 from mengxr/SPARK-3424-kmeans-parallel and squashes the following commits:

0a875ec [Xiangrui Meng] address comments
4341bb8 [Xiangrui Meng] do not re-compute point distances during k-means||


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/ca7910d6
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/ca7910d6
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/ca7910d6

Branch: refs/heads/master
Commit: ca7910d6dd7693be2a675a0d6a6fcc9eb0aaeb5d
Parents: 27bccc5
Author: Xiangrui Meng <me...@databricks.com>
Authored: Wed Jan 21 21:20:31 2015 -0800
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Wed Jan 21 21:21:07 2015 -0800

----------------------------------------------------------------------
 .../apache/spark/mllib/clustering/KMeans.scala  | 65 +++++++++++++++-----
 1 file changed, 50 insertions(+), 15 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/ca7910d6/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
index 6b5c934..fc46da3 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
@@ -279,45 +279,80 @@ class KMeans private (
    */
   private def initKMeansParallel(data: RDD[VectorWithNorm])
   : Array[Array[VectorWithNorm]] = {
-    // Initialize each run's center to a random point
+    // Initialize empty centers and point costs.
+    val centers = Array.tabulate(runs)(r => ArrayBuffer.empty[VectorWithNorm])
+    var costs = data.map(_ => Vectors.dense(Array.fill(runs)(Double.PositiveInfinity))).cache()
+
+    // Initialize each run's first center to a random point.
     val seed = new XORShiftRandom(this.seed).nextInt()
     val sample = data.takeSample(true, runs, seed).toSeq
-    val centers = Array.tabulate(runs)(r => ArrayBuffer(sample(r).toDense))
+    val newCenters = Array.tabulate(runs)(r => ArrayBuffer(sample(r).toDense))
+
+    /** Merges new centers to centers. */
+    def mergeNewCenters(): Unit = {
+      var r = 0
+      while (r < runs) {
+        centers(r) ++= newCenters(r)
+        newCenters(r).clear()
+        r += 1
+      }
+    }
 
     // On each step, sample 2 * k points on average for each run with probability proportional
-    // to their squared distance from that run's current centers
+    // to their squared distance from that run's centers. Note that only distances between points
+    // and new centers are computed in each iteration.
     var step = 0
     while (step < initializationSteps) {
-      val bcCenters = data.context.broadcast(centers)
-      val sumCosts = data.flatMap { point =>
-        (0 until runs).map { r =>
-          (r, KMeans.pointCost(bcCenters.value(r), point))
-        }
-      }.reduceByKey(_ + _).collectAsMap()
-      val chosen = data.mapPartitionsWithIndex { (index, points) =>
+      val bcNewCenters = data.context.broadcast(newCenters)
+      val preCosts = costs
+      costs = data.zip(preCosts).map { case (point, cost) =>
+        Vectors.dense(
+          Array.tabulate(runs) { r =>
+            math.min(KMeans.pointCost(bcNewCenters.value(r), point), cost(r))
+          })
+      }.cache()
+      val sumCosts = costs
+        .aggregate(Vectors.zeros(runs))(
+          seqOp = (s, v) => {
+            // s += v
+            axpy(1.0, v, s)
+            s
+          },
+          combOp = (s0, s1) => {
+            // s0 += s1
+            axpy(1.0, s1, s0)
+            s0
+          }
+        )
+      preCosts.unpersist(blocking = false)
+      val chosen = data.zip(costs).mapPartitionsWithIndex { (index, pointsWithCosts) =>
         val rand = new XORShiftRandom(seed ^ (step << 16) ^ index)
-        points.flatMap { p =>
+        pointsWithCosts.flatMap { case (p, c) =>
           (0 until runs).filter { r =>
-            rand.nextDouble() < 2.0 * KMeans.pointCost(bcCenters.value(r), p) * k / sumCosts(r)
+            rand.nextDouble() < 2.0 * c(r) * k / sumCosts(r)
           }.map((_, p))
         }
       }.collect()
+      mergeNewCenters()
       chosen.foreach { case (r, p) =>
-        centers(r) += p.toDense
+        newCenters(r) += p.toDense
       }
       step += 1
     }
 
+    mergeNewCenters()
+    costs.unpersist(blocking = false)
+
     // Finally, we might have a set of more than k candidate centers for each run; weigh each
     // candidate by the number of points in the dataset mapping to it and run a local k-means++
     // on the weighted centers to pick just k of them
     val bcCenters = data.context.broadcast(centers)
     val weightMap = data.flatMap { p =>
-      (0 until runs).map { r =>
+      Iterator.tabulate(runs) { r =>
         ((r, KMeans.findClosest(bcCenters.value(r), p)._1), 1.0)
       }
     }.reduceByKey(_ + _).collectAsMap()
-    val finalCenters = (0 until runs).map { r =>
+    val finalCenters = (0 until runs).par.map { r =>
       val myCenters = centers(r).toArray
       val myWeights = (0 until myCenters.length).map(i => weightMap.getOrElse((r, i), 0.0)).toArray
       LocalKMeans.kMeansPlusPlus(r, myCenters, myWeights, k, 30)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org