You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by do...@apache.org on 2019/05/31 23:27:59 UTC

[spark] branch branch-2.4 updated: [SPARK-27896][ML] Fix definition of clustering silhouette coefficient for 1-element clusters

This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch branch-2.4
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-2.4 by this push:
     new 16f2ceb  [SPARK-27896][ML] Fix definition of clustering silhouette coefficient for 1-element clusters
16f2ceb is described below

commit 16f2ceb0cb8407023e7bd14575221b0a718e50de
Author: Sean Owen <se...@databricks.com>
AuthorDate: Fri May 31 16:27:20 2019 -0700

    [SPARK-27896][ML] Fix definition of clustering silhouette coefficient for 1-element clusters
    
    ## What changes were proposed in this pull request?
    
    Single-point clusters should have silhouette score of 0, according to the original paper and scikit implementation.
    
    ## How was this patch tested?
    
    Existing test suite + new test case.
    
    Closes #24756 from srowen/SPARK-27896.
    
    Authored-by: Sean Owen <se...@databricks.com>
    Signed-off-by: Dongjoon Hyun <dh...@apple.com>
    (cherry picked from commit aec0869fb2ae1ace93056ee1f9ea50b1bdbae9ad)
    Signed-off-by: Dongjoon Hyun <dh...@apple.com>
---
 .../spark/ml/evaluation/ClusteringEvaluator.scala  | 38 +++++++++++-----------
 .../ml/evaluation/ClusteringEvaluatorSuite.scala   | 11 +++++++
 2 files changed, 30 insertions(+), 19 deletions(-)

diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala
index 5c1d1ae..4c915e0 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala
@@ -146,27 +146,27 @@ private[evaluation] abstract class Silhouette {
       pointClusterId: Double,
       pointClusterNumOfPoints: Long,
       averageDistanceToCluster: (Double) => Double): Double = {
-    // Here we compute the average dissimilarity of the current point to any cluster of which the
-    // point is not a member.
-    // The cluster with the lowest average dissimilarity - i.e. the nearest cluster to the current
-    // point - is said to be the "neighboring cluster".
-    val otherClusterIds = clusterIds.filter(_ != pointClusterId)
-    val neighboringClusterDissimilarity = otherClusterIds.map(averageDistanceToCluster).min
-
-    // adjustment for excluding the node itself from the computation of the average dissimilarity
-    val currentClusterDissimilarity = if (pointClusterNumOfPoints == 1) {
+    if (pointClusterNumOfPoints == 1) {
+      // Single-element clusters have silhouette 0
       0.0
     } else {
-      averageDistanceToCluster(pointClusterId) * pointClusterNumOfPoints /
-        (pointClusterNumOfPoints - 1)
-    }
-
-    if (currentClusterDissimilarity < neighboringClusterDissimilarity) {
-      1 - (currentClusterDissimilarity / neighboringClusterDissimilarity)
-    } else if (currentClusterDissimilarity > neighboringClusterDissimilarity) {
-      (neighboringClusterDissimilarity / currentClusterDissimilarity) - 1
-    } else {
-      0.0
+      // Here we compute the average dissimilarity of the current point to any cluster of which the
+      // point is not a member.
+      // The cluster with the lowest average dissimilarity - i.e. the nearest cluster to the current
+      // point - is said to be the "neighboring cluster".
+      val otherClusterIds = clusterIds.filter(_ != pointClusterId)
+      val neighboringClusterDissimilarity = otherClusterIds.map(averageDistanceToCluster).min
+      // adjustment for excluding the node itself from the computation of the average dissimilarity
+      val currentClusterDissimilarity =
+        averageDistanceToCluster(pointClusterId) * pointClusterNumOfPoints /
+          (pointClusterNumOfPoints - 1)
+      if (currentClusterDissimilarity < neighboringClusterDissimilarity) {
+        1 - (currentClusterDissimilarity / neighboringClusterDissimilarity)
+      } else if (currentClusterDissimilarity > neighboringClusterDissimilarity) {
+        (neighboringClusterDissimilarity / currentClusterDissimilarity) - 1
+      } else {
+        0.0
+      }
     }
   }
 
diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala
index e2d7756..a47d0c5 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala
@@ -134,4 +134,15 @@ class ClusteringEvaluatorSuite
     // with wrong metadata the evaluator throws an Exception
     intercept[SparkException](evaluator.evaluate(dfWrong))
   }
+
+  test("SPARK-27896: single-element clusters should have silhouette score of 0") {
+    val twoSingleItemClusters =
+      irisDataset.where($"label" === 0.0).limit(1).union(
+        irisDataset.where($"label" === 1.0).limit(1))
+    val evaluator = new ClusteringEvaluator()
+      .setFeaturesCol("features")
+      .setPredictionCol("label")
+    assert(evaluator.evaluate(twoSingleItemClusters) === 0.0)
+  }
+
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org