You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by sr...@apache.org on 2016/12/30 10:40:22 UTC

spark git commit: [SPARK-18808][ML][MLLIB] ml.KMeansModel.transform is very inefficient

Repository: spark
Updated Branches:
  refs/heads/master 63036aee2 -> 56d3a7eb8


[SPARK-18808][ML][MLLIB] ml.KMeansModel.transform is very inefficient

## What changes were proposed in this pull request?

mllib.KMeansModel.clusterCentersWithNorm is a method than ends up being called every time `predict` is called on a single vector, which is bad news for now the ml.KMeansModel Transformer works, which necessarily transforms one vector at a time.

This causes the model to just store the vectors with norms upfront. The extra norm should be small compared to the vectors. This would avoid this form of overhead on this and other code paths.

## How was this patch tested?

Existing tests.

Author: Sean Owen <so...@cloudera.com>

Closes #16328 from srowen/SPARK-18808.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/56d3a7eb
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/56d3a7eb
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/56d3a7eb

Branch: refs/heads/master
Commit: 56d3a7eb83f9c91d06dab2c91e10569723eeb105
Parents: 63036ae
Author: Sean Owen <so...@cloudera.com>
Authored: Fri Dec 30 10:40:17 2016 +0000
Committer: Sean Owen <so...@cloudera.com>
Committed: Fri Dec 30 10:40:17 2016 +0000

----------------------------------------------------------------------
 .../spark/mllib/clustering/KMeansModel.scala       | 17 ++++++++---------
 .../spark/mllib/clustering/StreamingKMeans.scala   |  2 +-
 2 files changed, 9 insertions(+), 10 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/56d3a7eb/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala
index aa78149..df2a9c0 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala
@@ -39,6 +39,9 @@ import org.apache.spark.sql.{Row, SparkSession}
 class KMeansModel @Since("1.1.0") (@Since("1.0.0") val clusterCenters: Array[Vector])
   extends Saveable with Serializable with PMMLExportable {
 
+  private val clusterCentersWithNorm =
+    if (clusterCenters == null) null else clusterCenters.map(new VectorWithNorm(_))
+
   /**
    * A Java-friendly constructor that takes an Iterable of Vectors.
    */
@@ -49,7 +52,7 @@ class KMeansModel @Since("1.1.0") (@Since("1.0.0") val clusterCenters: Array[Vec
    * Total number of clusters.
    */
   @Since("0.8.0")
-  def k: Int = clusterCenters.length
+  def k: Int = clusterCentersWithNorm.length
 
   /**
    * Returns the cluster index that a given point belongs to.
@@ -64,8 +67,7 @@ class KMeansModel @Since("1.1.0") (@Since("1.0.0") val clusterCenters: Array[Vec
    */
   @Since("1.0.0")
   def predict(points: RDD[Vector]): RDD[Int] = {
-    val centersWithNorm = clusterCentersWithNorm
-    val bcCentersWithNorm = points.context.broadcast(centersWithNorm)
+    val bcCentersWithNorm = points.context.broadcast(clusterCentersWithNorm)
     points.map(p => KMeans.findClosest(bcCentersWithNorm.value, new VectorWithNorm(p))._1)
   }
 
@@ -82,13 +84,10 @@ class KMeansModel @Since("1.1.0") (@Since("1.0.0") val clusterCenters: Array[Vec
    */
   @Since("0.8.0")
   def computeCost(data: RDD[Vector]): Double = {
-    val centersWithNorm = clusterCentersWithNorm
-    val bcCentersWithNorm = data.context.broadcast(centersWithNorm)
+    val bcCentersWithNorm = data.context.broadcast(clusterCentersWithNorm)
     data.map(p => KMeans.pointCost(bcCentersWithNorm.value, new VectorWithNorm(p))).sum()
   }
 
-  private def clusterCentersWithNorm: Iterable[VectorWithNorm] =
-    clusterCenters.map(new VectorWithNorm(_))
 
   @Since("1.4.0")
   override def save(sc: SparkContext, path: String): Unit = {
@@ -127,8 +126,8 @@ object KMeansModel extends Loader[KMeansModel] {
       val metadata = compact(render(
         ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("k" -> model.k)))
       sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))
-      val dataRDD = sc.parallelize(model.clusterCenters.zipWithIndex).map { case (point, id) =>
-        Cluster(id, point)
+      val dataRDD = sc.parallelize(model.clusterCentersWithNorm.zipWithIndex).map { case (p, id) =>
+        Cluster(id, p.vector)
       }
       spark.createDataFrame(dataRDD).write.parquet(Loader.dataPath(path))
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/56d3a7eb/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala
index 85c37c4..3ca75e8 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala
@@ -145,7 +145,7 @@ class StreamingKMeansModel @Since("1.2.0") (
       }
     }
 
-    this
+    new StreamingKMeansModel(clusterCenters, clusterWeights)
   }
 }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org