You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jk...@apache.org on 2016/04/13 22:23:14 UTC
spark git commit: [SPARK-14375][ML] Unit test for spark.ml
KMeansSummary
Repository: spark
Updated Branches:
refs/heads/master 0d17593b3 -> a91aaf5a8
[SPARK-14375][ML] Unit test for spark.ml KMeansSummary
## What changes were proposed in this pull request?
* Modify ```KMeansSummary.clusterSizes``` method to make it robust to empty clusters.
* Add unit test for spark.ml ```KMeansSummary```.
* Add Since tag.
## How was this patch tested?
unit tests.
cc jkbradley
Author: Yanbo Liang <yb...@gmail.com>
Closes #12254 from yanboliang/spark-14375.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/a91aaf5a
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/a91aaf5a
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/a91aaf5a
Branch: refs/heads/master
Commit: a91aaf5a8cca18811c0cccc20f4e77f36231b344
Parents: 0d17593
Author: Yanbo Liang <yb...@gmail.com>
Authored: Wed Apr 13 13:23:10 2016 -0700
Committer: Joseph K. Bradley <jo...@databricks.com>
Committed: Wed Apr 13 13:23:10 2016 -0700
----------------------------------------------------------------------
.../org/apache/spark/ml/clustering/KMeans.scala | 35 ++++++++++++++++----
.../org/apache/spark/ml/r/KMeansWrapper.scala | 2 +-
.../spark/ml/clustering/KMeansSuite.scala | 18 +++++++++-
3 files changed, 47 insertions(+), 8 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/a91aaf5a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
index d716bc6..b324196 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
@@ -144,6 +144,12 @@ class KMeansModel private[ml] (
}
/**
+ * Return true if there exists summary of model.
+ */
+ @Since("2.0.0")
+ def hasSummary: Boolean = trainingSummary.nonEmpty
+
+ /**
* Gets summary of model on training set. An exception is
* thrown if `trainingSummary == None`.
*/
@@ -267,7 +273,8 @@ class KMeans @Since("1.5.0") (
.setEpsilon($(tol))
val parentModel = algo.run(rdd)
val model = copyValues(new KMeansModel(uid, parentModel).setParent(this))
- val summary = new KMeansSummary(model.transform(dataset), $(predictionCol), $(featuresCol))
+ val summary = new KMeansSummary(
+ model.transform(dataset), $(predictionCol), $(featuresCol), $(k))
model.setSummary(summary)
}
@@ -284,10 +291,22 @@ object KMeans extends DefaultParamsReadable[KMeans] {
override def load(path: String): KMeans = super.load(path)
}
+/**
+ * :: Experimental ::
+ * Summary of KMeans.
+ *
+ * @param predictions [[DataFrame]] produced by [[KMeansModel.transform()]]
+ * @param predictionCol Name for column of predicted clusters in `predictions`
+ * @param featuresCol Name for column of features in `predictions`
+ * @param k Number of clusters
+ */
+@Since("2.0.0")
+@Experimental
class KMeansSummary private[clustering] (
@Since("2.0.0") @transient val predictions: DataFrame,
@Since("2.0.0") val predictionCol: String,
- @Since("2.0.0") val featuresCol: String) extends Serializable {
+ @Since("2.0.0") val featuresCol: String,
+ @Since("2.0.0") val k: Int) extends Serializable {
/**
* Cluster centers of the transformed data.
@@ -296,11 +315,15 @@ class KMeansSummary private[clustering] (
@transient lazy val cluster: DataFrame = predictions.select(predictionCol)
/**
- * Size of each cluster.
+ * Size of (number of data points in) each cluster.
*/
@Since("2.0.0")
- lazy val clusterSizes: Array[Int] = cluster.rdd.map {
- case Row(clusterIdx: Int) => (clusterIdx, 1)
- }.reduceByKey(_ + _).collect().sortBy(_._1).map(_._2)
+ lazy val clusterSizes: Array[Long] = {
+ val sizes = Array.fill[Long](k)(0)
+ cluster.groupBy(predictionCol).count().select(predictionCol, "count").collect().foreach {
+ case Row(cluster: Int, count: Long) => sizes(cluster) = count
+ }
+ sizes
+ }
}
http://git-wip-us.apache.org/repos/asf/spark/blob/a91aaf5a/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala
index ee51357..9e2b81e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala
@@ -37,7 +37,7 @@ private[r] class KMeansWrapper private (
lazy val k: Int = kMeansModel.getK
- lazy val size: Array[Int] = kMeansModel.summary.clusterSizes
+ lazy val size: Array[Long] = kMeansModel.summary.clusterSizes
lazy val cluster: DataFrame = kMeansModel.summary.cluster
http://git-wip-us.apache.org/repos/asf/spark/blob/a91aaf5a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
index 2076c74..2ca386e 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
@@ -82,7 +82,7 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR
}
}
- test("fit & transform") {
+ test("fit, transform, and summary") {
val predictionColName = "kmeans_prediction"
val kmeans = new KMeans().setK(k).setPredictionCol(predictionColName).setSeed(1)
val model = kmeans.fit(dataset)
@@ -99,6 +99,22 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR
assert(clusters === Set(0, 1, 2, 3, 4))
assert(model.computeCost(dataset) < 0.1)
assert(model.hasParent)
+
+ // Check validity of model summary
+ val numRows = dataset.count()
+ assert(model.hasSummary)
+ val summary: KMeansSummary = model.summary
+ assert(summary.predictionCol === predictionColName)
+ assert(summary.featuresCol === "features")
+ assert(summary.predictions.count() === numRows)
+ for (c <- Array(predictionColName, "features")) {
+ assert(summary.predictions.columns.contains(c))
+ }
+ assert(summary.cluster.columns === Array(predictionColName))
+ val clusterSizes = summary.clusterSizes
+ assert(clusterSizes.length === k)
+ assert(clusterSizes.sum === numRows)
+ assert(clusterSizes.forall(_ >= 0))
}
test("read/write") {
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org