You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by hu...@apache.org on 2022/03/02 19:55:01 UTC

[spark] branch branch-3.1 updated: [SPARK-36553][ML] KMeans avoid compute auxiliary statistics for large K

This is an automated email from the ASF dual-hosted git repository.

huaxingao pushed a commit to branch branch-3.1
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.1 by this push:
     new 357d3b2  [SPARK-36553][ML] KMeans avoid compute auxiliary statistics for large K
357d3b2 is described below

commit 357d3b24173405cdf915be60f2cebe442fa31536
Author: Ruifeng Zheng <ru...@foxmail.com>
AuthorDate: Wed Mar 2 11:51:06 2022 -0800

    [SPARK-36553][ML] KMeans avoid compute auxiliary statistics for large K
    
    ### What changes were proposed in this pull request?
    
    SPARK-31007 introduce an auxiliary statistics to speed up computation in KMeasn.
    
    However, it needs a array of size `k * (k + 1) / 2`, which may cause overflow or OOM when k is too large.
    
    So we should skip this optimization in this case.
    
    ### Why are the changes needed?
    
    avoid overflow or OOM when k is too large (like 50,000)
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    existing testsuites
    
    Closes #35457 from zhengruifeng/kmean_k_limit.
    
    Authored-by: Ruifeng Zheng <ru...@foxmail.com>
    Signed-off-by: huaxingao <hu...@apple.com>
    (cherry picked from commit ad5427ebe644fc01a9b4c19a48f902f584245edf)
    Signed-off-by: huaxingao <hu...@apple.com>
---
 .../spark/mllib/clustering/DistanceMeasure.scala   | 23 ++++++++++++++++++++++
 .../org/apache/spark/mllib/clustering/KMeans.scala | 15 ++++++++++----
 .../spark/mllib/clustering/KMeansModel.scala       | 11 +++++++++--
 3 files changed, 43 insertions(+), 6 deletions(-)

diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/DistanceMeasure.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/DistanceMeasure.scala
index 9ac473a..e4c29a7 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/DistanceMeasure.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/DistanceMeasure.scala
@@ -118,6 +118,24 @@ private[spark] abstract class DistanceMeasure extends Serializable {
   }
 
   /**
+   * @param centers the clustering centers
+   * @param statistics optional statistics to accelerate the computation, which should not
+   *                   change the result.
+   * @param point given point
+   * @return the index of the closest center to the given point, as well as the cost.
+   */
+  def findClosest(
+      centers: Array[VectorWithNorm],
+      statistics: Option[Array[Double]],
+      point: VectorWithNorm): (Int, Double) = {
+    if (statistics.nonEmpty) {
+      findClosest(centers, statistics.get, point)
+    } else {
+      findClosest(centers, point)
+    }
+  }
+
+  /**
    * @return the index of the closest center to the given point, as well as the cost.
    */
   def findClosest(
@@ -253,6 +271,11 @@ object DistanceMeasure {
       case _ => false
     }
   }
+
+  private[clustering] def shouldComputeStatistics(k: Int): Boolean = k < 1000
+
+  private[clustering] def shouldComputeStatisticsLocally(k: Int, numFeatures: Int): Boolean =
+    k.toLong * k * numFeatures < 1000000
 }
 
 private[spark] class EuclideanDistanceMeasure extends DistanceMeasure {
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
index 76e2928..c140b1b 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
@@ -269,15 +269,22 @@ class KMeans private (
 
     instr.foreach(_.logNumFeatures(numFeatures))
 
-    val shouldDistributed = centers.length * centers.length * numFeatures.toLong > 1000000L
+    val shouldComputeStats =
+      DistanceMeasure.shouldComputeStatistics(centers.length)
+    val shouldComputeStatsLocally =
+      DistanceMeasure.shouldComputeStatisticsLocally(centers.length, numFeatures)
 
     // Execute iterations of Lloyd's algorithm until converged
     while (iteration < maxIterations && !converged) {
       val bcCenters = sc.broadcast(centers)
-      val stats = if (shouldDistributed) {
-        distanceMeasureInstance.computeStatisticsDistributedly(sc, bcCenters)
+      val stats = if (shouldComputeStats) {
+        if (shouldComputeStatsLocally) {
+          Some(distanceMeasureInstance.computeStatistics(centers))
+        } else {
+          Some(distanceMeasureInstance.computeStatisticsDistributedly(sc, bcCenters))
+        }
       } else {
-        distanceMeasureInstance.computeStatistics(centers)
+        None
       }
       val bcStats = sc.broadcast(stats)
 
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala
index a24493b..64b3521 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala
@@ -50,9 +50,16 @@ class KMeansModel (@Since("1.0.0") val clusterCenters: Array[Vector],
 
   // TODO: computation of statistics may take seconds, so save it to KMeansModel in training
   @transient private lazy val statistics = if (clusterCenters == null) {
-    null
+    None
   } else {
-    distanceMeasureInstance.computeStatistics(clusterCentersWithNorm)
+    val k = clusterCenters.length
+    val numFeatures = clusterCenters.head.size
+    if (DistanceMeasure.shouldComputeStatistics(k) &&
+        DistanceMeasure.shouldComputeStatisticsLocally(k, numFeatures)) {
+      Some(distanceMeasureInstance.computeStatistics(clusterCentersWithNorm))
+    } else {
+      None
+    }
   }
 
   @Since("2.4.0")

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org