You are viewing a plain text version of this content. The canonical link for it is here.
Posted to reviews@spark.apache.org by srowen <gi...@git.apache.org> on 2018/01/13 19:20:07 UTC

[GitHub] spark pull request #19340: [SPARK-22119][ML] Add cosine distance to KMeans

Github user srowen commented on a diff in the pull request:

    https://github.com/apache/spark/pull/19340#discussion_r161379528
  
    --- Diff: mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala ---
    @@ -546,10 +577,109 @@ object KMeans {
           .run(data)
       }
     
    +  private[spark] def validateInitMode(initMode: String): Boolean = {
    +    initMode match {
    +      case KMeans.RANDOM => true
    +      case KMeans.K_MEANS_PARALLEL => true
    +      case _ => false
    +    }
    +  }
    +  private[spark] def validateDistanceMeasure(distanceMeasure: String): Boolean = {
    +    distanceMeasure match {
    +      case DistanceMeasure.EUCLIDEAN => true
    +      case DistanceMeasure.COSINE => true
    +      case _ => false
    +    }
    +  }
    +}
    +
    +/**
    + * A vector with its norm for fast distance computation.
    + *
    + * @see [[org.apache.spark.mllib.clustering.KMeans#fastSquaredDistance]]
    + */
    +private[clustering]
    +class VectorWithNorm(val vector: Vector, val norm: Double) extends Serializable {
    +
    +  def this(vector: Vector) = this(vector, Vectors.norm(vector, 2.0))
    +
    +  def this(array: Array[Double]) = this(Vectors.dense(array))
    +
    +  /** Converts the vector to a dense vector. */
    +  def toDense: VectorWithNorm = new VectorWithNorm(Vectors.dense(vector.toArray), norm)
    +}
    +
    +
    +private[spark] abstract class DistanceMeasure extends Serializable {
    +
       /**
        * Returns the index of the closest center to the given point, as well as the squared distance.
        */
    -  private[mllib] def findClosest(
    +  def findClosest(
    +     centers: TraversableOnce[VectorWithNorm],
    +     point: VectorWithNorm): (Int, Double) = {
    +    var bestDistance = Double.PositiveInfinity
    +    var bestIndex = 0
    +    var i = 0
    +    centers.foreach { center =>
    +      val currentDistance = distance(center, point)
    +      if (currentDistance < bestDistance) {
    +        bestDistance = currentDistance
    +        bestIndex = i
    +      }
    +      i += 1
    +    }
    +    (bestIndex, bestDistance)
    +  }
    +
    +  /**
    +   * Returns the K-means cost of a given point against the given cluster centers.
    +   */
    +  def pointCost(
    +      centers: TraversableOnce[VectorWithNorm],
    +      point: VectorWithNorm): Double =
    +    findClosest(centers, point)._2
    +
    +  /**
    +   * Returns whether a center converged or not, given the epsilon parameter.
    +   */
    +  def isCenterConverged(
    +      oldCenter: VectorWithNorm,
    +      newCenter: VectorWithNorm,
    +      epsilon: Double): Boolean =
    +    distance(oldCenter, newCenter) <= epsilon
    +
    +  /**
    +   * Computes the cosine distance between two points.
    +   */
    +  def distance(
    +      v1: VectorWithNorm,
    +      v2: VectorWithNorm): Double
    +
    +}
    +
    +@Since("2.3.0")
    --- End diff --
    
    All the "2.3.0" would likely have to change. I don't know if this would get in for 2.3.0.


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org