You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by sr...@apache.org on 2018/01/21 14:51:16 UTC

spark git commit: [SPARK-22119][ML] Add cosine distance to KMeans

Repository: spark
Updated Branches:
  refs/heads/master 121dc96f0 -> 4f43d27c9


[SPARK-22119][ML] Add cosine distance to KMeans

## What changes were proposed in this pull request?

Currently, KMeans assumes the only possible distance measure to be used is the Euclidean. This PR aims to add the cosine distance support to the KMeans algorithm.

## How was this patch tested?

existing and added UTs.

Author: Marco Gaido <ma...@gmail.com>
Author: Marco Gaido <mg...@hortonworks.com>

Closes #19340 from mgaido91/SPARK-22119.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/4f43d27c
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/4f43d27c
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/4f43d27c

Branch: refs/heads/master
Commit: 4f43d27c9e97be8605b120b3d7c11c7c61e3ca6f
Parents: 121dc96
Author: Marco Gaido <ma...@gmail.com>
Authored: Sun Jan 21 08:51:12 2018 -0600
Committer: Sean Owen <so...@cloudera.com>
Committed: Sun Jan 21 08:51:12 2018 -0600

----------------------------------------------------------------------
 .../org/apache/spark/ml/clustering/KMeans.scala |  22 +-
 .../mllib/clustering/BisectingKMeans.scala      |  11 +-
 .../apache/spark/mllib/clustering/KMeans.scala  | 216 +++++++++++++++----
 .../spark/mllib/clustering/KMeansModel.scala    |  74 ++++++-
 .../spark/mllib/clustering/LocalKMeans.scala    |  10 +-
 .../spark/ml/clustering/KMeansSuite.scala       |  42 +++-
 .../spark/mllib/clustering/KMeansSuite.scala    |   6 +-
 7 files changed, 315 insertions(+), 66 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/4f43d27c/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
index f2af7fe..c8145de 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
@@ -26,7 +26,7 @@ import org.apache.spark.ml.linalg.{Vector, VectorUDT}
 import org.apache.spark.ml.param._
 import org.apache.spark.ml.param.shared._
 import org.apache.spark.ml.util._
-import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans, KMeansModel => MLlibKMeansModel}
+import org.apache.spark.mllib.clustering.{DistanceMeasure, KMeans => MLlibKMeans, KMeansModel => MLlibKMeansModel}
 import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors}
 import org.apache.spark.mllib.linalg.VectorImplicits._
 import org.apache.spark.rdd.RDD
@@ -71,6 +71,15 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe
   @Since("1.5.0")
   def getInitMode: String = $(initMode)
 
+  @Since("2.4.0")
+  final val distanceMeasure = new Param[String](this, "distanceMeasure", "The distance measure. " +
+    "Supported options: 'euclidean' and 'cosine'.",
+    (value: String) => MLlibKMeans.validateDistanceMeasure(value))
+
+  /** @group expertGetParam */
+  @Since("2.4.0")
+  def getDistanceMeasure: String = $(distanceMeasure)
+
   /**
    * Param for the number of steps for the k-means|| initialization mode. This is an advanced
    * setting -- the default of 2 is almost always enough. Must be &gt; 0. Default: 2.
@@ -260,7 +269,8 @@ class KMeans @Since("1.5.0") (
     maxIter -> 20,
     initMode -> MLlibKMeans.K_MEANS_PARALLEL,
     initSteps -> 2,
-    tol -> 1e-4)
+    tol -> 1e-4,
+    distanceMeasure -> DistanceMeasure.EUCLIDEAN)
 
   @Since("1.5.0")
   override def copy(extra: ParamMap): KMeans = defaultCopy(extra)
@@ -285,6 +295,10 @@ class KMeans @Since("1.5.0") (
   def setInitMode(value: String): this.type = set(initMode, value)
 
   /** @group expertSetParam */
+  @Since("2.4.0")
+  def setDistanceMeasure(value: String): this.type = set(distanceMeasure, value)
+
+  /** @group expertSetParam */
   @Since("1.5.0")
   def setInitSteps(value: Int): this.type = set(initSteps, value)
 
@@ -314,7 +328,8 @@ class KMeans @Since("1.5.0") (
     }
 
     val instr = Instrumentation.create(this, instances)
-    instr.logParams(featuresCol, predictionCol, k, initMode, initSteps, maxIter, seed, tol)
+    instr.logParams(featuresCol, predictionCol, k, initMode, initSteps, distanceMeasure,
+      maxIter, seed, tol)
     val algo = new MLlibKMeans()
       .setK($(k))
       .setInitializationMode($(initMode))
@@ -322,6 +337,7 @@ class KMeans @Since("1.5.0") (
       .setMaxIterations($(maxIter))
       .setSeed($(seed))
       .setEpsilon($(tol))
+      .setDistanceMeasure($(distanceMeasure))
     val parentModel = algo.run(instances, Option(instr))
     val model = copyValues(new KMeansModel(uid, parentModel).setParent(this))
     val summary = new KMeansSummary(

http://git-wip-us.apache.org/repos/asf/spark/blob/4f43d27c/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala
index 9b9c70c..2221f4c 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala
@@ -350,7 +350,7 @@ private object BisectingKMeans extends Serializable {
         val newClusterChildren = children.filter(newClusterCenters.contains(_))
         if (newClusterChildren.nonEmpty) {
           val selected = newClusterChildren.minBy { child =>
-            KMeans.fastSquaredDistance(newClusterCenters(child), v)
+            EuclideanDistanceMeasure.fastSquaredDistance(newClusterCenters(child), v)
           }
           (selected, v)
         } else {
@@ -387,7 +387,7 @@ private object BisectingKMeans extends Serializable {
         val rightIndex = rightChildIndex(rawIndex)
         val indexes = Seq(leftIndex, rightIndex).filter(clusters.contains(_))
         val height = math.sqrt(indexes.map { childIndex =>
-          KMeans.fastSquaredDistance(center, clusters(childIndex).center)
+          EuclideanDistanceMeasure.fastSquaredDistance(center, clusters(childIndex).center)
         }.max)
         val children = indexes.map(buildSubTree(_)).toArray
         new ClusteringTreeNode(index, size, center, cost, height, children)
@@ -457,7 +457,7 @@ private[clustering] class ClusteringTreeNode private[clustering] (
       this :: Nil
     } else {
       val selected = children.minBy { child =>
-        KMeans.fastSquaredDistance(child.centerWithNorm, pointWithNorm)
+        EuclideanDistanceMeasure.fastSquaredDistance(child.centerWithNorm, pointWithNorm)
       }
       selected :: selected.predictPath(pointWithNorm)
     }
@@ -475,7 +475,8 @@ private[clustering] class ClusteringTreeNode private[clustering] (
    * Predicts the cluster index and the cost of the input point.
    */
   private def predict(pointWithNorm: VectorWithNorm): (Int, Double) = {
-    predict(pointWithNorm, KMeans.fastSquaredDistance(centerWithNorm, pointWithNorm))
+    predict(pointWithNorm,
+      EuclideanDistanceMeasure.fastSquaredDistance(centerWithNorm, pointWithNorm))
   }
 
   /**
@@ -490,7 +491,7 @@ private[clustering] class ClusteringTreeNode private[clustering] (
       (index, cost)
     } else {
       val (selectedChild, minCost) = children.map { child =>
-        (child, KMeans.fastSquaredDistance(child.centerWithNorm, pointWithNorm))
+        (child, EuclideanDistanceMeasure.fastSquaredDistance(child.centerWithNorm, pointWithNorm))
       }.minBy(_._2)
       selectedChild.predict(pointWithNorm, minCost)
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/4f43d27c/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
index 49043b5..607145c 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
@@ -25,7 +25,7 @@ import org.apache.spark.internal.Logging
 import org.apache.spark.ml.clustering.{KMeans => NewKMeans}
 import org.apache.spark.ml.util.Instrumentation
 import org.apache.spark.mllib.linalg.{Vector, Vectors}
-import org.apache.spark.mllib.linalg.BLAS.{axpy, scal}
+import org.apache.spark.mllib.linalg.BLAS.{axpy, dot, scal}
 import org.apache.spark.mllib.util.MLUtils
 import org.apache.spark.rdd.RDD
 import org.apache.spark.storage.StorageLevel
@@ -46,14 +46,23 @@ class KMeans private (
     private var initializationMode: String,
     private var initializationSteps: Int,
     private var epsilon: Double,
-    private var seed: Long) extends Serializable with Logging {
+    private var seed: Long,
+    private var distanceMeasure: String) extends Serializable with Logging {
+
+  @Since("0.8.0")
+  private def this(k: Int, maxIterations: Int, initializationMode: String, initializationSteps: Int,
+      epsilon: Double, seed: Long) =
+    this(k, maxIterations, initializationMode, initializationSteps,
+      epsilon, seed, DistanceMeasure.EUCLIDEAN)
 
   /**
    * Constructs a KMeans instance with default parameters: {k: 2, maxIterations: 20,
-   * initializationMode: "k-means||", initializationSteps: 2, epsilon: 1e-4, seed: random}.
+   * initializationMode: "k-means||", initializationSteps: 2, epsilon: 1e-4, seed: random,
+   * distanceMeasure: "euclidean"}.
    */
   @Since("0.8.0")
-  def this() = this(2, 20, KMeans.K_MEANS_PARALLEL, 2, 1e-4, Utils.random.nextLong())
+  def this() = this(2, 20, KMeans.K_MEANS_PARALLEL, 2, 1e-4, Utils.random.nextLong(),
+    DistanceMeasure.EUCLIDEAN)
 
   /**
    * Number of clusters to create (k).
@@ -184,6 +193,22 @@ class KMeans private (
     this
   }
 
+  /**
+   * The distance suite used by the algorithm.
+   */
+  @Since("2.4.0")
+  def getDistanceMeasure: String = distanceMeasure
+
+  /**
+   * Set the distance suite used by the algorithm.
+   */
+  @Since("2.4.0")
+  def setDistanceMeasure(distanceMeasure: String): this.type = {
+    KMeans.validateDistanceMeasure(distanceMeasure)
+    this.distanceMeasure = distanceMeasure
+    this
+  }
+
   // Initial cluster centers can be provided as a KMeansModel object rather than using the
   // random or k-means|| initializationMode
   private var initialModel: Option[KMeansModel] = None
@@ -246,6 +271,8 @@ class KMeans private (
 
     val initStartTime = System.nanoTime()
 
+    val distanceMeasureInstance = DistanceMeasure.decodeFromString(this.distanceMeasure)
+
     val centers = initialModel match {
       case Some(kMeansCenters) =>
         kMeansCenters.clusterCenters.map(new VectorWithNorm(_))
@@ -253,7 +280,7 @@ class KMeans private (
         if (initializationMode == KMeans.RANDOM) {
           initRandom(data)
         } else {
-          initKMeansParallel(data)
+          initKMeansParallel(data, distanceMeasureInstance)
         }
     }
     val initTimeInSeconds = (System.nanoTime() - initStartTime) / 1e9
@@ -281,7 +308,7 @@ class KMeans private (
         val counts = Array.fill(thisCenters.length)(0L)
 
         points.foreach { point =>
-          val (bestCenter, cost) = KMeans.findClosest(thisCenters, point)
+          val (bestCenter, cost) = distanceMeasureInstance.findClosest(thisCenters, point)
           costAccum.add(cost)
           val sum = sums(bestCenter)
           axpy(1.0, point.vector, sum)
@@ -302,7 +329,8 @@ class KMeans private (
       // Update the cluster centers and costs
       converged = true
       newCenters.foreach { case (j, newCenter) =>
-        if (converged && KMeans.fastSquaredDistance(newCenter, centers(j)) > epsilon * epsilon) {
+        if (converged &&
+          !distanceMeasureInstance.isCenterConverged(centers(j), newCenter, epsilon)) {
           converged = false
         }
         centers(j) = newCenter
@@ -323,7 +351,7 @@ class KMeans private (
 
     logInfo(s"The cost is $cost.")
 
-    new KMeansModel(centers.map(_.vector))
+    new KMeansModel(centers.map(_.vector), distanceMeasure)
   }
 
   /**
@@ -345,7 +373,8 @@ class KMeans private (
    *
    * The original paper can be found at http://theory.stanford.edu/~sergei/papers/vldb12-kmpar.pdf.
    */
-  private[clustering] def initKMeansParallel(data: RDD[VectorWithNorm]): Array[VectorWithNorm] = {
+  private[clustering] def initKMeansParallel(data: RDD[VectorWithNorm],
+      distanceMeasureInstance: DistanceMeasure): Array[VectorWithNorm] = {
     // Initialize empty centers and point costs.
     var costs = data.map(_ => Double.PositiveInfinity)
 
@@ -369,7 +398,7 @@ class KMeans private (
       bcNewCentersList += bcNewCenters
       val preCosts = costs
       costs = data.zip(preCosts).map { case (point, cost) =>
-        math.min(KMeans.pointCost(bcNewCenters.value, point), cost)
+        math.min(distanceMeasureInstance.pointCost(bcNewCenters.value, point), cost)
       }.persist(StorageLevel.MEMORY_AND_DISK)
       val sumCosts = costs.sum()
 
@@ -397,7 +426,9 @@ class KMeans private (
       // candidate by the number of points in the dataset mapping to it and run a local k-means++
       // on the weighted centers to pick k of them
       val bcCenters = data.context.broadcast(distinctCenters)
-      val countMap = data.map(KMeans.findClosest(bcCenters.value, _)._1).countByValue()
+      val countMap = data
+        .map(distanceMeasureInstance.findClosest(bcCenters.value, _)._1)
+        .countByValue()
 
       bcCenters.destroy(blocking = false)
 
@@ -546,10 +577,110 @@ object KMeans {
       .run(data)
   }
 
+  private[spark] def validateInitMode(initMode: String): Boolean = {
+    initMode match {
+      case KMeans.RANDOM => true
+      case KMeans.K_MEANS_PARALLEL => true
+      case _ => false
+    }
+  }
+
+  private[spark] def validateDistanceMeasure(distanceMeasure: String): Boolean = {
+    distanceMeasure match {
+      case DistanceMeasure.EUCLIDEAN => true
+      case DistanceMeasure.COSINE => true
+      case _ => false
+    }
+  }
+}
+
+/**
+ * A vector with its norm for fast distance computation.
+ */
+private[clustering] class VectorWithNorm(val vector: Vector, val norm: Double)
+    extends Serializable {
+
+  def this(vector: Vector) = this(vector, Vectors.norm(vector, 2.0))
+
+  def this(array: Array[Double]) = this(Vectors.dense(array))
+
+  /** Converts the vector to a dense vector. */
+  def toDense: VectorWithNorm = new VectorWithNorm(Vectors.dense(vector.toArray), norm)
+}
+
+
+private[spark] abstract class DistanceMeasure extends Serializable {
+
+  /**
+   * @return the index of the closest center to the given point, as well as the cost.
+   */
+  def findClosest(
+      centers: TraversableOnce[VectorWithNorm],
+      point: VectorWithNorm): (Int, Double) = {
+    var bestDistance = Double.PositiveInfinity
+    var bestIndex = 0
+    var i = 0
+    centers.foreach { center =>
+      val currentDistance = distance(center, point)
+      if (currentDistance < bestDistance) {
+        bestDistance = currentDistance
+        bestIndex = i
+      }
+      i += 1
+    }
+    (bestIndex, bestDistance)
+  }
+
   /**
-   * Returns the index of the closest center to the given point, as well as the squared distance.
+   * @return the K-means cost of a given point against the given cluster centers.
    */
-  private[mllib] def findClosest(
+  def pointCost(
+      centers: TraversableOnce[VectorWithNorm],
+      point: VectorWithNorm): Double = {
+    findClosest(centers, point)._2
+  }
+
+  /**
+   * @return whether a center converged or not, given the epsilon parameter.
+   */
+  def isCenterConverged(
+      oldCenter: VectorWithNorm,
+      newCenter: VectorWithNorm,
+      epsilon: Double): Boolean = {
+    distance(oldCenter, newCenter) <= epsilon
+  }
+
+  /**
+   * @return the cosine distance between two points.
+   */
+  def distance(
+      v1: VectorWithNorm,
+      v2: VectorWithNorm): Double
+
+}
+
+@Since("2.4.0")
+object DistanceMeasure {
+
+  @Since("2.4.0")
+  val EUCLIDEAN = "euclidean"
+  @Since("2.4.0")
+  val COSINE = "cosine"
+
+  private[spark] def decodeFromString(distanceMeasure: String): DistanceMeasure =
+    distanceMeasure match {
+      case EUCLIDEAN => new EuclideanDistanceMeasure
+      case COSINE => new CosineDistanceMeasure
+      case _ => throw new IllegalArgumentException(s"distanceMeasure must be one of: " +
+        s"$EUCLIDEAN, $COSINE. $distanceMeasure provided.")
+    }
+}
+
+private[spark] class EuclideanDistanceMeasure extends DistanceMeasure {
+  /**
+   * @return the index of the closest center to the given point, as well as the squared distance.
+   */
+  override def findClosest(
       centers: TraversableOnce[VectorWithNorm],
       point: VectorWithNorm): (Int, Double) = {
     var bestDistance = Double.PositiveInfinity
@@ -561,7 +692,7 @@ object KMeans {
       var lowerBoundOfSqDist = center.norm - point.norm
       lowerBoundOfSqDist = lowerBoundOfSqDist * lowerBoundOfSqDist
       if (lowerBoundOfSqDist < bestDistance) {
-        val distance: Double = fastSquaredDistance(center, point)
+        val distance: Double = EuclideanDistanceMeasure.fastSquaredDistance(center, point)
         if (distance < bestDistance) {
           bestDistance = distance
           bestIndex = i
@@ -573,15 +704,29 @@ object KMeans {
   }
 
   /**
-   * Returns the K-means cost of a given point against the given cluster centers.
+   * @return whether a center converged or not, given the epsilon parameter.
    */
-  private[mllib] def pointCost(
-      centers: TraversableOnce[VectorWithNorm],
-      point: VectorWithNorm): Double =
-    findClosest(centers, point)._2
+  override def isCenterConverged(
+      oldCenter: VectorWithNorm,
+      newCenter: VectorWithNorm,
+      epsilon: Double): Boolean = {
+    EuclideanDistanceMeasure.fastSquaredDistance(newCenter, oldCenter) <= epsilon * epsilon
+  }
+
+  /**
+   * @param v1: first vector
+   * @param v2: second vector
+   * @return the Euclidean distance between the two input vectors
+   */
+  override def distance(v1: VectorWithNorm, v2: VectorWithNorm): Double = {
+    Math.sqrt(EuclideanDistanceMeasure.fastSquaredDistance(v1, v2))
+  }
+}
+
 
+private[spark] object EuclideanDistanceMeasure {
   /**
-   * Returns the squared Euclidean distance between two vectors computed by
+   * @return the squared Euclidean distance between two vectors computed by
    * [[org.apache.spark.mllib.util.MLUtils#fastSquaredDistance]].
    */
   private[clustering] def fastSquaredDistance(
@@ -589,28 +734,15 @@ object KMeans {
       v2: VectorWithNorm): Double = {
     MLUtils.fastSquaredDistance(v1.vector, v1.norm, v2.vector, v2.norm)
   }
-
-  private[spark] def validateInitMode(initMode: String): Boolean = {
-    initMode match {
-      case KMeans.RANDOM => true
-      case KMeans.K_MEANS_PARALLEL => true
-      case _ => false
-    }
-  }
 }
 
-/**
- * A vector with its norm for fast distance computation.
- *
- * @see [[org.apache.spark.mllib.clustering.KMeans#fastSquaredDistance]]
- */
-private[clustering]
-class VectorWithNorm(val vector: Vector, val norm: Double) extends Serializable {
-
-  def this(vector: Vector) = this(vector, Vectors.norm(vector, 2.0))
-
-  def this(array: Array[Double]) = this(Vectors.dense(array))
-
-  /** Converts the vector to a dense vector. */
-  def toDense: VectorWithNorm = new VectorWithNorm(Vectors.dense(vector.toArray), norm)
+private[spark] class CosineDistanceMeasure extends DistanceMeasure {
+  /**
+   * @param v1: first vector
+   * @param v2: second vector
+   * @return the cosine distance between the two input vectors
+   */
+  override def distance(v1: VectorWithNorm, v2: VectorWithNorm): Double = {
+    1 - dot(v1.vector, v2.vector) / v1.norm / v2.norm
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/4f43d27c/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala
index 3ad08c4..a78c21e 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala
@@ -36,12 +36,20 @@ import org.apache.spark.sql.{Row, SparkSession}
  * A clustering model for K-means. Each point belongs to the cluster with the closest center.
  */
 @Since("0.8.0")
-class KMeansModel @Since("1.1.0") (@Since("1.0.0") val clusterCenters: Array[Vector])
+class KMeansModel @Since("2.4.0") (@Since("1.0.0") val clusterCenters: Array[Vector],
+  @Since("2.4.0") val distanceMeasure: String)
   extends Saveable with Serializable with PMMLExportable {
 
+  private val distanceMeasureInstance: DistanceMeasure =
+    DistanceMeasure.decodeFromString(distanceMeasure)
+
   private val clusterCentersWithNorm =
     if (clusterCenters == null) null else clusterCenters.map(new VectorWithNorm(_))
 
+  @Since("1.1.0")
+  def this(clusterCenters: Array[Vector]) =
+    this(clusterCenters: Array[Vector], DistanceMeasure.EUCLIDEAN)
+
   /**
    * A Java-friendly constructor that takes an Iterable of Vectors.
    */
@@ -59,7 +67,7 @@ class KMeansModel @Since("1.1.0") (@Since("1.0.0") val clusterCenters: Array[Vec
    */
   @Since("0.8.0")
   def predict(point: Vector): Int = {
-    KMeans.findClosest(clusterCentersWithNorm, new VectorWithNorm(point))._1
+    distanceMeasureInstance.findClosest(clusterCentersWithNorm, new VectorWithNorm(point))._1
   }
 
   /**
@@ -68,7 +76,8 @@ class KMeansModel @Since("1.1.0") (@Since("1.0.0") val clusterCenters: Array[Vec
   @Since("1.0.0")
   def predict(points: RDD[Vector]): RDD[Int] = {
     val bcCentersWithNorm = points.context.broadcast(clusterCentersWithNorm)
-    points.map(p => KMeans.findClosest(bcCentersWithNorm.value, new VectorWithNorm(p))._1)
+    points.map(p =>
+      distanceMeasureInstance.findClosest(bcCentersWithNorm.value, new VectorWithNorm(p))._1)
   }
 
   /**
@@ -85,8 +94,9 @@ class KMeansModel @Since("1.1.0") (@Since("1.0.0") val clusterCenters: Array[Vec
   @Since("0.8.0")
   def computeCost(data: RDD[Vector]): Double = {
     val bcCentersWithNorm = data.context.broadcast(clusterCentersWithNorm)
-    val cost = data
-      .map(p => KMeans.pointCost(bcCentersWithNorm.value, new VectorWithNorm(p))).sum()
+    val cost = data.map(p =>
+      distanceMeasureInstance.pointCost(bcCentersWithNorm.value, new VectorWithNorm(p)))
+      .sum()
     bcCentersWithNorm.destroy(blocking = false)
     cost
   }
@@ -94,7 +104,7 @@ class KMeansModel @Since("1.1.0") (@Since("1.0.0") val clusterCenters: Array[Vec
 
   @Since("1.4.0")
   override def save(sc: SparkContext, path: String): Unit = {
-    KMeansModel.SaveLoadV1_0.save(sc, this, path)
+    KMeansModel.SaveLoadV2_0.save(sc, this, path)
   }
 
   override protected def formatVersion: String = "1.0"
@@ -105,7 +115,20 @@ object KMeansModel extends Loader[KMeansModel] {
 
   @Since("1.4.0")
   override def load(sc: SparkContext, path: String): KMeansModel = {
-    KMeansModel.SaveLoadV1_0.load(sc, path)
+    val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path)
+    val classNameV1_0 = SaveLoadV1_0.thisClassName
+    val classNameV2_0 = SaveLoadV2_0.thisClassName
+    (loadedClassName, version) match {
+      case (className, "1.0") if className == classNameV1_0 =>
+        SaveLoadV1_0.load(sc, path)
+      case (className, "2.0") if className == classNameV2_0 =>
+        SaveLoadV2_0.load(sc, path)
+      case _ => throw new Exception(
+        s"KMeansModel.load did not recognize model with (className, format version):" +
+          s"($loadedClassName, $version).  Supported:\n" +
+          s"  ($classNameV1_0, 1.0)\n" +
+          s"  ($classNameV2_0, 2.0)")
+    }
   }
 
   private case class Cluster(id: Int, point: Vector)
@@ -116,8 +139,7 @@ object KMeansModel extends Loader[KMeansModel] {
     }
   }
 
-  private[clustering]
-  object SaveLoadV1_0 {
+  private[clustering] object SaveLoadV1_0 {
 
     private val thisFormatVersion = "1.0"
 
@@ -149,4 +171,38 @@ object KMeansModel extends Loader[KMeansModel] {
       new KMeansModel(localCentroids.sortBy(_.id).map(_.point))
     }
   }
+
+  private[clustering] object SaveLoadV2_0 {
+
+    private val thisFormatVersion = "2.0"
+
+    private[clustering] val thisClassName = "org.apache.spark.mllib.clustering.KMeansModel"
+
+    def save(sc: SparkContext, model: KMeansModel, path: String): Unit = {
+      val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
+      val metadata = compact(render(
+        ("class" -> thisClassName) ~ ("version" -> thisFormatVersion)
+         ~ ("k" -> model.k) ~ ("distanceMeasure" -> model.distanceMeasure)))
+      sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))
+      val dataRDD = sc.parallelize(model.clusterCentersWithNorm.zipWithIndex).map { case (p, id) =>
+        Cluster(id, p.vector)
+      }
+      spark.createDataFrame(dataRDD).write.parquet(Loader.dataPath(path))
+    }
+
+    def load(sc: SparkContext, path: String): KMeansModel = {
+      implicit val formats = DefaultFormats
+      val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
+      val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path)
+      assert(className == thisClassName)
+      assert(formatVersion == thisFormatVersion)
+      val k = (metadata \ "k").extract[Int]
+      val centroids = spark.read.parquet(Loader.dataPath(path))
+      Loader.checkSchema[Cluster](centroids.schema)
+      val localCentroids = centroids.rdd.map(Cluster.apply).collect()
+      assert(k == localCentroids.length)
+      val distanceMeasure = (metadata \ "distanceMeasure").extract[String]
+      new KMeansModel(localCentroids.sortBy(_.id).map(_.point), distanceMeasure)
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/4f43d27c/mllib/src/main/scala/org/apache/spark/mllib/clustering/LocalKMeans.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LocalKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LocalKMeans.scala
index 5358767..4a08c0a 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LocalKMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LocalKMeans.scala
@@ -46,7 +46,7 @@ private[mllib] object LocalKMeans extends Logging {
 
     // Initialize centers by sampling using the k-means++ procedure.
     centers(0) = pickWeighted(rand, points, weights).toDense
-    val costArray = points.map(KMeans.fastSquaredDistance(_, centers(0)))
+    val costArray = points.map(EuclideanDistanceMeasure.fastSquaredDistance(_, centers(0)))
 
     for (i <- 1 until k) {
       val sum = costArray.zip(weights).map(p => p._1 * p._2).sum
@@ -67,11 +67,15 @@ private[mllib] object LocalKMeans extends Logging {
 
       // update costArray
       for (p <- points.indices) {
-        costArray(p) = math.min(KMeans.fastSquaredDistance(points(p), centers(i)), costArray(p))
+        costArray(p) = math.min(
+          EuclideanDistanceMeasure.fastSquaredDistance(points(p), centers(i)),
+          costArray(p))
       }
 
     }
 
+    val distanceMeasureInstance = new EuclideanDistanceMeasure
+
     // Run up to maxIterations iterations of Lloyd's algorithm
     val oldClosest = Array.fill(points.length)(-1)
     var iteration = 0
@@ -83,7 +87,7 @@ private[mllib] object LocalKMeans extends Logging {
       var i = 0
       while (i < points.length) {
         val p = points(i)
-        val index = KMeans.findClosest(centers, p)._1
+        val index = distanceMeasureInstance.findClosest(centers, p)._1
         axpy(weights(i), p.vector, sums(index))
         counts(index) += weights(i)
         if (index != oldClosest(i)) {

http://git-wip-us.apache.org/repos/asf/spark/blob/4f43d27c/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
index 119fe1d..e4506f2 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
@@ -23,7 +23,7 @@ import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.linalg.{Vector, Vectors}
 import org.apache.spark.ml.param.ParamMap
 import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
-import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans}
+import org.apache.spark.mllib.clustering.{DistanceMeasure, KMeans => MLlibKMeans}
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
 
@@ -50,6 +50,7 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR
     assert(kmeans.getInitMode === MLlibKMeans.K_MEANS_PARALLEL)
     assert(kmeans.getInitSteps === 2)
     assert(kmeans.getTol === 1e-4)
+    assert(kmeans.getDistanceMeasure === DistanceMeasure.EUCLIDEAN)
     val model = kmeans.setMaxIter(1).fit(dataset)
 
     MLTestingUtils.checkCopyAndUids(kmeans, model)
@@ -68,6 +69,7 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR
       .setInitSteps(3)
       .setSeed(123)
       .setTol(1e-3)
+      .setDistanceMeasure(DistanceMeasure.COSINE)
 
     assert(kmeans.getK === 9)
     assert(kmeans.getFeaturesCol === "test_feature")
@@ -77,6 +79,7 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR
     assert(kmeans.getInitSteps === 3)
     assert(kmeans.getSeed === 123)
     assert(kmeans.getTol === 1e-3)
+    assert(kmeans.getDistanceMeasure === DistanceMeasure.COSINE)
   }
 
   test("parameters validation") {
@@ -89,6 +92,9 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR
     intercept[IllegalArgumentException] {
       new KMeans().setInitSteps(0)
     }
+    intercept[IllegalArgumentException] {
+      new KMeans().setDistanceMeasure("no_such_a_measure")
+    }
   }
 
   test("fit, transform and summary") {
@@ -144,6 +150,37 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR
     assert(model.getPredictionCol == predictionColName)
   }
 
+  test("KMeans using cosine distance") {
+    val df = spark.createDataFrame(spark.sparkContext.parallelize(Array(
+      Vectors.dense(1.0, 1.0),
+      Vectors.dense(10.0, 10.0),
+      Vectors.dense(1.0, 0.5),
+      Vectors.dense(10.0, 4.4),
+      Vectors.dense(-1.0, 1.0),
+      Vectors.dense(-100.0, 90.0)
+    )).map(v => TestRow(v)))
+
+    val model = new KMeans()
+      .setK(3)
+      .setSeed(1)
+      .setInitMode(MLlibKMeans.RANDOM)
+      .setTol(1e-6)
+      .setDistanceMeasure(DistanceMeasure.COSINE)
+      .fit(df)
+
+    val predictionDf = model.transform(df)
+    assert(predictionDf.select("prediction").distinct().count() == 3)
+    val predictionsMap = predictionDf.collect().map(row =>
+      row.getAs[Vector]("features") -> row.getAs[Int]("prediction")).toMap
+    assert(predictionsMap(Vectors.dense(1.0, 1.0)) ==
+      predictionsMap(Vectors.dense(10.0, 10.0)))
+    assert(predictionsMap(Vectors.dense(1.0, 0.5)) ==
+      predictionsMap(Vectors.dense(10.0, 4.4)))
+    assert(predictionsMap(Vectors.dense(-1.0, 1.0)) ==
+      predictionsMap(Vectors.dense(-100.0, 90.0)))
+
+  }
+
   test("read/write") {
     def checkModelData(model: KMeansModel, model2: KMeansModel): Unit = {
       assert(model.clusterCenters === model2.clusterCenters)
@@ -182,6 +219,7 @@ object KMeansSuite {
     "predictionCol" -> "myPrediction",
     "k" -> 3,
     "maxIter" -> 2,
-    "tol" -> 0.01
+    "tol" -> 0.01,
+    "distanceMeasure" -> DistanceMeasure.EUCLIDEAN
   )
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/4f43d27c/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
index 00d7e2f..1b98250 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
@@ -89,7 +89,9 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext {
       .setInitializationMode("k-means||")
       .setInitializationSteps(10)
       .setSeed(seed)
-    val initialCenters = km.initKMeansParallel(normedData).map(_.vector)
+
+    val distanceMeasureInstance = new EuclideanDistanceMeasure
+    val initialCenters = km.initKMeansParallel(normedData, distanceMeasureInstance).map(_.vector)
     assert(initialCenters.length === initialCenters.distinct.length)
     assert(initialCenters.length <= numDistinctPoints)
 
@@ -104,7 +106,7 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext {
       .setInitializationMode("k-means||")
       .setInitializationSteps(10)
       .setSeed(seed)
-    val initialCenters2 = km2.initKMeansParallel(normedData).map(_.vector)
+    val initialCenters2 = km2.initKMeansParallel(normedData, distanceMeasureInstance).map(_.vector)
     assert(initialCenters2.length === initialCenters2.distinct.length)
     assert(initialCenters2.length === k)
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org