You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by yl...@apache.org on 2016/11/30 04:51:51 UTC
spark git commit: [SPARK-15819][PYSPARK][ML] Add KMeanSummary in
KMeans of PySpark
Repository: spark
Updated Branches:
refs/heads/master 489845f3a -> 4c82ca86d
[SPARK-15819][PYSPARK][ML] Add KMeanSummary in KMeans of PySpark
## What changes were proposed in this pull request?
Add python api for KMeansSummary
## How was this patch tested?
unit test added
Author: Jeff Zhang <zj...@apache.org>
Closes #13557 from zjffdu/SPARK-15819.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/4c82ca86
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/4c82ca86
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/4c82ca86
Branch: refs/heads/master
Commit: 4c82ca86d979e5526a15666683eef3c79c37dc68
Parents: 489845f
Author: Jeff Zhang <zj...@apache.org>
Authored: Tue Nov 29 20:51:27 2016 -0800
Committer: Yanbo Liang <yb...@gmail.com>
Committed: Tue Nov 29 20:51:27 2016 -0800
----------------------------------------------------------------------
python/pyspark/ml/clustering.py | 41 ++++++++++++++++++++++++++++++++++++
python/pyspark/ml/tests.py | 15 +++++++++++++
2 files changed, 56 insertions(+)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/4c82ca86/python/pyspark/ml/clustering.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py
index 7f8d845..35d0aef 100644
--- a/python/pyspark/ml/clustering.py
+++ b/python/pyspark/ml/clustering.py
@@ -292,6 +292,17 @@ class GaussianMixtureSummary(ClusteringSummary):
return self._call_java("probability")
+class KMeansSummary(ClusteringSummary):
+ """
+ .. note:: Experimental
+
+ Summary of KMeans.
+
+ .. versionadded:: 2.1.0
+ """
+ pass
+
+
class KMeansModel(JavaModel, JavaMLWritable, JavaMLReadable):
"""
Model fitted by KMeans.
@@ -312,6 +323,27 @@ class KMeansModel(JavaModel, JavaMLWritable, JavaMLReadable):
"""
return self._call_java("computeCost", dataset)
+ @property
+ @since("2.1.0")
+ def hasSummary(self):
+ """
+ Indicates whether a training summary exists for this model instance.
+ """
+ return self._call_java("hasSummary")
+
+ @property
+ @since("2.1.0")
+ def summary(self):
+ """
+ Gets summary (e.g. cluster assignments, cluster sizes) of the model trained on the
+ training set. An exception is thrown if no summary exists.
+ """
+ if self.hasSummary:
+ return KMeansSummary(self._call_java("summary"))
+ else:
+ raise RuntimeError("No training summary available for this %s" %
+ self.__class__.__name__)
+
@inherit_doc
class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol, HasSeed,
@@ -337,6 +369,13 @@ class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol
True
>>> rows[2].prediction == rows[3].prediction
True
+ >>> model.hasSummary
+ True
+ >>> summary = model.summary
+ >>> summary.k
+ 2
+ >>> summary.clusterSizes
+ [2, 2]
>>> kmeans_path = temp_path + "/kmeans"
>>> kmeans.save(kmeans_path)
>>> kmeans2 = KMeans.load(kmeans_path)
@@ -345,6 +384,8 @@ class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol
>>> model_path = temp_path + "/kmeans_model"
>>> model.save(model_path)
>>> model2 = KMeansModel.load(model_path)
+ >>> model2.hasSummary
+ False
>>> model.clusterCenters()[0] == model2.clusterCenters()[0]
array([ True, True], dtype=bool)
>>> model.clusterCenters()[1] == model2.clusterCenters()[1]
http://git-wip-us.apache.org/repos/asf/spark/blob/4c82ca86/python/pyspark/ml/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index c0f0d40..a0c288a 100755
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -1129,6 +1129,21 @@ class TrainingSummaryTest(SparkSessionTestCase):
self.assertEqual(len(s.clusterSizes), 2)
self.assertEqual(s.k, 2)
+ def test_kmeans_summary(self):
+ data = [(Vectors.dense([0.0, 0.0]),), (Vectors.dense([1.0, 1.0]),),
+ (Vectors.dense([9.0, 8.0]),), (Vectors.dense([8.0, 9.0]),)]
+ df = self.spark.createDataFrame(data, ["features"])
+ kmeans = KMeans(k=2, seed=1)
+ model = kmeans.fit(df)
+ self.assertTrue(model.hasSummary)
+ s = model.summary
+ self.assertTrue(isinstance(s.predictions, DataFrame))
+ self.assertEqual(s.featuresCol, "features")
+ self.assertEqual(s.predictionCol, "prediction")
+ self.assertTrue(isinstance(s.cluster, DataFrame))
+ self.assertEqual(len(s.clusterSizes), 2)
+ self.assertEqual(s.k, 2)
+
class OneVsRestTests(SparkSessionTestCase):
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org