You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by sr...@apache.org on 2016/11/25 13:19:31 UTC

spark git commit: [SPARK-18356][ML] Improve MLKmeans Performance

Repository: spark
Updated Branches:
  refs/heads/master 5ecdc7c5c -> 445d4d9e1


[SPARK-18356][ML] Improve MLKmeans Performance

## What changes were proposed in this pull request?

Spark Kmeans fit() doesn't cache the RDD which generates a lot of warnings :
 WARN KMeans: The input data is not directly cached, which may hurt performance if its parent RDDs are also uncached.
So, Kmeans should cache the internal rdd before calling the Mllib.Kmeans algo, this helped to improve spark kmeans performance by 14%

https://github.com/ZakariaHili/spark/commit/a9cf905cf7dbd50eeb9a8b4f891f2f41ea672472

hhbyyh
## How was this patch tested?
Pass Kmeans tests and existing tests

Author: Zakaria_Hili <za...@gmail.com>
Author: HILI Zakaria <za...@gmail.com>

Closes #15965 from ZakariaHili/zakbranch.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/445d4d9e
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/445d4d9e
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/445d4d9e

Branch: refs/heads/master
Commit: 445d4d9e13ebaee9eceea6135fe7ee47812d97de
Parents: 5ecdc7c
Author: Zakaria_Hili <za...@gmail.com>
Authored: Fri Nov 25 13:19:26 2016 +0000
Committer: Sean Owen <so...@cloudera.com>
Committed: Fri Nov 25 13:19:26 2016 +0000

----------------------------------------------------------------------
 .../org/apache/spark/ml/clustering/KMeans.scala | 20 ++++++++++++++++----
 1 file changed, 16 insertions(+), 4 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/445d4d9e/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
index 6e124eb..ad4f79a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
@@ -33,6 +33,7 @@ import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.{DataFrame, Dataset, Row}
 import org.apache.spark.sql.functions.{col, udf}
 import org.apache.spark.sql.types.{IntegerType, StructType}
+import org.apache.spark.storage.StorageLevel
 import org.apache.spark.util.VersionUtils.majorVersion
 
 /**
@@ -305,12 +306,20 @@ class KMeans @Since("1.5.0") (
 
   @Since("2.0.0")
   override def fit(dataset: Dataset[_]): KMeansModel = {
+    val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
+    fit(dataset, handlePersistence)
+  }
+
+  @Since("2.2.0")
+  protected def fit(dataset: Dataset[_], handlePersistence: Boolean): KMeansModel = {
     transformSchema(dataset.schema, logging = true)
-    val rdd: RDD[OldVector] = dataset.select(col($(featuresCol))).rdd.map {
+    val instances: RDD[OldVector] = dataset.select(col($(featuresCol))).rdd.map {
       case Row(point: Vector) => OldVectors.fromML(point)
     }
-
-    val instr = Instrumentation.create(this, rdd)
+    if (handlePersistence) {
+      instances.persist(StorageLevel.MEMORY_AND_DISK)
+    }
+    val instr = Instrumentation.create(this, instances)
     instr.logParams(featuresCol, predictionCol, k, initMode, initSteps, maxIter, seed, tol)
 
     val algo = new MLlibKMeans()
@@ -320,12 +329,15 @@ class KMeans @Since("1.5.0") (
       .setMaxIterations($(maxIter))
       .setSeed($(seed))
       .setEpsilon($(tol))
-    val parentModel = algo.run(rdd, Option(instr))
+    val parentModel = algo.run(instances, Option(instr))
     val model = copyValues(new KMeansModel(uid, parentModel).setParent(this))
     val summary = new KMeansSummary(
       model.transform(dataset), $(predictionCol), $(featuresCol), $(k))
     model.setSummary(Some(summary))
     instr.logSuccess(model)
+    if (handlePersistence) {
+      instances.unpersist()
+    }
     model
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org