You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by sr...@apache.org on 2018/03/12 19:53:18 UTC

spark git commit: [SPARK-23412][ML] Add cosine distance to BisectingKMeans

Repository: spark
Updated Branches:
  refs/heads/master d5b41aea6 -> 567bd31e0


[SPARK-23412][ML] Add cosine distance to BisectingKMeans

## What changes were proposed in this pull request?

The PR adds the option to specify a distance measure in BisectingKMeans. Moreover, it introduces the ability to use the cosine distance measure in it.

## How was this patch tested?

added UTs + existing UTs

Author: Marco Gaido <ma...@gmail.com>

Closes #20600 from mgaido91/SPARK-23412.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/567bd31e
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/567bd31e
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/567bd31e

Branch: refs/heads/master
Commit: 567bd31e0ae8b632357baa93e1469b666fb06f3d
Parents: d5b41ae
Author: Marco Gaido <ma...@gmail.com>
Authored: Mon Mar 12 14:53:15 2018 -0500
Committer: Sean Owen <sr...@gmail.com>
Committed: Mon Mar 12 14:53:15 2018 -0500

----------------------------------------------------------------------
 .../spark/ml/clustering/BisectingKMeans.scala   |  16 +-
 .../org/apache/spark/ml/clustering/KMeans.scala |  11 +-
 .../ml/param/shared/SharedParamsCodeGen.scala   |   6 +-
 .../spark/ml/param/shared/sharedParams.scala    |  19 ++
 .../mllib/clustering/BisectingKMeans.scala      | 139 +++++----
 .../mllib/clustering/BisectingKMeansModel.scala | 115 +++++--
 .../mllib/clustering/DistanceMeasure.scala      | 303 +++++++++++++++++++
 .../apache/spark/mllib/clustering/KMeans.scala  | 196 +-----------
 .../ml/clustering/BisectingKMeansSuite.scala    |  44 ++-
 project/MimaExcludes.scala                      |   6 +
 10 files changed, 557 insertions(+), 298 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/567bd31e/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
index 4c20e65..f7c422d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
@@ -26,7 +26,8 @@ import org.apache.spark.ml.linalg.{Vector, VectorUDT}
 import org.apache.spark.ml.param._
 import org.apache.spark.ml.param.shared._
 import org.apache.spark.ml.util._
-import org.apache.spark.mllib.clustering.{BisectingKMeans => MLlibBisectingKMeans, BisectingKMeansModel => MLlibBisectingKMeansModel}
+import org.apache.spark.mllib.clustering.{BisectingKMeans => MLlibBisectingKMeans,
+  BisectingKMeansModel => MLlibBisectingKMeansModel}
 import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors}
 import org.apache.spark.mllib.linalg.VectorImplicits._
 import org.apache.spark.rdd.RDD
@@ -38,8 +39,8 @@ import org.apache.spark.sql.types.{IntegerType, StructType}
 /**
  * Common params for BisectingKMeans and BisectingKMeansModel
  */
-private[clustering] trait BisectingKMeansParams extends Params
-  with HasMaxIter with HasFeaturesCol with HasSeed with HasPredictionCol {
+private[clustering] trait BisectingKMeansParams extends Params with HasMaxIter
+  with HasFeaturesCol with HasSeed with HasPredictionCol with HasDistanceMeasure {
 
   /**
    * The desired number of leaf clusters. Must be &gt; 1. Default: 4.
@@ -104,6 +105,10 @@ class BisectingKMeansModel private[ml] (
   @Since("2.1.0")
   def setPredictionCol(value: String): this.type = set(predictionCol, value)
 
+  /** @group expertSetParam */
+  @Since("2.4.0")
+  def setDistanceMeasure(value: String): this.type = set(distanceMeasure, value)
+
   @Since("2.0.0")
   override def transform(dataset: Dataset[_]): DataFrame = {
     transformSchema(dataset.schema, logging = true)
@@ -248,6 +253,10 @@ class BisectingKMeans @Since("2.0.0") (
   @Since("2.0.0")
   def setMinDivisibleClusterSize(value: Double): this.type = set(minDivisibleClusterSize, value)
 
+  /** @group expertSetParam */
+  @Since("2.4.0")
+  def setDistanceMeasure(value: String): this.type = set(distanceMeasure, value)
+
   @Since("2.0.0")
   override def fit(dataset: Dataset[_]): BisectingKMeansModel = {
     transformSchema(dataset.schema, logging = true)
@@ -263,6 +272,7 @@ class BisectingKMeans @Since("2.0.0") (
       .setMaxIterations($(maxIter))
       .setMinDivisibleClusterSize($(minDivisibleClusterSize))
       .setSeed($(seed))
+      .setDistanceMeasure($(distanceMeasure))
     val parentModel = bkm.run(rdd)
     val model = copyValues(new BisectingKMeansModel(uid, parentModel).setParent(this))
     val summary = new BisectingKMeansSummary(

http://git-wip-us.apache.org/repos/asf/spark/blob/567bd31e/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
index c8145de..987a428 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
@@ -40,7 +40,7 @@ import org.apache.spark.util.VersionUtils.majorVersion
  * Common params for KMeans and KMeansModel
  */
 private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFeaturesCol
-  with HasSeed with HasPredictionCol with HasTol {
+  with HasSeed with HasPredictionCol with HasTol with HasDistanceMeasure {
 
   /**
    * The number of clusters to create (k). Must be &gt; 1. Note that it is possible for fewer than
@@ -71,15 +71,6 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe
   @Since("1.5.0")
   def getInitMode: String = $(initMode)
 
-  @Since("2.4.0")
-  final val distanceMeasure = new Param[String](this, "distanceMeasure", "The distance measure. " +
-    "Supported options: 'euclidean' and 'cosine'.",
-    (value: String) => MLlibKMeans.validateDistanceMeasure(value))
-
-  /** @group expertGetParam */
-  @Since("2.4.0")
-  def getDistanceMeasure: String = $(distanceMeasure)
-
   /**
    * Param for the number of steps for the k-means|| initialization mode. This is an advanced
    * setting -- the default of 2 is almost always enough. Must be &gt; 0. Default: 2.

http://git-wip-us.apache.org/repos/asf/spark/blob/567bd31e/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
index 6ad44af..b9c3170 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
@@ -91,7 +91,11 @@ private[shared] object SharedParamsCodeGen {
         "after fitting. If set to true, then all sub-models will be available. Warning: For " +
         "large models, collecting all sub-models can cause OOMs on the Spark driver",
         Some("false"), isExpertParam = true),
-      ParamDesc[String]("loss", "the loss function to be optimized", finalFields = false)
+      ParamDesc[String]("loss", "the loss function to be optimized", finalFields = false),
+      ParamDesc[String]("distanceMeasure", "The distance measure. Supported options: 'euclidean'" +
+        " and 'cosine'", Some("org.apache.spark.mllib.clustering.DistanceMeasure.EUCLIDEAN"),
+        isValid = "(value: String) => " +
+        "org.apache.spark.mllib.clustering.DistanceMeasure.validateDistanceMeasure(value)")
     )
 
     val code = genSharedParams(params)

http://git-wip-us.apache.org/repos/asf/spark/blob/567bd31e/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
index be8b2f2..282ea6e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
@@ -504,4 +504,23 @@ trait HasLoss extends Params {
   /** @group getParam */
   final def getLoss: String = $(loss)
 }
+
+/**
+ * Trait for shared param distanceMeasure (default: org.apache.spark.mllib.clustering.DistanceMeasure.EUCLIDEAN). This trait may be changed or
+ * removed between minor versions.
+ */
+@DeveloperApi
+trait HasDistanceMeasure extends Params {
+
+  /**
+   * Param for The distance measure. Supported options: 'euclidean' and 'cosine'.
+   * @group param
+   */
+  final val distanceMeasure: Param[String] = new Param[String](this, "distanceMeasure", "The distance measure. Supported options: 'euclidean' and 'cosine'", (value: String) => org.apache.spark.mllib.clustering.DistanceMeasure.validateDistanceMeasure(value))
+
+  setDefault(distanceMeasure, org.apache.spark.mllib.clustering.DistanceMeasure.EUCLIDEAN)
+
+  /** @group getParam */
+  final def getDistanceMeasure: String = $(distanceMeasure)
+}
 // scalastyle:on

http://git-wip-us.apache.org/repos/asf/spark/blob/567bd31e/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala
index 2221f4c..98af487 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala
@@ -25,7 +25,7 @@ import scala.collection.mutable
 import org.apache.spark.annotation.Since
 import org.apache.spark.api.java.JavaRDD
 import org.apache.spark.internal.Logging
-import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors}
+import org.apache.spark.mllib.linalg.{Vector, Vectors}
 import org.apache.spark.mllib.util.MLUtils
 import org.apache.spark.rdd.RDD
 import org.apache.spark.storage.StorageLevel
@@ -57,7 +57,8 @@ class BisectingKMeans private (
     private var k: Int,
     private var maxIterations: Int,
     private var minDivisibleClusterSize: Double,
-    private var seed: Long) extends Logging {
+    private var seed: Long,
+    private var distanceMeasure: String) extends Logging {
 
   import BisectingKMeans._
 
@@ -65,7 +66,7 @@ class BisectingKMeans private (
    * Constructs with the default configuration
    */
   @Since("1.6.0")
-  def this() = this(4, 20, 1.0, classOf[BisectingKMeans].getName.##)
+  def this() = this(4, 20, 1.0, classOf[BisectingKMeans].getName.##, DistanceMeasure.EUCLIDEAN)
 
   /**
    * Sets the desired number of leaf clusters (default: 4).
@@ -135,6 +136,22 @@ class BisectingKMeans private (
   def getSeed: Long = this.seed
 
   /**
+   * The distance suite used by the algorithm.
+   */
+  @Since("2.4.0")
+  def getDistanceMeasure: String = distanceMeasure
+
+  /**
+   * Set the distance suite used by the algorithm.
+   */
+  @Since("2.4.0")
+  def setDistanceMeasure(distanceMeasure: String): this.type = {
+    DistanceMeasure.validateDistanceMeasure(distanceMeasure)
+    this.distanceMeasure = distanceMeasure
+    this
+  }
+
+  /**
    * Runs the bisecting k-means algorithm.
    * @param input RDD of vectors
    * @return model for the bisecting kmeans
@@ -147,11 +164,13 @@ class BisectingKMeans private (
     }
     val d = input.map(_.size).first()
     logInfo(s"Feature dimension: $d.")
+
+    val dMeasure: DistanceMeasure = DistanceMeasure.decodeFromString(this.distanceMeasure)
     // Compute and cache vector norms for fast distance computation.
     val norms = input.map(v => Vectors.norm(v, 2.0)).persist(StorageLevel.MEMORY_AND_DISK)
     val vectors = input.zip(norms).map { case (x, norm) => new VectorWithNorm(x, norm) }
     var assignments = vectors.map(v => (ROOT_INDEX, v))
-    var activeClusters = summarize(d, assignments)
+    var activeClusters = summarize(d, assignments, dMeasure)
     val rootSummary = activeClusters(ROOT_INDEX)
     val n = rootSummary.size
     logInfo(s"Number of points: $n.")
@@ -184,24 +203,25 @@ class BisectingKMeans private (
         val divisibleIndices = divisibleClusters.keys.toSet
         logInfo(s"Dividing ${divisibleIndices.size} clusters on level $level.")
         var newClusterCenters = divisibleClusters.flatMap { case (index, summary) =>
-          val (left, right) = splitCenter(summary.center, random)
+          val (left, right) = splitCenter(summary.center, random, dMeasure)
           Iterator((leftChildIndex(index), left), (rightChildIndex(index), right))
         }.map(identity) // workaround for a Scala bug (SI-7005) that produces a not serializable map
         var newClusters: Map[Long, ClusterSummary] = null
         var newAssignments: RDD[(Long, VectorWithNorm)] = null
         for (iter <- 0 until maxIterations) {
-          newAssignments = updateAssignments(assignments, divisibleIndices, newClusterCenters)
+          newAssignments = updateAssignments(assignments, divisibleIndices, newClusterCenters,
+              dMeasure)
             .filter { case (index, _) =>
             divisibleIndices.contains(parentIndex(index))
           }
-          newClusters = summarize(d, newAssignments)
+          newClusters = summarize(d, newAssignments, dMeasure)
           newClusterCenters = newClusters.mapValues(_.center).map(identity)
         }
         if (preIndices != null) {
           preIndices.unpersist(false)
         }
         preIndices = indices
-        indices = updateAssignments(assignments, divisibleIndices, newClusterCenters).keys
+        indices = updateAssignments(assignments, divisibleIndices, newClusterCenters, dMeasure).keys
           .persist(StorageLevel.MEMORY_AND_DISK)
         assignments = indices.zip(vectors)
         inactiveClusters ++= activeClusters
@@ -222,8 +242,8 @@ class BisectingKMeans private (
     }
     norms.unpersist(false)
     val clusters = activeClusters ++ inactiveClusters
-    val root = buildTree(clusters)
-    new BisectingKMeansModel(root)
+    val root = buildTree(clusters, dMeasure)
+    new BisectingKMeansModel(root, this.distanceMeasure)
   }
 
   /**
@@ -266,8 +286,9 @@ private object BisectingKMeans extends Serializable {
    */
   private def summarize(
       d: Int,
-      assignments: RDD[(Long, VectorWithNorm)]): Map[Long, ClusterSummary] = {
-    assignments.aggregateByKey(new ClusterSummaryAggregator(d))(
+      assignments: RDD[(Long, VectorWithNorm)],
+      distanceMeasure: DistanceMeasure): Map[Long, ClusterSummary] = {
+    assignments.aggregateByKey(new ClusterSummaryAggregator(d, distanceMeasure))(
         seqOp = (agg, v) => agg.add(v),
         combOp = (agg1, agg2) => agg1.merge(agg2)
       ).mapValues(_.summary)
@@ -278,7 +299,8 @@ private object BisectingKMeans extends Serializable {
    * Cluster summary aggregator.
    * @param d feature dimension
    */
-  private class ClusterSummaryAggregator(val d: Int) extends Serializable {
+  private class ClusterSummaryAggregator(val d: Int, val distanceMeasure: DistanceMeasure)
+      extends Serializable {
     private var n: Long = 0L
     private val sum: Vector = Vectors.zeros(d)
     private var sumSq: Double = 0.0
@@ -288,7 +310,7 @@ private object BisectingKMeans extends Serializable {
       n += 1L
       // TODO: use a numerically stable approach to estimate cost
       sumSq += v.norm * v.norm
-      BLAS.axpy(1.0, v.vector, sum)
+      distanceMeasure.updateClusterSum(v, sum)
       this
     }
 
@@ -296,19 +318,15 @@ private object BisectingKMeans extends Serializable {
     def merge(other: ClusterSummaryAggregator): this.type = {
       n += other.n
       sumSq += other.sumSq
-      BLAS.axpy(1.0, other.sum, sum)
+      distanceMeasure.updateClusterSum(new VectorWithNorm(other.sum), sum)
       this
     }
 
     /** Returns the summary. */
     def summary: ClusterSummary = {
-      val mean = sum.copy
-      if (n > 0L) {
-        BLAS.scal(1.0 / n, mean)
-      }
-      val center = new VectorWithNorm(mean)
-      val cost = math.max(sumSq - n * center.norm * center.norm, 0.0)
-      new ClusterSummary(n, center, cost)
+      val center = distanceMeasure.centroid(sum.copy, n)
+      val cost = distanceMeasure.clusterCost(center, new VectorWithNorm(sum), n, sumSq)
+      ClusterSummary(n, center, cost)
     }
   }
 
@@ -321,16 +339,13 @@ private object BisectingKMeans extends Serializable {
    */
   private def splitCenter(
       center: VectorWithNorm,
-      random: Random): (VectorWithNorm, VectorWithNorm) = {
+      random: Random,
+      distanceMeasure: DistanceMeasure): (VectorWithNorm, VectorWithNorm) = {
     val d = center.vector.size
     val norm = center.norm
     val level = 1e-4 * norm
     val noise = Vectors.dense(Array.fill(d)(random.nextDouble()))
-    val left = center.vector.copy
-    BLAS.axpy(-level, noise, left)
-    val right = center.vector.copy
-    BLAS.axpy(level, noise, right)
-    (new VectorWithNorm(left), new VectorWithNorm(right))
+    distanceMeasure.symmetricCentroids(level, noise, center.vector)
   }
 
   /**
@@ -343,16 +358,20 @@ private object BisectingKMeans extends Serializable {
   private def updateAssignments(
       assignments: RDD[(Long, VectorWithNorm)],
       divisibleIndices: Set[Long],
-      newClusterCenters: Map[Long, VectorWithNorm]): RDD[(Long, VectorWithNorm)] = {
+      newClusterCenters: Map[Long, VectorWithNorm],
+      distanceMeasure: DistanceMeasure): RDD[(Long, VectorWithNorm)] = {
     assignments.map { case (index, v) =>
       if (divisibleIndices.contains(index)) {
         val children = Seq(leftChildIndex(index), rightChildIndex(index))
-        val newClusterChildren = children.filter(newClusterCenters.contains(_))
+        val newClusterChildren = children.filter(newClusterCenters.contains)
+        val newClusterChildrenCenterToId =
+          newClusterChildren.map(id => newClusterCenters(id) -> id).toMap
+        val newClusterChildrenCenters = newClusterChildrenCenterToId.keys.toArray
         if (newClusterChildren.nonEmpty) {
-          val selected = newClusterChildren.minBy { child =>
-            EuclideanDistanceMeasure.fastSquaredDistance(newClusterCenters(child), v)
-          }
-          (selected, v)
+          val selected = distanceMeasure.findClosest(newClusterChildrenCenters, v)._1
+          val center = newClusterChildrenCenters(selected)
+          val id = newClusterChildrenCenterToId(center)
+          (id, v)
         } else {
           (index, v)
         }
@@ -367,7 +386,9 @@ private object BisectingKMeans extends Serializable {
    * @param clusters a map from cluster indices to corresponding cluster summaries
    * @return the root node of the clustering tree
    */
-  private def buildTree(clusters: Map[Long, ClusterSummary]): ClusteringTreeNode = {
+  private def buildTree(
+      clusters: Map[Long, ClusterSummary],
+      distanceMeasure: DistanceMeasure): ClusteringTreeNode = {
     var leafIndex = 0
     var internalIndex = -1
 
@@ -385,11 +406,11 @@ private object BisectingKMeans extends Serializable {
         internalIndex -= 1
         val leftIndex = leftChildIndex(rawIndex)
         val rightIndex = rightChildIndex(rawIndex)
-        val indexes = Seq(leftIndex, rightIndex).filter(clusters.contains(_))
-        val height = math.sqrt(indexes.map { childIndex =>
-          EuclideanDistanceMeasure.fastSquaredDistance(center, clusters(childIndex).center)
-        }.max)
-        val children = indexes.map(buildSubTree(_)).toArray
+        val indexes = Seq(leftIndex, rightIndex).filter(clusters.contains)
+        val height = indexes.map { childIndex =>
+          distanceMeasure.distance(center, clusters(childIndex).center)
+        }.max
+        val children = indexes.map(buildSubTree).toArray
         new ClusteringTreeNode(index, size, center, cost, height, children)
       } else {
         val index = leafIndex
@@ -441,42 +462,45 @@ private[clustering] class ClusteringTreeNode private[clustering] (
   def center: Vector = centerWithNorm.vector
 
   /** Predicts the leaf cluster node index that the input point belongs to. */
-  def predict(point: Vector): Int = {
-    val (index, _) = predict(new VectorWithNorm(point))
+  def predict(point: Vector, distanceMeasure: DistanceMeasure): Int = {
+    val (index, _) = predict(new VectorWithNorm(point), distanceMeasure)
     index
   }
 
   /** Returns the full prediction path from root to leaf. */
-  def predictPath(point: Vector): Array[ClusteringTreeNode] = {
-    predictPath(new VectorWithNorm(point)).toArray
+  def predictPath(point: Vector, distanceMeasure: DistanceMeasure): Array[ClusteringTreeNode] = {
+    predictPath(new VectorWithNorm(point), distanceMeasure).toArray
   }
 
   /** Returns the full prediction path from root to leaf. */
-  private def predictPath(pointWithNorm: VectorWithNorm): List[ClusteringTreeNode] = {
+  private def predictPath(
+      pointWithNorm: VectorWithNorm,
+      distanceMeasure: DistanceMeasure): List[ClusteringTreeNode] = {
     if (isLeaf) {
       this :: Nil
     } else {
       val selected = children.minBy { child =>
-        EuclideanDistanceMeasure.fastSquaredDistance(child.centerWithNorm, pointWithNorm)
+        distanceMeasure.distance(child.centerWithNorm, pointWithNorm)
       }
-      selected :: selected.predictPath(pointWithNorm)
+      selected :: selected.predictPath(pointWithNorm, distanceMeasure)
     }
   }
 
   /**
-   * Computes the cost (squared distance to the predicted leaf cluster center) of the input point.
+   * Computes the cost of the input point.
    */
-  def computeCost(point: Vector): Double = {
-    val (_, cost) = predict(new VectorWithNorm(point))
+  def computeCost(point: Vector, distanceMeasure: DistanceMeasure): Double = {
+    val (_, cost) = predict(new VectorWithNorm(point), distanceMeasure)
     cost
   }
 
   /**
    * Predicts the cluster index and the cost of the input point.
    */
-  private def predict(pointWithNorm: VectorWithNorm): (Int, Double) = {
-    predict(pointWithNorm,
-      EuclideanDistanceMeasure.fastSquaredDistance(centerWithNorm, pointWithNorm))
+  private def predict(
+      pointWithNorm: VectorWithNorm,
+      distanceMeasure: DistanceMeasure): (Int, Double) = {
+    predict(pointWithNorm, distanceMeasure.cost(centerWithNorm, pointWithNorm), distanceMeasure)
   }
 
   /**
@@ -486,14 +510,17 @@ private[clustering] class ClusteringTreeNode private[clustering] (
    * @return (predicted leaf cluster index, cost)
    */
   @tailrec
-  private def predict(pointWithNorm: VectorWithNorm, cost: Double): (Int, Double) = {
+  private def predict(
+      pointWithNorm: VectorWithNorm,
+      cost: Double,
+      distanceMeasure: DistanceMeasure): (Int, Double) = {
     if (isLeaf) {
       (index, cost)
     } else {
       val (selectedChild, minCost) = children.map { child =>
-        (child, EuclideanDistanceMeasure.fastSquaredDistance(child.centerWithNorm, pointWithNorm))
+        (child, distanceMeasure.cost(child.centerWithNorm, pointWithNorm))
       }.minBy(_._2)
-      selectedChild.predict(pointWithNorm, minCost)
+      selectedChild.predict(pointWithNorm, minCost, distanceMeasure)
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/567bd31e/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala
index 633bda6..9d115af 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala
@@ -40,9 +40,16 @@ import org.apache.spark.sql.{Row, SparkSession}
  */
 @Since("1.6.0")
 class BisectingKMeansModel private[clustering] (
-    private[clustering] val root: ClusteringTreeNode
+    private[clustering] val root: ClusteringTreeNode,
+    @Since("2.4.0") val distanceMeasure: String
   ) extends Serializable with Saveable with Logging {
 
+  @Since("1.6.0")
+  def this(root: ClusteringTreeNode) = this(root, DistanceMeasure.EUCLIDEAN)
+
+  private val distanceMeasureInstance: DistanceMeasure =
+    DistanceMeasure.decodeFromString(distanceMeasure)
+
   /**
    * Leaf cluster centers.
    */
@@ -59,7 +66,7 @@ class BisectingKMeansModel private[clustering] (
    */
   @Since("1.6.0")
   def predict(point: Vector): Int = {
-    root.predict(point)
+    root.predict(point, distanceMeasureInstance)
   }
 
   /**
@@ -67,7 +74,7 @@ class BisectingKMeansModel private[clustering] (
    */
   @Since("1.6.0")
   def predict(points: RDD[Vector]): RDD[Int] = {
-    points.map { p => root.predict(p) }
+    points.map { p => root.predict(p, distanceMeasureInstance) }
   }
 
   /**
@@ -82,7 +89,7 @@ class BisectingKMeansModel private[clustering] (
    */
   @Since("1.6.0")
   def computeCost(point: Vector): Double = {
-    root.computeCost(point)
+    root.computeCost(point, distanceMeasureInstance)
   }
 
   /**
@@ -91,7 +98,7 @@ class BisectingKMeansModel private[clustering] (
    */
   @Since("1.6.0")
   def computeCost(data: RDD[Vector]): Double = {
-    data.map(root.computeCost).sum()
+    data.map(root.computeCost(_, distanceMeasureInstance)).sum()
   }
 
   /**
@@ -113,18 +120,19 @@ object BisectingKMeansModel extends Loader[BisectingKMeansModel] {
 
   @Since("2.0.0")
   override def load(sc: SparkContext, path: String): BisectingKMeansModel = {
-    val (loadedClassName, formatVersion, metadata) = Loader.loadMetadata(sc, path)
-    implicit val formats = DefaultFormats
-    val rootId = (metadata \ "rootId").extract[Int]
-    val classNameV1_0 = SaveLoadV1_0.thisClassName
+    val (loadedClassName, formatVersion, __) = Loader.loadMetadata(sc, path)
     (loadedClassName, formatVersion) match {
-      case (classNameV1_0, "1.0") =>
-        val model = SaveLoadV1_0.load(sc, path, rootId)
+      case (SaveLoadV1_0.thisClassName, SaveLoadV1_0.thisFormatVersion) =>
+        val model = SaveLoadV1_0.load(sc, path)
+        model
+      case (SaveLoadV2_0.thisClassName, SaveLoadV2_0.thisFormatVersion) =>
+        val model = SaveLoadV1_0.load(sc, path)
         model
       case _ => throw new Exception(
         s"BisectingKMeansModel.load did not recognize model with (className, format version):" +
           s"($loadedClassName, $formatVersion).  Supported:\n" +
-          s"  ($classNameV1_0, 1.0)")
+          s"  (${SaveLoadV1_0.thisClassName}, ${SaveLoadV1_0.thisClassName}\n" +
+          s"  (${SaveLoadV2_0.thisClassName}, ${SaveLoadV2_0.thisClassName})")
     }
   }
 
@@ -136,8 +144,28 @@ object BisectingKMeansModel extends Loader[BisectingKMeansModel] {
       r.getDouble(4), r.getDouble(5), r.getSeq[Int](6))
   }
 
+  private def getNodes(node: ClusteringTreeNode): Array[ClusteringTreeNode] = {
+    if (node.children.isEmpty) {
+      Array(node)
+    } else {
+      node.children.flatMap(getNodes) ++ Array(node)
+    }
+  }
+
+  private def buildTree(rootId: Int, nodes: Map[Int, Data]): ClusteringTreeNode = {
+    val root = nodes(rootId)
+    if (root.children.isEmpty) {
+      new ClusteringTreeNode(root.index, root.size, new VectorWithNorm(root.center, root.norm),
+        root.cost, root.height, new Array[ClusteringTreeNode](0))
+    } else {
+      val children = root.children.map(c => buildTree(c, nodes))
+      new ClusteringTreeNode(root.index, root.size, new VectorWithNorm(root.center, root.norm),
+        root.cost, root.height, children.toArray)
+    }
+  }
+
   private[clustering] object SaveLoadV1_0 {
-    private val thisFormatVersion = "1.0"
+    private[clustering] val thisFormatVersion = "1.0"
 
     private[clustering]
     val thisClassName = "org.apache.spark.mllib.clustering.BisectingKMeansModel"
@@ -155,34 +183,55 @@ object BisectingKMeansModel extends Loader[BisectingKMeansModel] {
       spark.createDataFrame(data).write.parquet(Loader.dataPath(path))
     }
 
-    private def getNodes(node: ClusteringTreeNode): Array[ClusteringTreeNode] = {
-      if (node.children.isEmpty) {
-        Array(node)
-      } else {
-        node.children.flatMap(getNodes(_)) ++ Array(node)
-      }
-    }
-
-    def load(sc: SparkContext, path: String, rootId: Int): BisectingKMeansModel = {
+    def load(sc: SparkContext, path: String): BisectingKMeansModel = {
+      implicit val formats: DefaultFormats = DefaultFormats
+      val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path)
+      assert(className == thisClassName)
+      assert(formatVersion == thisFormatVersion)
+      val rootId = (metadata \ "rootId").extract[Int]
       val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
       val rows = spark.read.parquet(Loader.dataPath(path))
       Loader.checkSchema[Data](rows.schema)
       val data = rows.select("index", "size", "center", "norm", "cost", "height", "children")
       val nodes = data.rdd.map(Data.apply).collect().map(d => (d.index, d)).toMap
       val rootNode = buildTree(rootId, nodes)
-      new BisectingKMeansModel(rootNode)
+      new BisectingKMeansModel(rootNode, DistanceMeasure.EUCLIDEAN)
     }
+  }
+
+  private[clustering] object SaveLoadV2_0 {
+    private[clustering] val thisFormatVersion = "2.0"
 
-    private def buildTree(rootId: Int, nodes: Map[Int, Data]): ClusteringTreeNode = {
-      val root = nodes.get(rootId).get
-      if (root.children.isEmpty) {
-        new ClusteringTreeNode(root.index, root.size, new VectorWithNorm(root.center, root.norm),
-          root.cost, root.height, new Array[ClusteringTreeNode](0))
-      } else {
-        val children = root.children.map(c => buildTree(c, nodes))
-        new ClusteringTreeNode(root.index, root.size, new VectorWithNorm(root.center, root.norm),
-          root.cost, root.height, children.toArray)
-      }
+    private[clustering]
+    val thisClassName = "org.apache.spark.mllib.clustering.BisectingKMeansModel"
+
+    def save(sc: SparkContext, model: BisectingKMeansModel, path: String): Unit = {
+      val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
+      val metadata = compact(render(
+        ("class" -> thisClassName) ~ ("version" -> thisFormatVersion)
+          ~ ("rootId" -> model.root.index) ~ ("distanceMeasure" -> model.distanceMeasure)))
+      sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))
+
+      val data = getNodes(model.root).map(node => Data(node.index, node.size,
+        node.centerWithNorm.vector, node.centerWithNorm.norm, node.cost, node.height,
+        node.children.map(_.index)))
+      spark.createDataFrame(data).write.parquet(Loader.dataPath(path))
+    }
+
+    def load(sc: SparkContext, path: String): BisectingKMeansModel = {
+      implicit val formats: DefaultFormats = DefaultFormats
+      val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path)
+      assert(className == thisClassName)
+      assert(formatVersion == thisFormatVersion)
+      val rootId = (metadata \ "rootId").extract[Int]
+      val distanceMeasure = (metadata \ "distanceMeasure").extract[String]
+      val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
+      val rows = spark.read.parquet(Loader.dataPath(path))
+      Loader.checkSchema[Data](rows.schema)
+      val data = rows.select("index", "size", "center", "norm", "cost", "height", "children")
+      val nodes = data.rdd.map(Data.apply).collect().map(d => (d.index, d)).toMap
+      val rootNode = buildTree(rootId, nodes)
+      new BisectingKMeansModel(rootNode, distanceMeasure)
     }
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/567bd31e/mllib/src/main/scala/org/apache/spark/mllib/clustering/DistanceMeasure.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/DistanceMeasure.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/DistanceMeasure.scala
new file mode 100644
index 0000000..683360e
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/DistanceMeasure.scala
@@ -0,0 +1,303 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.clustering
+
+import org.apache.spark.annotation.Since
+import org.apache.spark.mllib.linalg.{Vector, Vectors}
+import org.apache.spark.mllib.linalg.BLAS.{axpy, dot, scal}
+import org.apache.spark.mllib.util.MLUtils
+
+private[spark] abstract class DistanceMeasure extends Serializable {
+
+  /**
+   * @return the index of the closest center to the given point, as well as the cost.
+   */
+  def findClosest(
+      centers: TraversableOnce[VectorWithNorm],
+      point: VectorWithNorm): (Int, Double) = {
+    var bestDistance = Double.PositiveInfinity
+    var bestIndex = 0
+    var i = 0
+    centers.foreach { center =>
+      val currentDistance = distance(center, point)
+      if (currentDistance < bestDistance) {
+        bestDistance = currentDistance
+        bestIndex = i
+      }
+      i += 1
+    }
+    (bestIndex, bestDistance)
+  }
+
+  /**
+   * @return the K-means cost of a given point against the given cluster centers.
+   */
+  def pointCost(
+      centers: TraversableOnce[VectorWithNorm],
+      point: VectorWithNorm): Double = {
+    findClosest(centers, point)._2
+  }
+
+  /**
+   * @return whether a center converged or not, given the epsilon parameter.
+   */
+  def isCenterConverged(
+      oldCenter: VectorWithNorm,
+      newCenter: VectorWithNorm,
+      epsilon: Double): Boolean = {
+    distance(oldCenter, newCenter) <= epsilon
+  }
+
+  /**
+   * @return the distance between two points.
+   */
+  def distance(
+      v1: VectorWithNorm,
+      v2: VectorWithNorm): Double
+
+  /**
+   * @return the total cost of the cluster from its aggregated properties
+   */
+  def clusterCost(
+      centroid: VectorWithNorm,
+      pointsSum: VectorWithNorm,
+      numberOfPoints: Long,
+      pointsSquaredNorm: Double): Double
+
+  /**
+   * Updates the value of `sum` adding the `point` vector.
+   * @param point a `VectorWithNorm` to be added to `sum` of a cluster
+   * @param sum the `sum` for a cluster to be updated
+   */
+  def updateClusterSum(point: VectorWithNorm, sum: Vector): Unit = {
+    axpy(1.0, point.vector, sum)
+  }
+
+  /**
+   * Returns a centroid for a cluster given its `sum` vector and its `count` of points.
+   *
+   * @param sum   the `sum` for a cluster
+   * @param count the number of points in the cluster
+   * @return the centroid of the cluster
+   */
+  def centroid(sum: Vector, count: Long): VectorWithNorm = {
+    scal(1.0 / count, sum)
+    new VectorWithNorm(sum)
+  }
+
+  /**
+   * Returns two new centroids symmetric to the specified centroid applying `noise` with the
+   * with the specified `level`.
+   *
+   * @param level the level of `noise` to apply to the given centroid.
+   * @param noise a noise vector
+   * @param centroid the parent centroid
+   * @return a left and right centroid symmetric to `centroid`
+   */
+  def symmetricCentroids(
+      level: Double,
+      noise: Vector,
+      centroid: Vector): (VectorWithNorm, VectorWithNorm) = {
+    val left = centroid.copy
+    axpy(-level, noise, left)
+    val right = centroid.copy
+    axpy(level, noise, right)
+    (new VectorWithNorm(left), new VectorWithNorm(right))
+  }
+
+  /**
+   * @return the cost of a point to be assigned to the cluster centroid
+   */
+  def cost(
+      point: VectorWithNorm,
+      centroid: VectorWithNorm): Double = distance(point, centroid)
+}
+
+@Since("2.4.0")
+object DistanceMeasure {
+
+  @Since("2.4.0")
+  val EUCLIDEAN = "euclidean"
+  @Since("2.4.0")
+  val COSINE = "cosine"
+
+  private[spark] def decodeFromString(distanceMeasure: String): DistanceMeasure =
+    distanceMeasure match {
+      case EUCLIDEAN => new EuclideanDistanceMeasure
+      case COSINE => new CosineDistanceMeasure
+      case _ => throw new IllegalArgumentException(s"distanceMeasure must be one of: " +
+        s"$EUCLIDEAN, $COSINE. $distanceMeasure provided.")
+    }
+
+  private[spark] def validateDistanceMeasure(distanceMeasure: String): Boolean = {
+    distanceMeasure match {
+      case DistanceMeasure.EUCLIDEAN => true
+      case DistanceMeasure.COSINE => true
+      case _ => false
+    }
+  }
+}
+
+private[spark] class EuclideanDistanceMeasure extends DistanceMeasure {
+  /**
+   * @return the index of the closest center to the given point, as well as the squared distance.
+   */
+  override def findClosest(
+      centers: TraversableOnce[VectorWithNorm],
+      point: VectorWithNorm): (Int, Double) = {
+    var bestDistance = Double.PositiveInfinity
+    var bestIndex = 0
+    var i = 0
+    centers.foreach { center =>
+      // Since `\|a - b\| \geq |\|a\| - \|b\||`, we can use this lower bound to avoid unnecessary
+      // distance computation.
+      var lowerBoundOfSqDist = center.norm - point.norm
+      lowerBoundOfSqDist = lowerBoundOfSqDist * lowerBoundOfSqDist
+      if (lowerBoundOfSqDist < bestDistance) {
+        val distance: Double = EuclideanDistanceMeasure.fastSquaredDistance(center, point)
+        if (distance < bestDistance) {
+          bestDistance = distance
+          bestIndex = i
+        }
+      }
+      i += 1
+    }
+    (bestIndex, bestDistance)
+  }
+
+  /**
+   * @return whether a center converged or not, given the epsilon parameter.
+   */
+  override def isCenterConverged(
+      oldCenter: VectorWithNorm,
+      newCenter: VectorWithNorm,
+      epsilon: Double): Boolean = {
+    EuclideanDistanceMeasure.fastSquaredDistance(newCenter, oldCenter) <= epsilon * epsilon
+  }
+
+  /**
+   * @param v1: first vector
+   * @param v2: second vector
+   * @return the Euclidean distance between the two input vectors
+   */
+  override def distance(v1: VectorWithNorm, v2: VectorWithNorm): Double = {
+    Math.sqrt(EuclideanDistanceMeasure.fastSquaredDistance(v1, v2))
+  }
+
+  /**
+   * @return the total cost of the cluster from its aggregated properties
+   */
+  override def clusterCost(
+      centroid: VectorWithNorm,
+      pointsSum: VectorWithNorm,
+      numberOfPoints: Long,
+      pointsSquaredNorm: Double): Double = {
+    math.max(pointsSquaredNorm - numberOfPoints * centroid.norm * centroid.norm, 0.0)
+  }
+
+  /**
+   * @return the cost of a point to be assigned to the cluster centroid
+   */
+  override def cost(
+      point: VectorWithNorm,
+      centroid: VectorWithNorm): Double = {
+    EuclideanDistanceMeasure.fastSquaredDistance(point, centroid)
+  }
+}
+
+
+private[spark] object EuclideanDistanceMeasure {
+  /**
+   * @return the squared Euclidean distance between two vectors computed by
+   * [[org.apache.spark.mllib.util.MLUtils#fastSquaredDistance]].
+   */
+  private[clustering] def fastSquaredDistance(
+      v1: VectorWithNorm,
+      v2: VectorWithNorm): Double = {
+    MLUtils.fastSquaredDistance(v1.vector, v1.norm, v2.vector, v2.norm)
+  }
+}
+
+private[spark] class CosineDistanceMeasure extends DistanceMeasure {
+  /**
+   * @param v1: first vector
+   * @param v2: second vector
+   * @return the cosine distance between the two input vectors
+   */
+  override def distance(v1: VectorWithNorm, v2: VectorWithNorm): Double = {
+    assert(v1.norm > 0 && v2.norm > 0, "Cosine distance is not defined for zero-length vectors.")
+    1 - dot(v1.vector, v2.vector) / v1.norm / v2.norm
+  }
+
+  /**
+   * Updates the value of `sum` adding the `point` vector.
+   * @param point a `VectorWithNorm` to be added to `sum` of a cluster
+   * @param sum the `sum` for a cluster to be updated
+   */
+  override def updateClusterSum(point: VectorWithNorm, sum: Vector): Unit = {
+    assert(point.norm > 0, "Cosine distance is not defined for zero-length vectors.")
+    axpy(1.0 / point.norm, point.vector, sum)
+  }
+
+  /**
+   * Returns a centroid for a cluster given its `sum` vector and its `count` of points.
+   *
+   * @param sum   the `sum` for a cluster
+   * @param count the number of points in the cluster
+   * @return the centroid of the cluster
+   */
+  override def centroid(sum: Vector, count: Long): VectorWithNorm = {
+    scal(1.0 / count, sum)
+    val norm = Vectors.norm(sum, 2)
+    scal(1.0 / norm, sum)
+    new VectorWithNorm(sum, 1)
+  }
+
+  /**
+   * @return the total cost of the cluster from its aggregated properties
+   */
+  override def clusterCost(
+      centroid: VectorWithNorm,
+      pointsSum: VectorWithNorm,
+      numberOfPoints: Long,
+      pointsSquaredNorm: Double): Double = {
+    val costVector = pointsSum.vector.copy
+    math.max(numberOfPoints - dot(centroid.vector, costVector) / centroid.norm, 0.0)
+  }
+
+  /**
+   * Returns two new centroids symmetric to the specified centroid applying `noise` with the
+   * with the specified `level`.
+   *
+   * @param level the level of `noise` to apply to the given centroid.
+   * @param noise a noise vector
+   * @param centroid the parent centroid
+   * @return a left and right centroid symmetric to `centroid`
+   */
+  override def symmetricCentroids(
+      level: Double,
+      noise: Vector,
+      centroid: Vector): (VectorWithNorm, VectorWithNorm) = {
+    val (left, right) = super.symmetricCentroids(level, noise, centroid)
+    val leftVector = left.vector
+    val rightVector = right.vector
+    scal(1.0 / left.norm, leftVector)
+    scal(1.0 / right.norm, rightVector)
+    (new VectorWithNorm(leftVector, 1.0), new VectorWithNorm(rightVector, 1.0))
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/567bd31e/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
index 3c4ba0b..b5b1be3 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
@@ -25,8 +25,7 @@ import org.apache.spark.internal.Logging
 import org.apache.spark.ml.clustering.{KMeans => NewKMeans}
 import org.apache.spark.ml.util.Instrumentation
 import org.apache.spark.mllib.linalg.{Vector, Vectors}
-import org.apache.spark.mllib.linalg.BLAS.{axpy, dot, scal}
-import org.apache.spark.mllib.util.MLUtils
+import org.apache.spark.mllib.linalg.BLAS.axpy
 import org.apache.spark.rdd.RDD
 import org.apache.spark.storage.StorageLevel
 import org.apache.spark.util.Utils
@@ -204,7 +203,7 @@ class KMeans private (
    */
   @Since("2.4.0")
   def setDistanceMeasure(distanceMeasure: String): this.type = {
-    KMeans.validateDistanceMeasure(distanceMeasure)
+    DistanceMeasure.validateDistanceMeasure(distanceMeasure)
     this.distanceMeasure = distanceMeasure
     this
   }
@@ -582,14 +581,6 @@ object KMeans {
       case _ => false
     }
   }
-
-  private[spark] def validateDistanceMeasure(distanceMeasure: String): Boolean = {
-    distanceMeasure match {
-      case DistanceMeasure.EUCLIDEAN => true
-      case DistanceMeasure.COSINE => true
-      case _ => false
-    }
-  }
 }
 
 /**
@@ -605,186 +596,3 @@ private[clustering] class VectorWithNorm(val vector: Vector, val norm: Double)
   /** Converts the vector to a dense vector. */
   def toDense: VectorWithNorm = new VectorWithNorm(Vectors.dense(vector.toArray), norm)
 }
-
-
-private[spark] abstract class DistanceMeasure extends Serializable {
-
-  /**
-   * @return the index of the closest center to the given point, as well as the cost.
-   */
-  def findClosest(
-      centers: TraversableOnce[VectorWithNorm],
-      point: VectorWithNorm): (Int, Double) = {
-    var bestDistance = Double.PositiveInfinity
-    var bestIndex = 0
-    var i = 0
-    centers.foreach { center =>
-      val currentDistance = distance(center, point)
-      if (currentDistance < bestDistance) {
-        bestDistance = currentDistance
-        bestIndex = i
-      }
-      i += 1
-    }
-    (bestIndex, bestDistance)
-  }
-
-  /**
-   * @return the K-means cost of a given point against the given cluster centers.
-   */
-  def pointCost(
-      centers: TraversableOnce[VectorWithNorm],
-      point: VectorWithNorm): Double = {
-    findClosest(centers, point)._2
-  }
-
-  /**
-   * @return whether a center converged or not, given the epsilon parameter.
-   */
-  def isCenterConverged(
-      oldCenter: VectorWithNorm,
-      newCenter: VectorWithNorm,
-      epsilon: Double): Boolean = {
-    distance(oldCenter, newCenter) <= epsilon
-  }
-
-  /**
-   * @return the cosine distance between two points.
-   */
-  def distance(
-      v1: VectorWithNorm,
-      v2: VectorWithNorm): Double
-
-  /**
-   * Updates the value of `sum` adding the `point` vector.
-   * @param point a `VectorWithNorm` to be added to `sum` of a cluster
-   * @param sum the `sum` for a cluster to be updated
-   */
-  def updateClusterSum(point: VectorWithNorm, sum: Vector): Unit = {
-    axpy(1.0, point.vector, sum)
-  }
-
-  /**
-   * Returns a centroid for a cluster given its `sum` vector and its `count` of points.
-   *
-   * @param sum   the `sum` for a cluster
-   * @param count the number of points in the cluster
-   * @return the centroid of the cluster
-   */
-  def centroid(sum: Vector, count: Long): VectorWithNorm = {
-    scal(1.0 / count, sum)
-    new VectorWithNorm(sum)
-  }
-}
-
-@Since("2.4.0")
-object DistanceMeasure {
-
-  @Since("2.4.0")
-  val EUCLIDEAN = "euclidean"
-  @Since("2.4.0")
-  val COSINE = "cosine"
-
-  private[spark] def decodeFromString(distanceMeasure: String): DistanceMeasure =
-    distanceMeasure match {
-      case EUCLIDEAN => new EuclideanDistanceMeasure
-      case COSINE => new CosineDistanceMeasure
-      case _ => throw new IllegalArgumentException(s"distanceMeasure must be one of: " +
-        s"$EUCLIDEAN, $COSINE. $distanceMeasure provided.")
-    }
-}
-
-private[spark] class EuclideanDistanceMeasure extends DistanceMeasure {
-  /**
-   * @return the index of the closest center to the given point, as well as the squared distance.
-   */
-  override def findClosest(
-      centers: TraversableOnce[VectorWithNorm],
-      point: VectorWithNorm): (Int, Double) = {
-    var bestDistance = Double.PositiveInfinity
-    var bestIndex = 0
-    var i = 0
-    centers.foreach { center =>
-      // Since `\|a - b\| \geq |\|a\| - \|b\||`, we can use this lower bound to avoid unnecessary
-      // distance computation.
-      var lowerBoundOfSqDist = center.norm - point.norm
-      lowerBoundOfSqDist = lowerBoundOfSqDist * lowerBoundOfSqDist
-      if (lowerBoundOfSqDist < bestDistance) {
-        val distance: Double = EuclideanDistanceMeasure.fastSquaredDistance(center, point)
-        if (distance < bestDistance) {
-          bestDistance = distance
-          bestIndex = i
-        }
-      }
-      i += 1
-    }
-    (bestIndex, bestDistance)
-  }
-
-  /**
-   * @return whether a center converged or not, given the epsilon parameter.
-   */
-  override def isCenterConverged(
-      oldCenter: VectorWithNorm,
-      newCenter: VectorWithNorm,
-      epsilon: Double): Boolean = {
-    EuclideanDistanceMeasure.fastSquaredDistance(newCenter, oldCenter) <= epsilon * epsilon
-  }
-
-  /**
-   * @param v1: first vector
-   * @param v2: second vector
-   * @return the Euclidean distance between the two input vectors
-   */
-  override def distance(v1: VectorWithNorm, v2: VectorWithNorm): Double = {
-    Math.sqrt(EuclideanDistanceMeasure.fastSquaredDistance(v1, v2))
-  }
-}
-
-
-private[spark] object EuclideanDistanceMeasure {
-  /**
-   * @return the squared Euclidean distance between two vectors computed by
-   * [[org.apache.spark.mllib.util.MLUtils#fastSquaredDistance]].
-   */
-  private[clustering] def fastSquaredDistance(
-      v1: VectorWithNorm,
-      v2: VectorWithNorm): Double = {
-    MLUtils.fastSquaredDistance(v1.vector, v1.norm, v2.vector, v2.norm)
-  }
-}
-
-private[spark] class CosineDistanceMeasure extends DistanceMeasure {
-  /**
-   * @param v1: first vector
-   * @param v2: second vector
-   * @return the cosine distance between the two input vectors
-   */
-  override def distance(v1: VectorWithNorm, v2: VectorWithNorm): Double = {
-    assert(v1.norm > 0 && v2.norm > 0, "Cosine distance is not defined for zero-length vectors.")
-    1 - dot(v1.vector, v2.vector) / v1.norm / v2.norm
-  }
-
-  /**
-   * Updates the value of `sum` adding the `point` vector.
-   * @param point a `VectorWithNorm` to be added to `sum` of a cluster
-   * @param sum the `sum` for a cluster to be updated
-   */
-  override def updateClusterSum(point: VectorWithNorm, sum: Vector): Unit = {
-    axpy(1.0 / point.norm, point.vector, sum)
-  }
-
-  /**
-   * Returns a centroid for a cluster given its `sum` vector and its `count` of points.
-   *
-   * @param sum   the `sum` for a cluster
-   * @param count the number of points in the cluster
-   * @return the centroid of the cluster
-   */
-  override def centroid(sum: Vector, count: Long): VectorWithNorm = {
-    scal(1.0 / count, sum)
-    val norm = Vectors.norm(sum, 2)
-    scal(1.0 / norm, sum)
-    new VectorWithNorm(sum, 1)
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/567bd31e/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
index fa7471f..02880f9 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
@@ -17,9 +17,11 @@
 
 package org.apache.spark.ml.clustering
 
-import org.apache.spark.SparkFunSuite
+import org.apache.spark.{SparkException, SparkFunSuite}
+import org.apache.spark.ml.linalg.{Vector, Vectors}
 import org.apache.spark.ml.param.ParamMap
 import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
+import org.apache.spark.mllib.clustering.DistanceMeasure
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.sql.Dataset
 
@@ -140,6 +142,46 @@ class BisectingKMeansSuite
     testEstimatorAndModelReadWrite(bisectingKMeans, dataset, BisectingKMeansSuite.allParamSettings,
       BisectingKMeansSuite.allParamSettings, checkModelData)
   }
+
+  test("BisectingKMeans with cosine distance is not supported for 0-length vectors") {
+    val model = new BisectingKMeans().setK(2).setDistanceMeasure(DistanceMeasure.COSINE).setSeed(1)
+    val df = spark.createDataFrame(spark.sparkContext.parallelize(Array(
+      Vectors.dense(0.0, 0.0),
+      Vectors.dense(10.0, 10.0),
+      Vectors.dense(1.0, 0.5)
+    )).map(v => TestRow(v)))
+    val e = intercept[SparkException](model.fit(df))
+    assert(e.getCause.isInstanceOf[AssertionError])
+    assert(e.getCause.getMessage.contains("Cosine distance is not defined"))
+  }
+
+  test("BisectingKMeans with cosine distance") {
+    val df = spark.createDataFrame(spark.sparkContext.parallelize(Array(
+      Vectors.dense(1.0, 1.0),
+      Vectors.dense(10.0, 10.0),
+      Vectors.dense(1.0, 0.5),
+      Vectors.dense(10.0, 4.4),
+      Vectors.dense(-1.0, 1.0),
+      Vectors.dense(-100.0, 90.0)
+    )).map(v => TestRow(v)))
+    val model = new BisectingKMeans()
+      .setK(3)
+      .setDistanceMeasure(DistanceMeasure.COSINE)
+      .setSeed(1)
+      .fit(df)
+    val predictionDf = model.transform(df)
+    assert(predictionDf.select("prediction").distinct().count() == 3)
+    val predictionsMap = predictionDf.collect().map(row =>
+      row.getAs[Vector]("features") -> row.getAs[Int]("prediction")).toMap
+    assert(predictionsMap(Vectors.dense(1.0, 1.0)) ==
+      predictionsMap(Vectors.dense(10.0, 10.0)))
+    assert(predictionsMap(Vectors.dense(1.0, 0.5)) ==
+      predictionsMap(Vectors.dense(10.0, 4.4)))
+    assert(predictionsMap(Vectors.dense(-1.0, 1.0)) ==
+      predictionsMap(Vectors.dense(-100.0, 90.0)))
+
+    model.clusterCenters.forall(Vectors.norm(_, 2) == 1.0)
+  }
 }
 
 object BisectingKMeansSuite {

http://git-wip-us.apache.org/repos/asf/spark/blob/567bd31e/project/MimaExcludes.scala
----------------------------------------------------------------------
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 381f7b5..1b6d1de 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -36,6 +36,12 @@ object MimaExcludes {
 
   // Exclude rules for 2.4.x
   lazy val v24excludes = v23excludes ++ Seq(
+    // [SPARK-23412][ML] Add cosine distance measure to BisectingKmeans
+    ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasDistanceMeasure.org$apache$spark$ml$param$shared$HasDistanceMeasure$_setter_$distanceMeasure_="),
+    ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasDistanceMeasure.getDistanceMeasure"),
+    ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasDistanceMeasure.distanceMeasure"),
+    ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.clustering.BisectingKMeansModel#SaveLoadV1_0.load"),
+    
     // [SPARK-20659] Remove StorageStatus, or make it private
     ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.SparkExecutorInfo.totalOffHeapStorageMemory"),
     ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.SparkExecutorInfo.usedOffHeapStorageMemory"),


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org