You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by sr...@apache.org on 2018/02/12 02:15:33 UTC

spark git commit: [SPARK-22119][FOLLOWUP][ML] Use spherical KMeans with cosine distance

Repository: spark
Updated Branches:
  refs/heads/master 4bbd7443e -> c0c902aed


[SPARK-22119][FOLLOWUP][ML] Use spherical KMeans with cosine distance

## What changes were proposed in this pull request?

In #19340 some comments considered needed to use spherical KMeans when cosine distance measure is specified, as Matlab does; instead of the implementation based on the behavior of other tools/libraries like Rapidminer, nltk and ELKI, ie. the centroids are computed as the mean of all the points in the clusters.

The PR introduce the approach used in spherical KMeans. This behavior has the nice feature to minimize the within-cluster cosine distance.

## How was this patch tested?

existing/improved UTs

Author: Marco Gaido <ma...@gmail.com>

Closes #20518 from mgaido91/SPARK-22119_followup.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/c0c902ae
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/c0c902ae
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/c0c902ae

Branch: refs/heads/master
Commit: c0c902aedcf9ed24e482d873d766a7df63b964cb
Parents: 4bbd744
Author: Marco Gaido <ma...@gmail.com>
Authored: Sun Feb 11 20:15:30 2018 -0600
Committer: Sean Owen <so...@cloudera.com>
Committed: Sun Feb 11 20:15:30 2018 -0600

----------------------------------------------------------------------
 .../apache/spark/mllib/clustering/KMeans.scala  | 54 +++++++++++++++++---
 .../spark/ml/clustering/KMeansSuite.scala       | 15 +++++-
 2 files changed, 62 insertions(+), 7 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/c0c902ae/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
index 607145c..3c4ba0b 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
@@ -310,8 +310,7 @@ class KMeans private (
         points.foreach { point =>
           val (bestCenter, cost) = distanceMeasureInstance.findClosest(thisCenters, point)
           costAccum.add(cost)
-          val sum = sums(bestCenter)
-          axpy(1.0, point.vector, sum)
+          distanceMeasureInstance.updateClusterSum(point, sums(bestCenter))
           counts(bestCenter) += 1
         }
 
@@ -319,10 +318,9 @@ class KMeans private (
       }.reduceByKey { case ((sum1, count1), (sum2, count2)) =>
         axpy(1.0, sum2, sum1)
         (sum1, count1 + count2)
-      }.mapValues { case (sum, count) =>
-        scal(1.0 / count, sum)
-        new VectorWithNorm(sum)
-      }.collectAsMap()
+      }.collectAsMap().mapValues { case (sum, count) =>
+        distanceMeasureInstance.centroid(sum, count)
+      }
 
       bcCenters.destroy(blocking = false)
 
@@ -657,6 +655,26 @@ private[spark] abstract class DistanceMeasure extends Serializable {
       v1: VectorWithNorm,
       v2: VectorWithNorm): Double
 
+  /**
+   * Updates the value of `sum` adding the `point` vector.
+   * @param point a `VectorWithNorm` to be added to `sum` of a cluster
+   * @param sum the `sum` for a cluster to be updated
+   */
+  def updateClusterSum(point: VectorWithNorm, sum: Vector): Unit = {
+    axpy(1.0, point.vector, sum)
+  }
+
+  /**
+   * Returns a centroid for a cluster given its `sum` vector and its `count` of points.
+   *
+   * @param sum   the `sum` for a cluster
+   * @param count the number of points in the cluster
+   * @return the centroid of the cluster
+   */
+  def centroid(sum: Vector, count: Long): VectorWithNorm = {
+    scal(1.0 / count, sum)
+    new VectorWithNorm(sum)
+  }
 }
 
 @Since("2.4.0")
@@ -743,6 +761,30 @@ private[spark] class CosineDistanceMeasure extends DistanceMeasure {
    * @return the cosine distance between the two input vectors
    */
   override def distance(v1: VectorWithNorm, v2: VectorWithNorm): Double = {
+    assert(v1.norm > 0 && v2.norm > 0, "Cosine distance is not defined for zero-length vectors.")
     1 - dot(v1.vector, v2.vector) / v1.norm / v2.norm
   }
+
+  /**
+   * Updates the value of `sum` adding the `point` vector.
+   * @param point a `VectorWithNorm` to be added to `sum` of a cluster
+   * @param sum the `sum` for a cluster to be updated
+   */
+  override def updateClusterSum(point: VectorWithNorm, sum: Vector): Unit = {
+    axpy(1.0 / point.norm, point.vector, sum)
+  }
+
+  /**
+   * Returns a centroid for a cluster given its `sum` vector and its `count` of points.
+   *
+   * @param sum   the `sum` for a cluster
+   * @param count the number of points in the cluster
+   * @return the centroid of the cluster
+   */
+  override def centroid(sum: Vector, count: Long): VectorWithNorm = {
+    scal(1.0 / count, sum)
+    val norm = Vectors.norm(sum, 2)
+    scal(1.0 / norm, sum)
+    new VectorWithNorm(sum, 1)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/c0c902ae/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
index e4506f2..32830b3 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.ml.clustering
 
 import scala.util.Random
 
-import org.apache.spark.SparkFunSuite
+import org.apache.spark.{SparkException, SparkFunSuite}
 import org.apache.spark.ml.linalg.{Vector, Vectors}
 import org.apache.spark.ml.param.ParamMap
 import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
@@ -179,6 +179,19 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR
     assert(predictionsMap(Vectors.dense(-1.0, 1.0)) ==
       predictionsMap(Vectors.dense(-100.0, 90.0)))
 
+    model.clusterCenters.forall(Vectors.norm(_, 2) == 1.0)
+  }
+
+  test("KMeans with cosine distance is not supported for 0-length vectors") {
+    val model = new KMeans().setDistanceMeasure(DistanceMeasure.COSINE).setK(2)
+    val df = spark.createDataFrame(spark.sparkContext.parallelize(Array(
+      Vectors.dense(0.0, 0.0),
+      Vectors.dense(10.0, 10.0),
+      Vectors.dense(1.0, 0.5)
+    )).map(v => TestRow(v)))
+    val e = intercept[SparkException](model.fit(df))
+    assert(e.getCause.isInstanceOf[AssertionError])
+    assert(e.getCause.getMessage.contains("Cosine distance is not defined"))
   }
 
   test("read/write") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org