You are viewing a plain text version of this content. The canonical link for it is here.
Posted to reviews@spark.apache.org by srowen <gi...@git.apache.org> on 2018/01/13 19:20:07 UTC
[GitHub] spark pull request #19340: [SPARK-22119][ML] Add cosine distance to KMeans
Github user srowen commented on a diff in the pull request:
https://github.com/apache/spark/pull/19340#discussion_r161379528
--- Diff: mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala ---
@@ -546,10 +577,109 @@ object KMeans {
.run(data)
}
+ private[spark] def validateInitMode(initMode: String): Boolean = {
+ initMode match {
+ case KMeans.RANDOM => true
+ case KMeans.K_MEANS_PARALLEL => true
+ case _ => false
+ }
+ }
+ private[spark] def validateDistanceMeasure(distanceMeasure: String): Boolean = {
+ distanceMeasure match {
+ case DistanceMeasure.EUCLIDEAN => true
+ case DistanceMeasure.COSINE => true
+ case _ => false
+ }
+ }
+}
+
+/**
+ * A vector with its norm for fast distance computation.
+ *
+ * @see [[org.apache.spark.mllib.clustering.KMeans#fastSquaredDistance]]
+ */
+private[clustering]
+class VectorWithNorm(val vector: Vector, val norm: Double) extends Serializable {
+
+ def this(vector: Vector) = this(vector, Vectors.norm(vector, 2.0))
+
+ def this(array: Array[Double]) = this(Vectors.dense(array))
+
+ /** Converts the vector to a dense vector. */
+ def toDense: VectorWithNorm = new VectorWithNorm(Vectors.dense(vector.toArray), norm)
+}
+
+
+private[spark] abstract class DistanceMeasure extends Serializable {
+
/**
* Returns the index of the closest center to the given point, as well as the squared distance.
*/
- private[mllib] def findClosest(
+ def findClosest(
+ centers: TraversableOnce[VectorWithNorm],
+ point: VectorWithNorm): (Int, Double) = {
+ var bestDistance = Double.PositiveInfinity
+ var bestIndex = 0
+ var i = 0
+ centers.foreach { center =>
+ val currentDistance = distance(center, point)
+ if (currentDistance < bestDistance) {
+ bestDistance = currentDistance
+ bestIndex = i
+ }
+ i += 1
+ }
+ (bestIndex, bestDistance)
+ }
+
+ /**
+ * Returns the K-means cost of a given point against the given cluster centers.
+ */
+ def pointCost(
+ centers: TraversableOnce[VectorWithNorm],
+ point: VectorWithNorm): Double =
+ findClosest(centers, point)._2
+
+ /**
+ * Returns whether a center converged or not, given the epsilon parameter.
+ */
+ def isCenterConverged(
+ oldCenter: VectorWithNorm,
+ newCenter: VectorWithNorm,
+ epsilon: Double): Boolean =
+ distance(oldCenter, newCenter) <= epsilon
+
+ /**
+ * Computes the cosine distance between two points.
+ */
+ def distance(
+ v1: VectorWithNorm,
+ v2: VectorWithNorm): Double
+
+}
+
+@Since("2.3.0")
--- End diff --
All the "2.3.0" would likely have to change. I don't know if this would get in for 2.3.0.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org